Add engine.coalesce_grad_reduction() for ZeRO 1/2/3 multi-backward#7992
Add engine.coalesce_grad_reduction() for ZeRO 1/2/3 multi-backward#7992roycho96 wants to merge 6 commits into
Conversation
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
4ea6fc2 to
d4e71b7
Compare
|
Hi @roycho96, Thank you for submitting an interesting PR! Can you help me understand the tradeoff a bit better? When we delay gradient reduction, ZeRO-2 needs to keep the full unreduced gradients, not just shards, until the flush. In that case, do we still have a meaningful advantage over I still think this feature can be beneficial and am happy to merge, but we should clarify what it brings in terms of memory requirements as well as communication efficiency. Once the memory tradeoff gets clear, could you also add a short note to |
Hi @tohtana, thanks for looking at this. You are right about the memory side. I ran a benchmark put numbers on every case. Setup2 GPUs, MLP hidden=4096, 8 layers (~134M params, full bf16 grad = 256 MiB), bf16. Each step does Modes in the table:
Benchmark (one step, N=4 backwards)
N=2 and N=8 sweeps gave the same pattern: baseline collectives scale as N for Z2/Z3, coalesce stays at 1. Z2 + coalesce vs Z1 + no_syncSame window memory (both hold one full grad/rank, 256 MiB here), and both drop to 128 MiB after flush. Z2 + coalesce wins on (a) wire bytes: ring all_reduce moves ~2x payload, reduce_scatter ~1x, so Z2's one reduce_scatter is about half Z1's one all_reduce, and (b) opt state: Z2 shards Adam moments and fp32 master, about 800 MiB/rank less than Z1 for this model. No memory penalty vs Z1 + no_sync, strictly better on comm and steady-state memory. Z3 + coalesce extra costWindow resident is 384 MiB vs 128 MiB for baseline. The 256 MiB gap is one full bf16 grad of the params, held until flush. Peak is about the same (640 MiB either way) because the in-flight backward already needs full-grad room; the accumulator reuses it. After flush both go back to the usual 128 MiB partition. Scaled to a 70B Z3 run at world=64 that is ~2.0 GiB/rank extra during the window, so this is meant for runs where comm is the bottleneck and there is some grad memory to spare, not the tightest Z3 jobs. Comm and wall timeN to 1 reduction is exact across N=2/4/8. Wall time on this 2-GPU rig: Z2 203 ms -> 120 ms (1.7x), Z3 177 ms -> 133 ms (1.3x). The gap grows with rank count and model size. DocsI will add a short note to |
Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
Summary
Adds
engine.coalesce_grad_reduction(), an opt-in context manager that defers ZeRO 1/2/3 gradient reduction across multipleengine.backward()calls inside one optimizer step. On context exit, a single reduction pass populatesaveraged_gradientsfor the nextengine.step().This is the third step of the multi-backward feature:
set_gradient_accumulation_boundary()plus manualengine.backward()(PyTorch-style backward) a first-class API.cpu_offload. Chunks 1 through N-1 were dropped atga_steps=1 + N>1because of an outer gate incopy_grads_in_partition.This PR makes the path efficient. N reduce-scatters collapse into 1 across all ZeRO stages, removing the communication bottleneck that remained after the correctness fix.
Motivation
engine.no_sync()(engine.py) explicitly asserts ZeRO 2/3 are incompatible with the no_sync context manager. ZeRO needs per-backward reduction to partition gradients. The assert enforces that, but it blocks patterns where multipleengine.backward()calls per step are intentional:CachedMultipleNegativesRankingLoss,CachedGISTEmbedLoss,CachedMultipleNegativesSymmetricRankingLossall callengine.backward()once per cached chunk.torch.autograd.backward()inside their forwardBoth rely on the PyTorch-style backward API from #7665. With that API, the user (or a custom autograd Function) issues N
engine.backward()calls perengine.step()and togglesset_gradient_accumulation_boundary()to mark the last one. Without this PR, the pattern issues N reduce-scatters per step on ZeRO 2/3 even when the math only needs 1.What changed
deepspeed/runtime/engine.pyengine.coalesce_grad_reduction()context manager._flush_coalesced_reduction_zero{12,3}). Iterates params explicitly instead of callingreduce_gradients()to bypass theoverlap_commshort-circuit and thecontiguous_gradientssetup_buckets dependency.deepspeed/runtime/zero/stage_1_and_2.pyanddeepspeed/runtime/zero/stage3.py_coalesce_grad_reduction = Falseinit plus a 2-line guard at the top of the per-param reducer entry point. No existing function bodies modified.Compatibility matrix (all bit-exact vs. baseline multi-backward)
All four (contiguous_gradients, overlap_comm) combinations bit-exact vs. baseline multi-backward:
Additional verified:
gradient_accumulation_dtype=fp32(Z2 directly, Z1 with offload via theuse_grad_accum_attribute=Truepath).reduce_bucket_size).TestCoalesceCollectiveCount, patchesdist.all_reduce/reduce/reduce_scatter_fn).optimizer.micro_step_idinvariant. Stays 0 at flush across multiple steps, sopartition_gradsalways takes thecopy_branch instead of the stale-bufferadd_branch (TestCoalesceZero3MicroStepInvariant).Unsupported (NotImplementedError)
grad_accum_dtype=fp32, and no offload. Users on this combo can switch to ZeRO-2.engine.no_sync().