Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ def __init__(
self.valid = valid


class FieldUpdate(dict):
"""
Specify some entries in the field that should be updated from the base class.
Useful for changing the default or description in a derived class.
Processed in `__init_subclass__`.
"""


def check_field(fn, *args, **kwargs):
"""
Helper function to define a condition that a config field should satisfy,
Expand Down Expand Up @@ -270,7 +278,7 @@ class Config:
__class_validated__: typing.ClassVar[bool] = True
_abstract: typing.ClassVar[bool] = False
_validated: bool = Field(init=False, repr=False)
_unknown_fields: dict[str] = Field(init=False, repr=False)
_unknown_fields: dict[str, typing.Any] = Field(init=False, repr=False)

def __post_init__(self):
"""
Expand Down Expand Up @@ -621,7 +629,7 @@ def _get_class_name(cls):
@classmethod
def from_dict(
cls,
default: typing.Union["Config", dict[str]],
default: typing.Union["Config", dict[str, typing.Any]],
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
strict: bool = True,
):
Expand All @@ -646,7 +654,7 @@ def from_dict(
@classmethod
def from_flat_dict(
cls,
default: dict[str],
default: dict[str, typing.Any],
strict: bool = True,
):
# TODO v0.2: Remove flat format
Expand All @@ -655,7 +663,7 @@ def from_flat_dict(
@classmethod
def _from_dict(
cls,
default: dict[str],
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
):
Expand Down Expand Up @@ -768,7 +776,7 @@ def _from_dict_dict(cls, value, type_, strict: bool):
return {key: cls._from_dict_nested(value_, args[1], strict) for key, value_ in value.items()}

@classmethod
def _handle_renamed_field(cls, default: dict[str], old_name: str, new_name: str):
def _handle_renamed_field(cls, default: dict[str, typing.Any], old_name: str, new_name: str):
if old_name in default:
warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.")
default[new_name] = default.pop(old_name)
Expand Down Expand Up @@ -803,11 +811,51 @@ def _check_abstract(cls):
if not cls.__class_validated__:
raise RuntimeError(f"{cls.__name__} hasn't been validated. Make sure to use the @config_class decorator.")

def __init_subclass__(cls, **kwargs):
def __init_subclass__(cls):
"""
We need to postpone validation until the class has been processed by the dataclass wrapper.
"""
assert (
cls.__class_validated__
), f"Parent class of config class {cls.__name__} has not been validated. Make sure to use the @config_class decorator."
for base_class in cls.__mro__:
if issubclass(base_class, Config):
assert cls.__class_validated__, (
f"Parent class {get_type_name(base_class)} of config class {get_type_name(cls)} has not been validated."
f" Make sure to use the @config_class decorator."
)
cls.__class_validated__ = False
for name in list(cls.__dict__):
value = getattr(cls, name)
if isinstance(value, FieldUpdate):
# In case of multiple inheritance, the base class field may not appear in `cls.__dataclass_fields__`.
# so we iterate over superclasses following mro and use the first match.
base_class_field = None
for base_class in cls.__mro__:
base_class_fields = getattr(base_class, "__dataclass_fields__", {})
if name in base_class_fields:
base_class_field = base_class_fields[name]
break
if base_class_field is None:
raise RuntimeError(f"Trying to update the non-existent field {name} in class {get_type_name(cls)}")
setattr(
cls,
name,
Field(
desc=value.pop("desc", base_class_field.desc),
doc=value.pop("doc", base_class_field.doc),
hint=value.pop("hint", base_class_field.hint),
valid=value.pop("valid", base_class_field.valid),
default=value.pop("default", base_class_field.default),
default_factory=value.pop("default_factory", base_class_field.default_factory),
repr=value.pop("repr", base_class_field.repr),
hash=value.pop("hash", base_class_field.hash),
compare=value.pop("compare", base_class_field.compare),
metadata=value.pop("metadata", base_class_field.metadata),
kw_only=value.pop("kw_only", base_class_field.kw_only),
),
)
if name in cls.__annotations__:
# TODO: Generalize to other type hints.
if isinstance(cls.__annotations__[name], type) and isinstance(base_class_field.type, type):
Assert.custom(issubclass, cls.__annotations__[name], base_class_field.type)
else:
# dataclasses expects an annotation, so we use the one from the base class.
cls.__annotations__[name] = base_class_field.type
4 changes: 2 additions & 2 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

if typing.TYPE_CHECKING:
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.engine.multi_stage.conversion import ModelConverter
from fast_llm.engine.multi_stage.conversion import ExternalModelConverter


@config_class()
Expand All @@ -30,7 +30,7 @@ def compare_architecture(
return self.get_architecture().compare(model_config.get_architecture(), log_fn)

@classmethod
def get_converter_class(cls, model_type: str | None = None) -> type["ModelConverter"]:
def get_converter_class(cls, model_type: str | None = None) -> type["ExternalModelConverter"]:
raise NotImplementedError()


Expand Down
147 changes: 147 additions & 0 deletions fast_llm/engine/config_utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# TODO: Use packaging.version? (Safer but extra requirement)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure where to ask this question:
where will we be configuring the layers-per-step option?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a hack for the single process converter, it doesn't apply here because we already have the whole model loaded. I'm not entirely sure about memory usage but it should be OK because we reconstruct one layer at the time.

import enum
import logging
import pathlib
import typing
import warnings

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)

# TODO: Use packaging.version? (Safer but extra requirement)
CHECKPOINT_VERSION = "0.1"
KNOWN_CHECKPOINT_VERSIONS = ("0", "0.1")


class CheckpointFormat(str, enum.Enum):
# Distributed checkpoint for fast checkpointing and resuming.
distributed = "distributed"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we call this fast or short_term?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also an option, let's discuss it later.

# Model state dict, for safe long-term storage in Fast-LLM format.
state_dict = "state_dict"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about we call this long_term instead?

Copy link
Copy Markdown
Collaborator Author

@jlamypoirier jlamypoirier Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with renaming (in a future PR), but I would prefer a name more related to Fast-LLM, since it's meant to be the "standard" Fast-LLM checkpoint format. (Maybe just "fast_llm"?)

# A checkpoint format external to Fast-LLM.
external = "external"


class ModelConfigType(str, enum.Enum):
none = "none"
architecture = "architecture"
model = "model"
fast_llm = "fast_llm"

@property
def load_architecture(self):
return self != ModelConfigType.none

@property
def load_base_model(self):
return self in (ModelConfigType.model, ModelConfigType.fast_llm)

@property
def load_fast_llm(self):
return self == ModelConfigType.fast_llm


@config_class()
class CheckpointPathConfigBase(Config):
_abstract = True
path: pathlib.Path | None = Field(
default=None,
desc="Location of the checkpoint.",
hint=FieldHint.core,
)


@config_class()
class CheckpointConfigBase(Config):
_abstract = True
format: CheckpointFormat = Field(
default=CheckpointFormat.distributed,
desc="Format of the checkpoint.",
hint=FieldHint.core,
)
model_type: str | None = Field(
default=None,
desc="Model type for external models (ex. Huggingace model name).",
hint=FieldHint.feature,
)

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
):
# TODO v0.2: Remove.
if default.get("format", None) == "huggingface":
warnings.warn(f"`huggingface` checkpoint format has been renamed to `external`.")
default["format"] = CheckpointFormat.external.value
Comment on lines +79 to +81
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will there be other external checkpoint formats or why this change?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, in theory the conversion mechanism can be used for any kind of checkpoint format, I had renamed anything else already but not this one because of backward compatibility.

I'm thinking of getting rid of it altogether in the next PR though, and just use the "model_type" as the format for external formats.

cls._handle_renamed_field(default, "imported_type", "model_type")
return super()._from_dict(default, strict, flat)


@config_class()
class CheckpointStateConfigBase(Config):
_abstract = True
model_weights: bool = Field(default=True, desc="Save/load the model weights.", hint=FieldHint.feature)
optimizer_state: bool = Field(default=False, desc="Save/load the optimizer state.", hint=FieldHint.feature)

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
):
cls._handle_renamed_field(default, "load_weights", "model_weights")
cls._handle_renamed_field(default, "load_optimizer", "optimizer_state")
return super()._from_dict(default, strict, flat)


@config_class()
class CheckpointSaveConfigBase(Config):
_abstract = True
parameters_per_file: int = Field(
default=2**32,
desc="Limit the number of parameters saved in each file.",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 2**20),
)
data_type: DataType | None = Field(
default=None,
desc="Data type to save the checkpoint.",
hint=FieldHint.feature,
)


@config_class()
class CheckpointSaveMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase):
_abstract = False


@config_class()
class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateConfigBase, CheckpointSaveConfigBase):
_abstract = False


@config_class()
class CheckpointLoadMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBase):
_abstract = False

load_config: ModelConfigType = Field(
default=ModelConfigType.architecture,
desc="Configuration to save/load.",
hint=FieldHint.core,
)

@property
def compare_log_fn(self):
return ValueError if self.load_config.load_architecture else logger.warning


@config_class()
class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase):
_abstract = False
Loading