diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index bc1ec7bc1880..12867437d9dd 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1599,15 +1599,21 @@ def _configure_basic_optimizer(self, model_parameters): param_groups = [] if muon_params: accepted_parameters = dict() - for key in ["lr", "momentum", "weight_decay"]: + for key in ["lr", "momentum", "weight_decay", "muon_lr"]: if key in optimizer_parameters: - accepted_parameters[key] = optimizer_parameters[key] + if key == "muon_lr": # muon_lr will override lr + accepted_parameters['lr'] = optimizer_parameters[key] + else: + accepted_parameters[key] = optimizer_parameters[key] param_groups.append(dict(params=muon_params, use_muon=True, **accepted_parameters)) if non_muon_params: accepted_parameters = dict() - for key in ["lr", "betas", "eps", "weight_decay"]: + for key in ["lr", "betas", "eps", "weight_decay", "adam_lr"]: if key in optimizer_parameters: - accepted_parameters[key] = optimizer_parameters[key] + if key == "adam_lr": # adam_lr will override lr + accepted_parameters['lr'] = optimizer_parameters[key] + else: + accepted_parameters[key] = optimizer_parameters[key] param_groups.append(dict(params=non_muon_params, use_muon=False, **accepted_parameters)) optimizer = MuonWithAuxAdam(param_groups) else: