Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,9 +1179,13 @@ def allreduce_and_copy_with_multiple_ranks(self,
log=log,
divide=divide,
process_group=process_group)
if self.overlap_comm and not get_accelerator().resolves_data_dependency():
allreduced.record_stream(self.reduction_stream)
for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks):
if dist.get_rank(group=process_group) == bucket_rank:
buf.copy_(synced)
if self.overlap_comm and not get_accelerator().resolves_data_dependency():
buf.record_stream(self.reduction_stream)

def allreduce_and_scatter(self,
bucket,
Expand Down Expand Up @@ -1746,9 +1750,13 @@ def allreduce_and_copy(self,
divide=divide,
process_group=process_group,
)
if self.overlap_comm and not get_accelerator().resolves_data_dependency():
allreduced.record_stream(stream)
if rank is None or rank == dist.get_rank(group=self.dp_process_group):
for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)):
buf.copy_(synced)
if self.overlap_comm and not get_accelerator().resolves_data_dependency():
buf.record_stream(stream)

def allreduce_no_retain(
self,
Expand Down
97 changes: 97 additions & 0 deletions tests/unit/v1/zero/test_overlap_comm_record_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from contextlib import nullcontext

import torch

import deepspeed.runtime.zero.stage_1_and_2 as zero_stage12
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer


class _FakeTensor:

def __init__(self):
self.recorded_streams = []
self.copied_from = None

def copy_(self, other):
self.copied_from = other
return self

def record_stream(self, stream):
self.recorded_streams.append(stream)


class _FakeAccelerator:

def __init__(self, resolves_data_dependency, current_device_name="cpu"):
self._resolves_data_dependency = resolves_data_dependency
self._current_device_name = current_device_name

def resolves_data_dependency(self):
return self._resolves_data_dependency

def stream(self, stream):
return nullcontext()

def current_stream(self):
return object()

def current_device_name(self):
return self._current_device_name

def synchronize(self):
return None


def _build_overlap_optimizer(monkeypatch, *, resolves_data_dependency):
optimizer = DeepSpeedZeroOptimizer.__new__(DeepSpeedZeroOptimizer)
optimizer.overlap_comm = True
optimizer.reduction_stream = object()
optimizer.dp_process_group = object()
optimizer.previous_reduced_grads = {}

allreduced = _FakeTensor()
synced = [_FakeTensor(), _FakeTensor()]

optimizer.allreduce_bucket = lambda *args, **kwargs: allreduced
optimizer.unflatten = lambda allreduced_tensor, small_bucket: synced

monkeypatch.setattr(
zero_stage12,
"get_accelerator",
lambda: _FakeAccelerator(resolves_data_dependency),
)
monkeypatch.setattr(zero_stage12.dist, "get_rank", lambda group=None: 0)
return optimizer, allreduced, synced


def test_allreduce_and_copy_records_stream_for_overlap_comm(monkeypatch):
optimizer, allreduced, synced = _build_overlap_optimizer(monkeypatch, resolves_data_dependency=False)
bucket = [_FakeTensor(), _FakeTensor()]

optimizer.allreduce_and_copy(bucket, torch.float16)

assert allreduced.recorded_streams == [optimizer.reduction_stream]
for buf, expected_synced in zip(bucket, synced):
assert buf.copied_from is expected_synced
assert buf.recorded_streams == [optimizer.reduction_stream]


def test_allreduce_and_copy_with_multiple_ranks_records_only_local_buffers(monkeypatch):
optimizer, allreduced, synced = _build_overlap_optimizer(monkeypatch, resolves_data_dependency=False)
bucket = [_FakeTensor(), _FakeTensor()]

optimizer.allreduce_and_copy_with_multiple_ranks(
bucket,
torch.float16,
bucket_ranks=[0, 1],
)

assert allreduced.recorded_streams == [optimizer.reduction_stream]
assert bucket[0].copied_from is synced[0]
assert bucket[0].recorded_streams == [optimizer.reduction_stream]
assert bucket[1].copied_from is None
assert bucket[1].recorded_streams == []
Loading