Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions csrc/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,17 @@ class Adam_Optimizer {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
if (step == _step + 1) { // first optimizer step increase
_step++;
_betta1_t *= _betta1;
_betta2_t *= _betta2;
} else if (step ==
_step) { // no need to update step; beta1_t and beta2_t already updated
return;
} else { // support step increase not equal to 1
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions csrc/xpu/includes/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,17 @@ class Adam_Optimizer {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
if (step == _step + 1) { // first optimizer step increase
_step++;
_betta1_t *= _betta1;
_betta2_t *= _betta2;
} else if (step ==
_step) { // no need to update step; beta1_t and beta2_t already updated
return;
} else { // support step increase not equal to 1
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/ops/adam/test_cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,59 @@ def test_multiple_subgroups(self):
optimizer.rollback_subgroup(0)
assert optimizer.state[0]['step'] == 1, "Subgroup 0 step count should be decremented"
assert optimizer.state[1]['step'] == 1, "Subgroup 1 step count should be unchanged"

def test_step_subgroup_same_step_idempotent_across_subgroups(self):
"""Repeated same-step subgroup updates should remain bit-identical."""
from deepspeed.ops.adam import DeepSpeedCPUAdam

model_size = 128
steps = 4
base = torch.randn(model_size, device='cpu', dtype=torch.float32)
param_a = torch.nn.Parameter(base.clone())
param_b = torch.nn.Parameter(base.clone())

optimizer = DeepSpeedCPUAdam([param_a])
for logical_step in range(1, steps + 1):
grad = torch.randn(model_size, device='cpu', dtype=torch.float32)

optimizer.param_groups[0]['params'] = [param_a]
param_a.grad = grad.clone()
optimizer.step_subgroup(0)

optimizer.param_groups[0]['params'] = [param_b]
param_b.grad = grad.clone()
optimizer.step_subgroup(1)

assert optimizer.state[0]['step'] == logical_step
assert optimizer.state[1]['step'] == logical_step
assert torch.equal(param_a.data, param_b.data)
assert torch.equal(optimizer.state[0]['exp_avg'], optimizer.state[1]['exp_avg'])
assert torch.equal(optimizer.state[0]['exp_avg_sq'], optimizer.state[1]['exp_avg_sq'])

def test_step_same_step_idempotent_across_param_keys(self):
"""Repeated optimizer.step() with swapped param keys should be deterministic."""
from deepspeed.ops.adam import DeepSpeedCPUAdam

model_size = 128
steps = 4
base = torch.randn(model_size, device='cpu', dtype=torch.float32)
param_a = torch.nn.Parameter(base.clone())
param_b = torch.nn.Parameter(base.clone())

optimizer = DeepSpeedCPUAdam([param_a])
for logical_step in range(1, steps + 1):
grad = torch.randn(model_size, device='cpu', dtype=torch.float32)

optimizer.param_groups[0]['params'] = [param_a]
param_a.grad = grad.clone()
optimizer.step()

optimizer.param_groups[0]['params'] = [param_b]
param_b.grad = grad.clone()
optimizer.step()

assert optimizer.state[param_a]['step'] == logical_step
assert optimizer.state[param_b]['step'] == logical_step
assert torch.equal(param_a.data, param_b.data)
assert torch.equal(optimizer.state[param_a]['exp_avg'], optimizer.state[param_b]['exp_avg'])
assert torch.equal(optimizer.state[param_a]['exp_avg_sq'], optimizer.state[param_b]['exp_avg_sq'])
Loading