diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index c4f19f43de4f..8f28ee4f8685 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -7,7 +7,7 @@ import gc import collections import itertools -from typing import Deque, Dict, Set, List, Tuple, Container, Optional +from typing import Deque, Dict, Set, List, Container, Optional from contextlib import contextmanager from dataclasses import dataclass, field @@ -21,13 +21,13 @@ from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_comm_dtypes, is_autocast_initialized, sort_dtypes from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, mask_nan_or_inf_with_val_inplace, count_used_parameters_in_backward from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload import deepspeed.runtime.zenflow.engine_stage3 as zf_engine_stage3 -from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer +from deepspeed.runtime.zero.utils import get_mapping_to_flat_buffer, defragment from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus @@ -655,38 +655,6 @@ def get_lr(self): """Return the current learning rate.""" return self.optimizer.param_groups[0]["lr"] - # TODO. factor out to a utility outside of stage3 - @staticmethod - def defragment(tensors: List[Tensor]) -> Tensor: - """move provided tensors into a contiguous flat buffer, with some additional - measures taken to reduce memory fragmentation""" - assert len(set(t.dtype for t in tensors)) == 1 - assert len(set(t.device for t in tensors)) == 1 - - cpu_buffer = torch.empty(sum(p.numel() for p in tensors), - dtype=get_only_unique_item(t.dtype for t in tensors), - device="cpu") - tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors) - orig_device = get_only_unique_item(t.device for t in tensors) - - offset = 0 - for tensor, offset, tensor_numel in tensor_infos: - # move the tensor from device memory to host memory - cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) - tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - - gc.collect() - get_accelerator().empty_cache() - - # copy tensors (now flattened and contiguous) back to GPU - device_buffer = cpu_buffer.to(orig_device) - - # restore device tensors - for tensor, offset, tensor_numel in tensor_infos: - tensor.data = device_buffer.narrow(0, offset, tensor_numel) - - return device_buffer - def _get_param_coordinator(self): return self.parameter_offload.get_param_coordinator() @@ -834,7 +802,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups): parameter_partitions = self._get_parameter_partitions() # We need to keep the reference to this buffer to make sure you can free it in `offload_states` - self.lp_param_buffer = __class__.defragment(parameter_partitions) + self.lp_param_buffer = defragment(parameter_partitions) self._set_fp16_partitioned_groups_flat() else: # partitioned params offloaded to CPU when not in use diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index faccddcb5309..139419563352 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -4,6 +4,7 @@ # DeepSpeed Team import os +import gc from typing import List, Tuple import torch @@ -15,6 +16,7 @@ from deepspeed.ops.lion import DeepSpeedCPULion, FusedLion from deepspeed.utils.nvtx import instrument_w_nvtx from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import get_only_unique_item # ensure we only warn once, otherwise every iteration will trigger a warning warned = False @@ -200,3 +202,34 @@ def get_mapping_to_flat_buffer(tensors: List[torch.Tensor]) -> List[Tuple[torch. offset += tensor_numel return tensor_infos + + +def defragment(tensors: List[torch.Tensor]) -> torch.Tensor: + """move provided tensors into a contiguous flat buffer, with some additional + measures taken to reduce memory fragmentation""" + assert len(set(t.dtype for t in tensors)) == 1 + assert len(set(t.device for t in tensors)) == 1 + + cpu_buffer = torch.empty(sum(p.numel() for p in tensors), + dtype=get_only_unique_item(t.dtype for t in tensors), + device="cpu") + tensor_infos: List[Tuple[torch.Tensor, int, int]] = get_mapping_to_flat_buffer(tensors) + orig_device = get_only_unique_item(t.device for t in tensors) + + offset = 0 + for tensor, offset, tensor_numel in tensor_infos: + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + + gc.collect() + get_accelerator().empty_cache() + + # copy tensors (now flattened and contiguous) back to GPU + device_buffer = cpu_buffer.to(orig_device) + + # restore device tensors + for tensor, offset, tensor_numel in tensor_infos: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + + return device_buffer