diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6123c0359a03..5f8585d65d76 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1416,6 +1416,14 @@ def _configure_distributed_model(self, model): self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() self.sequence_parallel_size = groups._get_sequence_parallel_world_size() if self.sequence_parallel_size > 1: + # Inserted Warning for PyTorch < 2.3 + if not required_torch_version(min_version=2.3): + logger.warning( + "DeepSpeed Sequence Parallelism (Ulysses) with PyTorch < 2.3 may encounter " + "rank indexing errors during backward pass when sp_size < world_size. " + "Please use the weighted all-reduce workaround shown in the regression test " + "(https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/sequence_parallelism/test_ulysses.py) " + "or upgrade to PyTorch 2.3+.") self.communication_data_type = self._config.seq_parallel_communication_data_type self.seq_parallel_group = groups._get_sequence_parallel_group() diff --git a/docs/_tutorials/ulysses-alst-sequence-parallelism.md b/docs/_tutorials/ulysses-alst-sequence-parallelism.md index d03535ca3992..4d426913108a 100644 --- a/docs/_tutorials/ulysses-alst-sequence-parallelism.md +++ b/docs/_tutorials/ulysses-alst-sequence-parallelism.md @@ -222,6 +222,12 @@ In theory you could just average `losses_per_rank`, but the system supports vari ## Nuances +### Note on PyTorch Versions < 2.3 + +If you are using Sequence Parallelism with **PyTorch version < 2.3**, you may encounter an `IndexError: tuple index out of range` during the backward pass when `sequence_parallel_size < world_size`. This is due to a known issue in the `torch.distributed.all_gather` backward implementation in older versions. + +**Workaround:** We recommend using a **weighted `all_reduce` pattern** instead of `all_gather` for loss averaging. You can refer to our [regression test case](https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/sequence_parallelism/test_ulysses.py) for a code example of this workaround. + ### Why do labels need to be pre-shifted? When using batch sharding one can't let the upstream `loss` function do the labels shifting. Here is why: diff --git a/tests/unit/sequence_parallelism/test_ulysses.py b/tests/unit/sequence_parallelism/test_ulysses.py index bd20900a3d1e..7c1371864073 100644 --- a/tests/unit/sequence_parallelism/test_ulysses.py +++ b/tests/unit/sequence_parallelism/test_ulysses.py @@ -255,3 +255,58 @@ def __init__(self): assert torch.allclose( fpdt_output, baseline_output_shuffled, rtol=0.01, atol=0.1 ), f"rank {dist.get_rank()}, sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}" + + +@pytest.mark.parametrize("sp_size", [2]) +class TestUlyssesLossBackward(DistributedTest): + world_size = 4 + + def test_sp_loss_backward_stability(self, sp_size: int) -> None: + """ + Regression test for Issue #7672. + Verifies that using all_reduce for loss aggregation is stable + when sequence_parallel_size < world_size, preventing IndexError. + """ + skip_on_arch(min_arch=8) + + # Setup + dp_size = self.world_size // sp_size + model = SimpleModel(4) + ds_engine, _, _, _ = initialize( + model=model, + config_params={ + "train_batch_size": 8, + "data_parallel_size": dp_size, + "sequence_parallel_size": sp_size + }, + ) + + sp_group = ds_engine.seq_parallel_group + + # Simulate Loss on each rank + rank = dist.get_rank() + local_loss = torch.tensor(float(rank + 1), device=ds_engine.device, requires_grad=True) + local_weight = torch.tensor(1.0, device=ds_engine.device) + + # Numerator: Weighted Loss summation + weighted_loss = local_loss * local_weight + dist.all_reduce(weighted_loss, op=dist.ReduceOp.SUM, group=sp_group) + + # B. Denominator: Sum of total weights + total_weight = local_weight.clone() + dist.all_reduce(total_weight, op=dist.ReduceOp.SUM, group=sp_group) + + # C. Calculate the final loss + dist_loss = weighted_loss / total_weight + + # Backward Pass verification + try: + dist_loss.backward() + except IndexError as e: + pytest.fail(f"Backward crashed with IndexError: {e}") + + # Verify Gradients + # Loss = (L1*1 + L2*1) / 2 = 0.5*L1 + 0.5*L2 + expected_grad = 0.5 + assert torch.allclose(local_loss.grad, torch.tensor(expected_grad, device=ds_engine.device)), \ + f"Gradient mismatch! Expected {expected_grad}, got {local_loss.grad}"