From 5e794872f531163f55eb821bc386f4f8c66c66a2 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 21 May 2025 14:49:58 +0800 Subject: [PATCH 1/2] feat(pt): add AdamW for pt training --- deepmd/pt/train/training.py | 22 +++++++++++++++------- deepmd/utils/argcheck.py | 1 + 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index eab5601d55..a0e2c53575 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -158,6 +158,7 @@ def get_opt_param(params): "kf_limit_pref_e": params.get("kf_limit_pref_e", 1), "kf_start_pref_f": params.get("kf_start_pref_f", 1), "kf_limit_pref_f": params.get("kf_limit_pref_f", 1), + "weight_decay": params.get("weight_decay", 0.001), } return opt_type, opt_param @@ -609,12 +610,19 @@ def warm_up_linear(step, warmup_steps): # TODO add optimizers for multitask # author: iProzd - if self.opt_type == "Adam": - self.optimizer = torch.optim.Adam( - self.wrapper.parameters(), - lr=self.lr_exp.start_lr, - fused=False if DEVICE.type == "cpu" else True, - ) + if self.opt_type in ["Adam", "AdamW"]: + if self.opt_type == "Adam": + self.optimizer = torch.optim.Adam( + self.wrapper.parameters(), + lr=self.lr_exp.start_lr, + fused=False if DEVICE.type == "cpu" else True, + ) + else: + self.optimizer = torch.optim.AdamW( + self.wrapper.parameters(), + lr=self.lr_exp.start_lr, + weight_decay=float(self.opt_param["weight_decay"]), + ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict) self.scheduler = torch.optim.lr_scheduler.LambdaLR( @@ -710,7 +718,7 @@ def step(_step_id, task_key="Default") -> None: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) fout1.flush() - if self.opt_type == "Adam": + if self.opt_type in ["Adam", "AdamW"]: cur_lr = self.scheduler.get_last_lr()[0] if _step_id < self.warmup_steps: pref_lr = _lr.start_lr diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0260700165..dbfb71459e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3177,6 +3177,7 @@ def training_args( "opt_type", choices=[ Argument("Adam", dict, [], [], optional=True), + Argument("AdamW", dict, [], [], optional=True), Argument( "LKF", dict, From 6e20c181d1eefb07f0e01ce6bf66275d8a271c7c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 21 May 2025 23:04:31 +0800 Subject: [PATCH 2/2] Update training.py --- deepmd/pt/train/training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index a0e2c53575..af8f46d7a9 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -622,6 +622,7 @@ def warm_up_linear(step, warmup_steps): self.wrapper.parameters(), lr=self.lr_exp.start_lr, weight_decay=float(self.opt_param["weight_decay"]), + fused=False if DEVICE.type == "cpu" else True, ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict)