From a09084b09b3f8a875c194ce1cf4ea297b205c97d Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 14 Jan 2026 23:20:04 -0800 Subject: [PATCH 01/11] fix backward with checkpointing and reentrant Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/stage3.py | 8 + tests/unit/v1/zero/test_zero_user_backward.py | 220 ++++++++++++++++++ 2 files changed, 228 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 0828bd7c755b..20a0290a3057 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1963,6 +1963,14 @@ def _pre_step(self): print_rank_0("Finished Tracing at Beginning of Step") + # Clear any stale params from ipg_buckets. This is needed because with + # reentrant checkpointing (use_reentrant=True), the backward pass can + # leave params in the buckets that weren't properly processed, causing + # errors in the next iteration. + for bucket in self.ipg_buckets.values(): + bucket.params.clear() + bucket.elements = 0 + @instrument_w_nvtx def _get_norm_groups(self): norm_groups = [] diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py index 106d2425fbbb..6f4b80ad1846 100644 --- a/tests/unit/v1/zero/test_zero_user_backward.py +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -1148,3 +1148,223 @@ def test_scale_with_torch_autocast(self, zero_stage): assert len(grads) > 0, "Expected gradients to be computed" model_engine.destroy() + + +class CheckpointedModel(torch.nn.Module): + """Model that uses gradient checkpointing with configurable use_reentrant setting. + + This model is designed to test the interaction between ZeRO-3 and gradient + checkpointing with both reentrant (use_reentrant=True) and non-reentrant + (use_reentrant=False) modes. + """ + + def __init__(self, hidden_dim, use_reentrant=True): + super().__init__() + self.use_reentrant = use_reentrant + self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) + + def _checkpointed_block(self, x): + """Block that will be checkpointed""" + x = self.linear1(x) + x = torch.nn.functional.relu(x) + x = self.linear2(x) + return x + + def forward(self, x): + # Use gradient checkpointing on the middle block + if self.training: + from torch.utils.checkpoint import checkpoint + x = checkpoint(self._checkpointed_block, x, use_reentrant=self.use_reentrant) + else: + x = self._checkpointed_block(x) + x = self.linear3(x) + return x + + +@pytest.mark.parametrize("use_reentrant", [True, False]) +class TestZeroUserBackwardWithCheckpointing(DistributedTest): + """Test ZeRO-3 with gradient checkpointing and non-scalar backward. + + This test class validates the interaction between: + 1. ZeRO-3 parameter partitioning + 2. Gradient checkpointing (both reentrant and non-reentrant modes) + 3. Non-scalar backward (tensor.backward(gradient=...)) + + Both use_reentrant=True and use_reentrant=False are supported with ZeRO-3. + Note: When using use_reentrant=True, input tensors should have requires_grad=True + for proper gradient computation through the checkpointed region. + """ + world_size = 2 + + def test_checkpointed_non_scalar_backward_zero3(self, use_reentrant): + """Test that gradient checkpointing works with ZeRO-3 and non-scalar backward. + + Verifies that tensor.backward(gradient=...) works correctly with ZeRO-3 + and gradient checkpointing in both reentrant and non-reentrant modes. + """ + hidden_dim = 8 + batch_size = 2 + zero_stage = 3 + + # Initialize distributed environment + device, rank, dtype = initialize_distributed() + + # Create DDP model for reference (no checkpointing issues with DDP) + torch.manual_seed(42) + model_ddp = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + model_ddp = model_ddp.to(device=device, dtype=dtype) + model_ddp = DDP(model_ddp, device_ids=[rank], output_device=rank) + optimizer_ddp = torch.optim.Adam(model_ddp.parameters(), lr=1e-3) + + # Create DeepSpeed model with ZeRO-3 + torch.manual_seed(42) + model_ds = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + + config = get_config_dict(zero_stage) + model_engine, _, _, _ = deepspeed.initialize(config=config, + model=model_ds, + model_parameters=model_ds.parameters()) + + # Create input data + # For reentrant checkpointing (use_reentrant=True), inputs need requires_grad=True + # for proper gradient computation through the checkpointed region. + torch.manual_seed(123) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + + # DDP: forward and non-scalar backward + optimizer_ddp.zero_grad() + output_ddp = model_ddp(x) + grad_output = torch.ones_like(output_ddp) + output_ddp.backward(grad_output) + ddp_grads = collect_ddp_gradients(model_ddp) + + # DeepSpeed with ZeRO-3: forward and non-scalar backward + # This is the pattern used in disaggregated training + output_ds = model_engine(x.detach().requires_grad_(True)) + grad_output_ds = torch.ones_like(output_ds) + + # Non-scalar backward with gradient checkpointing + output_ds.backward(grad_output_ds) + + # Collect and verify gradients + ds_grads = collect_gradients_safe(model_engine) + + # Verify gradients were computed + assert len(ds_grads) > 0, \ + f"No gradients computed with use_reentrant={use_reentrant} and ZeRO-3" + + # Compare gradients with DDP reference + compare_gradients(ddp_grads, ds_grads, + f"with checkpointing use_reentrant={use_reentrant}") + + # Run optimizer step to verify full training loop works + model_engine.step() + + model_engine.destroy() + + def test_checkpointed_scalar_backward_zero3(self, use_reentrant): + """Test that gradient checkpointing works with ZeRO-3 and scalar backward. + + Verifies that scalar loss.backward() works correctly with ZeRO-3 and + gradient checkpointing in both reentrant and non-reentrant modes. + """ + hidden_dim = 8 + batch_size = 2 + zero_stage = 3 + + # Initialize distributed environment + device, rank, dtype = initialize_distributed() + + # Create DDP model for reference + torch.manual_seed(42) + model_ddp = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + model_ddp = model_ddp.to(device=device, dtype=dtype) + model_ddp = DDP(model_ddp, device_ids=[rank], output_device=rank) + optimizer_ddp = torch.optim.Adam(model_ddp.parameters(), lr=1e-3) + + # Create DeepSpeed model with ZeRO-3 + torch.manual_seed(42) + model_ds = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + + config = get_config_dict(zero_stage) + model_engine, _, _, _ = deepspeed.initialize(config=config, + model=model_ds, + model_parameters=model_ds.parameters()) + + # Create input data + # For reentrant checkpointing (use_reentrant=True), inputs need requires_grad=True + torch.manual_seed(123) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + y = torch.randint(0, hidden_dim, (batch_size,), device=device) + + # DDP: forward with scalar loss and backward + optimizer_ddp.zero_grad() + output_ddp = model_ddp(x) + loss_ddp = torch.nn.functional.cross_entropy(output_ddp, y) + loss_ddp.backward() + ddp_grads = collect_ddp_gradients(model_ddp) + + # DeepSpeed with ZeRO-3: forward with scalar loss and backward + output_ds = model_engine(x.detach().requires_grad_(True)) + loss_ds = torch.nn.functional.cross_entropy(output_ds, y) + loss_ds.backward() + + # Collect and verify gradients + ds_grads = collect_gradients_safe(model_engine) + + # Verify gradients were computed + assert len(ds_grads) > 0, \ + f"No gradients computed with scalar loss, use_reentrant={use_reentrant}" + + # Compare gradients with DDP reference + compare_gradients(ddp_grads, ds_grads, + f"scalar loss with checkpointing use_reentrant={use_reentrant}") + + model_engine.destroy() + + def test_checkpointed_multiple_backward_zero3(self, use_reentrant): + """Test multiple backward passes with checkpointing and ZeRO-3. + + Verifies that consecutive training iterations work correctly with + gradient checkpointing in both reentrant and non-reentrant modes. + """ + + hidden_dim = 8 + batch_size = 2 + zero_stage = 3 + num_iterations = 3 + + # Initialize distributed environment + device, rank, dtype = initialize_distributed() + + # Create DeepSpeed model with ZeRO-3 + torch.manual_seed(42) + model_ds = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + + config = get_config_dict(zero_stage) + model_engine, _, _, _ = deepspeed.initialize(config=config, + model=model_ds, + model_parameters=model_ds.parameters()) + + for iteration in range(num_iterations): + # Create input data with different seed each iteration + # For reentrant checkpointing (use_reentrant=True), inputs need requires_grad=True + torch.manual_seed(123 + iteration) + x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + + # Forward and non-scalar backward + output_ds = model_engine(x) + grad_output_ds = torch.ones_like(output_ds) + output_ds.backward(grad_output_ds) + + # Collect and verify gradients + ds_grads = collect_gradients_safe(model_engine) + assert len(ds_grads) > 0, \ + f"No gradients at iteration {iteration} with use_reentrant={use_reentrant}" + + # Run optimizer step + model_engine.step() + + model_engine.destroy() From 8211447d46d2c4f8138e6c719ae5d39462e48568 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Mon, 12 Jan 2026 19:30:59 -0800 Subject: [PATCH 02/11] Update README with newer status badges for CI Signed-off-by: Masahiro Tanaka --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a58372a802cd..b85e663751a5 100755 --- a/README.md +++ b/README.md @@ -91,13 +91,12 @@ DeepSpeed has been integrated with several different popular open-source DL fram | Description | Status | | ----------- | ------ | -| NVIDIA | [![nv-torch-latest-v100](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-torch-latest-v100.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-torch-latest-v100.yml) [![nv-inference](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-inference.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-inference.yml) [![nv-nightly](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-nightly.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-nightly.yml) | +| NVIDIA | [![nv-pre-compile-ops](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-pre-compile-ops.yml/badge.svg)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-pre-compile-ops.yml) [![aws-torch-latest](https://github.com/deepspeedai/DeepSpeed/actions/workflows/aws-torch-latest.yml/badge.svg)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/aws-torch-latest.yml) | | AMD | [![amd-mi200](https://github.com/deepspeedai/DeepSpeed/actions/workflows/amd-mi200.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/amd-mi200.yml) | | CPU | [![torch-latest-cpu](https://github.com/deepspeedai/DeepSpeed/actions/workflows/cpu-torch-latest.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/cpu-torch-latest.yml) | | Intel Gaudi | [![hpu-gaudi2](https://github.com/deepspeedai/DeepSpeed/actions/workflows/hpu-gaudi2.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/hpu-gaudi2.yml) | | Intel XPU | [![xpu-max1100](https://github.com/deepspeedai/DeepSpeed/actions/workflows/xpu-max1100.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/xpu-max1100.yml) | -| PyTorch Nightly | [![nv-torch-nightly-v100](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-torch-nightly-v100.yml) | -| Integrations | [![nv-transformers-v100](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-transformers-v100.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-transformers-v100.yml) [![nv-lightning-v100](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-lightning-v100.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-lightning-v100.yml) [![nv-accelerate-v100](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-accelerate-v100.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-accelerate-v100.yml) [![nv-mii](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-mii.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-mii.yml) [![nv-ds-chat](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-ds-chat.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-ds-chat.yml) [![nv-sd](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-sd.yml/badge.svg)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/nv-sd.yml) | +| Integrations | [![aws-accelerate](https://github.com/deepspeedai/DeepSpeed/actions/workflows/aws-accelerate.yml/badge.svg)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/aws-accelerate.yml) | | Misc | [![Formatting](https://github.com/deepspeedai/DeepSpeed/actions/workflows/formatting.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/formatting.yml) [![pages-build-deployment](https://github.com/deepspeedai/DeepSpeed/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/pages/pages-build-deployment) [![Documentation Status](https://readthedocs.org/projects/deepspeed/badge/?version=latest)](https://deepspeed.readthedocs.io/en/latest/?badge=latest)[![python](https://github.com/deepspeedai/DeepSpeed/actions/workflows/python.yml/badge.svg?branch=master)](https://github.com/deepspeedai/DeepSpeed/actions/workflows/python.yml) | | Huawei Ascend NPU | [![Huawei Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml/badge.svg?branch=main)](https://github.com/Ascend/Ascend-CI/actions/workflows/deepspeed.yaml) | From f6026d19b2b403d97a4d7d9baab85ce03bbc2d81 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 12 Jan 2026 20:06:43 -0800 Subject: [PATCH 03/11] Add timeout to test workflows (#7774) This PR adds timeout to CI workflows. This will prevent zombie jobs from holding GPU instances. Signed-off-by: Masahiro Tanaka Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Signed-off-by: Masahiro Tanaka --- .github/workflows/aws-accelerate.yml | 1 + .github/workflows/aws-torch-latest.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/aws-accelerate.yml b/.github/workflows/aws-accelerate.yml index 2ab293543ebd..19c68d3d0b25 100644 --- a/.github/workflows/aws-accelerate.yml +++ b/.github/workflows/aws-accelerate.yml @@ -47,6 +47,7 @@ jobs: needs: check-paths if: needs.check-paths.outputs.should_run == 'true' runs-on: [self-hosted, gpu-ci, gpu-l40s, l40s-1gpu, aws] + timeout-minutes: 60 container: image: nvidia/cuda:12.6.3-devel-ubuntu22.04 diff --git a/.github/workflows/aws-torch-latest.yml b/.github/workflows/aws-torch-latest.yml index d321990345dc..45d9b23ff891 100644 --- a/.github/workflows/aws-torch-latest.yml +++ b/.github/workflows/aws-torch-latest.yml @@ -46,6 +46,7 @@ jobs: needs: check-paths if: needs.check-paths.outputs.should_run == 'true' runs-on: [self-hosted, gpu-ci, gpu-l40s, l40s-4gpu, aws] + timeout-minutes: 60 container: image: nvidia/cuda:12.6.3-devel-ubuntu22.04 From 3cf426cdbef443839eac442ecb898bf0a42ef679 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 13 Jan 2026 07:53:09 -0800 Subject: [PATCH 04/11] Remove cron/PR triggers for outdated V100 tests (#7777) The V100 tests are not needed anymore but this prevents the CI cron jobs from being spun up even though the jobs are disabled. The next step will be to remove the yaml files we do not use anymore/have already ported. Signed-off-by: Masahiro Tanaka --- .github/workflows/nv-accelerate-v100.yml | 13 ------------- .github/workflows/nv-ds-chat.yml | 2 -- .github/workflows/nv-mii.yml | 21 --------------------- .github/workflows/nv-nightly.yml | 8 -------- .github/workflows/nv-sd.yml | 2 -- .github/workflows/nv-torch-latest-v100.yml | 13 ------------- .github/workflows/nv-torch-nightly-v100.yml | 8 -------- .github/workflows/nv-transformers-v100.yml | 12 ------------ 8 files changed, 79 deletions(-) diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index d23ef32742f6..a1b7fd343e0b 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -1,18 +1,5 @@ name: nv-accelerate-v100 -on: - workflow_dispatch: - pull_request: - paths-ignore: - - 'docs/**' - - 'blogs/**' - - 'deepspeed/inference/v2/**' - - 'tests/unit/inference/v2/**' - merge_group: - branches: [ master ] - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/nv-ds-chat.yml b/.github/workflows/nv-ds-chat.yml index 3cefee0c171f..14cd0eff1e25 100644 --- a/.github/workflows/nv-ds-chat.yml +++ b/.github/workflows/nv-ds-chat.yml @@ -1,8 +1,6 @@ name: nv-ds-chat on: - schedule: - - cron: "0 0 * * *" workflow_dispatch: inputs: dse_branch: diff --git a/.github/workflows/nv-mii.yml b/.github/workflows/nv-mii.yml index c81384450029..64b97b080ede 100644 --- a/.github/workflows/nv-mii.yml +++ b/.github/workflows/nv-mii.yml @@ -1,26 +1,5 @@ name: nv-mii -on: - workflow_dispatch: - inputs: - mii_branch: - description: 'DeepSpeed-MII Branch' - required: false - default: 'main' - type: string - pull_request: - paths: - - '.github/workflows/nv-mii.yml' - - 'requirements/**' - - 'setup.py' - - 'deepspeed/__init__.py' - - 'deepspeed/inference/**' - - '!deepspeed/inference/v2/**' # exclude v2 dir - merge_group: - branches: [ master ] - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/nv-nightly.yml b/.github/workflows/nv-nightly.yml index 7f81484c7646..670b0c4eda44 100644 --- a/.github/workflows/nv-nightly.yml +++ b/.github/workflows/nv-nightly.yml @@ -1,13 +1,5 @@ name: nv-nightly -on: - workflow_dispatch: - pull_request: - paths: - - '.github/workflows/nv-nightly.yml' - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/nv-sd.yml b/.github/workflows/nv-sd.yml index af406075b868..ae0d4cb0510f 100644 --- a/.github/workflows/nv-sd.yml +++ b/.github/workflows/nv-sd.yml @@ -2,8 +2,6 @@ name: nv-sd on: workflow_dispatch: - schedule: - - cron: "0 0 * * 0" pull_request: paths: - "deepspeed/ops/transformer/inference/diffusers_**" diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index df62d1fffb32..9dbdb024ffac 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -1,18 +1,5 @@ name: nv-torch-latest-v100 -on: - workflow_dispatch: - pull_request: - paths-ignore: - - 'docs/**' - - 'blogs/**' - - 'deepspeed/inference/v2/**' - - 'tests/unit/inference/v2/**' - merge_group: - branches: [ master ] - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/nv-torch-nightly-v100.yml b/.github/workflows/nv-torch-nightly-v100.yml index 34ac3e5ba514..f88951b9b03e 100644 --- a/.github/workflows/nv-torch-nightly-v100.yml +++ b/.github/workflows/nv-torch-nightly-v100.yml @@ -1,13 +1,5 @@ name: nv-torch-nightly-v100 -on: - workflow_dispatch: - schedule: - - cron: "0 0 * * *" - pull_request: - paths: - - '.github/workflows/nv-torch-nightly-v100.yml' - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true diff --git a/.github/workflows/nv-transformers-v100.yml b/.github/workflows/nv-transformers-v100.yml index 9d1253fd77ca..e9326613273f 100644 --- a/.github/workflows/nv-transformers-v100.yml +++ b/.github/workflows/nv-transformers-v100.yml @@ -1,17 +1,5 @@ name: nv-transformers-v100 -on: - pull_request: - paths-ignore: - - 'docs/**' - - 'blogs/**' - - 'deepspeed/inference/v2/**' - - 'tests/unit/inference/v2/**' - merge_group: - branches: [ master ] - schedule: - - cron: "0 0 * * *" - concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true From 9116f4a970666519da88bcec59c7ffa4f4bb1c7b Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 14 Jan 2026 23:33:39 -0800 Subject: [PATCH 05/11] fix yapf formatting in test file Co-Authored-By: Claude Opus 4.5 Signed-off-by: Masahiro Tanaka --- tests/unit/v1/zero/test_zero_user_backward.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py index 6f4b80ad1846..2d5d25d66c9a 100644 --- a/tests/unit/v1/zero/test_zero_user_backward.py +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -1256,8 +1256,7 @@ def test_checkpointed_non_scalar_backward_zero3(self, use_reentrant): f"No gradients computed with use_reentrant={use_reentrant} and ZeRO-3" # Compare gradients with DDP reference - compare_gradients(ddp_grads, ds_grads, - f"with checkpointing use_reentrant={use_reentrant}") + compare_gradients(ddp_grads, ds_grads, f"with checkpointing use_reentrant={use_reentrant}") # Run optimizer step to verify full training loop works model_engine.step() @@ -1297,7 +1296,7 @@ def test_checkpointed_scalar_backward_zero3(self, use_reentrant): # For reentrant checkpointing (use_reentrant=True), inputs need requires_grad=True torch.manual_seed(123) x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) - y = torch.randint(0, hidden_dim, (batch_size,), device=device) + y = torch.randint(0, hidden_dim, (batch_size, ), device=device) # DDP: forward with scalar loss and backward optimizer_ddp.zero_grad() @@ -1319,8 +1318,7 @@ def test_checkpointed_scalar_backward_zero3(self, use_reentrant): f"No gradients computed with scalar loss, use_reentrant={use_reentrant}" # Compare gradients with DDP reference - compare_gradients(ddp_grads, ds_grads, - f"scalar loss with checkpointing use_reentrant={use_reentrant}") + compare_gradients(ddp_grads, ds_grads, f"scalar loss with checkpointing use_reentrant={use_reentrant}") model_engine.destroy() From f61abc9080f5c69d4a6df16e01f1f877cf24266c Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 15 Jan 2026 09:48:10 -0800 Subject: [PATCH 06/11] added sync in tests Signed-off-by: Masahiro Tanaka --- tests/unit/v1/zero/test_zero_user_backward.py | 42 ++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py index 2d5d25d66c9a..1a6ee11aba03 100644 --- a/tests/unit/v1/zero/test_zero_user_backward.py +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -5,6 +5,7 @@ import pytest import torch +import deepspeed.comm as dist import deepspeed from torch.nn.parallel import DistributedDataParallel as DDP @@ -1227,27 +1228,36 @@ def test_checkpointed_non_scalar_backward_zero3(self, use_reentrant): model=model_ds, model_parameters=model_ds.parameters()) - # Create input data - # For reentrant checkpointing (use_reentrant=True), inputs need requires_grad=True - # for proper gradient computation through the checkpointed region. + # Create input data - use separate tensors for DDP and DeepSpeed to avoid + # memory sharing issues during parallel test execution torch.manual_seed(123) - x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + x_ddp = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) # DDP: forward and non-scalar backward optimizer_ddp.zero_grad() - output_ddp = model_ddp(x) + output_ddp = model_ddp(x_ddp) grad_output = torch.ones_like(output_ddp) output_ddp.backward(grad_output) + get_accelerator().synchronize() # Ensure CUDA ops complete + dist.barrier() # Ensure all ranks complete gradient sync ddp_grads = collect_ddp_gradients(model_ddp) # DeepSpeed with ZeRO-3: forward and non-scalar backward # This is the pattern used in disaggregated training - output_ds = model_engine(x.detach().requires_grad_(True)) + # Create fresh tensor with same seed for reproducibility + torch.manual_seed(123) + x_ds = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + output_ds = model_engine(x_ds) grad_output_ds = torch.ones_like(output_ds) # Non-scalar backward with gradient checkpointing output_ds.backward(grad_output_ds) + # Synchronize device before collecting gradients. ZeRO-3 uses async operations + # on separate streams for gradient reduction. With use_reentrant=True checkpointing, + # we need to ensure all operations complete before reading gradient data. + get_accelerator().synchronize() + # Collect and verify gradients ds_grads = collect_gradients_safe(model_engine) @@ -1292,24 +1302,34 @@ def test_checkpointed_scalar_backward_zero3(self, use_reentrant): model=model_ds, model_parameters=model_ds.parameters()) - # Create input data - # For reentrant checkpointing (use_reentrant=True), inputs need requires_grad=True + # Create input data - use separate tensors for DDP and DeepSpeed to avoid + # memory sharing issues during parallel test execution torch.manual_seed(123) - x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + x_ddp = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) y = torch.randint(0, hidden_dim, (batch_size, ), device=device) # DDP: forward with scalar loss and backward optimizer_ddp.zero_grad() - output_ddp = model_ddp(x) + output_ddp = model_ddp(x_ddp) loss_ddp = torch.nn.functional.cross_entropy(output_ddp, y) loss_ddp.backward() + get_accelerator().synchronize() # Ensure CUDA ops complete + dist.barrier() # Ensure all ranks complete gradient sync ddp_grads = collect_ddp_gradients(model_ddp) # DeepSpeed with ZeRO-3: forward with scalar loss and backward - output_ds = model_engine(x.detach().requires_grad_(True)) + # Create fresh tensor with same seed for reproducibility + torch.manual_seed(123) + x_ds = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) + output_ds = model_engine(x_ds) loss_ds = torch.nn.functional.cross_entropy(output_ds, y) loss_ds.backward() + # Synchronize device before collecting gradients. ZeRO-3 uses async operations + # on separate streams for gradient reduction. With use_reentrant=True checkpointing, + # we need to ensure all operations complete before reading gradient data. + get_accelerator().synchronize() + # Collect and verify gradients ds_grads = collect_gradients_safe(model_engine) From 84fa1db03b4cad49ee81a5b75d7da220dc775e97 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 15 Jan 2026 15:14:46 -0800 Subject: [PATCH 07/11] extract function to clear params Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/stage3.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 20a0290a3057..d9585ce9147b 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -122,6 +122,11 @@ def clear(self): self.params.clear() self.elements = 0 + def clear_params(self): + """Clear params and elements but keep buffer for reuse.""" + self.params.clear() + self.elements = 0 + INITIAL_MICRO_STEP_ID = -1 @@ -1968,8 +1973,7 @@ def _pre_step(self): # leave params in the buckets that weren't properly processed, causing # errors in the next iteration. for bucket in self.ipg_buckets.values(): - bucket.params.clear() - bucket.elements = 0 + bucket.clear_params() @instrument_w_nvtx def _get_norm_groups(self): From ddb54e76a3d2023aad77635605412566f3503ea8 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 15 Jan 2026 23:59:05 -0800 Subject: [PATCH 08/11] fix issue with backward count Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/base_optimizer.py | 8 ++++++++ deepspeed/runtime/engine.py | 5 +++++ deepspeed/runtime/zero/stage3.py | 19 +++++++++++++++++++ deepspeed/runtime/zero/stage_1_and_2.py | 7 +++++++ tests/unit/v1/zero/test_zero_user_backward.py | 7 +++++++ 5 files changed, 46 insertions(+) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index b10bb7e57f48..c037e579a6f1 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -24,6 +24,7 @@ def __init__(self): self._remaining_grad_acc_hooks = 0 self._grad_acc_post_hooks = [] self._backward_active_depth = 0 + self._backward_seen_this_step = False def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: checkpoint_dir = os.path.join(checkpoint_dir, "zero") @@ -152,11 +153,18 @@ def run_grad_acc_post_hooks(self): def enter_backward(self): self._backward_active_depth += 1 + # Track that backward has been active at some point in this step. + # This is used to detect subsequent gradient hook phases with reentrant checkpointing. + self._backward_seen_this_step = True def exit_backward(self): if self._backward_active_depth > 0: self._backward_active_depth -= 1 + def clear_backward_seen_flag(self): + """Clear the backward seen flag at the start of each step.""" + self._backward_seen_this_step = False + def _configure_master_weights(self, fp16_master_weights_and_gradients=False, bf16_master_weights_and_gradients=False, diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 01468d87e42a..51c6974d4577 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2225,6 +2225,11 @@ def forward(self, *inputs, **kwargs): *inputs: Variable length input list **kwargs: variable length keyword arguments """ + # Clear the backward seen flag at the start of each forward pass. + # This is used to track multiple gradient hook phases with reentrant checkpointing. + if isinstance(self.optimizer, ZeROOptimizer): + self.optimizer.clear_backward_seen_flag() + if self.autotuning_profile_model_info(): ma = get_ma_status() diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d9585ce9147b..7a1feee66b09 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1281,6 +1281,22 @@ def reduce_partition_and_remove_grads(*notneeded): if self._remaining_grad_acc_hooks == 0: self._remaining_grad_acc_hooks = count_used_parameters_in_backward( non_leaf_params_requiring_grad) + leaf_module_count + # With reentrant gradient checkpointing, gradient hooks can fire in + # multiple phases within a single backward call. The first phase + # triggers _backward_epilogue which calls exit_backward(), setting + # _backward_active_depth to 0. When the next phase starts, we need + # to re-enter backward to ensure post hooks run for that phase too. + # + # We detect this case by checking: + # 1. _backward_active_depth == 0 (we've exited from previous phase) + # 2. _backward_seen_this_step == True (backward was active earlier) + # + # This distinguishes from TiledFusedLogitsLoss which calls backward() + # during forward - in that case _backward_seen_this_step is False + # because enter_backward() was never called. + if self._backward_active_depth == 0 and getattr(self, '_backward_seen_this_step', + False): + self.enter_backward() self.reduce_ready_partitions_and_remove_grads(param) @@ -1305,6 +1321,9 @@ def reduce_leaf_module_grads(module, grad_input, grad_output): if self._remaining_grad_acc_hooks == 0: self._remaining_grad_acc_hooks = count_used_parameters_in_backward( non_leaf_params_requiring_grad) + leaf_module_count + # Re-enter backward for subsequent phases (see comment in reduce_partition_and_remove_grads) + if self._backward_active_depth == 0 and getattr(self, '_backward_seen_this_step', False): + self.enter_backward() for param in params: # this takes care of grads for MoE experts that didn't participate in the current iteration/layer diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 86f50a1a0c0b..dfce94006260 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1002,6 +1002,13 @@ def grad_handling_hook(*notneeded): if self._remaining_grad_acc_hooks == 0: self._remaining_grad_acc_hooks = count_used_parameters_in_backward( all_params_requiring_grad) + # With reentrant gradient checkpointing, gradient hooks can fire in + # multiple phases within a single backward call. Re-enter backward + # for subsequent phases to ensure post hooks run correctly. + # (See detailed comment in stage3.py reduce_partition_and_remove_grads) + if self._backward_active_depth == 0 and getattr(self, '_backward_seen_this_step', + False): + self.enter_backward() self.process_gradients(param, i) diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py index 1a6ee11aba03..99deb9db75cb 100644 --- a/tests/unit/v1/zero/test_zero_user_backward.py +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -1257,6 +1257,7 @@ def test_checkpointed_non_scalar_backward_zero3(self, use_reentrant): # on separate streams for gradient reduction. With use_reentrant=True checkpointing, # we need to ensure all operations complete before reading gradient data. get_accelerator().synchronize() + dist.barrier() # Ensure all ranks complete backward before collecting gradients # Collect and verify gradients ds_grads = collect_gradients_safe(model_engine) @@ -1323,12 +1324,14 @@ def test_checkpointed_scalar_backward_zero3(self, use_reentrant): x_ds = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) output_ds = model_engine(x_ds) loss_ds = torch.nn.functional.cross_entropy(output_ds, y) + loss_ds.backward() # Synchronize device before collecting gradients. ZeRO-3 uses async operations # on separate streams for gradient reduction. With use_reentrant=True checkpointing, # we need to ensure all operations complete before reading gradient data. get_accelerator().synchronize() + dist.barrier() # Ensure all ranks complete backward before collecting gradients # Collect and verify gradients ds_grads = collect_gradients_safe(model_engine) @@ -1377,6 +1380,10 @@ def test_checkpointed_multiple_backward_zero3(self, use_reentrant): grad_output_ds = torch.ones_like(output_ds) output_ds.backward(grad_output_ds) + # Synchronize before collecting gradients to ensure async operations complete + get_accelerator().synchronize() + dist.barrier() + # Collect and verify gradients ds_grads = collect_gradients_safe(model_engine) assert len(ds_grads) > 0, \ From 3f6938ea723d86a7c5f75d1460562403814ca6c6 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 16 Jan 2026 16:18:33 -0800 Subject: [PATCH 09/11] fix backward hook state management Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/base_optimizer.py | 255 +++++++++++++++++++++--- deepspeed/runtime/zero/stage3.py | 76 +++---- deepspeed/runtime/zero/stage_1_and_2.py | 18 +- 3 files changed, 271 insertions(+), 78 deletions(-) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index c037e579a6f1..8e1550ace460 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -18,13 +18,218 @@ class DeepSpeedOptimizer(object): pass -class ZeROOptimizer(DeepSpeedOptimizer): +class BackwardHookStateManager: + """Manages backward pass state for ZeRO optimizers. + + This class handles the complex state management needed for gradient accumulation hooks + to work correctly with: + + 1. **Reentrant Gradient Checkpointing** (use_reentrant=True): + With reentrant checkpointing, gradient hooks fire in multiple phases within a + single backward() call. For example, with model: linear1 (checkpointed) -> linear2: + - Phase 1: Hooks for linear2 fire (non-checkpointed params) + - Checkpoint recomputes linear1's forward + - Phase 2: Hooks for linear1 fire (checkpointed params) + + The challenge is that `count_used_parameters_in_backward()` only sees params + currently in the backward graph. During Phase 1, it returns 2 (linear2's params), + but after checkpoint recomputation, it returns 4 (all params). We must NOT run + the epilogue prematurely after Phase 1. + + Solution: Track `_max_expected_hooks_seen` across all phases. The epilogue only + runs when `_hooks_fired_this_backward >= _max_expected_hooks_seen`. + + 2. **TiledFusedLogitsLoss and Similar Custom Autograd Functions**: + Some custom autograd functions call `torch.autograd.backward()` from their + forward pass BEFORE the user calls `engine.backward(loss)`. These internal + backward calls trigger ZeRO's gradient hooks, but we must NOT run the epilogue + until the user's actual backward pass. + + Solution: Track `_backward_active_depth` which is only incremented when + `enter_backward()` is called (from engine.backward or user code). Hooks check + this depth before running the epilogue. + + 3. **Multiple Backward Phases with Exit/Re-entry**: + When the epilogue runs after Phase 1 (with reentrant checkpointing), it calls + `exit_backward()`, setting `_backward_active_depth` to 0. When Phase 2's hooks + fire, we need to re-enter the backward context. + + Solution: `_backward_seen_this_step` flag tracks if backward was ever active + this step. Combined with `_backward_active_depth == 0`, this detects Phase 2 + and calls `enter_backward()` again. + + Attributes: + remaining_grad_acc_hooks: Count of hooks remaining before epilogue should run + backward_active_depth: Nesting depth of backward() calls (0 = not in backward) + backward_seen_this_step: True if enter_backward() was called this step + epilogue_ran_this_backward: True if epilogue ran (for micro_step_id management) + hooks_fired_this_backward: Count of gradient hooks that have fired + max_expected_hooks_seen: Maximum expected hook count seen (grows with reentrant) + """ def __init__(self): - self._remaining_grad_acc_hooks = 0 + self.remaining_grad_acc_hooks = 0 + self._grad_acc_post_hooks = [] + self.backward_active_depth = 0 + self.backward_seen_this_step = False + self.epilogue_ran_this_backward = False + self.hooks_fired_this_backward = 0 + self.max_expected_hooks_seen = 0 + + def register_grad_acc_post_hook(self, hook): + """Register a callback to run when all gradient hooks have fired.""" + self._grad_acc_post_hooks.append(hook) + + def unregister_grad_acc_post_hooks(self): + """Remove all registered gradient accumulation post hooks.""" self._grad_acc_post_hooks = [] - self._backward_active_depth = 0 - self._backward_seen_this_step = False + + def run_grad_acc_post_hooks(self): + """Run all registered post hooks if backward is active. + + Custom autograd Functions (e.g., TiledFusedLogitsLoss) can invoke + `torch.autograd.backward()` from their *forward* pass before the user + ever calls `engine.backward(loss)`. Those early backward calls still + trigger ZeRO's grad hooks, but we must not run the engine's + post-backward logic (which reduces/clears grads) until the outer/user + backward is active. The depth guard filters out only those pre-user + invocations while still allowing backward calls that happen during + the real user backward. + """ + if self.backward_active_depth == 0: + return + for hook in self._grad_acc_post_hooks: + hook() + + def enter_backward(self): + """Enter backward context. Call at the start of backward pass.""" + self.backward_active_depth += 1 + # Track that backward has been active at some point in this step. + # This is used to detect subsequent gradient hook phases with reentrant checkpointing. + self.backward_seen_this_step = True + + def exit_backward(self): + """Exit backward context. Call at the end of backward pass.""" + if self.backward_active_depth > 0: + self.backward_active_depth -= 1 + + def reset_for_new_step(self): + """Reset state at the start of each forward/backward step.""" + self.backward_seen_this_step = False + self.hooks_fired_this_backward = 0 + self.max_expected_hooks_seen = 0 + self.epilogue_ran_this_backward = False + + def reenter_backward_if_needed(self): + """Re-enter backward context for subsequent phases in reentrant checkpointing. + + With reentrant gradient checkpointing, gradient hooks can fire in multiple phases + within a single backward call. When the epilogue runs after a phase, it calls + exit_backward(), setting backward_active_depth to 0. When the next phase starts, + we need to re-enter backward. + + We detect subsequent phases by checking: + 1. remaining_grad_acc_hooks == 0 (epilogue ran or new backward) + 2. backward_active_depth == 0 (we've exited from previous phase) + 3. backward_seen_this_step == True (backward was active earlier) + + This distinguishes from TiledFusedLogitsLoss which calls backward() during forward - + in that case backward_seen_this_step is False because enter_backward() was never called. + """ + if self.remaining_grad_acc_hooks == 0: + if self.backward_active_depth == 0 and self.backward_seen_this_step: + self.enter_backward() + + def update_hook_state_and_maybe_run_epilogue(self, current_expected_count): + """Update hook state after a gradient hook fires and run epilogue if all hooks have fired. + + With reentrant gradient checkpointing, count_used_parameters_in_backward() returns the + count of params that will execute in the current backward graph. This count grows as + checkpointed regions are recomputed. We track the MAXIMUM count seen to ensure we don't + run the epilogue until all params that will ever participate have been processed. + Counters are reset at forward() time via reset_for_new_step(). + + Args: + current_expected_count: The current expected number of hooks, from + count_used_parameters_in_backward() plus any leaf modules. + """ + self.hooks_fired_this_backward += 1 + self.max_expected_hooks_seen = max(self.max_expected_hooks_seen, current_expected_count) + + # Run epilogue only when we've processed ALL params that will participate. + # This is the maximum count we've seen (accounts for late-joining params + # from reentrant checkpointing) and also excludes unused params. + if self.hooks_fired_this_backward >= self.max_expected_hooks_seen: + self.remaining_grad_acc_hooks = 0 + self.run_grad_acc_post_hooks() + else: + self.remaining_grad_acc_hooks = self.max_expected_hooks_seen - self.hooks_fired_this_backward + + +class ZeROOptimizer(DeepSpeedOptimizer): + """Base class for ZeRO optimizer implementations (stages 1, 2, and 3).""" + + def __init__(self): + self._backward_hook_state = BackwardHookStateManager() + + # Delegate backward hook state management to the manager. + # These properties provide backward compatibility with code that accesses + # these attributes directly (e.g., in stage3.py and stage_1_and_2.py). + @property + def _remaining_grad_acc_hooks(self): + return self._backward_hook_state.remaining_grad_acc_hooks + + @_remaining_grad_acc_hooks.setter + def _remaining_grad_acc_hooks(self, value): + self._backward_hook_state.remaining_grad_acc_hooks = value + + @property + def _backward_active_depth(self): + return self._backward_hook_state.backward_active_depth + + @_backward_active_depth.setter + def _backward_active_depth(self, value): + self._backward_hook_state.backward_active_depth = value + + @property + def _backward_seen_this_step(self): + return self._backward_hook_state.backward_seen_this_step + + @_backward_seen_this_step.setter + def _backward_seen_this_step(self, value): + self._backward_hook_state.backward_seen_this_step = value + + @property + def _epilogue_ran_this_backward(self): + return self._backward_hook_state.epilogue_ran_this_backward + + @_epilogue_ran_this_backward.setter + def _epilogue_ran_this_backward(self, value): + self._backward_hook_state.epilogue_ran_this_backward = value + + @property + def _hooks_fired_this_backward(self): + return self._backward_hook_state.hooks_fired_this_backward + + @_hooks_fired_this_backward.setter + def _hooks_fired_this_backward(self, value): + self._backward_hook_state.hooks_fired_this_backward = value + + @property + def _max_expected_hooks_seen(self): + return self._backward_hook_state.max_expected_hooks_seen + + @_max_expected_hooks_seen.setter + def _max_expected_hooks_seen(self, value): + self._backward_hook_state.max_expected_hooks_seen = value + + @property + def _grad_acc_post_hooks(self): + return self._backward_hook_state._grad_acc_post_hooks + + @_grad_acc_post_hooks.setter + def _grad_acc_post_hooks(self, value): + self._backward_hook_state._grad_acc_post_hooks = value def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: checkpoint_dir = os.path.join(checkpoint_dir, "zero") @@ -132,38 +337,36 @@ def backward(self, loss, **kwargs): self.exit_backward() def register_grad_acc_post_hook(self, hook): - self._grad_acc_post_hooks.append(hook) + """Register a callback to run when all gradient hooks have fired.""" + self._backward_hook_state.register_grad_acc_post_hook(hook) def unregister_grad_acc_post_hooks(self): - self._grad_acc_post_hooks = [] + """Remove all registered gradient accumulation post hooks.""" + self._backward_hook_state.unregister_grad_acc_post_hooks() def run_grad_acc_post_hooks(self): - # Custom autograd Functions (e.g., TiledFusedLogitsLoss) can invoke - # `torch.autograd.backward()` from their *forward* pass before the user - # ever calls `engine.backward(loss)`. Those early backward calls still - # trigger ZeRO's grad hooks, but we must not run the engine's - # post-backward logic (which reduces/clears grads) until the outer/user - # backward is active. The depth guard filters out only those pre-user - # invocations while still allowing backward calls that happen during - # the real user backward. - if self._backward_active_depth == 0: - return - for hook in self._grad_acc_post_hooks: - hook() + """Run all registered post hooks if backward is active.""" + self._backward_hook_state.run_grad_acc_post_hooks() def enter_backward(self): - self._backward_active_depth += 1 - # Track that backward has been active at some point in this step. - # This is used to detect subsequent gradient hook phases with reentrant checkpointing. - self._backward_seen_this_step = True + """Enter backward context. Call at the start of backward pass.""" + self._backward_hook_state.enter_backward() def exit_backward(self): - if self._backward_active_depth > 0: - self._backward_active_depth -= 1 + """Exit backward context. Call at the end of backward pass.""" + self._backward_hook_state.exit_backward() def clear_backward_seen_flag(self): - """Clear the backward seen flag at the start of each step.""" - self._backward_seen_this_step = False + """Clear the backward seen flag and reset hook counters at the start of each step.""" + self._backward_hook_state.reset_for_new_step() + + def reenter_backward_if_needed(self): + """Re-enter backward context for subsequent phases in reentrant checkpointing.""" + self._backward_hook_state.reenter_backward_if_needed() + + def update_hook_state_and_maybe_run_epilogue(self, current_expected_count): + """Update hook state after a gradient hook fires and run epilogue if all hooks have fired.""" + self._backward_hook_state.update_hook_state_and_maybe_run_epilogue(current_expected_count) def _configure_master_weights(self, fp16_master_weights_and_gradients=False, diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 7a1feee66b09..279e03fc4df6 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1245,12 +1245,13 @@ def independent_gradient_partition_epilogue(self): self.__param_id_to_grad_partition[param.ds_id] if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group ] - # this method gets called after every backward. need to increment - # here because if it gets incremented in backward() the micro step - # id will be off by one when we do the reduce and partition at the. - # start of this method. - # TODO. make this less error prone - self.micro_step_id += 1 + # This method gets called after every backward. With reentrant gradient + # checkpointing, it may be called multiple times per backward pass (once per phase). + # We track that the epilogue ran this backward so we can increment micro_step_id + # at the start of the NEXT forward pass. This ensures all phases within a backward + # use the same micro_step_id value (copy semantics for all, not accumulate). + # The increment is deferred to clear_backward_seen_flag() which runs in forward(). + self._epilogue_ran_this_backward = True def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() @@ -1278,31 +1279,15 @@ def wrapper(param): @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): - if self._remaining_grad_acc_hooks == 0: - self._remaining_grad_acc_hooks = count_used_parameters_in_backward( - non_leaf_params_requiring_grad) + leaf_module_count - # With reentrant gradient checkpointing, gradient hooks can fire in - # multiple phases within a single backward call. The first phase - # triggers _backward_epilogue which calls exit_backward(), setting - # _backward_active_depth to 0. When the next phase starts, we need - # to re-enter backward to ensure post hooks run for that phase too. - # - # We detect this case by checking: - # 1. _backward_active_depth == 0 (we've exited from previous phase) - # 2. _backward_seen_this_step == True (backward was active earlier) - # - # This distinguishes from TiledFusedLogitsLoss which calls backward() - # during forward - in that case _backward_seen_this_step is False - # because enter_backward() was never called. - if self._backward_active_depth == 0 and getattr(self, '_backward_seen_this_step', - False): - self.enter_backward() + # Re-enter backward for subsequent phases in reentrant checkpointing + self.reenter_backward_if_needed() self.reduce_ready_partitions_and_remove_grads(param) - self._remaining_grad_acc_hooks -= 1 - if self._remaining_grad_acc_hooks == 0: - self.run_grad_acc_post_hooks() + # Update hook state and run epilogue if all expected hooks have fired + current_expected = count_used_parameters_in_backward( + non_leaf_params_requiring_grad) + leaf_module_count + self.update_hook_state_and_maybe_run_epilogue(current_expected) self._grad_acc_hooks.append(register_grad_hook(param, reduce_partition_and_remove_grads)) @@ -1318,12 +1303,7 @@ def reduce_partition_and_remove_grads(*notneeded): def make_hook(params): def reduce_leaf_module_grads(module, grad_input, grad_output): - if self._remaining_grad_acc_hooks == 0: - self._remaining_grad_acc_hooks = count_used_parameters_in_backward( - non_leaf_params_requiring_grad) + leaf_module_count - # Re-enter backward for subsequent phases (see comment in reduce_partition_and_remove_grads) - if self._backward_active_depth == 0 and getattr(self, '_backward_seen_this_step', False): - self.enter_backward() + self.reenter_backward_if_needed() for param in params: # this takes care of grads for MoE experts that didn't participate in the current iteration/layer @@ -1331,9 +1311,9 @@ def reduce_leaf_module_grads(module, grad_input, grad_output): param.grad = torch.zeros_like(param) self.reduce_ready_partitions_and_remove_grads(param) - self._remaining_grad_acc_hooks -= 1 - if self._remaining_grad_acc_hooks == 0: - self.run_grad_acc_post_hooks() + current_expected = count_used_parameters_in_backward( + non_leaf_params_requiring_grad) + leaf_module_count + self.update_hook_state_and_maybe_run_epilogue(current_expected) return reduce_leaf_module_grads @@ -1859,6 +1839,24 @@ def zero_grad(self, set_to_none=True): p.grad.detach_() p.grad.zero_() + def clear_backward_seen_flag(self): + """Clear the backward seen flag and increment micro_step_id if epilogue ran. + + This override defers the micro_step_id increment from the epilogue to here. + With reentrant gradient checkpointing, the epilogue may be called multiple + times per backward pass, but we only want to increment micro_step_id once + after the backward is complete. By incrementing here at the start of the + NEXT forward, all phases within a backward use the same micro_step_id value. + """ + # Increment micro_step_id if the epilogue ran during the previous backward. + # This is deferred from independent_gradient_partition_epilogue() to ensure + # all phases within a backward use the same micro_step_id (copy semantics). + if self._epilogue_ran_this_backward: + self.micro_step_id += 1 + + # Call base class to reset flags (including _epilogue_ran_this_backward) + super().clear_backward_seen_flag() + def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. """ @@ -1978,6 +1976,10 @@ def reset_cpu_buffers(self): def _pre_step(self): self.micro_step_id = 0 + # Also reset the epilogue flag so the next iteration starts fresh. + # Without this, the flag from the last backward before step() would cause + # an increment in the next forward(), which is wrong. + self._epilogue_ran_this_backward = False print_rank_0("Inside Step function") see_memory_usage("In step before checking overflow", force=False) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index dfce94006260..6dd82d63b185 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -999,22 +999,10 @@ def create_gradient_handling_hooks(self): def wrapper(param, i): def grad_handling_hook(*notneeded): - if self._remaining_grad_acc_hooks == 0: - self._remaining_grad_acc_hooks = count_used_parameters_in_backward( - all_params_requiring_grad) - # With reentrant gradient checkpointing, gradient hooks can fire in - # multiple phases within a single backward call. Re-enter backward - # for subsequent phases to ensure post hooks run correctly. - # (See detailed comment in stage3.py reduce_partition_and_remove_grads) - if self._backward_active_depth == 0 and getattr(self, '_backward_seen_this_step', - False): - self.enter_backward() - + self.reenter_backward_if_needed() self.process_gradients(param, i) - - self._remaining_grad_acc_hooks -= 1 - if self._remaining_grad_acc_hooks == 0: - self.run_grad_acc_post_hooks() + current_expected = count_used_parameters_in_backward(all_params_requiring_grad) + self.update_hook_state_and_maybe_run_epilogue(current_expected) self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook)) From db1ff062ce84fcf593b536c1d4379fabad9121a1 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 16 Jan 2026 18:32:47 -0800 Subject: [PATCH 10/11] fix for zero1 Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/base_optimizer.py | 49 +++++++- deepspeed/runtime/engine.py | 1 + deepspeed/runtime/zero/stage_1_and_2.py | 45 ++++++- tests/unit/v1/zero/test_zero_user_backward.py | 119 ++++++++++++------ 4 files changed, 167 insertions(+), 47 deletions(-) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 8e1550ace460..f2ca7fd887ac 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -36,8 +36,10 @@ class BackwardHookStateManager: but after checkpoint recomputation, it returns 4 (all params). We must NOT run the epilogue prematurely after Phase 1. - Solution: Track `_max_expected_hooks_seen` across all phases. The epilogue only - runs when `_hooks_fired_this_backward >= _max_expected_hooks_seen`. + Solution: Queue a post-backward callback on the autograd engine at the start of + backward and run the epilogue when the graph task completes. This avoids premature + epilogues across reentrant phases. The `_max_expected_hooks_seen` counter remains + as a fallback when the callback API is unavailable. 2. **TiledFusedLogitsLoss and Similar Custom Autograd Functions**: Some custom autograd functions call `torch.autograd.backward()` from their @@ -65,6 +67,8 @@ class BackwardHookStateManager: epilogue_ran_this_backward: True if epilogue ran (for micro_step_id management) hooks_fired_this_backward: Count of gradient hooks that have fired max_expected_hooks_seen: Maximum expected hook count seen (grows with reentrant) + post_backward_callback_queued: True if a post-backward callback is queued + post_backward_callback_graph_task_id: Graph task id for the queued callback """ def __init__(self): @@ -75,6 +79,8 @@ def __init__(self): self.epilogue_ran_this_backward = False self.hooks_fired_this_backward = 0 self.max_expected_hooks_seen = 0 + self.post_backward_callback_queued = False + self.post_backward_callback_graph_task_id = None def register_grad_acc_post_hook(self, hook): """Register a callback to run when all gradient hooks have fired.""" @@ -119,6 +125,8 @@ def reset_for_new_step(self): self.hooks_fired_this_backward = 0 self.max_expected_hooks_seen = 0 self.epilogue_ran_this_backward = False + self.post_backward_callback_queued = False + self.post_backward_callback_graph_task_id = None def reenter_backward_if_needed(self): """Re-enter backward context for subsequent phases in reentrant checkpointing. @@ -140,6 +148,31 @@ def reenter_backward_if_needed(self): if self.backward_active_depth == 0 and self.backward_seen_this_step: self.enter_backward() + def queue_post_backward_callback(self): + """Queue post-backward hooks to run after the current graph finishes.""" + if self.post_backward_callback_queued: + return True + if self.backward_active_depth == 0: + return False + + engine = getattr(torch.autograd.Variable, "_execution_engine", None) + if engine is None or not hasattr(engine, "queue_callback"): + return False + if not hasattr(torch._C, "_current_graph_task_id"): + return False + + graph_task_id = torch._C._current_graph_task_id() + if graph_task_id == -1: + return False + + def _run_post_backward(): + self.run_grad_acc_post_hooks() + + engine.queue_callback(_run_post_backward) + self.post_backward_callback_queued = True + self.post_backward_callback_graph_task_id = graph_task_id + return True + def update_hook_state_and_maybe_run_epilogue(self, current_expected_count): """Update hook state after a gradient hook fires and run epilogue if all hooks have fired. @@ -156,7 +189,13 @@ def update_hook_state_and_maybe_run_epilogue(self, current_expected_count): self.hooks_fired_this_backward += 1 self.max_expected_hooks_seen = max(self.max_expected_hooks_seen, current_expected_count) - # Run epilogue only when we've processed ALL params that will participate. + # Prefer running post-backward hooks via autograd engine callback when available. + # This avoids premature epilogues with reentrant checkpointing. + if self.queue_post_backward_callback(): + self.remaining_grad_acc_hooks = max(self.max_expected_hooks_seen - self.hooks_fired_this_backward, 0) + return + + # Fallback: Run epilogue only when we've processed ALL params that will participate. # This is the maximum count we've seen (accounts for late-joining params # from reentrant checkpointing) and also excludes unused params. if self.hooks_fired_this_backward >= self.max_expected_hooks_seen: @@ -368,6 +407,10 @@ def update_hook_state_and_maybe_run_epilogue(self, current_expected_count): """Update hook state after a gradient hook fires and run epilogue if all hooks have fired.""" self._backward_hook_state.update_hook_state_and_maybe_run_epilogue(current_expected_count) + def queue_post_backward_callback(self): + """Queue post-backward hooks to run after autograd completes.""" + return self._backward_hook_state.queue_post_backward_callback() + def _configure_master_weights(self, fp16_master_weights_and_gradients=False, bf16_master_weights_and_gradients=False, diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 51c6974d4577..5223dafb0a19 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2345,6 +2345,7 @@ def _backward_prologue(self): if isinstance(self.optimizer, ZeROOptimizer): self.optimizer.backward_prologue() self.optimizer.enter_backward() + self.optimizer.queue_post_backward_callback() if self.zenflow and self.auto_update: self.optimizer.zenflow_state ^= 1 diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 6dd82d63b185..06a9c45e88cc 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -878,16 +878,41 @@ def independent_gradient_partition_epilogue(self): device=get_accelerator().current_device_name(), param_group_idx=i, return_tensor_list=True) + # Clear all_grad_tensors after use. With reentrant checkpointing, + # the epilogue may run multiple times per backward pass. Each time, + # we read the cumulative grad_accum (which PyTorch naturally accumulates) + # and the final phase will have all gradients. self.all_grad_tensors[i] = None self._release_ipg_buffers() - # No need to keep the gradients anymore. - # All gradients required by the step - # are in self.averaged_gradients - self.zero_grad(set_to_none=True) + # Clear param.grad so safe_get_full_grad() goes through the proper _hp_mapping + # path (which does all_reduce for ZeRO-2). Keep grad_accum intact for reentrant + # checkpointing where gradients need to accumulate across multiple phases. + # grad_accum is cleared in clear_backward_seen_flag() at the start of next forward. + self._clear_param_grad_only() + self._epilogue_ran_this_backward = True + see_memory_usage("End ipg_epilogue") + def clear_backward_seen_flag(self): + """Clear the backward seen flag and do deferred cleanup. + + With reentrant gradient checkpointing, the epilogue may run multiple times + per backward pass (once per phase). We defer clearing grad_accum until here + (called at the start of the next forward) to ensure all phases have completed. + + Note: param.grad is cleared in the epilogue via _clear_param_grad_only() to + ensure safe_get_full_grad() works correctly. Only grad_accum is deferred. + """ + if self._epilogue_ran_this_backward: + # Clear grad_accum for next step. param.grad is already cleared in epilogue. + for group in self.bit16_groups: + for p in group: + p.grad_accum = None + + super().clear_backward_seen_flag() + # resets all partition to no reduced # sets remaining grads to the total number of grads in each partition # set is grad computed to false for all grads in partition @@ -1809,6 +1834,18 @@ def zero_grad(self, set_to_none=True): p.grad.detach_() p.grad.zero_() + def _clear_param_grad_only(self): + """Clear only param.grad but keep grad_accum intact. + + This is used at the end of the epilogue to ensure safe_get_full_grad() goes + through the proper _hp_mapping path (which does all_reduce for ZeRO-2), while + preserving grad_accum for reentrant checkpointing where gradients need to + accumulate across multiple backward phases. + """ + for group in self.bit16_groups: + for p in group: + p.grad = None + def _model_parallel_all_reduce(self, tensor, op): """ Perform all reduce within model parallel group, if any. """ diff --git a/tests/unit/v1/zero/test_zero_user_backward.py b/tests/unit/v1/zero/test_zero_user_backward.py index 99deb9db75cb..b094c22c2fb5 100644 --- a/tests/unit/v1/zero/test_zero_user_backward.py +++ b/tests/unit/v1/zero/test_zero_user_backward.py @@ -146,19 +146,25 @@ def collect_ddp_gradients(model_ddp): def compare_gradients(grads_ddp, grads_ds, step_info=""): - """Compare gradients between DDP and DeepSpeed""" + """Compare gradients between DDP and DeepSpeed. + + Uses PyTorch's default tolerances for the tensor dtype (e.g., for bfloat16: + rtol=1.6e-2, atol=1e-5). The 2-layer model keeps differences small enough + to pass with default tolerances even after multiple optimizer steps. + """ step_suffix = f" at {step_info}" if step_info else "" assert len(grads_ddp) == len(grads_ds), \ f"Different number of parameters with gradients{step_suffix}: DDP={len(grads_ddp)}, DeepSpeed={len(grads_ds)}" for name in grads_ddp.keys(): assert name in grads_ds, f"Parameter {name} missing in DeepSpeed gradients{step_suffix}" - # Convert both to fp32 for comparison in case of dtype mismatch - grads_ddp_fp32 = grads_ddp[name].float() - grads_ds_fp32 = grads_ds[name].float() - allclose_on_all_ranks(grads_ddp_fp32, - grads_ds_fp32, - assert_message=f"Gradients differ for parameter {name}{step_suffix}") + grad_ddp = grads_ddp[name] + grad_ds = grads_ds[name] + # If dtypes differ, convert ds to match ddp's dtype + if grad_ds.dtype != grad_ddp.dtype: + grad_ds = grad_ds.to(grad_ddp.dtype) + # Use PyTorch's default tolerances for the dtype + allclose_on_all_ranks(grad_ddp, grad_ds, assert_message=f"Gradients differ for parameter {name}{step_suffix}") def collect_ddp_parameters(model_ddp): @@ -1151,12 +1157,30 @@ def test_scale_with_torch_autocast(self, zero_stage): model_engine.destroy() +class NonCheckpointedModel(torch.nn.Module): + """Model without gradient checkpointing, used as reference for comparison.""" + + def __init__(self, hidden_dim): + super().__init__() + self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) + + def forward(self, x): + x = self.linear1(x) + x = torch.nn.functional.relu(x) + x = self.linear2(x) + return x + + class CheckpointedModel(torch.nn.Module): """Model that uses gradient checkpointing with configurable use_reentrant setting. This model is designed to test the interaction between ZeRO-3 and gradient checkpointing with both reentrant (use_reentrant=True) and non-reentrant (use_reentrant=False) modes. + + Uses 2 layers to minimize numerical divergence from bfloat16 precision + accumulation over multiple optimizer steps. """ def __init__(self, hidden_dim, use_reentrant=True): @@ -1164,50 +1188,48 @@ def __init__(self, hidden_dim, use_reentrant=True): self.use_reentrant = use_reentrant self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim) self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) - self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim) def _checkpointed_block(self, x): """Block that will be checkpointed""" x = self.linear1(x) x = torch.nn.functional.relu(x) - x = self.linear2(x) return x def forward(self, x): - # Use gradient checkpointing on the middle block + # Use gradient checkpointing on the first block if self.training: from torch.utils.checkpoint import checkpoint x = checkpoint(self._checkpointed_block, x, use_reentrant=self.use_reentrant) else: x = self._checkpointed_block(x) - x = self.linear3(x) + x = self.linear2(x) return x +@pytest.mark.parametrize("zero_stage", [1, 2, 3]) @pytest.mark.parametrize("use_reentrant", [True, False]) class TestZeroUserBackwardWithCheckpointing(DistributedTest): - """Test ZeRO-3 with gradient checkpointing and non-scalar backward. + """Test ZeRO with gradient checkpointing and non-scalar backward. This test class validates the interaction between: - 1. ZeRO-3 parameter partitioning + 1. ZeRO parameter partitioning (stages 1 and 3) 2. Gradient checkpointing (both reentrant and non-reentrant modes) 3. Non-scalar backward (tensor.backward(gradient=...)) - Both use_reentrant=True and use_reentrant=False are supported with ZeRO-3. + Both use_reentrant=True and use_reentrant=False are supported with ZeRO. Note: When using use_reentrant=True, input tensors should have requires_grad=True for proper gradient computation through the checkpointed region. """ world_size = 2 - def test_checkpointed_non_scalar_backward_zero3(self, use_reentrant): - """Test that gradient checkpointing works with ZeRO-3 and non-scalar backward. + def test_checkpointed_non_scalar_backward(self, zero_stage, use_reentrant): + """Test that gradient checkpointing works with ZeRO and non-scalar backward. - Verifies that tensor.backward(gradient=...) works correctly with ZeRO-3 + Verifies that tensor.backward(gradient=...) works correctly with ZeRO and gradient checkpointing in both reentrant and non-reentrant modes. """ hidden_dim = 8 batch_size = 2 - zero_stage = 3 # Initialize distributed environment device, rank, dtype = initialize_distributed() @@ -1274,15 +1296,14 @@ def test_checkpointed_non_scalar_backward_zero3(self, use_reentrant): model_engine.destroy() - def test_checkpointed_scalar_backward_zero3(self, use_reentrant): - """Test that gradient checkpointing works with ZeRO-3 and scalar backward. + def test_checkpointed_scalar_backward(self, zero_stage, use_reentrant): + """Test that gradient checkpointing works with ZeRO and scalar backward. - Verifies that scalar loss.backward() works correctly with ZeRO-3 and + Verifies that scalar loss.backward() works correctly with ZeRO and gradient checkpointing in both reentrant and non-reentrant modes. """ hidden_dim = 8 batch_size = 2 - zero_stage = 3 # Initialize distributed environment device, rank, dtype = initialize_distributed() @@ -1345,51 +1366,69 @@ def test_checkpointed_scalar_backward_zero3(self, use_reentrant): model_engine.destroy() - def test_checkpointed_multiple_backward_zero3(self, use_reentrant): - """Test multiple backward passes with checkpointing and ZeRO-3. + def test_checkpointed_multiple_backward(self, zero_stage, use_reentrant): + """Test multiple backward passes with checkpointing and ZeRO. Verifies that consecutive training iterations work correctly with - gradient checkpointing in both reentrant and non-reentrant modes. + gradient checkpointing. Compares gradients with DDP at all iterations + to verify correctness. Uses PyTorch Adam for both to ensure fair comparison. """ - hidden_dim = 8 batch_size = 2 - zero_stage = 3 num_iterations = 3 # Initialize distributed environment device, rank, dtype = initialize_distributed() - # Create DeepSpeed model with ZeRO-3 + # Create DDP model for reference with PyTorch Adam torch.manual_seed(42) - model_ds = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + model_ddp = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + model_ddp = model_ddp.to(device=device, dtype=dtype) + model_ddp = DDP(model_ddp, device_ids=[rank], output_device=rank) + optimizer_ddp = torch.optim.Adam(model_ddp.parameters(), lr=1e-3) + # Create DeepSpeed model WITH checkpointing, using PyTorch Adam + torch.manual_seed(42) + model_ds = CheckpointedModel(hidden_dim=hidden_dim, use_reentrant=use_reentrant) + optimizer_ds = torch.optim.Adam(model_ds.parameters(), lr=1e-3) config = get_config_dict(zero_stage) model_engine, _, _, _ = deepspeed.initialize(config=config, model=model_ds, - model_parameters=model_ds.parameters()) + model_parameters=model_ds.parameters(), + optimizer=optimizer_ds) for iteration in range(num_iterations): - # Create input data with different seed each iteration - # For reentrant checkpointing (use_reentrant=True), inputs need requires_grad=True + # Use same random seed for both models torch.manual_seed(123 + iteration) x = torch.randn(batch_size, hidden_dim, device=device, dtype=dtype, requires_grad=True) - # Forward and non-scalar backward - output_ds = model_engine(x) - grad_output_ds = torch.ones_like(output_ds) - output_ds.backward(grad_output_ds) - - # Synchronize before collecting gradients to ensure async operations complete + # DDP: forward and backward + optimizer_ddp.zero_grad() + x_ddp = x.clone().detach().requires_grad_(True) + output_ddp = model_ddp(x_ddp) + output_ddp.backward(torch.ones_like(output_ddp)) get_accelerator().synchronize() dist.barrier() + ddp_grads = collect_ddp_gradients(model_ddp) - # Collect and verify gradients + # DeepSpeed: forward and backward + x_ds = x.clone().detach().requires_grad_(True) + output_ds = model_engine(x_ds) + output_ds.backward(torch.ones_like(output_ds)) + get_accelerator().synchronize() + dist.barrier() ds_grads = collect_gradients_safe(model_engine) + + # Verify gradients were computed assert len(ds_grads) > 0, \ f"No gradients at iteration {iteration} with use_reentrant={use_reentrant}" - # Run optimizer step + # Compare gradients with DDP - using same optimizer so should match closely + # Small differences at later iterations are expected due to bfloat16 precision + compare_gradients(ddp_grads, ds_grads, f"iteration {iteration} with use_reentrant={use_reentrant}") + + # Run optimizer steps on both models + optimizer_ddp.step() model_engine.step() model_engine.destroy() From e2de9a4aa67ab8963be559742f18410c8d0816fd Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 16 Jan 2026 19:57:12 -0800 Subject: [PATCH 11/11] fix micro step id count Signed-off-by: Masahiro Tanaka --- deepspeed/runtime/zero/stage3.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 279e03fc4df6..d36901150ebb 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1825,6 +1825,11 @@ def zero_grad(self, set_to_none=True): Zero FP16 parameter grads. """ self.micro_step_id = 0 + # Reset the epilogue flag so the next forward doesn't increment micro_step_id. + # Without this, calling zero_grad() between backward and forward would cause + # micro_step_id to be incremented at the next forward, leading to incorrect + # gradient accumulation behavior. + self._epilogue_ran_this_backward = False # FP32 grad should never exist. # For speed, set model fp16 grad to None by default