Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def get_train_input_output_trees(func, input_args, input_kwargs):
p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree)
return p_train_step

def calculate_tokens_training_per_device(config):
"""Calculate training Tokens per device"""
return config.max_target_length * config.per_device_batch_size

def calculate_tflops_training_per_device(config, log=True):
"""Calculate training TFLOP"""
Expand Down
8 changes: 6 additions & 2 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def load_next_batch(train_iter, example_batch, config):
return next(train_iter)


def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr, per_device_tokens):
"""Records scalar metrics to be written to tensorboard"""
metrics["scalar"].update({"perf/step_time_seconds": step_time_delta.total_seconds()})
metrics["scalar"].update({"perf/per_device_tflops": per_device_tflops})
metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()})
metrics["scalar"].update({"perf/per_device_tokens": per_device_tokens})
metrics["scalar"].update({"perf/per_device_tokens_per_sec": per_device_tokens / step_time_delta.total_seconds()})
metrics["scalar"].update({"learning/current_learning_rate": lr})


Expand Down Expand Up @@ -147,6 +149,7 @@ def write_metrics_to_tensorboard(writer, metrics, step, config):
max_logging.log(
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"Tokens/s/device: {metrics['scalar']['perf/per_device_tokens_per_sec']:.3f}, "
f"loss: {metrics['scalar']['learning/loss']:.3f}"
)

Expand Down Expand Up @@ -483,6 +486,7 @@ def train_loop(config, state=None):
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)
per_device_tokens = maxtext_utils.calculate_tokens_training_per_device(config)

# Write train config params, num model params, and XLA flags to tensorboard
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
Expand Down Expand Up @@ -542,7 +546,7 @@ def train_loop(config, state=None):
state, metrics = p_train_step(state, example_batch, nextrng)

new_time = datetime.datetime.now()
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step))
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step), per_device_tokens)
last_step_completion = new_time

if checkpoint_manager is not None:
Expand Down