Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepmd/loggers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def format_training_message(
wall_time: float,
eta: int | None = None,
current_time: datetime.datetime | None = None,
step_time: float | None = None,
) -> str:
"""Format the summary message for one training interval.

Expand All @@ -52,13 +53,18 @@ def format_training_message(
current_time : datetime.datetime | None, optional
Current local time used to estimate the finish timestamp. This is only
used when ``eta`` is provided.
step_time : float | None, optional
Average wall-clock time per training step over this interval, in
seconds. Shown only when provided.

Returns
-------
str
The formatted training message.
"""
msg = f"Batch {batch:7d}: total wall time = {wall_time:.2f} s"
if step_time is not None:
msg += f", avg = {step_time:.4f} s/step"
if isinstance(eta, int):
eta_seconds = int(eta)
msg += (
Expand Down
16 changes: 8 additions & 8 deletions deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,7 @@ def run(self) -> None:
self.wrapper.train()
wall_start = time.time()
last_log_time = wall_start
last_log_step = self.start_step

for step_id in range(self.start_step, self.num_steps):
cur_lr = float(self.lr_schedule.value(step_id))
Expand All @@ -1465,9 +1466,6 @@ def run(self) -> None:
)
task_key = self.model_keys[model_index]

if self.timing_in_training:
t_start = time.time()

# --- forward / backward ---
self.optimizer.zero_grad(set_to_none=True)
input_dict, label_dict = self.get_data(is_train=True, task_key=task_key)
Expand All @@ -1488,9 +1486,6 @@ def run(self) -> None:

self._optimizer_step()

if self.timing_in_training:
t_end = time.time()

# --- display ---
display_step_id = step_id + 1
if self.display_in_training and (
Expand Down Expand Up @@ -1598,9 +1593,14 @@ def _to_float(v: Any) -> float:
current_time = time.time()
wall_elapsed = current_time - wall_start
interval_wall_time = current_time - last_log_time
# average wall time per step over the interval since the
# last log (number of steps counted exactly once across
# intervals via last_log_step)
interval_steps = max(1, display_step_id - last_log_step)
step_time = interval_wall_time / interval_steps
last_log_time = current_time
last_log_step = display_step_id
if self.timing_in_training:
step_time = t_end - t_start
steps_completed_since_restart = max(
1,
display_step_id - self.start_step,
Expand All @@ -1619,9 +1619,9 @@ def _to_float(v: Any) -> float:
current_time,
tz=datetime.timezone.utc,
).astimezone(),
step_time=step_time,
)
)
log.info("step=%d step_time=%.4fs", display_step_id, step_time)
else:
log.info(
format_training_message(
Expand Down
59 changes: 59 additions & 0 deletions source/tests/pt_expt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
4. Loss decreases over those steps
"""

import datetime
import os
import shutil
import tempfile
Expand All @@ -18,6 +19,9 @@

import torch

from deepmd.loggers.training import (
format_training_message,
)
from deepmd.pt_expt.entrypoints.main import (
get_trainer,
)
Expand Down Expand Up @@ -1684,5 +1688,60 @@ def test_compiled_matches_eager_per_task(self) -> None:
shutil.rmtree(tmpdir, ignore_errors=True)


class TestFormatTrainingMessageStepTime(unittest.TestCase):
"""The pt_expt trainer reports the average wall time per step over each
display interval by passing ``step_time`` to ``format_training_message``
(replacing the former standalone ``step=... step_time=...`` debug line).
These tests cover both branches of the optional ``step_time``/``eta``
arguments so the "avg = ... s/step" segment is rendered only when requested.
"""

def test_without_step_time(self) -> None:
"""``step_time=None`` (default) omits the step-time segment."""
msg = format_training_message(batch=100, wall_time=18.41)
self.assertEqual(msg, "Batch 100: total wall time = 18.41 s")
self.assertNotIn("s/step", msg)

def test_with_step_time(self) -> None:
"""``step_time`` is rendered with 4 decimals after the wall time."""
msg = format_training_message(batch=100, wall_time=18.41, step_time=0.1841)
self.assertEqual(
msg,
"Batch 100: total wall time = 18.41 s, avg = 0.1841 s/step",
)

def test_step_time_zero_is_shown(self) -> None:
"""A literal ``0.0`` step time is still shown (not treated as absent)."""
msg = format_training_message(batch=1, wall_time=0.5, step_time=0.0)
self.assertIn("avg = 0.0000 s/step", msg)

def test_with_step_time_and_eta(self) -> None:
"""Step time appears before the eta segment."""
current_time = datetime.datetime(
2026, 6, 7, 5, 21, 29, tzinfo=datetime.timezone.utc
)
msg = format_training_message(
batch=100,
wall_time=18.41,
eta=100,
current_time=current_time,
step_time=0.1841,
)
self.assertIn("total wall time = 18.41 s, avg = 0.1841 s/step, eta = ", msg)
# ordering: wall time -> step time -> eta
self.assertLess(msg.index("s/step"), msg.index("eta ="))

def test_eta_without_step_time(self) -> None:
"""Eta still works when no step time is supplied."""
current_time = datetime.datetime(
2026, 6, 7, 5, 21, 29, tzinfo=datetime.timezone.utc
)
msg = format_training_message(
batch=100, wall_time=18.41, eta=100, current_time=current_time
)
self.assertNotIn("s/step", msg)
self.assertIn("eta = ", msg)


if __name__ == "__main__":
unittest.main()
Loading