From 001f77c363710e3f62e05c5aacbed4b2ff7c8c97 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:30:00 +0000 Subject: [PATCH 1/5] Initial plan From b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 06:36:18 +0000 Subject: [PATCH 2/5] Revert "fix: update 1 file reformatted." This reverts commit ff886701c392ab03863c227de14fbe1d671d4173. Co-authored-by: nathon-lee <248585198+nathon-lee@users.noreply.github.com> --- deepspeed/runtime/zero/stage_1_and_2.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 107e47a44042..183fd077f8a9 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -283,11 +283,18 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 + # Check for Muon optimizer usage + self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) + if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" + + # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) + if self.reduce_scatter and self.uses_muon: + assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1187,7 +1194,9 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - if not self.reduce_scatter: + # Check if current configuration requires full all-reduce + if not self.reduce_scatter or any(self.group_uses_muon): + # Force full all-reduce for Muon parameters or when reduce_scatter is disabled self.gradient_reduction_w_predivide(tensor, communication_data_type) return From cbc816c90f4bd6e10ab5b67f4d471002ade8cba7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:40:53 +0000 Subject: [PATCH 3/5] Initial plan From 5fcc9a7e4bf58b1d935dcfeab53143d3cf9dbdf7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 6 Mar 2026 06:43:32 +0000 Subject: [PATCH 4/5] Reapply "fix: update 1 file reformatted." This reverts commit b90aee5a854d5d7b4d9e4c5c951b3c6d61a87c35. --- deepspeed/runtime/zero/stage_1_and_2.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 1efea00bcbbd..12f97348a21f 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -284,18 +284,11 @@ def _enforce_cpu_offload(): self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32 - # Check for Muon optimizer usage - self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params']) - if self.reduce_scatter and self.partition_gradients: valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32) assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'" assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled" assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled" - - # Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2) - if self.reduce_scatter and self.uses_muon: - assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer." # param flattened by groups self.bit16_groups = [] @@ -1224,9 +1217,7 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - # Check if current configuration requires full all-reduce - if not self.reduce_scatter or any(self.group_uses_muon): - # Force full all-reduce for Muon parameters or when reduce_scatter is disabled + if not self.reduce_scatter: self.gradient_reduction_w_predivide(tensor, communication_data_type) return From 7b11a9063ae1777debd9f29bbd70de5680deaf37 Mon Sep 17 00:00:00 2001 From: nathon-lee Date: Tue, 31 Mar 2026 07:04:42 +0000 Subject: [PATCH 5/5] refactor(zero3): factor out defragment method to zero utils Signed-off-by: nathon-lee refactor(zero3): factor out defragment method to zero utils Signed-off-by: nathon-lee --- deepspeed/runtime/zero/stage3.py | 40 ++++---------------------------- deepspeed/runtime/zero/utils.py | 33 ++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index c4f19f43de4f..8f28ee4f8685 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -7,7 +7,7 @@ import gc import collections import itertools -from typing import Deque, Dict, Set, List, Tuple, Container, Optional +from typing import Deque, Dict, Set, List, Container, Optional from contextlib import contextmanager from dataclasses import dataclass, field @@ -21,13 +21,13 @@ from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload import deepspeed.runtime.zenflow.engine_stage3 as zf_engine_stage3 -from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer +from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer, defragment from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus @@ -655,38 +655,6 @@ def get_lr(self): """Return the current learning rate.""" return self.optimizer.param_groups[0]["lr"] - # TODO. factor out to a utility outside of stage3 - @staticmethod - def defragment(tensors: List[Tensor]) -> Tensor: - """move provided tensors into a contiguous flat buffer, with some additional - measures taken to reduce memory fragmentation""" - assert len(set(t.dtype for t in tensors)) == 1 - assert len(set(t.device for t in tensors)) == 1 - - cpu_buffer = torch.empty(sum(p.numel() for p in tensors), - dtype=get_only_unique_item(t.dtype for t in tensors), - device="cpu") - tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors) - orig_device = get_only_unique_item(t.device for t in tensors) - - offset = 0 - for tensor, offset, tensor_numel in tensor_infos: - # move the tensor from device memory to host memory - cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) - tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - - gc.collect() - get_accelerator().empty_cache() - - # copy tensors (now flattened and contiguous) back to GPU - device_buffer = cpu_buffer.to(orig_device) - - # restore device tensors - for tensor, offset, tensor_numel in tensor_infos: - tensor.data = device_buffer.narrow(0, offset, tensor_numel) - - return device_buffer - def _get_param_coordinator(self): return self.parameter_offload.get_param_coordinator() @@ -834,7 +802,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): parameter_partitions = self._get_parameter_partitions() # We need to keep the reference to this buffer to make sure you can free it in `offload_states` - self.lp_param_buffer = __class__.defragment(parameter_partitions) + self.lp_param_buffer = defragment(parameter_partitions) self._set_fp16_partitioned_groups_flat() else: # partitioned params offloaded to CPU when not in use diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index faccddcb5309..139419563352 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import gc from typing import List, Tuple import torch @@ -15,6 +16,7 @@ from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion from deepspeed.utils.nvtx import instrument_w_nvtx from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import get_only_unique_item # ensure we only warn once, otherwise every iteration will trigger a warning warned = False @@ -200,3 +202,34 @@ def get_mapping_to_flat_buffer(tensors: List[torch.Tensor]) -> List[Tuple[torch. offset += tensor_numel return tensor_infos + + +def defragment(tensors: List[torch.Tensor]) -> torch.Tensor: + """move provided tensors into a contiguous flat buffer, with some additional + measures taken to reduce memory fragmentation""" + assert len(set(t.dtype for t in tensors)) == 1 + assert len(set(t.device for t in tensors)) == 1 + + cpu_buffer = torch.empty(sum(p.numel() for p in tensors), + dtype=get_only_unique_item(t.dtype for t in tensors), + device="cpu") + tensor_infos: List[Tuple[torch.Tensor, int, int]] = get_mapping_to_flat_buffer(tensors) + orig_device = get_only_unique_item(t.device for t in tensors) + + offset = 0 + for tensor, offset, tensor_numel in tensor_infos: + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + + gc.collect() + get_accelerator().empty_cache() + + # copy tensors (now flattened and contiguous) back to GPU + device_buffer = cpu_buffer.to(orig_device) + + # restore device tensors + for tensor, offset, tensor_numel in tensor_infos: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + + return device_buffer