From cf63b8dfc06fb8bb9e6f8919aa9a974db7d64cda Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 17 Oct 2024 08:52:04 -0400 Subject: [PATCH 1/8] Better checkpoints --- examples/example_config.yaml | 12 +- fast_llm/config.py | 31 ++++ fast_llm/engine/config_utils/checkpoint.py | 117 ++++++++++++++ fast_llm/engine/huggingface/config.py | 11 +- fast_llm/engine/huggingface/model.py | 9 +- fast_llm/engine/multi_stage/config.py | 135 +++------------- fast_llm/engine/multi_stage/fast_llm_model.py | 89 +++++----- fast_llm/engine/training/config.py | 153 ++++++++++++------ fast_llm/engine/training/trainer.py | 16 +- fast_llm/layers/language_model/config.py | 6 +- fast_llm/layers/transformer/config.py | 8 +- fast_llm/models/custom/config.py | 12 +- fast_llm/models/gpt/config.py | 12 +- fast_llm/tools/convert.py | 35 ++-- tests/test_checkpoint.py | 67 ++++---- tools/push_model.py | 4 +- 16 files changed, 408 insertions(+), 309 deletions(-) create mode 100644 fast_llm/engine/config_utils/checkpoint.py diff --git a/examples/example_config.yaml b/examples/example_config.yaml index c23d7c7b1..4c4e6d383 100644 --- a/examples/example_config.yaml +++ b/examples/example_config.yaml @@ -27,8 +27,8 @@ model: distributed_timeout: 60.0 training_dtype: float32 pretrained: - pretrained_checkpoint_path: null - pretrained_checkpoint_type: distributed + path: null + format: distributed batch: micro_batch_size: 1 depth_first_micro_batches: 1 @@ -42,13 +42,13 @@ data: - 969.0 - 30.0 - 1.0 - dataset_source: list - data_path: + format: list + path: - fkgtiu data_sample_warn_time_ms: 1000.0 profiling: - profile_cuda: false - profile_ranks: [] + cuda: false + ranks: [] optimizer: weight_decay: 0.01 initial_loss_scale: 65536.0 diff --git a/fast_llm/config.py b/fast_llm/config.py index 815b6b00a..869bf5391 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -153,6 +153,14 @@ def __init__( self.valid = valid +class FieldUpdate(dict): + """ + Specify some entries in the field that should be updated from the base class. + Useful for changing the default or description in a derived class. + Processed in `__init_subclass__`. + """ + + def check_field(fn, *args, **kwargs): """ Helper function to define a condition that a config field should satisfy, @@ -811,3 +819,26 @@ def __init_subclass__(cls, **kwargs): cls.__class_validated__ ), f"Parent class of config class {cls.__name__} has not been validated. Make sure to use the @config_class decorator." cls.__class_validated__ = False + for key, value in cls.__dict__: + if isinstance(value, FieldUpdate): + base_class_field = cls.get_field(key) + cls.__dict__[key] = Field( + desc=kwargs.pop("desc", base_class_field.desc), + doc=kwargs.pop("doc", base_class_field.doc), + hint=kwargs.pop("hint", base_class_field.hint), + valid=kwargs.pop("valid", base_class_field.valid), + default=kwargs.pop("default", base_class_field.default), + default_factory=kwargs.pop("default_factory", base_class_field.default_factory), + repr=kwargs.pop("repr", base_class_field.repr), + hash=kwargs.pop("hash", base_class_field.hash), + compare=kwargs.pop("compare", base_class_field.compare), + metadata=kwargs.pop("metadata", base_class_field.metadata), + kw_only=kwargs.pop("kw_only", base_class_field.kw_only), + ) + if key in cls.__annotations__: + # TODO: Generalize to other type hints. + if isinstance(cls.__annotations__[key], type) and isinstance(base_class_field.type, type): + Assert.custom(issubclass, cls.__annotations__[key], base_class_field.type) + else: + # dataclasses expects an annotation, so we use the one from the base class. + cls.__annotations__[key] = base_class_field.type diff --git a/fast_llm/engine/config_utils/checkpoint.py b/fast_llm/engine/config_utils/checkpoint.py new file mode 100644 index 000000000..1a279f3c9 --- /dev/null +++ b/fast_llm/engine/config_utils/checkpoint.py @@ -0,0 +1,117 @@ +# TODO: Use packaging.version? (Safer but extra requirement) +import enum +import logging +import pathlib + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + +# TODO: Use packaging.version? (Safer but extra requirement) +CHECKPOINT_VERSION = "0.1" +KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1") + + +class CheckpointType(str, enum.Enum): + # Distributed checkpoint for fast checkpointing and resuming. + distributed = "distributed" + # Model state dict, for safe long-term storage in Fast-LLM format. + state_dict = "state_dict" + # A checkpoint format external to Fast-LLM. + external = "external" + + +@config_class() +class CheckpointPathConfigBase(Config): + _abstract = True + path: pathlib.Path | None = Field( + default=None, + desc="Location of the checkpoint.", + hint=FieldHint.core, + ) + + +@config_class() +class CheckpointConfigBase(Config): + _abstract = True + format: CheckpointType = Field( + default=CheckpointType.distributed, + desc="Format of the checkpoint.", + hint=FieldHint.core, + ) + model_type: str | None = Field( + default=None, + desc="Model type for external models (ex. Huggingace model name).", + hint=FieldHint.feature, + ) + architecture_config: bool = Field( + default=False, + desc="Save/load the model architecture configuration.", + hint=FieldHint.feature, + ) + base_model_config: bool = Field( + default=False, + desc="Save/load the full base model configuration, including the non-architecture fields.", + hint=FieldHint.feature, + ) + fast_llm_config: bool = Field( + default=False, + desc="Save/load the full fast-llm model configuration, including the distributed and multi-stage configurations.", + hint=FieldHint.feature, + ) + + @property + def compare_log_fn(self): + return logger.warning if self.architecture_config else ValueError + + def _validate(self): + if self.fast_llm_config: + self.base_model_config = True + if self.base_model_config: + self.architecture_config = True + super()._validate() + + +@config_class() +class CheckpointStateConfigBase(Config): + _abstract = True + model_weights: bool = Field(default=True, desc="Save/load the model weights.", hint=FieldHint.feature) + optimizer_state: bool = Field(default=False, desc="Save/load the optimizer state.", hint=FieldHint.feature) + + +@config_class() +class CheckpointSaveConfigBase(Config): + _abstract = True + parameters_per_file: int = Field( + default=2**32, + desc="Limit the number of parameters saved in each file.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 2**20), + ) + data_type: DataType | None = Field( + default=None, + desc="Data type to save the checkpoint.", + hint=FieldHint.feature, + ) + + +@config_class() +class CheckpointMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase): + _abstract = False + + +@config_class() +class CheckpointSaveConfig(CheckpointMetadataConfig, CheckpointStateConfigBase, CheckpointSaveConfigBase): + _abstract = False + + +@config_class() +class CheckpointLoadConfig(CheckpointMetadataConfig, CheckpointStateConfigBase): + _abstract = False + + +# @config_class() +# class TrainingExportConfig(CheckpointConfigBase, CheckpointStateConfigBase, CheckpointSaveConfigBase): +# _abstract=False diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 1adff8bdd..d47506204 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -5,7 +5,8 @@ import transformers -from fast_llm.engine.multi_stage.config import CheckpointType, FastLLMModelConfig, PretrainedConfig +from fast_llm.engine.config_utils.checkpoint import CheckpointMetadataConfig, CheckpointType +from fast_llm.engine.multi_stage.config import FastLLMModelConfig logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = transformers.configuration_utils.CONFIG_NAME = _backup @classmethod - def _get_config_dict(cls, pretrained_model_name_or_path: str | os.PathLike | PretrainedConfig, **kwargs): + def _get_config_dict(cls, pretrained_model_name_or_path: str | os.PathLike | CheckpointMetadataConfig, **kwargs): # TODO: Support download from hub/url # Unused arguments, remove to avoid warnings. @@ -55,13 +56,13 @@ def _get_config_dict(cls, pretrained_model_name_or_path: str | os.PathLike | Pre # Get the pretrained config. if "pretrained" in kwargs: - assert isinstance(kwargs["pretrained"], PretrainedConfig) + assert isinstance(kwargs["pretrained"], CheckpointMetadataConfig) assert kwargs["pretrained"].path == pretrained_model_name_or_path pretrained = kwargs.pop("pretrained") - elif isinstance(pretrained_model_name_or_path, PretrainedConfig): + elif isinstance(pretrained_model_name_or_path, CheckpointMetadataConfig): pretrained = pretrained_model_name_or_path else: - pretrained = PretrainedConfig( + pretrained = CheckpointMetadataConfig( path=pathlib.Path(pretrained_model_name_or_path), format=CheckpointType.state_dict, ) diff --git a/fast_llm/engine/huggingface/model.py b/fast_llm/engine/huggingface/model.py index fa46d0e40..e177ca796 100644 --- a/fast_llm/engine/huggingface/model.py +++ b/fast_llm/engine/huggingface/model.py @@ -5,9 +5,10 @@ import transformers.modeling_outputs from fast_llm.config import NoAutoValidate +from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointType from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.huggingface.config import HuggingfaceModelConfig -from fast_llm.engine.multi_stage.config import CheckpointType, PretrainedCheckpointConfig, PretrainedConfig, StageMode +from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig from fast_llm.engine.schedule.runner import ScheduleRunner @@ -58,14 +59,14 @@ def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, @classmethod def from_pretrained( cls, - pretrained_model_name_or_path: str | os.PathLike | PretrainedCheckpointConfig, + pretrained_model_name_or_path: str | os.PathLike | CheckpointLoadConfig, *, mode: StageMode = StageMode.inference, **kwargs, ): # Pretrained config. - if not isinstance(pretrained_model_name_or_path, PretrainedConfig): - pretrained_model_name_or_path = PretrainedCheckpointConfig( + if not isinstance(pretrained_model_name_or_path, CheckpointLoadConfig): + pretrained_model_name_or_path = CheckpointLoadConfig( path=pathlib.Path(pretrained_model_name_or_path), format=CheckpointType.state_dict, ) diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index e1a756c08..53abfbcaa 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -1,12 +1,17 @@ import enum import json import logging -import pathlib import typing from fast_llm.config import Config, Field, FieldHint, NoAutoValidate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.checkpoint import ( + CHECKPOINT_VERSION, + KNOWN_CHECKPOINT_VERSIONS, + CheckpointLoadConfig, + CheckpointMetadataConfig, + CheckpointType, +) from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert @@ -174,98 +179,6 @@ def _validate(self): # TODO: Does this matter? Best value? SHARD_PAD_TO_MULTIPLE = 32 -# TODO: Use packaging.version? (Safer but extra requirement) -CHECKPOINT_VERSION = "0.1" -_KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1") - - -class CheckpointType(str, enum.Enum): - distributed = "distributed" - # TODO: Rename - huggingface = "huggingface" - # Model state dict, mostly for debug. - state_dict = "state_dict" - - -@config_class() -class PretrainedConfig(Config): - path: pathlib.Path | None = Field( - default=None, - desc="Path to the checkpoint.", - hint=FieldHint.core, - ) - format: CheckpointType = Field( - default=CheckpointType.distributed, - desc="Format of the checkpoint.", - hint=FieldHint.core, - ) - imported_type: str | None = Field( - default=None, - desc="Model type for external models (ex. Huggingace type).", - hint=FieldHint.feature, - ) - override_architecture: bool = Field( - default=False, - desc="Ignore the base model architecture from the pretrained checkpoint and use the provided one instead." - " May have unintended consequences.", - hint=FieldHint.feature, - ) - load_full_base_model_config: bool = Field( - default=False, - desc="Load the non-architecture model config from the checkpoint.", - hint=FieldHint.feature, - ) - load_full_fast_llm_config: bool = Field( - default=False, - desc="Load the distributed and multi-stage config from the checkpoint.", - hint=FieldHint.feature, - ) - - def _validate(self): - if self.load_full_fast_llm_config: - self.load_full_base_model_config = True - super()._validate() - - @property - def compare_log_fn(self): - return logger.warning if self.override_architecture else ValueError - - -@config_class() -class PretrainedCheckpointConfig(PretrainedConfig): - # Load weights from path (if applicable), - # otherwise reinitialize them (i.e. load the config only.) - load_weights: bool = Field(default=True, desc="Load model weights from the checkpoint.", hint=FieldHint.feature) - load_optimizer: bool = Field( - default=False, desc="Load the optimizer state from the checkpoint.", hint=FieldHint.feature - ) - - -@config_class() -class CheckpointConfig(Config): - # TODO: Merge/match with PretrainedConfig? - checkpoint_path: pathlib.Path = Field(desc="Path to the checkpoint.", hint=FieldHint.core) - checkpoint_type: CheckpointType = Field( - default=CheckpointType.distributed, desc="Format of the checkpoint.", hint=FieldHint.core - ) - exported_model_type: str | None = Field( - default=None, desc="Model type for external models (ex. Huggingace type).", hint=FieldHint.feature - ) - save_optimizer: bool = Field( - default=True, desc="Save the optimizer state from the checkpoint.", hint=FieldHint.feature - ) - target_params_per_file: int = Field( - default=2**32, - desc="Limit the number of parameters saved in each file.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 2**20), - ) - dtype: DataType | None = Field( - default=None, - desc="Data type to save the checkpoint.", - hint=FieldHint.feature, - ) - @config_class() class FastLLMModelConfig(Config): @@ -298,7 +211,7 @@ def get_base_model_config_cls(cls) -> type[BaseModelConfig]: @classmethod def from_pretrained( cls, - pretrained: PretrainedConfig, + pretrained: CheckpointMetadataConfig, default: "FastLLMModelConfig" = None, ): # TODO: Add *updates? @@ -309,7 +222,7 @@ def from_pretrained( @classmethod def from_metadata( cls, - pretrained: PretrainedConfig, + pretrained: CheckpointMetadataConfig, metadata: dict, default: "FastLLMModelConfig" = None, updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -320,22 +233,22 @@ def from_metadata( # TODO python 3.12: Assert.incl(metadata["checkpoint_type"], CheckpointType) CheckpointType(metadata["checkpoint_type"]) version = metadata["checkpoint_version"] - if version not in _KNOWN_CHECKPOINT_VERSIONS: + if version not in KNOWN_CHECKPOINT_VERSIONS: raise ValueError(f"Unrecognised checkpoint version: {version}") if version == "0": return cls._from_metadata_v0(pretrained, metadata, default, updates) pretrained_config = cls.from_dict(metadata["fast_llm_config"]) - if pretrained.override_architecture: + if pretrained.architecture_config: assert default is not None config = default.to_copy() config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn) - elif pretrained.load_full_fast_llm_config: + elif pretrained.fast_llm_config: config = pretrained_config else: with NoAutoValidate(): config = cls() if default is None else default.to_copy() - if pretrained.load_full_base_model_config: + if pretrained.base_model_config: config.base_model = pretrained_config.base_model else: config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture()) @@ -348,7 +261,7 @@ def from_metadata( @classmethod def _from_metadata_v0( cls, - pretrained: PretrainedConfig, + pretrained: CheckpointMetadataConfig, metadata: dict, default: "FastLLMModelConfig" = None, updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -361,22 +274,22 @@ def _from_metadata_v0( with NoAutoValidate(): if default is None: - assert not pretrained.override_architecture + assert not pretrained.architecture_config config = cls(base_model=base_model_config_cls()) else: config = default.to_copy() - if pretrained.override_architecture: + if pretrained.architecture_config: config.validate() architecture_config.compare_architecture(default.base_model, pretrained.compare_log_fn) else: - if pretrained.load_full_base_model_config: + if pretrained.base_model_config: # Replace the whole config config.base_model = base_model_config_cls.from_flat_dict(metadata["model_config"]) else: # Replace the architecture parts of the config. config.base_model = config.base_model.to_copy(architecture_config) - if pretrained.load_full_fast_llm_config: + if pretrained.fast_llm_config: config.multi_stage = MultiStageConfig.from_flat_dict(metadata["multi_stage_config"]) config.distributed = DistributedConfig.from_flat_dict( metadata["distributed_config"], @@ -388,7 +301,7 @@ def _from_metadata_v0( return config @classmethod - def load_pretrained_metadata(cls, pretrained): + def load_pretrained_metadata(cls, pretrained: CheckpointMetadataConfig): import yaml base_model_config_cls = cls.get_base_model_config_cls() @@ -396,12 +309,12 @@ def load_pretrained_metadata(cls, pretrained): return yaml.safe_load((pretrained.path / "metadata.yaml").open("r")) elif pretrained.format == CheckpointType.state_dict: return json.load((pretrained.path / f"state_dict.safetensors.index.json").open("r"))["metadata"] - elif pretrained.format == CheckpointType.huggingface: - converter_class = base_model_config_cls.get_converter_class(pretrained.imported_type) + elif pretrained.format == CheckpointType.external: + converter_class = base_model_config_cls.get_converter_class(pretrained.model_type) imported_model_config = converter_class.import_config(converter_class.load_config(pretrained.path), True) return { "fast_llm_config": {"base_model": imported_model_config.to_serialized()}, - "checkpoint_type": CheckpointType.huggingface.value, + "checkpoint_type": CheckpointType.external.value, "checkpoint_version": CHECKPOINT_VERSION, } else: @@ -415,8 +328,8 @@ class PretrainedFastLLMModelConfig(Config): model: FastLLMModelConfig = Field( default_factory=FastLLMModelConfig, desc="Configuration for the Fast-LLM model.", hint=FieldHint.core ) - pretrained: PretrainedCheckpointConfig = Field( - default_factory=PretrainedCheckpointConfig, + pretrained: CheckpointLoadConfig = Field( + default_factory=CheckpointLoadConfig, desc="Configuration for loading the configuration and state of a pretrained model.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 1e4deaa8f..806c5d950 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -10,15 +10,14 @@ from fast_llm.core.distributed import all_reduce, broadcast, safe_barrier from fast_llm.engine.base_model.base_model import BaseModel -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import ( +from fast_llm.engine.config_utils.checkpoint import ( CHECKPOINT_VERSION, - CheckpointConfig, + CheckpointLoadConfig, + CheckpointSaveConfig, CheckpointType, - FastLLMModelConfig, - PretrainedCheckpointConfig, - StageMode, ) +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import MultiStageModel from fast_llm.engine.multi_stage.stage import Stage from fast_llm.functional.triton.pointwise import triton_fill @@ -83,28 +82,28 @@ def distributed(self): def save_checkpoint( self, - checkpoint_config: CheckpointConfig, + checkpoint_config: CheckpointSaveConfig, metadata: dict | None = None, ): if metadata is None: metadata = {} - if checkpoint_config.checkpoint_type == CheckpointType.distributed: + if checkpoint_config.format == CheckpointType.distributed: self._save_distributed_checkpoint(checkpoint_config, metadata) - elif checkpoint_config.checkpoint_type == CheckpointType.state_dict: + elif checkpoint_config.format == CheckpointType.state_dict: self._save_state_dict_checkpoint(checkpoint_config, metadata) - elif checkpoint_config.checkpoint_type == CheckpointType.huggingface: - assert checkpoint_config.exported_model_type is not None + elif checkpoint_config.format == CheckpointType.external: + assert checkpoint_config.model_type is not None self._export_checkpoint(checkpoint_config, metadata) else: - raise NotImplementedError(checkpoint_config.checkpoint_type) + raise NotImplementedError(checkpoint_config.format) - def load_pretrained_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): + def load_pretrained_checkpoint(self, pretrained_config: CheckpointLoadConfig): if pretrained_config.format == CheckpointType.distributed: # TODO: Check if same format. self._load_distributed_checkpoint(pretrained_config) elif pretrained_config.format == CheckpointType.state_dict: self._load_state_dict_checkpoint(pretrained_config) - elif pretrained_config.format == CheckpointType.huggingface: + elif pretrained_config.format == CheckpointType.external: self._import_checkpoint(pretrained_config) else: raise NotImplementedError(pretrained_config.format) @@ -113,7 +112,7 @@ def load_distributed_checkpoint_same_format(self, directory: pathlib.Path): # TODO: Handle barriers, ok file, etc. here # TODO: More safety checks # TODO: Integrate to load_checkpoint. - pretrained_config = PretrainedCheckpointConfig(path=directory, format=CheckpointType.distributed) + pretrained_config = CheckpointLoadConfig(path=directory, format=CheckpointType.distributed) metadata = self.config_class.load_pretrained_metadata(pretrained_config) with self._LoadContext(self, safe=False, load_optimizer=True, reset_pads=False) as context: Assert.eq( @@ -132,7 +131,7 @@ def load_distributed_checkpoint_same_format(self, directory: pathlib.Path): @classmethod def from_pretrained( cls, - pretrained_config: PretrainedCheckpointConfig, + pretrained_config: CheckpointLoadConfig, default_config: FastLLMModelConfig = None, *, config_updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -166,7 +165,7 @@ def from_pretrained( model.setup(Distributed(config.distributed, use_cpu=use_cpu), mode=mode) if mode.on_device: - if pretrained_config.load_weights: + if pretrained_config.model_weights: model.load_pretrained_checkpoint(pretrained_config) else: model.initialize_weights() @@ -306,15 +305,15 @@ def _save_next_file(self): self.param_count = 0 self.tensors = {} - def _save_distributed_checkpoint(self, checkpoint_config: CheckpointConfig, metadata: dict): + def _save_distributed_checkpoint(self, checkpoint_config: CheckpointSaveConfig, metadata: dict): # TODO: Non blocking? # TODO: CPU memory? # TODO: Handle barriers, ok file, mkdir, etc. here with self._SaveContext( self, metadata, - directory=checkpoint_config.checkpoint_path, - save_optimizer=checkpoint_config.save_optimizer, + directory=checkpoint_config.path, + save_optimizer=checkpoint_config.optimizer_state, ) as context: if self._distributed_config.rank == 0: yaml.safe_dump(context.metadata, (context.directory / "metadata.yaml").open("w")) @@ -324,40 +323,40 @@ def _save_distributed_checkpoint(self, checkpoint_config: CheckpointConfig, meta metadata=_export_safetensors_metadata(context.metadata), ) - def _save_state_dict_checkpoint(self, checkpoint_config: CheckpointConfig, metadata: dict): + def _save_state_dict_checkpoint(self, checkpoint_config: CheckpointSaveConfig, metadata: dict): # TODO: Make into a special case of _export_checkpoint? # TODO: Handle barriers, ok file, mkdir, etc. here with self._SaveStateDictContext( self, metadata, - directory=checkpoint_config.checkpoint_path, - save_optimizer=checkpoint_config.save_optimizer, + directory=checkpoint_config.path, + save_optimizer=checkpoint_config.optimizer_state, base_file_name="state_dict", - target_params_per_file=checkpoint_config.target_params_per_file, + target_params_per_file=checkpoint_config.parameters_per_file, ) as context: for i, shard_name in enumerate( - self._state_shard_names if checkpoint_config.save_optimizer else self._state_shard_names[:1] + self._state_shard_names if checkpoint_config.optimizer_state else self._state_shard_names[:1] ): shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) for stage, shard in zip(self._stages_on_device.values(), shard_split): - for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.dtype): # noqa + for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.data_type): # noqa context.add_tensor(f"{name}/{shard_name}", tensor) - def _export_checkpoint(self, checkpoint_config: CheckpointConfig, metadata: dict): + def _export_checkpoint(self, checkpoint_config: CheckpointSaveConfig, metadata: dict): # TODO: Handle barriers, ok file, mkdir, etc. here # TODO: Support optimizer? - assert not checkpoint_config.save_optimizer + assert not checkpoint_config.optimizer_state with self._SaveStateDictContext( self, metadata, - directory=checkpoint_config.checkpoint_path, + directory=checkpoint_config.path, save_optimizer=False, base_file_name="model", - target_params_per_file=checkpoint_config.target_params_per_file, + target_params_per_file=checkpoint_config.parameters_per_file, ) as context: - converter_class = self._base_model_config.get_converter_class(checkpoint_config.exported_model_type) + converter_class = self._base_model_config.get_converter_class(checkpoint_config.model_type) exported_config = converter_class.export_config(self._base_model_config) - converter_class.save_config(checkpoint_config.checkpoint_path, exported_config) + converter_class.save_config(checkpoint_config.path, exported_config) context.metadata = { "fast_llm_metadata": context.metadata, "model_config": exported_config, @@ -371,11 +370,11 @@ def _export_checkpoint(self, checkpoint_config: CheckpointConfig, metadata: dict # it will remain in `fast_llm_state_dict` until that tensor is available. fast_llm_state_dict = {} for i, shard_name in enumerate( - self._state_shard_names if checkpoint_config.save_optimizer else self._state_shard_names[:1] + self._state_shard_names if checkpoint_config.optimizer_state else self._state_shard_names[:1] ): shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) for stage, shard in zip(self._stages_on_device.values(), shard_split): - for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.dtype): # noqa + for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.data_type): # noqa assert name not in fast_llm_state_dict fast_llm_state_dict[name] = tensor for exported_name, exported_tensor in converter.convert_state_dict( @@ -568,13 +567,13 @@ def import_state_tensor(self, shard_name: str, parameter_name: str, tensor: torc self.loaded_parameters[shard_name] = {} self.loaded_parameters[shard_name][parameter_name] = loaded - def _load_distributed_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): + def _load_distributed_checkpoint(self, pretrained_config: CheckpointLoadConfig): # TODO: More safety checks metadata = self.config_class.load_pretrained_metadata(pretrained_config) loaded_pretrained_config = pretrained_config.to_copy( { - "load_full_base_model_config": True, - "load_full_fast_llm_config": True, + "base_model_config": True, + "fast_llm_config": True, }, ) loaded_config = self.config_class.from_metadata( @@ -582,7 +581,7 @@ def _load_distributed_checkpoint(self, pretrained_config: PretrainedCheckpointCo metadata, ) with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.load_optimizer, reset_pads=True + self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True ) as context: Assert.eq(metadata["state_shard_names"][: context.num_shards], list(context.shard_names)) @@ -601,12 +600,12 @@ def _load_distributed_checkpoint(self, pretrained_config: PretrainedCheckpointCo return metadata["metadata"] - def _load_state_dict_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): + def _load_state_dict_checkpoint(self, pretrained_config: CheckpointLoadConfig): # TODO: Make into a special case of _import_state_tensor? # TODO: Verify more distributed configs. # TODO: More safety checks with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.load_optimizer, reset_pads=True + self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True ) as context: index_path = pretrained_config.path / f"state_dict.safetensors.index.json" logger.info(f"Loading index from {index_path}") @@ -627,21 +626,21 @@ def _load_state_dict_checkpoint(self, pretrained_config: PretrainedCheckpointCon return metadata["metadata"] - def _import_checkpoint(self, pretrained_config: PretrainedCheckpointConfig): + def _import_checkpoint(self, pretrained_config: CheckpointLoadConfig): # TODO: Support optimizer? - assert not pretrained_config.load_optimizer + assert not pretrained_config.optimizer_state # TODO: Verify more distributed configs. # TODO: Safety checks - converter_class = self.base_model.architecture_cls().get_converter_class(pretrained_config.imported_type) + converter_class = self.base_model.architecture_cls().get_converter_class(pretrained_config.model_type) converter = converter_class.from_config(converter_class.load_config(pretrained_config.path)) self._base_model_config.compare_architecture(converter.config, pretrained_config.compare_log_fn) state_dict = {} with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.load_optimizer, reset_pads=True + self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True ) as context: - for name, tensor in converter.load_weights(pretrained_config.path, self._distributed.device): + for name, tensor in converter.model_weights(pretrained_config.path, self._distributed.device): assert name not in state_dict state_dict[name] = tensor for parameter_name, fast_llm_tensor in converter.convert_state_dict(state_dict, False).items(): diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 6f2fce854..c4a7b94a8 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -1,11 +1,19 @@ import argparse import os +import pathlib import shlex import subprocess import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.data.config import AbstractDataConfig +from fast_llm.engine.config_utils.checkpoint import ( + CheckpointConfigBase, + CheckpointSaveConfig, + CheckpointSaveConfigBase, + CheckpointStateConfigBase, + CheckpointType, +) from fast_llm.engine.config_utils.run import ExperimentConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig @@ -17,48 +25,56 @@ from fast_llm.engine.training.trainer import Trainer -def get_interval_config_class(desc: str, offset_desc: str | None = None): - # Intervals are a common pattern, so we standardize them with this helper. - @config_class() - class IntervalConfig(Config): - interval: int | None = Field( - default=None, - desc=f"The number of training iterations between each {desc}. Setting to None will disable.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.gt, 0)), - ) - offset: int = Field( - default=0, - desc=f"Offset for the first {offset_desc or desc}.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) +@config_class() +class IntervalConfig(Config): + # Intervals are a common pattern, so we standardize them with this base class. + interval: int | None = Field( + default=None, + desc="The number of training iterations between each interval. Setting to None will disable.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + offset: int = Field( + default=0, + desc="Offset for the first interval.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) - def enabled(self, iteration: int | None = None): - return self.interval and (iteration is None or (iteration - self.offset) % self.interval == 0) + def _validate(self): + self.offset %= self.interval + super()._validate() + + def enabled(self, iteration: int | None = None): + return self.interval and (iteration is None or (iteration - self.offset) % self.interval == 0) - def is_sub_interval(self, other: "IntervalConfig"): - if not self.enabled(): - return True - elif not other.enabled(): - return False - return self.interval % other.interval == 0 and (other.offset % other.interval) == ( - self.offset % other.interval - ) + def is_sub_interval(self, other: "IntervalConfig"): + if not self.enabled(): + return True + elif not other.enabled(): + return False + return self.interval % other.interval == 0 and (other.offset % other.interval) == ( + self.offset % other.interval + ) - def assert_sub_interval(self, other: "IntervalConfig"): - assert self.is_sub_interval(other), f"{self} is not a sub-interval of {other}" + def assert_sub_interval(self, other: "IntervalConfig"): + assert self.is_sub_interval(other), f"{self} is not a sub-interval of {other}" - return IntervalConfig + def get_count(self, iteration): + # Number of times this interval was enabled after a given iteration. + return (iteration - self.offset) // self.interval + 1 if self.enabled() else 0 @config_class() -class WandbAlertConfig( - get_interval_config_class( - "Wandb status post (alert). Must be a multiple of the logging interval", - "Wandb status post (alert). Must be compatible with the logging offset", +class WandbAlertConfig(IntervalConfig): + interval = FieldUpdate( + desc="The number of training iterations between each Wandb status post (alert)." + " Setting to None will disable iteration-based wandb alerts." + " Must be a sub-interval of the logging interval." + ) + offset = FieldUpdate( + desc="Offset for the first Wandb status post (alert)." " Must be compatible with the logging offset.", ) -): status_updates: bool | None = Field( default=None, desc="Post wandb status updates on status changes (run begin/end). " @@ -73,15 +89,21 @@ def _validate(self): @config_class() -class MetricsLogsConfig(get_interval_config_class("metric logs")): - pass +class MetricsLogsConfig(IntervalConfig): + interval = FieldUpdate( + default=100, + desc="The number of training iterations between each metric logs." + " Setting to None will disable metric logging.", + ) + offset = FieldUpdate(desc="Offset for the first metric logs.") @config_class() class WandbConfig(Config): alert: WandbAlertConfig = Field( default_factory=WandbAlertConfig, - desc="Configuration for Wandb alerts. The alerts may be posted by email and/or slack depending on the Wandb account configuration.", + desc="Configuration for Wandb alerts." + " The alerts may be posted by email and/or slack depending on the Wandb account configuration.", hint=FieldHint.core, ) group_name: str = Field(default="default", desc="A group name for Wandb", hint=FieldHint.feature) @@ -90,7 +112,12 @@ class WandbConfig(Config): @config_class() -class ValidationConfig(get_interval_config_class("validation")): +class ValidationConfig(IntervalConfig): + interval = FieldUpdate( + desc="The number of training iterations between each validation phase." + " Setting to None will disable validation." + ) + offset = FieldUpdate(desc="Offset for the first validation phase.") iterations: int | None = Field( default=None, desc="Number of iterations for each validation phase. Setting to None will disable.", @@ -98,21 +125,38 @@ class ValidationConfig(get_interval_config_class("validation")): valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) - def get_completed_iterations(self, training_iterations: int, completed_validations: int = 0): + def get_iteration_count(self, training_iterations: int, extra_validations: int = 0): # Number of completed validation iterations - return ( - (training_iterations // self.interval + completed_validations) * self.iterations if self.enabled() else 0 - ) + return (self.get_count(training_iterations) + extra_validations) * self.iterations if self.enabled() else 0 @config_class() -class CheckpointConfig(get_interval_config_class("checkpoint")): +class CheckpointConfig(IntervalConfig): + interval = FieldUpdate( + desc="The number of training iterations between each checkpoint." " Setting to None will disable checkpoints." + ) + offset = FieldUpdate(desc="Offset for the first checkpoint.") keep: int | None = Field( default=5, desc="The maximum number of checkpoints to keep. When exceeding this value, checkpoints are deleted starting from the older ones.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + keep_every: int | None = Field( + default=None, + desc="Keep every nth checkpoint, i.e. Exclude it from the checkpoint count and deletion in `keep`.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.gt, 0)), + ) + + def get_save_config(self, path: pathlib.Path): + return CheckpointSaveConfig( + path=path, + format=CheckpointType.distributed, + fast_llm_config=True, + model_weights=True, + optimizer_state=True, + ) def _validate_script(value): @@ -144,17 +188,31 @@ def run(self): @config_class() -class ExportConfig(get_interval_config_class("export")): +class ExportConfig(IntervalConfig, CheckpointConfigBase, CheckpointStateConfigBase, CheckpointSaveConfigBase): + interval = FieldUpdate( + desc="The number of training iterations between each export." " Setting to None will disable exports." + ) + offset = FieldUpdate(desc="Offset for the first export.") callback: CallbackConfig = Field( default_factory=CallbackConfig, desc="Callback (shell script) to run after export.", hint=FieldHint.core, ) + def get_save_config(self, path: pathlib.Path): + return CheckpointSaveConfig.from_dict(self, {"path": path}, strict=False) + @config_class() -class ShutdownConfig(get_interval_config_class("automated shutdown")): - pass +class ShutdownConfig(IntervalConfig): + interval = FieldUpdate( + desc="The number of training iterations between each automated shutdown." + " Setting to None will disable automated shutdowns." + " Must be a sub-interval of the checkpoint interval." + ) + offset = FieldUpdate( + desc="Offset for the first automated shutdown." " Must be compatible with the checkpoint offset." + ) @config_class() @@ -201,7 +259,6 @@ class TrainingConfig(Config): def _validate(self): super()._validate() - self.export.assert_sub_interval(self.checkpoint) self.shutdown.assert_sub_interval(self.checkpoint) self.wandb.alert.assert_sub_interval(self.logs) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a17eca121..cdc9b7e7f 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -9,10 +9,10 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.config import AbstractData from fast_llm.data.data import Data +from fast_llm.engine.config_utils.checkpoint import CheckpointSaveConfig, CheckpointType from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import CheckpointConfig, CheckpointType from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer @@ -56,8 +56,10 @@ def __init__(self, config: TrainerConfig): ) steps_per_split = { PhaseType.training: self._config.training.train_iters, - PhaseType.validation: self._config.training.validation.get_completed_iterations( - self._config.training.train_iters, 1 + PhaseType.validation: self._config.training.validation.get_iteration_count( + self._config.training.train_iters, + # There may be an extra validation after the last training step. + not self._config.training.validation.enabled(self._config.training.train_iters), ), PhaseType.test: self._config.training.test_iters, } @@ -130,7 +132,7 @@ def _consumed_tokens(self): @property def _completed_validation_steps(self) -> int: # Number of validation steps performed before the current step - return self._config.training.validation.get_completed_iterations(self._completed_steps - 1) + return self._config.training.validation.get_iteration_count(self._completed_steps - 1) def run(self): assert self._is_setup @@ -365,10 +367,10 @@ def _evaluate( def _prepare_training_state(self): # Setup the training state. if (last_iteration := self._run.get_last_checkpoint()) is None: - if (path := self._config.pretrained.path) is not None and self._config.pretrained.load_weights: + if (path := self._config.pretrained.path) is not None and self._config.pretrained.model_weights: log_main_rank( f"Initializing training state from pretrained checkpoint at {path}" - f" ({'loading' if self._config.pretrained.load_optimizer else 'resetting'}" + f" ({'loading' if self._config.pretrained.optimizer_state else 'resetting'}" f" optimizer state)..." ) self._multi_stage.load_pretrained_checkpoint(self._config.pretrained) @@ -412,7 +414,7 @@ def _save_checkpoint(self, metrics: dict[PhaseType, dict[str, float | int]] | No if metrics is not None: metadata["metrics"] = {key.value: value for key, value in metrics.items()} self._multi_stage.save_checkpoint( - CheckpointConfig(checkpoint_type=CheckpointType.distributed, checkpoint_path=checkpoint.directory), + CheckpointSaveConfig(path=checkpoint.directory, format=CheckpointType.distributed), metadata, ) if export and self._run.is_main_rank: # noqa diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 3976c69bc..d8afb2fca 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,4 +1,4 @@ -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames @@ -109,9 +109,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): architecture_cls = LanguageModelArchitectureConfig - transformer: TransformerConfig = Field( - default_factory=TransformerConfig, desc="Configuration for the transformer.", hint=FieldHint.core - ) + transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig) init_method_std_embed: float = Field( default=None, desc="Initialization scale for the vocabulary embedding and output weights (logits).", diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 8109f2052..e6c133f14 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -3,7 +3,7 @@ import math import warnings -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace @@ -262,11 +262,7 @@ def use_rotary_position_embeddings(self): @config_class() class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig): - normalization: NormalizationConfig = Field( - default_factory=NormalizationConfig, - desc="Configuration for the normalization layers.", - hint=FieldHint.core, - ) + normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) # Default: hidden_size**-0.5 # TODO: Allow custom initialization (InitializationConfig?) init_method_std: float = Field( diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 1f6bb0168..82e004319 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -1,4 +1,4 @@ -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import FieldUpdate, config_class from fast_llm.data.config import DataConfig from fast_llm.models.gpt.config import ( GPTArchitectureConfig, @@ -30,7 +30,7 @@ class CustomBaseModelConfig(GPTBaseModelConfig, CustomArchitectureConfig): @config_class() class CustomModelConfig(GPTModelConfig): # TODO: Add custom model config parameters, if any (typically none). - base_model: CustomBaseModelConfig = Field(default_factory=CustomBaseModelConfig) + base_model: CustomBaseModelConfig = FieldUpdate(default_factory=CustomBaseModelConfig) @classmethod def get_model_class(cls): @@ -47,18 +47,14 @@ def get_huggingface_model_class(cls): @config_class() class PretrainedCustomModelConfig(PretrainedGPTModelConfig): - model: CustomModelConfig = Field(default_factory=CustomModelConfig) + model: CustomModelConfig = FieldUpdate(default_factory=CustomModelConfig) @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = Field( - default_factory=CustomDataConfig, - desc="Configuration for the dataset and model-independent preprocessing.", - hint=FieldHint.core, - ) + data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) @classmethod def get_trainer_class(cls): diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 4725f1e86..b61ea8cd3 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,6 +1,6 @@ import typing -from fast_llm.config import Field, FieldHint, config_class +from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.config import DataConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -65,7 +65,7 @@ def _from_dict( @config_class() class GPTModelConfig(FastLLMModelConfig): _abstract = False - base_model: GPTBaseModelConfig = Field(default_factory=GPTBaseModelConfig) + base_model: GPTBaseModelConfig = FieldUpdate(default_factory=GPTBaseModelConfig) @classmethod def get_model_class(cls): @@ -83,17 +83,13 @@ def get_huggingface_model_class(cls): @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): _abstract = False - model: GPTModelConfig = Field(default_factory=GPTModelConfig) + model: GPTModelConfig = FieldUpdate(default_factory=GPTModelConfig) @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: DataConfig = Field( - default_factory=DataConfig, - desc="Configuration for the dataset and model-independent preprocessing.", - hint=FieldHint.core, - ) + data: DataConfig = FieldUpdate(default_factory=DataConfig) def _setup(self): super()._setup() diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index d2305ade7..3e932f0ba 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -7,15 +7,10 @@ import typing from fast_llm.config import Field, config_class +from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointType from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.multi_stage.config import ( - CheckpointConfig, - CheckpointType, - FastLLMModelConfig, - PretrainedCheckpointConfig, - StageMode, -) +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.functional.config import TritonConfig from fast_llm.models.auto import model_registry from fast_llm.utils import Assert @@ -62,10 +57,10 @@ def _convert_model_partial( ): logger.info(f"Loading {self.input_type} checkpoint from {self.input_path}...") model = model_class.from_pretrained( - PretrainedCheckpointConfig( + CheckpointLoadConfig( path=self.input_path, format=self.input_type, - imported_type=self.model_type, + model_type=self.model_type, ), mode=StageMode.weights, use_cpu=self.use_cpu, @@ -74,13 +69,13 @@ def _convert_model_partial( logger.info(f"Saving {self.output_type} checkpoint to {output_path}...") output_path.mkdir(parents=True, exist_ok=self.exist_ok) model.save_checkpoint( - CheckpointConfig( - checkpoint_path=output_path, - checkpoint_type=self.output_type, - exported_model_type=self.model_type, - save_optimizer=False, - target_params_per_file=self.target_params_per_file, - dtype=self.dtype, + CheckpointSaveConfig( + path=output_path, + format=self.output_type, + model_type=self.model_type, + optimizer_state=False, + parameters_per_file=self.target_params_per_file, + data_type=self.dtype, ) ) (output_path / "ok").open("w") @@ -106,15 +101,15 @@ def run(self, model_config_class: type["FastLLMModelConfig"] | str): self._convert_model_partial(model_class, self.output_path) else: # TODO: Support other types? - assert self.output_type == CheckpointType.huggingface + assert self.output_type == CheckpointType.external logger.info(f">>> Loading model config") # Create a dummy version to determine the stage split. model = model_class.from_pretrained( - PretrainedCheckpointConfig( + CheckpointLoadConfig( path=self.input_path, format=self.input_type, - imported_type=self.model_type, - load_pretrained_weights=False, + model_type=self.model_type, + model_weights=False, ), mode=StageMode.off_device, use_cpu=self.use_cpu, diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 50ba105fa..4293ab2e1 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -7,7 +7,8 @@ import transformers import yaml -from fast_llm.engine.multi_stage.config import CheckpointType, PretrainedCheckpointConfig, StageMode +from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointType +from fast_llm.engine.multi_stage.config import StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig from tests.common import ( @@ -104,7 +105,7 @@ def test_convert_state_dict_to_huggingface(): ConversionConfig( input_type=CheckpointType.state_dict, input_path=_CONVERT_PATH / "state_dict_0", - output_type=CheckpointType.huggingface, + output_type=CheckpointType.external, output_path=_CONVERT_PATH / "huggingface_0", model_type=HUGGINGFACE_MODEL_TYPE, ) @@ -115,7 +116,7 @@ def test_convert_state_dict_to_huggingface(): def test_convert_huggingface_to_distributed(): _run_conversion( ConversionConfig( - input_type=CheckpointType.huggingface, + input_type=CheckpointType.external, input_path=_CONVERT_PATH / "huggingface_0", output_type=CheckpointType.distributed, output_path=_CONVERT_PATH / "distributed_0", @@ -131,7 +132,7 @@ def test_convert_distributed_to_huggingface(): ConversionConfig( input_type=CheckpointType.distributed, input_path=_CKPT_PATH, - output_type=CheckpointType.huggingface, + output_type=CheckpointType.external, output_path=_CONVERT_PATH / "huggingface_1", model_type=HUGGINGFACE_MODEL_TYPE, ) @@ -142,7 +143,7 @@ def test_convert_distributed_to_huggingface(): def test_convert_huggingface_to_state_dict(): _run_conversion( ConversionConfig( - input_type=CheckpointType.huggingface, + input_type=CheckpointType.external, input_path=_CONVERT_PATH / "huggingface_1", output_type=CheckpointType.state_dict, output_path=_CONVERT_PATH / "state_dict_1", @@ -206,12 +207,12 @@ def test_load_pretrained_distributed_checkpoint(): config = TEST_ARCHITECTURE_CONFIG_CLS.from_dict( yaml.safe_load((_CKPT_PATH / ".." / ".." / "config.yaml").open("r")), strict=False ) - pretrained_config_ref = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, format=CheckpointType.distributed, - load_optimizer=True, - load_full_base_model_config=True, - load_full_fast_llm_config=True, + optimizer_state=True, + base_model_config=True, + fast_llm_config=True, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_configs(config, model._base_model_config) @@ -223,12 +224,12 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): - pretrained_config_ref = PretrainedCheckpointConfig(path=_CKPT_PATH, format=CheckpointType.distributed) - pretrained_config_0 = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=CheckpointType.distributed) + pretrained_config_0 = CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_0", format=CheckpointType.distributed, ) - pretrained_config_1 = PretrainedCheckpointConfig( + pretrained_config_1 = CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_1", format=CheckpointType.distributed, ) @@ -245,13 +246,9 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_state_dict", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_state_dict_checkpoint(): - pretrained_config_ref = PretrainedCheckpointConfig(path=_CKPT_PATH, format=CheckpointType.distributed) - pretrained_config_0 = PretrainedCheckpointConfig( - path=_CONVERT_PATH / "state_dict_0", format=CheckpointType.state_dict - ) - pretrained_config_1 = PretrainedCheckpointConfig( - path=_CONVERT_PATH / "state_dict_1", format=CheckpointType.state_dict - ) + pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=CheckpointType.distributed) + pretrained_config_0 = CheckpointLoadConfig(path=_CONVERT_PATH / "state_dict_0", format=CheckpointType.state_dict) + pretrained_config_1 = CheckpointLoadConfig(path=_CONVERT_PATH / "state_dict_1", format=CheckpointType.state_dict) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) @@ -265,17 +262,17 @@ def test_load_converted_state_dict_checkpoint(): @pytest.mark.depends(on=["test_converted_state_dict", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_huggingface_checkpoint(): - pretrained_config_ref = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, format=CheckpointType.distributed, ) - pretrained_config_0 = PretrainedCheckpointConfig( + pretrained_config_0 = CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", - format=CheckpointType.huggingface, + format=CheckpointType.external, ) - pretrained_config_1 = PretrainedCheckpointConfig( + pretrained_config_1 = CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", - format=CheckpointType.huggingface, + format=CheckpointType.external, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0, mode=StageMode.weights) @@ -291,7 +288,7 @@ def test_load_converted_huggingface_checkpoint(): @pytest.mark.depends(on=["test_load_converted_state_dict_checkpoint", "test_load_converted_huggingface_checkpoint"]) def test_run_converted_model(): model_ref = TEST_MODEL_HF_CLS.from_pretrained( - PretrainedCheckpointConfig( + CheckpointLoadConfig( path=_CKPT_PATH, format=CheckpointType.distributed, ) @@ -302,9 +299,9 @@ def test_run_converted_model(): output_ref = model_ref(test_input) model_from_state_dict = TEST_MODEL_HF_CLS.from_pretrained(_CONVERT_PATH / "state_dict_0") model_from_hf = TEST_MODEL_HF_CLS.from_pretrained( - PretrainedCheckpointConfig( + CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", - format=CheckpointType.huggingface, + format=CheckpointType.external, ) ) errors = [] @@ -364,15 +361,15 @@ def test_load_pretrained_distributed_with_config(): @pytest.mark.depends(on=["test_load_pretrained_distributed_in_dp2"]) def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1" - pretrained_config_ref = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, format=CheckpointType.distributed, - load_full_fast_llm_config=True, + fast_llm_config=True, ) - pretrained_config_test = PretrainedCheckpointConfig( + pretrained_config_test = CheckpointLoadConfig( path=test_ckpt_path, format=CheckpointType.distributed, - load_full_fast_llm_config=True, + fast_llm_config=True, ) config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) config_test = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_test) @@ -406,13 +403,13 @@ def test_load_pretrained_in_dp2_match_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_in_dp2_match_checkpoint"]) def test_load_distributed_checkpoint_dp2(): # This also tests conversion which uses `FastLLMModel.from_checkpoint` - pretrained_config_ref = PretrainedCheckpointConfig( + pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, format=CheckpointType.distributed, - load_full_base_model_config=True, - load_full_fast_llm_config=True, + base_model_config=True, + fast_llm_config=True, ) - pretrained_config_test = PretrainedCheckpointConfig( + pretrained_config_test = CheckpointLoadConfig( path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1", format=CheckpointType.distributed, ) diff --git a/tools/push_model.py b/tools/push_model.py index 57012c0c2..3045204d2 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -6,6 +6,7 @@ import shutil import subprocess +from fast_llm.engine.config_utils.checkpoint import CheckpointType from fast_llm.engine.config_utils.runnable import RunnableConfig try: @@ -26,7 +27,6 @@ from fast_llm.config import Config, config_class, Field # isort:skip -from fast_llm.engine.multi_stage.config import CheckpointType # isort:skip from fast_llm.tools.convert import ConversionConfig # isort:skip @@ -149,7 +149,7 @@ def run(self) -> None: # Block until the conversion is done ConversionConfig( input_type=CheckpointType.distributed, - output_type=CheckpointType.huggingface, + output_type=CheckpointType.external, input_path=checkpoint_path, output_path=checkpoint_path_hf, model_type=self.model_type, From ed4521209f66781088f09db1eef393dafc9d679b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 17 Oct 2024 11:23:53 -0400 Subject: [PATCH 2/8] Simplify save checkpoint --- fast_llm/engine/config_utils/run.py | 46 ------------ fast_llm/engine/training/config.py | 106 +++++++++++++++++----------- fast_llm/engine/training/trainer.py | 65 +++++++++++------ 3 files changed, 107 insertions(+), 110 deletions(-) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 740b83171..ec9dadb1d 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -1,7 +1,6 @@ import logging import os import pathlib -import shutil import typing import warnings @@ -207,9 +206,6 @@ def save_logged_tensors(self, iteration: int | str): torch.save(tensor_stats, self.open_artifact(f"tensor_logs_{iteration}.pt", mode="wb")) TensorLogs.reset(self._config.tensor_logs) - def get_save_checkpoint_context(self, iteration: int, export: bool = False, keep: int | None = None): - return self._SaveCheckpointContext(self, iteration, export, keep) - def get_load_checkpoint_context(self, iteration: int): return self._LoadCheckpointContext(self, iteration) @@ -236,35 +232,6 @@ def __init__(self, run: "Run", iteration: int): def directory(self): return self._directory - class _SaveCheckpointContext(_CheckpointContext): - def __init__(self, run: "Run", iteration: int, export: bool = False, keep: int | None = None): - super().__init__(run, iteration) - self._export = export - self._keep = keep - if self._export: - self._link_directory = self._directory - self._directory = self._run._export_dir / str(self._iteration) - - def __enter__(self): - assert self._run._is_running - if self._run._is_main_rank: - logger.info(f"Saving checkpoint at iteration {self._iteration}") - self._directory.mkdir(parents=True) - if self._export: - (self._run._checkpoint_dir / str(self._iteration)).symlink_to(self._directory) - # Barrier to ensure the directory is created correctly (and didn't exist before). - self._run.barrier(f"save {self._iteration} enter") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_type: - self._run.barrier(f"save {self._iteration} exit") - if self._run._is_main_rank: - # Prevent corrupted checkpoint. - (self._directory / "ok").open("w") - logger.info(f"Checkpoint saved to {self._directory}") - self._run._delete_old_checkpoints(self._keep) - class _LoadCheckpointContext(_CheckpointContext): def __enter__(self): assert self._run._is_running @@ -275,19 +242,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: self._run.barrier(f"load {self._iteration} exit") - def _delete_old_checkpoints(self, keep: int | None): - assert self._is_running - if keep is None: - return - checkpoints = sorted(int(path.name) for path in self._checkpoint_dir.iterdir()) - for checkpoint in checkpoints[:-keep]: - path = self._checkpoint_dir / str(checkpoint) - logger.info(f"Deleting checkpoint at {path}") - try: - shutil.rmtree(path, ignore_errors=True) - except OSError as e: - logger.warning(f"Could not remove checkpoint directory: {e.args}") - def get_last_checkpoint(self): assert self._is_running if self._checkpoint_dir is None: diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index c4a7b94a8..586c72ade 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -65,6 +65,34 @@ def get_count(self, iteration): return (iteration - self.offset) // self.interval + 1 if self.enabled() else 0 +def _validate_script(value): + if isinstance(value, str): + value = shlex.split(value) + Assert.geq(len(value), 1) + return value + + +@config_class() +class CallbackConfig(Config): + script: list[str] | None = Field( + default=None, + desc="Shell script to run.", + hint=FieldHint.feature, + valid=skip_valid_if_none(_validate_script), + ) + environment: dict[str, str] = Field( + default_factory=dict, + desc="Environment variables to add to the script.", + hint=FieldHint.feature, + ) + + def run(self): + if self.script is not None: + environment = os.environ.copy() + environment.update(self.environment) + subprocess.Popen(self.script, env=environment) + + @config_class() class WandbAlertConfig(IntervalConfig): interval = FieldUpdate( @@ -131,24 +159,51 @@ def get_iteration_count(self, training_iterations: int, extra_validations: int = @config_class() -class CheckpointConfig(IntervalConfig): - interval = FieldUpdate( - desc="The number of training iterations between each checkpoint." " Setting to None will disable checkpoints." +class CheckpointBaseConfig(IntervalConfig): + save_name: typing.ClassVar[str] = "save" + callback: CallbackConfig = Field( + default_factory=CallbackConfig, + desc="Callback (shell script).", + hint=FieldHint.core, ) - offset = FieldUpdate(desc="Offset for the first checkpoint.") keep: int | None = Field( - default=5, - desc="The maximum number of checkpoints to keep. When exceeding this value, checkpoints are deleted starting from the older ones.", + default=None, + desc="The maximum number of saves to keep. When exceeding this value, checkpoints are deleted starting from the older ones.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) keep_every: int | None = Field( default=None, - desc="Keep every nth checkpoint, i.e. Exclude it from the checkpoint count and deletion in `keep`.", + desc="Keep every nth saves, i.e. Exclude it from the checkpoint count and deletion in `keep`.", hint=FieldHint.feature, valid=skip_valid_if_none(check_field(Assert.gt, 0)), ) + def get_save_config(self, path: pathlib.Path): + raise NotImplementedError() + + def to_delete(self, iterations: list[int]): + if not self.keep: + return [] + # Ignore checkpoints that aren't supposed to be there. + iterations = [iteration for iteration in iterations if self.enabled(iteration)] + # Ignore excluded checkpoints. + if self.keep_every: + iterations = [iteration for iteration in iterations if self.get_count(iteration) % self.keep_every != 0] + # Exclude the last `keep`. + return iterations[: -self.keep] + + +@config_class() +class CheckpointConfig(CheckpointBaseConfig): + save_name: typing.ClassVar[str] = "checkpoint" + interval = FieldUpdate( + desc="The number of training iterations between each checkpoint." " Setting to None will disable checkpoints." + ) + offset = FieldUpdate(desc="Offset for the first checkpoint.") + callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after checkpoint.") + keep: int | None = FieldUpdate(default=5) + def get_save_config(self, path: pathlib.Path): return CheckpointSaveConfig( path=path, @@ -159,45 +214,14 @@ def get_save_config(self, path: pathlib.Path): ) -def _validate_script(value): - if isinstance(value, str): - value = shlex.split(value) - Assert.geq(len(value), 1) - return value - - @config_class() -class CallbackConfig(Config): - script: list[str] | None = Field( - default=None, - desc="Shell script to run after.", - hint=FieldHint.feature, - valid=skip_valid_if_none(_validate_script), - ) - environment: dict[str, str] = Field( - default_factory=dict, - desc="Environment variables to add to the script.", - hint=FieldHint.feature, - ) - - def run(self): - if self.script is not None: - environment = os.environ.copy() - environment.update(self.environment) - subprocess.Popen(self.script, env=environment) - - -@config_class() -class ExportConfig(IntervalConfig, CheckpointConfigBase, CheckpointStateConfigBase, CheckpointSaveConfigBase): +class ExportConfig(CheckpointBaseConfig, CheckpointConfigBase, CheckpointStateConfigBase, CheckpointSaveConfigBase): + save_name: typing.ClassVar[str] = "export" interval = FieldUpdate( desc="The number of training iterations between each export." " Setting to None will disable exports." ) offset = FieldUpdate(desc="Offset for the first export.") - callback: CallbackConfig = Field( - default_factory=CallbackConfig, - desc="Callback (shell script) to run after export.", - hint=FieldHint.core, - ) + callback: CallbackConfig = FieldUpdate(desc="Callback (shell script) to run after export.") def get_save_config(self, path: pathlib.Path): return CheckpointSaveConfig.from_dict(self, {"path": path}, strict=False) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index cdc9b7e7f..149e8cbc8 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -1,6 +1,7 @@ import abc import logging import math +import shutil import time import typing @@ -9,7 +10,6 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.config import AbstractData from fast_llm.data.data import Data -from fast_llm.engine.config_utils.checkpoint import CheckpointSaveConfig, CheckpointType from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed @@ -18,7 +18,7 @@ from fast_llm.engine.optimizer.optimizer import Optimizer from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.engine.training.config import TrainerConfig +from fast_llm.engine.training.config import CheckpointBaseConfig, TrainerConfig from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, get_memory_usage_mib, log_memory_usage from fast_llm.utils import Assert @@ -313,10 +313,10 @@ def _train(self): self._wandb.log_metrics(self._completed_steps, metrics) if self._config.training.checkpoint.enabled(None if stop else self._completed_steps): - self._save_checkpoint( - metrics, - export=self._config.training.export.enabled(None if done else self._completed_steps), - ) + self._save_checkpoint(self._config.training.checkpoint, metrics) + + if self._config.training.export.enabled(None if done else self._completed_steps): + self._save_checkpoint(self._config.training.export, metrics) return done, metrics @@ -402,23 +402,42 @@ def _get_data_iterator(self, phase, completed_steps: int = 0, prefetch_factor: i prefetch_factor=prefetch_factor, ) - def _save_checkpoint(self, metrics: dict[PhaseType, dict[str, float | int]] | None, export: bool = False): - assert self._is_setup - with self._run.get_save_checkpoint_context( - self._completed_steps, export, self._config.training.checkpoint.keep - ) as checkpoint: - metadata = { - "optimizer": self._optimizer.save(), - "completed_steps": self._completed_steps, - } - if metrics is not None: - metadata["metrics"] = {key.value: value for key, value in metrics.items()} - self._multi_stage.save_checkpoint( - CheckpointSaveConfig(path=checkpoint.directory, format=CheckpointType.distributed), - metadata, - ) - if export and self._run.is_main_rank: # noqa - self._config.training.export.callback.run() + def _save_checkpoint(self, config: CheckpointBaseConfig, metrics: dict[PhaseType, dict[str, float | int]] | None): + checkpoint_base_directory = self._run.experiment_directory / config.save_name + checkpoint_directory = checkpoint_base_directory / str(self._completed_steps) + + # Create the checkpoint + log_main_rank(f"Saving {config.save_name} at iteration {self._completed_steps}") + checkpoint_directory.mkdir(exist_ok=False, parents=True) + # Barrier to ensure the directory is created correctly (and didn't exist before). + self._run.barrier(f"{config.save_name} {self._completed_steps} enter") + + metadata = { + "optimizer": self._optimizer.save(), + "completed_steps": self._completed_steps, + } + if metrics is not None: + metadata["metrics"] = {key.value: value for key, value in metrics.items()} + self._multi_stage.save_checkpoint(config.get_save_config(checkpoint_directory), metadata) + + # Barrier to ensure everyone is done. + self._run.barrier(f"{config.save_name} {self._completed_steps} exit") + # Mark the checkpoint as complete. + (checkpoint_directory / "ok").open("w") + logger.info(f"Saved {config.save_name} to {checkpoint_directory}") + + to_delete = config.to_delete(sorted(int(path.name) for path in checkpoint_directory.iterdir())) + + for iteration in to_delete: + path = checkpoint_base_directory / str(iteration) + logger.info(f"Deleting {config.save_name} at {path}") + try: + shutil.rmtree(path, ignore_errors=True) + except OSError as e: + logger.warning(f"Could not remove {config.save_name} directory: {e.args}") + + if self._run.is_main_rank: # noqa + config.callback.run() @abc.abstractmethod def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: From a6fe0e598f476818c7c3691c0778bd8d1875322b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 18 Oct 2024 07:28:26 -0400 Subject: [PATCH 3/8] Simplify checkpoint loading --- fast_llm/engine/config_utils/run.py | 40 ++--------------------------- fast_llm/engine/training/trainer.py | 38 ++++++++++++++++++++------- tools/push_model.py | 2 +- 3 files changed, 32 insertions(+), 48 deletions(-) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index ec9dadb1d..b088326e9 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -165,7 +165,7 @@ def __init__( else: run = 0 # Make sure all the workers agree on the run. This also acts as a barrier. - self.index = self._broadcast_int(run) + self.index = self.broadcast_int(run) run_dir = self._experiment_directory / "runs" / str(self.index) self._artifact_dir = run_dir / "artifacts" / str(self._distributed_config.rank) log_dir = run_dir / "logs" @@ -206,54 +206,18 @@ def save_logged_tensors(self, iteration: int | str): torch.save(tensor_stats, self.open_artifact(f"tensor_logs_{iteration}.pt", mode="wb")) TensorLogs.reset(self._config.tensor_logs) - def get_load_checkpoint_context(self, iteration: int): - return self._LoadCheckpointContext(self, iteration) - def barrier(self, value: int | str = 1): from fast_llm.core.distributed import safe_barrier safe_barrier(self._distributed.world_group, value) - def _broadcast_int(self, value: int): + def broadcast_int(self, value: int): import torch from fast_llm.core.distributed import broadcast_scalar return broadcast_scalar(value, dtype=torch.int64, src=_MAIN_RANK, group=self._distributed.world_group) - class _CheckpointContext: - def __init__(self, run: "Run", iteration: int): - self._run = run - self._iteration = iteration - assert self._run._checkpoint_dir is not None - self._directory = self._run._checkpoint_dir / str(self._iteration) - - @property - def directory(self): - return self._directory - - class _LoadCheckpointContext(_CheckpointContext): - def __enter__(self): - assert self._run._is_running - Assert.custom(pathlib.Path.is_file, self._directory / "ok") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_type: - self._run.barrier(f"load {self._iteration} exit") - - def get_last_checkpoint(self): - assert self._is_running - if self._checkpoint_dir is None: - return None - if self._is_main_rank: - checkpoints = [int(path.name) for path in self._checkpoint_dir.iterdir()] - iteration = max(checkpoints) if checkpoints else -1 - else: - iteration = -1 - iteration = self._broadcast_int(iteration) - return iteration if iteration >= 0 else None - def open_artifact(self, name: str, mode: str | None = "w", verbose=True): assert self._is_running if self._artifact_dir is None: diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 149e8cbc8..eca8a9121 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -1,6 +1,7 @@ import abc import logging import math +import pathlib import shutil import time import typing @@ -366,7 +367,7 @@ def _evaluate( def _prepare_training_state(self): # Setup the training state. - if (last_iteration := self._run.get_last_checkpoint()) is None: + if (last_iteration := self._get_last_checkpoint()) is None: if (path := self._config.pretrained.path) is not None and self._config.pretrained.model_weights: log_main_rank( f"Initializing training state from pretrained checkpoint at {path}" @@ -381,18 +382,25 @@ def _prepare_training_state(self): self._completed_steps = 0 else: log_main_rank(lambda: f"Loading checkpoint from iteration {last_iteration}...") - with self._run.get_load_checkpoint_context(last_iteration) as context: - metadata = self._multi_stage.load_distributed_checkpoint_same_format(context.directory) - self._optimizer.load(metadata["optimizer"]) - if "schedules" in metadata: - # Backward compatibility. - self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] - else: - self._completed_steps = metadata["completed_steps"] + self._load_checkpoint(self._config.training.checkpoint, last_iteration) Assert.eq(self._completed_steps, last_iteration or 0) assert self._multi_stage._is_loaded # noqa + def _load_checkpoint(self, config: CheckpointBaseConfig, iteration: int): + checkpoint_directory = self._run.experiment_directory / config.save_name / str(iteration) + Assert.custom(pathlib.Path.is_file, checkpoint_directory / "ok") + # TODO v0.2: Use config.get_load_config to make it generic + # TODO v0.2: Detect format instead of hard-coding + metadata = self._multi_stage.load_distributed_checkpoint_same_format(checkpoint_directory) + self._optimizer.load(metadata["optimizer"]) + if "schedules" in metadata: + # Backward compatibility. + self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] + else: + self._completed_steps = metadata["completed_steps"] + self._run.barrier(f"load {config.save_name} {iteration} exit") + def _get_data_iterator(self, phase, completed_steps: int = 0, prefetch_factor: int | None = None): return self._data.get_iterator( self._config.batch, @@ -439,6 +447,18 @@ def _save_checkpoint(self, config: CheckpointBaseConfig, metrics: dict[PhaseType if self._run.is_main_rank: # noqa config.callback.run() + def _get_last_checkpoint(self): + if self._run.experiment_directory is None: + return None + checkpoint_base_directory = self._run.experiment_directory / self._config.training.checkpoint.save_name + if self._run.is_main_rank: + checkpoints = [int(path.name) for path in checkpoint_base_directory.iterdir()] + iteration = max(checkpoints) if checkpoints else -1 + else: + iteration = -1 + iteration = self._run.broadcast_int(iteration) + return iteration if iteration >= 0 else None + @abc.abstractmethod def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: # TODO: Do in model, automate/generalize, get other stats diff --git a/tools/push_model.py b/tools/push_model.py index 3045204d2..76c03e045 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -6,6 +6,7 @@ import shutil import subprocess +from fast_llm.config import Field, config_class from fast_llm.engine.config_utils.checkpoint import CheckpointType from fast_llm.engine.config_utils.runnable import RunnableConfig @@ -26,7 +27,6 @@ raise ImportError("Please install huggingface_hub to use this script") from e -from fast_llm.config import Config, config_class, Field # isort:skip from fast_llm.tools.convert import ConversionConfig # isort:skip From 958889c0c56728ec8f87a11f91396cde3cccdbb0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 18 Oct 2024 10:08:51 -0400 Subject: [PATCH 4/8] Fixed, misc, backward compatible --- fast_llm/config.py | 61 +++++++++----- fast_llm/engine/config_utils/checkpoint.py | 67 +++++++++++---- fast_llm/engine/config_utils/run.py | 6 +- fast_llm/engine/multi_stage/config.py | 14 ++-- fast_llm/engine/multi_stage/fast_llm_model.py | 10 +-- fast_llm/engine/training/config.py | 10 ++- fast_llm/engine/training/trainer.py | 81 ++++++++++--------- fast_llm/layers/common/config.py | 1 + tests/test_checkpoint.py | 14 ++-- 9 files changed, 158 insertions(+), 106 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 869bf5391..01b1068a2 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -811,34 +811,51 @@ def _check_abstract(cls): if not cls.__class_validated__: raise RuntimeError(f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator.") - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls): """ We need to postpone validation until the class has been processed by the dataclass wrapper. """ - assert ( - cls.__class_validated__ - ), f"Parent class of config class {cls.__name__} has not been validated. Make sure to use the @config_class decorator." + for base_class in cls.__mro__: + if issubclass(base_class, Config): + assert cls.__class_validated__, ( + f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated." + f" Make sure to use the @config_class decorator." + ) cls.__class_validated__ = False - for key, value in cls.__dict__: + for name in list(cls.__dict__): + value = getattr(cls, name) if isinstance(value, FieldUpdate): - base_class_field = cls.get_field(key) - cls.__dict__[key] = Field( - desc=kwargs.pop("desc", base_class_field.desc), - doc=kwargs.pop("doc", base_class_field.doc), - hint=kwargs.pop("hint", base_class_field.hint), - valid=kwargs.pop("valid", base_class_field.valid), - default=kwargs.pop("default", base_class_field.default), - default_factory=kwargs.pop("default_factory", base_class_field.default_factory), - repr=kwargs.pop("repr", base_class_field.repr), - hash=kwargs.pop("hash", base_class_field.hash), - compare=kwargs.pop("compare", base_class_field.compare), - metadata=kwargs.pop("metadata", base_class_field.metadata), - kw_only=kwargs.pop("kw_only", base_class_field.kw_only), + # In case of multiple inheritance, the base class field may not appear in `cls.__dataclass_fields__`. + # so we iterate over superclasses following mro and use the first match. + base_class_field = None + for base_class in cls.__mro__: + base_class_fields = getattr(base_class, "__dataclass_fields__", {}) + if name in base_class_fields: + base_class_field = base_class_fields[name] + break + if base_class_field is None: + raise RuntimeError(f"Trying to update the non-existent field {name} in class {get_type_name(cls)}") + setattr( + cls, + name, + Field( + desc=value.pop("desc", base_class_field.desc), + doc=value.pop("doc", base_class_field.doc), + hint=value.pop("hint", base_class_field.hint), + valid=value.pop("valid", base_class_field.valid), + default=value.pop("default", base_class_field.default), + default_factory=value.pop("default_factory", base_class_field.default_factory), + repr=value.pop("repr", base_class_field.repr), + hash=value.pop("hash", base_class_field.hash), + compare=value.pop("compare", base_class_field.compare), + metadata=value.pop("metadata", base_class_field.metadata), + kw_only=value.pop("kw_only", base_class_field.kw_only), + ), ) - if key in cls.__annotations__: + if name in cls.__annotations__: # TODO: Generalize to other type hints. - if isinstance(cls.__annotations__[key], type) and isinstance(base_class_field.type, type): - Assert.custom(issubclass, cls.__annotations__[key], base_class_field.type) + if isinstance(cls.__annotations__[name], type) and isinstance(base_class_field.type, type): + Assert.custom(issubclass, cls.__annotations__[name], base_class_field.type) else: # dataclasses expects an annotation, so we use the one from the base class. - cls.__annotations__[key] = base_class_field.type + cls.__annotations__[name] = base_class_field.type diff --git a/fast_llm/engine/config_utils/checkpoint.py b/fast_llm/engine/config_utils/checkpoint.py index 1a279f3c9..a5a7b157d 100644 --- a/fast_llm/engine/config_utils/checkpoint.py +++ b/fast_llm/engine/config_utils/checkpoint.py @@ -2,6 +2,7 @@ import enum import logging import pathlib +import warnings from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.data_type import DataType @@ -23,6 +24,25 @@ class CheckpointType(str, enum.Enum): external = "external" +class LoadConfig(str, enum.Enum): + none = "none" + architecture = "architecture" + model = "model" + fast_llm = "fast_llm" + + @property + def load_architecture(self): + return self != LoadConfig.none + + @property + def load_base_model(self): + return self in (LoadConfig.model, LoadConfig.fast_llm) + + @property + def load_fast_llm(self): + return self == LoadConfig.fast_llm + + @config_class() class CheckpointPathConfigBase(Config): _abstract = True @@ -46,15 +66,10 @@ class CheckpointConfigBase(Config): desc="Model type for external models (ex. Huggingace model name).", hint=FieldHint.feature, ) - architecture_config: bool = Field( - default=False, - desc="Save/load the model architecture configuration.", - hint=FieldHint.feature, - ) - base_model_config: bool = Field( - default=False, - desc="Save/load the full base model configuration, including the non-architecture fields.", - hint=FieldHint.feature, + load_config: LoadConfig = Field( + default=LoadConfig.architecture, + desc="Configuration to save/load.", + hint=FieldHint.core, ) fast_llm_config: bool = Field( default=False, @@ -64,14 +79,21 @@ class CheckpointConfigBase(Config): @property def compare_log_fn(self): - return logger.warning if self.architecture_config else ValueError - - def _validate(self): - if self.fast_llm_config: - self.base_model_config = True - if self.base_model_config: - self.architecture_config = True - super()._validate() + return ValueError if self.load_config.load_architecture else logger.warning + + @classmethod + def _from_dict( + cls, + default: dict[str], + strict: bool = True, + flat: bool = False, + ): + # TODO v0.2: Remove. + if default.get("format", None) == "huggingface": + warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.") + default["format"] = CheckpointType.external.value + cls._handle_renamed_field(default, "imported_type", "model_type") + return super()._from_dict(default, strict, flat) @config_class() @@ -80,6 +102,17 @@ class CheckpointStateConfigBase(Config): model_weights: bool = Field(default=True, desc="Save/load the model weights.", hint=FieldHint.feature) optimizer_state: bool = Field(default=False, desc="Save/load the optimizer state.", hint=FieldHint.feature) + @classmethod + def _from_dict( + cls, + default: dict[str], + strict: bool = True, + flat: bool = False, + ): + cls._handle_renamed_field(default, "load_weights", "model_weights") + cls._handle_renamed_field(default, "load_optimizer", "optimizer_state") + return super()._from_dict(default, strict, flat) + @config_class() class CheckpointSaveConfigBase(Config): diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index b088326e9..cd32b7d51 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -127,7 +127,6 @@ class Run: """ _experiment_dir: pathlib.Path | None - _checkpoint_dir: pathlib.Path | None def __init__( self, @@ -153,10 +152,7 @@ def __init__( if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() self.dataset_cache_dir = self._experiment_directory / "dataset_cache" - self._checkpoint_dir = self._experiment_directory / "checkpoints" - self._export_dir = self._experiment_directory / "export" if self._is_main_rank: - self._checkpoint_dir.mkdir(exist_ok=True, parents=True) (self._experiment_directory / "runs").mkdir(exist_ok=True, parents=True) run = len(list((self._experiment_directory / "runs").iterdir())) (self._experiment_directory / "runs" / str(run)).mkdir() @@ -170,7 +166,7 @@ def __init__( self._artifact_dir = run_dir / "artifacts" / str(self._distributed_config.rank) log_dir = run_dir / "logs" else: - _experiment_directory, self._checkpoint_dir, self._artifact_dir, log_dir = None, None, None, None + _experiment_directory, self._artifact_dir, log_dir = None, None, None self.dataset_cache_dir = None self.index = None diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 53abfbcaa..f491466ff 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -239,16 +239,16 @@ def from_metadata( return cls._from_metadata_v0(pretrained, metadata, default, updates) pretrained_config = cls.from_dict(metadata["fast_llm_config"]) - if pretrained.architecture_config: + if not pretrained.load_config.load_architecture: assert default is not None config = default.to_copy() config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn) - elif pretrained.fast_llm_config: + elif pretrained.load_config.load_fast_llm: config = pretrained_config else: with NoAutoValidate(): config = cls() if default is None else default.to_copy() - if pretrained.base_model_config: + if pretrained.load_config.load_base_model: config.base_model = pretrained_config.base_model else: config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture()) @@ -274,22 +274,22 @@ def _from_metadata_v0( with NoAutoValidate(): if default is None: - assert not pretrained.architecture_config + assert pretrained.load_config.load_architecture config = cls(base_model=base_model_config_cls()) else: config = default.to_copy() - if pretrained.architecture_config: + if pretrained.load_config.load_architecture: config.validate() architecture_config.compare_architecture(default.base_model, pretrained.compare_log_fn) else: - if pretrained.base_model_config: + if pretrained.load_config.load_base_model: # Replace the whole config config.base_model = base_model_config_cls.from_flat_dict(metadata["model_config"]) else: # Replace the architecture parts of the config. config.base_model = config.base_model.to_copy(architecture_config) - if pretrained.fast_llm_config: + if pretrained.load_config.load_fast_llm: config.multi_stage = MultiStageConfig.from_flat_dict(metadata["multi_stage_config"]) config.distributed = DistributedConfig.from_flat_dict( metadata["distributed_config"], diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 806c5d950..a009ea6ef 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -15,6 +15,7 @@ CheckpointLoadConfig, CheckpointSaveConfig, CheckpointType, + LoadConfig, ) from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode @@ -570,12 +571,7 @@ def import_state_tensor(self, shard_name: str, parameter_name: str, tensor: torc def _load_distributed_checkpoint(self, pretrained_config: CheckpointLoadConfig): # TODO: More safety checks metadata = self.config_class.load_pretrained_metadata(pretrained_config) - loaded_pretrained_config = pretrained_config.to_copy( - { - "base_model_config": True, - "fast_llm_config": True, - }, - ) + loaded_pretrained_config = pretrained_config.to_copy({"load_config": LoadConfig.fast_llm}) loaded_config = self.config_class.from_metadata( loaded_pretrained_config, metadata, @@ -640,7 +636,7 @@ def _import_checkpoint(self, pretrained_config: CheckpointLoadConfig): with self._LoadContext( self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True ) as context: - for name, tensor in converter.model_weights(pretrained_config.path, self._distributed.device): + for name, tensor in converter.load_weights(pretrained_config.path, self._distributed.device): assert name not in state_dict state_dict[name] = tensor for parameter_name, fast_llm_tensor in converter.convert_state_dict(state_dict, False).items(): diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 586c72ade..f1f2a3c43 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -42,7 +42,8 @@ class IntervalConfig(Config): ) def _validate(self): - self.offset %= self.interval + if self.interval: + self.offset %= self.interval super()._validate() def enabled(self, iteration: int | None = None): @@ -160,7 +161,9 @@ def get_iteration_count(self, training_iterations: int, extra_validations: int = @config_class() class CheckpointBaseConfig(IntervalConfig): + _abstract = True save_name: typing.ClassVar[str] = "save" + directory_name: typing.ClassVar[str] = "save" callback: CallbackConfig = Field( default_factory=CallbackConfig, desc="Callback (shell script).", @@ -196,7 +199,10 @@ def to_delete(self, iterations: list[int]): @config_class() class CheckpointConfig(CheckpointBaseConfig): + _abstract = False save_name: typing.ClassVar[str] = "checkpoint" + # TODO v0.2: Rename to `checkpoint` so we don't need this extra variable? + directory_name = "checkpoints" interval = FieldUpdate( desc="The number of training iterations between each checkpoint." " Setting to None will disable checkpoints." ) @@ -216,7 +222,9 @@ def get_save_config(self, path: pathlib.Path): @config_class() class ExportConfig(CheckpointBaseConfig, CheckpointConfigBase, CheckpointStateConfigBase, CheckpointSaveConfigBase): + _abstract = False save_name: typing.ClassVar[str] = "export" + directory_name = "export" interval = FieldUpdate( desc="The number of training iterations between each export." " Setting to None will disable exports." ) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index eca8a9121..ff8a5f029 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -365,6 +365,15 @@ def _evaluate( return metrics + def _get_data_iterator(self, phase, completed_steps: int = 0, prefetch_factor: int | None = None): + return self._data.get_iterator( + self._config.batch, + phase, + consumed_samples=completed_steps * self._config.batch.batch_size, + num_workers=self._config.training.num_workers, + prefetch_factor=prefetch_factor, + ) + def _prepare_training_state(self): # Setup the training state. if (last_iteration := self._get_last_checkpoint()) is None: @@ -387,36 +396,15 @@ def _prepare_training_state(self): Assert.eq(self._completed_steps, last_iteration or 0) assert self._multi_stage._is_loaded # noqa - def _load_checkpoint(self, config: CheckpointBaseConfig, iteration: int): - checkpoint_directory = self._run.experiment_directory / config.save_name / str(iteration) - Assert.custom(pathlib.Path.is_file, checkpoint_directory / "ok") - # TODO v0.2: Use config.get_load_config to make it generic - # TODO v0.2: Detect format instead of hard-coding - metadata = self._multi_stage.load_distributed_checkpoint_same_format(checkpoint_directory) - self._optimizer.load(metadata["optimizer"]) - if "schedules" in metadata: - # Backward compatibility. - self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] - else: - self._completed_steps = metadata["completed_steps"] - self._run.barrier(f"load {config.save_name} {iteration} exit") - - def _get_data_iterator(self, phase, completed_steps: int = 0, prefetch_factor: int | None = None): - return self._data.get_iterator( - self._config.batch, - phase, - consumed_samples=completed_steps * self._config.batch.batch_size, - num_workers=self._config.training.num_workers, - prefetch_factor=prefetch_factor, - ) - def _save_checkpoint(self, config: CheckpointBaseConfig, metrics: dict[PhaseType, dict[str, float | int]] | None): - checkpoint_base_directory = self._run.experiment_directory / config.save_name + # TODO v0.2: Move barrier, ok file to FastLLMModel + checkpoint_base_directory = self._run.experiment_directory / config.directory_name checkpoint_directory = checkpoint_base_directory / str(self._completed_steps) # Create the checkpoint - log_main_rank(f"Saving {config.save_name} at iteration {self._completed_steps}") - checkpoint_directory.mkdir(exist_ok=False, parents=True) + if self._run.is_main_rank: + logger.info(f"Saving {config.save_name} at iteration {self._completed_steps}") + checkpoint_directory.mkdir(exist_ok=False, parents=True) # Barrier to ensure the directory is created correctly (and didn't exist before). self._run.barrier(f"{config.save_name} {self._completed_steps} enter") @@ -431,27 +419,42 @@ def _save_checkpoint(self, config: CheckpointBaseConfig, metrics: dict[PhaseType # Barrier to ensure everyone is done. self._run.barrier(f"{config.save_name} {self._completed_steps} exit") # Mark the checkpoint as complete. - (checkpoint_directory / "ok").open("w") - logger.info(f"Saved {config.save_name} to {checkpoint_directory}") + if self._run.is_main_rank: + (checkpoint_directory / "ok").open("w") + logger.info(f"Saved {config.save_name} to {checkpoint_directory}") - to_delete = config.to_delete(sorted(int(path.name) for path in checkpoint_directory.iterdir())) + to_delete = config.to_delete(sorted(int(path.name) for path in checkpoint_base_directory.iterdir())) - for iteration in to_delete: - path = checkpoint_base_directory / str(iteration) - logger.info(f"Deleting {config.save_name} at {path}") - try: - shutil.rmtree(path, ignore_errors=True) - except OSError as e: - logger.warning(f"Could not remove {config.save_name} directory: {e.args}") + for iteration in to_delete: + path = checkpoint_base_directory / str(iteration) + logger.info(f"Deleting {config.save_name} at {path}") + try: + shutil.rmtree(path, ignore_errors=True) + except OSError as e: + logger.warning(f"Could not remove {config.save_name} directory: {e.args}") - if self._run.is_main_rank: # noqa config.callback.run() + def _load_checkpoint(self, config: CheckpointBaseConfig, iteration: int): + checkpoint_directory = self._run.experiment_directory / config.directory_name / str(iteration) + Assert.custom(pathlib.Path.is_file, checkpoint_directory / "ok") + # TODO v0.2: Use config.get_load_config to make it generic + # TODO v0.2: Detect format instead of hard-coding + metadata = self._multi_stage.load_distributed_checkpoint_same_format(checkpoint_directory) + self._optimizer.load(metadata["optimizer"]) + if "schedules" in metadata: + # Backward compatibility. + self._completed_steps = metadata["schedules"][PhaseType.training.value]["completed_steps"] + else: + self._completed_steps = metadata["completed_steps"] + # TODO v0.2: Move barrier, ok file to FastLLMModel + self._run.barrier(f"load {config.save_name} {iteration} exit") + def _get_last_checkpoint(self): if self._run.experiment_directory is None: return None - checkpoint_base_directory = self._run.experiment_directory / self._config.training.checkpoint.save_name - if self._run.is_main_rank: + checkpoint_base_directory = self._run.experiment_directory / self._config.training.checkpoint.directory_name + if self._run.is_main_rank and checkpoint_base_directory.is_dir(): checkpoints = [int(path.name) for path in checkpoint_base_directory.iterdir()] iteration = max(checkpoints) if checkpoints else -1 else: diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 74684dfe8..93c940e3b 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -58,6 +58,7 @@ def _from_dict( strict: bool = True, flat: bool = False, ): + # TODO v0.2: Remove. cls._handle_renamed_field(default, "normalization_type", "type") cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 4293ab2e1..e18a5e85b 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -7,7 +7,7 @@ import transformers import yaml -from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointType +from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointType, LoadConfig from fast_llm.engine.multi_stage.config import StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig @@ -211,8 +211,7 @@ def test_load_pretrained_distributed_checkpoint(): path=_CKPT_PATH, format=CheckpointType.distributed, optimizer_state=True, - base_model_config=True, - fast_llm_config=True, + load_config=LoadConfig.fast_llm, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_configs(config, model._base_model_config) @@ -364,12 +363,12 @@ def test_load_pretrained_in_dp2_match_checkpoint(): pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, format=CheckpointType.distributed, - fast_llm_config=True, + load_config=LoadConfig.fast_llm, ) pretrained_config_test = CheckpointLoadConfig( path=test_ckpt_path, format=CheckpointType.distributed, - fast_llm_config=True, + load_config=LoadConfig.fast_llm, ) config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) config_test = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_test) @@ -406,8 +405,7 @@ def test_load_distributed_checkpoint_dp2(): pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, format=CheckpointType.distributed, - base_model_config=True, - fast_llm_config=True, + load_config=LoadConfig.fast_llm, ) pretrained_config_test = CheckpointLoadConfig( path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1", @@ -463,7 +461,7 @@ def test_load_pretrained_huggingface_in_dp2(): "training.checkpoint.interval=1", "training.train_iters=1", f"pretrained.path={_CONVERT_PATH / 'huggingface_0'}", - f"pretrained.format=huggingface", + f"pretrained.format=external", "schedule.skip_step=True", ], num_gpus=2, From 5b508ee8e55f7b88feed77c66265095934b7a002 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 18 Oct 2024 13:37:54 -0400 Subject: [PATCH 5/8] Simpler saving --- examples/example_config.yaml | 12 +- fast_llm/engine/base_model/config.py | 4 +- fast_llm/engine/config_utils/checkpoint.py | 39 +-- fast_llm/engine/huggingface/config.py | 12 +- fast_llm/engine/multi_stage/checkpoint.py | 110 +++++++ fast_llm/engine/multi_stage/config.py | 10 +- fast_llm/engine/multi_stage/conversion.py | 39 ++- fast_llm/engine/multi_stage/fast_llm_model.py | 275 +++++------------- fast_llm/models/gpt/config.py | 4 +- tests/test_checkpoint.py | 10 +- 10 files changed, 258 insertions(+), 257 deletions(-) create mode 100644 fast_llm/engine/multi_stage/checkpoint.py diff --git a/examples/example_config.yaml b/examples/example_config.yaml index 4c4e6d383..c23d7c7b1 100644 --- a/examples/example_config.yaml +++ b/examples/example_config.yaml @@ -27,8 +27,8 @@ model: distributed_timeout: 60.0 training_dtype: float32 pretrained: - path: null - format: distributed + pretrained_checkpoint_path: null + pretrained_checkpoint_type: distributed batch: micro_batch_size: 1 depth_first_micro_batches: 1 @@ -42,13 +42,13 @@ data: - 969.0 - 30.0 - 1.0 - format: list - path: + dataset_source: list + data_path: - fkgtiu data_sample_warn_time_ms: 1000.0 profiling: - cuda: false - ranks: [] + profile_cuda: false + profile_ranks: [] optimizer: weight_decay: 0.01 initial_loss_scale: 65536.0 diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 66d4db581..433f1c92b 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -4,7 +4,7 @@ if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace - from fast_llm.engine.multi_stage.conversion import ModelConverter + from fast_llm.engine.multi_stage.conversion import ExternalModelConverter @config_class() @@ -30,7 +30,7 @@ def compare_architecture( return self.get_architecture().compare(model_config.get_architecture(), log_fn) @classmethod - def get_converter_class(cls, model_type: str | None = None) -> type["ModelConverter"]: + def get_converter_class(cls, model_type: str | None = None) -> type["ExternalModelConverter"]: raise NotImplementedError() diff --git a/fast_llm/engine/config_utils/checkpoint.py b/fast_llm/engine/config_utils/checkpoint.py index a5a7b157d..c2c0f3f05 100644 --- a/fast_llm/engine/config_utils/checkpoint.py +++ b/fast_llm/engine/config_utils/checkpoint.py @@ -24,7 +24,7 @@ class CheckpointType(str, enum.Enum): external = "external" -class LoadConfig(str, enum.Enum): +class ModelConfigType(str, enum.Enum): none = "none" architecture = "architecture" model = "model" @@ -32,15 +32,15 @@ class LoadConfig(str, enum.Enum): @property def load_architecture(self): - return self != LoadConfig.none + return self != ModelConfigType.none @property def load_base_model(self): - return self in (LoadConfig.model, LoadConfig.fast_llm) + return self in (ModelConfigType.model, ModelConfigType.fast_llm) @property def load_fast_llm(self): - return self == LoadConfig.fast_llm + return self == ModelConfigType.fast_llm @config_class() @@ -66,21 +66,12 @@ class CheckpointConfigBase(Config): desc="Model type for external models (ex. Huggingace model name).", hint=FieldHint.feature, ) - load_config: LoadConfig = Field( - default=LoadConfig.architecture, - desc="Configuration to save/load.", - hint=FieldHint.core, - ) fast_llm_config: bool = Field( default=False, desc="Save/load the full fast-llm model configuration, including the distributed and multi-stage configurations.", hint=FieldHint.feature, ) - @property - def compare_log_fn(self): - return ValueError if self.load_config.load_architecture else logger.warning - @classmethod def _from_dict( cls, @@ -131,20 +122,30 @@ class CheckpointSaveConfigBase(Config): @config_class() -class CheckpointMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase): +class CheckpointSaveMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase): _abstract = False @config_class() -class CheckpointSaveConfig(CheckpointMetadataConfig, CheckpointStateConfigBase, CheckpointSaveConfigBase): +class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateConfigBase, CheckpointSaveConfigBase): _abstract = False @config_class() -class CheckpointLoadConfig(CheckpointMetadataConfig, CheckpointStateConfigBase): +class CheckpointLoadMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase): _abstract = False + load_config: ModelConfigType = Field( + default=ModelConfigType.architecture, + desc="Configuration to save/load.", + hint=FieldHint.core, + ) + + @property + def compare_log_fn(self): + return ValueError if self.load_config.load_architecture else logger.warning + -# @config_class() -# class TrainingExportConfig(CheckpointConfigBase, CheckpointStateConfigBase, CheckpointSaveConfigBase): -# _abstract=False +@config_class() +class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): + _abstract = False diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index d47506204..5bd9e273e 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -5,7 +5,7 @@ import transformers -from fast_llm.engine.config_utils.checkpoint import CheckpointMetadataConfig, CheckpointType +from fast_llm.engine.config_utils.checkpoint import CheckpointLoadMetadataConfig, CheckpointType from fast_llm.engine.multi_stage.config import FastLLMModelConfig logger = logging.getLogger(__name__) @@ -36,7 +36,9 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = transformers.configuration_utils.CONFIG_NAME = _backup @classmethod - def _get_config_dict(cls, pretrained_model_name_or_path: str | os.PathLike | CheckpointMetadataConfig, **kwargs): + def _get_config_dict( + cls, pretrained_model_name_or_path: str | os.PathLike | CheckpointLoadMetadataConfig, **kwargs + ): # TODO: Support download from hub/url # Unused arguments, remove to avoid warnings. @@ -56,13 +58,13 @@ def _get_config_dict(cls, pretrained_model_name_or_path: str | os.PathLike | Che # Get the pretrained config. if "pretrained" in kwargs: - assert isinstance(kwargs["pretrained"], CheckpointMetadataConfig) + assert isinstance(kwargs["pretrained"], CheckpointLoadMetadataConfig) assert kwargs["pretrained"].path == pretrained_model_name_or_path pretrained = kwargs.pop("pretrained") - elif isinstance(pretrained_model_name_or_path, CheckpointMetadataConfig): + elif isinstance(pretrained_model_name_or_path, CheckpointLoadMetadataConfig): pretrained = pretrained_model_name_or_path else: - pretrained = CheckpointMetadataConfig( + pretrained = CheckpointLoadMetadataConfig( path=pathlib.Path(pretrained_model_name_or_path), format=CheckpointType.state_dict, ) diff --git a/fast_llm/engine/multi_stage/checkpoint.py b/fast_llm/engine/multi_stage/checkpoint.py new file mode 100644 index 000000000..f25aa26e5 --- /dev/null +++ b/fast_llm/engine/multi_stage/checkpoint.py @@ -0,0 +1,110 @@ +import json +import logging + +import safetensors.torch +import torch +import yaml + +from fast_llm.core.distributed import safe_barrier +from fast_llm.engine.config_utils.checkpoint import CheckpointSaveConfig +from fast_llm.engine.distributed.distributed import Distributed + +logger = logging.getLogger(__name__) + + +def _export_safetensors_metadata(metadata): + """ + Safetensor only accepts string entries, so we convert to string explicitly. + We use yaml rather than json because json requires explicit quotation marks on strings, which breaks things. + (ex. "format": "pt" becomes '"pt"' which breaks huggingface models.) + We avoid using safe_dump for scalars because it adds junk ("\n...\n") at the end of the string + (decoding is unaffected.) + """ + return { + key: str(value) if isinstance(value, (str, int, float, bool)) else yaml.safe_dump(value) + for key, value in metadata.items() + } + + +class StateDictSaver: + def __init__( + self, + config: CheckpointSaveConfig, + *, + distributed: Distributed, + metadata, + base_file_name: str, + ): + self._config = config + self._metadata = metadata + self.distributed = distributed + self._distributed_config = distributed.config + self.base_file_name = ( + base_file_name + if self._distributed_config.pipeline_parallel == 1 + else f"{base_file_name}_{self._distributed_config.pipeline_rank}" + ) + # All ranks reconstruct the pipeline-parallel state (for simplicity), but only one saves it. + self._do_save = self._distributed_config.data_rank == self._distributed_config.tensor_rank == 0 + + def add_tensor(self, name: str, tensor: torch.Tensor): + assert name not in self.tensors + self.tensors[name] = tensor + self.param_count += tensor.numel() + if self.param_count >= self._config.parameters_per_file: + self._save_next_file() + + def __enter__(self): + self.file_count = 0 + self.param_count = 0 + self.tensors = {} + self.index = {} + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.tensors: + # Save the last file. + self._save_next_file() + + if self._do_save and self._distributed_config.pipeline_parallel != 1: + # Combine the indexes from all pipeline ranks. + logger.info(f"Merging pipeline-parallel indexes.") + json.dump( + self.index, (self._config.path / f"index_{self._distributed_config.pipeline_rank}.json").open("w") + ) + safe_barrier(self.distributed.pipeline_group, "save state dict") + if self._distributed_config.pipeline_rank == 0: + self.index = {} + for rank in range(self._distributed_config.pipeline_parallel): + file_name = self._config.path / f"index_{rank}.json" + local_index = json.load(file_name.open("r")) + for key, value in local_index.items(): + assert key not in self.index, key + self.index[key] = value + file_name.unlink() + + if self._distributed_config.rank == 0: + path = self._config.path / f"{self.base_file_name}.safetensors.index.json" + logger.info(f"Saving index to {path}") + # Save the index. + json.dump( + {"metadata": self._metadata, "weight_map": self.index}, + path.open("w"), + indent=4, + ) + + def _save_next_file(self): + file_name = f"{self.base_file_name}_{self.file_count}.safetensors" + if self._do_save: + logger.info(f"Saving tensors to {self._config.path / file_name}") + safetensors.torch.save_file( + tensors=self.tensors, + filename=self._config.path / file_name, + metadata=_export_safetensors_metadata(self._metadata), + ) + for name_ in self.tensors: + assert name_ not in self.index + self.index[name_] = file_name + self.file_count += 1 + self.param_count = 0 + self.tensors = {} diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index f491466ff..197dbdeff 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -9,7 +9,7 @@ CHECKPOINT_VERSION, KNOWN_CHECKPOINT_VERSIONS, CheckpointLoadConfig, - CheckpointMetadataConfig, + CheckpointLoadMetadataConfig, CheckpointType, ) from fast_llm.engine.distributed.config import DistributedConfig @@ -211,7 +211,7 @@ def get_base_model_config_cls(cls) -> type[BaseModelConfig]: @classmethod def from_pretrained( cls, - pretrained: CheckpointMetadataConfig, + pretrained: CheckpointLoadMetadataConfig, default: "FastLLMModelConfig" = None, ): # TODO: Add *updates? @@ -222,7 +222,7 @@ def from_pretrained( @classmethod def from_metadata( cls, - pretrained: CheckpointMetadataConfig, + pretrained: CheckpointLoadMetadataConfig, metadata: dict, default: "FastLLMModelConfig" = None, updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -261,7 +261,7 @@ def from_metadata( @classmethod def _from_metadata_v0( cls, - pretrained: CheckpointMetadataConfig, + pretrained: CheckpointLoadMetadataConfig, metadata: dict, default: "FastLLMModelConfig" = None, updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -301,7 +301,7 @@ def _from_metadata_v0( return config @classmethod - def load_pretrained_metadata(cls, pretrained: CheckpointMetadataConfig): + def load_pretrained_metadata(cls, pretrained: CheckpointLoadMetadataConfig): import yaml base_model_config_cls = cls.get_base_model_config_cls() diff --git a/fast_llm/engine/multi_stage/conversion.py b/fast_llm/engine/multi_stage/conversion.py index 65ab7486e..4fa978209 100644 --- a/fast_llm/engine/multi_stage/conversion.py +++ b/fast_llm/engine/multi_stage/conversion.py @@ -138,6 +138,30 @@ def import_weight( class ModelConverter(abc.ABC): + base_file_name: typing.ClassVar[str] + + @abc.abstractmethod + def convert_state_dict( + self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool + ) -> dict[str, torch.Tensor | SafeTensorSlice]: + pass + + +class TrivialConverter(ModelConverter): + base_file_name = "state_dict" + + def convert_state_dict( + self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool + ) -> dict[str, torch.Tensor | SafeTensorSlice]: + out_state_dict = {} + for key in list(state_dict): + name, shard_name = key + out_state_dict[f"{name}/{shard_name}"] = state_dict.pop(key) + return out_state_dict + + +class ExternalModelConverter(ModelConverter): + base_file_name = "model" _base_model_cls: type[BaseModelConfig] _config_converters: list[ParamConverter] @@ -213,20 +237,21 @@ def from_config(cls, config: dict[str], architecture_only: bool = False): return cls(cls.import_config(config, architecture_only=architecture_only)) def convert_state_dict( - self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool + self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool ) -> dict[str, torch.Tensor | SafeTensorSlice]: out_state_dict = {} weight_converters = self._export_converters if export else self._import_converters - for state_dict_name in list(state_dict): + for state_dict_name, shard_name in list(state_dict): + assert shard_name == "weights" try: if state_dict_name not in weight_converters: continue weight_converter: WeightConverter = weight_converters[state_dict_name] in_names = weight_converter.fast_llm_name if export else weight_converter.export_name - if not all(name in state_dict for name in in_names): + if not all((name, shard_name) in state_dict for name in in_names): continue - in_weights = tuple(state_dict.pop(name) for name in in_names) + in_weights = tuple(state_dict.pop((name, shard_name)) for name in in_names) out_names = weight_converter.export_name if export else weight_converter.fast_llm_name out_weights = ( weight_converter.export_weight(in_weights) @@ -262,8 +287,8 @@ def _get_fast_llm_attribute(config: BaseModelArchitectureConfig, name: str | tup return val -class AutoModelConverter(ModelConverter, abc.ABC): - converter_map: dict[str, type[ModelConverter]] +class AutoModelConverter(ExternalModelConverter, abc.ABC): + converter_map: dict[str, type[ExternalModelConverter]] @classmethod def import_config(cls, config: dict[str], architecture_only: bool = False): @@ -274,7 +299,7 @@ def from_config(cls, config: dict[str], architecture_only: bool = False): return cls.converter_map[config["model_type"]].from_config(config, architecture_only) -class HuggingfaceModelConverter(ModelConverter, abc.ABC): +class HuggingfaceModelConverter(ExternalModelConverter, abc.ABC): model_type: str | None = None @classmethod diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index a009ea6ef..8dc45f8eb 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -8,17 +8,19 @@ import torch import yaml -from fast_llm.core.distributed import all_reduce, broadcast, safe_barrier +from fast_llm.core.distributed import all_reduce, broadcast from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.checkpoint import ( CHECKPOINT_VERSION, CheckpointLoadConfig, CheckpointSaveConfig, CheckpointType, - LoadConfig, + ModelConfigType, ) from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.checkpoint import StateDictSaver from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode +from fast_llm.engine.multi_stage.conversion import ModelConverter, TrivialConverter from fast_llm.engine.multi_stage.multi_stage import MultiStageModel from fast_llm.engine.multi_stage.stage import Stage from fast_llm.functional.triton.pointwise import triton_fill @@ -86,18 +88,73 @@ def save_checkpoint( checkpoint_config: CheckpointSaveConfig, metadata: dict | None = None, ): - if metadata is None: - metadata = {} - if checkpoint_config.format == CheckpointType.distributed: - self._save_distributed_checkpoint(checkpoint_config, metadata) + # TODO: Handle barriers, ok file, mkdir, etc. here + + num_shards = len(self._state_shard_names) if checkpoint_config.optimizer_state else 1 + metadata = { + "checkpoint_type": CheckpointType.distributed.value, + "checkpoint_version": str(CHECKPOINT_VERSION), + "fast_llm_config": self._fast_llm_config.to_serialized(), + "state_shard_names": list(self._state_shard_names[:num_shards]), + "metadata": {} if metadata is None else metadata, + } + + # TODO: Simplify branching. + if checkpoint_config.format == CheckpointType.external: + # TODO: Support optimizer? + assert not checkpoint_config.optimizer_state + converter_class = self._base_model_config.get_converter_class(checkpoint_config.model_type) + exported_config = converter_class.export_config(self._base_model_config) + converter_class.save_config(checkpoint_config.path, exported_config) + self._save_state_dict( + checkpoint_config, + converter_class(self._base_model_config), + { + "fast_llm_metadata": metadata, + "model_config": exported_config, + "format": "pt", + }, + ) elif checkpoint_config.format == CheckpointType.state_dict: - self._save_state_dict_checkpoint(checkpoint_config, metadata) - elif checkpoint_config.format == CheckpointType.external: - assert checkpoint_config.model_type is not None - self._export_checkpoint(checkpoint_config, metadata) + self._save_state_dict(checkpoint_config, TrivialConverter(), metadata) + elif checkpoint_config.format == CheckpointType.distributed: + if self._distributed_config.rank == 0: + yaml.safe_dump(metadata, (checkpoint_config.path / "metadata.yaml").open("w")) + safetensors.torch.save_file( + tensors={"state_shard": self._state_shard[:num_shards]}, + filename=checkpoint_config.path / f"rank_{self._distributed_config.rank}.safetensors", + metadata=_export_safetensors_metadata(metadata), + ) else: raise NotImplementedError(checkpoint_config.format) + def _save_state_dict(self, checkpoint_config: CheckpointSaveConfig, converter: ModelConverter, metadata: dict): + with StateDictSaver( + checkpoint_config, + distributed=self._distributed, + metadata=metadata, + base_file_name=converter.base_file_name, + ) as context: + # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from + # `fast_llm_state_dict` that are ready for conversion, + # and return a dict containing the converted tensors(s). + # If converting a tensor requires another one that is not yet available (e.g. for concatenation), + # it will remain in `fast_llm_state_dict` until that tensor is available. + fast_llm_state_dict = {} + for i, shard_name in enumerate( + self._state_shard_names if checkpoint_config.optimizer_state else self._state_shard_names[:1] + ): + shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) + for stage, shard in zip(self._stages_on_device.values(), shard_split): + for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.data_type): # noqa + assert name not in fast_llm_state_dict + fast_llm_state_dict[(name, shard_name)] = tensor + for exported_name, exported_tensor in converter.convert_state_dict( + fast_llm_state_dict, True + ).items(): + context.add_tensor(exported_name, exported_tensor) + assert not fast_llm_state_dict, list(fast_llm_state_dict) + def load_pretrained_checkpoint(self, pretrained_config: CheckpointLoadConfig): if pretrained_config.format == CheckpointType.distributed: # TODO: Check if same format. @@ -190,200 +247,6 @@ def _reset_shard_pads(self, optimizer: bool = False): counter += stage.reset_shard_pad(stage_shard) return counter - class _SaveContext: - def __init__( - self, - model: "FastLLMModel", - metadata, - *, - directory: pathlib.Path, - save_optimizer: bool, - ): - assert model._is_setup - assert model._is_loaded - self.save_optimizer = save_optimizer - self.num_shards = len(model._state_shard_names) if self.save_optimizer else 1 - self.self_shard = model._state_shard[: self.num_shards] - self.shard_names = model._state_shard_names[: self.num_shards] - self.directory = directory - self.metadata = { - "checkpoint_type": CheckpointType.distributed.value, - "checkpoint_version": str(CHECKPOINT_VERSION), - "fast_llm_config": model.fast_llm_config.to_serialized(), - "state_shard_names": list(model._state_shard_names[: self.num_shards]), - "metadata": metadata, - } - - def __enter__(self): - # Is a context for future-proofing. - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return - - class _SaveStateDictContext(_SaveContext): - def __init__( - self, - model: "FastLLMModel", - metadata, - *, - directory: pathlib.Path, - save_optimizer: bool, - base_file_name: str, - target_params_per_file, - ): - super().__init__(model, metadata, directory=directory, save_optimizer=save_optimizer) - self.distributed_config = model._distributed_config - self.distributed = model._distributed - self.base_file_name = ( - base_file_name - if self.distributed_config.pipeline_parallel == 1 - else f"{base_file_name}_{self.distributed_config.pipeline_rank}" - ) - self.target_params_per_file = target_params_per_file - # All ranks reconstruct the pipeline-parallel state (for simplicity), but only one saves it. - self.do_save = self.distributed_config.data_rank == self.distributed_config.tensor_rank == 0 - - def add_tensor(self, name: str, tensor: torch.Tensor): - assert name not in self.tensors - self.tensors[name] = tensor - self.param_count += tensor.numel() - if self.param_count >= self.target_params_per_file: - self._save_next_file() - - def __enter__(self): - self.file_count = 0 - self.param_count = 0 - self.tensors = {} - self.index = {} - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.tensors: - # Save the last file. - self._save_next_file() - - if self.do_save and self.distributed_config.pipeline_parallel != 1: - # Combine the indexes from all pipeline ranks. - logger.info(f"Merging pipeline-parallel indexes.") - json.dump( - self.index, (self.directory / f"index_{self.distributed_config.pipeline_rank}.json").open("w") - ) - safe_barrier(self.distributed.pipeline_group, "save state dict") - if self.distributed_config.pipeline_rank == 0: - self.index = {} - for rank in range(self.distributed_config.pipeline_parallel): - file_name = self.directory / f"index_{rank}.json" - local_index = json.load(file_name.open("r")) - for key, value in local_index.items(): - assert key not in self.index, key - self.index[key] = value - file_name.unlink() - - if self.distributed_config.rank == 0: - path = self.directory / f"{self.base_file_name}.safetensors.index.json" - logger.info(f"Saving index to {path}") - # Save the index. - json.dump( - {"metadata": self.metadata, "weight_map": self.index}, - path.open("w"), - indent=4, - ) - - def _save_next_file(self): - file_name = f"{self.base_file_name}_{self.file_count}.safetensors" - if self.do_save: - logger.info(f"Saving tensors to {self.directory/file_name}") - safetensors.torch.save_file( - tensors=self.tensors, - filename=self.directory / file_name, - metadata=_export_safetensors_metadata(self.metadata), - ) - for name_ in self.tensors: - assert name_ not in self.index - self.index[name_] = file_name - self.file_count += 1 - self.param_count = 0 - self.tensors = {} - - def _save_distributed_checkpoint(self, checkpoint_config: CheckpointSaveConfig, metadata: dict): - # TODO: Non blocking? - # TODO: CPU memory? - # TODO: Handle barriers, ok file, mkdir, etc. here - with self._SaveContext( - self, - metadata, - directory=checkpoint_config.path, - save_optimizer=checkpoint_config.optimizer_state, - ) as context: - if self._distributed_config.rank == 0: - yaml.safe_dump(context.metadata, (context.directory / "metadata.yaml").open("w")) - safetensors.torch.save_file( - tensors={"state_shard": self._state_shard[: context.num_shards]}, - filename=context.directory / f"rank_{self._distributed_config.rank}.safetensors", - metadata=_export_safetensors_metadata(context.metadata), - ) - - def _save_state_dict_checkpoint(self, checkpoint_config: CheckpointSaveConfig, metadata: dict): - # TODO: Make into a special case of _export_checkpoint? - # TODO: Handle barriers, ok file, mkdir, etc. here - with self._SaveStateDictContext( - self, - metadata, - directory=checkpoint_config.path, - save_optimizer=checkpoint_config.optimizer_state, - base_file_name="state_dict", - target_params_per_file=checkpoint_config.parameters_per_file, - ) as context: - for i, shard_name in enumerate( - self._state_shard_names if checkpoint_config.optimizer_state else self._state_shard_names[:1] - ): - shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) - for stage, shard in zip(self._stages_on_device.values(), shard_split): - for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.data_type): # noqa - context.add_tensor(f"{name}/{shard_name}", tensor) - - def _export_checkpoint(self, checkpoint_config: CheckpointSaveConfig, metadata: dict): - # TODO: Handle barriers, ok file, mkdir, etc. here - # TODO: Support optimizer? - assert not checkpoint_config.optimizer_state - with self._SaveStateDictContext( - self, - metadata, - directory=checkpoint_config.path, - save_optimizer=False, - base_file_name="model", - target_params_per_file=checkpoint_config.parameters_per_file, - ) as context: - converter_class = self._base_model_config.get_converter_class(checkpoint_config.model_type) - exported_config = converter_class.export_config(self._base_model_config) - converter_class.save_config(checkpoint_config.path, exported_config) - context.metadata = { - "fast_llm_metadata": context.metadata, - "model_config": exported_config, - "format": "pt", - } - converter = converter_class(self._base_model_config) - # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from - # `fast_llm_state_dict` that are ready for conversion, - # and return a dict containing the converted tensors(s). - # If converting a tensor requires another one that is not yet available (e.g. for concatenation), - # it will remain in `fast_llm_state_dict` until that tensor is available. - fast_llm_state_dict = {} - for i, shard_name in enumerate( - self._state_shard_names if checkpoint_config.optimizer_state else self._state_shard_names[:1] - ): - shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) - for stage, shard in zip(self._stages_on_device.values(), shard_split): - for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.data_type): # noqa - assert name not in fast_llm_state_dict - fast_llm_state_dict[name] = tensor - for exported_name, exported_tensor in converter.convert_state_dict( - fast_llm_state_dict, True - ).items(): - context.add_tensor(exported_name, exported_tensor) - assert not fast_llm_state_dict, list(fast_llm_state_dict) - class _LoadContext: # TODO: Improve def __init__(self, model: "FastLLMModel", *, safe: bool, load_optimizer: bool, reset_pads: bool): @@ -571,7 +434,7 @@ def import_state_tensor(self, shard_name: str, parameter_name: str, tensor: torc def _load_distributed_checkpoint(self, pretrained_config: CheckpointLoadConfig): # TODO: More safety checks metadata = self.config_class.load_pretrained_metadata(pretrained_config) - loaded_pretrained_config = pretrained_config.to_copy({"load_config": LoadConfig.fast_llm}) + loaded_pretrained_config = pretrained_config.to_copy({"load_config": ModelConfigType.fast_llm}) loaded_config = self.config_class.from_metadata( loaded_pretrained_config, metadata, @@ -638,7 +501,7 @@ def _import_checkpoint(self, pretrained_config: CheckpointLoadConfig): ) as context: for name, tensor in converter.load_weights(pretrained_config.path, self._distributed.device): assert name not in state_dict - state_dict[name] = tensor + state_dict[(name, "weights")] = tensor for parameter_name, fast_llm_tensor in converter.convert_state_dict(state_dict, False).items(): context.import_state_tensor("weights", parameter_name, fast_llm_tensor) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index b61ea8cd3..f13488e3e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -8,7 +8,7 @@ from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds if typing.TYPE_CHECKING: - from fast_llm.engine.multi_stage.conversion import ModelConverter + from fast_llm.engine.multi_stage.conversion import ExternalModelConverter @config_class() @@ -28,7 +28,7 @@ def _from_dict( return super()._from_dict(default, strict, flat) @classmethod - def get_converter_class(cls, model_type: str | None = None) -> type["ModelConverter"]: + def get_converter_class(cls, model_type: str | None = None) -> type["ExternalModelConverter"]: from fast_llm.models.gpt.conversion import AutoGPTConverter return AutoGPTConverter if model_type is None else AutoGPTConverter.converter_map[model_type] diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index e18a5e85b..7043170cd 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -7,7 +7,7 @@ import transformers import yaml -from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointType, LoadConfig +from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointType, ModelConfigType from fast_llm.engine.multi_stage.config import StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig @@ -211,7 +211,7 @@ def test_load_pretrained_distributed_checkpoint(): path=_CKPT_PATH, format=CheckpointType.distributed, optimizer_state=True, - load_config=LoadConfig.fast_llm, + load_config=ModelConfigType.fast_llm, ) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_ref) _compare_configs(config, model._base_model_config) @@ -363,12 +363,12 @@ def test_load_pretrained_in_dp2_match_checkpoint(): pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, format=CheckpointType.distributed, - load_config=LoadConfig.fast_llm, + load_config=ModelConfigType.fast_llm, ) pretrained_config_test = CheckpointLoadConfig( path=test_ckpt_path, format=CheckpointType.distributed, - load_config=LoadConfig.fast_llm, + load_config=ModelConfigType.fast_llm, ) config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) config_test = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_test) @@ -405,7 +405,7 @@ def test_load_distributed_checkpoint_dp2(): pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, format=CheckpointType.distributed, - load_config=LoadConfig.fast_llm, + load_config=ModelConfigType.fast_llm, ) pretrained_config_test = CheckpointLoadConfig( path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1", From b3a97c001845eef478c71402c9119eb07fd99b2b Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 18 Oct 2024 13:43:56 -0400 Subject: [PATCH 6/8] cleanup --- fast_llm/engine/config_utils/checkpoint.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/fast_llm/engine/config_utils/checkpoint.py b/fast_llm/engine/config_utils/checkpoint.py index c2c0f3f05..f46de6088 100644 --- a/fast_llm/engine/config_utils/checkpoint.py +++ b/fast_llm/engine/config_utils/checkpoint.py @@ -66,11 +66,6 @@ class CheckpointConfigBase(Config): desc="Model type for external models (ex. Huggingace model name).", hint=FieldHint.feature, ) - fast_llm_config: bool = Field( - default=False, - desc="Save/load the full fast-llm model configuration, including the distributed and multi-stage configurations.", - hint=FieldHint.feature, - ) @classmethod def _from_dict( From 48cfd9bfb3edc6078c58f028f7f3912ff77e84c6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 21 Oct 2024 08:13:26 -0400 Subject: [PATCH 7/8] fix --- fast_llm/engine/training/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index f1f2a3c43..ba1c6c5bc 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -214,7 +214,6 @@ def get_save_config(self, path: pathlib.Path): return CheckpointSaveConfig( path=path, format=CheckpointType.distributed, - fast_llm_config=True, model_weights=True, optimizer_state=True, ) From 3e5b33a545a8b702769bc28440fa77530142d927 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 21 Oct 2024 13:07:52 -0400 Subject: [PATCH 8/8] tweak --- fast_llm/config.py | 10 ++-- fast_llm/engine/config_utils/checkpoint.py | 13 +++-- fast_llm/engine/huggingface/config.py | 4 +- fast_llm/engine/huggingface/model.py | 4 +- fast_llm/engine/multi_stage/config.py | 12 ++-- fast_llm/engine/multi_stage/conversion.py | 16 ++--- fast_llm/engine/multi_stage/fast_llm_model.py | 18 +++--- fast_llm/engine/training/config.py | 4 +- fast_llm/layers/common/config.py | 4 +- fast_llm/layers/language_model/config.py | 4 +- fast_llm/tools/convert.py | 8 +-- tests/test_checkpoint.py | 58 +++++++++---------- tools/push_model.py | 6 +- 13 files changed, 82 insertions(+), 79 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 01b1068a2..b4ade6b99 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -278,7 +278,7 @@ class Config: __class_validated__: typing.ClassVar[bool] = True _abstract: typing.ClassVar[bool] = False _validated: bool = Field(init=False, repr=False) - _unknown_fields: dict[str] = Field(init=False, repr=False) + _unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False) def __post_init__(self): """ @@ -629,7 +629,7 @@ def _get_class_name(cls): @classmethod def from_dict( cls, - default: typing.Union["Config", dict[str]], + default: typing.Union["Config", dict[str, typing.Any]], *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True, ): @@ -654,7 +654,7 @@ def from_dict( @classmethod def from_flat_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, ): # TODO v0.2: Remove flat format @@ -663,7 +663,7 @@ def from_flat_dict( @classmethod def _from_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, flat: bool = False, ): @@ -776,7 +776,7 @@ def _from_dict_dict(cls, value, type_, strict: bool): return {key: cls._from_dict_nested(value_, args[1], strict) for key, value_ in value.items()} @classmethod - def _handle_renamed_field(cls, default: dict[str], old_name: str, new_name: str): + def _handle_renamed_field(cls, default: dict[str, typing.Any], old_name: str, new_name: str): if old_name in default: warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.") default[new_name] = default.pop(old_name) diff --git a/fast_llm/engine/config_utils/checkpoint.py b/fast_llm/engine/config_utils/checkpoint.py index f46de6088..dc77fd5e0 100644 --- a/fast_llm/engine/config_utils/checkpoint.py +++ b/fast_llm/engine/config_utils/checkpoint.py @@ -2,6 +2,7 @@ import enum import logging import pathlib +import typing import warnings from fast_llm.config import Config, Field, FieldHint, check_field, config_class @@ -15,7 +16,7 @@ KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1") -class CheckpointType(str, enum.Enum): +class CheckpointFormat(str, enum.Enum): # Distributed checkpoint for fast checkpointing and resuming. distributed = "distributed" # Model state dict, for safe long-term storage in Fast-LLM format. @@ -56,8 +57,8 @@ class CheckpointPathConfigBase(Config): @config_class() class CheckpointConfigBase(Config): _abstract = True - format: CheckpointType = Field( - default=CheckpointType.distributed, + format: CheckpointFormat = Field( + default=CheckpointFormat.distributed, desc="Format of the checkpoint.", hint=FieldHint.core, ) @@ -70,14 +71,14 @@ class CheckpointConfigBase(Config): @classmethod def _from_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, flat: bool = False, ): # TODO v0.2: Remove. if default.get("format", None) == "huggingface": warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.") - default["format"] = CheckpointType.external.value + default["format"] = CheckpointFormat.external.value cls._handle_renamed_field(default, "imported_type", "model_type") return super()._from_dict(default, strict, flat) @@ -91,7 +92,7 @@ class CheckpointStateConfigBase(Config): @classmethod def _from_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, flat: bool = False, ): diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 5bd9e273e..b53263e4a 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -5,7 +5,7 @@ import transformers -from fast_llm.engine.config_utils.checkpoint import CheckpointLoadMetadataConfig, CheckpointType +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadMetadataConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ def _get_config_dict( else: pretrained = CheckpointLoadMetadataConfig( path=pathlib.Path(pretrained_model_name_or_path), - format=CheckpointType.state_dict, + format=CheckpointFormat.state_dict, ) metadata = cls.model_config_class.load_pretrained_metadata(pretrained) updates = {} diff --git a/fast_llm/engine/huggingface/model.py b/fast_llm/engine/huggingface/model.py index e177ca796..45ab60cca 100644 --- a/fast_llm/engine/huggingface/model.py +++ b/fast_llm/engine/huggingface/model.py @@ -5,7 +5,7 @@ import transformers.modeling_outputs from fast_llm.config import NoAutoValidate -from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointType +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadConfig from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.huggingface.config import HuggingfaceModelConfig from fast_llm.engine.multi_stage.config import StageMode @@ -68,7 +68,7 @@ def from_pretrained( if not isinstance(pretrained_model_name_or_path, CheckpointLoadConfig): pretrained_model_name_or_path = CheckpointLoadConfig( path=pathlib.Path(pretrained_model_name_or_path), - format=CheckpointType.state_dict, + format=CheckpointFormat.state_dict, ) config_updates = {} diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 197dbdeff..2f04bab4a 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -8,9 +8,9 @@ from fast_llm.engine.config_utils.checkpoint import ( CHECKPOINT_VERSION, KNOWN_CHECKPOINT_VERSIONS, + CheckpointFormat, CheckpointLoadConfig, CheckpointLoadMetadataConfig, - CheckpointType, ) from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert @@ -231,7 +231,7 @@ def from_metadata( # TODO: Standardize to *updates? if "checkpoint_type" in metadata: # TODO python 3.12: Assert.incl(metadata["checkpoint_type"], CheckpointType) - CheckpointType(metadata["checkpoint_type"]) + CheckpointFormat(metadata["checkpoint_type"]) version = metadata["checkpoint_version"] if version not in KNOWN_CHECKPOINT_VERSIONS: raise ValueError(f"Unrecognised checkpoint version: {version}") @@ -305,16 +305,16 @@ def load_pretrained_metadata(cls, pretrained: CheckpointLoadMetadataConfig): import yaml base_model_config_cls = cls.get_base_model_config_cls() - if pretrained.format == CheckpointType.distributed: + if pretrained.format == CheckpointFormat.distributed: return yaml.safe_load((pretrained.path / "metadata.yaml").open("r")) - elif pretrained.format == CheckpointType.state_dict: + elif pretrained.format == CheckpointFormat.state_dict: return json.load((pretrained.path / f"state_dict.safetensors.index.json").open("r"))["metadata"] - elif pretrained.format == CheckpointType.external: + elif pretrained.format == CheckpointFormat.external: converter_class = base_model_config_cls.get_converter_class(pretrained.model_type) imported_model_config = converter_class.import_config(converter_class.load_config(pretrained.path), True) return { "fast_llm_config": {"base_model": imported_model_config.to_serialized()}, - "checkpoint_type": CheckpointType.external.value, + "checkpoint_type": CheckpointFormat.external.value, "checkpoint_version": CHECKPOINT_VERSION, } else: diff --git a/fast_llm/engine/multi_stage/conversion.py b/fast_llm/engine/multi_stage/conversion.py index 4fa978209..4acb9d928 100644 --- a/fast_llm/engine/multi_stage/conversion.py +++ b/fast_llm/engine/multi_stage/conversion.py @@ -189,12 +189,12 @@ def _create_weight_converters(self) -> list[WeightConverter]: @classmethod @abc.abstractmethod - def load_config(cls, directory: pathlib.Path | str) -> dict[str]: + def load_config(cls, directory: pathlib.Path | str) -> dict[str, typing.Any]: pass @classmethod @abc.abstractmethod - def save_config(cls, directory: pathlib.Path | str, config: dict[str]): + def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]): pass @abc.abstractmethod @@ -204,7 +204,7 @@ def load_weights( pass @classmethod - def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str]: + def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]: exported_config = {} for converter in cls._get_config_converters(): value = converter.export_param( @@ -218,7 +218,7 @@ def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str]: return exported_config # Noqa @classmethod - def import_config(cls, config: dict[str], architecture_only: bool = False): # noqa + def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): # noqa kwargs = {} for converter in cls._get_config_converters(): value = converter.import_param( @@ -233,7 +233,7 @@ def import_config(cls, config: dict[str], architecture_only: bool = False): # n return config_class.from_dict({}, kwargs) @classmethod - def from_config(cls, config: dict[str], architecture_only: bool = False): + def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): return cls(cls.import_config(config, architecture_only=architecture_only)) def convert_state_dict( @@ -291,11 +291,11 @@ class AutoModelConverter(ExternalModelConverter, abc.ABC): converter_map: dict[str, type[ExternalModelConverter]] @classmethod - def import_config(cls, config: dict[str], architecture_only: bool = False): + def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): return cls.converter_map[config["model_type"]].import_config(config, architecture_only) @classmethod - def from_config(cls, config: dict[str], architecture_only: bool = False): + def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): return cls.converter_map[config["model_type"]].from_config(config, architecture_only) @@ -317,7 +317,7 @@ def load_config(cls, directory: pathlib.Path | str): return config @classmethod - def save_config(cls, directory: pathlib.Path | str, config: dict[str]): + def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]): import transformers transformers.CONFIG_MAPPING[config["model_type"]].from_dict(config).save_pretrained(directory) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 8dc45f8eb..edcbdc0e2 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -12,9 +12,9 @@ from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.checkpoint import ( CHECKPOINT_VERSION, + CheckpointFormat, CheckpointLoadConfig, CheckpointSaveConfig, - CheckpointType, ModelConfigType, ) from fast_llm.engine.distributed.distributed import Distributed @@ -92,7 +92,7 @@ def save_checkpoint( num_shards = len(self._state_shard_names) if checkpoint_config.optimizer_state else 1 metadata = { - "checkpoint_type": CheckpointType.distributed.value, + "checkpoint_type": CheckpointFormat.distributed.value, "checkpoint_version": str(CHECKPOINT_VERSION), "fast_llm_config": self._fast_llm_config.to_serialized(), "state_shard_names": list(self._state_shard_names[:num_shards]), @@ -100,7 +100,7 @@ def save_checkpoint( } # TODO: Simplify branching. - if checkpoint_config.format == CheckpointType.external: + if checkpoint_config.format == CheckpointFormat.external: # TODO: Support optimizer? assert not checkpoint_config.optimizer_state converter_class = self._base_model_config.get_converter_class(checkpoint_config.model_type) @@ -115,9 +115,9 @@ def save_checkpoint( "format": "pt", }, ) - elif checkpoint_config.format == CheckpointType.state_dict: + elif checkpoint_config.format == CheckpointFormat.state_dict: self._save_state_dict(checkpoint_config, TrivialConverter(), metadata) - elif checkpoint_config.format == CheckpointType.distributed: + elif checkpoint_config.format == CheckpointFormat.distributed: if self._distributed_config.rank == 0: yaml.safe_dump(metadata, (checkpoint_config.path / "metadata.yaml").open("w")) safetensors.torch.save_file( @@ -156,12 +156,12 @@ def _save_state_dict(self, checkpoint_config: CheckpointSaveConfig, converter: M assert not fast_llm_state_dict, list(fast_llm_state_dict) def load_pretrained_checkpoint(self, pretrained_config: CheckpointLoadConfig): - if pretrained_config.format == CheckpointType.distributed: + if pretrained_config.format == CheckpointFormat.distributed: # TODO: Check if same format. self._load_distributed_checkpoint(pretrained_config) - elif pretrained_config.format == CheckpointType.state_dict: + elif pretrained_config.format == CheckpointFormat.state_dict: self._load_state_dict_checkpoint(pretrained_config) - elif pretrained_config.format == CheckpointType.external: + elif pretrained_config.format == CheckpointFormat.external: self._import_checkpoint(pretrained_config) else: raise NotImplementedError(pretrained_config.format) @@ -170,7 +170,7 @@ def load_distributed_checkpoint_same_format(self, directory: pathlib.Path): # TODO: Handle barriers, ok file, etc. here # TODO: More safety checks # TODO: Integrate to load_checkpoint. - pretrained_config = CheckpointLoadConfig(path=directory, format=CheckpointType.distributed) + pretrained_config = CheckpointLoadConfig(path=directory, format=CheckpointFormat.distributed) metadata = self.config_class.load_pretrained_metadata(pretrained_config) with self._LoadContext(self, safe=False, load_optimizer=True, reset_pads=False) as context: Assert.eq( diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index ba1c6c5bc..655f29677 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -9,10 +9,10 @@ from fast_llm.data.config import AbstractDataConfig from fast_llm.engine.config_utils.checkpoint import ( CheckpointConfigBase, + CheckpointFormat, CheckpointSaveConfig, CheckpointSaveConfigBase, CheckpointStateConfigBase, - CheckpointType, ) from fast_llm.engine.config_utils.run import ExperimentConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig @@ -213,7 +213,7 @@ class CheckpointConfig(CheckpointBaseConfig): def get_save_config(self, path: pathlib.Path): return CheckpointSaveConfig( path=path, - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, model_weights=True, optimizer_state=True, ) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 93c940e3b..1c896e3b1 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -54,7 +54,7 @@ class NormalizationArchitectureConfig(BaseModelArchitectureConfig): @classmethod def _from_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, flat: bool = False, ): @@ -107,7 +107,7 @@ def get_layer(self, hidden_dim: "TensorDim"): @classmethod def _from_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, flat: bool = False, ): diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d8afb2fca..4e3058995 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,3 +1,5 @@ +import typing + from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace @@ -84,7 +86,7 @@ def use_absolute_position_embeddings(self): @classmethod def from_flat_dict( cls, - default: dict[str], + default: dict[str, typing.Any], strict: bool = True, ): # The backward compatibility fix in `NormalizationArchitectureConfig` diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index 3e932f0ba..1ef3c0494 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -7,7 +7,7 @@ import typing from fast_llm.config import Field, config_class -from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointType +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadConfig, CheckpointSaveConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode @@ -23,8 +23,8 @@ @config_class() class ConversionConfig(RunnableConfig): - input_type: CheckpointType = Field() - output_type: CheckpointType = Field() + input_type: CheckpointFormat = Field() + output_type: CheckpointFormat = Field() input_path: pathlib.Path = Field() output_path: pathlib.Path = Field() model_type: str | None = Field(default=None) @@ -101,7 +101,7 @@ def run(self, model_config_class: type["FastLLMModelConfig"] | str): self._convert_model_partial(model_class, self.output_path) else: # TODO: Support other types? - assert self.output_type == CheckpointType.external + assert self.output_type == CheckpointFormat.external logger.info(f">>> Loading model config") # Create a dummy version to determine the stage split. model = model_class.from_pretrained( diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 7043170cd..5fd272a47 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -7,7 +7,7 @@ import transformers import yaml -from fast_llm.engine.config_utils.checkpoint import CheckpointLoadConfig, CheckpointType, ModelConfigType +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat, CheckpointLoadConfig, ModelConfigType from fast_llm.engine.multi_stage.config import StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig @@ -89,9 +89,9 @@ def _run_conversion(config: ConversionConfig): def test_convert_distributed_to_state_dict(): _run_conversion( ConversionConfig( - input_type=CheckpointType.distributed, + input_type=CheckpointFormat.distributed, input_path=_CKPT_PATH, - output_type=CheckpointType.state_dict, + output_type=CheckpointFormat.state_dict, output_path=_CONVERT_PATH / "state_dict_0", ) ) @@ -103,9 +103,9 @@ def test_convert_state_dict_to_huggingface(): pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( ConversionConfig( - input_type=CheckpointType.state_dict, + input_type=CheckpointFormat.state_dict, input_path=_CONVERT_PATH / "state_dict_0", - output_type=CheckpointType.external, + output_type=CheckpointFormat.external, output_path=_CONVERT_PATH / "huggingface_0", model_type=HUGGINGFACE_MODEL_TYPE, ) @@ -116,9 +116,9 @@ def test_convert_state_dict_to_huggingface(): def test_convert_huggingface_to_distributed(): _run_conversion( ConversionConfig( - input_type=CheckpointType.external, + input_type=CheckpointFormat.external, input_path=_CONVERT_PATH / "huggingface_0", - output_type=CheckpointType.distributed, + output_type=CheckpointFormat.distributed, output_path=_CONVERT_PATH / "distributed_0", ) ) @@ -130,9 +130,9 @@ def test_convert_distributed_to_huggingface(): pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( ConversionConfig( - input_type=CheckpointType.distributed, + input_type=CheckpointFormat.distributed, input_path=_CKPT_PATH, - output_type=CheckpointType.external, + output_type=CheckpointFormat.external, output_path=_CONVERT_PATH / "huggingface_1", model_type=HUGGINGFACE_MODEL_TYPE, ) @@ -143,9 +143,9 @@ def test_convert_distributed_to_huggingface(): def test_convert_huggingface_to_state_dict(): _run_conversion( ConversionConfig( - input_type=CheckpointType.external, + input_type=CheckpointFormat.external, input_path=_CONVERT_PATH / "huggingface_1", - output_type=CheckpointType.state_dict, + output_type=CheckpointFormat.state_dict, output_path=_CONVERT_PATH / "state_dict_1", ) ) @@ -155,9 +155,9 @@ def test_convert_huggingface_to_state_dict(): def test_convert_state_dict_to_distributed(): _run_conversion( ConversionConfig( - input_type=CheckpointType.state_dict, + input_type=CheckpointFormat.state_dict, input_path=_CONVERT_PATH / "state_dict_1", - output_type=CheckpointType.distributed, + output_type=CheckpointFormat.distributed, output_path=_CONVERT_PATH / "distributed_1", ) ) @@ -209,7 +209,7 @@ def test_load_pretrained_distributed_checkpoint(): ) pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, optimizer_state=True, load_config=ModelConfigType.fast_llm, ) @@ -223,14 +223,14 @@ def test_load_pretrained_distributed_checkpoint(): @pytest.mark.depends(on=["test_load_pretrained_distributed_checkpoint"]) def test_load_converted_distributed_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=CheckpointType.distributed) + pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=CheckpointFormat.distributed) pretrained_config_0 = CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_0", - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) pretrained_config_1 = CheckpointLoadConfig( path=_CONVERT_PATH / "distributed_1", - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) @@ -245,9 +245,9 @@ def test_load_converted_distributed_checkpoint(): @pytest.mark.depends(on=["test_converted_state_dict", "test_load_pretrained_distributed_checkpoint"]) def test_load_converted_state_dict_checkpoint(): - pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=CheckpointType.distributed) - pretrained_config_0 = CheckpointLoadConfig(path=_CONVERT_PATH / "state_dict_0", format=CheckpointType.state_dict) - pretrained_config_1 = CheckpointLoadConfig(path=_CONVERT_PATH / "state_dict_1", format=CheckpointType.state_dict) + pretrained_config_ref = CheckpointLoadConfig(path=_CKPT_PATH, format=CheckpointFormat.distributed) + pretrained_config_0 = CheckpointLoadConfig(path=_CONVERT_PATH / "state_dict_0", format=CheckpointFormat.state_dict) + pretrained_config_1 = CheckpointLoadConfig(path=_CONVERT_PATH / "state_dict_1", format=CheckpointFormat.state_dict) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0) config_1 = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_1) @@ -263,15 +263,15 @@ def test_load_converted_state_dict_checkpoint(): def test_load_converted_huggingface_checkpoint(): pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) pretrained_config_0 = CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", - format=CheckpointType.external, + format=CheckpointFormat.external, ) pretrained_config_1 = CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", - format=CheckpointType.external, + format=CheckpointFormat.external, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0, mode=StageMode.weights) @@ -289,7 +289,7 @@ def test_run_converted_model(): model_ref = TEST_MODEL_HF_CLS.from_pretrained( CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) ) test_input = torch.randint( @@ -300,7 +300,7 @@ def test_run_converted_model(): model_from_hf = TEST_MODEL_HF_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", - format=CheckpointType.external, + format=CheckpointFormat.external, ) ) errors = [] @@ -362,12 +362,12 @@ def test_load_pretrained_in_dp2_match_checkpoint(): test_ckpt_path = TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1" pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, load_config=ModelConfigType.fast_llm, ) pretrained_config_test = CheckpointLoadConfig( path=test_ckpt_path, - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, load_config=ModelConfigType.fast_llm, ) config_ref = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) @@ -404,12 +404,12 @@ def test_load_distributed_checkpoint_dp2(): # This also tests conversion which uses `FastLLMModel.from_checkpoint` pretrained_config_ref = CheckpointLoadConfig( path=_CKPT_PATH, - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, load_config=ModelConfigType.fast_llm, ) pretrained_config_test = CheckpointLoadConfig( path=TEST_RESULTS_PATH / f"test_{TEST_MODEL}_load_pretrained_distributed_in_dp2" / "checkpoints" / "1", - format=CheckpointType.distributed, + format=CheckpointFormat.distributed, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_test, mode=StageMode.weights) diff --git a/tools/push_model.py b/tools/push_model.py index d6c99b04e..8a0f3747c 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -7,7 +7,7 @@ import subprocess from fast_llm.config import Field, config_class -from fast_llm.engine.config_utils.checkpoint import CheckpointType +from fast_llm.engine.config_utils.checkpoint import CheckpointFormat from fast_llm.engine.config_utils.runnable import RunnableConfig try: @@ -148,8 +148,8 @@ def run(self) -> None: checkpoint_path_hf = checkpoint_path.with_name(checkpoint_path.name + "_hf") # Block until the conversion is done ConversionConfig( - input_type=CheckpointType.distributed, - output_type=CheckpointType.external, + input_type=CheckpointFormat.distributed, + output_type=CheckpointFormat.external, input_path=checkpoint_path, output_path=checkpoint_path_hf, model_type=self.model_type,