From cdafcc15cd2d6d8863acf29093a236292bce03d8 Mon Sep 17 00:00:00 2001 From: st_bang Date: Tue, 21 Apr 2026 14:33:03 +0900 Subject: [PATCH] Fix Adam subgroup inconsistency Signed-off-by: st_bang --- csrc/includes/cpu_adam.h | 12 +++--- csrc/xpu/includes/cpu_adam.h | 12 +++--- tests/unit/ops/adam/test_cpu_adam.py | 57 ++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 10 deletions(-) diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index e4fae63ce7cd..ddd8f2f933eb 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -63,14 +63,16 @@ class Adam_Optimizer { _betta1_t = std::pow(_betta1, step); _betta2_t = std::pow(_betta2, step); } else { - _step++; - if (_step != step) { + if (step == _step + 1) { // first optimizer step increase + _step++; + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } else if (step == _step) { // no need to update step; beta1_t and beta2_t already updated + return; + } else { // support step increase not equal to 1 _betta1_t = std::pow(_betta1, step); _betta2_t = std::pow(_betta2, step); _step = step; - } else { - _betta1_t *= _betta1; - _betta2_t *= _betta2; } } } diff --git a/csrc/xpu/includes/cpu_adam.h b/csrc/xpu/includes/cpu_adam.h index 7bc0364c569d..11217f7e5676 100644 --- a/csrc/xpu/includes/cpu_adam.h +++ b/csrc/xpu/includes/cpu_adam.h @@ -69,14 +69,16 @@ class Adam_Optimizer { _betta1_t = std::pow(_betta1, step); _betta2_t = std::pow(_betta2, step); } else { - _step++; - if (_step != step) { + if (step == _step + 1) { // first optimizer step increase + _step++; + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } else if (step == _step) { // no need to update step; beta1_t and beta2_t already updated + return; + } else { // support step increase not equal to 1 _betta1_t = std::pow(_betta1, step); _betta2_t = std::pow(_betta2, step); _step = step; - } else { - _betta1_t *= _betta1; - _betta2_t *= _betta2; } } } diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index d83b1732e700..074d87e570b4 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -312,3 +312,60 @@ def test_multiple_subgroups(self): optimizer.rollback_subgroup(0) assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented" assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged" + + + def test_step_subgroup_same_step_idempotent_across_subgroups(self): + """Repeated same-step subgroup updates should remain bit-identical.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 128 + steps = 4 + base = torch.randn(model_size, device='cpu', dtype=torch.float32) + param_a = torch.nn.Parameter(base.clone()) + param_b = torch.nn.Parameter(base.clone()) + + optimizer = DeepSpeedCPUAdam([param_a]) + for logical_step in range(1, steps + 1): + grad = torch.randn(model_size, device='cpu', dtype=torch.float32) + + optimizer.param_groups[0]['params'] = [param_a] + param_a.grad = grad.clone() + optimizer.step_subgroup(0) + + optimizer.param_groups[0]['params'] = [param_b] + param_b.grad = grad.clone() + optimizer.step_subgroup(1) + + assert optimizer.state[0]['step'] == logical_step + assert optimizer.state[1]['step'] == logical_step + assert torch.equal(param_a.data, param_b.data) + assert torch.equal(optimizer.state[0]['exp_avg'], optimizer.state[1]['exp_avg']) + assert torch.equal(optimizer.state[0]['exp_avg_sq'], optimizer.state[1]['exp_avg_sq']) + + def test_step_same_step_idempotent_across_param_keys(self): + """Repeated optimizer.step() with swapped param keys should be deterministic.""" + from deepspeed.ops.adam import DeepSpeedCPUAdam + + model_size = 128 + steps = 4 + base = torch.randn(model_size, device='cpu', dtype=torch.float32) + param_a = torch.nn.Parameter(base.clone()) + param_b = torch.nn.Parameter(base.clone()) + + optimizer = DeepSpeedCPUAdam([param_a]) + for logical_step in range(1, steps + 1): + grad = torch.randn(model_size, device='cpu', dtype=torch.float32) + + optimizer.param_groups[0]['params'] = [param_a] + param_a.grad = grad.clone() + optimizer.step() + + optimizer.param_groups[0]['params'] = [param_b] + param_b.grad = grad.clone() + optimizer.step() + + assert optimizer.state[param_a]['step'] == logical_step + assert optimizer.state[param_b]['step'] == logical_step + assert torch.equal(param_a.data, param_b.data) + assert torch.equal(optimizer.state[param_a]['exp_avg'], optimizer.state[param_b]['exp_avg']) + assert torch.equal(optimizer.state[param_a]['exp_avg_sq'], optimizer.state[param_b]['exp_avg_sq']) \ No newline at end of file