Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,12 +935,19 @@
eta=eta,
)
)
# the first training time is not accurate
if (
(_step_id + 1 - self.start_step) > self.disp_freq
or self.num_steps - self.start_step < 2 * self.disp_freq
(self.num_steps - self.start_step)
<= 2 * self.disp_freq # not enough steps
Comment thread
njzjz marked this conversation as resolved.
or (_step_id - self.start_step)
>= self.disp_freq # skip first disp_freq steps
):
self.total_train_time += train_time
if display_step_id == 1:
self.timed_steps += 1
else:
self.timed_steps += min(

Check warning on line 948 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L948

Added line #L948 was not covered by tests
self.disp_freq, _step_id - self.start_step
)

if fout:
if self.lcurve_should_print_header:
Expand All @@ -951,11 +958,14 @@
)

if (
((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step)
or (_step_id + 1) == self.num_steps
(
(display_step_id) % self.save_freq == 0
and _step_id != self.start_step
)
or (display_step_id) == self.num_steps
) and (self.rank == 0 or dist.get_rank() == 0):
Comment thread
caic99 marked this conversation as resolved.
# Handle the case if rank 0 aborted and re-assigned
self.latest_model = Path(self.save_ckpt + f"-{_step_id + 1}.pt")
self.latest_model = Path(self.save_ckpt + f"-{display_step_id}.pt")

module = (
self.wrapper.module
Expand All @@ -982,6 +992,7 @@
self.wrapper.train()
self.t0 = time.time()
self.total_train_time = 0.0
self.timed_steps = 0
for step_id in range(self.start_step, self.num_steps):
step(step_id)
if JIT:
Expand Down Expand Up @@ -1021,24 +1032,12 @@
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))

elapsed_batch = self.num_steps - self.start_step
if self.timing_in_training and elapsed_batch // self.disp_freq > 0:
if self.start_step >= 2 * self.disp_freq:
log.info(
"average training time: %.4f s/batch (exclude first %d batches)",
self.total_train_time
/ (
elapsed_batch // self.disp_freq * self.disp_freq
- self.disp_freq
),
self.disp_freq,
)
else:
log.info(
"average training time: %.4f s/batch",
self.total_train_time
/ (elapsed_batch // self.disp_freq * self.disp_freq),
)
if self.timing_in_training and self.timed_steps:
msg = f"average training time: {self.total_train_time / self.timed_steps:.4f} s/batch"
excluded_steps = self.num_steps - self.start_step - self.timed_steps
if excluded_steps > 0:
msg += f" ({excluded_steps} batches excluded)"
log.info(msg)

if JIT:
pth_model_path = (
Expand Down