From dfd08cff142c76c27841d44fd89d29a0fb2ccfe3 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 7 May 2025 16:57:42 +0800 Subject: [PATCH] add eta for pt --- deepmd/loggers/training.py | 7 ++++++- deepmd/pd/train/training.py | 20 ++------------------ deepmd/pt/train/training.py | 4 ++++ 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/deepmd/loggers/training.py b/deepmd/loggers/training.py index 16d3eb1618..5de7926460 100644 --- a/deepmd/loggers/training.py +++ b/deepmd/loggers/training.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import datetime from typing import ( Optional, ) @@ -7,9 +8,13 @@ def format_training_message( batch: int, wall_time: float, + eta: Optional[int] = None, ) -> str: """Format a training message.""" - return f"batch {batch:7d}: total wall time = {wall_time:.2f} s" + msg = f"batch {batch:7d}: total wall time = {wall_time:.2f} s" + if isinstance(eta, int): + msg += f", eta = {datetime.timedelta(seconds=int(eta))!s}" + return msg def format_training_message_per_task( diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 8c1e55490f..c914ee46a8 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import datetime import functools import logging import time @@ -30,6 +29,7 @@ symlink_prefix_files, ) from deepmd.loggers.training import ( + format_training_message, format_training_message_per_task, ) from deepmd.pd.loss import ( @@ -78,22 +78,6 @@ log = logging.getLogger(__name__) -from typing import ( - Optional, -) - - -def format_training_message( - batch: int, - wall_time: float, - eta: Optional[int] = None, -): - """Format a training message.""" - msg = f"batch {batch:7d}: total wall time = {wall_time:.2f} s" - if isinstance(eta, int): - msg += f", eta = {datetime.timedelta(seconds=int(eta))!s}" - return msg - class Trainer: def __init__( @@ -863,7 +847,7 @@ def log_loss_valid(_task_key="Default"): self.t0 = current_time if self.rank == 0 and self.timing_in_training: eta = int( - (self.num_steps - _step_id - 1) / self.disp_freq * train_time + (self.num_steps - display_step_id) / self.disp_freq * train_time ) log.info( format_training_message( diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 63e0180ace..114cae18bf 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -899,10 +899,14 @@ def log_loss_valid(_task_key="Default"): train_time = current_time - self.t0 self.t0 = current_time if self.rank == 0 and self.timing_in_training: + eta = int( + (self.num_steps - display_step_id) / self.disp_freq * train_time + ) log.info( format_training_message( batch=display_step_id, wall_time=train_time, + eta=eta, ) ) # the first training time is not accurate