Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,14 @@ def _configure_distributed_model(self, model):
self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
if self.sequence_parallel_size > 1:
# Inserted Warning for PyTorch < 2.3
Comment thread
Flink-ddd marked this conversation as resolved.
if not required_torch_version(min_version=2.3):
logger.warning(
"DeepSpeed Sequence Parallelism (Ulysses) with PyTorch < 2.3 may encounter "
"rank indexing errors during backward pass when sp_size < world_size. "
"Please use the weighted all-reduce workaround shown in the regression test "
"(https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/sequence_parallelism/test_ulysses.py) "
"or upgrade to PyTorch 2.3+.")
self.communication_data_type = self._config.seq_parallel_communication_data_type
self.seq_parallel_group = groups._get_sequence_parallel_group()

Expand Down
6 changes: 6 additions & 0 deletions docs/_tutorials/ulysses-alst-sequence-parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ In theory you could just average `losses_per_rank`, but the system supports vari

## Nuances

### Note on PyTorch Versions < 2.3

If you are using Sequence Parallelism with **PyTorch version < 2.3**, you may encounter an `IndexError: tuple index out of range` during the backward pass when `sequence_parallel_size < world_size`. This is due to a known issue in the `torch.distributed.all_gather` backward implementation in older versions.

**Workaround:** We recommend using a **weighted `all_reduce` pattern** instead of `all_gather` for loss averaging. You can refer to our [regression test case](https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/sequence_parallelism/test_ulysses.py) for a code example of this workaround.

### Why do labels need to be pre-shifted?

When using batch sharding one can't let the upstream `loss` function do the labels shifting. Here is why:
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/sequence_parallelism/test_ulysses.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,58 @@ def __init__(self):
assert torch.allclose(
fpdt_output, baseline_output_shuffled, rtol=0.01, atol=0.1
), f"rank {dist.get_rank()}, sp size: {dist.get_world_size(spg)}, input_tensor: {input_tensor.shape}, fpdt_input_tensor: {fpdt_input_tensor.shape}, fpdt_output: {fpdt_output.shape}, baseline_output_shuffled: {baseline_output_shuffled.shape},{torch.max(torch.abs(fpdt_output - baseline_output_shuffled))}"


@pytest.mark.parametrize("sp_size", [2])
class TestUlyssesLossBackward(DistributedTest):
world_size = 4

def test_sp_loss_backward_stability(self, sp_size: int) -> None:
"""
Regression test for Issue #7672.
Verifies that using all_reduce for loss aggregation is stable
when sequence_parallel_size < world_size, preventing IndexError.
"""
skip_on_arch(min_arch=8)

# Setup
dp_size = self.world_size // sp_size
model = SimpleModel(4)
ds_engine, _, _, _ = initialize(
model=model,
config_params={
"train_batch_size": 8,
"data_parallel_size": dp_size,
"sequence_parallel_size": sp_size
},
)

sp_group = ds_engine.seq_parallel_group

# Simulate Loss on each rank
rank = dist.get_rank()
local_loss = torch.tensor(float(rank + 1), device=ds_engine.device, requires_grad=True)
local_weight = torch.tensor(1.0, device=ds_engine.device)

# Numerator: Weighted Loss summation
weighted_loss = local_loss * local_weight
dist.all_reduce(weighted_loss, op=dist.ReduceOp.SUM, group=sp_group)

# B. Denominator: Sum of total weights
total_weight = local_weight.clone()
dist.all_reduce(total_weight, op=dist.ReduceOp.SUM, group=sp_group)

# C. Calculate the final loss
dist_loss = weighted_loss / total_weight

# Backward Pass verification
try:
dist_loss.backward()
except IndexError as e:
pytest.fail(f"Backward crashed with IndexError: {e}")

# Verify Gradients
# Loss = (L1*1 + L2*1) / 2 = 0.5*L1 + 0.5*L2
expected_grad = 0.5
assert torch.allclose(local_loss.grad, torch.tensor(expected_grad, device=ds_engine.device)), \
f"Gradient mismatch! Expected {expected_grad}, got {local_loss.grad}"