Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2470,12 +2470,20 @@ def _backward_prologue(self):
def _backward_epilogue(self):
self._stop_timers(self.engine_timers.backward_inner_timers)
self._start_timers(self.engine_timers.backward_reduce_timers)
# BF16_Optimizer (without immediate_grad_update) accumulates low
# precision grads into a separate fp32 buffer in backward_epilogue().
# Run it before allreduce so the boundary microbatch is reduced.
bf16_optimizer = isinstance(self.optimizer, BF16_Optimizer)
if bf16_optimizer:
self.optimizer.backward_epilogue()

if self.enable_backward_allreduce and not self.inside_no_sync_ctxt:
# Traditional code path that allreduces the module parameter grads
self.allreduce_gradients()

if isinstance(self.optimizer, ZeROOptimizer):
self.optimizer.backward_epilogue()
if not bf16_optimizer:
self.optimizer.backward_epilogue()
self.optimizer.exit_backward()

see_memory_usage("Engine after backward", force=self.memory_breakdown())
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/v1/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import deepspeed
from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
from deepspeed.runtime.zero.utils import ZeRORuntimeException
Expand Down Expand Up @@ -53,6 +54,68 @@ def dump_state_dict(model):
print(f"{name} {param.data}")


class TestBF16OptimizerGradReduction(DistributedTest):
world_size = 2

def test_boundary_microbatch_grad_is_reduced(self):
if not get_accelerator().is_bf16_supported():
pytest.skip("bfloat16 is not supported on this accelerator")

class ScaleModel(Module):

def __init__(self):
super().__init__()
self.weight = Parameter(torch.ones(4))

def forward(self, x):
return (self.weight * x).sum()

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 2,
"zero_optimization": {
"stage": 1
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
"bf16": {
"enabled": True,
"immediate_grad_update": False,
},
"data_types": {
"grad_accum_dtype": "fp32"
}
}

model = ScaleModel()
engine, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
assert isinstance(engine.optimizer, BF16_Optimizer)

rank = dist.get_rank()
rank_offset = 18 * rank
inputs = [
torch.tensor([1, 3, 5, 7], dtype=torch.bfloat16, device=engine.device) + rank_offset,
torch.tensor([11, 13, 15, 17], dtype=torch.bfloat16, device=engine.device) + rank_offset,
]
for i, x in enumerate(inputs):
engine.set_gradient_accumulation_boundary(i == len(inputs) - 1)
engine.backward(engine(x))

grad = engine.optimizer.fp32_groups_gradients_flat[0].detach().clone()
expected = torch.tensor([15, 17, 19, 21], dtype=grad.dtype, device=grad.device)
torch.testing.assert_close(grad, expected)

gathered_grads = [torch.zeros_like(grad) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_grads, grad)
torch.testing.assert_close(gathered_grads[0], gathered_grads[1])

engine.destroy()


@pytest.mark.parametrize("zero_stage", [1, 2, 3])
class TestZeroUnbalancedGradients(DistributedTest):
world_size = 1
Expand Down
Loading