diff --git a/csrc/compile/z1.cpp b/csrc/compile/z1.cpp index cbec2dec82ab..d0c804b7f1ad 100644 --- a/csrc/compile/z1.cpp +++ b/csrc/compile/z1.cpp @@ -90,8 +90,21 @@ class Z1CustomOpExecutor : public CustomOpExecutor { std::unordered_map grad_tensors_; }; -static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); -static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); +namespace { + +at::cuda::CUDAStream get_rs_stream() +{ + static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); + return rs_stream; +} + +at::cuda::CUDAStream get_copy_stream() +{ + static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); + return copy_stream; +} + +} // namespace void register_graph_z1(long graph_id, const std::vector& ds_ids) { @@ -100,8 +113,8 @@ void register_graph_z1(long graph_id, const std::vector& ds_ids) reduce_buckets, ds_ids, nccl_comm, - rs_stream, - copy_stream, + get_rs_stream(), + get_copy_stream(), pre_div_reduce); } diff --git a/csrc/compile/z2.cpp b/csrc/compile/z2.cpp index 83d8ccd59085..09290174f146 100644 --- a/csrc/compile/z2.cpp +++ b/csrc/compile/z2.cpp @@ -86,8 +86,21 @@ class Z2CustomOpExecutor : public CustomOpExecutor { } }; -static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); -static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); +namespace { + +at::cuda::CUDAStream get_rs_stream() +{ + static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); + return rs_stream; +} + +at::cuda::CUDAStream get_copy_stream() +{ + static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); + return copy_stream; +} + +} // namespace void register_graph_z2(long graph_id, const std::vector& ds_ids) { @@ -96,8 +109,8 @@ void register_graph_z2(long graph_id, const std::vector& ds_ids) reduce_buckets, ds_ids, nccl_comm, - rs_stream, - copy_stream, + get_rs_stream(), + get_copy_stream(), pre_div_reduce); } diff --git a/csrc/compile/z3.cpp b/csrc/compile/z3.cpp index 28ab171dd11e..fdc146b4ec02 100644 --- a/csrc/compile/z3.cpp +++ b/csrc/compile/z3.cpp @@ -409,11 +409,39 @@ class Z3CustomOpExecutor : public CustomOpExecutor { std::unordered_map param_use_count_; }; -static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true); -static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); -static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); -static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true); -static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true); +namespace { + +at::cuda::CUDAStream get_ag_stream() +{ + static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true); + return ag_stream; +} + +at::cuda::CUDAStream get_rs_stream() +{ + static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); + return rs_stream; +} + +at::cuda::CUDAStream get_copy_stream() +{ + static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); + return copy_stream; +} + +at::cuda::CUDAStream get_offload_stream() +{ + static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true); + return offload_stream; +} + +at::cuda::CUDAStream get_reload_stream() +{ + static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true); + return reload_stream; +} + +} // namespace void register_graph_z3(long graph_id, const std::vector& ds_ids) { @@ -422,11 +450,11 @@ void register_graph_z3(long graph_id, const std::vector& ds_ids) reduce_buckets, ds_ids, nccl_comm, - ag_stream, - rs_stream, - copy_stream, - offload_stream, - reload_stream, + get_ag_stream(), + get_rs_stream(), + get_copy_stream(), + get_offload_stream(), + get_reload_stream(), pre_div_reduce); } diff --git a/csrc/includes/deepcompile.h b/csrc/includes/deepcompile.h index ee3d965970ce..7016d4a99310 100644 --- a/csrc/includes/deepcompile.h +++ b/csrc/includes/deepcompile.h @@ -266,9 +266,7 @@ class DSParam { grad_buffer_(grad_buffer), partitioned_(partitioned), offset_(offset), - persistent_(persistent), - offload_stream_(at::cuda::getStreamFromPool()), - reload_stream_(at::cuda::getStreamFromPool()) + persistent_(persistent) { } @@ -302,18 +300,19 @@ class DSParam { { // If a reloaded tensor exists, offload its data back to ds_tensor_ if (ds_reload_tensor_.defined()) { + auto offload_stream = getOffloadStream(); auto comp_stream = at::cuda::getCurrentCUDAStream(); comp_done_event_ = std::make_shared(cudaEventDisableTiming); // Record completion and wait on the offload stream comp_done_event_->record(comp_stream); - comp_done_event_->block(offload_stream_); + comp_done_event_->block(offload_stream); offload_done_event_ = std::make_shared(cudaEventDisableTiming); { - at::cuda::CUDAStreamGuard guard(offload_stream_); + at::cuda::CUDAStreamGuard guard(offload_stream); ds_tensor_.copy_(ds_reload_tensor_, /*non_blocking=*/true); ds_reload_tensor_.reset(); // Clear the reloaded tensor - offload_done_event_->record(offload_stream_); + offload_done_event_->record(offload_stream); } // Reset the reload event to indicate that no valid reload is present. if (reload_done_event_) { reload_done_event_.reset(); } @@ -324,19 +323,20 @@ class DSParam { { // Reload only if the current ds_tensor_ is on CPU if (ds_tensor_.device().is_cpu()) { + auto reload_stream = getReloadStream(); auto comp_stream = at::cuda::getCurrentCUDAStream(); comp_done_event_ = std::make_shared(cudaEventDisableTiming); // Record and wait on the reload stream comp_done_event_->record(comp_stream); - comp_done_event_->block(reload_stream_); + comp_done_event_->block(reload_stream); reload_done_event_ = std::make_shared(cudaEventDisableTiming); { - at::cuda::CUDAStreamGuard guard(reload_stream_); + at::cuda::CUDAStreamGuard guard(reload_stream); ds_reload_tensor_ = at::empty_like(ds_tensor_, ds_tensor_.options().device(torch::kCUDA)); ds_reload_tensor_.copy_(ds_tensor_, /*non_blocking=*/true); - reload_done_event_->record(reload_stream_); + reload_done_event_->record(reload_stream); } // Reset offload_done_event if it exists to clear any stale offload state. if (offload_done_event_) { offload_done_event_.reset(); } @@ -344,6 +344,18 @@ class DSParam { } private: + at::cuda::CUDAStream getOffloadStream() + { + if (!offload_stream_) { offload_stream_.emplace(at::cuda::getStreamFromPool()); } + return *offload_stream_; + } + + at::cuda::CUDAStream getReloadStream() + { + if (!reload_stream_) { reload_stream_.emplace(at::cuda::getStreamFromPool()); } + return *reload_stream_; + } + long id_; std::vector shape_; at::ScalarType ds_dtype_; @@ -355,8 +367,8 @@ class DSParam { bool persistent_; // for Z3 mutable bool is_reloaded = false; - at::cuda::CUDAStream offload_stream_; - at::cuda::CUDAStream reload_stream_; + std::optional offload_stream_; + std::optional reload_stream_; std::shared_ptr comp_done_event_; std::shared_ptr offload_done_event_; std::shared_ptr reload_done_event_; diff --git a/op_builder/builder.py b/op_builder/builder.py index b2e42d4bd3a8..5c14225446f2 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -556,56 +556,60 @@ def jit_load(self, verbose=True): if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch(): self.build_for_cpu = not torch.cuda.is_available() + saved_jit_mode = self.jit_mode self.jit_mode = True + torch_arch_list_present = "TORCH_CUDA_ARCH_LIST" in os.environ + torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") + normalized_arch_list = torch_arch_list.strip() if torch_arch_list is not None else None + self._jit_arch_list = normalized_arch_list or None from torch.utils.cpp_extension import load start_build = time.time() sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()] extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()] - # Stash TORCH_CUDA_ARCH_LIST to restore after build. - torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") - - nvcc_args = self.strip_empty_entries(self.nvcc_args()) - cxx_args = self.strip_empty_entries(self.cxx_args()) - - cxx_args.append("-UC10_USE_GLOG") - nvcc_args.append("-UC10_USE_GLOG") - if isinstance(self, CUDAOpBuilder): - if not self.build_for_cpu and self.enable_bf16: - cxx_args.append("-DBF16_AVAILABLE") - nvcc_args.append("-DBF16_AVAILABLE") - nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__") - nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__") - nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") - - if self.is_rocm_pytorch(): - cxx_args.append("-D__HIP_PLATFORM_AMD__=1") - os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() - cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) - - op_module = load(name=self.name, - sources=self.strip_empty_entries(sources), - extra_include_paths=self.strip_empty_entries(extra_include_paths), - extra_cflags=cxx_args, - extra_cuda_cflags=nvcc_args, - extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), - with_cuda=True if (isinstance(self, CUDAOpBuilder) and not self.build_for_cpu) else None, - verbose=verbose) - - build_duration = time.time() - start_build - if verbose: - print(f"Time to load {self.name} op: {build_duration} seconds") - - # Restore TORCH_CUDA_ARCH_LIST to its original state. - if torch_arch_list is not None: - os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list - elif "TORCH_CUDA_ARCH_LIST" in os.environ: - del os.environ["TORCH_CUDA_ARCH_LIST"] + try: + nvcc_args = self.strip_empty_entries(self.nvcc_args()) + cxx_args = self.strip_empty_entries(self.cxx_args()) + + cxx_args.append("-UC10_USE_GLOG") + nvcc_args.append("-UC10_USE_GLOG") + if isinstance(self, CUDAOpBuilder): + if not self.build_for_cpu and self.enable_bf16: + cxx_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-DBF16_AVAILABLE") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__") + nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + + if self.is_rocm_pytorch(): + cxx_args.append("-D__HIP_PLATFORM_AMD__=1") + os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch() + cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size()) + + op_module = load(name=self.name, + sources=self.strip_empty_entries(sources), + extra_include_paths=self.strip_empty_entries(extra_include_paths), + extra_cflags=cxx_args, + extra_cuda_cflags=nvcc_args, + extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), + with_cuda=True if (isinstance(self, CUDAOpBuilder) and not self.build_for_cpu) else None, + verbose=verbose) + + build_duration = time.time() - start_build + if verbose: + print(f"Time to load {self.name} op: {build_duration} seconds") - __class__._loaded_ops[self.name] = op_module + __class__._loaded_ops[self.name] = op_module - return op_module + return op_module + finally: + if torch_arch_list_present: + os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + else: + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + self._jit_arch_list = None + self.jit_mode = saved_jit_mode class CUDAOpBuilder(OpBuilder): @@ -614,11 +618,15 @@ def compute_capability_args(self, cross_compile_archs=None): """ Returns nvcc compute capability compile flags. - 1. Under ``jit_mode`` the visible-card architectures are detected, - ``TORCH_CUDA_ARCH_LIST`` is set accordingly, and an **empty list** - is returned so that PyTorch generates the ``-gencode`` flags - itself (avoiding duplicates). See - https://github.com/deepspeedai/DeepSpeed/issues/7972 + 1. Under ``jit_mode``, the precedence is: + a. preserved ``TORCH_CUDA_ARCH_LIST`` captured by ``jit_load()`` + b. live ``TORCH_CUDA_ARCH_LIST`` from the environment + c. runtime device probing when the process is not in a bad-fork context + d. an error when no explicit arch list exists in a bad-fork context + + JIT mode auto-adds ``+PTX`` to the highest compute capability when + no entry already carries it, then sets ``TORCH_CUDA_ARCH_LIST`` so + PyTorch can generate the ``-gencode`` flags itself. 2. ``TORCH_CUDA_ARCH_LIST`` takes priority over ``cross_compile_archs``. 3. If neither is set default compute capabilities will be used. @@ -634,14 +642,33 @@ def compute_capability_args(self, cross_compile_archs=None): """ ccs = [] if self.jit_mode: - # Compile for underlying architectures since we know those at runtime - for i in range(torch.cuda.device_count()): - CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) - cc = f"{CC_MAJOR}.{CC_MINOR}" - if cc not in ccs: - ccs.append(cc) - ccs = sorted(ccs) - ccs[-1] += '+PTX' + arch_string = getattr(self, '_jit_arch_list', None) + if arch_string: + arch_string = arch_string.replace(' ', ';') + ccs = [cc.strip() for cc in arch_string.split(';') if cc.strip()] + else: + arch_string = os.environ.get('TORCH_CUDA_ARCH_LIST', '').strip() + if arch_string: + arch_string = arch_string.replace(' ', ';') + ccs = [cc.strip() for cc in arch_string.split(';') if cc.strip()] + else: + if hasattr(torch.cuda, '_is_in_bad_fork') and torch.cuda._is_in_bad_fork(): + raise RuntimeError( + f"DeepSpeed JIT builder for '{self.name}' cannot probe CUDA device capabilities " + "in a forked subprocess where CUDA has already been initialized. Set " + "TORCH_CUDA_ARCH_LIST to specify target architectures explicitly.") + for i in range(torch.cuda.device_count()): + CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i) + cc = f"{CC_MAJOR}.{CC_MINOR}" + if cc not in ccs: + ccs.append(cc) + if len(ccs) == 0: + raise RuntimeError(f"DeepSpeed JIT builder for '{self.name}' found no CUDA devices. Set " + "TORCH_CUDA_ARCH_LIST or make GPUs visible.") + + ccs = sorted(ccs, key=lambda cc: tuple(int(part.split('+')[0]) for part in cc.split('.'))) + if not any('+PTX' in cc for cc in ccs): + ccs[-1] += '+PTX' else: # Cross-compile mode, compile for various architectures # env override takes priority diff --git a/tests/unit/ops/test_op_builder.py b/tests/unit/ops/test_op_builder.py new file mode 100644 index 000000000000..218d053cf955 --- /dev/null +++ b/tests/unit/ops/test_op_builder.py @@ -0,0 +1,165 @@ +# Copyright (c) DeepSpeed Team. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import importlib.util +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +BUILDER_PATH = Path(__file__).resolve().parents[3] / "op_builder" / "builder.py" +BUILDER_SPEC = importlib.util.spec_from_file_location("test_op_builder_module", BUILDER_PATH) +builder_module = importlib.util.module_from_spec(BUILDER_SPEC) +BUILDER_SPEC.loader.exec_module(builder_module) +CUDAOpBuilder = builder_module.CUDAOpBuilder + +BUILDER_MODULE = builder_module +CUDA_API = BUILDER_MODULE.torch.cuda #ignore-cuda + + +class _StubCUDAOpBuilder(CUDAOpBuilder): + BUILD_VAR = "STUB_BUILDER" + NAME = "stub" + + def __init__(self): + super().__init__(name="stub") + + def absolute_name(self): + return "deepspeed.ops.stub" + + def sources(self): + return [] + + def include_paths(self): + return [] + + +def make_builder(**overrides): + builder = _StubCUDAOpBuilder() + for key, value in overrides.items(): + setattr(builder, key, value) + return builder + + +def assert_jit_uses_explicit_arch_list(builder, expected_arch_list, env_updates=None): + env_updates = env_updates or {} + + with patch.dict(os.environ, env_updates, clear=False): + if "TORCH_CUDA_ARCH_LIST" not in env_updates: + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + with patch.object(CUDA_API, "device_count", + side_effect=AssertionError("probe should not be called")) as device_count: + with patch.object(CUDA_API, + "get_device_capability", + side_effect=AssertionError("probe should not be called")) as get_device_capability: + assert builder.compute_capability_args() == [] + assert os.environ["TORCH_CUDA_ARCH_LIST"] == expected_arch_list + + device_count.assert_not_called() + get_device_capability.assert_not_called() + + +def test_jit_mode_prefers_explicit_arch_lists_before_cuda_probe(): + assert_jit_uses_explicit_arch_list(make_builder(jit_mode=True, _jit_arch_list="8.0;8.9"), "8.0;8.9+PTX") + assert_jit_uses_explicit_arch_list(make_builder(jit_mode=True), "8.0;8.9+PTX", {"TORCH_CUDA_ARCH_LIST": "8.0 8.9"}) + + +def test_bad_fork_jit_without_arch_list_raises_actionable_error(): + builder = make_builder(jit_mode=True) + + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + with patch.object(CUDA_API, "_is_in_bad_fork", return_value=True): + with patch.object(CUDA_API, "device_count", + side_effect=AssertionError("probe should not be called")) as device_count: + with pytest.raises(RuntimeError, match="TORCH_CUDA_ARCH_LIST"): + builder.compute_capability_args() + + device_count.assert_not_called() + + +def test_jit_mode_probes_devices_when_safe_and_errors_without_visible_gpus(): + builder = make_builder(jit_mode=True) + + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + with patch.object(CUDA_API, "_is_in_bad_fork", return_value=False): + with patch.object(CUDA_API, "device_count", return_value=2) as device_count: + with patch.object(CUDA_API, "get_device_capability", side_effect=[(7, 0), + (8, 9)]) as get_device_capability: + assert builder.compute_capability_args() == [] + assert os.environ["TORCH_CUDA_ARCH_LIST"] == "7.0;8.9+PTX" + assert builder.enable_bf16 is False + + device_count.assert_called_once_with() + assert get_device_capability.call_count == 2 + + builder = make_builder(jit_mode=True) + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("TORCH_CUDA_ARCH_LIST", None) + with patch.object(CUDA_API, "_is_in_bad_fork", return_value=False): + with patch.object(CUDA_API, "device_count", return_value=0): + with pytest.raises(RuntimeError, match="no CUDA devices"): + builder.compute_capability_args() + + +def test_jit_load_restores_env_and_state_after_failure(): + builder = make_builder() + + def fail_nvcc_args(): + assert getattr(builder, "_jit_arch_list", None) == "8.9" + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9+PTX" + raise RuntimeError("build failed") + + with patch.dict(os.environ, {"TORCH_CUDA_ARCH_LIST": "8.9"}, clear=False): + with patch.object(builder, "is_compatible", return_value=True): + with patch.object(CUDAOpBuilder, "is_rocm_pytorch", return_value=False): + with patch.object(CUDA_API, "is_available", return_value=True): + with patch("torch.utils.cpp_extension.verify_ninja_availability", return_value=None): + with patch.object(builder, "nvcc_args", side_effect=fail_nvcc_args): + with pytest.raises(RuntimeError, match="build failed"): + builder.jit_load(verbose=False) + + assert getattr(builder, "_jit_arch_list", None) is None + assert builder.jit_mode is False + assert os.environ["TORCH_CUDA_ARCH_LIST"] == "8.9" + + +def test_jit_load_restores_state_after_success(): + builder = make_builder() + op_module = MagicMock() + + def successful_nvcc_args(): + assert builder._jit_arch_list == "8.9" + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.9+PTX" + return [] + + with patch.dict(os.environ, {"TORCH_CUDA_ARCH_LIST": "8.9"}, clear=False): + with patch.object(builder, "is_compatible", return_value=True): + with patch.object(CUDAOpBuilder, "is_rocm_pytorch", return_value=False): + with patch.object(CUDA_API, "is_available", return_value=True): + with patch("torch.utils.cpp_extension.verify_ninja_availability", return_value=None): + with patch.object(builder, "nvcc_args", side_effect=successful_nvcc_args): + with patch.object(builder, "cxx_args", return_value=[]): + with patch("torch.utils.cpp_extension.load", return_value=op_module): + assert builder.jit_load(verbose=False) is op_module + + assert os.environ["TORCH_CUDA_ARCH_LIST"] == "8.9" + assert getattr(builder, "_jit_arch_list", None) is None + assert builder.jit_mode is False + + +def test_non_jit_branch_unchanged(): + builder = make_builder(jit_mode=False) + + with patch.dict(os.environ, {"TORCH_CUDA_ARCH_LIST": "8.0;8.9+PTX"}, clear=False): + args = builder.compute_capability_args() + + assert args == [ + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_89,code=compute_89", + ]