feat(lora): save/restore LoRA config in checkpoint metadata#4269
feat(lora): save/restore LoRA config in checkpoint metadata#4269RexBearIU wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
shralex
left a comment
There was a problem hiding this comment.
Thanks Jackie! A significant thing missing in this PR is using the metadata file on checkpoint restore path.
187905b to
cd17578
Compare
|
Hi @shralex, thank you for the feedback! I have fully addressed your comments with the following changes:
Please let me know if you would like any other enhancements! |
cd17578 to
1b15640
Compare
1b15640 to
ae44adc
Compare
69c78a7 to
a701719
Compare
a701719 to
07c5e19
Compare
added the logic to re-use the metadata for checkpoint restore. |
5940e65 to
9bc253e
Compare
| ) | ||
| return trainer | ||
|
|
||
| sync_lora_metadata(mt_config) |
There was a problem hiding this comment.
lets move this down to after we verified that lora is enabled
There was a problem hiding this comment.
Done! I have moved sync_lora_metadata(config) down in to_huggingface.py so that it is called after we verify that LoRA is indeed enabled in the model configuration.
|
|
||
| def restore_lora_from_path(trainer: Any, mt_config: pyconfig.HyperParameters) -> Any: | ||
| """Restores LoRA parameter weights from an external Orbax checkpoint for a fresh run.""" | ||
| lora_restore_path = mt_config.lora.lora_restore_path |
There was a problem hiding this comment.
can we add a check here:
if not lora_restore_path:
return trainer # No restore requested; exit cleanly without error
(otherwise we're relying on the callers to always call this function when this path is set)
There was a problem hiding this comment.
Done! Added the guard check at the beginning of restore_lora_from_path so it returns the trainer early and exits cleanly if lora_restore_path is not set.
9bc253e to
0f6248b
Compare
shralex
left a comment
There was a problem hiding this comment.
This version reverts Xibin's previous version where sync_lora_metadata was in lora_utils. We should move it back there and use it not just on checkpoint conversion but also before model creation.
| save_args_composite["iter"] = GrainCheckpointSave(item=grain_iters_to_save) | ||
|
|
||
| custom_metadata = None | ||
| if config and config.lora.lora_rank > 0: |
There was a problem hiding this comment.
Lets check that config contains "lora" before accessing config.lora.lora_rank:
if config and hasattr(config, "lora") and config.lora:
lora_rank = getattr(config.lora, "lora_rank", 0)
if lora_rank > 0 and hasattr(config.lora, "model_dump"):
custom_metadata = {"lora": config.lora.model_dump()}
There was a problem hiding this comment.
Done! Added checks to ensure config has the lora attribute and is not None before attempting to access lora_rank or model_dump.
| replicator_error_handler(config) | ||
| return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force) | ||
| return checkpoint_manager.save( | ||
| step, args=Composite(state=checkpoint_args), force=force, custom_metadata=custom_metadata |
There was a problem hiding this comment.
EmergencyCheckpointManager and EmergencyReplicatorCheckpointManager do not accept a custom metadata argument. Lets leave this argument out here, and open a bug to add this support
There was a problem hiding this comment.
Done! Omitted passing the custom_metadata argument when calling .save() on EmergencyCheckpointManager or EmergencyReplicatorCheckpointManager.
There was a problem hiding this comment.
I've created a bug b/529671188 for Orbax team to add support on EmergencyCheckpointManager or EmergencyReplicatorCheckpointManager
d55b90d to
ffe10de
Compare
Done. Moved |
ffe10de to
2649217
Compare
Co-authored-by: Xibin Liu <xibin@google.com>
2649217 to
e58e177
Compare
Description
This PR implements native serialization of LoRA configuration parameters (
lora_rank,lora_alpha) in standard Orbax_CHECKPOINT_METADATAfiles, and automatically restores them during checkpoint-to-Hugging Face conversion.Why is this change being made?
Previously, users had to manually supply matching
lora.lora_rankandlora.lora_alphaparameters when converting MaxText checkpoints to Hugging Face format. Storing them in Orbax metadata makes the conversion seamless and error-free (resolves @igorts-git's request in #3970).Key Implementation Details
save_checkpoint(checkpointing.py), we save the activeconfig.lorablock under the"lora"key in Orbax'scustom_metadatawhen a LoRA rank is specified.main(to_huggingface.py),sync_lora_metadatareads the custom metadata fromlora_restore_pathviaocp.StandardCheckpointerand overrides active config parameters during conversion.hf_checkpoint_conversion_test.pyto move dynamically loaded inline imports to global top-level imports and completely removedjsonimport since JSON string is written directly.BUGS: #3970
Tests
We have verified the implementation with complete suite-level and individual unit-tests:
SyncLoRAMetadataTestintests/unit/hf_checkpoint_conversion_test.pyto verify the auto-resolving mechanism during Hugging Face conversion.python tests/unit/hf_checkpoint_conversion_test.pyAll tests pass successfully.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.