diff --git a/fast_llm/config.py b/fast_llm/config.py index 815b6b00a..b4ade6b99 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -153,6 +153,14 @@ def __init__( self.valid = valid +class FieldUpdate(dict): + """ + Specify some entries in the field that should be updated from the base class. + Useful for changing the default or description in a derived class. + Processed in `__init_subclass__`. + """ + + def check_field(fn, *args, **kwargs): """ Helper function to define a condition that a config field should satisfy, @@ -270,7 +278,7 @@ class Config: __class_validated__: typing.ClassVar[bool] = True _abstract: typing.ClassVar[bool] = False _validated: bool = Field(init=False, repr=False) - _unknown_fields: dict[str] = Field(init=False, repr=False) + _unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False) def __post_init__(self): """ @@ -621,7 +629,7 @@ def _get_class_name(cls): @classmethod def from_dict( cls, - default: typing.Union["Config", dict[str]], + default: typing.Union["Config", dict[str, typing.Any]], *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True, ): @@ -646,7 +654,7 @@ def from_dict( @classmethod def from_flat_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, ): # TODO v0.2: Remove flat format @@ -655,7 +663,7 @@ def from_flat_dict( @classmethod def _from_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, flat: bool = False, ): @@ -768,7 +776,7 @@ def _from_dict_dict(cls, value, type_, strict: bool): return {key: cls._from_dict_nested(value_, args[1], strict) for key, value_ in value.items()} @classmethod - def _handle_renamed_field(cls, default: dict[str], old_name: str, new_name: str): + def _handle_renamed_field(cls, default: dict[str, typing.Any], old_name: str, new_name: str): if old_name in default: warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.") default[new_name] = default.pop(old_name) @@ -803,11 +811,51 @@ def _check_abstract(cls): if not cls.__class_validated__: raise RuntimeError(f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator.") - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. """ - assert ( - cls.__class_validated__ - ), f"Parent class of config class {cls.__name__} has not been validated. Make sure to use the @config_class decorator." + for base_class in cls.__mro__: + if issubclass(base_class, Config): + assert cls.__class_validated__, ( + f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated." + f" Make sure to use the @config_class decorator." + ) cls.__class_validated__ = False + for name in list(cls.__dict__): + value = getattr(cls, name) + if isinstance(value, FieldUpdate): + # In case of multiple inheritance, the base class field may not appear in `cls.__dataclass_fields__`. + # so we iterate over superclasses following mro and use the first match. + base_class_field = None + for base_class in cls.__mro__: + base_class_fields = getattr(base_class, "__dataclass_fields__", {}) + if name in base_class_fields: + base_class_field = base_class_fields[name] + break + if base_class_field is None: + raise RuntimeError(f"Trying to update the non-existent field {name} in class {get_type_name(cls)}") + setattr( + cls, + name, + Field( + desc=value.pop("desc", base_class_field.desc), + doc=value.pop("doc", base_class_field.doc), + hint=value.pop("hint", base_class_field.hint), + valid=value.pop("valid", base_class_field.valid), + default=value.pop("default", base_class_field.default), + default_factory=value.pop("default_factory", base_class_field.default_factory), + repr=value.pop("repr", base_class_field.repr), + hash=value.pop("hash", base_class_field.hash), + compare=value.pop("compare", base_class_field.compare), + metadata=value.pop("metadata", base_class_field.metadata), + kw_only=value.pop("kw_only", base_class_field.kw_only), + ), + ) + if name in cls.__annotations__: + # TODO: Generalize to other type hints. + if isinstance(cls.__annotations__[name], type) and isinstance(base_class_field.type, type): + Assert.custom(issubclass, cls.__annotations__[name], base_class_field.type) + else: + # dataclasses expects an annotation, so we use the one from the base class. + cls.__annotations__[name] = base_class_field.type diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 66d4db581..433f1c92b 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -4,7 +4,7 @@ if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace - from fast_llm.engine.multi_stage.conversion import ModelConverter + from fast_llm.engine.multi_stage.conversion import ExternalModelConverter @config_class() @@ -30,7 +30,7 @@ def compare_architecture( return self.get_architecture().compare(model_config.get_architecture(), log_fn) @classmethod - def get_converter_class(cls, model_type: str | None = None) -> type["ModelConverter"]: + def get_converter_class(cls, model_type: str | None = None) -> type["ExternalModelConverter"]: raise NotImplementedError() diff --git a/fast_llm/engine/config_utils/checkpoint.py b/fast_llm/engine/config_utils/checkpoint.py new file mode 100644 index 000000000..dc77fd5e0 --- /dev/null +++ b/fast_llm/engine/config_utils/checkpoint.py @@ -0,0 +1,147 @@ +# TODO: Use packaging.version? (Safer but extra requirement) +import enum +import logging +import pathlib +import typing +import warnings + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + +# TODO: Use packaging.version? (Safer but extra requirement) +CHECKPOINT_VERSION = "0.1" +KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1") + + +class CheckpointFormat(str, enum.Enum): + # Distributed checkpoint for fast checkpointing and resuming. + distributed = "distributed" + # Model state dict, for safe long-term storage in Fast-LLM format. + state_dict = "state_dict" + # A checkpoint format external to Fast-LLM. + external = "external" + + +class ModelConfigType(str, enum.Enum): + none = "none" + architecture = "architecture" + model = "model" + fast_llm = "fast_llm" + + @property + def load_architecture(self): + return self != ModelConfigType.none + + @property + def load_base_model(self): + return self in (ModelConfigType.model, ModelConfigType.fast_llm) + + @property + def load_fast_llm(self): + return self == ModelConfigType.fast_llm + + +@config_class() +class CheckpointPathConfigBase(Config): + _abstract = True + path: pathlib.Path | None = Field( + default=None, + desc="Location of the checkpoint.", + hint=FieldHint.core, + ) + + +@config_class() +class CheckpointConfigBase(Config): + _abstract = True + format: CheckpointFormat = Field( + default=CheckpointFormat.distributed, + desc="Format of the checkpoint.", + hint=FieldHint.core, + ) + model_type: str | None = Field( + default=None, + desc="Model type for external models (ex. Huggingace model name).", + hint=FieldHint.feature, + ) + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ): + # TODO v0.2: Remove. + if default.get("format", None) == "huggingface": + warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.") + default["format"] = CheckpointFormat.external.value + cls._handle_renamed_field(default, "imported_type", "model_type") + return super()._from_dict(default, strict, flat) + + +@config_class() +class CheckpointStateConfigBase(Config): + _abstract = True + model_weights: bool = Field(default=True, desc="Save/load the model weights.", hint=FieldHint.feature) + optimizer_state: bool = Field(default=False, desc="Save/load the optimizer state.", hint=FieldHint.feature) + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ): + cls._handle_renamed_field(default, "load_weights", "model_weights") + cls._handle_renamed_field(default, "load_optimizer", "optimizer_state") + return super()._from_dict(default, strict, flat) + + +@config_class() +class CheckpointSaveConfigBase(Config): + _abstract = True + parameters_per_file: int = Field( + default=2**32, + desc="Limit the number of parameters saved in each file.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 2**20), + ) + data_type: DataType | None = Field( + default=None, + desc="Data type to save the checkpoint.", + hint=FieldHint.feature, + ) + + +@config_class() +class CheckpointSaveMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase): + _abstract = False + + +@config_class() +class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateConfigBase, CheckpointSaveConfigBase): + _abstract = False + + +@config_class() +class CheckpointLoadMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase): + _abstract = False + + load_config: ModelConfigType = Field( + default=ModelConfigType.architecture, + desc="Configuration to save/load.", + hint=FieldHint.core, + ) + + @property + def compare_log_fn(self): + return ValueError if self.load_config.load_architecture else logger.warning + + +@config_class() +class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): + _abstract = False diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 740b83171..cd32b7d51 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -1,7 +1,6 @@ import logging import os import pathlib -import shutil import typing import warnings @@ -128,7 +127,6 @@ class Run: """ _experiment_dir: pathlib.Path | None - _checkpoint_dir: pathlib.Path | None def __init__( self, @@ -154,10 +152,7 @@ def __init__( if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() self.dataset_cache_dir = self._experiment_directory / "dataset_cache" - self._checkpoint_dir = self._experiment_directory / "checkpoints" - self._export_dir = self._experiment_directory / "export" if self._is_main_rank: - self._checkpoint_dir.mkdir(exist_ok=True, parents=True) (self._experiment_directory / "runs").mkdir(exist_ok=True, parents=True) run = len(list((self._experiment_directory / "runs").iterdir())) (self._experiment_directory / "runs" / str(run)).mkdir() @@ -166,12 +161,12 @@ def __init__( else: run = 0 # Make sure all the workers agree on the run. This also acts as a barrier. - self.index = self._broadcast_int(run) + self.index = self.broadcast_int(run) run_dir = self._experiment_directory / "runs" / str(self.index) self._artifact_dir = run_dir / "artifacts" / str(self._distributed_config.rank) log_dir = run_dir / "logs" else: - _experiment_directory, self._checkpoint_dir, self._artifact_dir, log_dir = None, None, None, None + _experiment_directory, self._artifact_dir, log_dir = None, None, None self.dataset_cache_dir = None self.index = None @@ -207,99 +202,18 @@ def save_logged_tensors(self, iteration: int | str): torch.save(tensor_stats, self.open_artifact(f"tensor_logs_{iteration}.pt", mode="wb")) TensorLogs.reset(self._config.tensor_logs) - def get_save_checkpoint_context(self, iteration: int, export: bool = False, keep: int | None = None): - return self._SaveCheckpointContext(self, iteration, export, keep) - - def get_load_checkpoint_context(self, iteration: int): - return self._LoadCheckpointContext(self, iteration) - def barrier(self, value: int | str = 1): from fast_llm.core.distributed import safe_barrier safe_barrier(self._distributed.world_group, value) - def _broadcast_int(self, value: int): + def broadcast_int(self, value: int): import torch from fast_llm.core.distributed import broadcast_scalar return broadcast_scalar(value, dtype=torch.int64, src=_MAIN_RANK, group=self._distributed.world_group) - class _CheckpointContext: - def __init__(self, run: "Run", iteration: int): - self._run = run - self._iteration = iteration - assert self._run._checkpoint_dir is not None - self._directory = self._run._checkpoint_dir / str(self._iteration) - - @property - def directory(self): - return self._directory - - class _SaveCheckpointContext(_CheckpointContext): - def __init__(self, run: "Run", iteration: int, export: bool = False, keep: int | None = None): - super().__init__(run, iteration) - self._export = export - self._keep = keep - if self._export: - self._link_directory = self._directory - self._directory = self._run._export_dir / str(self._iteration) - - def __enter__(self): - assert self._run._is_running - if self._run._is_main_rank: - logger.info(f"Saving checkpoint at iteration {self._iteration}") - self._directory.mkdir(parents=True) - if self._export: - (self._run._checkpoint_dir / str(self._iteration)).symlink_to(self._directory) - # Barrier to ensure the directory is created correctly (and didn't exist before). - self._run.barrier(f"save {self._iteration} enter") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_type: - self._run.barrier(f"save {self._iteration} exit") - if self._run._is_main_rank: - # Prevent corrupted checkpoint. - (self._directory / "ok").open("w") - logger.info(f"Checkpoint saved to {self._directory}") - self._run._delete_old_checkpoints(self._keep) - - class _LoadCheckpointContext(_CheckpointContext): - def __enter__(self): - assert self._run._is_running - Assert.custom(pathlib.Path.is_file, self._directory / "ok") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_type: - self._run.barrier(f"load {self._iteration} exit") - - def _delete_old_checkpoints(self, keep: int | None): - assert self._is_running - if keep is None: - return - checkpoints = sorted(int(path.name) for path in self._checkpoint_dir.iterdir()) - for checkpoint in checkpoints[:-keep]: - path = self._checkpoint_dir / str(checkpoint) - logger.info(f"Deleting checkpoint at {path}") - try: - shutil.rmtree(path, ignore_errors=True) - except OSError as e: - logger.warning(f"Could not remove checkpoint directory: {e.args}") - - def get_last_checkpoint(self): - assert self._is_running - if self._checkpoint_dir is None: - return None - if self._is_main_rank: - checkpoints = [int(path.name) for path in self._checkpoint_dir.iterdir()] - iteration = max(checkpoints) if checkpoints else -1 - else: - iteration = -1 - iteration = self._broadcast_int(iteration) - return iteration if iteration >= 0 else None - def open_artifact(self, name: str, mode: str | None = "w", verbose=True): assert self._is_running if self._artifact_dir is None: diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 1adff8bdd..b53263e4a 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -5,7 +5,8 @@ import transformers -from fast_llm.engine.multi_stage.config import CheckpointType, FastLLMModelConfig, PretrainedConfig +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadMetadataConfig +from fast_llm.engine.multi_stage.config import FastLLMModelConfig logger = logging.getLogger(__name__) @@ -35,7 +36,9 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = transformers.configuration_utils.CONFIG_NAME = _backup @classmethod - def _get_config_dict(cls, pretrained_model_name_or_path: str | os.PathLike | PretrainedConfig, **kwargs): + def _get_config_dict( + cls, pretrained_model_name_or_path: str | os.PathLike | CheckpointLoadMetadataConfig, **kwargs + ): # TODO: Support download from hub/url # Unused arguments, remove to avoid warnings. @@ -55,15 +58,15 @@ def _get_config_dict(cls, pretrained_model_name_or_path: str | os.PathLike | Pre # Get the pretrained config. if "pretrained" in kwargs: - assert isinstance(kwargs["pretrained"], PretrainedConfig) + assert isinstance(kwargs["pretrained"], CheckpointLoadMetadataConfig) assert kwargs["pretrained"].path == pretrained_model_name_or_path pretrained = kwargs.pop("pretrained") - elif isinstance(pretrained_model_name_or_path, PretrainedConfig): + elif isinstance(pretrained_model_name_or_path, CheckpointLoadMetadataConfig): pretrained = pretrained_model_name_or_path else: - pretrained = PretrainedConfig( + pretrained = CheckpointLoadMetadataConfig( path=pathlib.Path(pretrained_model_name_or_path), - format=CheckpointType.state_dict, + format=CheckpointFormat.state_dict, ) metadata = cls.model_config_class.load_pretrained_metadata(pretrained) updates = {} diff --git a/fast_llm/engine/huggingface/model.py b/fast_llm/engine/huggingface/model.py index fa46d0e40..45ab60cca 100644 --- a/fast_llm/engine/huggingface/model.py +++ b/fast_llm/engine/huggingface/model.py @@ -5,9 +5,10 @@ import transformers.modeling_outputs from fast_llm.config import NoAutoValidate +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadConfig from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.huggingface.config import HuggingfaceModelConfig -from fast_llm.engine.multi_stage.config import CheckpointType, PretrainedCheckpointConfig, PretrainedConfig, StageMode +from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner @@ -58,16 +59,16 @@ def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, @classmethod def from_pretrained( cls, - pretrained_model_name_or_path: str | os.PathLike | PretrainedCheckpointConfig, + pretrained_model_name_or_path: str | os.PathLike | CheckpointLoadConfig, *, mode: StageMode = StageMode.inference, **kwargs, ): # Pretrained config. - if not isinstance(pretrained_model_name_or_path, PretrainedConfig): - pretrained_model_name_or_path = PretrainedCheckpointConfig( + if not isinstance(pretrained_model_name_or_path, CheckpointLoadConfig): + pretrained_model_name_or_path = CheckpointLoadConfig( path=pathlib.Path(pretrained_model_name_or_path), - format=CheckpointType.state_dict, + format=CheckpointFormat.state_dict, ) config_updates = {} diff --git a/fast_llm/engine/multi_stage/checkpoint.py b/fast_llm/engine/multi_stage/checkpoint.py new file mode 100644 index 000000000..f25aa26e5 --- /dev/null +++ b/fast_llm/engine/multi_stage/checkpoint.py @@ -0,0 +1,110 @@ +import json +import logging + +import safetensors.torch +import torch +import yaml + +from fast_llm.core.distributed import safe_barrier +from fast_llm.engine.config_utils.checkpoint import CheckpointSaveConfig +from fast_llm.engine.distributed.distributed import Distributed + +logger = logging.getLogger(__name__) + + +def _export_safetensors_metadata(metadata): + """ + Safetensor only accepts string entries, so we convert to string explicitly. + We use yaml rather than json because json requires explicit quotation marks on strings, which breaks things. + (ex. "format": "pt" becomes '"pt"' which breaks huggingface models.) + We avoid using safe_dump for scalars because it adds junk ("\n...\n") at the end of the string + (decoding is unaffected.) + """ + return { + key: str(value) if isinstance(value, (str, int, float, bool)) else yaml.safe_dump(value) + for key, value in metadata.items() + } + + +class StateDictSaver: + def __init__( + self, + config: CheckpointSaveConfig, + *, + distributed: Distributed, + metadata, + base_file_name: str, + ): + self._config = config + self._metadata = metadata + self.distributed = distributed + self._distributed_config = distributed.config + self.base_file_name = ( + base_file_name + if self._distributed_config.pipeline_parallel == 1 + else f"{base_file_name}_{self._distributed_config.pipeline_rank}" + ) + # All ranks reconstruct the pipeline-parallel state (for simplicity), but only one saves it. + self._do_save = self._distributed_config.data_rank == self._distributed_config.tensor_rank == 0 + + def add_tensor(self, name: str, tensor: torch.Tensor): + assert name not in self.tensors + self.tensors[name] = tensor + self.param_count += tensor.numel() + if self.param_count >= self._config.parameters_per_file: + self._save_next_file() + + def __enter__(self): + self.file_count = 0 + self.param_count = 0 + self.tensors = {} + self.index = {} + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.tensors: + # Save the last file. + self._save_next_file() + + if self._do_save and self._distributed_config.pipeline_parallel != 1: + # Combine the indexes from all pipeline ranks. + logger.info(f"Merging pipeline-parallel indexes.") + json.dump( + self.index, (self._config.path / f"index_{self._distributed_config.pipeline_rank}.json").open("w") + ) + safe_barrier(self.distributed.pipeline_group, "save state dict") + if self._distributed_config.pipeline_rank == 0: + self.index = {} + for rank in range(self._distributed_config.pipeline_parallel): + file_name = self._config.path / f"index_{rank}.json" + local_index = json.load(file_name.open("r")) + for key, value in local_index.items(): + assert key not in self.index, key + self.index[key] = value + file_name.unlink() + + if self._distributed_config.rank == 0: + path = self._config.path / f"{self.base_file_name}.safetensors.index.json" + logger.info(f"Saving index to {path}") + # Save the index. + json.dump( + {"metadata": self._metadata, "weight_map": self.index}, + path.open("w"), + indent=4, + ) + + def _save_next_file(self): + file_name = f"{self.base_file_name}_{self.file_count}.safetensors" + if self._do_save: + logger.info(f"Saving tensors to {self._config.path / file_name}") + safetensors.torch.save_file( + tensors=self.tensors, + filename=self._config.path / file_name, + metadata=_export_safetensors_metadata(self._metadata), + ) + for name_ in self.tensors: + assert name_ not in self.index + self.index[name_] = file_name + self.file_count += 1 + self.param_count = 0 + self.tensors = {} diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e1a756c08..2f04bab4a 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -1,12 +1,17 @@ import enum import json import logging -import pathlib import typing from fast_llm.config import Config, Field, FieldHint, NoAutoValidate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.checkpoint import ( + CHECKPOINT_VERSION, + KNOWN_CHECKPOINT_VERSIONS, + CheckpointFormat, + CheckpointLoadConfig, + CheckpointLoadMetadataConfig, +) from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert @@ -174,98 +179,6 @@ def _validate(self): # TODO: Does this matter? Best value? SHARD_PAD_TO_MULTIPLE = 32 -# TODO: Use packaging.version? (Safer but extra requirement) -CHECKPOINT_VERSION = "0.1" -_KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1") - - -class CheckpointType(str, enum.Enum): - distributed = "distributed" - # TODO: Rename - huggingface = "huggingface" - # Model state dict, mostly for debug. - state_dict = "state_dict" - - -@config_class() -class PretrainedConfig(Config): - path: pathlib.Path | None = Field( - default=None, - desc="Path to the checkpoint.", - hint=FieldHint.core, - ) - format: CheckpointType = Field( - default=CheckpointType.distributed, - desc="Format of the checkpoint.", - hint=FieldHint.core, - ) - imported_type: str | None = Field( - default=None, - desc="Model type for external models (ex. Huggingace type).", - hint=FieldHint.feature, - ) - override_architecture: bool = Field( - default=False, - desc="Ignore the base model architecture from the pretrained checkpoint and use the provided one instead." - " May have unintended consequences.", - hint=FieldHint.feature, - ) - load_full_base_model_config: bool = Field( - default=False, - desc="Load the non-architecture model config from the checkpoint.", - hint=FieldHint.feature, - ) - load_full_fast_llm_config: bool = Field( - default=False, - desc="Load the distributed and multi-stage config from the checkpoint.", - hint=FieldHint.feature, - ) - - def _validate(self): - if self.load_full_fast_llm_config: - self.load_full_base_model_config = True - super()._validate() - - @property - def compare_log_fn(self): - return logger.warning if self.override_architecture else ValueError - - -@config_class() -class PretrainedCheckpointConfig(PretrainedConfig): - # Load weights from path (if applicable), - # otherwise reinitialize them (i.e. load the config only.) - load_weights: bool = Field(default=True, desc="Load model weights from the checkpoint.", hint=FieldHint.feature) - load_optimizer: bool = Field( - default=False, desc="Load the optimizer state from the checkpoint.", hint=FieldHint.feature - ) - - -@config_class() -class CheckpointConfig(Config): - # TODO: Merge/match with PretrainedConfig? - checkpoint_path: pathlib.Path = Field(desc="Path to the checkpoint.", hint=FieldHint.core) - checkpoint_type: CheckpointType = Field( - default=CheckpointType.distributed, desc="Format of the checkpoint.", hint=FieldHint.core - ) - exported_model_type: str | None = Field( - default=None, desc="Model type for external models (ex. Huggingace type).", hint=FieldHint.feature - ) - save_optimizer: bool = Field( - default=True, desc="Save the optimizer state from the checkpoint.", hint=FieldHint.feature - ) - target_params_per_file: int = Field( - default=2**32, - desc="Limit the number of parameters saved in each file.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 2**20), - ) - dtype: DataType | None = Field( - default=None, - desc="Data type to save the checkpoint.", - hint=FieldHint.feature, - ) - @config_class() class FastLLMModelConfig(Config): @@ -298,7 +211,7 @@ def get_base_model_config_cls(cls) -> type[BaseModelConfig]: @classmethod def from_pretrained( cls, - pretrained: PretrainedConfig, + pretrained: CheckpointLoadMetadataConfig, default: "FastLLMModelConfig" = None, ): # TODO: Add *updates? @@ -309,7 +222,7 @@ def from_pretrained( @classmethod def from_metadata( cls, - pretrained: PretrainedConfig, + pretrained: CheckpointLoadMetadataConfig, metadata: dict, default: "FastLLMModelConfig" = None, updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -318,24 +231,24 @@ def from_metadata( # TODO: Standardize to *updates? if "checkpoint_type" in metadata: # TODO python 3.12: Assert.incl(metadata["checkpoint_type"], CheckpointType) - CheckpointType(metadata["checkpoint_type"]) + CheckpointFormat(metadata["checkpoint_type"]) version = metadata["checkpoint_version"] - if version not in _KNOWN_CHECKPOINT_VERSIONS: + if version not in KNOWN_CHECKPOINT_VERSIONS: raise ValueError(f"Unrecognised checkpoint version: {version}") if version == "0": return cls._from_metadata_v0(pretrained, metadata, default, updates) pretrained_config = cls.from_dict(metadata["fast_llm_config"]) - if pretrained.override_architecture: + if not pretrained.load_config.load_architecture: assert default is not None config = default.to_copy() config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn) - elif pretrained.load_full_fast_llm_config: + elif pretrained.load_config.load_fast_llm: config = pretrained_config else: with NoAutoValidate(): config = cls() if default is None else default.to_copy() - if pretrained.load_full_base_model_config: + if pretrained.load_config.load_base_model: config.base_model = pretrained_config.base_model else: config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture()) @@ -348,7 +261,7 @@ def from_metadata( @classmethod def _from_metadata_v0( cls, - pretrained: PretrainedConfig, + pretrained: CheckpointLoadMetadataConfig, metadata: dict, default: "FastLLMModelConfig" = None, updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -361,22 +274,22 @@ def _from_metadata_v0( with NoAutoValidate(): if default is None: - assert not pretrained.override_architecture + assert pretrained.load_config.load_architecture config = cls(base_model=base_model_config_cls()) else: config = default.to_copy() - if pretrained.override_architecture: + if pretrained.load_config.load_architecture: config.validate() architecture_config.compare_architecture(default.base_model, pretrained.compare_log_fn) else: - if pretrained.load_full_base_model_config: + if pretrained.load_config.load_base_model: # Replace the whole config config.base_model = base_model_config_cls.from_flat_dict(metadata["model_config"]) else: # Replace the architecture parts of the config. config.base_model = config.base_model.to_copy(architecture_config) - if pretrained.load_full_fast_llm_config: + if pretrained.load_config.load_fast_llm: config.multi_stage = MultiStageConfig.from_flat_dict(metadata["multi_stage_config"]) config.distributed = DistributedConfig.from_flat_dict( metadata["distributed_config"], @@ -388,20 +301,20 @@ def _from_metadata_v0( return config @classmethod - def load_pretrained_metadata(cls, pretrained): + def load_pretrained_metadata(cls, pretrained: CheckpointLoadMetadataConfig): import yaml base_model_config_cls = cls.get_base_model_config_cls() - if pretrained.format == CheckpointType.distributed: + if pretrained.format == CheckpointFormat.distributed: return yaml.safe_load((pretrained.path / "metadata.yaml").open("r")) - elif pretrained.format == CheckpointType.state_dict: + elif pretrained.format == CheckpointFormat.state_dict: return json.load((pretrained.path / f"state_dict.safetensors.index.json").open("r"))["metadata"] - elif pretrained.format == CheckpointType.huggingface: - converter_class = base_model_config_cls.get_converter_class(pretrained.imported_type) + elif pretrained.format == CheckpointFormat.external: + converter_class = base_model_config_cls.get_converter_class(pretrained.model_type) imported_model_config = converter_class.import_config(converter_class.load_config(pretrained.path), True) return { "fast_llm_config": {"base_model": imported_model_config.to_serialized()}, - "checkpoint_type": CheckpointType.huggingface.value, + "checkpoint_type": CheckpointFormat.external.value, "checkpoint_version": CHECKPOINT_VERSION, } else: @@ -415,8 +328,8 @@ class PretrainedFastLLMModelConfig(Config): model: FastLLMModelConfig = Field( default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core ) - pretrained: PretrainedCheckpointConfig = Field( - default_factory=PretrainedCheckpointConfig, + pretrained: CheckpointLoadConfig = Field( + default_factory=CheckpointLoadConfig, desc="Configuration for loading the configuration and state of a pretrained model.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/multi_stage/conversion.py b/fast_llm/engine/multi_stage/conversion.py index 65ab7486e..4acb9d928 100644 --- a/fast_llm/engine/multi_stage/conversion.py +++ b/fast_llm/engine/multi_stage/conversion.py @@ -138,6 +138,30 @@ def import_weight( class ModelConverter(abc.ABC): + base_file_name: typing.ClassVar[str] + + @abc.abstractmethod + def convert_state_dict( + self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool + ) -> dict[str, torch.Tensor | SafeTensorSlice]: + pass + + +class TrivialConverter(ModelConverter): + base_file_name = "state_dict" + + def convert_state_dict( + self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool + ) -> dict[str, torch.Tensor | SafeTensorSlice]: + out_state_dict = {} + for key in list(state_dict): + name, shard_name = key + out_state_dict[f"{name}/{shard_name}"] = state_dict.pop(key) + return out_state_dict + + +class ExternalModelConverter(ModelConverter): + base_file_name = "model" _base_model_cls: type[BaseModelConfig] _config_converters: list[ParamConverter] @@ -165,12 +189,12 @@ def _create_weight_converters(self) -> list[WeightConverter]: @classmethod @abc.abstractmethod - def load_config(cls, directory: pathlib.Path | str) -> dict[str]: + def load_config(cls, directory: pathlib.Path | str) -> dict[str, typing.Any]: pass @classmethod @abc.abstractmethod - def save_config(cls, directory: pathlib.Path | str, config: dict[str]): + def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]): pass @abc.abstractmethod @@ -180,7 +204,7 @@ def load_weights( pass @classmethod - def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str]: + def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]: exported_config = {} for converter in cls._get_config_converters(): value = converter.export_param( @@ -194,7 +218,7 @@ def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str]: return exported_config # Noqa @classmethod - def import_config(cls, config: dict[str], architecture_only: bool = False): # noqa + def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): # noqa kwargs = {} for converter in cls._get_config_converters(): value = converter.import_param( @@ -209,24 +233,25 @@ def import_config(cls, config: dict[str], architecture_only: bool = False): # n return config_class.from_dict({}, kwargs) @classmethod - def from_config(cls, config: dict[str], architecture_only: bool = False): + def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): return cls(cls.import_config(config, architecture_only=architecture_only)) def convert_state_dict( - self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool + self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool ) -> dict[str, torch.Tensor | SafeTensorSlice]: out_state_dict = {} weight_converters = self._export_converters if export else self._import_converters - for state_dict_name in list(state_dict): + for state_dict_name, shard_name in list(state_dict): + assert shard_name == "weights" try: if state_dict_name not in weight_converters: continue weight_converter: WeightConverter = weight_converters[state_dict_name] in_names = weight_converter.fast_llm_name if export else weight_converter.export_name - if not all(name in state_dict for name in in_names): + if not all((name, shard_name) in state_dict for name in in_names): continue - in_weights = tuple(state_dict.pop(name) for name in in_names) + in_weights = tuple(state_dict.pop((name, shard_name)) for name in in_names) out_names = weight_converter.export_name if export else weight_converter.fast_llm_name out_weights = ( weight_converter.export_weight(in_weights) @@ -262,19 +287,19 @@ def _get_fast_llm_attribute(config: BaseModelArchitectureConfig, name: str | tup return val -class AutoModelConverter(ModelConverter, abc.ABC): - converter_map: dict[str, type[ModelConverter]] +class AutoModelConverter(ExternalModelConverter, abc.ABC): + converter_map: dict[str, type[ExternalModelConverter]] @classmethod - def import_config(cls, config: dict[str], architecture_only: bool = False): + def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): return cls.converter_map[config["model_type"]].import_config(config, architecture_only) @classmethod - def from_config(cls, config: dict[str], architecture_only: bool = False): + def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): return cls.converter_map[config["model_type"]].from_config(config, architecture_only) -class HuggingfaceModelConverter(ModelConverter, abc.ABC): +class HuggingfaceModelConverter(ExternalModelConverter, abc.ABC): model_type: str | None = None @classmethod @@ -292,7 +317,7 @@ def load_config(cls, directory: pathlib.Path | str): return config @classmethod - def save_config(cls, directory: pathlib.Path | str, config: dict[str]): + def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]): import transformers transformers.CONFIG_MAPPING[config["model_type"]].from_dict(config).save_pretrained(directory) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 1e4deaa8f..edcbdc0e2 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -8,17 +8,19 @@ import torch import yaml -from fast_llm.core.distributed import all_reduce, broadcast, safe_barrier +from fast_llm.core.distributed import all_reduce, broadcast from fast_llm.engine.base_model.base_model import BaseModel -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import ( +from fast_llm.engine.config_utils.checkpoint import ( CHECKPOINT_VERSION, - CheckpointConfig, - CheckpointType, - FastLLMModelConfig, - PretrainedCheckpointConfig, - StageMode, + CheckpointFormat, + CheckpointLoadConfig, + CheckpointSaveConfig, + ModelConfigType, ) +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.checkpoint import StateDictSaver +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode +from fast_llm.engine.multi_stage.conversion import ModelConverter, TrivialConverter from fast_llm.engine.multi_stage.multi_stage import MultiStageModel from fast_llm.engine.multi_stage.stage import Stage from fast_llm.functional.triton.pointwise import triton_fill @@ -83,28 +85,83 @@ def distributed(self): def save_checkpoint( self, - checkpoint_config: CheckpointConfig, + checkpoint_config: CheckpointSaveConfig, metadata: dict | None = None, ): - if metadata is None: - metadata = {} - if checkpoint_config.checkpoint_type == CheckpointType.distributed: - self._save_distributed_checkpoint(checkpoint_config, metadata) - elif checkpoint_config.checkpoint_type == CheckpointType.state_dict: - self._save_state_dict_checkpoint(checkpoint_config, metadata) - elif checkpoint_config.checkpoint_type == CheckpointType.huggingface: - assert checkpoint_config.exported_model_type is not None - self._export_checkpoint(checkpoint_config, metadata) + # TODO: Handle barriers, ok file, mkdir, etc. here + + num_shards = len(self._state_shard_names) if checkpoint_config.optimizer_state else 1 + metadata = { + "checkpoint_type": CheckpointFormat.distributed.value, + "checkpoint_version": str(CHECKPOINT_VERSION), + "fast_llm_config": self._fast_llm_config.to_serialized(), + "state_shard_names": list(self._state_shard_names[:num_shards]), + "metadata": {} if metadata is None else metadata, + } + + # TODO: Simplify branching. + if checkpoint_config.format == CheckpointFormat.external: + # TODO: Support optimizer? + assert not checkpoint_config.optimizer_state + converter_class = self._base_model_config.get_converter_class(checkpoint_config.model_type) + exported_config = converter_class.export_config(self._base_model_config) + converter_class.save_config(checkpoint_config.path, exported_config) + self._save_state_dict( + checkpoint_config, + converter_class(self._base_model_config), + { + "fast_llm_metadata": metadata, + "model_config": exported_config, + "format": "pt", + }, + ) + elif checkpoint_config.format == CheckpointFormat.state_dict: + self._save_state_dict(checkpoint_config, TrivialConverter(), metadata) + elif checkpoint_config.format == CheckpointFormat.distributed: + if self._distributed_config.rank == 0: + yaml.safe_dump(metadata, (checkpoint_config.path / "metadata.yaml").open("w")) + safetensors.torch.save_file( + tensors={"state_shard": self._state_shard[:num_shards]}, + filename=checkpoint_config.path / f"rank_{self._distributed_config.rank}.safetensors", + metadata=_export_safetensors_metadata(metadata), + ) else: - raise NotImplementedError(checkpoint_config.checkpoint_type) + raise NotImplementedError(checkpoint_config.format) + + def _save_state_dict(self, checkpoint_config: CheckpointSaveConfig, converter: ModelConverter, metadata: dict): + with StateDictSaver( + checkpoint_config, + distributed=self._distributed, + metadata=metadata, + base_file_name=converter.base_file_name, + ) as context: + # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from + # `fast_llm_state_dict` that are ready for conversion, + # and return a dict containing the converted tensors(s). + # If converting a tensor requires another one that is not yet available (e.g. for concatenation), + # it will remain in `fast_llm_state_dict` until that tensor is available. + fast_llm_state_dict = {} + for i, shard_name in enumerate( + self._state_shard_names if checkpoint_config.optimizer_state else self._state_shard_names[:1] + ): + shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) + for stage, shard in zip(self._stages_on_device.values(), shard_split): + for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.data_type): # noqa + assert name not in fast_llm_state_dict + fast_llm_state_dict[(name, shard_name)] = tensor + for exported_name, exported_tensor in converter.convert_state_dict( + fast_llm_state_dict, True + ).items(): + context.add_tensor(exported_name, exported_tensor) + assert not fast_llm_state_dict, list(fast_llm_state_dict) - def load_pretrained_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): - if pretrained_config.format == CheckpointType.distributed: + def load_pretrained_checkpoint(self, pretrained_config: CheckpointLoadConfig): + if pretrained_config.format == CheckpointFormat.distributed: # TODO: Check if same format. self._load_distributed_checkpoint(pretrained_config) - elif pretrained_config.format == CheckpointType.state_dict: + elif pretrained_config.format == CheckpointFormat.state_dict: self._load_state_dict_checkpoint(pretrained_config) - elif pretrained_config.format == CheckpointType.huggingface: + elif pretrained_config.format == CheckpointFormat.external: self._import_checkpoint(pretrained_config) else: raise NotImplementedError(pretrained_config.format) @@ -113,7 +170,7 @@ def load_distributed_checkpoint_same_format(self, directory: pathlib.Path): # TODO: Handle barriers, ok file, etc. here # TODO: More safety checks # TODO: Integrate to load_checkpoint. - pretrained_config = PretrainedCheckpointConfig(path=directory, format=CheckpointType.distributed) + pretrained_config = CheckpointLoadConfig(path=directory, format=CheckpointFormat.distributed) metadata = self.config_class.load_pretrained_metadata(pretrained_config) with self._LoadContext(self, safe=False, load_optimizer=True, reset_pads=False) as context: Assert.eq( @@ -132,7 +189,7 @@ def load_distributed_checkpoint_same_format(self, directory: pathlib.Path): @classmethod def from_pretrained( cls, - pretrained_config: PretrainedCheckpointConfig, + pretrained_config: CheckpointLoadConfig, default_config: FastLLMModelConfig = None, *, config_updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -166,7 +223,7 @@ def from_pretrained( model.setup(Distributed(config.distributed, use_cpu=use_cpu), mode=mode) if mode.on_device: - if pretrained_config.load_weights: + if pretrained_config.model_weights: model.load_pretrained_checkpoint(pretrained_config) else: model.initialize_weights() @@ -190,200 +247,6 @@ def _reset_shard_pads(self, optimizer: bool = False): counter += stage.reset_shard_pad(stage_shard) return counter - class _SaveContext: - def __init__( - self, - model: "FastLLMModel", - metadata, - *, - directory: pathlib.Path, - save_optimizer: bool, - ): - assert model._is_setup - assert model._is_loaded - self.save_optimizer = save_optimizer - self.num_shards = len(model._state_shard_names) if self.save_optimizer else 1 - self.self_shard = model._state_shard[: self.num_shards] - self.shard_names = model._state_shard_names[: self.num_shards] - self.directory = directory - self.metadata = { - "checkpoint_type": CheckpointType.distributed.value, - "checkpoint_version": str(CHECKPOINT_VERSION), - "fast_llm_config": model.fast_llm_config.to_serialized(), - "state_shard_names": list(model._state_shard_names[: self.num_shards]), - "metadata": metadata, - } - - def __enter__(self): - # Is a context for future-proofing. - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return - - class _SaveStateDictContext(_SaveContext): - def __init__( - self, - model: "FastLLMModel", - metadata, - *, - directory: pathlib.Path, - save_optimizer: bool, - base_file_name: str, - target_params_per_file, - ): - super().__init__(model, metadata, directory=directory, save_optimizer=save_optimizer) - self.distributed_config = model._distributed_config - self.distributed = model._distributed - self.base_file_name = ( - base_file_name - if self.distributed_config.pipeline_parallel == 1 - else f"{base_file_name}_{self.distributed_config.pipeline_rank}" - ) - self.target_params_per_file = target_params_per_file - # All ranks reconstruct the pipeline-parallel state (for simplicity), but only one saves it. - self.do_save = self.distributed_config.data_rank == self.distributed_config.tensor_rank == 0 - - def add_tensor(self, name: str, tensor: torch.Tensor): - assert name not in self.tensors - self.tensors[name] = tensor - self.param_count += tensor.numel() - if self.param_count >= self.target_params_per_file: - self._save_next_file() - - def __enter__(self): - self.file_count = 0 - self.param_count = 0 - self.tensors = {} - self.index = {} - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.tensors: - # Save the last file. - self._save_next_file() - - if self.do_save and self.distributed_config.pipeline_parallel != 1: - # Combine the indexes from all pipeline ranks. - logger.info(f"Merging pipeline-parallel indexes.") - json.dump( - self.index, (self.directory / f"index_{self.distributed_config.pipeline_rank}.json").open("w") - ) - safe_barrier(self.distributed.pipeline_group, "save state dict") - if self.distributed_config.pipeline_rank == 0: - self.index = {} - for rank in range(self.distributed_config.pipeline_parallel): - file_name = self.directory / f"index_{rank}.json" - local_index = json.load(file_name.open("r")) - for key, value in local_index.items(): - assert key not in self.index, key - self.index[key] = value - file_name.unlink() - - if self.distributed_config.rank == 0: - path = self.directory / f"{self.base_file_name}.safetensors.index.json" - logger.info(f"Saving index to {path}") - # Save the index. - json.dump( - {"metadata": self.metadata, "weight_map": self.index}, - path.open("w"), - indent=4, - ) - - def _save_next_file(self): - file_name = f"{self.base_file_name}_{self.file_count}.safetensors" - if self.do_save: - logger.info(f"Saving tensors to {self.directory/file_name}") - safetensors.torch.save_file( - tensors=self.tensors, - filename=self.directory / file_name, - metadata=_export_safetensors_metadata(self.metadata), - ) - for name_ in self.tensors: - assert name_ not in self.index - self.index[name_] = file_name - self.file_count += 1 - self.param_count = 0 - self.tensors = {} - - def _save_distributed_checkpoint(self, checkpoint_config: CheckpointConfig, metadata: dict): - # TODO: Non blocking? - # TODO: CPU memory? - # TODO: Handle barriers, ok file, mkdir, etc. here - with self._SaveContext( - self, - metadata, - directory=checkpoint_config.checkpoint_path, - save_optimizer=checkpoint_config.save_optimizer, - ) as context: - if self._distributed_config.rank == 0: - yaml.safe_dump(context.metadata, (context.directory / "metadata.yaml").open("w")) - safetensors.torch.save_file( - tensors={"state_shard": self._state_shard[: context.num_shards]}, - filename=context.directory / f"rank_{self._distributed_config.rank}.safetensors", - metadata=_export_safetensors_metadata(context.metadata), - ) - - def _save_state_dict_checkpoint(self, checkpoint_config: CheckpointConfig, metadata: dict): - # TODO: Make into a special case of _export_checkpoint? - # TODO: Handle barriers, ok file, mkdir, etc. here - with self._SaveStateDictContext( - self, - metadata, - directory=checkpoint_config.checkpoint_path, - save_optimizer=checkpoint_config.save_optimizer, - base_file_name="state_dict", - target_params_per_file=checkpoint_config.target_params_per_file, - ) as context: - for i, shard_name in enumerate( - self._state_shard_names if checkpoint_config.save_optimizer else self._state_shard_names[:1] - ): - shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) - for stage, shard in zip(self._stages_on_device.values(), shard_split): - for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.dtype): # noqa - context.add_tensor(f"{name}/{shard_name}", tensor) - - def _export_checkpoint(self, checkpoint_config: CheckpointConfig, metadata: dict): - # TODO: Handle barriers, ok file, mkdir, etc. here - # TODO: Support optimizer? - assert not checkpoint_config.save_optimizer - with self._SaveStateDictContext( - self, - metadata, - directory=checkpoint_config.checkpoint_path, - save_optimizer=False, - base_file_name="model", - target_params_per_file=checkpoint_config.target_params_per_file, - ) as context: - converter_class = self._base_model_config.get_converter_class(checkpoint_config.exported_model_type) - exported_config = converter_class.export_config(self._base_model_config) - converter_class.save_config(checkpoint_config.checkpoint_path, exported_config) - context.metadata = { - "fast_llm_metadata": context.metadata, - "model_config": exported_config, - "format": "pt", - } - converter = converter_class(self._base_model_config) - # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from - # `fast_llm_state_dict` that are ready for conversion, - # and return a dict containing the converted tensors(s). - # If converting a tensor requires another one that is not yet available (e.g. for concatenation), - # it will remain in `fast_llm_state_dict` until that tensor is available. - fast_llm_state_dict = {} - for i, shard_name in enumerate( - self._state_shard_names if checkpoint_config.save_optimizer else self._state_shard_names[:1] - ): - shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) - for stage, shard in zip(self._stages_on_device.values(), shard_split): - for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.dtype): # noqa - assert name not in fast_llm_state_dict - fast_llm_state_dict[name] = tensor - for exported_name, exported_tensor in converter.convert_state_dict( - fast_llm_state_dict, True - ).items(): - context.add_tensor(exported_name, exported_tensor) - assert not fast_llm_state_dict, list(fast_llm_state_dict) - class _LoadContext: # TODO: Improve def __init__(self, model: "FastLLMModel", *, safe: bool, load_optimizer: bool, reset_pads: bool): @@ -568,21 +431,16 @@ def import_state_tensor(self, shard_name: str, parameter_name: str, tensor: torc self.loaded_parameters[shard_name] = {} self.loaded_parameters[shard_name][parameter_name] = loaded - def _load_distributed_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): + def _load_distributed_checkpoint(self, pretrained_config: CheckpointLoadConfig): # TODO: More safety checks metadata = self.config_class.load_pretrained_metadata(pretrained_config) - loaded_pretrained_config = pretrained_config.to_copy( - { - "load_full_base_model_config": True, - "load_full_fast_llm_config": True, - }, - ) + loaded_pretrained_config = pretrained_config.to_copy({"load_config": ModelConfigType.fast_llm}) loaded_config = self.config_class.from_metadata( loaded_pretrained_config, metadata, ) with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.load_optimizer, reset_pads=True + self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True ) as context: Assert.eq(metadata["state_shard_names"][: context.num_shards], list(context.shard_names)) @@ -601,12 +459,12 @@ def _load_distributed_checkpoint(self, pretrained_config: PretrainedCheckpointCo return metadata["metadata"] - def _load_state_dict_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): + def _load_state_dict_checkpoint(self, pretrained_config: CheckpointLoadConfig): # TODO: Make into a special case of _import_state_tensor? # TODO: Verify more distributed configs. # TODO: More safety checks with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.load_optimizer, reset_pads=True + self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True ) as context: index_path = pretrained_config.path / f"state_dict.safetensors.index.json" logger.info(f"Loading index from {index_path}") @@ -627,23 +485,23 @@ def _load_state_dict_checkpoint(self, pretrained_config: PretrainedCheckpointCon return metadata["metadata"] - def _import_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): + def _import_checkpoint(self, pretrained_config: CheckpointLoadConfig): # TODO: Support optimizer? - assert not pretrained_config.load_optimizer + assert not pretrained_config.optimizer_state # TODO: Verify more distributed configs. # TODO: Safety checks - converter_class = self.base_model.architecture_cls().get_converter_class(pretrained_config.imported_type) + converter_class = self.base_model.architecture_cls().get_converter_class(pretrained_config.model_type) converter = converter_class.from_config(converter_class.load_config(pretrained_config.path)) self._base_model_config.compare_architecture(converter.config, pretrained_config.compare_log_fn) state_dict = {} with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.load_optimizer, reset_pads=True + self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True ) as context: for name, tensor in converter.load_weights(pretrained_config.path, self._distributed.device): assert name not in state_dict - state_dict[name] = tensor + state_dict[(name, "weights")] = tensor for parameter_name, fast_llm_tensor in converter.convert_state_dict(state_dict, False).items(): context.import_state_tensor("weights", parameter_name, fast_llm_tensor) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 6f2fce854..655f29677 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -1,11 +1,19 @@ import argparse import os +import pathlib import shlex import subprocess import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.data.config import AbstractDataConfig +from fast_llm.engine.config_utils.checkpoint import ( + CheckpointConfigBase, + CheckpointFormat, + CheckpointSaveConfig, + CheckpointSaveConfigBase, + CheckpointStateConfigBase, +) from fast_llm.engine.config_utils.run import ExperimentConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig @@ -17,48 +25,85 @@ from fast_llm.engine.training.trainer import Trainer -def get_interval_config_class(desc: str, offset_desc: str | None = None): - # Intervals are a common pattern, so we standardize them with this helper. - @config_class() - class IntervalConfig(Config): - interval: int | None = Field( - default=None, - desc=f"The number of training iterations between each {desc}. Setting to None will disable.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) - offset: int = Field( - default=0, - desc=f"Offset for the first {offset_desc or desc}.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), +@config_class() +class IntervalConfig(Config): + # Intervals are a common pattern, so we standardize them with this base class. + interval: int | None = Field( + default=None, + desc="The number of training iterations between each interval. Setting to None will disable.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + offset: int = Field( + default=0, + desc="Offset for the first interval.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + + def _validate(self): + if self.interval: + self.offset %= self.interval + super()._validate() + + def enabled(self, iteration: int | None = None): + return self.interval and (iteration is None or (iteration - self.offset) % self.interval == 0) + + def is_sub_interval(self, other: "IntervalConfig"): + if not self.enabled(): + return True + elif not other.enabled(): + return False + return self.interval % other.interval == 0 and (other.offset % other.interval) == ( + self.offset % other.interval ) - def enabled(self, iteration: int | None = None): - return self.interval and (iteration is None or (iteration - self.offset) % self.interval == 0) + def assert_sub_interval(self, other: "IntervalConfig"): + assert self.is_sub_interval(other), f"{self} is not a sub-interval of {other}" - def is_sub_interval(self, other: "IntervalConfig"): - if not self.enabled(): - return True - elif not other.enabled(): - return False - return self.interval % other.interval == 0 and (other.offset % other.interval) == ( - self.offset % other.interval - ) + def get_count(self, iteration): + # Number of times this interval was enabled after a given iteration. + return (iteration - self.offset) // self.interval + 1 if self.enabled() else 0 - def assert_sub_interval(self, other: "IntervalConfig"): - assert self.is_sub_interval(other), f"{self} is not a sub-interval of {other}" - return IntervalConfig +def _validate_script(value): + if isinstance(value, str): + value = shlex.split(value) + Assert.geq(len(value), 1) + return value @config_class() -class WandbAlertConfig( - get_interval_config_class( - "Wandb status post (alert). Must be a multiple of the logging interval", - "Wandb status post (alert). Must be compatible with the logging offset", +class CallbackConfig(Config): + script: list[str] | None = Field( + default=None, + desc="Shell script to run.", + hint=FieldHint.feature, + valid=skip_valid_if_none(_validate_script), + ) + environment: dict[str, str] = Field( + default_factory=dict, + desc="Environment variables to add to the script.", + hint=FieldHint.feature, + ) + + def run(self): + if self.script is not None: + environment = os.environ.copy() + environment.update(self.environment) + subprocess.Popen(self.script, env=environment) + + +@config_class() +class WandbAlertConfig(IntervalConfig): + interval = FieldUpdate( + desc="The number of training iterations between each Wandb status post (alert)." + " Setting to None will disable iteration-based wandb alerts." + " Must be a sub-interval of the logging interval." + ) + offset = FieldUpdate( + desc="Offset for the first Wandb status post (alert)." " Must be compatible with the logging offset.", ) -): status_updates: bool | None = Field( default=None, desc="Post wandb status updates on status changes (run begin/end). " @@ -73,15 +118,21 @@ def _validate(self): @config_class() -class MetricsLogsConfig(get_interval_config_class("metric logs")): - pass +class MetricsLogsConfig(IntervalConfig): + interval = FieldUpdate( + default=100, + desc="The number of training iterations between each metric logs." + " Setting to None will disable metric logging.", + ) + offset = FieldUpdate(desc="Offset for the first metric logs.") @config_class() class WandbConfig(Config): alert: WandbAlertConfig = Field( default_factory=WandbAlertConfig, - desc="Configuration for Wandb alerts. The alerts may be posted by email and/or slack depending on the Wandb account configuration.", + desc="Configuration for Wandb alerts." + " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.core, ) group_name: str = Field(default="default", desc="A group name for Wandb", hint=FieldHint.feature) @@ -90,7 +141,12 @@ class WandbConfig(Config): @config_class() -class ValidationConfig(get_interval_config_class("validation")): +class ValidationConfig(IntervalConfig): + interval = FieldUpdate( + desc="The number of training iterations between each validation phase." + " Setting to None will disable validation." + ) + offset = FieldUpdate(desc="Offset for the first validation phase.") iterations: int | None = Field( default=None, desc="Number of iterations for each validation phase. Setting to None will disable.", @@ -98,63 +154,96 @@ class ValidationConfig(get_interval_config_class("validation")): valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - def get_completed_iterations(self, training_iterations: int, completed_validations: int = 0): + def get_iteration_count(self, training_iterations: int, extra_validations: int = 0): # Number of completed validation iterations - return ( - (training_iterations // self.interval + completed_validations) * self.iterations if self.enabled() else 0 - ) + return (self.get_count(training_iterations) + extra_validations) * self.iterations if self.enabled() else 0 @config_class() -class CheckpointConfig(get_interval_config_class("checkpoint")): +class CheckpointBaseConfig(IntervalConfig): + _abstract = True + save_name: typing.ClassVar[str] = "save" + directory_name: typing.ClassVar[str] = "save" + callback: CallbackConfig = Field( + default_factory=CallbackConfig, + desc="Callback (shell script).", + hint=FieldHint.core, + ) keep: int | None = Field( - default=5, - desc="The maximum number of checkpoints to keep. When exceeding this value, checkpoints are deleted starting from the older ones.", + default=None, + desc="The maximum number of saves to keep. When exceeding this value, checkpoints are deleted starting from the older ones.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + keep_every: int | None = Field( + default=None, + desc="Keep every nth saves, i.e. Exclude it from the checkpoint count and deletion in `keep`.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + def get_save_config(self, path: pathlib.Path): + raise NotImplementedError() -def _validate_script(value): - if isinstance(value, str): - value = shlex.split(value) - Assert.geq(len(value), 1) - return value + def to_delete(self, iterations: list[int]): + if not self.keep: + return [] + # Ignore checkpoints that aren't supposed to be there. + iterations = [iteration for iteration in iterations if self.enabled(iteration)] + # Ignore excluded checkpoints. + if self.keep_every: + iterations = [iteration for iteration in iterations if self.get_count(iteration) % self.keep_every != 0] + # Exclude the last `keep`. + return iterations[: -self.keep] @config_class() -class CallbackConfig(Config): - script: list[str] | None = Field( - default=None, - desc="Shell script to run after.", - hint=FieldHint.feature, - valid=skip_valid_if_none(_validate_script), - ) - environment: dict[str, str] = Field( - default_factory=dict, - desc="Environment variables to add to the script.", - hint=FieldHint.feature, +class CheckpointConfig(CheckpointBaseConfig): + _abstract = False + save_name: typing.ClassVar[str] = "checkpoint" + # TODO v0.2: Rename to `checkpoint` so we don't need this extra variable? + directory_name = "checkpoints" + interval = FieldUpdate( + desc="The number of training iterations between each checkpoint." " Setting to None will disable checkpoints." ) - - def run(self): - if self.script is not None: - environment = os.environ.copy() - environment.update(self.environment) - subprocess.Popen(self.script, env=environment) + offset = FieldUpdate(desc="Offset for the first checkpoint.") + callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after checkpoint.") + keep: int | None = FieldUpdate(default=5) + + def get_save_config(self, path: pathlib.Path): + return CheckpointSaveConfig( + path=path, + format=CheckpointFormat.distributed, + model_weights=True, + optimizer_state=True, + ) @config_class() -class ExportConfig(get_interval_config_class("export")): - callback: CallbackConfig = Field( - default_factory=CallbackConfig, - desc="Callback (shell script) to run after export.", - hint=FieldHint.core, +class ExportConfig(CheckpointBaseConfig, CheckpointConfigBase, CheckpointStateConfigBase, CheckpointSaveConfigBase): + _abstract = False + save_name: typing.ClassVar[str] = "export" + directory_name = "export" + interval = FieldUpdate( + desc="The number of training iterations between each export." " Setting to None will disable exports." ) + offset = FieldUpdate(desc="Offset for the first export.") + callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after export.") + + def get_save_config(self, path: pathlib.Path): + return CheckpointSaveConfig.from_dict(self, {"path": path}, strict=False) @config_class() -class ShutdownConfig(get_interval_config_class("automated shutdown")): - pass +class ShutdownConfig(IntervalConfig): + interval = FieldUpdate( + desc="The number of training iterations between each automated shutdown." + " Setting to None will disable automated shutdowns." + " Must be a sub-interval of the checkpoint interval." + ) + offset = FieldUpdate( + desc="Offset for the first automated shutdown." " Must be compatible with the checkpoint offset." + ) @config_class() @@ -201,7 +290,6 @@ class TrainingConfig(Config): def _validate(self): super()._validate() - self.export.assert_sub_interval(self.checkpoint) self.shutdown.assert_sub_interval(self.checkpoint) self.wandb.alert.assert_sub_interval(self.logs) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a17eca121..ff8a5f029 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -1,6 +1,8 @@ import abc import logging import math +import pathlib +import shutil import time import typing @@ -12,13 +14,12 @@ from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import CheckpointConfig, CheckpointType from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.engine.training.config import TrainerConfig +from fast_llm.engine.training.config import CheckpointBaseConfig, TrainerConfig from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, get_memory_usage_mib, log_memory_usage from fast_llm.utils import Assert @@ -56,8 +57,10 @@ def __init__(self, config: TrainerConfig): ) steps_per_split = { PhaseType.training: self._config.training.train_iters, - PhaseType.validation: self._config.training.validation.get_completed_iterations( - self._config.training.train_iters, 1 + PhaseType.validation: self._config.training.validation.get_iteration_count( + self._config.training.train_iters, + # There may be an extra validation after the last training step. + not self._config.training.validation.enabled(self._config.training.train_iters), ), PhaseType.test: self._config.training.test_iters, } @@ -130,7 +133,7 @@ def _consumed_tokens(self): @property def _completed_validation_steps(self) -> int: # Number of validation steps performed before the current step - return self._config.training.validation.get_completed_iterations(self._completed_steps - 1) + return self._config.training.validation.get_iteration_count(self._completed_steps - 1) def run(self): assert self._is_setup @@ -311,10 +314,10 @@ def _train(self): self._wandb.log_metrics(self._completed_steps, metrics) if self._config.training.checkpoint.enabled(None if stop else self._completed_steps): - self._save_checkpoint( - metrics, - export=self._config.training.export.enabled(None if done else self._completed_steps), - ) + self._save_checkpoint(self._config.training.checkpoint, metrics) + + if self._config.training.export.enabled(None if done else self._completed_steps): + self._save_checkpoint(self._config.training.export, metrics) return done, metrics @@ -362,13 +365,22 @@ def _evaluate( return metrics + def _get_data_iterator(self, phase, completed_steps: int = 0, prefetch_factor: int | None = None): + return self._data.get_iterator( + self._config.batch, + phase, + consumed_samples=completed_steps * self._config.batch.batch_size, + num_workers=self._config.training.num_workers, + prefetch_factor=prefetch_factor, + ) + def _prepare_training_state(self): # Setup the training state. - if (last_iteration := self._run.get_last_checkpoint()) is None: - if (path := self._config.pretrained.path) is not None and self._config.pretrained.load_weights: + if (last_iteration := self._get_last_checkpoint()) is None: + if (path := self._config.pretrained.path) is not None and self._config.pretrained.model_weights: log_main_rank( f"Initializing training state from pretrained checkpoint at {path}" - f" ({'loading' if self._config.pretrained.load_optimizer else 'resetting'}" + f" ({'loading' if self._config.pretrained.optimizer_state else 'resetting'}" f" optimizer state)..." ) self._multi_stage.load_pretrained_checkpoint(self._config.pretrained) @@ -379,44 +391,76 @@ def _prepare_training_state(self): self._completed_steps = 0 else: log_main_rank(lambda: f"Loading checkpoint from iteration {last_iteration}...") - with self._run.get_load_checkpoint_context(last_iteration) as context: - metadata = self._multi_stage.load_distributed_checkpoint_same_format(context.directory) - self._optimizer.load(metadata["optimizer"]) - if "schedules" in metadata: - # Backward compatibility. - self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] - else: - self._completed_steps = metadata["completed_steps"] + self._load_checkpoint(self._config.training.checkpoint, last_iteration) Assert.eq(self._completed_steps, last_iteration or 0) assert self._multi_stage._is_loaded # noqa - def _get_data_iterator(self, phase, completed_steps: int = 0, prefetch_factor: int | None = None): - return self._data.get_iterator( - self._config.batch, - phase, - consumed_samples=completed_steps * self._config.batch.batch_size, - num_workers=self._config.training.num_workers, - prefetch_factor=prefetch_factor, - ) - - def _save_checkpoint(self, metrics: dict[PhaseType, dict[str, float | int]] | None, export: bool = False): - assert self._is_setup - with self._run.get_save_checkpoint_context( - self._completed_steps, export, self._config.training.checkpoint.keep - ) as checkpoint: - metadata = { - "optimizer": self._optimizer.save(), - "completed_steps": self._completed_steps, - } - if metrics is not None: - metadata["metrics"] = {key.value: value for key, value in metrics.items()} - self._multi_stage.save_checkpoint( - CheckpointConfig(checkpoint_type=CheckpointType.distributed, checkpoint_path=checkpoint.directory), - metadata, - ) - if export and self._run.is_main_rank: # noqa - self._config.training.export.callback.run() + def _save_checkpoint(self, config: CheckpointBaseConfig, metrics: dict[PhaseType, dict[str, float | int]] | None): + # TODO v0.2: Move barrier, ok file to FastLLMModel + checkpoint_base_directory = self._run.experiment_directory / config.directory_name + checkpoint_directory = checkpoint_base_directory / str(self._completed_steps) + + # Create the checkpoint + if self._run.is_main_rank: + logger.info(f"Saving {config.save_name} at iteration {self._completed_steps}") + checkpoint_directory.mkdir(exist_ok=False, parents=True) + # Barrier to ensure the directory is created correctly (and didn't exist before). + self._run.barrier(f"{config.save_name} {self._completed_steps} enter") + + metadata = { + "optimizer": self._optimizer.save(), + "completed_steps": self._completed_steps, + } + if metrics is not None: + metadata["metrics"] = {key.value: value for key, value in metrics.items()} + self._multi_stage.save_checkpoint(config.get_save_config(checkpoint_directory), metadata) + + # Barrier to ensure everyone is done. + self._run.barrier(f"{config.save_name} {self._completed_steps} exit") + # Mark the checkpoint as complete. + if self._run.is_main_rank: + (checkpoint_directory / "ok").open("w") + logger.info(f"Saved {config.save_name} to {checkpoint_directory}") + + to_delete = config.to_delete(sorted(int(path.name) for path in checkpoint_base_directory.iterdir())) + + for iteration in to_delete: + path = checkpoint_base_directory / str(iteration) + logger.info(f"Deleting {config.save_name} at {path}") + try: + shutil.rmtree(path, ignore_errors=True) + except OSError as e: + logger.warning(f"Could not remove {config.save_name} directory: {e.args}") + + config.callback.run() + + def _load_checkpoint(self, config: CheckpointBaseConfig, iteration: int): + checkpoint_directory = self._run.experiment_directory / config.directory_name / str(iteration) + Assert.custom(pathlib.Path.is_file, checkpoint_directory / "ok") + # TODO v0.2: Use config.get_load_config to make it generic + # TODO v0.2: Detect format instead of hard-coding + metadata = self._multi_stage.load_distributed_checkpoint_same_format(checkpoint_directory) + self._optimizer.load(metadata["optimizer"]) + if "schedules" in metadata: + # Backward compatibility. + self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] + else: + self._completed_steps = metadata["completed_steps"] + # TODO v0.2: Move barrier, ok file to FastLLMModel + self._run.barrier(f"load {config.save_name} {iteration} exit") + + def _get_last_checkpoint(self): + if self._run.experiment_directory is None: + return None + checkpoint_base_directory = self._run.experiment_directory / self._config.training.checkpoint.directory_name + if self._run.is_main_rank and checkpoint_base_directory.is_dir(): + checkpoints = [int(path.name) for path in checkpoint_base_directory.iterdir()] + iteration = max(checkpoints) if checkpoints else -1 + else: + iteration = -1 + iteration = self._run.broadcast_int(iteration) + return iteration if iteration >= 0 else None @abc.abstractmethod def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 74684dfe8..1c896e3b1 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -54,10 +54,11 @@ class NormalizationArchitectureConfig(BaseModelArchitectureConfig): @classmethod def _from_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, flat: bool = False, ): + # TODO v0.2: Remove. cls._handle_renamed_field(default, "normalization_type", "type") cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") @@ -106,7 +107,7 @@ def get_layer(self, hidden_dim: "TensorDim"): @classmethod def _from_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, flat: bool = False, ): diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 3976c69bc..4e3058995 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,4 +1,6 @@ -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +import typing + +from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames @@ -84,7 +86,7 @@ def use_absolute_position_embeddings(self): @classmethod def from_flat_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, ): # The backward compatibility fix in `NormalizationArchitectureConfig` @@ -109,9 +111,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_cls = LanguageModelArchitectureConfig - transformer: TransformerConfig = Field( - default_factory=TransformerConfig, desc="Configuration for the transformer.", hint=FieldHint.core - ) + transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 8109f2052..e6c133f14 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -3,7 +3,7 @@ import math import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace @@ -262,11 +262,7 @@ def use_rotary_position_embeddings(self): @config_class() class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): - normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, - desc="Configuration for the normalization layers.", - hint=FieldHint.core, - ) + normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) # Default: hidden_size**-0.5 # TODO: Allow custom initialization (InitializationConfig?) init_method_std: float = Field( diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 1f6bb0168..82e004319 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -1,4 +1,4 @@ -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import FieldUpdate, config_class from fast_llm.data.config import DataConfig from fast_llm.models.gpt.config import ( GPTArchitectureConfig, @@ -30,7 +30,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig, CustomArchitectureConfig): @config_class() class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). - base_model: CustomBaseModelConfig = Field(default_factory=CustomBaseModelConfig) + base_model: CustomBaseModelConfig = FieldUpdate(default_factory=CustomBaseModelConfig) @classmethod def get_model_class(cls): @@ -47,18 +47,14 @@ def get_huggingface_model_class(cls): @config_class() class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = Field(default_factory=CustomModelConfig) + model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig) @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = Field( - default_factory=CustomDataConfig, - desc="Configuration for the dataset and model-independent preprocessing.", - hint=FieldHint.core, - ) + data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) @classmethod def get_trainer_class(cls): diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 4725f1e86..f13488e3e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,6 +1,6 @@ import typing -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.config import DataConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -8,7 +8,7 @@ from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds if typing.TYPE_CHECKING: - from fast_llm.engine.multi_stage.conversion import ModelConverter + from fast_llm.engine.multi_stage.conversion import ExternalModelConverter @config_class() @@ -28,7 +28,7 @@ def _from_dict( return super()._from_dict(default, strict, flat) @classmethod - def get_converter_class(cls, model_type: str | None = None) -> type["ModelConverter"]: + def get_converter_class(cls, model_type: str | None = None) -> type["ExternalModelConverter"]: from fast_llm.models.gpt.conversion import AutoGPTConverter return AutoGPTConverter if model_type is None else AutoGPTConverter.converter_map[model_type] @@ -65,7 +65,7 @@ def _from_dict( @config_class() class GPTModelConfig(FastLLMModelConfig): _abstract = False - base_model: GPTBaseModelConfig = Field(default_factory=GPTBaseModelConfig) + base_model: GPTBaseModelConfig = FieldUpdate(default_factory=GPTBaseModelConfig) @classmethod def get_model_class(cls): @@ -83,17 +83,13 @@ def get_huggingface_model_class(cls): @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: GPTModelConfig = Field(default_factory=GPTModelConfig) + model: GPTModelConfig = FieldUpdate(default_factory=GPTModelConfig) @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: DataConfig = Field( - default_factory=DataConfig, - desc="Configuration for the dataset and model-independent preprocessing.", - hint=FieldHint.core, - ) + data: DataConfig = FieldUpdate(default_factory=DataConfig) def _setup(self): super()._setup() diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index d2305ade7..1ef3c0494 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -7,15 +7,10 @@ import typing from fast_llm.config import Field, config_class +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadConfig, CheckpointSaveConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.multi_stage.config import ( - CheckpointConfig, - CheckpointType, - FastLLMModelConfig, - PretrainedCheckpointConfig, - StageMode, -) +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.functional.config import TritonConfig from fast_llm.models.auto import model_registry from fast_llm.utils import Assert @@ -28,8 +23,8 @@ @config_class() class ConversionConfig(RunnableConfig): - input_type: CheckpointType = Field() - output_type: CheckpointType = Field() + input_type: CheckpointFormat = Field() + output_type: CheckpointFormat = Field() input_path: pathlib.Path = Field() output_path: pathlib.Path = Field() model_type: str | None = Field(default=None) @@ -62,10 +57,10 @@ def _convert_model_partial( ): logger.info(f"Loading {self.input_type} checkpoint from {self.input_path}...") model = model_class.from_pretrained( - PretrainedCheckpointConfig( + CheckpointLoadConfig( path=self.input_path, format=self.input_type, - imported_type=self.model_type, + model_type=self.model_type, ), mode=StageMode.weights, use_cpu=self.use_cpu, @@ -74,13 +69,13 @@ def _convert_model_partial( logger.info(f"Saving {self.output_type} checkpoint to {output_path}...") output_path.mkdir(parents=True, exist_ok=self.exist_ok) model.save_checkpoint( - CheckpointConfig( - checkpoint_path=output_path, - checkpoint_type=self.output_type, - exported_model_type=self.model_type, - save_optimizer=False, - target_params_per_file=self.target_params_per_file, - dtype=self.dtype, + CheckpointSaveConfig( + path=output_path, + format=self.output_type, + model_type=self.model_type, + optimizer_state=False, + parameters_per_file=self.target_params_per_file, + data_type=self.dtype, ) ) (output_path / "ok").open("w") @@ -106,15 +101,15 @@ def run(self, model_config_class: type["FastLLMModelConfig"] | str): self._convert_model_partial(model_class, self.output_path) else: # TODO: Support other types? - assert self.output_type == CheckpointType.huggingface + assert self.output_type == CheckpointFormat.external logger.info(f">>> Loading model config") # Create a dummy version to determine the stage split. model = model_class.from_pretrained( - PretrainedCheckpointConfig( + CheckpointLoadConfig( path=self.input_path, format=self.input_type, - imported_type=self.model_type, - load_pretrained_weights=False, + model_type=self.model_type, + model_weights=False, ), mode=StageMode.off_device, use_cpu=self.use_cpu, diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 50ba105fa..5fd272a47 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -7,7 +7,8 @@ import transformers import yaml -from fast_llm.engine.multi_stage.config import CheckpointType, PretrainedCheckpointConfig, StageMode +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadConfig, ModelConfigType +from fast_llm.engine.multi_stage.config import StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig from tests.common import ( @@ -88,9 +89,9 @@ def _run_conversion(config: ConversionConfig): def test_convert_distributed_to_state_dict(): _run_conversion( ConversionConfig( - input_type=CheckpointType.distributed, + input_type=CheckpointFormat.distributed, input_path=_CKPT_PATH, - output_type=CheckpointType.state_dict, + output_type=CheckpointFormat.state_dict, output_path=_CONVERT_PATH / "state_dict_0", ) ) @@ -102,9 +103,9 @@ def test_convert_state_dict_to_huggingface(): pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( ConversionConfig( - input_type=CheckpointType.state_dict, + input_type=CheckpointFormat.state_dict, input_path=_CONVERT_PATH / "state_dict_0", - output_type=CheckpointType.huggingface, + output_type=CheckpointFormat.external, output_path=_CONVERT_PATH / "huggingface_0", model_type=HUGGINGFACE_MODEL_TYPE, ) @@ -115,9 +116,9 @@ def test_convert_state_dict_to_huggingface(): def test_convert_huggingface_to_distributed(): _run_conversion( ConversionConfig( - input_type=CheckpointType.huggingface, + input_type=CheckpointFormat.external, input_path=_CONVERT_PATH / "huggingface_0", - output_type=CheckpointType.distributed, + output_type=CheckpointFormat.distributed, output_path=_CONVERT_PATH / "distributed_0", ) ) @@ -129,9 +130,9 @@ def test_convert_distributed_to_huggingface(): pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( ConversionConfig( - input_type=CheckpointType.distributed, + input_type=CheckpointFormat.distributed, input_path=_CKPT_PATH, - output_type=CheckpointType.huggingface, + output_type=CheckpointFormat.external, output_path=_CONVERT_PATH / "huggingface_1", model_type=HUGGINGFACE_MODEL_TYPE, ) @@ -142,9 +143,9 @@ def test_convert_distributed_to_huggingface(): def test_convert_huggingface_to_state_dict(): _run_conversion( ConversionConfig( - input_type=CheckpointType.huggingface, + input_type=CheckpointFormat.external, input_path=_CONVERT_PATH / "huggingface_1", - output_type=CheckpointType.state_dict, + output_type=CheckpointFormat.state_dict, output_path=_CONVERT_PATH / "state_dict_1", ) ) @@ -154,9 +155,9 @@ def test_convert_huggingface_to_state_dict(): def test_convert_state_dict_to_distributed(): _run_conversion( ConversionConfig( - input_type=CheckpointType.state_dict, + input_type=CheckpointFormat.state_dict, input_path=_CONVERT_PATH / "state_dict_1", - output_type=CheckpointType.distributed, + output_type=CheckpointFormat.distributed, output_path=_CONVERT_PATH / "distributed_1", ) ) @@ -206,12 +207,11 @@ def test_load_pretrained_distributed_checkpoint(): config = TEST_ARCHITECTURE_CONFIG_CLS.from_dict( yaml.safe_load((_CKPT_PATH / ".." / ".." / "config.yaml").open("r")), strict=False ) - pretrained_config_ref = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, - load_optimizer=True, - load_full_base_model_config=True, - load_full_fast_llm_config=True, + format=CheckpointFormat.distributed, + optimizer_state=True, + load_config=ModelConfigType.fast_llm, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_configs(config, model._base_model_config) @@ -223,14 +223,14 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): - pretrained_config_ref = PretrainedCheckpointConfig(path=_CKPT_PATH, format=CheckpointType.distributed) - pretrained_config_0 = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=CheckpointFormat.distributed) + pretrained_config_0 = CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_0", - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) - pretrained_config_1 = PretrainedCheckpointConfig( + pretrained_config_1 = CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_1", - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) @@ -245,13 +245,9 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_state_dict", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_state_dict_checkpoint(): - pretrained_config_ref = PretrainedCheckpointConfig(path=_CKPT_PATH, format=CheckpointType.distributed) - pretrained_config_0 = PretrainedCheckpointConfig( - path=_CONVERT_PATH / "state_dict_0", format=CheckpointType.state_dict - ) - pretrained_config_1 = PretrainedCheckpointConfig( - path=_CONVERT_PATH / "state_dict_1", format=CheckpointType.state_dict - ) + pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=CheckpointFormat.distributed) + pretrained_config_0 = CheckpointLoadConfig(path=_CONVERT_PATH / "state_dict_0", format=CheckpointFormat.state_dict) + pretrained_config_1 = CheckpointLoadConfig(path=_CONVERT_PATH / "state_dict_1", format=CheckpointFormat.state_dict) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) @@ -265,17 +261,17 @@ def test_load_converted_state_dict_checkpoint(): @pytest.mark.depends(on=["test_converted_state_dict", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_huggingface_checkpoint(): - pretrained_config_ref = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) - pretrained_config_0 = PretrainedCheckpointConfig( + pretrained_config_0 = CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", - format=CheckpointType.huggingface, + format=CheckpointFormat.external, ) - pretrained_config_1 = PretrainedCheckpointConfig( + pretrained_config_1 = CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", - format=CheckpointType.huggingface, + format=CheckpointFormat.external, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0, mode=StageMode.weights) @@ -291,9 +287,9 @@ def test_load_converted_huggingface_checkpoint(): @pytest.mark.depends(on=["test_load_converted_state_dict_checkpoint", "test_load_converted_huggingface_checkpoint"]) def test_run_converted_model(): model_ref = TEST_MODEL_HF_CLS.from_pretrained( - PretrainedCheckpointConfig( + CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) ) test_input = torch.randint( @@ -302,9 +298,9 @@ def test_run_converted_model(): output_ref = model_ref(test_input) model_from_state_dict = TEST_MODEL_HF_CLS.from_pretrained(_CONVERT_PATH / "state_dict_0") model_from_hf = TEST_MODEL_HF_CLS.from_pretrained( - PretrainedCheckpointConfig( + CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", - format=CheckpointType.huggingface, + format=CheckpointFormat.external, ) ) errors = [] @@ -364,15 +360,15 @@ def test_load_pretrained_distributed_with_config(): @pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1" - pretrained_config_ref = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, - load_full_fast_llm_config=True, + format=CheckpointFormat.distributed, + load_config=ModelConfigType.fast_llm, ) - pretrained_config_test = PretrainedCheckpointConfig( + pretrained_config_test = CheckpointLoadConfig( path=test_ckpt_path, - format=CheckpointType.distributed, - load_full_fast_llm_config=True, + format=CheckpointFormat.distributed, + load_config=ModelConfigType.fast_llm, ) config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) config_test = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_test) @@ -406,15 +402,14 @@ def test_load_pretrained_in_dp2_match_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_in_dp2_match_checkpoint"]) def test_load_distributed_checkpoint_dp2(): # This also tests conversion which uses `FastLLMModel.from_checkpoint` - pretrained_config_ref = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, - load_full_base_model_config=True, - load_full_fast_llm_config=True, + format=CheckpointFormat.distributed, + load_config=ModelConfigType.fast_llm, ) - pretrained_config_test = PretrainedCheckpointConfig( + pretrained_config_test = CheckpointLoadConfig( path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1", - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) @@ -466,7 +461,7 @@ def test_load_pretrained_huggingface_in_dp2(): "training.checkpoint.interval=1", "training.train_iters=1", f"pretrained.path={_CONVERT_PATH / 'huggingface_0'}", - f"pretrained.format=huggingface", + f"pretrained.format=external", "schedule.skip_step=True", ], num_gpus=2, diff --git a/tools/push_model.py b/tools/push_model.py index 78d607d2d..8a0f3747c 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -6,6 +6,8 @@ import shutil import subprocess +from fast_llm.config import Field, config_class +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat from fast_llm.engine.config_utils.runnable import RunnableConfig try: @@ -25,8 +27,6 @@ raise ImportError("Please install huggingface_hub to use this script") from e -from fast_llm.config import Config, config_class, Field # isort:skip -from fast_llm.engine.multi_stage.config import CheckpointType # isort:skip from fast_llm.tools.convert import ConversionConfig # isort:skip @@ -148,8 +148,8 @@ def run(self) -> None: checkpoint_path_hf = checkpoint_path.with_name(checkpoint_path.name + "_hf") # Block until the conversion is done ConversionConfig( - input_type=CheckpointType.distributed, - output_type=CheckpointType.huggingface, + input_type=CheckpointFormat.distributed, + output_type=CheckpointFormat.external, input_path=checkpoint_path, output_path=checkpoint_path_hf, model_type=self.model_type,