-
Notifications
You must be signed in to change notification settings - Fork 43
Better checkpoints #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cf63b8d
ed45212
a6fe0e5
958889c
ffa5630
5b508ee
b3a97c0
48cfd9b
e1e3fde
3e5b33a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| # TODO: Use packaging.version? (Safer but extra requirement) | ||
| import enum | ||
| import logging | ||
| import pathlib | ||
| import typing | ||
| import warnings | ||
|
|
||
| from fast_llm.config import Config, Field, FieldHint, check_field, config_class | ||
| from fast_llm.engine.config_utils.data_type import DataType | ||
| from fast_llm.utils import Assert | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # TODO: Use packaging.version? (Safer but extra requirement) | ||
| CHECKPOINT_VERSION = "0.1" | ||
| KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1") | ||
|
|
||
|
|
||
| class CheckpointFormat(str, enum.Enum): | ||
| # Distributed checkpoint for fast checkpointing and resuming. | ||
| distributed = "distributed" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about we call this
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also an option, let's discuss it later. |
||
| # Model state dict, for safe long-term storage in Fast-LLM format. | ||
| state_dict = "state_dict" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about we call this
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm ok with renaming (in a future PR), but I would prefer a name more related to Fast-LLM, since it's meant to be the "standard" Fast-LLM checkpoint format. (Maybe just "fast_llm"?) |
||
| # A checkpoint format external to Fast-LLM. | ||
| external = "external" | ||
|
|
||
|
|
||
| class ModelConfigType(str, enum.Enum): | ||
| none = "none" | ||
| architecture = "architecture" | ||
| model = "model" | ||
| fast_llm = "fast_llm" | ||
|
|
||
| @property | ||
| def load_architecture(self): | ||
| return self != ModelConfigType.none | ||
|
|
||
| @property | ||
| def load_base_model(self): | ||
| return self in (ModelConfigType.model, ModelConfigType.fast_llm) | ||
|
|
||
| @property | ||
| def load_fast_llm(self): | ||
| return self == ModelConfigType.fast_llm | ||
|
|
||
|
|
||
| @config_class() | ||
| class CheckpointPathConfigBase(Config): | ||
| _abstract = True | ||
| path: pathlib.Path | None = Field( | ||
| default=None, | ||
| desc="Location of the checkpoint.", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
|
|
||
| @config_class() | ||
| class CheckpointConfigBase(Config): | ||
| _abstract = True | ||
| format: CheckpointFormat = Field( | ||
| default=CheckpointFormat.distributed, | ||
| desc="Format of the checkpoint.", | ||
| hint=FieldHint.core, | ||
| ) | ||
| model_type: str | None = Field( | ||
| default=None, | ||
| desc="Model type for external models (ex. Huggingace model name).", | ||
| hint=FieldHint.feature, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def _from_dict( | ||
| cls, | ||
| default: dict[str, typing.Any], | ||
| strict: bool = True, | ||
| flat: bool = False, | ||
| ): | ||
| # TODO v0.2: Remove. | ||
| if default.get("format", None) == "huggingface": | ||
| warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.") | ||
| default["format"] = CheckpointFormat.external.value | ||
|
Comment on lines
+79
to
+81
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will there be other external checkpoint formats or why this change?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe, in theory the conversion mechanism can be used for any kind of checkpoint format, I had renamed anything else already but not this one because of backward compatibility. I'm thinking of getting rid of it altogether in the next PR though, and just use the "model_type" as the format for external formats. |
||
| cls._handle_renamed_field(default, "imported_type", "model_type") | ||
| return super()._from_dict(default, strict, flat) | ||
|
|
||
|
|
||
| @config_class() | ||
| class CheckpointStateConfigBase(Config): | ||
| _abstract = True | ||
| model_weights: bool = Field(default=True, desc="Save/load the model weights.", hint=FieldHint.feature) | ||
| optimizer_state: bool = Field(default=False, desc="Save/load the optimizer state.", hint=FieldHint.feature) | ||
|
|
||
| @classmethod | ||
| def _from_dict( | ||
| cls, | ||
| default: dict[str, typing.Any], | ||
| strict: bool = True, | ||
| flat: bool = False, | ||
| ): | ||
| cls._handle_renamed_field(default, "load_weights", "model_weights") | ||
| cls._handle_renamed_field(default, "load_optimizer", "optimizer_state") | ||
| return super()._from_dict(default, strict, flat) | ||
|
|
||
|
|
||
| @config_class() | ||
| class CheckpointSaveConfigBase(Config): | ||
| _abstract = True | ||
| parameters_per_file: int = Field( | ||
| default=2**32, | ||
| desc="Limit the number of parameters saved in each file.", | ||
| hint=FieldHint.feature, | ||
| valid=check_field(Assert.geq, 2**20), | ||
| ) | ||
| data_type: DataType | None = Field( | ||
| default=None, | ||
| desc="Data type to save the checkpoint.", | ||
| hint=FieldHint.feature, | ||
| ) | ||
|
|
||
|
|
||
| @config_class() | ||
| class CheckpointSaveMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase): | ||
| _abstract = False | ||
|
|
||
|
|
||
| @config_class() | ||
| class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateConfigBase, CheckpointSaveConfigBase): | ||
| _abstract = False | ||
|
|
||
|
|
||
| @config_class() | ||
| class CheckpointLoadMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase): | ||
| _abstract = False | ||
|
|
||
| load_config: ModelConfigType = Field( | ||
| default=ModelConfigType.architecture, | ||
| desc="Configuration to save/load.", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| @property | ||
| def compare_log_fn(self): | ||
| return ValueError if self.load_config.load_architecture else logger.warning | ||
|
|
||
|
|
||
| @config_class() | ||
| class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): | ||
| _abstract = False | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure where to ask this question:
where will we be configuring the layers-per-step option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a hack for the single process converter, it doesn't apply here because we already have the whole model loaded. I'm not entirely sure about memory usage but it should be OK because we reconstruct one layer at the time.