diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 7a6ff0ebde..193dcd8cb9 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -935,12 +935,19 @@ def log_loss_valid(_task_key="Default"): eta=eta, ) ) - # the first training time is not accurate if ( - (_step_id + 1 - self.start_step) > self.disp_freq - or self.num_steps - self.start_step < 2 * self.disp_freq + (self.num_steps - self.start_step) + <= 2 * self.disp_freq # not enough steps + or (_step_id - self.start_step) + >= self.disp_freq # skip first disp_freq steps ): self.total_train_time += train_time + if display_step_id == 1: + self.timed_steps += 1 + else: + self.timed_steps += min( + self.disp_freq, _step_id - self.start_step + ) if fout: if self.lcurve_should_print_header: @@ -951,11 +958,14 @@ def log_loss_valid(_task_key="Default"): ) if ( - ((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step) - or (_step_id + 1) == self.num_steps + ( + (display_step_id) % self.save_freq == 0 + and _step_id != self.start_step + ) + or (display_step_id) == self.num_steps ) and (self.rank == 0 or dist.get_rank() == 0): # Handle the case if rank 0 aborted and re-assigned - self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pt") + self.latest_model = Path(self.save_ckpt + f"-{display_step_id}.pt") module = ( self.wrapper.module @@ -982,6 +992,7 @@ def log_loss_valid(_task_key="Default"): self.wrapper.train() self.t0 = time.time() self.total_train_time = 0.0 + self.timed_steps = 0 for step_id in range(self.start_step, self.num_steps): step(step_id) if JIT: @@ -1021,24 +1032,12 @@ def log_loss_valid(_task_key="Default"): with open("checkpoint", "w") as f: f.write(str(self.latest_model)) - elapsed_batch = self.num_steps - self.start_step - if self.timing_in_training and elapsed_batch // self.disp_freq > 0: - if self.start_step >= 2 * self.disp_freq: - log.info( - "average training time: %.4f s/batch (exclude first %d batches)", - self.total_train_time - / ( - elapsed_batch // self.disp_freq * self.disp_freq - - self.disp_freq - ), - self.disp_freq, - ) - else: - log.info( - "average training time: %.4f s/batch", - self.total_train_time - / (elapsed_batch // self.disp_freq * self.disp_freq), - ) + if self.timing_in_training and self.timed_steps: + msg = f"average training time: {self.total_train_time / self.timed_steps:.4f} s/batch" + excluded_steps = self.num_steps - self.start_step - self.timed_steps + if excluded_steps > 0: + msg += f" ({excluded_steps} batches excluded)" + log.info(msg) if JIT: pth_model_path = (