Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions csrc/compile/z1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,21 @@ class Z1CustomOpExecutor : public CustomOpExecutor {
std::unordered_map<long, at::Tensor> grad_tensors_;
};

static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
namespace {

at::cuda::CUDAStream get_rs_stream()
{
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
return rs_stream;
}

at::cuda::CUDAStream get_copy_stream()
{
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
return copy_stream;
}

} // namespace

void register_graph_z1(long graph_id, const std::vector<long>& ds_ids)
{
Expand All @@ -100,8 +113,8 @@ void register_graph_z1(long graph_id, const std::vector<long>& ds_ids)
reduce_buckets,
ds_ids,
nccl_comm,
rs_stream,
copy_stream,
get_rs_stream(),
get_copy_stream(),
pre_div_reduce);
}

Expand Down
21 changes: 17 additions & 4 deletions csrc/compile/z2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,21 @@ class Z2CustomOpExecutor : public CustomOpExecutor {
}
};

static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
namespace {

at::cuda::CUDAStream get_rs_stream()
{
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
return rs_stream;
}

at::cuda::CUDAStream get_copy_stream()
{
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
return copy_stream;
}

} // namespace

void register_graph_z2(long graph_id, const std::vector<long>& ds_ids)
{
Expand All @@ -96,8 +109,8 @@ void register_graph_z2(long graph_id, const std::vector<long>& ds_ids)
reduce_buckets,
ds_ids,
nccl_comm,
rs_stream,
copy_stream,
get_rs_stream(),
get_copy_stream(),
pre_div_reduce);
}

Expand Down
48 changes: 38 additions & 10 deletions csrc/compile/z3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,39 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
std::unordered_map<long, long> param_use_count_;
};

static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true);
static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true);
namespace {

at::cuda::CUDAStream get_ag_stream()
{
static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true);
return ag_stream;
}

at::cuda::CUDAStream get_rs_stream()
{
static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true);
return rs_stream;
}

at::cuda::CUDAStream get_copy_stream()
{
static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true);
return copy_stream;
}

at::cuda::CUDAStream get_offload_stream()
{
static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true);
return offload_stream;
}

at::cuda::CUDAStream get_reload_stream()
{
static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true);
return reload_stream;
}

} // namespace

void register_graph_z3(long graph_id, const std::vector<long>& ds_ids)
{
Expand All @@ -422,11 +450,11 @@ void register_graph_z3(long graph_id, const std::vector<long>& ds_ids)
reduce_buckets,
ds_ids,
nccl_comm,
ag_stream,
rs_stream,
copy_stream,
offload_stream,
reload_stream,
get_ag_stream(),
get_rs_stream(),
get_copy_stream(),
get_offload_stream(),
get_reload_stream(),
pre_div_reduce);
}

Expand Down
34 changes: 23 additions & 11 deletions csrc/includes/deepcompile.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,7 @@ class DSParam {
grad_buffer_(grad_buffer),
partitioned_(partitioned),
offset_(offset),
persistent_(persistent),
offload_stream_(at::cuda::getStreamFromPool()),
reload_stream_(at::cuda::getStreamFromPool())
persistent_(persistent)
{
}

Expand Down Expand Up @@ -302,18 +300,19 @@ class DSParam {
{
// If a reloaded tensor exists, offload its data back to ds_tensor_
if (ds_reload_tensor_.defined()) {
auto offload_stream = getOffloadStream();
auto comp_stream = at::cuda::getCurrentCUDAStream();
comp_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
// Record completion and wait on the offload stream
comp_done_event_->record(comp_stream);
comp_done_event_->block(offload_stream_);
comp_done_event_->block(offload_stream);
offload_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);

{
at::cuda::CUDAStreamGuard guard(offload_stream_);
at::cuda::CUDAStreamGuard guard(offload_stream);
ds_tensor_.copy_(ds_reload_tensor_, /*non_blocking=*/true);
ds_reload_tensor_.reset(); // Clear the reloaded tensor
offload_done_event_->record(offload_stream_);
offload_done_event_->record(offload_stream);
}
// Reset the reload event to indicate that no valid reload is present.
if (reload_done_event_) { reload_done_event_.reset(); }
Expand All @@ -324,26 +323,39 @@ class DSParam {
{
// Reload only if the current ds_tensor_ is on CPU
if (ds_tensor_.device().is_cpu()) {
auto reload_stream = getReloadStream();
auto comp_stream = at::cuda::getCurrentCUDAStream();
comp_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);
// Record and wait on the reload stream
comp_done_event_->record(comp_stream);
comp_done_event_->block(reload_stream_);
comp_done_event_->block(reload_stream);
reload_done_event_ = std::make_shared<at::cuda::CUDAEvent>(cudaEventDisableTiming);

{
at::cuda::CUDAStreamGuard guard(reload_stream_);
at::cuda::CUDAStreamGuard guard(reload_stream);
ds_reload_tensor_ =
at::empty_like(ds_tensor_, ds_tensor_.options().device(torch::kCUDA));
ds_reload_tensor_.copy_(ds_tensor_, /*non_blocking=*/true);
reload_done_event_->record(reload_stream_);
reload_done_event_->record(reload_stream);
}
// Reset offload_done_event if it exists to clear any stale offload state.
if (offload_done_event_) { offload_done_event_.reset(); }
}
}

private:
at::cuda::CUDAStream getOffloadStream()
{
if (!offload_stream_) { offload_stream_.emplace(at::cuda::getStreamFromPool()); }
return *offload_stream_;
}

at::cuda::CUDAStream getReloadStream()
{
if (!reload_stream_) { reload_stream_.emplace(at::cuda::getStreamFromPool()); }
return *reload_stream_;
}

long id_;
std::vector<int64_t> shape_;
at::ScalarType ds_dtype_;
Expand All @@ -355,8 +367,8 @@ class DSParam {
bool persistent_; // for Z3
mutable bool is_reloaded = false;

at::cuda::CUDAStream offload_stream_;
at::cuda::CUDAStream reload_stream_;
std::optional<at::cuda::CUDAStream> offload_stream_;
std::optional<at::cuda::CUDAStream> reload_stream_;
std::shared_ptr<at::cuda::CUDAEvent> comp_done_event_;
std::shared_ptr<at::cuda::CUDAEvent> offload_done_event_;
std::shared_ptr<at::cuda::CUDAEvent> reload_done_event_;
Expand Down
135 changes: 81 additions & 54 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,56 +556,60 @@ def jit_load(self, verbose=True):
if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch():
self.build_for_cpu = not torch.cuda.is_available()

saved_jit_mode = self.jit_mode
self.jit_mode = True
torch_arch_list_present = "TORCH_CUDA_ARCH_LIST" in os.environ
torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
normalized_arch_list = torch_arch_list.strip() if torch_arch_list is not None else None
self._jit_arch_list = normalized_arch_list or None
from torch.utils.cpp_extension import load

start_build = time.time()
sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()]
extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()]

# Stash TORCH_CUDA_ARCH_LIST to restore after build.
torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")

nvcc_args = self.strip_empty_entries(self.nvcc_args())
cxx_args = self.strip_empty_entries(self.cxx_args())

cxx_args.append("-UC10_USE_GLOG")
nvcc_args.append("-UC10_USE_GLOG")
if isinstance(self, CUDAOpBuilder):
if not self.build_for_cpu and self.enable_bf16:
cxx_args.append("-DBF16_AVAILABLE")
nvcc_args.append("-DBF16_AVAILABLE")
nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__")
nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__")
nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")

if self.is_rocm_pytorch():
cxx_args.append("-D__HIP_PLATFORM_AMD__=1")
os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch()
cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size())

op_module = load(name=self.name,
sources=self.strip_empty_entries(sources),
extra_include_paths=self.strip_empty_entries(extra_include_paths),
extra_cflags=cxx_args,
extra_cuda_cflags=nvcc_args,
extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
with_cuda=True if (isinstance(self, CUDAOpBuilder) and not self.build_for_cpu) else None,
verbose=verbose)

build_duration = time.time() - start_build
if verbose:
print(f"Time to load {self.name} op: {build_duration} seconds")

# Restore TORCH_CUDA_ARCH_LIST to its original state.
if torch_arch_list is not None:
os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
elif "TORCH_CUDA_ARCH_LIST" in os.environ:
del os.environ["TORCH_CUDA_ARCH_LIST"]
try:
nvcc_args = self.strip_empty_entries(self.nvcc_args())
cxx_args = self.strip_empty_entries(self.cxx_args())

cxx_args.append("-UC10_USE_GLOG")
nvcc_args.append("-UC10_USE_GLOG")
if isinstance(self, CUDAOpBuilder):
if not self.build_for_cpu and self.enable_bf16:
cxx_args.append("-DBF16_AVAILABLE")
nvcc_args.append("-DBF16_AVAILABLE")
nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__")
nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__")
nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")

if self.is_rocm_pytorch():
cxx_args.append("-D__HIP_PLATFORM_AMD__=1")
os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch()
cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size())

op_module = load(name=self.name,
sources=self.strip_empty_entries(sources),
extra_include_paths=self.strip_empty_entries(extra_include_paths),
extra_cflags=cxx_args,
extra_cuda_cflags=nvcc_args,
extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
with_cuda=True if (isinstance(self, CUDAOpBuilder) and not self.build_for_cpu) else None,
verbose=verbose)

build_duration = time.time() - start_build
if verbose:
print(f"Time to load {self.name} op: {build_duration} seconds")

__class__._loaded_ops[self.name] = op_module
__class__._loaded_ops[self.name] = op_module

return op_module
return op_module
finally:
if torch_arch_list_present:
os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
else:
os.environ.pop("TORCH_CUDA_ARCH_LIST", None)
self._jit_arch_list = None
self.jit_mode = saved_jit_mode


class CUDAOpBuilder(OpBuilder):
Expand All @@ -614,11 +618,15 @@ def compute_capability_args(self, cross_compile_archs=None):
"""
Returns nvcc compute capability compile flags.

1. Under ``jit_mode`` the visible-card architectures are detected,
``TORCH_CUDA_ARCH_LIST`` is set accordingly, and an **empty list**
is returned so that PyTorch generates the ``-gencode`` flags
itself (avoiding duplicates). See
https://github.com/deepspeedai/DeepSpeed/issues/7972
1. Under ``jit_mode``, the precedence is:
a. preserved ``TORCH_CUDA_ARCH_LIST`` captured by ``jit_load()``
b. live ``TORCH_CUDA_ARCH_LIST`` from the environment
c. runtime device probing when the process is not in a bad-fork context
d. an error when no explicit arch list exists in a bad-fork context

JIT mode auto-adds ``+PTX`` to the highest compute capability when
no entry already carries it, then sets ``TORCH_CUDA_ARCH_LIST`` so
PyTorch can generate the ``-gencode`` flags itself.
2. ``TORCH_CUDA_ARCH_LIST`` takes priority over ``cross_compile_archs``.
3. If neither is set default compute capabilities will be used.

Expand All @@ -634,14 +642,33 @@ def compute_capability_args(self, cross_compile_archs=None):
"""
ccs = []
if self.jit_mode:
# Compile for underlying architectures since we know those at runtime
for i in range(torch.cuda.device_count()):
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
cc = f"{CC_MAJOR}.{CC_MINOR}"
if cc not in ccs:
ccs.append(cc)
ccs = sorted(ccs)
ccs[-1] += '+PTX'
arch_string = getattr(self, '_jit_arch_list', None)
if arch_string:
arch_string = arch_string.replace(' ', ';')
ccs = [cc.strip() for cc in arch_string.split(';') if cc.strip()]
else:
arch_string = os.environ.get('TORCH_CUDA_ARCH_LIST', '').strip()
if arch_string:
arch_string = arch_string.replace(' ', ';')
ccs = [cc.strip() for cc in arch_string.split(';') if cc.strip()]
else:
if hasattr(torch.cuda, '_is_in_bad_fork') and torch.cuda._is_in_bad_fork():
raise RuntimeError(
f"DeepSpeed JIT builder for '{self.name}' cannot probe CUDA device capabilities "
"in a forked subprocess where CUDA has already been initialized. Set "
"TORCH_CUDA_ARCH_LIST to specify target architectures explicitly.")
for i in range(torch.cuda.device_count()):
CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
cc = f"{CC_MAJOR}.{CC_MINOR}"
if cc not in ccs:
ccs.append(cc)
if len(ccs) == 0:
raise RuntimeError(f"DeepSpeed JIT builder for '{self.name}' found no CUDA devices. Set "
"TORCH_CUDA_ARCH_LIST or make GPUs visible.")

ccs = sorted(ccs, key=lambda cc: tuple(int(part.split('+')[0]) for part in cc.split('.')))
if not any('+PTX' in cc for cc in ccs):
ccs[-1] += '+PTX'
else:
# Cross-compile mode, compile for various architectures
# env override takes priority
Expand Down
Loading
Loading