From 66b088c8e802de2ce4228cb56c28b7e463a9059a Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 7 May 2026 11:50:13 +0800 Subject: [PATCH 1/3] fix(pt): base LambdaLR on configured start_lr --- deepmd/pt/train/training.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index e674eb4b33..5b64581ef5 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -939,12 +939,10 @@ def single_model_finetune( **extra, ) self._load_optimizer_state(optimizer_state_dict) - self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.scheduler = self._create_lr_scheduler( self.optimizer, - lambda step: ( - self.lr_schedule.value(step + self.start_step) / initial_lr - ), - last_epoch=self.start_step - 1, + self.lr_schedule, + self.start_step, ) if self.zero_stage > 0 and self.rank == 0: @@ -975,6 +973,21 @@ def single_model_finetune( if self.rank == 0: self._log_parameter_count() + @staticmethod + def _create_lr_scheduler( + optimizer: torch.optim.Optimizer, + lr_schedule: BaseLR, + start_step: int, + ) -> torch.optim.lr_scheduler.LambdaLR: + base_lr = float(lr_schedule.start_lr) + for group in optimizer.param_groups: + group["initial_lr"] = base_lr + return torch.optim.lr_scheduler.LambdaLR( + optimizer, + lambda step: lr_schedule.value(step) / base_lr, + last_epoch=start_step - 1, + ) + def _create_full_validator( self, *, From b997baee6b0729a7f7459c9878b3a9eb08889a38 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 7 May 2026 19:44:24 +0800 Subject: [PATCH 2/3] fixup --- deepmd/dpmodel/utils/learning_rate.py | 20 +++--- source/tests/pt/test_training.py | 61 +++++++++++++++++++ .../dpmodel/utils/test_learning_rate.py | 12 ++++ 3 files changed, 84 insertions(+), 9 deletions(-) diff --git a/deepmd/dpmodel/utils/learning_rate.py b/deepmd/dpmodel/utils/learning_rate.py index 545376a1b8..d26432528c 100644 --- a/deepmd/dpmodel/utils/learning_rate.py +++ b/deepmd/dpmodel/utils/learning_rate.py @@ -70,7 +70,11 @@ def __init__( The warmup learning rate starts from warmup_start_factor * start_lr. Default is 0.0. """ - # === Step 1. Validate stop_lr and stop_lr_ratio (runtime check) === + # === Step 1. Validate start_lr (runtime check) === + if start_lr <= 0 or not np.isfinite(start_lr): + raise ValueError(f"start_lr ({start_lr}) must be positive and finite.") + + # === Step 2. Validate stop_lr and stop_lr_ratio (runtime check) === has_stop_lr = stop_lr is not None has_stop_lr_ratio = stop_lr_ratio is not None @@ -85,13 +89,13 @@ def __init__( "Got stop_lr=None, stop_lr_ratio=None" ) - # === Step 2. Compute stop_lr from stop_lr_ratio if needed === + # === Step 3. Compute stop_lr from stop_lr_ratio if needed === if stop_lr_ratio is not None: self.stop_lr = start_lr * stop_lr_ratio else: self.stop_lr = stop_lr - # === Step 3. Validate warmup_steps and warmup_ratio (runtime check) === + # === Step 4. Validate warmup_steps and warmup_ratio (runtime check) === has_warmup_steps = warmup_steps != 0 has_warmup_ratio = warmup_ratio is not None @@ -101,13 +105,13 @@ def __init__( f"Got warmup_steps={warmup_steps}, warmup_ratio={warmup_ratio}" ) - # === Step 4. Compute warmup_steps from warmup_ratio if needed === + # === Step 5. Compute warmup_steps from warmup_ratio if needed === if warmup_ratio is not None: self.warmup_steps = int(warmup_ratio * num_steps) else: self.warmup_steps = warmup_steps - # === Step 5. Validate step ranges (runtime check) === + # === Step 6. Validate step ranges (runtime check) === if num_steps < 0: raise ValueError("num_steps must be non-negative") if self.warmup_steps < 0: @@ -117,10 +121,10 @@ def __init__( if num_steps == 0 and self.warmup_steps != 0: raise ValueError("warmup_steps must be 0 when num_steps is 0") - # === Step 6. Compute warmup_start_lr === + # === Step 7. Compute warmup_start_lr === self.warmup_start_lr = warmup_start_factor * start_lr - # === Step 7. Store core parameters === + # === Step 8. Store core parameters === self._start_lr = start_lr self.num_steps = num_steps # Decay phase covers (num_steps - warmup_steps) steps @@ -493,8 +497,6 @@ def __init__( ) # === Validate WSD-specific invariants === - if self._start_lr <= 0: - raise ValueError(f"start_lr ({self._start_lr}) must be positive.") if self.stop_lr <= 0: raise ValueError(f"stop_lr ({self.stop_lr}) must be positive.") if decay_phase_ratio <= 0 or decay_phase_ratio > 1: diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 9e840dd9f2..66b689ba04 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -771,6 +771,67 @@ def test_fitting_stat_consistency(self) -> None: ) +class TestLearningRateRestart(unittest.TestCase): + def setUp(self) -> None: + self._cwd = os.getcwd() + self._tmpdir = tempfile.TemporaryDirectory() + os.chdir(self._tmpdir.name) + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config = convert_optimizer_v31_to_v32(self.config, warning=False) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["learning_rate"] = { + "type": "wsd", + "start_lr": 5e-4, + "stop_lr": 1e-6, + "warmup_steps": 2, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.5, + "decay_type": "cosine", + } + self.config["training"]["numb_steps"] = 3 + self.config["training"]["save_freq"] = 3 + self.config["training"]["disp_freq"] = 1 + self.config["training"]["disp_training"] = False + self.config["training"]["time_training"] = False + + def tearDown(self) -> None: + os.chdir(self._cwd) + self._tmpdir.cleanup() + + def test_restart_scheduler_matches_lr_schedule(self) -> None: + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + restart_model = Path("model-3.pt") + checkpoint = torch.load(restart_model, map_location="cpu", weights_only=True) + stale_initial_lr = trainer.lr_schedule.value(0) + for group in checkpoint["optimizer"]["param_groups"]: + group["initial_lr"] = stale_initial_lr + torch.save(checkpoint, restart_model) + + restart_config = deepcopy(self.config) + restart_config["training"]["numb_steps"] = 5 + restart_trainer = get_trainer( + restart_config, + restart_model=str(restart_model), + ) + + np.testing.assert_allclose( + restart_trainer.scheduler.get_last_lr()[0], + restart_trainer.lr_schedule.value(restart_trainer.start_step), + rtol=1e-12, + ) + restart_trainer.run() + np.testing.assert_allclose( + restart_trainer.scheduler.get_last_lr()[0], + restart_trainer.lr_schedule.value(restart_config["training"]["numb_steps"]), + rtol=1e-12, + ) + class TestFullValidation(unittest.TestCase): def setUp(self) -> None: self._cwd = os.getcwd() diff --git a/source/tests/universal/dpmodel/utils/test_learning_rate.py b/source/tests/universal/dpmodel/utils/test_learning_rate.py index d6eaff35b4..8406f7f352 100644 --- a/source/tests/universal/dpmodel/utils/test_learning_rate.py +++ b/source/tests/universal/dpmodel/utils/test_learning_rate.py @@ -50,6 +50,18 @@ def test_decay_rate_override(self) -> None: self.assertEqual(lr.decay_rate, 0.9) np.testing.assert_allclose(lr.value(1000), 1e-3 * 0.9, rtol=1e-10) + def test_rejects_nonpositive_or_nonfinite_start_lr(self) -> None: + """Test invalid start_lr values are rejected by the base schedule.""" + for start_lr in (0.0, -1e-3, np.inf, np.nan): + with self.subTest(start_lr=start_lr): + with self.assertRaisesRegex(ValueError, "start_lr"): + LearningRateExp( + start_lr=start_lr, + stop_lr=1e-5, + num_steps=10000, + decay_steps=5000, + ) + class TestLearningRateCosineBasic(unittest.TestCase): """Test basic cosine annealing learning rate functionality.""" From aad461eacdf786f1fa71a20954dc2d8d1489185e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 May 2026 11:45:23 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- source/tests/pt/test_training.py | 1 + 1 file changed, 1 insertion(+) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index 66b689ba04..6a299d9d48 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -832,6 +832,7 @@ def test_restart_scheduler_matches_lr_schedule(self) -> None: rtol=1e-12, ) + class TestFullValidation(unittest.TestCase): def setUp(self) -> None: self._cwd = os.getcwd()