Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,6 @@ def step(_step_id, task_key="Default") -> None:
# PyTorch Profiler
if self.enable_profiler or self.profiling:
prof.step()
self.wrapper.train()
if isinstance(self.lr_exp, dict):
_lr = self.lr_exp[task_key]
else:
Expand All @@ -682,12 +681,11 @@ def step(_step_id, task_key="Default") -> None:
)
loss.backward()
if self.gradient_max_norm > 0.0:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.wrapper.parameters(), self.gradient_max_norm
torch.nn.utils.clip_grad_norm_(
self.wrapper.parameters(),
self.gradient_max_norm,
error_if_nonfinite=True,
)
if not torch.isfinite(grad_norm).all():
# check local gradnorm single GPU case, trigger NanDetector
raise FloatingPointError("gradients are Nan/Inf")
with torch.device("cpu"):
self.optimizer.step()
self.scheduler.step()
Expand Down Expand Up @@ -766,7 +764,7 @@ def fake_model():
if self.display_in_training and (
display_step_id % self.disp_freq == 0 or display_step_id == 1
):
self.wrapper.eval()
self.wrapper.eval() # Will set to train mode before fininshing validation

def log_loss_train(_loss, _more_loss, _task_key="Default"):
results = {}
Expand Down Expand Up @@ -872,6 +870,7 @@ def log_loss_valid(_task_key="Default"):
learning_rate=None,
)
)
self.wrapper.train()

current_time = time.time()
train_time = current_time - self.t0
Expand Down Expand Up @@ -927,6 +926,7 @@ def log_loss_valid(_task_key="Default"):
f"{task_key}/{item}", more_loss[item], display_step_id
)

self.wrapper.train()
self.t0 = time.time()
self.total_train_time = 0.0
for step_id in range(self.num_steps):
Expand Down