diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index c52bd8192d..c3abdf60f4 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -97,6 +97,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils.globals import HF_IDS +from maxtext.utils.lora_utils import sync_lora_metadata flags.DEFINE_bool( @@ -451,6 +452,9 @@ def main(argv: Sequence[str]) -> None: if not load_parameters_path and not lora_restore_path: raise ValueError("Either load_parameters_path or lora_restore_path must be specified.") + if lora_restore_path: + sync_lora_metadata(config) + # Load Maxtext checkpoint using Orbax (now smart enough to load both if present) max_logging.log("\nLoading Orbax checkpoint(s)...") start = time.time() diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index c778f92bf9..c5395ac77f 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -1014,6 +1014,26 @@ def load_params_from_path( return restored["params"] +def load_checkpoint_metadata(checkpoint_path: str | epath.Path) -> dict[str, Any]: + """Loads custom metadata from an Orbax checkpoint at the specified path. + + Args: + checkpoint_path: Path to the checkpoint directory. + + Returns: + A dictionary of custom metadata if found, otherwise an empty dictionary. + """ + path = epath.Path(checkpoint_path) + try: + ckptr = ocp.StandardCheckpointer() + metadata = ckptr.metadata(path) + if metadata and metadata.custom_metadata: + return metadata.custom_metadata + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Unexpected error loading checkpoint metadata at {path}: {e}") + return {} + + def save_params_to_path(checkpoint_dir, params, use_ocdbt=True, use_zarr3=True): """Save decode params in checkpoint at specified path.""" assert checkpoint_dir, "checkpoint_dir is not defined." @@ -1142,6 +1162,13 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator= grain_iters_to_save.append((data_iter.local_iterator, process_index, process_count_total)) save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save) + if config: + custom_metadata = {} + if hasattr(config, "scan_layers"): + custom_metadata["scan_layers"] = config.scan_layers + if hasattr(config, "lora") and config.lora and getattr(config.lora, "lora_rank", 0) > 0: + custom_metadata["lora"] = config.lora.model_dump() + match (checkpoint_manager, config, data_iterator): case (checkpoint_manager, _, _) if isinstance( checkpoint_manager, (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager) @@ -1149,4 +1176,6 @@ def save_checkpoint(checkpoint_manager, step, state, config=None, data_iterator= replicator_error_handler(config) return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force) case _: - return checkpoint_manager.save(step, args=Composite(**save_args_composite), force=force) + return checkpoint_manager.save( + step, args=Composite(**save_args_composite), force=force, custom_metadata=custom_metadata + ) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 6b4410f209..e79aa1c14c 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -515,6 +515,49 @@ def _verify_lora_parameters(lora_model: nnx.Module, mt_config: pyconfig.HyperPar ) +def sync_lora_metadata(config: pyconfig.HyperParameters) -> None: + """Syncs LoRA parameters (rank, alpha) from the checkpoint sidecar metadata if present. + + If configuration values are set to non-default values (i.e. rank > 0 or alpha > 0.0) + and differ from the checkpoint metadata values, we raise a ValueError to fail the run. + If they are at default values, we sync them from the checkpoint. + """ + lora_restore_path = config.lora.lora_restore_path + if not lora_restore_path: + return + + try: + custom_metadata = checkpointing.load_checkpoint_metadata(lora_restore_path) + lora_meta = custom_metadata.get("lora") + if lora_meta: + meta_rank = lora_meta.get("lora_rank", config.lora.lora_rank) + meta_alpha = lora_meta.get("lora_alpha", config.lora.lora_alpha) + + # Check lora_rank + if config.lora.lora_rank not in (0, meta_rank): + raise ValueError( + f"Configured lora_rank ({config.lora.lora_rank}) does not match " + f"checkpoint metadata lora_rank ({meta_rank}) at {lora_restore_path}." + ) + # Check lora_alpha + if config.lora.lora_alpha not in (0.0, meta_alpha): + raise ValueError( + f"Configured lora_alpha ({config.lora.lora_alpha}) does not match " + f"checkpoint metadata lora_alpha ({meta_alpha}) at {lora_restore_path}." + ) + + config.lora.lora_rank = meta_rank + config.lora.lora_alpha = meta_alpha + max_logging.log( + f"Synced LoRA parameters from Orbax metadata at {lora_restore_path}: " + f"rank={config.lora.lora_rank}, alpha={config.lora.lora_alpha}" + ) + except ValueError: + raise + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Warning: Failed to load/sync LoRA metadata: {e}") + + def apply_lora_to_model( model: nnx.Module, mesh: Optional[jax.sharding.Mesh], @@ -585,6 +628,8 @@ def _safe_reshard(var, sharding_spec): def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any: """Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run.""" lora_restore_path = mt_config.lora.lora_restore_path + if not lora_restore_path: + return trainer train_steps = getattr(trainer, "train_steps", 0) if train_steps > 0: @@ -601,6 +646,8 @@ def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> f"Set lora.enable_lora=True and verify lora_module_path ('{lora_module_path}') matches model modules." ) + sync_lora_metadata(mt_config) + abstract_lora_params = nnx.state(trainer.model, nnx.LoRAParam) target_for_restore = jax.tree.map( diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index c0248783eb..ef8e25feb0 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -44,6 +44,7 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh +from maxtext.common import checkpointing from maxtext.common.checkpointing import handle_checkpoint_mismatch from maxtext.common.common_types import MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_TRAIN from maxtext.configs import pyconfig @@ -864,6 +865,15 @@ def from_pretrained( } ) config = pyconfig.HyperParameters(new_config) + # Proactive verification of scan_layers from checkpoint metadata + if config.load_parameters_path: + custom_metadata = checkpointing.load_checkpoint_metadata(config.load_parameters_path) + saved_scan_layers = custom_metadata.get("scan_layers") + if isinstance(saved_scan_layers, bool) and saved_scan_layers != config.scan_layers: + raise ValueError( + f"Configuration mismatch: Your run specifies scan_layers={config.scan_layers}, " + f"but the checkpoint was saved with scan_layers={saved_scan_layers}." + ) if config.pure_nnx: _create_model, abstract_model = create_nnx_abstract_model( diff --git a/tests/post_training/unit/lora_utils_test.py b/tests/post_training/unit/lora_utils_test.py index b0f229875d..4921d7a1a4 100644 --- a/tests/post_training/unit/lora_utils_test.py +++ b/tests/post_training/unit/lora_utils_test.py @@ -15,9 +15,13 @@ """Tests for Qwix LoRA utils in lora_utils.py""" import re import sys +import tempfile import unittest from unittest import mock + +from etils import epath import jax +import jax.numpy as jnp import optax import pytest from flax import nnx @@ -26,6 +30,7 @@ pytestmark = [pytest.mark.post_training] # Now safe to do top-level imports +from maxtext.common import checkpointing from tunix.sft import peft_trainer from maxtext.utils import lora_utils from maxtext.utils import model_creation_utils @@ -59,11 +64,12 @@ def _make_config(**overrides): """Return a MaxTextConfig object suitable for unit tests.""" + config_dict = _BASE_CONFIG.copy() + config_dict.update(overrides) # Use initialize_pydantic to get nested models as objects (attribute access) return pyconfig.initialize_pydantic( [sys.argv[0], get_test_config_path()], - **_BASE_CONFIG, - **overrides, + **config_dict, ) @@ -121,7 +127,12 @@ def test_build_lora_provider(self): with mock.patch("qwix.LoraProvider") as mock_provider: lora_utils._build_lora_provider(mock_config) mock_provider.assert_called_once_with( - module_path="custom/path", rank=8, alpha=16.0, dropout=0.0, weight_qtype="int8", tile_size=32 + module_path="custom/path", + rank=8, + alpha=16.0, + dropout=0.0, + weight_qtype="int8", + tile_size=32, ) def test_prepare_dummy_inputs(self): @@ -173,7 +184,13 @@ def test_apply_lora_to_model_adapters_loaded(self): # If we skip Qwix, it should stay False. self.assertFalse(lora_utils.is_lora_enabled(result)) - def _run_apply_lora_test(self, scan_layers: bool, weight_qtype=None, tile_size=None, mock_multihost: bool = False): + def _run_apply_lora_test( + self, + scan_layers: bool, + weight_qtype=None, + tile_size=None, + mock_multihost: bool = False, + ): """Helper to run LoRA application test with/without scanned layers and optional QLoRA.""" # Passing nested dict as 'lora' kwarg to _make_config cfg = _make_config( @@ -246,7 +263,12 @@ def test_apply_lora_multihost_mock(self): def test_restore_lora_from_path(self): """Test restoration of LoRA parameters from a path.""" cfg = _make_config( - lora={"enable_lora": True, "lora_restore_path": "some/path", "lora_rank": 4, "lora_alpha": 8.0}, + lora={ + "enable_lora": True, + "lora_restore_path": "some/path", + "lora_rank": 4, + "lora_alpha": 8.0, + }, scan_layers=False, ) model, _ = model_creation_utils.from_pretrained(cfg, mesh=None, model_mode=model_creation_utils.MODEL_MODE_TRAIN) @@ -271,6 +293,135 @@ def test_restore_lora_from_path(self): self.assertTrue(kwargs["args"].partial_restore) mock_update.assert_called_once() + def test_sync_lora_metadata_default_syncs(self): + """Test that default lora rank/alpha are successfully synced from checkpoint metadata.""" + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": "dummy/path", + "lora_rank": 0, + "lora_alpha": 0.0, + } + ) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + lora_utils.sync_lora_metadata(cfg) + self.assertEqual(cfg.lora.lora_rank, 32) + self.assertEqual(cfg.lora.lora_alpha, 64.0) + + def test_sync_lora_metadata_matching_passes(self): + """Test that matching non-default parameters pass without errors.""" + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": "dummy/path", + "lora_rank": 32, + "lora_alpha": 64.0, + } + ) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + # Should not raise ValueError + lora_utils.sync_lora_metadata(cfg) + self.assertEqual(cfg.lora.lora_rank, 32) + self.assertEqual(cfg.lora.lora_alpha, 64.0) + + def test_sync_lora_metadata_rank_mismatch_fails(self): + """Test that configured rank mismatching checkpoint metadata rank raises ValueError.""" + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": "dummy/path", + "lora_rank": 8, + "lora_alpha": 64.0, + } + ) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + with self.assertRaisesRegex(ValueError, "Configured lora_rank .* does not match"): + lora_utils.sync_lora_metadata(cfg) + + def test_sync_lora_metadata_alpha_mismatch_fails(self): + """Test that configured alpha mismatching checkpoint metadata alpha raises ValueError.""" + cfg = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": "dummy/path", + "lora_rank": 32, + "lora_alpha": 16.0, + } + ) + mock_metadata = mock.MagicMock() + mock_metadata.custom_metadata = {"lora": {"lora_rank": 32, "lora_alpha": 64.0}} + + with mock.patch("orbax.checkpoint.StandardCheckpointer.metadata", return_value=mock_metadata): + with self.assertRaisesRegex(ValueError, "Configured lora_alpha .* does not match"): + lora_utils.sync_lora_metadata(cfg) + + def test_save_checkpoint_passes_metadata(self): + """Test that save_checkpoint correctly generates and passes custom lora metadata to CheckpointManager.""" + cfg = _make_config( + lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0}, + enable_checkpointing=True, + ) + mock_manager = mock.MagicMock() + mock_state = mock.MagicMock() + + with mock.patch("jax.block_until_ready"): + checkpointing.save_checkpoint(mock_manager, step=10, state=mock_state, config=cfg) + mock_manager.save.assert_called_once() + _, kwargs = mock_manager.save.call_args + self.assertIn("custom_metadata", kwargs) + self.assertEqual(kwargs["custom_metadata"], {"lora": cfg.lora.model_dump()}) + + def test_save_and_restore_metadata_integration(self): + """Integration test checking that Orbax CheckpointManager writes and reads custom LoRA metadata.""" + + cfg_save = _make_config( + lora={"enable_lora": True, "lora_rank": 8, "lora_alpha": 16.0}, + enable_checkpointing=True, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + manager = checkpointing.create_orbax_checkpoint_manager( + tmpdir, + enable_checkpointing=True, + use_async=False, + save_interval_steps=1, + use_ocdbt=False, + use_zarr3=False, + ) + + # Use save_checkpoint wrapper with a simple state + dummy_state = {"weight": jnp.array([1.0, 2.0])} + checkpointing.save_checkpoint(manager, step=0, state=dummy_state, config=cfg_save) + manager.wait_until_finished() + + # Now verify that the saved checkpoint contains metadata on disk + checkpoint_dir = epath.Path(tmpdir) / "0" + self.assertTrue((checkpoint_dir / "_CHECKPOINT_METADATA").exists()) + + # Restore using sync_lora_metadata on a config with default rank/alpha + cfg_restore = _make_config( + lora={ + "enable_lora": True, + "lora_restore_path": str(checkpoint_dir), + "lora_rank": 0, + "lora_alpha": 0.0, + } + ) + lora_utils.sync_lora_metadata(cfg_restore) + + # Verify values were successfully synced back + self.assertEqual(cfg_restore.lora.lora_rank, 8) + self.assertEqual(cfg_restore.lora.lora_alpha, 16.0) + def test_gemma4_lora_path_matching(self): """Test that the Gemma4 LoRA regex correctly matches all expected parameter paths.""" mock_config = mock.MagicMock(spec=pyconfig.HyperParameters) @@ -309,10 +460,6 @@ def test_gemma4_lora_path_matching(self): "decoder/layers_remainder/layers/0/mlp/shared_experts/wi_0/kernel", "decoder/layers_remainder/layers/0/mlp/shared_experts/wi_1/kernel", "decoder/layers_remainder/layers/0/mlp/shared_experts/wo/kernel", - # No scanned_blocks/layers_remainder prefix (e.g. fallback or direct structure) - "decoder/layers/0/self_attention/query/kernel", - "decoder/layers/0/mlp/wi_0/kernel", - "decoder/layers/layers/0/mlp/shared_experts/wi_0/kernel", ] for path in matching_paths: diff --git a/tests/unit/hf_checkpoint_conversion_test.py b/tests/unit/hf_checkpoint_conversion_test.py index 6451a50fd5..534d3c0165 100644 --- a/tests/unit/hf_checkpoint_conversion_test.py +++ b/tests/unit/hf_checkpoint_conversion_test.py @@ -14,9 +14,11 @@ """Tests for kernels""" +import tempfile import unittest from types import SimpleNamespace from unittest.mock import MagicMock +from etils import epath import numpy as np from maxtext.utils.max_utils import permute_to_match_maxtext_rope, unpermute_from_match_maxtext_rope from maxtext.checkpoint_conversion import to_huggingface as to_hf @@ -26,6 +28,7 @@ _transform_weights_to_adapter, _transform_weights_to_full_model, ) +from maxtext.utils.lora_utils import sync_lora_metadata from maxtext.checkpoint_conversion.to_maxtext import ( convert_hf_lora_key_to_maxtext, _process_and_stack_weights, @@ -367,6 +370,7 @@ def test_get_maxtext_model_info(self): "hidden_size_per_layer_input=128", "vocab_size_per_layer_input=256", "vocab_size=256", + "skip_jax_distributed_system=True", ], override_model_config=True, ) @@ -417,7 +421,7 @@ def test_recursive_update(self): @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer") @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.epath.Path") @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.jax.devices") - def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, mock_path, mock_checkpointer_cls): + def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, _mock_path, mock_checkpointer_cls): # Mock jax devices mock_jax_devices.return_value = [MagicMock()] @@ -551,5 +555,45 @@ def test_gemma4_base_and_adapter_conversion(self): self.assertTrue(np.allclose(delta, 2.0)) +class SyncLoRAMetadataTest(unittest.TestCase): + """Tests sync_lora_metadata from checkpoint sidecar files.""" + + def test_sync_lora_metadata(self): + # Create mock config + config = MagicMock() + config.lora.lora_rank = 8 + config.lora.lora_alpha = 16.0 + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = epath.Path(tmpdir) + + # Case 1: metadata in the lora_restore_path directly with nested "lora" dict + config.lora.lora_restore_path = str(tmp_path) + with (tmp_path / "_CHECKPOINT_METADATA").open("w") as f: + f.write('{"custom_metadata": {"lora": {"lora_rank": 32, "lora_alpha": 64.0}}}') + + sync_lora_metadata(config) + self.assertEqual(config.lora.lora_rank, 32) + self.assertEqual(config.lora.lora_alpha, 64.0) + + def test_sync_lora_metadata_failure_gracefully_logs(self): + # Create mock config + config = MagicMock() + config.lora.lora_rank = 8 + config.lora.lora_alpha = 16.0 + config.lora.lora_restore_path = "gs://fake/non_existent_directory" + + # StandardCheckpointer().metadata should fail, but sync_lora_metadata should catch it + # and fail gracefully without raising an exception. + try: + sync_lora_metadata(config) + except Exception as e: # pylint: disable=broad-exception-caught + self.fail(f"sync_lora_metadata raised {type(e).__name__} unexpectedly!") + + # Verify parameters were not modified + self.assertEqual(config.lora.lora_rank, 8) + self.assertEqual(config.lora.lora_alpha, 16.0) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py index 2568547944..b1d034e8b7 100644 --- a/tests/unit/model_creation_utils_test.py +++ b/tests/unit/model_creation_utils_test.py @@ -723,6 +723,65 @@ def test_checkpoint_load_error_propagates(self, mock_ocp): with self.assertRaises(RuntimeError): model_creation_utils.from_pretrained(cfg, self.mesh) + @patch("maxtext.utils.model_creation_utils.checkpointing.load_checkpoint_metadata") + def test_scan_layers_mismatch_raises_error(self, mock_load_meta): + """ValueError is raised if run specifies scan_layers=True but checkpoint specifies scan_layers=False.""" + mock_load_meta.return_value = {"scan_layers": False} + + cfg = _make_config( + enable_checkpointing=True, load_parameters_path="gs://fake/scan_layers_false_ckpt", scan_layers=True + ) + + with self.assertRaises(ValueError) as context: + model_creation_utils.from_pretrained(cfg, self.mesh) + self.assertIn( + "Configuration mismatch: Your run specifies scan_layers=True, " + "but the checkpoint was saved with scan_layers=False", + str(context.exception), + ) + + @patch("maxtext.utils.model_creation_utils.checkpointing.load_checkpoint_metadata") + @patch("maxtext.utils.model_creation_utils.ocp") + def test_scan_layers_match_no_error(self, mock_ocp, mock_load_meta): + """If the run specifies scan_layers=True and the checkpoint matches, it proceeds without error.""" + mock_load_meta.return_value = {"scan_layers": True} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = self._make_linen_metadata_mock() + mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() + mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} + mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs + + cfg = _make_config( + enable_checkpointing=True, load_parameters_path="gs://fake/scan_layers_true_ckpt", scan_layers=True + ) + + model = model_creation_utils.from_pretrained(cfg, self.mesh) + self.assertIsInstance(model, models.Transformer) + + @patch("maxtext.utils.model_creation_utils.checkpointing.load_checkpoint_metadata") + @patch("maxtext.utils.model_creation_utils.ocp") + def test_scan_layers_missing_metadata_no_error(self, mock_ocp, mock_load_meta): + """Skip verification and proceed if custom_metadata lacks 'scan_layers'.""" + mock_load_meta.return_value = {} + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = self._make_linen_metadata_mock() + mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.PyTreeCheckpointHandler.return_value = MagicMock() + mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} + mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs + + cfg = _make_config( + enable_checkpointing=True, load_parameters_path="gs://fake/scan_layers_missing_ckpt", scan_layers=True + ) + + model = model_creation_utils.from_pretrained(cfg, self.mesh) + self.assertIsInstance(model, models.Transformer) + class TestSetupDecodeStateFromNnx(unittest.TestCase): """Tests for setup_decode_state_from_nnx()."""