Skip to content

Add engine.coalesce_grad_reduction() for ZeRO 1/2/3 multi-backward#7992

Open
roycho96 wants to merge 6 commits into
deepspeedai:masterfrom
roycho96:feat/zero-coalesce-grad-reduction
Open

Add engine.coalesce_grad_reduction() for ZeRO 1/2/3 multi-backward#7992
roycho96 wants to merge 6 commits into
deepspeedai:masterfrom
roycho96:feat/zero-coalesce-grad-reduction

Conversation

@roycho96
Copy link
Copy Markdown
Contributor

@roycho96 roycho96 commented May 5, 2026

Summary

Adds engine.coalesce_grad_reduction(), an opt-in context manager that defers ZeRO 1/2/3 gradient reduction across multiple engine.backward() calls inside one optimizer step. On context exit, a single reduction pass populates averaged_gradients for the next engine.step().

This is the third step of the multi-backward feature:

This PR makes the path efficient. N reduce-scatters collapse into 1 across all ZeRO stages, removing the communication bottleneck that remained after the correctness fix.

Motivation

engine.no_sync() (engine.py) explicitly asserts ZeRO 2/3 are incompatible with the no_sync context manager. ZeRO needs per-backward reduction to partition gradients. The assert enforces that, but it blocks patterns where multiple engine.backward() calls per step are intentional:

  • Cached contrastive learning (GradCache): sentence-transformers CachedMultipleNegativesRankingLoss, CachedGISTEmbedLoss, CachedMultipleNegativesSymmetricRankingLoss all call engine.backward() once per cached chunk.
  • Custom autograd Functions that invoke torch.autograd.backward() inside their forward

Both rely on the PyTorch-style backward API from #7665. With that API, the user (or a custom autograd Function) issues N engine.backward() calls per engine.step() and toggles set_gradient_accumulation_boundary() to mark the last one. Without this PR, the pattern issues N reduce-scatters per step on ZeRO 2/3 even when the math only needs 1.

What changed

deepspeed/runtime/engine.py

  • New engine.coalesce_grad_reduction() context manager.
  • Stage-aware flush helpers (_flush_coalesced_reduction_zero{12,3}). Iterates params explicitly instead of calling reduce_gradients() to bypass the overlap_comm short-circuit and the contiguous_gradients setup_buckets dependency.

deepspeed/runtime/zero/stage_1_and_2.py and deepspeed/runtime/zero/stage3.py

  • _coalesce_grad_reduction = False init plus a 2-line guard at the top of the per-param reducer entry point. No existing function bodies modified.

Compatibility matrix (all bit-exact vs. baseline multi-backward)

All four (contiguous_gradients, overlap_comm) combinations bit-exact vs. baseline multi-backward:

Stage (F, F) (T, F) (F, T) (T, T) default
ZeRO-1 OK OK OK OK
ZeRO-2 OK OK OK OK
ZeRO-3 OK OK OK OK

Additional verified:

  • CPU offload (offload_optimizer Z1/Z2/Z3, offload_param Z3).
  • BF16 with gradient_accumulation_dtype=fp32 (Z2 directly, Z1 with offload via the use_grad_accum_attribute=True path).
  • FP16 with dynamic loss scaling (Z1/Z2/Z3).
  • Multi-bucket flush (small reduce_bucket_size).
  • MoE smoke (ep_size=1, Z1/Z2). MoE ep_size=2 test included but requires world_size=4.
  • Gradient clipping, multi-step state hygiene.
  • N=4 deferred backward issues strictly fewer cross-rank collectives than baseline (TestCoalesceCollectiveCount, patches dist.all_reduce/reduce/reduce_scatter_fn).
  • ZeRO-3 optimizer.micro_step_id invariant. Stays 0 at flush across multiple steps, so partition_grads always takes the copy_ branch instead of the stale-buffer add_ branch (TestCoalesceZero3MicroStepInvariant).

Unsupported (NotImplementedError)

  • ZeRO stage 0.
  • BF16_Optimizer / FP16_Optimizer wrappers. BF16_Optimizer dispatches only for ZeRO-1 with bf16, grad_accum_dtype=fp32, and no offload. Users on this combo can switch to ZeRO-2.
  • PipelineModule (pipeline parallelism schedules its own reductions).
  • Reentry / nesting with engine.no_sync().

Signed-off-by: Sung Hyun Cho <hope5487@gmail.com>
@roycho96 roycho96 force-pushed the feat/zero-coalesce-grad-reduction branch from 4ea6fc2 to d4e71b7 Compare May 5, 2026 10:31
@tohtana tohtana assigned tohtana and unassigned tohtana May 18, 2026
@tohtana
Copy link
Copy Markdown
Collaborator

tohtana commented May 18, 2026

Hi @roycho96,

Thank you for submitting an interesting PR! Can you help me understand the tradeoff a bit better?

When we delay gradient reduction, ZeRO-2 needs to keep the full unreduced gradients, not just shards, until the flush. In that case, do we still have a meaningful advantage over no_sync() + ZeRO-1? ZeRO-3 + coalesce_grad_reduction() is different because parameter partitioning still matters, but it also changes the usual ZeRO-3 memory behavior.

I still think this feature can be beneficial and am happy to merge, but we should clarify what it brings in terms of memory requirements as well as communication efficiency. Once the memory tradeoff gets clear, could you also add a short note to docs/code-docs/source/training.rst near “Gradient Accumulation”?

@roycho96
Copy link
Copy Markdown
Contributor Author

roycho96 commented May 19, 2026

Hi @roycho96,

Thank you for submitting an interesting PR! Can you help me understand the tradeoff a bit better?

When we delay gradient reduction, ZeRO-2 needs to keep the full unreduced gradients, not just shards, until the flush. In that case, do we still have a meaningful advantage over no_sync() + ZeRO-1? ZeRO-3 + coalesce_grad_reduction() is different because parameter partitioning still matters, but it also changes the usual ZeRO-3 memory behavior.

I still think this feature can be beneficial and am happy to merge, but we should clarify what it brings in terms of memory requirements as well as communication efficiency. Once the memory tradeoff gets clear, could you also add a short note to docs/code-docs/source/training.rst near “Gradient Accumulation”?

Hi @tohtana, thanks for looking at this. You are right about the memory side. I ran a benchmark put numbers on every case.

Setup

2 GPUs, MLP hidden=4096, 8 layers (~134M params, full bf16 grad = 256 MiB), bf16. Each step does N back-to-back engine.backward() calls before engine.step(). The table below is N=4; I also swept N=2 and N=8. All configs bit-exact vs the matching stage's baseline (max param diff = 0).

Modes in the table:

  • baseline: no context manager, just N back-to-back engine.backward(). This is the path before the PR. Z1 only reduces at the boundary, so it stays at 1 collective. Z2/Z3 reduce per backward, so they scale as N.
  • no_sync: PyTorch DDP pattern, N-1 backwards inside engine.no_sync() plus 1 outside. Z2/Z3 are excluded because engine.no_sync() asserts incompatible with grad partitioning. Z1 is the only stage where this is allowed, and it is the only "alternative" for Z2/Z3 users today, which is why your question pivots on it.
  • coalesce: this PR's engine.coalesce_grad_reduction().

Benchmark (one step, N=4 backwards)

stage mode peak window window resident flush resident wall ms grad collectives comm MiB
Z1 baseline 512 128 128 121 1 all_reduce 256
Z1 no_sync 384 256 128 119 1 all_reduce 256
Z1 coalesce 384 256 128 128 1 all_reduce 256
Z2 baseline 640 128 128 203 4 all_reduce 1024
Z2 coalesce 384 256 128 120 1 all_reduce 256
Z3 baseline 640 128 128 177 4 reduce_scatter 1024
Z3 coalesce 640 384 128 133 1 reduce_scatter 256

N=2 and N=8 sweeps gave the same pattern: baseline collectives scale as N for Z2/Z3, coalesce stays at 1.

Z2 + coalesce vs Z1 + no_sync

Same window memory (both hold one full grad/rank, 256 MiB here), and both drop to 128 MiB after flush. Z2 + coalesce wins on (a) wire bytes: ring all_reduce moves ~2x payload, reduce_scatter ~1x, so Z2's one reduce_scatter is about half Z1's one all_reduce, and (b) opt state: Z2 shards Adam moments and fp32 master, about 800 MiB/rank less than Z1 for this model. No memory penalty vs Z1 + no_sync, strictly better on comm and steady-state memory.

Z3 + coalesce extra cost

Window resident is 384 MiB vs 128 MiB for baseline. The 256 MiB gap is one full bf16 grad of the params, held until flush. Peak is about the same (640 MiB either way) because the in-flight backward already needs full-grad room; the accumulator reuses it. After flush both go back to the usual 128 MiB partition. Scaled to a 70B Z3 run at world=64 that is ~2.0 GiB/rank extra during the window, so this is meant for runs where comm is the bottleneck and there is some grad memory to spare, not the tightest Z3 jobs.

Comm and wall time

N to 1 reduction is exact across N=2/4/8. Wall time on this 2-GPU rig: Z2 203 ms -> 120 ms (1.7x), Z3 177 ms -> 133 ms (1.3x). The gap grows with rank count and model size.

Docs

I will add a short note to docs/code-docs/source/training.rst near "Gradient Accumulation" that lays out the window memory cost (Z2 + coalesce matches Z1 + no_sync, Z3 + coalesce holds one extra full grad per rank until flush) and push it as a separate commit on this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants