diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 83741fbae..7c18c7a45 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -3,7 +3,6 @@ from fast_llm.config import Config, config_class if typing.TYPE_CHECKING: - from fast_llm.engine.checkpoint.external import ExternalStateDictConverter from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -29,10 +28,6 @@ def compare_architecture( ): return self.get_architecture().compare(model_config.get_architecture(), log_fn) - @classmethod - def get_converter_class(cls, model_type: str | None = None) -> type["ExternalStateDictConverter"]: - raise NotImplementedError() - @config_class() class BaseModelConfig(BaseModelArchitectureConfig): diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index fc390e01e..738eb0750 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -1,14 +1,20 @@ # TODO: Use packaging.version? (Safer but extra requirement) +import abc import enum import logging import pathlib import typing import warnings +import yaml + from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel + logger = logging.getLogger(__name__) # TODO: Use packaging.version? (Safer but extra requirement) @@ -16,13 +22,29 @@ KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1") -class CheckpointFormat(str, enum.Enum): +def export_safetensors_metadata(metadata): + """ + Safetensor only accepts string entries, so we convert to string explicitly. + We use yaml rather than json because json requires explicit quotation marks on strings, which breaks things. + (ex. "format": "pt" becomes '"pt"' which breaks huggingface models.) + We avoid using safe_dump for scalars because it adds junk ("\n...\n") at the end of the string + (decoding is unaffected.) + """ + return { + key: str(value) if isinstance(value, (str, int, float, bool)) else yaml.safe_dump(value) + for key, value in metadata.items() + } + + +def import_safetensors_metadata(metadata): + return {key: yaml.safe_load(value) for key, value in metadata.items()} + + +class CheckpointFormat(str): # Distributed checkpoint for fast checkpointing and resuming. distributed = "distributed" # Model state dict, for safe long-term storage in Fast-LLM format. state_dict = "state_dict" - # A checkpoint format external to Fast-LLM. - external = "external" class ModelConfigType(str, enum.Enum): @@ -57,16 +79,11 @@ class CheckpointPathConfigBase(Config): @config_class() class CheckpointConfigBase(Config): _abstract = True - format: CheckpointFormat = Field( + format: str = Field( default=CheckpointFormat.distributed, desc="Format of the checkpoint.", hint=FieldHint.core, ) - model_type: str | None = Field( - default=None, - desc="Model type for external models (ex. Huggingace model name).", - hint=FieldHint.feature, - ) @classmethod def _from_dict( @@ -76,10 +93,17 @@ def _from_dict( flat: bool = False, ): # TODO v0.2: Remove. - if default.get("format", None) == "huggingface": - warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.") - default["format"] = CheckpointFormat.external.value cls._handle_renamed_field(default, "imported_type", "model_type") + if "model_type" in default: + warnings.warn( + "`CheckpointConfigBase.model_type` is deprecated." + " Instead, use the model name directly as the checkpoint format." + ) + if default.get("format", None) in ("huggingface", "external"): + default["format"] = default.get("model_type") + if default["format"] is None: + default["format"] = "auto" + del default["model_type"] return super()._from_dict(default, strict, flat) @@ -151,8 +175,24 @@ def compare_log_fn(self): class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): _abstract = False - def _validate(self): - super()._validate() - if self.format == CheckpointFormat.external: - # TODO: Support optimizer? - assert not self.optimizer_state + +class Converter(abc.ABC): + # TODO: Rename? (Checkpointer? Saver?) + + def __init__(self, model: "FastLLMModel"): + self._model = model + + # TODO: save_metadata? + + @classmethod + @abc.abstractmethod + def load_metadata(cls, config: CheckpointLoadMetadataConfig): + pass + + @abc.abstractmethod + def save(self, config: CheckpointSaveConfig, metadata: dict): + pass + + @abc.abstractmethod + def load(self, config: CheckpointLoadConfig, metadata: dict): + pass diff --git a/fast_llm/engine/checkpoint/distributed.py b/fast_llm/engine/checkpoint/distributed.py new file mode 100644 index 000000000..588fffe51 --- /dev/null +++ b/fast_llm/engine/checkpoint/distributed.py @@ -0,0 +1,93 @@ +import logging + +import safetensors.torch +import torch +import yaml + +from fast_llm.engine.checkpoint.config import ( + CheckpointLoadConfig, + CheckpointLoadMetadataConfig, + CheckpointSaveConfig, + Converter, + ModelConfigType, + export_safetensors_metadata, +) +from fast_llm.engine.checkpoint.safe_load import SafeLoad +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +class DistributedConverter(Converter): + + @classmethod + def load_metadata(cls, config: CheckpointLoadMetadataConfig): + return yaml.safe_load((config.path / "metadata.yaml").open("r")) + + def save(self, config: CheckpointSaveConfig, metadata: dict): + if self._model.distributed_config.rank == 0: + yaml.safe_dump(metadata, (config.path / "metadata.yaml").open("w")) + num_shards = len(self._model.state_shard_names) if config.optimizer_state else 1 + safetensors.torch.save_file( + tensors={"state_shard": self._model.state_shard[:num_shards]}, + filename=config.path / f"rank_{self._model.distributed_config.rank}.safetensors", + metadata=export_safetensors_metadata(metadata), + ) + + def load(self, config: CheckpointLoadConfig, metadata: dict): + # TODO: More safety checks + loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm}) + loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata) + num_shards = self._model.num_state_shards if config.optimizer_state else 1 + Assert.eq(metadata["state_shard_names"][:num_shards], list(self._model.state_shard_names[:num_shards])) + + if ( + loaded_config.to_serialized(verbose=None) == self._model.fast_llm_config.to_serialized(verbose=None) + and config.optimizer_state + ): + logger.info("Checkpoint format matches, using fast load") + # TODO: Add version without optimizer state? + with safetensors.safe_open( + config.path / f"rank_{self._model.distributed_config.rank}.safetensors", + framework="pt", + device=str(self._model.distributed.device), + ) as f: + # TODO: Does this copy twice? + self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards]) + else: + logger.info("Checkpoint format doesn't match, using safe load") + self._model.base_model_config.compare_architecture(loaded_config.base_model, config.compare_log_fn) + with SafeLoad(self._model, num_shards=num_shards) as context: + for rank in range(loaded_config.distributed.world_size): + loaded_model = self._model.__class__( + loaded_config.to_copy({("distributed", "rank"): rank}), + optimizer_state_names=self._model.state_shard_names[1:num_shards], + verbose=False, + ) + path = config.path / f"rank_{rank}.safetensors" + logger.info(f"Loading from {path}") + # TODO: skip shards without overlap. + with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f: + # TODO: Use self_shard + loaded_shard = f.get_slice("state_shard")[:num_shards] + loaded_model.state_shard_meta.validate(loaded_shard) + + # TODO: Improve num shard selection. + self_shard_split = self._model.state_shard[: loaded_shard.size(0)].split( + self._model.stage_shard_sizes, 1 + ) + loaded_shard_split = loaded_shard.split(loaded_model.stage_shard_sizes, 1) + + counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device) + for loaded_shard_index, loaded_stage in enumerate(loaded_model.stages_on_device.values()): + loaded_shards = ( + loaded_shard_split[loaded_shard_index].to(self._model.distributed.device).unbind(0) + ) + for self_shard_index, self_stage in enumerate(self._model.stages_on_device.values()): + self_stage._copy_shard_overlaps( # noqa + loaded_stage, + self_shard_split[self_shard_index].unbind(0), + loaded_shards, + counter, + ) + context.mark_as_loaded(counter.item()) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index d808e8922..88ffedfbb 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -9,7 +9,14 @@ import torch from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig +from fast_llm.engine.checkpoint.config import ( + CHECKPOINT_VERSION, + CheckpointLoadConfig, + CheckpointLoadMetadataConfig, + CheckpointSaveConfig, +) from fast_llm.engine.checkpoint.state_dict import StateDictConverter +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert @@ -139,13 +146,12 @@ def import_weight( class ExternalStateDictConverter(StateDictConverter): - base_file_name = "model" _base_model_cls: type[BaseModelConfig] _config_converters: list[ParamConverter] - def __init__(self, config: BaseModelArchitectureConfig): - self.config = config - Assert.custom(isinstance, config, self._base_model_cls.architecture_cls) + def __init__(self, model: "FastLLMModel"): + super().__init__(model) + Assert.custom(isinstance, self._model.base_model_config, self._base_model_cls.architecture_cls) weight_converters = self._create_weight_converters() self._export_converters = { weight_converter.fast_llm_name[0]: weight_converter @@ -166,17 +172,7 @@ def _create_weight_converters(self) -> list[WeightConverter]: pass @classmethod - @abc.abstractmethod - def load_config(cls, directory: pathlib.Path | str) -> dict[str, typing.Any]: - pass - - @classmethod - @abc.abstractmethod - def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]): - pass - - @classmethod - def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]: + def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]: exported_config = {} for converter in cls._get_config_converters(): value = converter.export_param( @@ -190,7 +186,7 @@ def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing. return exported_config # Noqa @classmethod - def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): # noqa + def _import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): # noqa kwargs = {} for converter in cls._get_config_converters(): value = converter.import_param( @@ -204,11 +200,7 @@ def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = config_class = cls._base_model_cls.architecture_cls if architecture_only else cls._base_model_cls return config_class.from_dict({}, kwargs) - @classmethod - def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): - return cls(cls.import_config(config, architecture_only=architecture_only)) - - def convert_state_dict( + def _convert_state_dict( self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool ) -> dict[str, torch.Tensor | SafeTensorSlice]: out_state_dict = {} @@ -262,19 +254,56 @@ class AutoStateDictConverter(ExternalStateDictConverter, abc.ABC): converter_map: dict[str, type[ExternalStateDictConverter]] @classmethod - def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): - return cls.converter_map[config["model_type"]].import_config(config, architecture_only) + def get_converter_class(cls, format: str): + if format in cls.converter_map: + return cls.converter_map[format] + elif format == "auto": + return cls + else: + raise NotImplementedError(format) + + # TODO: load_metadata??? @classmethod - def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): - return cls.converter_map[config["model_type"]].from_config(config, architecture_only) + def _import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): + # TODO: ??? + return cls.converter_map[config["model_type"]]._import_config(config, architecture_only) class HuggingfaceStateDictConverter(ExternalStateDictConverter, abc.ABC): model_type: str | None = None + base_file_name = "model" + + @classmethod + def load_metadata(cls, config: CheckpointLoadMetadataConfig): + imported_model_config = cls._import_config(cls._load_config(config.path), True) + return { + # TODO: Avoid `to_serialized`? + "fast_llm_config": {"base_model": imported_model_config.to_serialized()}, + # TODO: Handle "auto"? + "checkpoint_type": config.format, + "checkpoint_version": CHECKPOINT_VERSION, + } + + def save(self, config: CheckpointSaveConfig, metadata: dict): + huggingface_config = self._export_config(self._model.base_model_config) + self._save_config(config.path, huggingface_config) + metadata = { + "fast_llm_metadata": metadata, + "model_config": huggingface_config, + "format": "pt", + } + super().save(config, metadata) + + def load(self, config: CheckpointLoadConfig, metadata: dict): + assert not config.optimizer_state + self._model.base_model_config.compare_architecture( + self._base_model_cls.from_dict(metadata["fast_llm_config"]["base_model"]), config.compare_log_fn + ) + super().load(config, metadata) @classmethod - def get_key(cls, parameter_name: str, shard_name: str) -> str: + def _get_key(cls, parameter_name: str, shard_name: str) -> str: Assert.eq(shard_name, "weights") return parameter_name @@ -284,7 +313,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: return [ConstantExportParamConverter(None, "model_type", cls.model_type)] @classmethod - def load_config(cls, directory: pathlib.Path | str): + def _load_config(cls, directory: pathlib.Path | str): import transformers config = transformers.AutoConfig.from_pretrained(directory).to_dict() @@ -293,12 +322,12 @@ def load_config(cls, directory: pathlib.Path | str): return config @classmethod - def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]): + def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]): import transformers transformers.CONFIG_MAPPING[config["model_type"]].from_dict(config).save_pretrained(directory) - def load_weights( + def _load_weights( self, directory: pathlib.Path | str, device, diff --git a/fast_llm/engine/checkpoint/safe_load.py b/fast_llm/engine/checkpoint/safe_load.py new file mode 100644 index 000000000..bef2c16b3 --- /dev/null +++ b/fast_llm/engine/checkpoint/safe_load.py @@ -0,0 +1,162 @@ +import logging +import math + +import torch +from torch.distributed import all_reduce + +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.functional.triton.pointwise import triton_fill +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +class SafeLoad: + """ + A context with multiple safety checks to ensure the model state is loaded correctly: + * Pre-filling the state with nans and verify that no such value remains at the end. + This ensures that all values are set at least once. + * Keep a counter for the number of tensor values set, and validate against the expected number. + This ensures that all values are set at most once when the nan check succeeds. + * Optionally keep track of all set parameters and shards, and ensure that each is set exactly once. + + In case of failure, it will attempt to find out as precisely as possible where the problem comes from. + """ + + def __init__(self, model: "FastLLMModel", *, num_shards: int): + self._model = model + self._distributed = self._model.distributed + self._num_shards = num_shards + self._self_shard = self._model.state_shard[: self._num_shards] + + def __enter__(self): + self._loaded = 0 + self._loaded_parameters = {} + # Track the number of loaded entries. + # Use nan to mark non-loaded entries. + triton_fill(self._self_shard, math.nan) + # Reset and count shard pads + for shard in self._model.state_shard[: self._num_shards]: + shard_split = shard.split(self._model.stage_shard_sizes, 0) + for stage, stage_shard in zip(self._model.stages_on_device.values(), shard_split): + self._loaded += stage.reset_shard_pad(stage_shard) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_type: + self._validate() + + def mark_as_loaded(self, count: int, parameter: tuple[str, str] | None = None): + self._loaded += count + if parameter is not None: + parameter_name, shard_name = parameter + if shard_name not in self._loaded_parameters: + self._loaded_parameters[shard_name] = {} + Assert.not_incl(parameter_name, self._loaded_parameters[shard_name]) + self._loaded_parameters[shard_name][parameter_name] = count + + def _validate(self): + errors = [] + self._check_counter(errors) + self._check_missing(errors) + if self._loaded_parameters: + self._check_parameters(errors) + if errors: + for error in errors: + logger.error(error) + raise RuntimeError("Model loading validation failed. See logs for details.") + logger.info(f"{self._loaded:,} state entries loaded successfully") + + def _check_counter(self, errors: list[str]): + to_load = self._self_shard.numel() + if self._loaded != to_load: + # Ensure the right amount of weights is loaded. + errors.append(f"Loaded a total of {self._loaded:,}, state entries, expected {to_load:,}") + + def _check_missing(self, errors: list[str]): + # Ensure the loaded weights have a 1-1 mapping by looking for nans. + missing = self._self_shard.new_zeros([], dtype=torch.int64) + # Count nans in slices of 100M parameters to limit memory usage. + # TODO: Find better solution (triton kernel?) + for shard_slice in self._self_shard.flatten().split(100000000): + missing += shard_slice.isnan().sum() + local_missing = missing.item() + if self._distributed.world_group is not None: + all_reduce(missing, group=self._distributed.world_group) + global_missing = missing.item() + if global_missing: + errors.append(f"{global_missing:,} state entries failed to load or corrupted (local={local_missing:,}).") + # Determine where the missing values are coming from. + global_total, local_total = 0, 0 + for shard_name, shard_ in zip(self._model.state_shard_names[: self._num_shards], self._self_shard): + shard_split = shard_.split(self._model.stage_shard_sizes, 0) + for stage, shard in zip(self._model.stages_on_device.values(), shard_split): + buffer = stage._reconstruct_from_shard(shard) + for i, parameter in enumerate(stage._split_buffer(buffer)): + missing_for_param = parameter.isnan().sum().item() + if missing_for_param > 0: + global_total += missing_for_param + local_values = stage._split_shard(shard)[i] + local_missing_for_param = local_values.isnan().sum().item() + local_total += local_missing_for_param + errors.append( + f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {stage.parameter_names[i]} in stage {stage.index}, shard {shard_name}" + f" (locally {local_missing_for_param:,} out of {local_values.numel():,})" + ) + missing_for_pad = buffer[-stage._global_pad :].isnan().sum().item() + if missing_for_pad > 0: + global_total += missing_for_pad + local_missing_for_pad = ( + shard[-stage._shard_pad :].isnan().sum().item() if stage._shard_pad > 0 else 0 + ) + local_total += local_missing_for_pad + errors.append( + f"{missing_for_pad:,} values missing out of {stage._global_pad:,} for padding in stage {stage.index}, shard {shard_name}" + f" (locally {local_missing_for_pad:,} out of {stage._shard_pad:,})" + ) + if global_total != global_missing: + errors.append( + f"Incorrect global breakdown of missing state entries (expected {global_missing:,}, got {global_total:,})" + ) + if local_total != local_missing: + errors.append( + f"Incorrect local breakdown of missing state entries (expected {local_missing:,}, got {local_total:,})" + ) + + def _check_parameters(self, errors: list[str]): + loaded_shard_names = set(self._loaded_parameters) + shard_names = set(self._model.state_shard_names[: self._num_shards]) + if loaded_shard_names != shard_names: + errors.append(f"Incorrect loaded shards: {loaded_shard_names}!={shard_names}") + for shard_name in shard_names & loaded_shard_names: + counter_per_parameter = { + parameter_name: self._loaded_parameters[shard_name].pop(parameter_name, None) + for parameter_name in self._model.parameter_names + } + for parameter_name, count in self._loaded_parameters[shard_name].items(): + errors.append(f'Loaded unknown parameter "{parameter_name}" for shard "{shard_name}" (count={count})') + for parameter_name, counter in counter_per_parameter.items(): + if self._model.is_parameter_on_device(parameter_name): + if counter is None: + errors.append(f'Missing parameter "{parameter_name}" for shard "{shard_name}"') + elif counter is not None and counter > 0: + errors.append(f'Loaded off-device parameter : "{parameter_name}" for shard "{shard_name}"') + if self._distributed.world_group is not None: + counter_tensor = torch.tensor( + [counter or 0 for counter in counter_per_parameter.values()], dtype=torch.int64 + ).to(self._distributed.device) + all_reduce(counter_tensor, group=self._distributed.world_group) + counter_per_parameter = { + parameter_name: counter + for parameter_name, counter in zip(counter_per_parameter, counter_tensor.tolist()) + } + for parameter_name, counter in counter_per_parameter.items(): + parameter_size = ( + self._model.get_parameter_stage(parameter_name) + .get_parameter_meta(parameter_name) + .global_shape.numel() + ) + if counter != parameter_size: + errors.append( + f'Global counter mismatch for parameter "{parameter_name}" and shard "{shard_name}": {counter} != {parameter_size}' + ) diff --git a/fast_llm/engine/checkpoint/state_dict.py b/fast_llm/engine/checkpoint/state_dict.py index 048d13027..874e0ac73 100644 --- a/fast_llm/engine/checkpoint/state_dict.py +++ b/fast_llm/engine/checkpoint/state_dict.py @@ -7,10 +7,17 @@ import safetensors import safetensors.torch import torch -import yaml from fast_llm.core.distributed import safe_barrier -from fast_llm.engine.checkpoint.config import CheckpointSaveConfig +from fast_llm.engine.checkpoint.config import ( + CheckpointLoadConfig, + CheckpointLoadMetadataConfig, + CheckpointSaveConfig, + Converter, + export_safetensors_metadata, + import_safetensors_metadata, +) +from fast_llm.engine.checkpoint.safe_load import SafeLoad from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert @@ -18,40 +25,74 @@ logger = logging.getLogger(__name__) -def _import_safetensors_metadata(metadata): - return {key: yaml.safe_load(value) for key, value in metadata.items()} - - -def _export_safetensors_metadata(metadata): - """ - Safetensor only accepts string entries, so we convert to string explicitly. - We use yaml rather than json because json requires explicit quotation marks on strings, which breaks things. - (ex. "format": "pt" becomes '"pt"' which breaks huggingface models.) - We avoid using safe_dump for scalars because it adds junk ("\n...\n") at the end of the string - (decoding is unaffected.) - """ - return { - key: str(value) if isinstance(value, (str, int, float, bool)) else yaml.safe_dump(value) - for key, value in metadata.items() - } - - -class StateDictConverter(abc.ABC): +class StateDictConverter(Converter): base_file_name: typing.ClassVar[str] + def save(self, config: CheckpointSaveConfig, metadata: dict): + num_shards = len(self._model.state_shard_names) if config.optimizer_state else 1 + with StateDictSaver( + config, + distributed=self._model.distributed, + metadata=metadata, + base_file_name=self.base_file_name, + ) as context: + # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from + # `state_dict` that are ready for conversion, + # and return a dict containing the converted tensors(s). + # If converting a tensor requires another one that is not yet available (e.g. for concatenation), + # it will remain in `state_dict` until that tensor is available. + state_dict = {} + for parameter_name, shard_name, tensor in self._model.get_state_tensor_iterator( + self._model.state_shard_names[:num_shards], config.data_type + ): + if shard_name not in state_dict: + state_dict[shard_name] = {} + shard_state_dict = state_dict[shard_name] + assert parameter_name not in shard_state_dict + shard_state_dict[parameter_name] = tensor + for exported_name, exported_tensor in self._convert_state_dict(shard_state_dict, True).items(): + context.add_tensor(self._get_key(exported_name, shard_name), exported_tensor) + + for shard_name, shard_state_dict in state_dict.items(): + assert not shard_state_dict, (shard_name, list(state_dict)) + + def load(self, config: CheckpointLoadConfig, metadata: dict): + num_shards = len(self._model.state_shard_names) if config.optimizer_state else 1 + with SafeLoad(self._model, num_shards=num_shards) as context: + # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from + # `state_dict` that are ready for conversion, + # and return a dict containing the converted tensors(s). + # If converting a tensor requires another one that is not yet available (e.g. for concatenation), + # it will remain in `state_dict` until that tensor is available. + state_dict = {} + for parameter_name, shard_name, tensor in self._load_weights( + config.path, self._model.distributed.device, self._model.state_shard_names[:num_shards] + ): + if shard_name not in state_dict: + state_dict[shard_name] = {} + shard_state_dict = state_dict[shard_name] + assert parameter_name not in shard_state_dict + shard_state_dict[parameter_name] = tensor + for parameter_name, fast_llm_tensor in self._convert_state_dict(shard_state_dict, False).items(): + loaded = self._model.import_state_tensor(parameter_name, shard_name, fast_llm_tensor) + context.mark_as_loaded(loaded, (parameter_name, shard_name)) + + for shard_name, shard_state_dict in state_dict.items(): + assert not shard_state_dict, (shard_name, list(state_dict)) + @classmethod @abc.abstractmethod - def get_key(cls, parameter_name: str, shard_name: str) -> str: + def _get_key(cls, parameter_name: str, shard_name: str) -> str: pass @abc.abstractmethod - def convert_state_dict( + def _convert_state_dict( self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool ) -> dict[str, torch.Tensor | SafeTensorSlice]: pass @abc.abstractmethod - def load_weights( + def _load_weights( self, directory: pathlib.Path | str, device, @@ -64,17 +105,21 @@ class TrivialConverter(StateDictConverter): base_file_name = "state_dict" @classmethod - def get_key(cls, parameter_name: str, shard_name: str) -> str: + def load_metadata(cls, config: CheckpointLoadMetadataConfig): + return json.load((config.path / f"state_dict.safetensors.index.json").open("r"))["metadata"] + + @classmethod + def _get_key(cls, parameter_name: str, shard_name: str) -> str: return f"{parameter_name}/{shard_name}" - def convert_state_dict( + def _convert_state_dict( self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool ) -> dict[str, torch.Tensor | SafeTensorSlice]: out_state_dict = state_dict.copy() state_dict.clear() return out_state_dict - def load_weights( + def _load_weights( self, directory: pathlib.Path | str, device, @@ -90,7 +135,7 @@ def load_weights( framework="pt", device=str(device), ) as f: - metadata = _import_safetensors_metadata(f.metadata()) + metadata = import_safetensors_metadata(f.metadata()) Assert.eq(metadata["state_shard_names"][: len(shard_names)], list(shard_names)) for key in f.keys(): parameter_name, shard_name = key.split("/", 1) @@ -172,7 +217,7 @@ def _save_next_file(self): safetensors.torch.save_file( tensors=self.tensors, filename=self._config.path / file_name, - metadata=_export_safetensors_metadata(self._metadata), + metadata=export_safetensors_metadata(self._metadata), ) for name_ in self.tensors: assert name_ not in self.index diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/huggingface/config.py index 4d92c9222..97823089d 100644 --- a/fast_llm/engine/huggingface/config.py +++ b/fast_llm/engine/huggingface/config.py @@ -68,7 +68,7 @@ def _get_config_dict( path=pathlib.Path(pretrained_model_name_or_path), format=CheckpointFormat.state_dict, ) - metadata = cls.model_config_class.load_pretrained_metadata(pretrained) + metadata = cls.model_config_class.load_metadata(pretrained) updates = {} torch_dtype = kwargs.pop("torch_dtype", None) if torch_dtype is not None: diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index f9235f069..88e79df25 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -1,16 +1,15 @@ import enum -import json import logging import typing from fast_llm.config import Config, Field, FieldHint, NoAutoValidate, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import ( - CHECKPOINT_VERSION, KNOWN_CHECKPOINT_VERSIONS, CheckpointFormat, CheckpointLoadConfig, CheckpointLoadMetadataConfig, + Converter, ) from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert @@ -195,6 +194,23 @@ class FastLLMModelConfig(Config): default_factory=DistributedConfig, desc="Distributed configuration.", hint=FieldHint.core ) + @classmethod + def get_supported_checkpoint_formats(cls): + return CheckpointFormat.distributed, CheckpointFormat.state_dict + + @classmethod + def get_converter_class(cls, format: str) -> type["Converter"]: + if format == CheckpointFormat.distributed: + from fast_llm.engine.checkpoint.distributed import DistributedConverter + + return DistributedConverter + elif format == CheckpointFormat.state_dict: + from fast_llm.engine.checkpoint.state_dict import TrivialConverter + + return TrivialConverter + else: + raise NotImplementedError(format) + @classmethod def get_model_class(cls) -> type["FastLLMModel"]: raise NotImplementedError @@ -216,7 +232,7 @@ def from_pretrained( ): # TODO: Add *updates? assert pretrained.path is not None - metadata = cls.load_pretrained_metadata(pretrained) + metadata = cls.load_metadata(pretrained) return cls.from_metadata(pretrained, metadata, default) @classmethod @@ -301,24 +317,9 @@ def _from_metadata_v0( return config @classmethod - def load_pretrained_metadata(cls, pretrained: CheckpointLoadMetadataConfig): - import yaml - - base_model_config_cls = cls.get_base_model_config_cls() - if pretrained.format == CheckpointFormat.distributed: - return yaml.safe_load((pretrained.path / "metadata.yaml").open("r")) - elif pretrained.format == CheckpointFormat.state_dict: - return json.load((pretrained.path / f"state_dict.safetensors.index.json").open("r"))["metadata"] - elif pretrained.format == CheckpointFormat.external: - converter_class = base_model_config_cls.get_converter_class(pretrained.model_type) - imported_model_config = converter_class.import_config(converter_class.load_config(pretrained.path), True) - return { - "fast_llm_config": {"base_model": imported_model_config.to_serialized()}, - "checkpoint_type": CheckpointFormat.external.value, - "checkpoint_version": CHECKPOINT_VERSION, - } - else: - raise NotImplementedError(pretrained.format) + def load_metadata(cls, config: CheckpointLoadMetadataConfig): + converter_class = cls.get_converter_class(config.format) + return converter_class.load_metadata(config) @config_class() diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index b8ffde0c5..259671f59 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -1,21 +1,8 @@ import logging -import math import typing -import safetensors.torch -import torch -import yaml - -from fast_llm.core.distributed import all_reduce, broadcast -from fast_llm.engine.base_model.base_model import BaseModel -from fast_llm.engine.checkpoint.config import ( - CHECKPOINT_VERSION, - CheckpointFormat, - CheckpointLoadConfig, - CheckpointSaveConfig, - ModelConfigType, -) -from fast_llm.engine.checkpoint.state_dict import StateDictConverter, StateDictSaver, TrivialConverter +from fast_llm.core.distributed import broadcast +from fast_llm.engine.checkpoint.config import CHECKPOINT_VERSION, CheckpointLoadConfig, CheckpointSaveConfig from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.multi_stage import MultiStageModel @@ -25,234 +12,36 @@ logger = logging.getLogger(__name__) -def _export_safetensors_metadata(metadata): - """ - Safetensor only accepts string entries, so we convert to string explicitly. - We use yaml rather than json because json requires explicit quotation marks on strings, which breaks things. - (ex. "format": "pt" becomes '"pt"' which breaks huggingface models.) - We avoid using safe_dump for scalars because it adds junk ("\n...\n") at the end of the string - (decoding is unaffected.) - """ - return { - key: str(value) if isinstance(value, (str, int, float, bool)) else yaml.safe_dump(value) - for key, value in metadata.items() - } - - class FastLLMModel(MultiStageModel): - _is_setup: bool = False _is_loaded: bool = False - _distributed: Distributed - config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig - base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel - - def __init__( - self, - config: FastLLMModelConfig, - *, - optimizer_state_names: tuple[str, ...] = (), - verbose: bool = True, - # A filter to create only a subset of the stages. Used for model conversion. - stage_filter: set | None = None, - ): - self._base_model_config = config.base_model - self._fast_llm_config = config - super().__init__( - base_model=self.base_model_class(config.base_model, config.distributed), - multi_stage_config=config.multi_stage, - distributed_config=config.distributed, - optimizer_state_names=optimizer_state_names, - verbose=verbose, - stage_filter=stage_filter, - ) - - @property - def fast_llm_config(self): - return self._fast_llm_config - - @property - def distributed(self): - return self._distributed def save_checkpoint( self, config: CheckpointSaveConfig, - metadata: dict | None = None, + extra_metadata: dict | None = None, ): # TODO: Handle barriers, ok file, mkdir, etc. here - - num_shards = len(self._state_shard_names) if config.optimizer_state else 1 - metadata = { - "checkpoint_type": CheckpointFormat.distributed.value, + num_shards = self.num_state_shards if config.optimizer_state else 1 + fast_llm_metadata = { + "checkpoint_type": config.format, "checkpoint_version": str(CHECKPOINT_VERSION), "fast_llm_config": self._fast_llm_config.to_serialized(), "state_shard_names": list(self._state_shard_names[:num_shards]), - "metadata": {} if metadata is None else metadata, + "metadata": {} if extra_metadata is None else extra_metadata, } - - # TODO: Simplify branching. - if config.format == CheckpointFormat.external: - # TODO: Support optimizer? - assert not config.optimizer_state - converter_class = self._base_model_config.get_converter_class(config.model_type) - exported_config = converter_class.export_config(self._base_model_config) - converter_class.save_config(config.path, exported_config) - self._save_state_dict( - config, - converter_class(self._base_model_config), - { - "fast_llm_metadata": metadata, - "model_config": exported_config, - "format": "pt", - }, - ) - elif config.format == CheckpointFormat.state_dict: - self._save_state_dict(config, TrivialConverter(), metadata) - elif config.format == CheckpointFormat.distributed: - if self._distributed_config.rank == 0: - yaml.safe_dump(metadata, (config.path / "metadata.yaml").open("w")) - safetensors.torch.save_file( - tensors={"state_shard": self._state_shard[:num_shards]}, - filename=config.path / f"rank_{self._distributed_config.rank}.safetensors", - metadata=_export_safetensors_metadata(metadata), - ) - else: - raise NotImplementedError(config.format) - - def _save_state_dict(self, config: CheckpointSaveConfig, converter: StateDictConverter, metadata: dict): - with StateDictSaver( - config, - distributed=self._distributed, - metadata=metadata, - base_file_name=converter.base_file_name, - ) as context: - # The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from - # `fast_llm_state_dict` that are ready for conversion, - # and return a dict containing the converted tensors(s). - # If converting a tensor requires another one that is not yet available (e.g. for concatenation), - # it will remain in `fast_llm_state_dict` until that tensor is available. - fast_llm_state_dict = {} - for parameter_name, shard_name, tensor in self.get_state_tensor_iterator( - self._state_shard_names if config.optimizer_state else self._state_shard_names[:1], config.data_type - ): - if shard_name not in fast_llm_state_dict: - fast_llm_state_dict[shard_name] = {} - shard_state_dict = fast_llm_state_dict[shard_name] - assert parameter_name not in shard_state_dict - shard_state_dict[parameter_name] = tensor - for exported_name, exported_tensor in converter.convert_state_dict(shard_state_dict, True).items(): - context.add_tensor(converter.get_key(exported_name, shard_name), exported_tensor) - for shard_name, shard_state_dict in fast_llm_state_dict.items(): - assert not shard_state_dict, (shard_name, list(fast_llm_state_dict)) + converter = self._fast_llm_config.get_converter_class(config.format)(self) + converter.save(config, fast_llm_metadata) def load_checkpoint(self, config: CheckpointLoadConfig): # TODO: Simplify branching. # TODO: Test with more distributed configs. # TODO: Safety checks # TODO: Handle barriers, ok file, etc. here - metadata = self.config_class.load_pretrained_metadata(config) - if config.format == CheckpointFormat.distributed: - # TODO: Check if same format. - self._load_distributed_checkpoint(config, metadata) - elif config.format == CheckpointFormat.state_dict: - self._load_state_dict(config, TrivialConverter()) - elif config.format == CheckpointFormat.external: - # TODO: Support optimizer. - assert not config.optimizer_state - converter_class = self.base_model.architecture_cls().get_converter_class(config.model_type) - converter = converter_class.from_config(converter_class.load_config(config.path)) - self._base_model_config.compare_architecture(converter.config, config.compare_log_fn) - self._load_state_dict(config, converter) - else: - raise NotImplementedError(config.format) - return metadata.get("metadata") - - def _load_state_dict(self, config: CheckpointLoadConfig, converter: StateDictConverter): - num_shards = len(self._state_shard_names) if config.optimizer_state else 1 - with self._SafeLoadContext(self, num_shards=num_shards) as context: - state_dict = {} - for parameter_name, shard_name, tensor in converter.load_weights( - config.path, self._distributed.device, self._state_shard_names[:num_shards] - ): - if shard_name not in state_dict: - state_dict[shard_name] = {} - shard_state_dict = state_dict[shard_name] - assert parameter_name not in shard_state_dict - shard_state_dict[parameter_name] = tensor - for parameter_name, fast_llm_tensor in converter.convert_state_dict(shard_state_dict, False).items(): - stage_index = self._parameter_stages[parameter_name] - if stage_index not in self._stage_shard_indices: - # Tensor is not on this device. - return 0 - stage_shard = self._state_shard[self._state_shard_names.index(shard_name)].split( - self._stage_shard_sizes, 0 - )[self._stage_shard_indices[stage_index]] - loaded = self._stages[stage_index]._import_state_tensor( - stage_shard, parameter_name, fast_llm_tensor - ) # noqa - context.mark_as_loaded(loaded, (parameter_name, shard_name)) - - for shard_name, shard_state_dict in state_dict.items(): - assert not shard_state_dict, (shard_name, list(state_dict)) - - self._finalize_load(reset_optimizer=not config.optimizer_state) - - def _load_distributed_checkpoint(self, config: CheckpointLoadConfig, metadata: dict): - # TODO: More safety checks - loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm}) - loaded_config = self.config_class.from_metadata(loaded_config_dict, metadata) - num_shards = self._num_state_shards if config.optimizer_state else 1 - Assert.eq(metadata["state_shard_names"][:num_shards], list(self._state_shard_names[:num_shards])) - - if ( - loaded_config.to_serialized(verbose=None) == self._fast_llm_config.to_serialized(verbose=None) - and config.optimizer_state - ): - logger.info("Checkpoint format matches, using fast load") - # TODO: Add version without optimizer state? - with safetensors.safe_open( - config.path / f"rank_{self._distributed_config.rank}.safetensors", - framework="pt", - device=str(self._distributed.device), - ) as f: - # TODO: Does this copy twice? - self._state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards]) - else: - logger.info("Checkpoint format doesn't match, using safe load") - self._base_model_config.compare_architecture(loaded_config.base_model, config.compare_log_fn) - with self._SafeLoadContext(self, num_shards=num_shards) as context: - for rank in range(loaded_config.distributed.world_size): - loaded_model = self.__class__( - loaded_config.to_copy({("distributed", "rank"): rank}), - optimizer_state_names=self._state_shard_names[1:num_shards], - verbose=False, - ) - path = config.path / f"rank_{rank}.safetensors" - logger.info(f"Loading from {path}") - # TODO: skip shards without overlap. - with safetensors.safe_open(path, framework="pt", device=str(self._distributed.device)) as f: - # TODO: Use self_shard - loaded_shard = f.get_slice("state_shard")[:num_shards] - loaded_model._state_shard_meta.validate(loaded_shard) - - # TODO: Improve num shard selection. - self_shard_split = self._state_shard[: loaded_shard.size(0)].split(self._stage_shard_sizes, 1) - loaded_shard_split = loaded_shard.split(loaded_model._stage_shard_sizes, 1) - - counter = torch.zeros(1, dtype=torch.int64, device=self._distributed.device) - for loaded_shard_index, loaded_stage in enumerate(loaded_model._stages_on_device.values()): - loaded_shards = ( - loaded_shard_split[loaded_shard_index].to(self._distributed.device).unbind(0) - ) - for self_shard_index, self_stage in enumerate(self._stages_on_device.values()): - self_stage._copy_shard_overlaps( # noqa - loaded_stage, - self_shard_split[self_shard_index].unbind(0), - loaded_shards, - counter, - ) - context.mark_as_loaded(counter.item()) + fast_llm_metadata = self.config_class.load_metadata(config) + converter = self._fast_llm_config.get_converter_class(config.format)(self) + converter.load(config, fast_llm_metadata) self._finalize_load(reset_optimizer=not config.optimizer_state) + return fast_llm_metadata.get("metadata") @classmethod def from_pretrained( @@ -267,7 +56,7 @@ def from_pretrained( use_cpu: bool = False, stage_filter: set | None = None, ): - metadata = cls.config_class.load_pretrained_metadata(pretrained_config) + metadata = cls.config_class.load_metadata(pretrained_config) config = cls.config_class.from_metadata(pretrained_config, metadata, default_config, config_updates) if mode.support_training: if "state_shard_names" in metadata: @@ -312,147 +101,3 @@ def _finalize_load(self, reset_optimizer: bool = True): if self._mode.support_forward: self.invalidate_buffers() self._is_loaded = True - - class _SafeLoadContext: - # TODO: Improve - def __init__(self, model: "FastLLMModel", *, num_shards: int): - self._model = model - self._num_shards = num_shards - self._self_shard = self._model._state_shard[: self._num_shards] - - def __enter__(self): - self._loaded = 0 - self._loaded_parameters = {} - # Track the number of loaded entries. - # Use nan to mark non-loaded entries. - triton_fill(self._self_shard, math.nan) - # Reset and count shard pads - for shard in self._model._state_shard[: self._num_shards]: - shard_split = shard.split(self._model._stage_shard_sizes, 0) - for stage, stage_shard in zip(self._model._stages_on_device.values(), shard_split): - self._loaded += stage.reset_shard_pad(stage_shard) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_type: - self._validate() - - def mark_as_loaded(self, count: int, parameter: tuple[str, str] | None = None): - self._loaded += count - if parameter is not None: - parameter_name, shard_name = parameter - if shard_name not in self._loaded_parameters: - self._loaded_parameters[shard_name] = {} - Assert.not_incl(parameter_name, self._loaded_parameters[shard_name]) - self._loaded_parameters[shard_name][parameter_name] = count - - def _validate(self): - errors = [] - self._check_counter(errors) - self._check_missing(errors) - if self._loaded_parameters: - self._check_parameters(errors) - if errors: - for error in errors: - logger.error(error) - raise RuntimeError("Model loading validation failed. See logs for details.") - logger.info(f"{self._loaded:,} state entries loaded successfully") - - def _check_counter(self, errors: list[str]): - to_load = self._self_shard.numel() - if self._loaded != to_load: - # Ensure the right amount of weights is loaded. - errors.append(f"Loaded a total of {self._loaded:,}, state entries, expected {to_load:,}") - - def _check_missing(self, errors: list[str]): - # Ensure the loaded weights have a 1-1 mapping by looking for nans. - missing = self._self_shard.new_zeros([], dtype=torch.int64) - # Count nans in slices of 100M parameters to limit memory usage. - # TODO: Find better solution (triton kernel?) - for shard_slice in self._self_shard.flatten().split(100000000): - missing += shard_slice.isnan().sum() - local_missing = missing.item() - if self._model._distributed.world_group is not None: - all_reduce(missing, group=self._model._distributed.world_group) - global_missing = missing.item() - if global_missing: - errors.append( - f"{global_missing:,} state entries failed to load or corrupted (local={local_missing:,})." - ) - # Determine where the missing values are coming from. - global_total, local_total = 0, 0 - for shard_name, shard_ in zip(self._model._state_shard_names[: self._num_shards], self._self_shard): - shard_split = shard_.split(self._model._stage_shard_sizes, 0) - for stage, shard in zip(self._model._stages_on_device.values(), shard_split): - buffer = stage._reconstruct_from_shard(shard) - for i, parameter in enumerate(stage._split_buffer(buffer)): - missing_for_param = parameter.isnan().sum().item() - if missing_for_param > 0: - global_total += missing_for_param - local_values = stage._split_shard(shard)[i] - local_missing_for_param = local_values.isnan().sum().item() - local_total += local_missing_for_param - errors.append( - f"{missing_for_param:,} values missing out of {parameter.numel():,} for parameter {stage.parameter_names[i]} in stage {stage.index}, shard {shard_name}" - f" (locally {local_missing_for_param:,} out of {local_values.numel():,})" - ) - missing_for_pad = buffer[-stage._global_pad :].isnan().sum().item() - if missing_for_pad > 0: - global_total += missing_for_pad - local_missing_for_pad = ( - shard[-stage._shard_pad :].isnan().sum().item() if stage._shard_pad > 0 else 0 - ) - local_total += local_missing_for_pad - errors.append( - f"{missing_for_pad:,} values missing out of {stage._global_pad:,} for padding in stage {stage.index}, shard {shard_name}" - f" (locally {local_missing_for_pad:,} out of {stage._shard_pad:,})" - ) - if global_total != global_missing: - errors.append( - f"Incorrect global breakdown of missing state entries (expected {global_missing:,}, got {global_total:,})" - ) - if local_total != local_missing: - errors.append( - f"Incorrect local breakdown of missing state entries (expected {local_missing:,}, got {local_total:,})" - ) - - def _check_parameters(self, errors: list[str]): - loaded_shard_names = set(self._loaded_parameters) - shard_names = set(self._model._state_shard_names[: self._num_shards]) - if loaded_shard_names != shard_names: - errors.append(f"Incorrect loaded shards: {loaded_shard_names}!={shard_names}") - for shard_name in shard_names & loaded_shard_names: - counter_per_parameter = { - parameter_name: self._loaded_parameters[shard_name].pop(parameter_name, None) - for parameter_name in self._model._parameter_stages - } - for parameter_name, count in self._loaded_parameters[shard_name].items(): - errors.append( - f'Loaded unknown parameter "{parameter_name}" for shard "{shard_name}" (count={count})' - ) - for parameter_name, counter in counter_per_parameter.items(): - if self._model._parameter_stages[parameter_name] in self._model._stages_on_device: - if counter is None: - errors.append(f'Missing parameter "{parameter_name}" for shard "{shard_name}"') - elif counter is not None and counter > 0: - errors.append(f'Loaded off-device parameter : "{parameter_name}" for shard "{shard_name}"') - distributed = self._model._distributed - if distributed.world_group is not None: - counter_tensor = torch.tensor( - [counter or 0 for counter in counter_per_parameter.values()], dtype=torch.int64 - ).to(distributed.device) - all_reduce(counter_tensor, group=distributed.world_group) - counter_per_parameter = { - parameter_name: counter - for parameter_name, counter in zip(counter_per_parameter, counter_tensor.tolist()) - } - for parameter_name, counter in counter_per_parameter.items(): - parameter_size = ( - self._model._stages[self._model._parameter_stages[parameter_name]] - .get_parameter_meta(parameter_name) - .global_shape.numel() - ) - if counter != parameter_size: - errors.append( - f'Global counter mismatch for parameter "{parameter_name}" and shard "{shard_name}": {counter} != {parameter_size}' - ) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index a1a50b960..42ad40880 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -1,5 +1,6 @@ import dataclasses import logging +import typing import warnings import numpy as np @@ -10,12 +11,12 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import MultiStageConfig, StageMode +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.stage import Stage from fast_llm.engine.optimizer.config import ParamGroup -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta from fast_llm.utils import Assert, get_unique logger = logging.getLogger(__name__) @@ -29,22 +30,24 @@ class MultiStageModel: _optimizer_shard: torch.Tensor _distributed: Distributed _mode: StageMode + config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig + base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel def __init__( self, + config: FastLLMModelConfig, *, - base_model: BaseModel, - multi_stage_config: MultiStageConfig, - distributed_config: DistributedConfig, - optimizer_state_names: tuple[str, ...], + optimizer_state_names: tuple[str, ...] = (), verbose: bool = True, # A filter to create only a subset of the stages. Used for model conversion. stage_filter: set | None = None, ): super().__init__() - self._multi_stage_config = multi_stage_config - self._distributed_config = distributed_config - self._base_model = base_model + self._fast_llm_config = config + self._base_model_config = self._fast_llm_config.base_model + self._multi_stage_config = self._fast_llm_config.multi_stage + self._distributed_config = self._fast_llm_config.distributed + self._base_model = self.base_model_class(self._base_model_config, self._distributed_config) self._training = None self._verbose = verbose self._stage_filter = stage_filter @@ -123,8 +126,6 @@ def __init__( stage_shard_dtype = get_unique([stage.weight_shard_meta.dtype for stage in self._stages_on_device.values()]) self._state_shard_names = ("weights",) + optimizer_state_names - self._num_state_shards = len(self._state_shard_names) - self._num_shards = self._num_state_shards + 1 shard_dim = TensorDim("flat_shard", sum(self._stage_shard_sizes)) self._weight_shard_meta = TensorMeta.from_dims( @@ -133,12 +134,12 @@ def __init__( dtype=stage_shard_dtype, ) self._state_shard_meta = TensorMeta.from_dims( - (TensorDim("state_shards", self._num_state_shards), shard_dim), + (TensorDim("state_shards", self.num_state_shards), shard_dim), tensor_name=f"multi_stage_state_shard", dtype=stage_shard_dtype, ) self._full_shards_meta = TensorMeta.from_dims( - (TensorDim("shards", self._num_shards), shard_dim), + (TensorDim("shards", self.num_shards), shard_dim), tensor_name=f"multi_stage_state_shard", dtype=stage_shard_dtype, ) @@ -227,7 +228,7 @@ def setup(self, distributed: Distributed, mode: StageMode = StageMode.training): allocated += (mem := num_shards * self._full_shards_meta.memory_usage // self._full_shards_meta.size(0)) if self._verbose: log_model_parallel_main_rank( - f">>> Allocating {self._num_shards} x {len(self._stage_shard_sizes)}" + f">>> Allocating {self.num_shards} x {len(self._stage_shard_sizes)}" f" shards ({mem / 2 ** 20:,.2f} MiB)" ) shards = torch.empty_like(self._full_shards_meta[:num_shards], device=self._distributed.device) @@ -311,6 +312,10 @@ def get_param_groups(self, param_group_cls: type[ParamGroup] = ParamGroup): return param_groups, grads_for_norm + @property + def state_shard_meta(self): + return self._state_shard_meta + @property def support_forward(self): assert self._is_setup @@ -334,10 +339,38 @@ def base_model(self): def stages(self): return self._stages + @property + def state_shard(self): + return self._state_shard + + @property + def num_shards(self): + return len(self._state_shard_names) + 1 + + @property + def num_state_shards(self): + return len(self._state_shard_names) + + @property + def stages_on_device(self): + return self._stages_on_device + + @property + def fast_llm_config(self): + return self._fast_llm_config + + @property + def base_model_config(self): + return self._base_model_config + @property def multi_stage_config(self): return self._multi_stage_config + @property + def distributed_config(self): + return self._distributed_config + @property def tied_parameters(self): return self._tied_parameters @@ -350,6 +383,28 @@ def weight_buffer_indices(self): def grad_buffer_indices(self): return self._grad_buffer_indices + @property + def state_shard_names(self): + return self._state_shard_names + + @property + def stage_shard_sizes(self): + return self._stage_shard_sizes + + @property + def parameter_names(self): + return list(self._parameter_stages) + + def get_parameter_stage(self, parameter_name: str): + return self._stages[self._parameter_stages[parameter_name]] + + def is_parameter_on_device(self, parameter_name: str): + return self._parameter_stages[parameter_name] in self._stages_on_device + + @property + def distributed(self): + return self._distributed + def invalidate_buffers(self): for stage in self._stages_on_device.values(): stage.invalidate_buffer() @@ -367,6 +422,20 @@ def get_state_tensor_iterator(self, shard_names: list[str], data_type: DataType for name, tensor in stage._export_shard(shard, data_type=data_type): # noqa yield name, shard_name, tensor + def import_state_tensor(self, parameter_name: str, shard_name: str, tensor: torch.Tensor | SafeTensorSlice): + """ + Given a global parameter tensor, set the associated slice of a local parameter shard. + Return the size of the local slice. + """ + if not self.is_parameter_on_device(parameter_name): + # Parameter is not on device, nothing to do. + return 0 + stage_index = self._stage_shard_indices[self._parameter_stages[parameter_name]] + stage_shard = self._state_shard[self._state_shard_names.index(shard_name)].split(self._stage_shard_sizes, 0)[ + stage_index + ] + return self.get_parameter_stage(parameter_name).import_state_tensor(parameter_name, stage_shard, tensor) + def _split_into_stages(self): # Create stages (greedy split, could do better). stage_splits = [0] diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 3ed5a2ee1..c0f899144 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -442,8 +442,11 @@ def _copy_shard_overlaps( shard[begin:end][overlap_mask] = loaded_shard[overlap_index_map_masked] counter += overlap_count - def _import_state_tensor(self, shard: torch.Tensor, parameter_name: str, tensor: torch.Tensor | SafeTensorSlice): - # See FastLLMModel._import_state_tensor. + def import_state_tensor(self, parameter_name: str, shard: torch.Tensor, tensor: torch.Tensor | SafeTensorSlice): + """ + Given a global parameter tensor, set the associated slice of a local parameter shard. + Return the size of the local slice. + """ Assert.eq(shard.shape, (self._shard_size,)) parameter_index = self._parameter_index[parameter_name] tensor_shard = self._parameter_global_to_shard(tensor, parameter_index) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 59fbcddf2..e27862ee0 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -2,13 +2,26 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.config import DataConfig +from fast_llm.engine.checkpoint.config import Converter from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds if typing.TYPE_CHECKING: - from fast_llm.engine.checkpoint.external import ExternalStateDictConverter + pass + + +class HuggingfaceModelType: + """ + An enum for the huggingface models with conversion support. + """ + + auto = "auto" + starcoder2 = "starcoder2" + llama = "llama" + mistral = "mistral" + mixtral = "mixtral" @config_class() @@ -27,12 +40,6 @@ def _from_dict( assert default.pop("transposed_mlp_weight") return super()._from_dict(default, strict, flat) - @classmethod - def get_converter_class(cls, model_type: str | None = None) -> type["ExternalStateDictConverter"]: - from fast_llm.models.gpt.conversion import AutoGPTConverter - - return AutoGPTConverter if model_type is None else AutoGPTConverter.converter_map[model_type] - @config_class() class GPTBaseModelConfig(LanguageModelBaseConfig, GPTArchitectureConfig): @@ -79,6 +86,21 @@ def get_huggingface_model_class(cls): return HuggingfaceGPTModelForCausalLM + @classmethod + def get_supported_checkpoint_formats(cls): + return super().get_supported_checkpoint_formats() + tuple( + name for name in HuggingfaceModelType.__dict__ if not name.startswith("_") + ) + + @classmethod + def get_converter_class(cls, format: str) -> type["Converter"]: + try: + return super().get_converter_class(format) + except NotImplementedError: + from fast_llm.models.gpt.conversion import AutoGPTConverter + + return AutoGPTConverter.get_converter_class(format) + @config_class() class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @@ -104,14 +126,3 @@ def get_trainer_class(cls): from fast_llm.models.gpt.trainer import GPTTrainer return GPTTrainer - - -class HuggingfaceModelType: - """ - An enum for the huggingface models with conversion support. - """ - - starcoder2 = "starcoder2" - llama = "llama" - mistral = "mistral" - mixtral = "mixtral" diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 31d66c772..6f48bdec0 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -21,6 +21,7 @@ from fast_llm.layers.common.config import NormalizationType from fast_llm.layers.transformer.config import RoutingType from fast_llm.models.gpt.config import GPTArchitectureConfig, GPTBaseModelConfig, HuggingfaceModelType +from fast_llm.models.gpt.model import GPTModel from fast_llm.tensor import SafeTensorSlice if typing.TYPE_CHECKING: @@ -90,7 +91,7 @@ def import_weight( class CommonHuggingfaceConverter(HuggingfaceStateDictConverter): - config: GPTArchitectureConfig + _model: GPTModel _base_model_cls = GPTBaseModelConfig """ Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) @@ -125,12 +126,12 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = [] - num_layers = self.config.transformer.num_layers - norm_bias: bool = self.config.transformer.normalization.type == NormalizationType.layer_norm - linear_bias: bool = self.config.transformer.add_linear_biases + num_layers = self._model.base_model_config.transformer.num_layers + norm_bias: bool = self._model.base_model_config.transformer.normalization.type == NormalizationType.layer_norm + linear_bias: bool = self._model.base_model_config.transformer.add_linear_biases # Embedding and output - if self.config.tie_word_embeddings: + if self._model.base_model_config.tie_word_embeddings: converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) converters.append(IgnoreWeightConverter((), "lm_head.weight")) else: @@ -188,7 +189,7 @@ def _get_weight_and_bias_converters( cls( tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), tuple(f"{prefix}.weight" for prefix in hf_prefix), - self.config, + self._model.base_model_config, ) ] if use_bias: @@ -196,7 +197,7 @@ def _get_weight_and_bias_converters( cls( tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), tuple(f"{prefix}.bias" for prefix in hf_prefix), - self.config, + self._model.base_model_config, ) ) return converters @@ -216,7 +217,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): - linear_bias: bool = self.config.transformer.add_linear_biases + linear_bias: bool = self._model.base_model_config.transformer.add_linear_biases return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", linear_bias @@ -252,7 +253,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): - linear_bias: bool = self.config.transformer.add_linear_biases + linear_bias: bool = self._model.base_model_config.transformer.add_linear_biases return [ *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", @@ -286,7 +287,9 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): (f"{hf_prefix}.mlp.gate_proj.weight", f"{hf_prefix}.mlp.up_proj.weight"), ), MLPLayer2Converter( - f"{fast_llm_prefix}.mlp.layer_2.weight", f"{hf_prefix}.mlp.down_proj.weight", self.config + f"{fast_llm_prefix}.mlp.layer_2.weight", + f"{hf_prefix}.mlp.down_proj.weight", + self._model.base_model_config, ), ] @@ -305,7 +308,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): - num_experts = self.config.transformer.num_experts + num_experts = self._model.base_model_config.transformer.num_experts return [ WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.block_sparse_moe.gate.weight"), SplitWeightConverter( @@ -319,12 +322,13 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): MLPLayer2Converter( f"{fast_llm_prefix}.mlp.layer_2.weight", tuple(f"{hf_prefix}.block_sparse_moe.experts.{i}.w2.weight" for i in range(num_experts)), - self.config, + self._model.base_model_config, ), ] class AutoGPTConverter(AutoStateDictConverter, HuggingfaceStateDictConverter, abc.ABC): + converter_map = { HuggingfaceModelType.starcoder2: Starcoder2HuggingfaceConverter, HuggingfaceModelType.llama: LlamaHuggingfaceConverter, diff --git a/fast_llm/tools/convert.py b/fast_llm/tools/convert.py index 6b8c7bcf6..0b39323ec 100644 --- a/fast_llm/tools/convert.py +++ b/fast_llm/tools/convert.py @@ -3,12 +3,12 @@ import json import logging import math -import pathlib import typing +import warnings from fast_llm.config import Field, config_class -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadConfig, CheckpointSaveConfig -from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig +from fast_llm.engine.checkpoint.external import HuggingfaceStateDictConverter from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.functional.config import TritonConfig @@ -23,19 +23,49 @@ @config_class() class ConversionConfig(RunnableConfig): - input_type: CheckpointFormat = Field() - output_type: CheckpointFormat = Field() - input_path: pathlib.Path = Field() - output_path: pathlib.Path = Field() - model_type: str | None = Field(default=None) + input: CheckpointLoadConfig = Field(default_factory=CheckpointLoadConfig) + output: CheckpointSaveConfig = Field(default_factory=CheckpointSaveConfig) use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) - target_params_per_file: int = Field(default=2**32) - dtype: DataType | None = Field( - default=None, - ) layers_per_step: int | None = Field(default=None) + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ): + # TODO v0.2: Remove. + if "input" not in default: + default["input"] = {} + if "output" not in default: + default["output"] = {} + if "input_type" in default: + warnings.warn("`input_type` is deprecated. Use `input.format` instead.") + default["input"]["format"] = default.pop("input_type") + if "input_path" in default: + warnings.warn("`input_path` is deprecated. Use `input.path` instead.") + default["input"]["path"] = default.pop("input_path") + if "output_type" in default: + warnings.warn("`output_type` is deprecated. Use `output.format` instead.") + default["output"]["format"] = default.pop("output_type") + if "output_path" in default: + warnings.warn("`output_path` is deprecated. Use `output.path` instead.") + default["output"]["path"] = default.pop("output_path") + if "model_type" in default: + warnings.warn("`model_type` is deprecated. Use `output.format` instead.") + # Will be handled in CheckpointConfigBase.from_dict + default["input"]["model_type"] = default.pop("model_type") + default["output"]["model_type"] = default.pop("model_type") + if "target_params_per_file" in default: + warnings.warn("`target_params_per_file` is deprecated. Use `output.parameters_per_file` instead.") + default["output"]["parameters_per_file"] = default.pop("target_params_per_file") + if "dtype" in default: + warnings.warn("`dtype` is deprecated. Use `output.data_type` instead.") + default["data_type"]["parameters_per_file"] = default.pop("dtype") + return super()._from_dict(default, strict, flat) + @classmethod def _get_parser(cls): parser = super()._get_parser() @@ -52,33 +82,20 @@ def _get_runnable(self, parsed: argparse.Namespace) -> typing.Callable[[], None] def _convert_model_partial( self, model_class: type["FastLLMModel"], - output_path: pathlib.Path, + output: CheckpointSaveConfig, stage_filter: set | None = None, ): - logger.info(f"Loading {self.input_type} checkpoint from {self.input_path}...") + logger.info(f"Loading {self.input.format} checkpoint from {self.input.path}...") model = model_class.from_pretrained( - CheckpointLoadConfig( - path=self.input_path, - format=self.input_type, - model_type=self.model_type, - ), + self.input, mode=StageMode.weights, use_cpu=self.use_cpu, stage_filter=stage_filter, ) - logger.info(f"Saving {self.output_type} checkpoint to {output_path}...") - output_path.mkdir(parents=True, exist_ok=self.exist_ok) - model.save_checkpoint( - CheckpointSaveConfig( - path=output_path, - format=self.output_type, - model_type=self.model_type, - optimizer_state=False, - parameters_per_file=self.target_params_per_file, - data_type=self.dtype, - ) - ) - (output_path / "ok").open("w") + logger.info(f"Saving {output.format} checkpoint to {output.path}...") + output.path.mkdir(parents=True, exist_ok=self.exist_ok) + model.save_checkpoint(output) + (output.path / "ok").open("w") logger.info(f"Done!") def run(self, model_config_class: type["FastLLMModelConfig"] | str): @@ -89,28 +106,24 @@ def run(self, model_config_class: type["FastLLMModelConfig"] | str): if self.use_cpu: TritonConfig.TRITON_ENABLED = False # Skip on exist_ok=False if the model has already been processed - if not self.exist_ok and (self.output_path / "ok").exists(): + if not self.exist_ok and (self.output.path / "ok").exists(): logger.info( - f"Output path {self.output_path} already exists and has been processed. Skipping model conversion..." + f"Output path {self.output.path} already exists and has been processed. Skipping model conversion..." ) return if isinstance(model_config_class, str): model_config_class = model_registry[model_config_class] model_class = model_config_class.get_model_class() if self.layers_per_step is None: - self._convert_model_partial(model_class, self.output_path) + self._convert_model_partial(model_class, self.output) else: + converter_class = model_config_class.get_converter_class(self.output.format) # TODO: Support other types? - assert self.output_type == CheckpointFormat.external + assert issubclass(converter_class, HuggingfaceStateDictConverter) logger.info(f">>> Loading model config") # Create a dummy version to determine the stage split. model = model_class.from_pretrained( - CheckpointLoadConfig( - path=self.input_path, - format=self.input_type, - model_type=self.model_type, - model_weights=False, - ), + self.input.to_copy({"model_weights": False}), mode=StageMode.off_device, use_cpu=self.use_cpu, ) @@ -120,9 +133,11 @@ def run(self, model_config_class: type["FastLLMModelConfig"] | str): for step_begin in range(0, num_stages, stages_per_step): step_end = min(step_begin + stages_per_step, num_stages) logger.info(f">>> Converting stages {step_begin} to {step_end-1} of {num_stages}") - step_path = self.output_path / str(step_begin) + step_path = self.output.path / str(step_begin) step_paths.append(step_path) - self._convert_model_partial(model_class, step_path, set(range(step_begin, step_end))) + self._convert_model_partial( + model_class, self.output.to_copy({"path": step_path}), set(range(step_begin, step_end)) + ) logger.info(f">>> Aggregating conversion steps") # Combine weight maps and rename data files to avoid duplications. @@ -146,18 +161,18 @@ def run(self, model_config_class: type["FastLLMModelConfig"] | str): new_file_name = f"model_{file_count}.safetensors" file_count += 1 rename_map[file_name] = new_file_name - global_rename_map[step_path / file_name] = self.output_path / new_file_name + global_rename_map[step_path / file_name] = self.output.path / new_file_name Assert.not_incl(name, weight_map) weight_map[name] = new_file_name # Save the combined index - path = self.output_path / index_filename + path = self.output.path / index_filename # Save the index. json.dump(index, path.open("w"), indent=4) # Copy the config - (step_paths[0] / config_filename).rename(self.output_path / config_filename) + (step_paths[0] / config_filename).rename(self.output.path / config_filename) # Move the data files for old_file_name, new_file_name in global_rename_map.items(): @@ -171,7 +186,7 @@ def run(self, model_config_class: type["FastLLMModelConfig"] | str): step_path.rmdir() # All good! - (self.output_path / "ok").open("w") + (self.output.path / "ok").open("w") logger.info(f">>> All done!") diff --git a/setup.cfg b/setup.cfg index 55816ff49..6c4f70fff 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [metadata] -name = llm +name = fast_llm [options] packages = find_namespace: diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 56323970d..2b3b63d0a 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -7,7 +7,12 @@ import transformers import yaml -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadConfig, ModelConfigType +from fast_llm.engine.checkpoint.config import ( + CheckpointFormat, + CheckpointLoadConfig, + CheckpointSaveConfig, + ModelConfigType, +) from fast_llm.engine.multi_stage.config import StageMode from fast_llm.models.auto import model_registry from fast_llm.tools.convert import ConversionConfig @@ -75,11 +80,11 @@ def test_resume(): def _run_conversion(config: ConversionConfig): - if config.output_path.is_dir() and not REUSE_RESULTS: - shutil.rmtree(config.output_path) - if not config.output_path.is_dir(): + if config.output.path.is_dir() and not REUSE_RESULTS: + shutil.rmtree(config.output.path) + if not config.output.path.is_dir(): if FORCE_REUSE_RESULTS: - raise RuntimeError(config.output_path) + raise RuntimeError(config.output.path) config.run(TEST_MODEL_CONFIG_CLS) @@ -91,10 +96,14 @@ def _run_conversion(config: ConversionConfig): def test_convert_distributed_to_state_dict(): _run_conversion( ConversionConfig( - input_type=CheckpointFormat.distributed, - input_path=_CKPT_PATH, - output_type=CheckpointFormat.state_dict, - output_path=_CONVERT_PATH / "state_dict_0", + input=CheckpointLoadConfig( + path=_CKPT_PATH, + format=CheckpointFormat.distributed, + ), + output=CheckpointSaveConfig( + path=_CONVERT_PATH / "state_dict_0", + format=CheckpointFormat.state_dict, + ), ) ) @@ -105,11 +114,14 @@ def test_convert_state_dict_to_huggingface(): pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( ConversionConfig( - input_type=CheckpointFormat.state_dict, - input_path=_CONVERT_PATH / "state_dict_0", - output_type=CheckpointFormat.external, - output_path=_CONVERT_PATH / "huggingface_0", - model_type=HUGGINGFACE_MODEL_TYPE, + input=CheckpointLoadConfig( + path=_CONVERT_PATH / "state_dict_0", + format=CheckpointFormat.state_dict, + ), + output=CheckpointSaveConfig( + path=_CONVERT_PATH / "huggingface_0", + format=HUGGINGFACE_MODEL_TYPE, + ), ) ) @@ -118,10 +130,14 @@ def test_convert_state_dict_to_huggingface(): def test_convert_huggingface_to_distributed(): _run_conversion( ConversionConfig( - input_type=CheckpointFormat.external, - input_path=_CONVERT_PATH / "huggingface_0", - output_type=CheckpointFormat.distributed, - output_path=_CONVERT_PATH / "distributed_0", + input=CheckpointLoadConfig( + path=_CONVERT_PATH / "huggingface_0", + format=HUGGINGFACE_MODEL_TYPE, + ), + output=CheckpointSaveConfig( + path=_CONVERT_PATH / "distributed_0", + format=CheckpointFormat.distributed, + ), ) ) @@ -132,11 +148,14 @@ def test_convert_distributed_to_huggingface(): pytest.skip(f"Conversion not supported for {TEST_MODEL}") _run_conversion( ConversionConfig( - input_type=CheckpointFormat.distributed, - input_path=_CKPT_PATH, - output_type=CheckpointFormat.external, - output_path=_CONVERT_PATH / "huggingface_1", - model_type=HUGGINGFACE_MODEL_TYPE, + input=CheckpointLoadConfig( + path=_CKPT_PATH, + format=CheckpointFormat.distributed, + ), + output=CheckpointSaveConfig( + path=_CONVERT_PATH / "huggingface_1", + format=HUGGINGFACE_MODEL_TYPE, + ), ) ) @@ -145,10 +164,14 @@ def test_convert_distributed_to_huggingface(): def test_convert_huggingface_to_state_dict(): _run_conversion( ConversionConfig( - input_type=CheckpointFormat.external, - input_path=_CONVERT_PATH / "huggingface_1", - output_type=CheckpointFormat.state_dict, - output_path=_CONVERT_PATH / "state_dict_1", + input=CheckpointLoadConfig( + path=_CONVERT_PATH / "huggingface_1", + format=HUGGINGFACE_MODEL_TYPE, + ), + output=CheckpointSaveConfig( + path=_CONVERT_PATH / "state_dict_1", + format=CheckpointFormat.state_dict, + ), ) ) @@ -157,10 +180,14 @@ def test_convert_huggingface_to_state_dict(): def test_convert_state_dict_to_distributed(): _run_conversion( ConversionConfig( - input_type=CheckpointFormat.state_dict, - input_path=_CONVERT_PATH / "state_dict_1", - output_type=CheckpointFormat.distributed, - output_path=_CONVERT_PATH / "distributed_1", + input=CheckpointLoadConfig( + path=_CONVERT_PATH / "state_dict_1", + format=CheckpointFormat.state_dict, + ), + output=CheckpointSaveConfig( + path=_CONVERT_PATH / "distributed_1", + format=CheckpointFormat.distributed, + ), ) ) @@ -269,11 +296,11 @@ def test_load_converted_huggingface_checkpoint(): ) pretrained_config_0 = CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", - format=CheckpointFormat.external, + format=HUGGINGFACE_MODEL_TYPE, ) pretrained_config_1 = CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_1", - format=CheckpointFormat.external, + format=HUGGINGFACE_MODEL_TYPE, ) config = TEST_MODEL_CONFIG_CLS.from_pretrained(pretrained_config_ref) model = TEST_MODEL_CLS.from_pretrained(pretrained_config_0, mode=StageMode.weights) @@ -302,7 +329,7 @@ def test_run_converted_model(): model_from_hf = TEST_MODEL_HF_CLS.from_pretrained( CheckpointLoadConfig( path=_CONVERT_PATH / "huggingface_0", - format=CheckpointFormat.external, + format=HUGGINGFACE_MODEL_TYPE, ) ) errors = [] @@ -467,7 +494,7 @@ def test_load_pretrained_huggingface_in_dp2(): "training.checkpoint.interval=1", "training.train_iters=1", f"pretrained.path={_CONVERT_PATH / 'huggingface_0'}", - f"pretrained.format=external", + f"pretrained.format={HUGGINGFACE_MODEL_TYPE}", "schedule.skip_step=True", ], num_gpus=2, diff --git a/tools/push_model.py b/tools/push_model.py index 161ac8554..c4667c618 100644 --- a/tools/push_model.py +++ b/tools/push_model.py @@ -7,7 +7,7 @@ import subprocess from fast_llm.config import Field, config_class -from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointLoadConfig, CheckpointSaveConfig from fast_llm.engine.config_utils.runnable import RunnableConfig try: @@ -148,11 +148,14 @@ def run(self) -> None: checkpoint_path_hf = checkpoint_path.with_name(checkpoint_path.name + "_hf") # Block until the conversion is done ConversionConfig( - input_type=CheckpointFormat.distributed, - output_type=CheckpointFormat.external, - input_path=checkpoint_path, - output_path=checkpoint_path_hf, - model_type=self.model_type, + input=CheckpointLoadConfig( + path=checkpoint_path, + format=CheckpointFormat.distributed, + ), + output=CheckpointSaveConfig( + path=checkpoint_path_hf, + format=self.model_type, + ), use_cpu=self.use_cpu, exist_ok=False, # skip if already processed layers_per_step=(