From 001f77c363710e3f62e05c5aacbed4b2ff7c8c97 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:30:00 +0000 Subject: [PATCH 1/5] Initial plan From b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:36:18 +0000 Subject: [PATCH 2/5] Revert "fix: update 1 file reformatted." This reverts commit ff886701c392ab03863c227de14fbe1d671d4173. Co-authored-by: nathon-lee <248585198+nathon-lee@users.noreply.github.com> --- deepspeed/runtime/zero/stage_1_and_2.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 107e47a44042..183fd077f8a9 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -283,11 +283,18 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 + # Check for Muon optimizer usage + self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) + if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + + # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) + if self.reduce_scatter and self.uses_muon: + assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1187,7 +1194,9 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - if not self.reduce_scatter: + # Check if current configuration requires full all-reduce + if not self.reduce_scatter or any(self.group_uses_muon): + # Force full all-reduce for Muon parameters or when reduce_scatter is disabled self.gradient_reduction_w_predivide(tensor, communication_data_type) return From cbc816c90f4bd6e10ab5b67f4d471002ade8cba7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:40:53 +0000 Subject: [PATCH 3/5] Initial plan From 5fcc9a7e4bf58b1d935dcfeab53143d3cf9dbdf7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:43:32 +0000 Subject: [PATCH 4/5] Reapply "fix: update 1 file reformatted." This reverts commit b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35. --- deepspeed/runtime/zero/stage_1_and_2.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 1efea00bcbbd..12f97348a21f 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -284,18 +284,11 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 - # Check for Muon optimizer usage - self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) - if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" - - # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) - if self.reduce_scatter and self.uses_muon: - assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1224,9 +1217,7 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - # Check if current configuration requires full all-reduce - if not self.reduce_scatter or any(self.group_uses_muon): - # Force full all-reduce for Muon parameters or when reduce_scatter is disabled + if not self.reduce_scatter: self.gradient_reduction_w_predivide(tensor, communication_data_type) return From b561325e1bddd98f65154ca30ad43d3024b2681e Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Mon, 30 Mar 2026 10:14:49 +0000 Subject: [PATCH 5/5] refactor(module_inject): consolidate duplicate transpose functions Signed-off-by: nathon-lee --- deepspeed/module_inject/load_checkpoint.py | 9 +-------- deepspeed/module_inject/policy.py | 11 +---------- deepspeed/module_inject/utils.py | 10 ++++++++++ 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index 0dba0359f88b..280897e3617a 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -17,6 +17,7 @@ import gc from deepspeed.accelerator import get_accelerator import re +from .utils import transpose def load_model_with_checkpoint(r_module, @@ -42,14 +43,6 @@ def prefix_check(): skip_level_0_prefix = prefix_check() and container.policy.use_load_prefix - def transpose(data): - with torch.no_grad(): - data = data.contiguous() - data1 = data.transpose(-1, -2).reshape(-1) - data.reshape(-1).copy_(data1) - data1 = None - return data.reshape(data.shape[-1], data.shape[-2]) - def load(module, prefix): args = (sd[0], prefix, {}, True, [], [], error_msgs) diff --git a/deepspeed/module_inject/policy.py b/deepspeed/module_inject/policy.py index 41df2b85dc0c..dff12b6c64b3 100644 --- a/deepspeed/module_inject/policy.py +++ b/deepspeed/module_inject/policy.py @@ -7,6 +7,7 @@ from deepspeed.utils.types import ActivationFuncType, NormType import torch from deepspeed.accelerator import get_accelerator +from .utils import transpose transformer_param_names = ( 'attn_qkvw', \ @@ -109,16 +110,6 @@ def layernorm(self): raise NotImplementedError -# TODO (lekurile): This function exists in base container as well, consolidate as some point -def transpose(data): - with torch.no_grad(): - data = data.contiguous() - data1 = data.transpose(-1, -2).reshape(-1) - data.reshape(-1).copy_(data1) - data1 = None - return data.reshape(data.shape[-1], data.shape[-2]) - - # TODO (lekurile): This function exists in megatron feature container as well, consolidate as some point def _transpose(x, heads=1, mp_replace=None): heads = heads // mp_replace.mp_size # type: ignore diff --git a/deepspeed/module_inject/utils.py b/deepspeed/module_inject/utils.py index 42822128f9e1..1837063ec63f 100644 --- a/deepspeed/module_inject/utils.py +++ b/deepspeed/module_inject/utils.py @@ -3,9 +3,19 @@ # DeepSpeed Team +import torch from deepspeed.utils import log_dist +def transpose(data): + with torch.no_grad(): + data = data.contiguous() + data1 = data.transpose(-1, -2).reshape(-1) + data.reshape(-1).copy_(data1) + data1 = None + return data.reshape(data.shape[-1], data.shape[-2]) + + # helper function to map between DS policies and DS containers def policy_to_ds_container(**kwargs): from .containers import HFGPT2LayerPolicy, DS_GPT2Container