diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 0ff74695b553..cb722708e683 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -13,6 +13,7 @@ from torch.optim import Optimizer import math from deepspeed.utils import logger +from torch import tensor, is_tensor LR_SCHEDULE = 'lr_schedule' LR_RANGE_TEST = 'LRRangeTest' @@ -249,6 +250,9 @@ def get_lr_from_config(config): def update_lr(param_groups, lrs): for param_group, lr in zip(param_groups, lrs): + # new LR should match the type of current LR for scalar and Tensor LR support + if is_tensor(param_group['lr']): + lr = tensor([lr], device=param_group['lr'].device) param_group['lr'] = lr return [group['lr'] for group in param_groups]