Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions fast_llm/engine/config_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ class CheckpointLoadMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBas
hint=FieldHint.core,
)

def _validate(self):
super()._validate()
if self.format == CheckpointFormat.distributed:
assert self.load_config.load_architecture

@property
def compare_log_fn(self):
return ValueError if self.load_config.load_architecture else logger.warning
Expand All @@ -145,3 +150,9 @@ def compare_log_fn(self):
@config_class()
class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase):
_abstract = False

def _validate(self):
super()._validate()
if self.format == CheckpointFormat.external:
# TODO: Support optimizer?
assert not self.optimizer_state
88 changes: 68 additions & 20 deletions fast_llm/engine/multi_stage/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import safetensors
import torch
import yaml

from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
from fast_llm.tensor import SafeTensorSlice
Expand Down Expand Up @@ -140,25 +141,70 @@ def import_weight(
class ModelConverter(abc.ABC):
base_file_name: typing.ClassVar[str]

@classmethod
@abc.abstractmethod
def get_key(cls, parameter_name: str, shard_name: str) -> str:
pass

@abc.abstractmethod
def convert_state_dict(
self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool
self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool
) -> dict[str, torch.Tensor | SafeTensorSlice]:
pass

@abc.abstractmethod
def load_weights(
self,
directory: pathlib.Path | str,
device,
shard_names: list[str],
) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]:
pass


def _import_safetensors_metadata(metadata):
return {key: yaml.safe_load(value) for key, value in metadata.items()}


class TrivialConverter(ModelConverter):
base_file_name = "state_dict"

@classmethod
def get_key(cls, parameter_name: str, shard_name: str) -> str:
return f"{parameter_name}/{shard_name}"

def convert_state_dict(
self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool
self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool
) -> dict[str, torch.Tensor | SafeTensorSlice]:
out_state_dict = {}
for key in list(state_dict):
name, shard_name = key
out_state_dict[f"{name}/{shard_name}"] = state_dict.pop(key)
out_state_dict = state_dict.copy()
state_dict.clear()
return out_state_dict

def load_weights(
self,
directory: pathlib.Path | str,
device,
shard_names: list[str],
) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]:
index_path = directory / f"state_dict.safetensors.index.json"
logger.info(f"Loading index from {index_path}")
file_names = set(json.load(index_path.open("r"))["weight_map"].values())
for file_name in file_names:
logger.info(f"Loading from {directory / file_name}")
with safetensors.safe_open(
directory / file_name,
framework="pt",
device=str(device),
) as f:
metadata = _import_safetensors_metadata(f.metadata())
Assert.eq(metadata["state_shard_names"][: len(shard_names)], list(shard_names))
for key in f.keys():
parameter_name, shard_name = key.split("/", 1)
if shard_name in shard_names:
yield parameter_name, shard_name, f.get_slice(key)

# return metadata["metadata"]


class ExternalModelConverter(ModelConverter):
base_file_name = "model"
Expand Down Expand Up @@ -197,12 +243,6 @@ def load_config(cls, directory: pathlib.Path | str) -> dict[str, typing.Any]:
def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]):
pass

@abc.abstractmethod
def load_weights(
self, directory: pathlib.Path | str, device
) -> typing.Iterator[tuple[str, torch.Tensor | SafeTensorSlice]]:
pass

@classmethod
def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]:
exported_config = {}
Expand Down Expand Up @@ -237,21 +277,20 @@ def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = Fa
return cls(cls.import_config(config, architecture_only=architecture_only))

def convert_state_dict(
self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool
self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool
) -> dict[str, torch.Tensor | SafeTensorSlice]:
out_state_dict = {}
weight_converters = self._export_converters if export else self._import_converters

for state_dict_name, shard_name in list(state_dict):
assert shard_name == "weights"
for state_dict_name in list(state_dict):
try:
if state_dict_name not in weight_converters:
continue
weight_converter: WeightConverter = weight_converters[state_dict_name]
in_names = weight_converter.fast_llm_name if export else weight_converter.export_name
if not all((name, shard_name) in state_dict for name in in_names):
if not all(name in state_dict for name in in_names):
continue
in_weights = tuple(state_dict.pop((name, shard_name)) for name in in_names)
in_weights = tuple(state_dict.pop(name) for name in in_names)
out_names = weight_converter.export_name if export else weight_converter.fast_llm_name
out_weights = (
weight_converter.export_weight(in_weights)
Expand Down Expand Up @@ -302,6 +341,11 @@ def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = Fa
class HuggingfaceModelConverter(ExternalModelConverter, abc.ABC):
model_type: str | None = None

@classmethod
def get_key(cls, parameter_name: str, shard_name: str) -> str:
Assert.eq(shard_name, "weights")
return parameter_name

@classmethod
@abc.abstractmethod
def _create_config_converters(cls) -> list[ParamConverter]:
Expand All @@ -323,10 +367,14 @@ def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any
transformers.CONFIG_MAPPING[config["model_type"]].from_dict(config).save_pretrained(directory)

def load_weights(
self, directory: pathlib.Path | str, device
) -> typing.Iterator[tuple[str, torch.Tensor | SafeTensorSlice]]:
self,
directory: pathlib.Path | str,
device,
shard_names: list[str],
) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]:
import transformers

Assert.eq(shard_names, ("weights",))
if (directory / transformers.utils.SAFE_WEIGHTS_NAME).is_file():
paths = {directory / transformers.utils.SAFE_WEIGHTS_NAME}
elif (directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file():
Expand Down Expand Up @@ -356,7 +404,7 @@ def load_weights(
if path.suffix == ".safetensors":
with safetensors.safe_open(path, framework="pt", device=str(device)) as f:
for key in f.keys():
yield key, f.get_slice(key)
yield key, "weights", f.get_slice(key)
elif path.suffix == ".bin":
# TODO: Prevent unsafe by default
yield from torch.load(path)
Expand Down
Loading