From 189aff8595cbd2e8f2140c8b1911585bb7bfe90f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 7 Jun 2026 09:55:29 +0800 Subject: [PATCH 1/2] fix(pt_expt): report average step time in the normal training log The pt_expt trainer emitted a separate, always-on `step=N step_time=...s` debug line (leftover from #5397) in addition to the normal progress message, where step_time was a single instantaneous sample of the last step. Replace it by folding an average-per-step time into the standard progress message: `format_training_message` gains an optional `step_time` argument (rendered as `step time = X.XXXX s`, default-off so pt/pd/jax are unaffected), and the trainer now reports interval_wall_time / steps_since_last_log. The per-step time.time() bracketing (t_start/t_end) is dropped from the hot loop. Tests covering the with/without step_time and eta branches are added to source/tests/pt_expt/test_training.py. --- deepmd/loggers/training.py | 6 +++ deepmd/pt_expt/train/training.py | 16 ++++---- source/tests/pt_expt/test_training.py | 59 +++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 8 deletions(-) diff --git a/deepmd/loggers/training.py b/deepmd/loggers/training.py index 2ea1eca16f..f09199abd1 100644 --- a/deepmd/loggers/training.py +++ b/deepmd/loggers/training.py @@ -38,6 +38,7 @@ def format_training_message( wall_time: float, eta: int | None = None, current_time: datetime.datetime | None = None, + step_time: float | None = None, ) -> str: """Format the summary message for one training interval. @@ -52,6 +53,9 @@ def format_training_message( current_time : datetime.datetime | None, optional Current local time used to estimate the finish timestamp. This is only used when ``eta`` is provided. + step_time : float | None, optional + Average wall-clock time per training step over this interval, in + seconds. Shown only when provided. Returns ------- @@ -59,6 +63,8 @@ def format_training_message( The formatted training message. """ msg = f"Batch {batch:7d}: total wall time = {wall_time:.2f} s" + if step_time is not None: + msg += f", step time = {step_time:.4f} s" if isinstance(eta, int): eta_seconds = int(eta) msg += ( diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index ea3b27bf72..70a880094b 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -1452,6 +1452,7 @@ def run(self) -> None: self.wrapper.train() wall_start = time.time() last_log_time = wall_start + last_log_step = self.start_step for step_id in range(self.start_step, self.num_steps): cur_lr = float(self.lr_schedule.value(step_id)) @@ -1465,9 +1466,6 @@ def run(self) -> None: ) task_key = self.model_keys[model_index] - if self.timing_in_training: - t_start = time.time() - # --- forward / backward --- self.optimizer.zero_grad(set_to_none=True) input_dict, label_dict = self.get_data(is_train=True, task_key=task_key) @@ -1488,9 +1486,6 @@ def run(self) -> None: self._optimizer_step() - if self.timing_in_training: - t_end = time.time() - # --- display --- display_step_id = step_id + 1 if self.display_in_training and ( @@ -1598,9 +1593,14 @@ def _to_float(v: Any) -> float: current_time = time.time() wall_elapsed = current_time - wall_start interval_wall_time = current_time - last_log_time + # average wall time per step over the interval since the + # last log (number of steps counted exactly once across + # intervals via last_log_step) + interval_steps = max(1, display_step_id - last_log_step) + step_time = interval_wall_time / interval_steps last_log_time = current_time + last_log_step = display_step_id if self.timing_in_training: - step_time = t_end - t_start steps_completed_since_restart = max( 1, display_step_id - self.start_step, @@ -1619,9 +1619,9 @@ def _to_float(v: Any) -> float: current_time, tz=datetime.timezone.utc, ).astimezone(), + step_time=step_time, ) ) - log.info("step=%d step_time=%.4fs", display_step_id, step_time) else: log.info( format_training_message( diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index a349d2e63c..0079480e3f 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -8,6 +8,7 @@ 4. Loss decreases over those steps """ +import datetime import os import shutil import tempfile @@ -18,6 +19,9 @@ import torch +from deepmd.loggers.training import ( + format_training_message, +) from deepmd.pt_expt.entrypoints.main import ( get_trainer, ) @@ -1684,5 +1688,60 @@ def test_compiled_matches_eager_per_task(self) -> None: shutil.rmtree(tmpdir, ignore_errors=True) +class TestFormatTrainingMessageStepTime(unittest.TestCase): + """The pt_expt trainer reports the average wall time per step over each + display interval by passing ``step_time`` to ``format_training_message`` + (replacing the former standalone ``step=... step_time=...`` debug line). + These tests cover both branches of the optional ``step_time``/``eta`` + arguments so the "step time" segment is rendered only when requested. + """ + + def test_without_step_time(self) -> None: + """``step_time=None`` (default) omits the step-time segment.""" + msg = format_training_message(batch=100, wall_time=18.41) + self.assertEqual(msg, "Batch 100: total wall time = 18.41 s") + self.assertNotIn("step time", msg) + + def test_with_step_time(self) -> None: + """``step_time`` is rendered with 4 decimals after the wall time.""" + msg = format_training_message(batch=100, wall_time=18.41, step_time=0.1841) + self.assertEqual( + msg, + "Batch 100: total wall time = 18.41 s, step time = 0.1841 s", + ) + + def test_step_time_zero_is_shown(self) -> None: + """A literal ``0.0`` step time is still shown (not treated as absent).""" + msg = format_training_message(batch=1, wall_time=0.5, step_time=0.0) + self.assertIn("step time = 0.0000 s", msg) + + def test_with_step_time_and_eta(self) -> None: + """Step time appears before the eta segment.""" + current_time = datetime.datetime( + 2026, 6, 7, 5, 21, 29, tzinfo=datetime.timezone.utc + ) + msg = format_training_message( + batch=100, + wall_time=18.41, + eta=100, + current_time=current_time, + step_time=0.1841, + ) + self.assertIn("total wall time = 18.41 s, step time = 0.1841 s, eta = ", msg) + # ordering: wall time -> step time -> eta + self.assertLess(msg.index("step time"), msg.index("eta =")) + + def test_eta_without_step_time(self) -> None: + """Eta still works when no step time is supplied.""" + current_time = datetime.datetime( + 2026, 6, 7, 5, 21, 29, tzinfo=datetime.timezone.utc + ) + msg = format_training_message( + batch=100, wall_time=18.41, eta=100, current_time=current_time + ) + self.assertNotIn("step time", msg) + self.assertIn("eta = ", msg) + + if __name__ == "__main__": unittest.main() From e2639d0a34b537990fda8aa8771bc90d083c109b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 8 Jun 2026 10:08:11 +0800 Subject: [PATCH 2/2] fix(pt_expt): render step time as 'avg = X.XXXX s/step' Use an explicit per-step rate unit (avg = 0.1841 s/step) instead of the redundant 'step time = 0.1841 s', keeping the interval 'total wall time' unchanged. --- deepmd/loggers/training.py | 2 +- source/tests/pt_expt/test_training.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/deepmd/loggers/training.py b/deepmd/loggers/training.py index f09199abd1..d145f42897 100644 --- a/deepmd/loggers/training.py +++ b/deepmd/loggers/training.py @@ -64,7 +64,7 @@ def format_training_message( """ msg = f"Batch {batch:7d}: total wall time = {wall_time:.2f} s" if step_time is not None: - msg += f", step time = {step_time:.4f} s" + msg += f", avg = {step_time:.4f} s/step" if isinstance(eta, int): eta_seconds = int(eta) msg += ( diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 0079480e3f..cbb8368074 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -1693,27 +1693,27 @@ class TestFormatTrainingMessageStepTime(unittest.TestCase): display interval by passing ``step_time`` to ``format_training_message`` (replacing the former standalone ``step=... step_time=...`` debug line). These tests cover both branches of the optional ``step_time``/``eta`` - arguments so the "step time" segment is rendered only when requested. + arguments so the "avg = ... s/step" segment is rendered only when requested. """ def test_without_step_time(self) -> None: """``step_time=None`` (default) omits the step-time segment.""" msg = format_training_message(batch=100, wall_time=18.41) self.assertEqual(msg, "Batch 100: total wall time = 18.41 s") - self.assertNotIn("step time", msg) + self.assertNotIn("s/step", msg) def test_with_step_time(self) -> None: """``step_time`` is rendered with 4 decimals after the wall time.""" msg = format_training_message(batch=100, wall_time=18.41, step_time=0.1841) self.assertEqual( msg, - "Batch 100: total wall time = 18.41 s, step time = 0.1841 s", + "Batch 100: total wall time = 18.41 s, avg = 0.1841 s/step", ) def test_step_time_zero_is_shown(self) -> None: """A literal ``0.0`` step time is still shown (not treated as absent).""" msg = format_training_message(batch=1, wall_time=0.5, step_time=0.0) - self.assertIn("step time = 0.0000 s", msg) + self.assertIn("avg = 0.0000 s/step", msg) def test_with_step_time_and_eta(self) -> None: """Step time appears before the eta segment.""" @@ -1727,9 +1727,9 @@ def test_with_step_time_and_eta(self) -> None: current_time=current_time, step_time=0.1841, ) - self.assertIn("total wall time = 18.41 s, step time = 0.1841 s, eta = ", msg) + self.assertIn("total wall time = 18.41 s, avg = 0.1841 s/step, eta = ", msg) # ordering: wall time -> step time -> eta - self.assertLess(msg.index("step time"), msg.index("eta =")) + self.assertLess(msg.index("s/step"), msg.index("eta =")) def test_eta_without_step_time(self) -> None: """Eta still works when no step time is supplied.""" @@ -1739,7 +1739,7 @@ def test_eta_without_step_time(self) -> None: msg = format_training_message( batch=100, wall_time=18.41, eta=100, current_time=current_time ) - self.assertNotIn("step time", msg) + self.assertNotIn("s/step", msg) self.assertIn("eta = ", msg)