Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from fast_llm.config import Config, config_class

if typing.TYPE_CHECKING:
from fast_llm.engine.checkpoint.external import ExternalStateDictConverter
from fast_llm.engine.config_utils.tensor_space import TensorSpace


Expand All @@ -29,10 +28,6 @@ def compare_architecture(
):
return self.get_architecture().compare(model_config.get_architecture(), log_fn)

@classmethod
def get_converter_class(cls, model_type: str | None = None) -> type["ExternalStateDictConverter"]:
raise NotImplementedError()


@config_class()
class BaseModelConfig(BaseModelArchitectureConfig):
Expand Down
74 changes: 57 additions & 17 deletions fast_llm/engine/checkpoint/config.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,50 @@
# TODO: Use packaging.version? (Safer but extra requirement)
import abc
import enum
import logging
import pathlib
import typing
import warnings

import yaml

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel

logger = logging.getLogger(__name__)

# TODO: Use packaging.version? (Safer but extra requirement)
CHECKPOINT_VERSION = "0.1"
KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1")


class CheckpointFormat(str, enum.Enum):
def export_safetensors_metadata(metadata):
"""
Safetensor only accepts string entries, so we convert to string explicitly.
We use yaml rather than json because json requires explicit quotation marks on strings, which breaks things.
(ex. "format": "pt" becomes '"pt"' which breaks huggingface models.)
We avoid using safe_dump for scalars because it adds junk ("\n...\n") at the end of the string
(decoding is unaffected.)
"""
return {
key: str(value) if isinstance(value, (str, int, float, bool)) else yaml.safe_dump(value)
for key, value in metadata.items()
}


def import_safetensors_metadata(metadata):
return {key: yaml.safe_load(value) for key, value in metadata.items()}


class CheckpointFormat(str):
# Distributed checkpoint for fast checkpointing and resuming.
distributed = "distributed"
# Model state dict, for safe long-term storage in Fast-LLM format.
state_dict = "state_dict"
# A checkpoint format external to Fast-LLM.
external = "external"


class ModelConfigType(str, enum.Enum):
Expand Down Expand Up @@ -57,16 +79,11 @@ class CheckpointPathConfigBase(Config):
@config_class()
class CheckpointConfigBase(Config):
_abstract = True
format: CheckpointFormat = Field(
format: str = Field(
default=CheckpointFormat.distributed,
desc="Format of the checkpoint.",
hint=FieldHint.core,
)
model_type: str | None = Field(
default=None,
desc="Model type for external models (ex. Huggingace model name).",
hint=FieldHint.feature,
)

@classmethod
def _from_dict(
Expand All @@ -76,10 +93,17 @@ def _from_dict(
flat: bool = False,
):
# TODO v0.2: Remove.
if default.get("format", None) == "huggingface":
warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.")
default["format"] = CheckpointFormat.external.value
cls._handle_renamed_field(default, "imported_type", "model_type")
if "model_type" in default:
warnings.warn(
"`CheckpointConfigBase.model_type` is deprecated."
" Instead, use the model name directly as the checkpoint format."
)
if default.get("format", None) in ("huggingface", "external"):
default["format"] = default.get("model_type")
if default["format"] is None:
default["format"] = "auto"
del default["model_type"]
return super()._from_dict(default, strict, flat)


Expand Down Expand Up @@ -151,8 +175,24 @@ def compare_log_fn(self):
class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase):
_abstract = False

def _validate(self):
super()._validate()
if self.format == CheckpointFormat.external:
# TODO: Support optimizer?
assert not self.optimizer_state

class Converter(abc.ABC):
# TODO: Rename? (Checkpointer? Saver?)

def __init__(self, model: "FastLLMModel"):
self._model = model

# TODO: save_metadata?

@classmethod
@abc.abstractmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig):
pass

@abc.abstractmethod
def save(self, config: CheckpointSaveConfig, metadata: dict):
pass

@abc.abstractmethod
def load(self, config: CheckpointLoadConfig, metadata: dict):
pass
93 changes: 93 additions & 0 deletions fast_llm/engine/checkpoint/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging

import safetensors.torch
import torch
import yaml

from fast_llm.engine.checkpoint.config import (
CheckpointLoadConfig,
CheckpointLoadMetadataConfig,
CheckpointSaveConfig,
Converter,
ModelConfigType,
export_safetensors_metadata,
)
from fast_llm.engine.checkpoint.safe_load import SafeLoad
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


class DistributedConverter(Converter):

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig):
return yaml.safe_load((config.path / "metadata.yaml").open("r"))

def save(self, config: CheckpointSaveConfig, metadata: dict):
if self._model.distributed_config.rank == 0:
yaml.safe_dump(metadata, (config.path / "metadata.yaml").open("w"))
num_shards = len(self._model.state_shard_names) if config.optimizer_state else 1
safetensors.torch.save_file(
tensors={"state_shard": self._model.state_shard[:num_shards]},
filename=config.path / f"rank_{self._model.distributed_config.rank}.safetensors",
metadata=export_safetensors_metadata(metadata),
)

def load(self, config: CheckpointLoadConfig, metadata: dict):
# TODO: More safety checks
loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm})
loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata)
num_shards = self._model.num_state_shards if config.optimizer_state else 1
Assert.eq(metadata["state_shard_names"][:num_shards], list(self._model.state_shard_names[:num_shards]))

if (
loaded_config.to_serialized(verbose=None) == self._model.fast_llm_config.to_serialized(verbose=None)
and config.optimizer_state
):
logger.info("Checkpoint format matches, using fast load")
# TODO: Add version without optimizer state?
with safetensors.safe_open(
config.path / f"rank_{self._model.distributed_config.rank}.safetensors",
framework="pt",
device=str(self._model.distributed.device),
) as f:
# TODO: Does this copy twice?
self._model.state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards])
else:
logger.info("Checkpoint format doesn't match, using safe load")
self._model.base_model_config.compare_architecture(loaded_config.base_model, config.compare_log_fn)
with SafeLoad(self._model, num_shards=num_shards) as context:
for rank in range(loaded_config.distributed.world_size):
loaded_model = self._model.__class__(
loaded_config.to_copy({("distributed", "rank"): rank}),
optimizer_state_names=self._model.state_shard_names[1:num_shards],
verbose=False,
)
path = config.path / f"rank_{rank}.safetensors"
logger.info(f"Loading from {path}")
# TODO: skip shards without overlap.
with safetensors.safe_open(path, framework="pt", device=str(self._model.distributed.device)) as f:
# TODO: Use self_shard
loaded_shard = f.get_slice("state_shard")[:num_shards]
loaded_model.state_shard_meta.validate(loaded_shard)

# TODO: Improve num shard selection.
self_shard_split = self._model.state_shard[: loaded_shard.size(0)].split(
self._model.stage_shard_sizes, 1
)
loaded_shard_split = loaded_shard.split(loaded_model.stage_shard_sizes, 1)

counter = torch.zeros(1, dtype=torch.int64, device=self._model.distributed.device)
for loaded_shard_index, loaded_stage in enumerate(loaded_model.stages_on_device.values()):
loaded_shards = (
loaded_shard_split[loaded_shard_index].to(self._model.distributed.device).unbind(0)
)
for self_shard_index, self_stage in enumerate(self._model.stages_on_device.values()):
self_stage._copy_shard_overlaps( # noqa
loaded_stage,
self_shard_split[self_shard_index].unbind(0),
loaded_shards,
counter,
)
context.mark_as_loaded(counter.item())
87 changes: 58 additions & 29 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
import torch

from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
from fast_llm.engine.checkpoint.config import (
CHECKPOINT_VERSION,
CheckpointLoadConfig,
CheckpointLoadMetadataConfig,
CheckpointSaveConfig,
)
from fast_llm.engine.checkpoint.state_dict import StateDictConverter
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert

Expand Down Expand Up @@ -139,13 +146,12 @@ def import_weight(


class ExternalStateDictConverter(StateDictConverter):
base_file_name = "model"
_base_model_cls: type[BaseModelConfig]
_config_converters: list[ParamConverter]

def __init__(self, config: BaseModelArchitectureConfig):
self.config = config
Assert.custom(isinstance, config, self._base_model_cls.architecture_cls)
def __init__(self, model: "FastLLMModel"):
super().__init__(model)
Assert.custom(isinstance, self._model.base_model_config, self._base_model_cls.architecture_cls)
weight_converters = self._create_weight_converters()
self._export_converters = {
weight_converter.fast_llm_name[0]: weight_converter
Expand All @@ -166,17 +172,7 @@ def _create_weight_converters(self) -> list[WeightConverter]:
pass

@classmethod
@abc.abstractmethod
def load_config(cls, directory: pathlib.Path | str) -> dict[str, typing.Any]:
pass

@classmethod
@abc.abstractmethod
def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]):
pass

@classmethod
def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]:
def _export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]:
exported_config = {}
for converter in cls._get_config_converters():
value = converter.export_param(
Expand All @@ -190,7 +186,7 @@ def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.
return exported_config # Noqa

@classmethod
def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): # noqa
def _import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False): # noqa
kwargs = {}
for converter in cls._get_config_converters():
value = converter.import_param(
Expand All @@ -204,11 +200,7 @@ def import_config(cls, config: dict[str, typing.Any], architecture_only: bool =
config_class = cls._base_model_cls.architecture_cls if architecture_only else cls._base_model_cls
return config_class.from_dict({}, kwargs)

@classmethod
def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False):
return cls(cls.import_config(config, architecture_only=architecture_only))

def convert_state_dict(
def _convert_state_dict(
self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool
) -> dict[str, torch.Tensor | SafeTensorSlice]:
out_state_dict = {}
Expand Down Expand Up @@ -262,19 +254,56 @@ class AutoStateDictConverter(ExternalStateDictConverter, abc.ABC):
converter_map: dict[str, type[ExternalStateDictConverter]]

@classmethod
def import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False):
return cls.converter_map[config["model_type"]].import_config(config, architecture_only)
def get_converter_class(cls, format: str):
if format in cls.converter_map:
return cls.converter_map[format]
elif format == "auto":
return cls
else:
raise NotImplementedError(format)

# TODO: load_metadata???

@classmethod
def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = False):
return cls.converter_map[config["model_type"]].from_config(config, architecture_only)
def _import_config(cls, config: dict[str, typing.Any], architecture_only: bool = False):
# TODO: ???
return cls.converter_map[config["model_type"]]._import_config(config, architecture_only)


class HuggingfaceStateDictConverter(ExternalStateDictConverter, abc.ABC):
model_type: str | None = None
base_file_name = "model"

@classmethod
def load_metadata(cls, config: CheckpointLoadMetadataConfig):
imported_model_config = cls._import_config(cls._load_config(config.path), True)
return {
# TODO: Avoid `to_serialized`?
"fast_llm_config": {"base_model": imported_model_config.to_serialized()},
# TODO: Handle "auto"?
"checkpoint_type": config.format,
"checkpoint_version": CHECKPOINT_VERSION,
}

def save(self, config: CheckpointSaveConfig, metadata: dict):
huggingface_config = self._export_config(self._model.base_model_config)
self._save_config(config.path, huggingface_config)
metadata = {
"fast_llm_metadata": metadata,
"model_config": huggingface_config,
"format": "pt",
}
super().save(config, metadata)

def load(self, config: CheckpointLoadConfig, metadata: dict):
assert not config.optimizer_state
self._model.base_model_config.compare_architecture(
self._base_model_cls.from_dict(metadata["fast_llm_config"]["base_model"]), config.compare_log_fn
)
super().load(config, metadata)

@classmethod
def get_key(cls, parameter_name: str, shard_name: str) -> str:
def _get_key(cls, parameter_name: str, shard_name: str) -> str:
Assert.eq(shard_name, "weights")
return parameter_name

Expand All @@ -284,7 +313,7 @@ def _create_config_converters(cls) -> list[ParamConverter]:
return [ConstantExportParamConverter(None, "model_type", cls.model_type)]

@classmethod
def load_config(cls, directory: pathlib.Path | str):
def _load_config(cls, directory: pathlib.Path | str):
import transformers

config = transformers.AutoConfig.from_pretrained(directory).to_dict()
Expand All @@ -293,12 +322,12 @@ def load_config(cls, directory: pathlib.Path | str):
return config

@classmethod
def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]):
def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]):
import transformers

transformers.CONFIG_MAPPING[config["model_type"]].from_dict(config).save_pretrained(directory)

def load_weights(
def _load_weights(
self,
directory: pathlib.Path | str,
device,
Expand Down
Loading