diff --git a/deepmd/loggers/training.py b/deepmd/loggers/training.py index 2ea1eca16f..d145f42897 100644 --- a/deepmd/loggers/training.py +++ b/deepmd/loggers/training.py @@ -38,6 +38,7 @@ def format_training_message( wall_time: float, eta: int | None = None, current_time: datetime.datetime | None = None, + step_time: float | None = None, ) -> str: """Format the summary message for one training interval. @@ -52,6 +53,9 @@ def format_training_message( current_time : datetime.datetime | None, optional Current local time used to estimate the finish timestamp. This is only used when ``eta`` is provided. + step_time : float | None, optional + Average wall-clock time per training step over this interval, in + seconds. Shown only when provided. Returns ------- @@ -59,6 +63,8 @@ def format_training_message( The formatted training message. """ msg = f"Batch {batch:7d}: total wall time = {wall_time:.2f} s" + if step_time is not None: + msg += f", avg = {step_time:.4f} s/step" if isinstance(eta, int): eta_seconds = int(eta) msg += ( diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index ea3b27bf72..70a880094b 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -1452,6 +1452,7 @@ def run(self) -> None: self.wrapper.train() wall_start = time.time() last_log_time = wall_start + last_log_step = self.start_step for step_id in range(self.start_step, self.num_steps): cur_lr = float(self.lr_schedule.value(step_id)) @@ -1465,9 +1466,6 @@ def run(self) -> None: ) task_key = self.model_keys[model_index] - if self.timing_in_training: - t_start = time.time() - # --- forward / backward --- self.optimizer.zero_grad(set_to_none=True) input_dict, label_dict = self.get_data(is_train=True, task_key=task_key) @@ -1488,9 +1486,6 @@ def run(self) -> None: self._optimizer_step() - if self.timing_in_training: - t_end = time.time() - # --- display --- display_step_id = step_id + 1 if self.display_in_training and ( @@ -1598,9 +1593,14 @@ def _to_float(v: Any) -> float: current_time = time.time() wall_elapsed = current_time - wall_start interval_wall_time = current_time - last_log_time + # average wall time per step over the interval since the + # last log (number of steps counted exactly once across + # intervals via last_log_step) + interval_steps = max(1, display_step_id - last_log_step) + step_time = interval_wall_time / interval_steps last_log_time = current_time + last_log_step = display_step_id if self.timing_in_training: - step_time = t_end - t_start steps_completed_since_restart = max( 1, display_step_id - self.start_step, @@ -1619,9 +1619,9 @@ def _to_float(v: Any) -> float: current_time, tz=datetime.timezone.utc, ).astimezone(), + step_time=step_time, ) ) - log.info("step=%d step_time=%.4fs", display_step_id, step_time) else: log.info( format_training_message( diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index a349d2e63c..cbb8368074 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -8,6 +8,7 @@ 4. Loss decreases over those steps """ +import datetime import os import shutil import tempfile @@ -18,6 +19,9 @@ import torch +from deepmd.loggers.training import ( + format_training_message, +) from deepmd.pt_expt.entrypoints.main import ( get_trainer, ) @@ -1684,5 +1688,60 @@ def test_compiled_matches_eager_per_task(self) -> None: shutil.rmtree(tmpdir, ignore_errors=True) +class TestFormatTrainingMessageStepTime(unittest.TestCase): + """The pt_expt trainer reports the average wall time per step over each + display interval by passing ``step_time`` to ``format_training_message`` + (replacing the former standalone ``step=... step_time=...`` debug line). + These tests cover both branches of the optional ``step_time``/``eta`` + arguments so the "avg = ... s/step" segment is rendered only when requested. + """ + + def test_without_step_time(self) -> None: + """``step_time=None`` (default) omits the step-time segment.""" + msg = format_training_message(batch=100, wall_time=18.41) + self.assertEqual(msg, "Batch 100: total wall time = 18.41 s") + self.assertNotIn("s/step", msg) + + def test_with_step_time(self) -> None: + """``step_time`` is rendered with 4 decimals after the wall time.""" + msg = format_training_message(batch=100, wall_time=18.41, step_time=0.1841) + self.assertEqual( + msg, + "Batch 100: total wall time = 18.41 s, avg = 0.1841 s/step", + ) + + def test_step_time_zero_is_shown(self) -> None: + """A literal ``0.0`` step time is still shown (not treated as absent).""" + msg = format_training_message(batch=1, wall_time=0.5, step_time=0.0) + self.assertIn("avg = 0.0000 s/step", msg) + + def test_with_step_time_and_eta(self) -> None: + """Step time appears before the eta segment.""" + current_time = datetime.datetime( + 2026, 6, 7, 5, 21, 29, tzinfo=datetime.timezone.utc + ) + msg = format_training_message( + batch=100, + wall_time=18.41, + eta=100, + current_time=current_time, + step_time=0.1841, + ) + self.assertIn("total wall time = 18.41 s, avg = 0.1841 s/step, eta = ", msg) + # ordering: wall time -> step time -> eta + self.assertLess(msg.index("s/step"), msg.index("eta =")) + + def test_eta_without_step_time(self) -> None: + """Eta still works when no step time is supplied.""" + current_time = datetime.datetime( + 2026, 6, 7, 5, 21, 29, tzinfo=datetime.timezone.utc + ) + msg = format_training_message( + batch=100, wall_time=18.41, eta=100, current_time=current_time + ) + self.assertNotIn("s/step", msg) + self.assertIn("eta = ", msg) + + if __name__ == "__main__": unittest.main()