From e21277eca88436cb4f3d0c74eee87668d37055e4 Mon Sep 17 00:00:00 2001 From: Cursx <674760201@qq.com> Date: Tue, 14 Apr 2026 23:29:53 +0800 Subject: [PATCH 1/3] fix(op_builder): avoid duplicate -gencode flags in JIT mode (#7972) In JIT mode, compute_capability_args() now sets TORCH_CUDA_ARCH_LIST to the detected GPU architectures and returns an empty list, letting PyTorch generate -gencode flags. Previously the env var was cleared to an empty string (which PyTorch treats as unset, triggering auto-detection) while DeepSpeed also added its own -gencode flags, resulting in duplicates. The jit_load() restore logic is also improved: if TORCH_CUDA_ARCH_LIST was not originally set, it is now removed from os.environ after build instead of being left as an empty string. Fixes #7972 Signed-off-by: Cursx <674760201@qq.com> --- op_builder/builder.py | 52 +++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index 308f1822a58f..1313c64247b1 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -563,14 +563,12 @@ def jit_load(self, verbose=True): sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()] extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()] - # Torch will try and apply whatever CCs are in the arch list at compile time, - # we have already set the intended targets ourselves we know that will be - # needed at runtime. This prevents CC collisions such as multiple __half - # implementations. Stash arch list to reset after build. - torch_arch_list = None - if "TORCH_CUDA_ARCH_LIST" in os.environ: - torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") - os.environ["TORCH_CUDA_ARCH_LIST"] = "" + # Stash the original TORCH_CUDA_ARCH_LIST so we can restore it after build. + # In JIT mode, compute_capability_args() will set TORCH_CUDA_ARCH_LIST to + # the detected GPU architectures, letting PyTorch generate the -gencode + # flags. This avoids duplicate flags and the "TORCH_CUDA_ARCH_LIST is not + # set" warning. See https://github.com/deepspeedai/DeepSpeed/issues/7972 + torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") nvcc_args = self.strip_empty_entries(self.nvcc_args()) cxx_args = self.strip_empty_entries(self.cxx_args()) @@ -603,9 +601,12 @@ def jit_load(self, verbose=True): if verbose: print(f"Time to load {self.name} op: {build_duration} seconds") - # Reset arch list so we are not silently removing it for other possible use cases - if torch_arch_list: + # Restore the original TORCH_CUDA_ARCH_LIST so we are not silently + # modifying it for other possible use cases. + if torch_arch_list is not None: os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + elif "TORCH_CUDA_ARCH_LIST" in os.environ: + del os.environ["TORCH_CUDA_ARCH_LIST"] __class__._loaded_ops[self.name] = op_module @@ -618,18 +619,22 @@ def compute_capability_args(self, cross_compile_archs=None): """ Returns nvcc compute capability compile flags. - 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. - 2. If neither is set default compute capabilities will be used - 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX + 1. Under ``jit_mode`` the visible-card architectures are detected, + ``TORCH_CUDA_ARCH_LIST`` is set accordingly, and an **empty list** + is returned so that PyTorch generates the ``-gencode`` flags + itself (avoiding duplicates). See + https://github.com/deepspeedai/DeepSpeed/issues/7972 + 2. ``TORCH_CUDA_ARCH_LIST`` takes priority over ``cross_compile_archs``. + 3. If neither is set default compute capabilities will be used. Format: - - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: + - ``TORCH_CUDA_ARCH_LIST`` may use ; or whitespace separators. Examples: TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ... TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ... - - `cross_compile_archs` uses ; separator. + - ``cross_compile_archs`` uses ; separator. """ ccs = [] @@ -662,17 +667,26 @@ def compute_capability_args(self, cross_compile_archs=None): raise RuntimeError( f"Unable to load {self.name} op due to no compute capabilities remaining after filtering") - args = [] self.enable_bf16 = True + for cc in ccs: + if int(cc[0]) <= 7: + self.enable_bf16 = False + + if self.jit_mode: + # Let PyTorch handle -gencode flag generation via TORCH_CUDA_ARCH_LIST + # to avoid duplicate flags. Reconstruct the arch-list string from the + # (potentially filtered) ``ccs`` list. + arch_list = ";".join(f"{cc[0]}.{cc[1]}" for cc in ccs) + os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list + return [] + + args = [] for cc in ccs: num = cc[0] + cc[1].split('+')[0] args.append(f'-gencode=arch=compute_{num},code=sm_{num}') if cc[1].endswith('+PTX'): args.append(f'-gencode=arch=compute_{num},code=compute_{num}') - if int(cc[0]) <= 7: - self.enable_bf16 = False - return args def filter_ccs(self, ccs: List[str]): From 392d038ad0343406db90c6e5fa04dde38abb7106 Mon Sep 17 00:00:00 2001 From: Cursx <674760201@qq.com> Date: Wed, 15 Apr 2026 07:59:46 +0800 Subject: [PATCH 2/3] fix(op_builder): sync TORCH_CUDA_ARCH_LIST with filtered archs in non-JIT mode Extend the fix to non-JIT (setup.py) mode: compute_capability_args() now updates TORCH_CUDA_ARCH_LIST to the filtered arch list from filter_ccs() for both JIT and non-JIT paths. Each CUDAExtension still carries its own -gencode flags in extra_compile_args, but BuildExtension will no longer silently re-introduce archs that filter_ccs() removed. Signed-off-by: Cursx <674760201@qq.com> --- op_builder/builder.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index 1313c64247b1..bb06706d1252 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -672,14 +672,25 @@ def compute_capability_args(self, cross_compile_archs=None): if int(cc[0]) <= 7: self.enable_bf16 = False + # Synchronise TORCH_CUDA_ARCH_LIST with the (potentially filtered) arch + # list so that PyTorch's BuildExtension / load() generates consistent + # -gencode flags. Without this, filter_ccs() removals are silently + # re-added by PyTorch reading the original, unfiltered env var. + # See https://github.com/deepspeedai/DeepSpeed/issues/7972 + arch_list = ";".join(f"{cc[0]}.{cc[1]}" for cc in ccs) + os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list + if self.jit_mode: - # Let PyTorch handle -gencode flag generation via TORCH_CUDA_ARCH_LIST - # to avoid duplicate flags. Reconstruct the arch-list string from the - # (potentially filtered) ``ccs`` list. - arch_list = ";".join(f"{cc[0]}.{cc[1]}" for cc in ccs) - os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list + # In JIT mode PyTorch's load() will read TORCH_CUDA_ARCH_LIST and + # generate the -gencode flags itself, so we return nothing here to + # avoid duplicates. return [] + # In non-JIT (setup.py) mode we still return explicit -gencode flags + # because each CUDAExtension needs its own per-builder flags in + # extra_compile_args. BuildExtension will also read the env var, which + # may cause harmless duplicates but will no longer reintroduce archs + # that filter_ccs() removed. args = [] for cc in ccs: num = cc[0] + cc[1].split('+')[0] From b5cdc893e1aa4e3ab73e7cea07cde91c2683a1dc Mon Sep 17 00:00:00 2001 From: Cursx <674760201@qq.com> Date: Wed, 15 Apr 2026 08:07:07 +0800 Subject: [PATCH 3/3] style: trim comments to match project conventions Signed-off-by: Cursx <674760201@qq.com> --- op_builder/builder.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index bb06706d1252..b2e42d4bd3a8 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -563,11 +563,7 @@ def jit_load(self, verbose=True): sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()] extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()] - # Stash the original TORCH_CUDA_ARCH_LIST so we can restore it after build. - # In JIT mode, compute_capability_args() will set TORCH_CUDA_ARCH_LIST to - # the detected GPU architectures, letting PyTorch generate the -gencode - # flags. This avoids duplicate flags and the "TORCH_CUDA_ARCH_LIST is not - # set" warning. See https://github.com/deepspeedai/DeepSpeed/issues/7972 + # Stash TORCH_CUDA_ARCH_LIST to restore after build. torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") nvcc_args = self.strip_empty_entries(self.nvcc_args()) @@ -601,8 +597,7 @@ def jit_load(self, verbose=True): if verbose: print(f"Time to load {self.name} op: {build_duration} seconds") - # Restore the original TORCH_CUDA_ARCH_LIST so we are not silently - # modifying it for other possible use cases. + # Restore TORCH_CUDA_ARCH_LIST to its original state. if torch_arch_list is not None: os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list elif "TORCH_CUDA_ARCH_LIST" in os.environ: @@ -672,25 +667,16 @@ def compute_capability_args(self, cross_compile_archs=None): if int(cc[0]) <= 7: self.enable_bf16 = False - # Synchronise TORCH_CUDA_ARCH_LIST with the (potentially filtered) arch - # list so that PyTorch's BuildExtension / load() generates consistent - # -gencode flags. Without this, filter_ccs() removals are silently - # re-added by PyTorch reading the original, unfiltered env var. - # See https://github.com/deepspeedai/DeepSpeed/issues/7972 + # Keep TORCH_CUDA_ARCH_LIST in sync with the filtered arch list so + # PyTorch does not re-add archs that filter_ccs() removed. arch_list = ";".join(f"{cc[0]}.{cc[1]}" for cc in ccs) os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list if self.jit_mode: - # In JIT mode PyTorch's load() will read TORCH_CUDA_ARCH_LIST and - # generate the -gencode flags itself, so we return nothing here to - # avoid duplicates. + # Let PyTorch generate -gencode flags from the env var. return [] - # In non-JIT (setup.py) mode we still return explicit -gencode flags - # because each CUDAExtension needs its own per-builder flags in - # extra_compile_args. BuildExtension will also read the env var, which - # may cause harmless duplicates but will no longer reintroduce archs - # that filter_ccs() removed. + # Non-JIT: return explicit flags per builder for extra_compile_args. args = [] for cc in ccs: num = cc[0] + cc[1].split('+')[0]