Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
001f77c
Initial plan
Copilot Feb 27, 2026
b90aee5
Revert "fix: update 1 file reformatted."
Copilot Feb 27, 2026
b6da9af
Merge pull request #5 from nathon-lee/copilot/git-revert-ff886701
nathon-lee Feb 27, 2026
bb7f64f
Merge branch 'deepspeedai:master' into master
nathon-lee Mar 6, 2026
cbc816c
Initial plan
Copilot Mar 6, 2026
5fcc9a7
Reapply "fix: update 1 file reformatted."
Copilot Mar 6, 2026
f7c5d75
Merge pull request #6 from nathon-lee/copilot/remove-commits-from-master
nathon-lee Mar 6, 2026
0513f4a
feat: Refactor AutoTP universal checkpoint metadata schema handling
nathon-lee Mar 11, 2026
6bfea51
fix: update unit test file test_autotp_universal_checkpoint.py
nathon-lee Mar 16, 2026
5ab684d
fix: update unit test file test_autotp_universal_checkpoint.py
nathon-lee Mar 16, 2026
90e30f1
Add constant for AutoTP universal-checkpoint metadata key
nathon-lee Mar 17, 2026
3f4ecc7
Remove redundant callable() guard for _mark_uc_metadata hook
nathon-lee Mar 17, 2026
2bf8402
tests: cover uneven sub_param_sizes in AutoTP UC restore
nathon-lee Mar 18, 2026
0f8e4ff
fix: update some logic for _resolve_autotp_partition
nathon-lee Mar 18, 2026
6c8510a
docs: update universal checkpointing and AutoTP checkpoint docs
nathon-lee Mar 18, 2026
4c81483
Merge branch 'master' into feat_uc_autotp
delock Mar 19, 2026
e09f5d1
tests: avoid pytest import file mismatch by renaming AutoTP UC test
nathon-lee Mar 23, 2026
7ac0316
fix: Automatically fix file end line breaks and formatting issues
nathon-lee Mar 23, 2026
a3f96ed
Merge branch 'master' into feat_uc_autotp
delock Mar 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
UNIVERSAL_CHECKPOINT_VERSION_KEY = 'universal_checkpoint_version'
# Reserve version 0.1 for the hardcoded logic used in BLOOM-176B training
UNIVERSAL_CHECKPOINT_VERSION_VALUE = 0.2
# Attribute name used to store AutoTP universal-checkpoint metadata on torch Parameters.
DS_AUTOTP_UC_META = "ds_autotp_universal_checkpoint_meta"

# Vocabulary padding
VOCAB_TENSOR = 'vocab_tensor'
Expand Down
169 changes: 123 additions & 46 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import types
from typing import List, Tuple, Union
from dataclasses import dataclass
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE)
from .constants import (FP32_WEIGHT_KEY, PARAM, VOCAB_TENSOR, CAT_DIM, PARAM_N_SUB_PARAMS, SUB_PARAM_SHAPE,
DS_AUTOTP_UC_META)


@dataclass
Expand All @@ -19,6 +20,82 @@ class SubparamShape:
partition_dim: int


def _get_param_uc_restore_meta(param):
"""Return the restore-facing view of AutoTP UC metadata for a parameter.

AutoTP parameter metadata intentionally serves two separate consumers:
- restore-time fields at the top level, consumed here by UC loading
- conversion-time fields under `conversion`, consumed by
`collect_autotp_universal_checkpoint_info()` in `layers.py`
"""
return getattr(param, DS_AUTOTP_UC_META, None)


def _resolve_autotp_partition(current_param, ckpt_dict, full_hp_param, tp_rank, tp_world_size):
meta = _get_param_uc_restore_meta(current_param)
if not meta:
return None

partition_dim = meta.get('partition_dim')
logical_shape = meta.get('logical_shape')
sub_param_shape = meta.get('sub_param_shape')
sub_param_sizes = meta.get('sub_param_sizes')
replicated = meta.get('replicated', False)

if replicated:
assert partition_dim is None
slice_tensor = full_hp_param
return slice_tensor.flatten()

if partition_dim is None:
return None

if logical_shape is None:
return None

full_view = full_hp_param.view(logical_shape)

if sub_param_shape is not None:
if hasattr(sub_param_shape, "shape") and hasattr(sub_param_shape, "partition_dim"):
shape_spec = sub_param_shape.shape
partition_dim = sub_param_shape.partition_dim
else:
shape_spec = sub_param_shape

sub_dim_sizes = shape_spec[partition_dim]
if not isinstance(sub_dim_sizes, tuple):
sub_dim_sizes = (sub_dim_sizes, )

offset = 0
merged_chunks = []
for sub_dim_size in sub_dim_sizes:
sub_slice = full_view.narrow(partition_dim, offset, sub_dim_size) \
.chunk(tp_world_size, dim=partition_dim)[tp_rank]
merged_chunks.append(sub_slice)
offset += sub_dim_size

slice_tensor = torch.cat(merged_chunks, dim=partition_dim)
return slice_tensor.flatten()

if sub_param_sizes is not None:
if not isinstance(sub_param_sizes, (tuple, list)):
sub_param_sizes = (sub_param_sizes, )

offset = 0
merged_chunks = []
for sub_dim_size in sub_param_sizes:
sub_slice = full_view.narrow(partition_dim, offset, sub_dim_size) \
.chunk(tp_world_size, dim=partition_dim)[tp_rank]
merged_chunks.append(sub_slice)
offset += sub_dim_size

slice_tensor = torch.cat(merged_chunks, dim=partition_dim)
return slice_tensor.flatten()

slice_tensor = full_view.chunk(tp_world_size, dim=partition_dim)[tp_rank]
return slice_tensor.flatten()


def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
hp_mapping = self._hp_mapping
hp_mapping.optim_fragment = {}
Expand Down Expand Up @@ -73,52 +150,52 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
padding_size = padded_target_vocab_size - full_hp_param.shape[0]
full_hp_param = torch.nn.functional.pad(full_hp_param, (0, 0, 0, padding_size), "constant", 0)

full_param_numel = full_hp_param.numel()
tp_slice_numel = self.numel()
# if key == FP32_WEIGHT_KEY and 'word_embeddings.weight' in folder:
# print_rank_0(f'{full_hp_param[:10]=}', force=True)


assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'

# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")

sub_param_shape = ckpt_dict.get(SUB_PARAM_SHAPE, None)
# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
# special case is when a single parameter is effectively a container for multiple sub parameters
# (more details at PARAM_N_SUB_PARAMS definition)
chunk_dim = ckpt_dict.get(CAT_DIM, 0)
n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1)
if sub_param_shape:
partition_dim = sub_param_shape.partition_dim
sub_dim_sizes = sub_param_shape.shape[partition_dim]
if not isinstance(sub_dim_sizes, tuple):
sub_dim_sizes = (sub_dim_sizes, )

partition_shape = [sum(d) if isinstance(d, tuple) else d for d in sub_param_shape.shape]
full_hp_param = full_hp_param.view(partition_shape)

offset = 0
merged_chunks = []
for sub_dim_size in sub_dim_sizes:
sub_params_tp_slice = full_hp_param.narrow(partition_dim,
offset, sub_dim_size).chunk(tp_world_size,
dim=partition_dim)[tp_rank]
merged_chunks.append(sub_params_tp_slice)
offset += sub_dim_size
tp_hp_slice = torch.cat(merged_chunks, dim=partition_dim)

elif n_sub_params > 1:
sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim)
sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params]
tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim)
autotp_tp_hp_slice = _resolve_autotp_partition(self, ckpt_dict, full_hp_param, tp_rank, tp_world_size)
if autotp_tp_hp_slice is not None:
tp_hp_slice = autotp_tp_hp_slice
else:
# this performs the opposite of cat when merging TP slices
tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]

tp_hp_slice = tp_hp_slice.flatten()
full_param_numel = full_hp_param.numel()
tp_slice_numel = self.numel()
assert full_param_numel == tp_world_size * tp_slice_numel, \
f'Loading {ckpt_file} full param numel {full_param_numel} != tensor slice numel {tp_slice_numel} * tp_world_size {tp_world_size}'

# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")

sub_param_shape = ckpt_dict.get(SUB_PARAM_SHAPE, None)
# since when we do many to 1 on tp we cat sometimes on dim=0 and other times on dim=1 we have to do exactly the same in reverse
# special case is when a single parameter is effectively a container for multiple sub parameters
# (more details at PARAM_N_SUB_PARAMS definition)
chunk_dim = ckpt_dict.get(CAT_DIM, 0)
n_sub_params = ckpt_dict.get(PARAM_N_SUB_PARAMS, 1)
if sub_param_shape:
partition_dim = sub_param_shape.partition_dim
sub_dim_sizes = sub_param_shape.shape[partition_dim]
if not isinstance(sub_dim_sizes, tuple):
sub_dim_sizes = (sub_dim_sizes, )

partition_shape = [sum(d) if isinstance(d, tuple) else d for d in sub_param_shape.shape]
full_hp_param = full_hp_param.view(partition_shape)

offset = 0
merged_chunks = []
for sub_dim_size in sub_dim_sizes:
sub_params_tp_slice = full_hp_param.narrow(partition_dim,
offset, sub_dim_size).chunk(tp_world_size,
dim=partition_dim)[tp_rank]
merged_chunks.append(sub_params_tp_slice)
offset += sub_dim_size
tp_hp_slice = torch.cat(merged_chunks, dim=partition_dim)

elif n_sub_params > 1:
sub_params = full_hp_param.chunk(n_sub_params, dim=chunk_dim)
sub_params_tp_slice = [p.chunk(tp_world_size, dim=chunk_dim)[tp_rank] for p in sub_params]
tp_hp_slice = torch.cat(sub_params_tp_slice, dim=chunk_dim)
else:
# this performs the opposite of cat when merging TP slices
tp_hp_slice = full_hp_param.chunk(tp_world_size, chunk_dim)[tp_rank]

tp_hp_slice = tp_hp_slice.flatten()

lp_frag_address = hp_mapping.lp_fragment_address
tp_hp_fragment = tp_hp_slice.narrow(0, lp_frag_address.start, lp_frag_address.numel)
Expand Down
Loading
Loading