Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from maxtext.utils import max_logging
from maxtext.utils import max_utils
from maxtext.utils.globals import HF_IDS
from maxtext.utils.lora_utils import sync_lora_metadata


flags.DEFINE_bool(
Expand Down Expand Up @@ -451,6 +452,9 @@ def main(argv: Sequence[str]) -> None:
if not load_parameters_path and not lora_restore_path:
raise ValueError("Either load_parameters_path or lora_restore_path must be specified.")

if lora_restore_path:
sync_lora_metadata(config)

# Load Maxtext checkpoint using Orbax (now smart enough to load both if present)
max_logging.log("\nLoading Orbax checkpoint(s)...")
start = time.time()
Expand Down
31 changes: 30 additions & 1 deletion src/maxtext/common/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,26 @@ def load_params_from_path(
return restored["params"]


def load_checkpoint_metadata(checkpoint_path: str | epath.Path) -> dict[str, Any]:
"""Loads custom metadata from an Orbax checkpoint at the specified path.

Args:
checkpoint_path: Path to the checkpoint directory.

Returns:
A dictionary of custom metadata if found, otherwise an empty dictionary.
"""
path = epath.Path(checkpoint_path)
try:
ckptr = ocp.StandardCheckpointer()
metadata = ckptr.metadata(path)
if metadata and metadata.custom_metadata:
return metadata.custom_metadata
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"Unexpected error loading checkpoint metadata at {path}: {e}")
return {}


def save_params_to_path(checkpoint_dir, params, use_ocdbt=True, use_zarr3=True):
"""Save decode params in checkpoint at specified path."""
assert checkpoint_dir, "checkpoint_dir is not defined."
Expand Down Expand Up @@ -1142,11 +1162,20 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator=
grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total))
save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save)

if config:
custom_metadata = {}
if hasattr(config, "scan_layers"):
custom_metadata["scan_layers"] = config.scan_layers
if hasattr(config, "lora") and config.lora and getattr(config.lora, "lora_rank", 0) > 0:
custom_metadata["lora"] = config.lora.model_dump()

match (checkpoint_manager, config, data_iterator):
case (checkpoint_manager, _, _) if isinstance(
checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager)
):
replicator_error_handler(config)
return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force)
case _:
return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force)
return checkpoint_manager.save(
step, args=Composite(**save_args_composite), force=force, custom_metadata=custom_metadata
)
47 changes: 47 additions & 0 deletions src/maxtext/utils/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,49 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar
)


def sync_lora_metadata(config: pyconfig.HyperParameters) -> None:
"""Syncs LoRA parameters (rank, alpha) from the checkpoint sidecar metadata if present.

If configuration values are set to non-default values (i.e. rank > 0 or alpha > 0.0)
and differ from the checkpoint metadata values, we raise a ValueError to fail the run.
If they are at default values, we sync them from the checkpoint.
"""
lora_restore_path = config.lora.lora_restore_path
if not lora_restore_path:
return

try:
custom_metadata = checkpointing.load_checkpoint_metadata(lora_restore_path)
lora_meta = custom_metadata.get("lora")
if lora_meta:
meta_rank = lora_meta.get("lora_rank", config.lora.lora_rank)
meta_alpha = lora_meta.get("lora_alpha", config.lora.lora_alpha)

# Check lora_rank
if config.lora.lora_rank not in (0, meta_rank):
raise ValueError(
f"Configured lora_rank ({config.lora.lora_rank}) does not match "
f"checkpoint metadata lora_rank ({meta_rank}) at {lora_restore_path}."
)
# Check lora_alpha
if config.lora.lora_alpha not in (0.0, meta_alpha):
raise ValueError(
f"Configured lora_alpha ({config.lora.lora_alpha}) does not match "
f"checkpoint metadata lora_alpha ({meta_alpha}) at {lora_restore_path}."
)

config.lora.lora_rank = meta_rank
config.lora.lora_alpha = meta_alpha
max_logging.log(
f"Synced LoRA parameters from Orbax metadata at {lora_restore_path}: "
f"rank={config.lora.lora_rank}, alpha={config.lora.lora_alpha}"
)
except ValueError:
raise
except Exception as e: # pylint: disable=broad-except
max_logging.log(f"Warning: Failed to load/sync LoRA metadata: {e}")


def apply_lora_to_model(
model: nnx.Module,
mesh: Optional[jax.sharding.Mesh],
Expand Down Expand Up @@ -585,6 +628,8 @@ def _safe_reshard(var, sharding_spec):
def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any:
"""Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run."""
lora_restore_path = mt_config.lora.lora_restore_path
if not lora_restore_path:
return trainer

train_steps = getattr(trainer, "train_steps", 0)
if train_steps > 0:
Expand All @@ -601,6 +646,8 @@ def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) ->
f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules."
)

sync_lora_metadata(mt_config)

abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam)

target_for_restore = jax.tree.map(
Expand Down
10 changes: 10 additions & 0 deletions src/maxtext/utils/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common import checkpointing
from maxtext.common.checkpointing import handle_checkpoint_mismatch
from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN
from maxtext.configs import pyconfig
Expand Down Expand Up @@ -864,6 +865,15 @@ def from_pretrained(
}
)
config = pyconfig.HyperParameters(new_config)
# Proactive verification of scan_layers from checkpoint metadata
if config.load_parameters_path:
custom_metadata = checkpointing.load_checkpoint_metadata(config.load_parameters_path)
saved_scan_layers = custom_metadata.get("scan_layers")
if isinstance(saved_scan_layers, bool) and saved_scan_layers != config.scan_layers:
raise ValueError(
f"Configuration mismatch: Your run specifies scan_layers={config.scan_layers}, "
f"but the checkpoint was saved with scan_layers={saved_scan_layers}."
)

if config.pure_nnx:
_create_model, abstract_model = create_nnx_abstract_model(
Expand Down
165 changes: 156 additions & 9 deletions tests/post_training/unit/lora_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
"""Tests for Qwix LoRA utils in lora_utils.py"""
import re
import sys
import tempfile
import unittest
from unittest import mock

from etils import epath
import jax
import jax.numpy as jnp
import optax
import pytest
from flax import nnx
Expand All @@ -26,6 +30,7 @@
pytestmark = [pytest.mark.post_training]

# Now safe to do top-level imports
from maxtext.common import checkpointing
from tunix.sft import peft_trainer
from maxtext.utils import lora_utils
from maxtext.utils import model_creation_utils
Expand Down Expand Up @@ -59,11 +64,12 @@

def _make_config(**overrides):
"""Return a MaxTextConfig object suitable for unit tests."""
config_dict = _BASE_CONFIG.copy()
config_dict.update(overrides)
# Use initialize_pydantic to get nested models as objects (attribute access)
return pyconfig.initialize_pydantic(
[sys.argv[0], get_test_config_path()],
**_BASE_CONFIG,
**overrides,
**config_dict,
)


Expand Down Expand Up @@ -121,7 +127,12 @@ def test_build_lora_provider(self):
with mock.patch("qwix.LoraProvider") as mock_provider:
lora_utils._build_lora_provider(mock_config)
mock_provider.assert_called_once_with(
module_path="custom/path", rank=8, alpha=16.0, dropout=0.0, weight_qtype="int8", tile_size=32
module_path="custom/path",
rank=8,
alpha=16.0,
dropout=0.0,
weight_qtype="int8",
tile_size=32,
)

def test_prepare_dummy_inputs(self):
Expand Down Expand Up @@ -173,7 +184,13 @@ def test_apply_lora_to_model_adapters_loaded(self):
# If we skip Qwix, it should stay False.
self.assertFalse(lora_utils.is_lora_enabled(result))

def _run_apply_lora_test(self, scan_layers: bool, weight_qtype=None, tile_size=None, mock_multihost: bool = False):
def _run_apply_lora_test(
self,
scan_layers: bool,
weight_qtype=None,
tile_size=None,
mock_multihost: bool = False,
):
"""Helper to run LoRA application test with/without scanned layers and optional QLoRA."""
# Passing nested dict as 'lora' kwarg to _make_config
cfg = _make_config(
Expand Down Expand Up @@ -246,7 +263,12 @@ def test_apply_lora_multihost_mock(self):
def test_restore_lora_from_path(self):
"""Test restoration of LoRA parameters from a path."""
cfg = _make_config(
lora={"enable_lora": True, "lora_restore_path": "some/path", "lora_rank": 4, "lora_alpha": 8.0},
lora={
"enable_lora": True,
"lora_restore_path": "some/path",
"lora_rank": 4,
"lora_alpha": 8.0,
},
scan_layers=False,
)
model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN)
Expand All @@ -271,6 +293,135 @@ def test_restore_lora_from_path(self):
self.assertTrue(kwargs["args"].partial_restore)
mock_update.assert_called_once()

def test_sync_lora_metadata_default_syncs(self):
"""Test that default lora rank/alpha are successfully synced from checkpoint metadata."""
cfg = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": "dummy/path",
"lora_rank": 0,
"lora_alpha": 0.0,
}
)
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
lora_utils.sync_lora_metadata(cfg)
self.assertEqual(cfg.lora.lora_rank, 32)
self.assertEqual(cfg.lora.lora_alpha, 64.0)

def test_sync_lora_metadata_matching_passes(self):
"""Test that matching non-default parameters pass without errors."""
cfg = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": "dummy/path",
"lora_rank": 32,
"lora_alpha": 64.0,
}
)
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
# Should not raise ValueError
lora_utils.sync_lora_metadata(cfg)
self.assertEqual(cfg.lora.lora_rank, 32)
self.assertEqual(cfg.lora.lora_alpha, 64.0)

def test_sync_lora_metadata_rank_mismatch_fails(self):
"""Test that configured rank mismatching checkpoint metadata rank raises ValueError."""
cfg = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": "dummy/path",
"lora_rank": 8,
"lora_alpha": 64.0,
}
)
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
with self.assertRaisesRegex(ValueError, "Configured lora_rank .* does not match"):
lora_utils.sync_lora_metadata(cfg)

def test_sync_lora_metadata_alpha_mismatch_fails(self):
"""Test that configured alpha mismatching checkpoint metadata alpha raises ValueError."""
cfg = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": "dummy/path",
"lora_rank": 32,
"lora_alpha": 16.0,
}
)
mock_metadata = mock.MagicMock()
mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}

with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata):
with self.assertRaisesRegex(ValueError, "Configured lora_alpha .* does not match"):
lora_utils.sync_lora_metadata(cfg)

def test_save_checkpoint_passes_metadata(self):
"""Test that save_checkpoint correctly generates and passes custom lora metadata to CheckpointManager."""
cfg = _make_config(
lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0},
enable_checkpointing=True,
)
mock_manager = mock.MagicMock()
mock_state = mock.MagicMock()

with mock.patch("jax.block_until_ready"):
checkpointing.save_checkpoint(mock_manager, step=10, state=mock_state, config=cfg)
mock_manager.save.assert_called_once()
_, kwargs = mock_manager.save.call_args
self.assertIn("custom_metadata", kwargs)
self.assertEqual(kwargs["custom_metadata"], {"lora": cfg.lora.model_dump()})

def test_save_and_restore_metadata_integration(self):
"""Integration test checking that Orbax CheckpointManager writes and reads custom LoRA metadata."""

cfg_save = _make_config(
lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0},
enable_checkpointing=True,
)

with tempfile.TemporaryDirectory() as tmpdir:
manager = checkpointing.create_orbax_checkpoint_manager(
tmpdir,
enable_checkpointing=True,
use_async=False,
save_interval_steps=1,
use_ocdbt=False,
use_zarr3=False,
)

# Use save_checkpoint wrapper with a simple state
dummy_state = {"weight": jnp.array([1.0, 2.0])}
checkpointing.save_checkpoint(manager, step=0, state=dummy_state, config=cfg_save)
manager.wait_until_finished()

# Now verify that the saved checkpoint contains metadata on disk
checkpoint_dir = epath.Path(tmpdir) / "0"
self.assertTrue((checkpoint_dir / "_CHECKPOINT_METADATA").exists())

# Restore using sync_lora_metadata on a config with default rank/alpha
cfg_restore = _make_config(
lora={
"enable_lora": True,
"lora_restore_path": str(checkpoint_dir),
"lora_rank": 0,
"lora_alpha": 0.0,
}
)
lora_utils.sync_lora_metadata(cfg_restore)

# Verify values were successfully synced back
self.assertEqual(cfg_restore.lora.lora_rank, 8)
self.assertEqual(cfg_restore.lora.lora_alpha, 16.0)

def test_gemma4_lora_path_matching(self):
"""Test that the Gemma4 LoRA regex correctly matches all expected parameter paths."""
mock_config = mock.MagicMock(spec=pyconfig.HyperParameters)
Expand Down Expand Up @@ -309,10 +460,6 @@ def test_gemma4_lora_path_matching(self):
"decoder/layers_remainder/layers/0/mlp/shared_experts/wi_0/kernel",
"decoder/layers_remainder/layers/0/mlp/shared_experts/wi_1/kernel",
"decoder/layers_remainder/layers/0/mlp/shared_experts/wo/kernel",
# No scanned_blocks/layers_remainder prefix (e.g. fallback or direct structure)
"decoder/layers/0/self_attention/query/kernel",
"decoder/layers/0/mlp/wi_0/kernel",
"decoder/layers/layers/0/mlp/shared_experts/wi_0/kernel",
]

for path in matching_paths:
Expand Down
Loading
Loading