From 10735fdd88f2ee3c5873a4f6065eaa71d2f78438 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 11 Jan 2026 00:53:46 +0800 Subject: [PATCH 1/2] chore(pd): sync get_lr from pt to pd --- deepmd/pd/train/training.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 4e5fea081f..78b24645af 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -14,6 +14,7 @@ ) import numpy as np +from deepmd.dpmodel.utils.learning_rate import BaseLR import paddle import paddle.distributed as dist from paddle.distributed import ( @@ -238,13 +239,10 @@ def get_sample(): _stat_file_path.root.close() return get_sample - def get_lr(lr_params): - assert lr_params.get("type", "exp") == "exp", ( - "Only learning rate `exp` is supported!" - ) + def get_lr(lr_params: dict[str, Any]) -> BaseLR: lr_params["stop_steps"] = self.num_steps - self.warmup_steps - lr_exp = LearningRateExp(**lr_params) - return lr_exp + lr_schedule = BaseLR(**lr_params) + return lr_schedule # Optimizer if self.multi_task and training_params.get("optim_dict", None) is not None: From bf21f3131d2799ab7208985c112a342b029edbe8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 10 Jan 2026 16:55:48 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pd/train/training.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 78b24645af..dd0fbdc94b 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -14,7 +14,6 @@ ) import numpy as np -from deepmd.dpmodel.utils.learning_rate import BaseLR import paddle import paddle.distributed as dist from paddle.distributed import ( @@ -31,6 +30,9 @@ from deepmd.common import ( symlink_prefix_files, ) +from deepmd.dpmodel.utils.learning_rate import ( + BaseLR, +) from deepmd.loggers.training import ( format_training_message, format_training_message_per_task, @@ -63,9 +65,6 @@ SAMPLER_RECORD, enable_prim, ) -from deepmd.pd.utils.learning_rate import ( - LearningRateExp, -) from deepmd.pd.utils.stat import ( make_stat_input, )