Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4526,6 +4526,18 @@ def is_deepcompile_active(self) -> bool:
def is_compiled(self) -> bool:
return self._is_compiled

def _refine_include_states(self, include: Container[OffloadStateTypeEnum]) -> Container[OffloadStateTypeEnum]:
if include is None:
include = list(OffloadStateTypeEnum)

if self.zero_use_cpu_optimizer():
exclude_states = [OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.optim_states]
if self.zero_optimization_partition_weights():
exclude_states.append(OffloadStateTypeEnum.lp_grads)
include = [x for x in include if x not in exclude_states]

return include

def offload_states(self,
include: Container[OffloadStateTypeEnum] = None,
device: OffloadDeviceEnum = OffloadDeviceEnum.cpu,
Expand All @@ -4539,8 +4551,7 @@ def offload_states(self,
pin_memory: Optional. Whether to pin the memory of the offloaded states.
non_blocking: Optional. Whether to offload the states asynchronously.
"""
opt_offload_config = self.zero_offload_optimizer()
assert opt_offload_config is None or opt_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded optimizer states."
include = self._refine_include_states(include)
param_offload_config = self.zero_offload_param()
assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters."

Expand Down
5 changes: 3 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3269,12 +3269,13 @@ def offload_states(self,

self.empty_partition_cache()

assert self.optimizer.__class__ == deepspeed.ops.adam.fused_adam.FusedAdam, "Offloading is supported only for DeepSpeed FusedAdam."

def needs_offload(target):
# return True
return target not in self.offloaded_states and (include == None or target in include)

if needs_offload(OffloadStateTypeEnum.optim_states) or needs_offload(OffloadStateTypeEnum.hp_params):
assert self.optimizer.__class__ == deepspeed.ops.adam.fused_adam.FusedAdam, "Offloading is supported only for DeepSpeed FusedAdam."

# HP param
if needs_offload(OffloadStateTypeEnum.hp_params):
if pin_memory:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def validate_grad_device(model, device: torch.device) -> None:
assert p.grad.device.type == device.type, f"Gradient partition on hp_param.grad is on {p.grad.device}, expected {device}"


def is_offload_optimizer_enabled(config_dict):
return config_dict.get("zero_optimization", {}).get("offload_optimizer", {}).get("device", None) is not None


def is_only_offload_optimizer_states(offloaded_states, optimizer_offload_states):
if offloaded_states is None:
return False
offload_set = set(offloaded_states)
optim_states_set = set(optimizer_offload_states)
return offload_set - optim_states_set == set()


def run_model_zero12(model, param_groups, config_dict, hidden_dim, dtype, offloaded_states, pin_memory, non_blocking):
"""
This function runs a training step, offloads states, reloads them, and verifies correctness for ZeRO-1/2.
Expand All @@ -65,6 +77,10 @@ def run_model_zero12(model, param_groups, config_dict, hidden_dim, dtype, offloa
offload_device = OffloadDeviceEnum.cpu
offload_torch_device = torch.device(offload_device.value)
accelerator_device = torch.device(get_accelerator().current_device_name())
optimizer_device = offload_torch_device if is_offload_optimizer_enabled(config_dict) else accelerator_device
offload_only_optimizer_states = is_only_offload_optimizer_states(
offloaded_states, [OffloadStateTypeEnum.optim_states, OffloadStateTypeEnum.hp_params])
expect_memory_change = not (is_offload_optimizer_enabled(config_dict) and offload_only_optimizer_states)

model, _, _, _ = deepspeed.initialize(model=model, model_parameters=param_groups, config=config_dict)

Expand All @@ -90,21 +106,25 @@ def run_model_zero12(model, param_groups, config_dict, hidden_dim, dtype, offloa
# Gradients exist only between backward() and step(). We must test them here.
grads_expected = [[g.clone().detach() for g in grad_list]
for grad_list in model.optimizer.averaged_gradients.values() if grad_list is not None]
grad_numel = sum(sum(g.numel() for g in grad_list) for grad_list in grads_expected)

alloc_before_offload = get_accelerator().memory_allocated()
model.offload_states(include=offloaded_states,
device=offload_device,
pin_memory=pin_memory,
non_blocking=non_blocking)
alloc_after_offload = get_accelerator().memory_allocated()
assert alloc_after_offload < alloc_before_offload, "FAIL: Allocated memory for grads should decrease after offload"
validate_grad_device(model, offload_torch_device)

if grad_numel > 0:
assert alloc_after_offload < alloc_before_offload, f"FAIL: Allocated memory for grads should decrease after offload {alloc_after_offload=} < {alloc_before_offload=}"
validate_grad_device(model, offload_torch_device)

model.reload_states()
alloc_after_reload = get_accelerator().memory_allocated()

assert alloc_after_reload > alloc_after_offload, "FAIL: Allocated memory for grads should increase after reload"
validate_grad_device(model, accelerator_device)
if grad_numel > 0:
assert alloc_after_reload > alloc_after_offload, f"FAIL: Allocated memory for grads should increase after reload {alloc_after_reload=} > {alloc_after_offload=}"
validate_grad_device(model, accelerator_device)

reloaded_grads = [
grad_list for grad_list in model.optimizer.averaged_gradients.values() if grad_list is not None
Expand Down Expand Up @@ -140,7 +160,9 @@ def run_model_zero12(model, param_groups, config_dict, hidden_dim, dtype, offloa
pin_memory=pin_memory,
non_blocking=non_blocking)
alloc_after_offload = get_accelerator().memory_allocated()
assert alloc_after_offload < alloc_before_offload, f"FAIL: Allocated memory for persistent state {offloaded_states} should decrease after offload"

if expect_memory_change:
assert alloc_after_offload < alloc_before_offload, f"FAIL: Allocated memory for persistent state {offloaded_states} should decrease after offload"

if offloaded_states is None or OffloadStateTypeEnum.lp_params in offloaded_states:
validate_lp_params_device(model, offload_torch_device)
Expand All @@ -151,7 +173,9 @@ def run_model_zero12(model, param_groups, config_dict, hidden_dim, dtype, offloa

model.reload_states()
alloc_after_reload = get_accelerator().memory_allocated()
assert alloc_after_reload > alloc_after_offload, f"FAIL: Allocated memory for persistent state {offloaded_states} should increase after reload"

if expect_memory_change:
assert alloc_after_reload > alloc_after_offload, f"FAIL: Allocated memory for persistent state {offloaded_states} should increase after reload"

# --- Verify restored data integrity ---
for expected, restored in zip(lp_params_expected, model.parameters()):
Expand All @@ -176,8 +200,8 @@ def run_model_zero12(model, param_groups, config_dict, hidden_dim, dtype, offloa

# --- FINAL VALIDATION FOR ALL TESTS ---
validate_lp_params_device(model, accelerator_device)
validate_hp_params_device(model, accelerator_device)
validate_adam_states_device(model, accelerator_device)
validate_hp_params_device(model, optimizer_device)
validate_adam_states_device(model, optimizer_device)

assert torch.any(torch.ne(list(model.parameters())[0], 0.0))

Expand All @@ -189,10 +213,12 @@ def run_model_zero12(model, param_groups, config_dict, hidden_dim, dtype, offloa
@pytest.mark.parametrize("pin_memory", [False, True])
@pytest.mark.parametrize("non_blocking", [False, True])
@pytest.mark.parametrize("zero_stage", [1, 2])
class TestOffloadStatesZero12(DistributedTest):
@pytest.mark.parametrize("static_offload_optimizer", [False, True])
class TestDynamicOffloadStatesZero12(DistributedTest):
world_size = 2

def test_offload_states_zero12(self, included_state, pin_memory, non_blocking, zero_stage):
def test_dynamic_offload_states_zero12(self, included_state, pin_memory, non_blocking, zero_stage,
static_offload_optimizer):
hidden_dim = 1024
config_dict = {
"train_micro_batch_size_per_gpu": 1,
Expand All @@ -209,6 +235,8 @@ def test_offload_states_zero12(self, included_state, pin_memory, non_blocking, z
"enabled": True
}
}
if static_offload_optimizer:
config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"}
model = SimpleModel(hidden_dim, nlayers=4)
param_groups = [{
"params": [p for n, p in model.named_parameters() if 'bias' not in n],
Expand All @@ -227,24 +255,47 @@ def test_offload_states_zero12(self, included_state, pin_memory, non_blocking, z
# ==============================================================================


def validate_device(model, device: torch.device, offloaded_states) -> None:
def validate_device(model, state_device: dict[OffloadStateTypeEnum, torch.device], offloaded_states) -> None:

def compare_device(state) -> bool:
devices = get_state_devices(model, state)
return len(devices) == 1 and device in devices
return len(devices) == 1 and state_device[state] in devices

for state in OffloadStateTypeEnum:
if offloaded_states is None or state in offloaded_states:
if state == OffloadStateTypeEnum.contiguous_grad_buffer and device == torch.device("cpu"):
if state == OffloadStateTypeEnum.contiguous_grad_buffer and state_device[state] == torch.device("cpu"):
assert len(get_state_devices(model,
state)) == 0, f"State {state} must be removed after offload_states()"
else:
assert compare_device(state), f"State {state} is not on device {device}"
assert compare_device(state), f"State {state} is not on device {state_device[state]}"


def run_model_zero3(model, param_groups, config_dict, hidden_dim, dtype, offloaded_states, pin_memory, non_blocking):
# Currently we only support OffloadDeviceEnum.cpu
offload_device = OffloadDeviceEnum.cpu
offload_torch_device = torch.device(offload_device.value)
accelerator_device = torch.device(get_accelerator().current_device_name())
optimizer_device = offload_torch_device if is_offload_optimizer_enabled(config_dict) else accelerator_device
offload_only_optimizer_states = is_only_offload_optimizer_states(
offloaded_states,
[OffloadStateTypeEnum.optim_states, OffloadStateTypeEnum.hp_params, OffloadStateTypeEnum.lp_grads])
expect_memory_change = not (is_offload_optimizer_enabled(config_dict) and offload_only_optimizer_states)

offload_state_device: dict[OffloadStateTypeEnum, torch.device] = {
OffloadStateTypeEnum.hp_params: offload_torch_device,
OffloadStateTypeEnum.lp_params: offload_torch_device,
OffloadStateTypeEnum.optim_states: offload_torch_device,
OffloadStateTypeEnum.lp_grads: offload_torch_device,
OffloadStateTypeEnum.contiguous_grad_buffer: offload_torch_device,
}

reload_state_device: dict[OffloadStateTypeEnum, torch.device] = {
OffloadStateTypeEnum.hp_params: optimizer_device,
OffloadStateTypeEnum.lp_params: accelerator_device,
OffloadStateTypeEnum.optim_states: optimizer_device,
OffloadStateTypeEnum.lp_grads: optimizer_device,
OffloadStateTypeEnum.contiguous_grad_buffer: accelerator_device,
}

model, _, _, _ = deepspeed.initialize(model=model, model_parameters=param_groups, config=config_dict)
data_loader = random_dataloader(model=model,
Expand All @@ -271,14 +322,17 @@ def run_model_zero3(model, param_groups, config_dict, hidden_dim, dtype, offload
pin_memory=pin_memory,
non_blocking=non_blocking)
alloc_after_offload = get_accelerator().memory_allocated()
assert alloc_after_offload < alloc_before_offload, "Allocated memory should decrease after offload"

validate_device(model, torch.device(offload_device.value), offloaded_states)
if expect_memory_change:
assert alloc_after_offload < alloc_before_offload, f"FAIL: Allocated memory should decrease after offload {alloc_after_offload=} < {alloc_before_offload=}"
validate_device(model, offload_state_device, offloaded_states)

# Reload states
model.reload_states()
assert alloc_after_offload < get_accelerator().memory_allocated(
), "Allocated memory should increase after offload back"
alloc_after_reload = get_accelerator().memory_allocated()

if expect_memory_change:
assert alloc_after_reload > alloc_after_offload, f"FAIL: Allocated memory should increase after offload back {alloc_after_reload=} > {alloc_after_offload=}"

# Verify restored states
hp_param_restored = [safe_get_local_fp32_param(p) for p in model.parameters()]
Expand All @@ -300,7 +354,7 @@ def run_model_zero3(model, param_groups, config_dict, hidden_dim, dtype, offload
for adam_exp_avg_sq_expected, adam_exp_avg_sq_restored in zip(adam_exp_avg_sq, adam_exp_avg_sq_restored):
assert torch.equal(adam_exp_avg_sq_expected, adam_exp_avg_sq_restored)

validate_device(model, torch.device(get_accelerator().current_device_name()), offloaded_states)
validate_device(model, reload_state_device, offloaded_states)

# Needed in ZeRO 3. Not doing so can give memory leak
model.destroy()
Expand All @@ -312,11 +366,12 @@ def run_model_zero3(model, param_groups, config_dict, hidden_dim, dtype, offload
])
@pytest.mark.parametrize("pin_memory", [False, True])
@pytest.mark.parametrize("non_blocking", [False, True])
class TestOffloadStatesZero3(DistributedTest):
@pytest.mark.parametrize("static_offload_optimizer", [False, True])
class TestDynamicOffloadStatesZero3(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2

def test_offload_states_zero3(self, included_state, pin_memory, non_blocking):
def test_dynamic_offload_states_zero3(self, included_state, pin_memory, non_blocking, static_offload_optimizer):
hidden_dim = 1024

config_dict = {
Expand All @@ -332,6 +387,8 @@ def test_offload_states_zero3(self, included_state, pin_memory, non_blocking):
}
}
config_dict["bf16"] = {"enabled": True}
if static_offload_optimizer:
config_dict["zero_optimization"]["offload_optimizer"] = {"device": "cpu"}

with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim, nlayers=4)
Expand Down
Loading