Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/maxtext/checkpoint_conversion/inspect_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def print_structure(data_dict, output_file=""):
"""Utility to format and print sorted keys and shapes from a flattened dictionary."""
if output_file:
# Save command
save_lines = [f"# {" ".join(sys.orig_argv)}", ""]
save_lines = [f"# {' '.join(sys.orig_argv)}", ""]

for key in sorted(data_dict.keys(), key=natural_sort_key):
line = f"key: {key} | {data_dict[key]}"
Expand Down
219 changes: 74 additions & 145 deletions src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any, Optional

from absl import flags
import contextlib
import datetime
from etils import epath
from flax import nnx
Expand Down Expand Up @@ -661,14 +660,12 @@ def _restore_grain_iterator(
if isinstance(data_iterator, RemoteIteratorWrapper):
grain_restore_args = GrainCheckpointRestore(item=data_iterator)
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
_assert_no_shaped_dtype_struct(restored_state)
return (restored_state, None)

# ElasticIterator: one shared `process_0.json` regardless of shard count.
if not isinstance(data_iterator, list) and isinstance(data_iterator.local_iterator, ElasticIterator):
grain_restore_args = GrainCheckpointRestore(item=data_iterator.local_iterator)
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
_assert_no_shaped_dtype_struct(restored_state)
return (restored_state, None)

directory = checkpoint_manager.directory / str(step) / "iter"
Expand Down Expand Up @@ -717,67 +714,9 @@ def _restore_grain_iterator(

# Call restore once with the composed arguments
restored_state = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args, iter=grain_restore_args))
_assert_no_shaped_dtype_struct(restored_state)
return (restored_state, None)


def is_structural_or_shape_mismatch(e: Exception) -> bool:
"""Helper to check if an exception is likely a PyTree structure or shape mismatch."""
if not isinstance(e, (ValueError, TypeError)):
return False
msg = str(e).lower()
mismatch_keywords = [
"mismatch",
"structure",
"shape",
"tree",
"leaf",
"leaves",
"paths matched",
"shapedtypestruct",
"invalid type",
]
return any(kw in msg for kw in mismatch_keywords)


def _assert_no_shaped_dtype_struct(pytree):
"""Asserts that there are no jax.ShapeDtypeStruct leaves in the restored pytree."""
if isinstance(pytree, jax.ShapeDtypeStruct):
raise ValueError(
"Some parameters in the restored state remained as ShapeDtypeStruct"
f" (indicating structure mismatch): {pytree}."
)

if hasattr(pytree, "keys") and hasattr(pytree, "__getitem__"):
for k in pytree.keys():
_assert_no_shaped_dtype_struct(pytree[k])
elif isinstance(pytree, (list, tuple)):
for v in pytree:
_assert_no_shaped_dtype_struct(v)
else:
leaves = jax.tree_util.tree_leaves(pytree)
if len(leaves) == 1 and leaves[0] is pytree:
return
for leaf in leaves:
_assert_no_shaped_dtype_struct(leaf)


@contextlib.contextmanager
def handle_checkpoint_mismatch(context_name: str, path: str):
"""Context manager to intercept PyTree/shape mismatches and raise descriptive errors."""
try:
yield
except Exception as e:
if is_structural_or_shape_mismatch(e):
raise ValueError(
f"Failed to {context_name} from {path}. This is often caused by a"
" mismatch in the 'scan_layers' configuration (stacked vs unstacked)"
" between your current execution command and the saved checkpoint."
f" Original error: {e}"
) from e
raise


def load_state_if_possible(
checkpoint_manager: CheckpointManager | None,
data_iterator: MultiHostDataLoadIterator | list[MultiHostDataLoadIterator] | None,
Expand Down Expand Up @@ -860,30 +799,26 @@ def map_to_pspec(data):
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
):
checkpoint_path = str(checkpoint_manager.directory / str(step) / "items")
with handle_checkpoint_mismatch("restore NNX checkpoint", checkpoint_path):
restored_nnx = _load_linen_checkpoint_into_nnx(
checkpoint_path,
abstract_unboxed_pre_state,
checkpoint_storage_concurrent_gb,
use_ocdbt,
use_zarr3,
)
_assert_no_shaped_dtype_struct(restored_nnx)
restored_nnx = _load_linen_checkpoint_into_nnx(
checkpoint_path,
abstract_unboxed_pre_state,
checkpoint_storage_concurrent_gb,
use_ocdbt,
use_zarr3,
)
return ({"items": restored_nnx}, None)

if isinstance(abstract_unboxed_pre_state, nnx.State) and isinstance(
checkpoint_manager,
(EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager),
):
checkpoint_path = str(checkpoint_manager.directory / str(step))
with handle_checkpoint_mismatch("restore emergency NNX checkpoint", checkpoint_path):
restored = _restore_emergency_linen_checkpoint_into_nnx(
checkpoint_manager,
step,
abstract_unboxed_pre_state,
map_to_pspec,
)
_assert_no_shaped_dtype_struct(restored)
restored = _restore_emergency_linen_checkpoint_into_nnx(
checkpoint_manager,
step,
abstract_unboxed_pre_state,
map_to_pspec,
)
return (
restored,
None,
Expand All @@ -902,49 +837,46 @@ def map_to_pspec(data):
)

checkpoint_path = str(checkpoint_manager.directory / str(step))
with handle_checkpoint_mismatch("restore checkpoint", checkpoint_path):
match (checkpoint_manager, dataset_type, data_iterator):
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
# or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and
# 'data_iterator' can be any value and aren't used in this pattern.
case (checkpoint_manager, _, _) if isinstance(
checkpoint_manager,
(
EmergencyCheckpointManager,
EmergencyReplicatorCheckpointManager,
),
):
restored = checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state
_assert_no_shaped_dtype_struct(restored)
return (
restored,
None,
)
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
# PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
case (
match (checkpoint_manager, dataset_type, data_iterator):
# Case 1: Matches if 'checkpoint_manager' is an instance of either EmergencyCheckpointManager
# or EmergencyReplicatorCheckpointManager. The '_' indicates that 'dataset_type' and
# 'data_iterator' can be any value and aren't used in this pattern.
case (checkpoint_manager, _, _) if isinstance(
checkpoint_manager,
(
EmergencyCheckpointManager,
EmergencyReplicatorCheckpointManager,
),
):
restored = checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state
return (
restored,
None,
)
# Case 2: Matches if dataset type is "grain" and the data iterator is not a
# PlaceHolderDataIterator and a specific checkpoint file exists for the iterator
case (
checkpoint_manager,
dataset_type,
data_iterator,
) if (
dataset_type == "grain"
and data_iterator
and not isinstance(data_iterator, PlaceHolderDataIterator)
and (checkpoint_manager.directory / str(step) / "iter").exists()
):
return _restore_grain_iterator(
checkpoint_manager,
dataset_type,
step,
data_iterator,
) if (
dataset_type == "grain"
and data_iterator
and not isinstance(data_iterator, PlaceHolderDataIterator)
and (checkpoint_manager.directory / str(step) / "iter").exists()
):
return _restore_grain_iterator(
checkpoint_manager,
step,
data_iterator,
checkpoint_args,
expansion_factor_real_data,
)
# Case 3: Default/Fallback case.
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
case _:
restored = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args))
_assert_no_shaped_dtype_struct(restored)
return (restored, None)
checkpoint_args,
expansion_factor_real_data,
)
# Case 3: Default/Fallback case.
# This case acts as a wildcard ('_') and matches if none of the preceding cases were met.
case _:
restored = checkpoint_manager.restore(step, args=Composite(items=checkpoint_args))
return (restored, None)

if source_checkpoint_layout == "safetensors_dynamic":
path = load_parameters_from_path or load_full_state_from_path
Expand All @@ -957,30 +889,26 @@ def map_to_pspec(data):
else:
params = abstract_unboxed_pre_state.params

with handle_checkpoint_mismatch("load parameters", load_parameters_from_path):
restored_params = load_params_from_path(
load_parameters_from_path,
params,
checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
_assert_no_shaped_dtype_struct(restored_params)
restored_params = load_params_from_path(
load_parameters_from_path,
params,
checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
return None, restored_params
elif load_full_state_from_path != "":
max_logging.log(f"Loading full state from path: {load_full_state_from_path}")
with handle_checkpoint_mismatch("load full state", load_full_state_from_path):
restored_state = _load_full_state_from_path(
path=load_full_state_from_path,
abstract_unboxed_pre_state=abstract_unboxed_pre_state,
enable_orbax_v1=enable_orbax_v1,
checkpoint_conversion_fn=checkpoint_conversion_fn,
source_checkpoint_layout=source_checkpoint_layout,
checkpoint_storage_concurrent_gb=checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
_assert_no_shaped_dtype_struct(restored_state)
restored_state = _load_full_state_from_path(
path=load_full_state_from_path,
abstract_unboxed_pre_state=abstract_unboxed_pre_state,
enable_orbax_v1=enable_orbax_v1,
checkpoint_conversion_fn=checkpoint_conversion_fn,
source_checkpoint_layout=source_checkpoint_layout,
checkpoint_storage_concurrent_gb=checkpoint_storage_concurrent_gb,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
)
return {"items": restored_state}, None
else:
max_logging.log("No existing checkpoints found, not restoring checkpoint.")
Expand Down Expand Up @@ -1201,11 +1129,12 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)

custom_metadata = None
if config and hasattr(config, "lora") and config.lora:
lora_rank = getattr(config.lora, "lora_rank", 0)
if lora_rank > 0 and hasattr(config.lora, "model_dump"):
custom_metadata = {"lora": config.lora.model_dump()}
custom_metadata = {}
if config:
if hasattr(config, "scan_layers"):
custom_metadata["scan_layers"] = config.scan_layers
if hasattr(config, "lora") and config.lora and getattr(config.lora, "lora_rank", 0) > 0:
custom_metadata["lora"] = config.lora.model_dump()

match (checkpoint_manager, config, data_iterator):
case (checkpoint_manager, _, _) if isinstance(
Expand Down
Loading
Loading