Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions deepmd/dpmodel/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,11 @@ def __init__(
The warmup learning rate starts from warmup_start_factor * start_lr.
Default is 0.0.
"""
# === Step 1. Validate stop_lr and stop_lr_ratio (runtime check) ===
# === Step 1. Validate start_lr (runtime check) ===
if start_lr <= 0 or not np.isfinite(start_lr):
raise ValueError(f"start_lr ({start_lr}) must be positive and finite.")

# === Step 2. Validate stop_lr and stop_lr_ratio (runtime check) ===
has_stop_lr = stop_lr is not None
has_stop_lr_ratio = stop_lr_ratio is not None

Expand All @@ -85,13 +89,13 @@ def __init__(
"Got stop_lr=None, stop_lr_ratio=None"
)

# === Step 2. Compute stop_lr from stop_lr_ratio if needed ===
# === Step 3. Compute stop_lr from stop_lr_ratio if needed ===
if stop_lr_ratio is not None:
self.stop_lr = start_lr * stop_lr_ratio
else:
self.stop_lr = stop_lr

# === Step 3. Validate warmup_steps and warmup_ratio (runtime check) ===
# === Step 4. Validate warmup_steps and warmup_ratio (runtime check) ===
has_warmup_steps = warmup_steps != 0
has_warmup_ratio = warmup_ratio is not None

Expand All @@ -101,13 +105,13 @@ def __init__(
f"Got warmup_steps={warmup_steps}, warmup_ratio={warmup_ratio}"
)

# === Step 4. Compute warmup_steps from warmup_ratio if needed ===
# === Step 5. Compute warmup_steps from warmup_ratio if needed ===
if warmup_ratio is not None:
self.warmup_steps = int(warmup_ratio * num_steps)
else:
self.warmup_steps = warmup_steps

# === Step 5. Validate step ranges (runtime check) ===
# === Step 6. Validate step ranges (runtime check) ===
if num_steps < 0:
raise ValueError("num_steps must be non-negative")
if self.warmup_steps < 0:
Expand All @@ -117,10 +121,10 @@ def __init__(
if num_steps == 0 and self.warmup_steps != 0:
raise ValueError("warmup_steps must be 0 when num_steps is 0")

# === Step 6. Compute warmup_start_lr ===
# === Step 7. Compute warmup_start_lr ===
self.warmup_start_lr = warmup_start_factor * start_lr

# === Step 7. Store core parameters ===
# === Step 8. Store core parameters ===
self._start_lr = start_lr
self.num_steps = num_steps
# Decay phase covers (num_steps - warmup_steps) steps
Expand Down Expand Up @@ -493,8 +497,6 @@ def __init__(
)

# === Validate WSD-specific invariants ===
if self._start_lr <= 0:
raise ValueError(f"start_lr ({self._start_lr}) must be positive.")
if self.stop_lr <= 0:
raise ValueError(f"stop_lr ({self.stop_lr}) must be positive.")
if decay_phase_ratio <= 0 or decay_phase_ratio > 1:
Expand Down
23 changes: 18 additions & 5 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,12 +939,10 @@ def single_model_finetune(
**extra,
)
self._load_optimizer_state(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.scheduler = self._create_lr_scheduler(
self.optimizer,
lambda step: (
self.lr_schedule.value(step + self.start_step) / initial_lr
),
last_epoch=self.start_step - 1,
self.lr_schedule,
self.start_step,
)

if self.zero_stage > 0 and self.rank == 0:
Expand Down Expand Up @@ -975,6 +973,21 @@ def single_model_finetune(
if self.rank == 0:
self._log_parameter_count()

@staticmethod
def _create_lr_scheduler(
optimizer: torch.optim.Optimizer,
lr_schedule: BaseLR,
start_step: int,
) -> torch.optim.lr_scheduler.LambdaLR:
base_lr = float(lr_schedule.start_lr)
for group in optimizer.param_groups:
group["initial_lr"] = base_lr
return torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda step: lr_schedule.value(step) / base_lr,
last_epoch=start_step - 1,
Comment thread
OutisLi marked this conversation as resolved.
)

def _create_full_validator(
self,
*,
Expand Down
62 changes: 62 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,68 @@ def test_fitting_stat_consistency(self) -> None:
)


class TestLearningRateRestart(unittest.TestCase):
def setUp(self) -> None:
self._cwd = os.getcwd()
self._tmpdir = tempfile.TemporaryDirectory()
os.chdir(self._tmpdir.name)
input_json = str(Path(__file__).parent / "water/se_atten.json")
with open(input_json) as f:
self.config = json.load(f)
self.config = convert_optimizer_v31_to_v32(self.config, warning=False)
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.config["training"]["training_data"]["systems"] = data_file
self.config["training"]["validation_data"]["systems"] = data_file
self.config["model"] = deepcopy(model_se_e2_a)
self.config["learning_rate"] = {
"type": "wsd",
"start_lr": 5e-4,
"stop_lr": 1e-6,
"warmup_steps": 2,
"warmup_start_factor": 0.2,
"decay_phase_ratio": 0.5,
"decay_type": "cosine",
}
self.config["training"]["numb_steps"] = 3
self.config["training"]["save_freq"] = 3
self.config["training"]["disp_freq"] = 1
self.config["training"]["disp_training"] = False
self.config["training"]["time_training"] = False

def tearDown(self) -> None:
os.chdir(self._cwd)
self._tmpdir.cleanup()

def test_restart_scheduler_matches_lr_schedule(self) -> None:
trainer = get_trainer(deepcopy(self.config))
trainer.run()
restart_model = Path("model-3.pt")
checkpoint = torch.load(restart_model, map_location="cpu", weights_only=True)
stale_initial_lr = trainer.lr_schedule.value(0)
for group in checkpoint["optimizer"]["param_groups"]:
group["initial_lr"] = stale_initial_lr
torch.save(checkpoint, restart_model)

restart_config = deepcopy(self.config)
restart_config["training"]["numb_steps"] = 5
restart_trainer = get_trainer(
restart_config,
restart_model=str(restart_model),
)

np.testing.assert_allclose(
restart_trainer.scheduler.get_last_lr()[0],
restart_trainer.lr_schedule.value(restart_trainer.start_step),
rtol=1e-12,
)
restart_trainer.run()
np.testing.assert_allclose(
restart_trainer.scheduler.get_last_lr()[0],
restart_trainer.lr_schedule.value(restart_config["training"]["numb_steps"]),
rtol=1e-12,
)


class TestFullValidation(unittest.TestCase):
def setUp(self) -> None:
self._cwd = os.getcwd()
Expand Down
12 changes: 12 additions & 0 deletions source/tests/universal/dpmodel/utils/test_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def test_decay_rate_override(self) -> None:
self.assertEqual(lr.decay_rate, 0.9)
np.testing.assert_allclose(lr.value(1000), 1e-3 * 0.9, rtol=1e-10)

def test_rejects_nonpositive_or_nonfinite_start_lr(self) -> None:
"""Test invalid start_lr values are rejected by the base schedule."""
for start_lr in (0.0, -1e-3, np.inf, np.nan):
with self.subTest(start_lr=start_lr):
with self.assertRaisesRegex(ValueError, "start_lr"):
LearningRateExp(
start_lr=start_lr,
stop_lr=1e-5,
num_steps=10000,
decay_steps=5000,
)


class TestLearningRateCosineBasic(unittest.TestCase):
"""Test basic cosine annealing learning rate functionality."""
Expand Down
Loading