diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index cb722708e683..f9a1fa7ad162 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -833,7 +833,7 @@ def __init__(self, def get_lr_ratio(self): if self.last_batch_iteration < 0: logger.warning("Attempting to get learning rate from scheduler before it has started") - return [0.0] + return 0.0 if self.last_batch_iteration < self.warmup_num_steps: if self.warmup_type == WARMUP_LOG_RATE: @@ -860,7 +860,7 @@ def step(self, last_batch_iteration=None): def get_lr(self): if self.last_batch_iteration < 0: logger.warning("Attempting to get learning rate from scheduler before it has started") - return [0.0] + return [0.0 for _ in self.org_lrs] lr_ratio = self.get_lr_ratio() return [org_lr * lr_ratio for org_lr in self.org_lrs] diff --git a/tests/unit/runtime/test_lr_schedulers.py b/tests/unit/runtime/test_lr_schedulers.py index 47734c0cd864..1dfa853dbe05 100644 --- a/tests/unit/runtime/test_lr_schedulers.py +++ b/tests/unit/runtime/test_lr_schedulers.py @@ -3,6 +3,8 @@ # DeepSpeed Team +import math + import torch import deepspeed import pytest @@ -13,7 +15,7 @@ from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE from deepspeed.runtime.lr_schedules import CYCLE_MIN_MOM, CYCLE_MAX_MOM, DECAY_MOM_RATE from deepspeed.runtime.lr_schedules import WARMUP_DECAY_LR, TOTAL_NUM_STEPS -from deepspeed.runtime.lr_schedules import WARMUP_COSINE_LR, WARMUP_MIN_RATIO, COS_MIN_RATIO +from deepspeed.runtime.lr_schedules import WARMUP_COSINE_LR, WARMUP_MIN_RATIO, COS_MIN_RATIO, WarmupCosineLR def _verify_continuous_decrease(values): @@ -518,3 +520,26 @@ def test_lr(self, total_num_steps, warmup_num_steps, cos_min_ratio, warmup_min_r # Verify decreasing phase _verify_continuous_decrease(step_lrs[warmup_num_steps:total_num_steps]) + + +def test_warmup_cosine_lr_initializes_all_param_groups(): + dense = torch.nn.Parameter(torch.zeros(1)) + expert = torch.nn.Parameter(torch.zeros(1)) + optimizer = torch.optim.Adam([{"params": [dense], "lr": 0.0015}, {"params": [expert], "lr": 0.003}]) + + scheduler = WarmupCosineLR(optimizer=optimizer, total_num_steps=100, warmup_num_steps=10, warmup_min_ratio=0.0) + + assert scheduler.get_lr_ratio() == 0.0 + assert scheduler.get_lr() == [0.0, 0.0] + assert scheduler.get_last_lr() == [0.0, 0.0] + assert [group["lr"] for group in optimizer.param_groups] == [0.0, 0.0] + + scheduler.step(1) + + expected_ratio = math.log(2) / math.log(10) + expected_lrs = [0.0015 * expected_ratio, 0.003 * expected_ratio] + + assert scheduler.get_lr_ratio() == pytest.approx(expected_ratio) + assert scheduler.get_lr() == pytest.approx(expected_lrs) + assert scheduler.get_last_lr() == pytest.approx(expected_lrs) + assert [group["lr"] for group in optimizer.param_groups] == pytest.approx(expected_lrs)