From 8daecd392c5385a4d726854fd5cee477bb35e13c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Oct 2024 18:06:29 -0400 Subject: [PATCH 1/3] Simplified checkpoint loading --- fast_llm/engine/config_utils/checkpoint.py | 11 + fast_llm/engine/multi_stage/conversion.py | 102 ++++- fast_llm/engine/multi_stage/fast_llm_model.py | 423 ++++++++---------- fast_llm/engine/multi_stage/multi_stage.py | 13 +- fast_llm/engine/multi_stage/stage_base.py | 11 +- fast_llm/engine/training/config.py | 9 + fast_llm/engine/training/trainer.py | 9 +- 7 files changed, 312 insertions(+), 266 deletions(-) diff --git a/fast_llm/engine/config_utils/checkpoint.py b/fast_llm/engine/config_utils/checkpoint.py index dc77fd5e0..fc390e01e 100644 --- a/fast_llm/engine/config_utils/checkpoint.py +++ b/fast_llm/engine/config_utils/checkpoint.py @@ -137,6 +137,11 @@ class CheckpointLoadMetadataConfig(CheckpointPathConfigBase, CheckpointConfigBas hint=FieldHint.core, ) + def _validate(self): + super()._validate() + if self.format == CheckpointFormat.distributed: + assert self.load_config.load_architecture + @property def compare_log_fn(self): return ValueError if self.load_config.load_architecture else logger.warning @@ -145,3 +150,9 @@ def compare_log_fn(self): @config_class() class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase): _abstract = False + + def _validate(self): + super()._validate() + if self.format == CheckpointFormat.external: + # TODO: Support optimizer? + assert not self.optimizer_state diff --git a/fast_llm/engine/multi_stage/conversion.py b/fast_llm/engine/multi_stage/conversion.py index 4acb9d928..becd4462d 100644 --- a/fast_llm/engine/multi_stage/conversion.py +++ b/fast_llm/engine/multi_stage/conversion.py @@ -7,6 +7,7 @@ import safetensors import torch +import yaml from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.tensor import SafeTensorSlice @@ -140,25 +141,80 @@ def import_weight( class ModelConverter(abc.ABC): base_file_name: typing.ClassVar[str] + @classmethod + @abc.abstractmethod + def get_key(cls, parameter_name: str, shard_name: str) -> str: + pass + + @classmethod + @abc.abstractmethod + def load_key(cls, key: str) -> tuple[str, str]: + pass + @abc.abstractmethod def convert_state_dict( - self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool + self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool ) -> dict[str, torch.Tensor | SafeTensorSlice]: pass + @abc.abstractmethod + def load_weights( + self, + directory: pathlib.Path | str, + device, + shard_names: list[str], + ) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]: + pass + + +def _import_safetensors_metadata(metadata): + return {key: yaml.safe_load(value) for key, value in metadata.items()} + class TrivialConverter(ModelConverter): base_file_name = "state_dict" + @classmethod + def get_key(cls, parameter_name: str, shard_name: str) -> str: + return f"{parameter_name}/{shard_name}" + + @classmethod + def load_key(cls, key: str) -> tuple[str, str]: + parameter_name, shard_name = key.split("/", 1) + return parameter_name, shard_name + def convert_state_dict( - self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool + self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool ) -> dict[str, torch.Tensor | SafeTensorSlice]: - out_state_dict = {} - for key in list(state_dict): - name, shard_name = key - out_state_dict[f"{name}/{shard_name}"] = state_dict.pop(key) + out_state_dict = state_dict.copy() + state_dict.clear() return out_state_dict + def load_weights( + self, + directory: pathlib.Path | str, + device, + shard_names: list[str], + ) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]: + index_path = directory / f"state_dict.safetensors.index.json" + logger.info(f"Loading index from {index_path}") + file_names = set(json.load(index_path.open("r"))["weight_map"].values()) + for file_name in file_names: + logger.info(f"Loading from {directory / file_name}") + with safetensors.safe_open( + directory / file_name, + framework="pt", + device=str(device), + ) as f: + metadata = _import_safetensors_metadata(f.metadata()) + Assert.eq(metadata["state_shard_names"][: len(shard_names)], shard_names) + for key in f.keys(): + parameter_name, shard_name = key.split("/", 1) + if shard_name in shard_names: + yield parameter_name, shard_name, f.get_slice(key) + + # return metadata["metadata"] + class ExternalModelConverter(ModelConverter): base_file_name = "model" @@ -197,12 +253,6 @@ def load_config(cls, directory: pathlib.Path | str) -> dict[str, typing.Any]: def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]): pass - @abc.abstractmethod - def load_weights( - self, directory: pathlib.Path | str, device - ) -> typing.Iterator[tuple[str, torch.Tensor | SafeTensorSlice]]: - pass - @classmethod def export_config(cls, config: BaseModelArchitectureConfig) -> dict[str, typing.Any]: exported_config = {} @@ -237,21 +287,20 @@ def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = Fa return cls(cls.import_config(config, architecture_only=architecture_only)) def convert_state_dict( - self, state_dict: dict[tuple[str, str], torch.Tensor | SafeTensorSlice], export: bool + self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool ) -> dict[str, torch.Tensor | SafeTensorSlice]: out_state_dict = {} weight_converters = self._export_converters if export else self._import_converters - for state_dict_name, shard_name in list(state_dict): - assert shard_name == "weights" + for state_dict_name in list(state_dict): try: if state_dict_name not in weight_converters: continue weight_converter: WeightConverter = weight_converters[state_dict_name] in_names = weight_converter.fast_llm_name if export else weight_converter.export_name - if not all((name, shard_name) in state_dict for name in in_names): + if not all(name in state_dict for name in in_names): continue - in_weights = tuple(state_dict.pop((name, shard_name)) for name in in_names) + in_weights = tuple(state_dict.pop(name) for name in in_names) out_names = weight_converter.export_name if export else weight_converter.fast_llm_name out_weights = ( weight_converter.export_weight(in_weights) @@ -302,6 +351,15 @@ def from_config(cls, config: dict[str, typing.Any], architecture_only: bool = Fa class HuggingfaceModelConverter(ExternalModelConverter, abc.ABC): model_type: str | None = None + @classmethod + def get_key(cls, parameter_name: str, shard_name: str) -> str: + Assert.eq(shard_name, "weights") + return parameter_name + + @classmethod + def load_key(cls, key: str) -> tuple[str, str]: + return key, "weights" + @classmethod @abc.abstractmethod def _create_config_converters(cls) -> list[ParamConverter]: @@ -323,10 +381,14 @@ def save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any transformers.CONFIG_MAPPING[config["model_type"]].from_dict(config).save_pretrained(directory) def load_weights( - self, directory: pathlib.Path | str, device - ) -> typing.Iterator[tuple[str, torch.Tensor | SafeTensorSlice]]: + self, + directory: pathlib.Path | str, + device, + shard_names: list[str], + ) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]: import transformers + Assert.eq(shard_names, ["weights"]) if (directory / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): paths = {directory / transformers.utils.SAFE_WEIGHTS_NAME} elif (directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): @@ -356,7 +418,7 @@ def load_weights( if path.suffix == ".safetensors": with safetensors.safe_open(path, framework="pt", device=str(device)) as f: for key in f.keys(): - yield key, f.get_slice(key) + yield key, "weights", f.get_slice(key) elif path.suffix == ".bin": # TODO: Prevent unsafe by default yield from torch.load(path) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index edcbdc0e2..77c30f196 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -1,7 +1,5 @@ -import json import logging import math -import pathlib import typing import safetensors.torch @@ -22,9 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode from fast_llm.engine.multi_stage.conversion import ModelConverter, TrivialConverter from fast_llm.engine.multi_stage.multi_stage import MultiStageModel -from fast_llm.engine.multi_stage.stage import Stage from fast_llm.functional.triton.pointwise import triton_fill -from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -44,10 +40,6 @@ def _export_safetensors_metadata(metadata): } -def _import_safetensors_metadata(metadata): - return {key: yaml.safe_load(value) for key, value in metadata.items()} - - class FastLLMModel(MultiStageModel): _is_setup: bool = False _is_loaded: bool = False @@ -85,12 +77,12 @@ def distributed(self): def save_checkpoint( self, - checkpoint_config: CheckpointSaveConfig, + config: CheckpointSaveConfig, metadata: dict | None = None, ): # TODO: Handle barriers, ok file, mkdir, etc. here - num_shards = len(self._state_shard_names) if checkpoint_config.optimizer_state else 1 + num_shards = len(self._state_shard_names) if config.optimizer_state else 1 metadata = { "checkpoint_type": CheckpointFormat.distributed.value, "checkpoint_version": str(CHECKPOINT_VERSION), @@ -100,14 +92,14 @@ def save_checkpoint( } # TODO: Simplify branching. - if checkpoint_config.format == CheckpointFormat.external: + if config.format == CheckpointFormat.external: # TODO: Support optimizer? - assert not checkpoint_config.optimizer_state - converter_class = self._base_model_config.get_converter_class(checkpoint_config.model_type) + assert not config.optimizer_state + converter_class = self._base_model_config.get_converter_class(config.model_type) exported_config = converter_class.export_config(self._base_model_config) - converter_class.save_config(checkpoint_config.path, exported_config) + converter_class.save_config(config.path, exported_config) self._save_state_dict( - checkpoint_config, + config, converter_class(self._base_model_config), { "fast_llm_metadata": metadata, @@ -115,22 +107,22 @@ def save_checkpoint( "format": "pt", }, ) - elif checkpoint_config.format == CheckpointFormat.state_dict: - self._save_state_dict(checkpoint_config, TrivialConverter(), metadata) - elif checkpoint_config.format == CheckpointFormat.distributed: + elif config.format == CheckpointFormat.state_dict: + self._save_state_dict(config, TrivialConverter(), metadata) + elif config.format == CheckpointFormat.distributed: if self._distributed_config.rank == 0: - yaml.safe_dump(metadata, (checkpoint_config.path / "metadata.yaml").open("w")) + yaml.safe_dump(metadata, (config.path / "metadata.yaml").open("w")) safetensors.torch.save_file( tensors={"state_shard": self._state_shard[:num_shards]}, - filename=checkpoint_config.path / f"rank_{self._distributed_config.rank}.safetensors", + filename=config.path / f"rank_{self._distributed_config.rank}.safetensors", metadata=_export_safetensors_metadata(metadata), ) else: - raise NotImplementedError(checkpoint_config.format) + raise NotImplementedError(config.format) - def _save_state_dict(self, checkpoint_config: CheckpointSaveConfig, converter: ModelConverter, metadata: dict): + def _save_state_dict(self, config: CheckpointSaveConfig, converter: ModelConverter, metadata: dict): with StateDictSaver( - checkpoint_config, + config, distributed=self._distributed, metadata=metadata, base_file_name=converter.base_file_name, @@ -141,50 +133,127 @@ def _save_state_dict(self, checkpoint_config: CheckpointSaveConfig, converter: M # If converting a tensor requires another one that is not yet available (e.g. for concatenation), # it will remain in `fast_llm_state_dict` until that tensor is available. fast_llm_state_dict = {} - for i, shard_name in enumerate( - self._state_shard_names if checkpoint_config.optimizer_state else self._state_shard_names[:1] + for parameter_name, shard_name, tensor in self.get_state_tensor_iterator( + self._state_shard_names if config.optimizer_state else self._state_shard_names[:1], config.data_type ): - shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) - for stage, shard in zip(self._stages_on_device.values(), shard_split): - for name, tensor in stage._export_shard(shard, dtype=checkpoint_config.data_type): # noqa - assert name not in fast_llm_state_dict - fast_llm_state_dict[(name, shard_name)] = tensor - for exported_name, exported_tensor in converter.convert_state_dict( - fast_llm_state_dict, True - ).items(): - context.add_tensor(exported_name, exported_tensor) - assert not fast_llm_state_dict, list(fast_llm_state_dict) - - def load_pretrained_checkpoint(self, pretrained_config: CheckpointLoadConfig): - if pretrained_config.format == CheckpointFormat.distributed: + if shard_name not in fast_llm_state_dict: + fast_llm_state_dict[shard_name] = {} + shard_state_dict = fast_llm_state_dict[shard_name] + assert parameter_name not in shard_state_dict + shard_state_dict[parameter_name] = tensor + for exported_name, exported_tensor in converter.convert_state_dict(shard_state_dict, True).items(): + context.add_tensor(converter.get_key(exported_name, shard_name), exported_tensor) + for shard_name, shard_state_dict in fast_llm_state_dict.items(): + assert not shard_state_dict, (shard_name, list(fast_llm_state_dict)) + + def load_checkpoint(self, config: CheckpointLoadConfig): + # TODO: Simplify branching. + # TODO: Test with more distributed configs. + # TODO: Safety checks + # TODO: Handle barriers, ok file, etc. here + metadata = self.config_class.load_pretrained_metadata(config) + if config.format == CheckpointFormat.distributed: # TODO: Check if same format. - self._load_distributed_checkpoint(pretrained_config) - elif pretrained_config.format == CheckpointFormat.state_dict: - self._load_state_dict_checkpoint(pretrained_config) - elif pretrained_config.format == CheckpointFormat.external: - self._import_checkpoint(pretrained_config) + self._load_distributed_checkpoint(config, metadata) + elif config.format == CheckpointFormat.state_dict: + self._load_state_dict(config, TrivialConverter()) + elif config.format == CheckpointFormat.external: + # TODO: Support optimizer. + assert not config.optimizer_state + converter_class = self.base_model.architecture_cls().get_converter_class(config.model_type) + converter = converter_class.from_config(converter_class.load_config(config.path)) + self._base_model_config.compare_architecture(converter.config, config.compare_log_fn) + self._load_state_dict(config, converter) else: - raise NotImplementedError(pretrained_config.format) - - def load_distributed_checkpoint_same_format(self, directory: pathlib.Path): - # TODO: Handle barriers, ok file, etc. here + raise NotImplementedError(config.format) + return metadata.get("metadata") + + def _load_state_dict(self, config: CheckpointLoadConfig, converter: ModelConverter): + num_shards = len(self._state_shard_names) if config.optimizer_state else 1 + with self._SafeLoadContext(self, num_shards=num_shards) as context: + state_dict = {} + for parameter_name, shard_name, tensor in converter.load_weights( + config.path, self._distributed.device, self._state_shard_names[:num_shards] + ): + if shard_name not in state_dict: + state_dict[shard_name] = {} + shard_state_dict = state_dict[shard_name] + assert parameter_name not in shard_state_dict + shard_state_dict[parameter_name] = tensor + for parameter_name, fast_llm_tensor in converter.convert_state_dict(shard_state_dict, False).items(): + stage_index = self._parameter_stages[parameter_name] + if stage_index not in self._stage_shard_indices: + # Tensor is not on this device. + return 0 + stage_shard = self._state_shard[self._state_shard_names.index(shard_name)].split( + self._stage_shard_sizes, 0 + )[self._stage_shard_indices[stage_index]] + loaded = multi_stage._stages[stage_index]._import_state_tensor( + stage_shard, parameter_name, fast_llm_tensor + ) # noqa + context.mark_as_loaded(loaded, (parameter_name, shard_name)) + + for shard_name, shard_state_dict in state_dict.items(): + assert not shard_state_dict, (shard_name, list(state_dict)) + + self._finalize_load(reset_optimizer=not config.optimizer_state) + + def _load_distributed_checkpoint(self, config: CheckpointLoadConfig, metadata: dict): # TODO: More safety checks - # TODO: Integrate to load_checkpoint. - pretrained_config = CheckpointLoadConfig(path=directory, format=CheckpointFormat.distributed) - metadata = self.config_class.load_pretrained_metadata(pretrained_config) - with self._LoadContext(self, safe=False, load_optimizer=True, reset_pads=False) as context: - Assert.eq( - metadata["state_shard_names"][: context.num_shards], - list(self._state_shard_names[: context.num_shards]), - ) + loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm}) + loaded_config = self.config_class.from_metadata(loaded_config_dict, metadata) + num_shards = self._num_state_shards if config.optimizer_state else 1 + Assert.eq(metadata["state_shard_names"][:num_shards], list(self._state_shard_names[:num_shards])) + + if ( + loaded_config.to_serialized(verbose=None) == self._fast_llm_config.to_serialized(verbose=None) + and config.optimizer_state + ): + logger.info("Checkpoint format matches, using fast load") + # TODO: Add version without optimizer state? with safetensors.safe_open( - directory / f"rank_{self._distributed_config.rank}.safetensors", + config.path / f"rank_{self._distributed_config.rank}.safetensors", framework="pt", device=str(self._distributed.device), ) as f: # TODO: Does this copy twice? - self._state_shard[: context.num_shards].copy_(f.get_slice("state_shard")[: context.num_shards]) - return metadata["metadata"] + self._state_shard[:num_shards].copy_(f.get_slice("state_shard")[:num_shards]) + else: + logger.info("Checkpoint format doesn't match, using safe load") + self._base_model_config.compare_architecture(loaded_config.base_model, config.compare_log_fn) + with self._SafeLoadContext(self, num_shards=num_shards) as context: + for rank in range(loaded_config.distributed.world_size): + loaded_model = self.__class__( + loaded_config.to_copy({("distributed", "rank"): rank}), + optimizer_state_names=self._state_shard_names[1:num_shards], + verbose=False, + ) + path = config.path / f"rank_{rank}.safetensors" + logger.info(f"Loading from {path}") + # TODO: skip shards without overlap. + with safetensors.safe_open(path, framework="pt", device=str(self._distributed.device)) as f: + # TODO: Use self_shard + loaded_shard = f.get_slice("state_shard")[:num_shards] + loaded_model._state_shard_meta.validate(loaded_shard) + + # TODO: Improve num shard selection. + self_shard_split = self._state_shard[: loaded_shard.size(0)].split(self._stage_shard_sizes, 1) + loaded_shard_split = loaded_shard.split(loaded_model._stage_shard_sizes, 1) + + counter = torch.zeros(1, dtype=torch.int64, device=self._distributed.device) + for loaded_shard_index, loaded_stage in enumerate(loaded_model._stages_on_device.values()): + loaded_shards = ( + loaded_shard_split[loaded_shard_index].to(self._distributed.device).unbind(0) + ) + for self_shard_index, self_stage in enumerate(self._stages_on_device.values()): + self_stage._copy_shard_overlaps( # noqa + loaded_stage, + self_shard_split[self_shard_index].unbind(0), + loaded_shards, + counter, + ) + context.mark_as_loaded(counter.item()) + self._finalize_load(reset_optimizer=not config.optimizer_state) @classmethod def from_pretrained( @@ -224,90 +293,88 @@ def from_pretrained( if mode.on_device: if pretrained_config.model_weights: - model.load_pretrained_checkpoint(pretrained_config) + model.load_checkpoint(pretrained_config) else: model.initialize_weights() return model def initialize_weights(self): assert self._is_setup - with self._LoadContext(self, safe=False, load_optimizer=False, reset_pads=True): - assert self._is_setup - for stage in self._stages: - stage.initialize_weights() - for name, tied_parameter in self._tied_parameters.items(): - if tied_parameter.group is not None: - broadcast(self._stages[tied_parameter.main_stage].weight_shard, 0, tied_parameter.group) - - def _reset_shard_pads(self, optimizer: bool = False): - counter = 0 - for shard in self._state_shard if optimizer else self._state_shard[:1]: - shard_split = shard.split(self._stage_shard_sizes, 0) - for stage, stage_shard in zip(self._stages_on_device.values(), shard_split): - counter += stage.reset_shard_pad(stage_shard) - return counter - - class _LoadContext: + for stage in self._stages: + stage.initialize_weights() + for name, tied_parameter in self._tied_parameters.items(): + if tied_parameter.group is not None: + broadcast(self._stages[tied_parameter.main_stage].weight_shard, 0, tied_parameter.group) + self._finalize_load(reset_optimizer=True) + + def _finalize_load(self, reset_optimizer: bool = True): + if reset_optimizer: + triton_fill(self._state_shard[1:], 0.0) + if self._mode.support_forward: + self.invalidate_buffers() + self._is_loaded = True + + class _SafeLoadContext: # TODO: Improve - def __init__(self, model: "FastLLMModel", *, safe: bool, load_optimizer: bool, reset_pads: bool): - assert model._is_setup - self.multi_stage = model - self.safe = safe - self.load_optimizer = load_optimizer - self.num_shards = len(self.multi_stage._state_shard_names) if self.load_optimizer else 1 - self.self_shard = self.multi_stage._state_shard[: self.num_shards] - self.reset_pads = reset_pads - self.shard_names = self.multi_stage._state_shard_names[: self.num_shards] + def __init__(self, model: "FastLLMModel", *, num_shards: int): + self._model = model + self._num_shards = num_shards + self._self_shard = self._model._state_shard[: self._num_shards] def __enter__(self): - if self.safe: - self.loaded = 0 - self.loaded_parameters = {} - # Track the number of loaded entries. - # Use nan to mark non-loaded entries. - triton_fill(self.self_shard, math.nan) - if self.reset_pads: - self.loaded += self.multi_stage._reset_shard_pads(self.load_optimizer) + self._loaded = 0 + self._loaded_parameters = {} + # Track the number of loaded entries. + # Use nan to mark non-loaded entries. + triton_fill(self._self_shard, math.nan) + # Reset and count shard pads + for shard in self._model._state_shard[: self._num_shards]: + shard_split = shard.split(self._model._stage_shard_sizes, 0) + for stage, stage_shard in zip(self._model._stages_on_device.values(), shard_split): + self._loaded += stage.reset_shard_pad(stage_shard) return self def __exit__(self, exc_type, exc_val, exc_tb): if not exc_type: - if self.safe: - self._validate() - if not self.load_optimizer: - triton_fill(self.multi_stage._state_shard[1:], 0.0) - if self.multi_stage._mode.support_forward: - self.multi_stage.invalidate_buffers() - self.multi_stage._is_loaded = True + self._validate() + + def mark_as_loaded(self, count: int, parameter: tuple[str, str] | None = None): + self._loaded += count + if parameter is not None: + parameter_name, shard_name = parameter + if shard_name not in self._loaded_parameters: + self._loaded_parameters[shard_name] = {} + Assert.not_incl(parameter_name, self._loaded_parameters[shard_name]) + self._loaded_parameters[shard_name][parameter_name] = count def _validate(self): errors = [] self._check_counter(errors) self._check_missing(errors) - if self.loaded_parameters: + if self._loaded_parameters: self._check_parameters(errors) if errors: for error in errors: logger.error(error) raise RuntimeError("Model loading validation failed. See logs for details.") - logger.info(f"{self.loaded:,} state entries loaded successfully") + logger.info(f"{self._loaded:,} state entries loaded successfully") def _check_counter(self, errors: list[str]): - to_load = self.self_shard.numel() - if self.loaded != to_load: + to_load = self._self_shard.numel() + if self._loaded != to_load: # Ensure the right amount of weights is loaded. - errors.append(f"Loaded a total of {self.loaded:,}, state entries, expected {to_load:,}") + errors.append(f"Loaded a total of {self._loaded:,}, state entries, expected {to_load:,}") def _check_missing(self, errors: list[str]): # Ensure the loaded weights have a 1-1 mapping by looking for nans. - missing = self.self_shard.new_zeros([], dtype=torch.int64) + missing = self._self_shard.new_zeros([], dtype=torch.int64) # Count nans in slices of 100M parameters to limit memory usage. # TODO: Find better solution (triton kernel?) - for shard_slice in self.self_shard.flatten().split(100000000): + for shard_slice in self._self_shard.flatten().split(100000000): missing += shard_slice.isnan().sum() local_missing = missing.item() - if self.multi_stage._distributed.world_group is not None: - all_reduce(missing, group=self.multi_stage._distributed.world_group) + if self._model._distributed.world_group is not None: + all_reduce(missing, group=self._model._distributed.world_group) global_missing = missing.item() if global_missing: errors.append( @@ -315,9 +382,9 @@ def _check_missing(self, errors: list[str]): ) # Determine where the missing values are coming from. global_total, local_total = 0, 0 - for shard_name, shard_ in zip(self.shard_names, self.self_shard): - shard_split = shard_.split(self.multi_stage._stage_shard_sizes, 0) - for stage, shard in zip(self.multi_stage._stages_on_device.values(), shard_split): + for shard_name, shard_ in zip(self._model._state_shard_names[: self._num_shards], self._self_shard): + shard_split = shard_.split(self._model._stage_shard_sizes, 0) + for stage, shard in zip(self._model._stages_on_device.values(), shard_split): buffer = stage._reconstruct_from_shard(shard) for i, parameter in enumerate(stage._split_buffer(buffer)): missing_for_param = parameter.isnan().sum().item() @@ -351,26 +418,26 @@ def _check_missing(self, errors: list[str]): ) def _check_parameters(self, errors: list[str]): - loaded_shard_names = set(self.loaded_parameters) - shard_names = set(self.shard_names) + loaded_shard_names = set(self._loaded_parameters) + shard_names = set(self._model._state_shard_names[: self._num_shards]) if loaded_shard_names != shard_names: errors.append(f"Incorrect loaded shards: {loaded_shard_names}!={shard_names}") for shard_name in shard_names & loaded_shard_names: counter_per_parameter = { - parameter_name: self.loaded_parameters[shard_name].pop(parameter_name, None) - for parameter_name in self.multi_stage._parameter_stages + parameter_name: self._loaded_parameters[shard_name].pop(parameter_name, None) + for parameter_name in self._model._parameter_stages } - for parameter_name, count in self.loaded_parameters[shard_name].items(): + for parameter_name, count in self._loaded_parameters[shard_name].items(): errors.append( f'Loaded unknown parameter "{parameter_name}" for shard "{shard_name}" (count={count})' ) for parameter_name, counter in counter_per_parameter.items(): - if self.multi_stage._parameter_stages[parameter_name] in self.multi_stage._stages_on_device: + if self._model._parameter_stages[parameter_name] in self._model._stages_on_device: if counter is None: errors.append(f'Missing parameter "{parameter_name}" for shard "{shard_name}"') elif counter is not None and counter > 0: errors.append(f'Loaded off-device parameter : "{parameter_name}" for shard "{shard_name}"') - distributed = self.multi_stage._distributed + distributed = self._model._distributed if distributed.world_group is not None: counter_tensor = torch.tensor( [counter or 0 for counter in counter_per_parameter.values()], dtype=torch.int64 @@ -382,7 +449,7 @@ def _check_parameters(self, errors: list[str]): } for parameter_name, counter in counter_per_parameter.items(): parameter_size = ( - self.multi_stage._stages[self.multi_stage._parameter_stages[parameter_name]] + self._model._stages[self._model._parameter_stages[parameter_name]] .get_parameter_meta(parameter_name) .global_shape.numel() ) @@ -390,119 +457,3 @@ def _check_parameters(self, errors: list[str]): errors.append( f'Global counter mismatch for parameter "{parameter_name}" and shard "{shard_name}": {counter} != {parameter_size}' ) - - def load_state_unsafe(self, state: torch.Tensor, loaded_model: "MultiStageModel"): - """ - Load a state from a checkpoint saved in another distributed configuration. - """ - self_stage: Stage - loaded_stage: Stage - loaded_model._state_shard_meta.validate(state) - multi_stage = self.multi_stage - - # TODO: Improve num shard selection. - self_shard_split = multi_stage._state_shard[: state.size(0)].split(multi_stage._stage_shard_sizes, 1) - loaded_shard_split = state.split(loaded_model._stage_shard_sizes, 1) - - counter = torch.zeros(1, dtype=torch.int64, device=multi_stage._distributed.device) - for loaded_shard_index, loaded_stage in enumerate(loaded_model._stages_on_device.values()): - loaded_shards = loaded_shard_split[loaded_shard_index].to(multi_stage._distributed.device).unbind(0) - for self_shard_index, self_stage in enumerate(multi_stage._stages_on_device.values()): - self_stage._copy_shard_overlaps( # noqa - loaded_stage, - self_shard_split[self_shard_index].unbind(0), - loaded_shards, - counter, - ) - self.loaded += counter.item() - - def import_state_tensor(self, shard_name: str, parameter_name: str, tensor: torch.Tensor | SafeTensorSlice): - multi_stage = self.multi_stage - stage_index = multi_stage._parameter_stages[parameter_name] - if stage_index not in multi_stage._stage_shard_indices: - # Tensor is not on this device. - return 0 - stage_shard = multi_stage._state_shard[multi_stage._state_shard_names.index(shard_name)].split( - multi_stage._stage_shard_sizes, 0 - )[multi_stage._stage_shard_indices[stage_index]] - loaded = multi_stage._stages[stage_index]._import_state_tensor(stage_shard, parameter_name, tensor) # noqa - self.loaded += loaded - if shard_name not in self.loaded_parameters: - self.loaded_parameters[shard_name] = {} - self.loaded_parameters[shard_name][parameter_name] = loaded - - def _load_distributed_checkpoint(self, pretrained_config: CheckpointLoadConfig): - # TODO: More safety checks - metadata = self.config_class.load_pretrained_metadata(pretrained_config) - loaded_pretrained_config = pretrained_config.to_copy({"load_config": ModelConfigType.fast_llm}) - loaded_config = self.config_class.from_metadata( - loaded_pretrained_config, - metadata, - ) - with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True - ) as context: - Assert.eq(metadata["state_shard_names"][: context.num_shards], list(context.shard_names)) - - for rank in range(loaded_config.distributed.world_size): - loaded_multi_stage = self.__class__( - loaded_config.to_copy({("distributed", "rank"): rank}), - optimizer_state_names=context.shard_names[1:], - verbose=False, - ) - path = pretrained_config.path / f"rank_{rank}.safetensors" - logger.info(f"Loading from {path}") - # TODO: skip shards without overlap. - with safetensors.safe_open(path, framework="pt", device=str(self._distributed.device)) as f: - # TODO: Use self_shard - context.load_state_unsafe(f.get_slice("state_shard")[: context.num_shards], loaded_multi_stage) - - return metadata["metadata"] - - def _load_state_dict_checkpoint(self, pretrained_config: CheckpointLoadConfig): - # TODO: Make into a special case of _import_state_tensor? - # TODO: Verify more distributed configs. - # TODO: More safety checks - with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True - ) as context: - index_path = pretrained_config.path / f"state_dict.safetensors.index.json" - logger.info(f"Loading index from {index_path}") - file_names = set(json.load(index_path.open("r"))["weight_map"].values()) - for file_name in file_names: - logger.info(f"Loading from {pretrained_config.path / file_name}") - with safetensors.safe_open( - pretrained_config.path / file_name, - framework="pt", - device=str(self._distributed.device), - ) as f: - metadata = _import_safetensors_metadata(f.metadata()) - Assert.eq(metadata["state_shard_names"][: context.num_shards], list(context.shard_names)) - for key in f.keys(): - parameter_name, shard_name = key.split("/", 1) - if shard_name in context.shard_names: - context.import_state_tensor(shard_name, parameter_name, f.get_slice(key)) - - return metadata["metadata"] - - def _import_checkpoint(self, pretrained_config: CheckpointLoadConfig): - # TODO: Support optimizer? - assert not pretrained_config.optimizer_state - # TODO: Verify more distributed configs. - # TODO: Safety checks - - converter_class = self.base_model.architecture_cls().get_converter_class(pretrained_config.model_type) - converter = converter_class.from_config(converter_class.load_config(pretrained_config.path)) - self._base_model_config.compare_architecture(converter.config, pretrained_config.compare_log_fn) - - state_dict = {} - with self._LoadContext( - self, safe=True, load_optimizer=pretrained_config.optimizer_state, reset_pads=True - ) as context: - for name, tensor in converter.load_weights(pretrained_config.path, self._distributed.device): - assert name not in state_dict - state_dict[(name, "weights")] = tensor - for parameter_name, fast_llm_tensor in converter.convert_state_dict(state_dict, False).items(): - context.import_state_tensor("weights", parameter_name, fast_llm_tensor) - - assert not state_dict, list(state_dict) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 6ffedc2e0..a1a50b960 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -7,6 +7,7 @@ from torch._C._distributed_c10d import ProcessGroup from fast_llm.engine.base_model.base_model import BaseModel +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames @@ -122,7 +123,8 @@ def __init__( stage_shard_dtype = get_unique([stage.weight_shard_meta.dtype for stage in self._stages_on_device.values()]) self._state_shard_names = ("weights",) + optimizer_state_names - self._num_shards = len(self._state_shard_names) + 1 + self._num_state_shards = len(self._state_shard_names) + self._num_shards = self._num_state_shards + 1 shard_dim = TensorDim("flat_shard", sum(self._stage_shard_sizes)) self._weight_shard_meta = TensorMeta.from_dims( @@ -131,7 +133,7 @@ def __init__( dtype=stage_shard_dtype, ) self._state_shard_meta = TensorMeta.from_dims( - (TensorDim("state_shards", self._num_shards - 1), shard_dim), + (TensorDim("state_shards", self._num_state_shards), shard_dim), tensor_name=f"multi_stage_state_shard", dtype=stage_shard_dtype, ) @@ -358,6 +360,13 @@ def train(self, mode: bool = True): stage.train(mode) self._training = mode + def get_state_tensor_iterator(self, shard_names: list[str], data_type: DataType | None = None): + for i, shard_name in enumerate(shard_names): + shard_split = self._state_shard[i].split(self._stage_shard_sizes, 0) + for stage, shard in zip(self._stages_on_device.values(), shard_split): + for name, tensor in stage._export_shard(shard, data_type=data_type): # noqa + yield name, shard_name, tensor + def _split_into_stages(self): # Create stages (greedy split, could do better). stage_splits = [0] diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 437aa7b08..2111cb66e 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -283,6 +283,9 @@ def initialize_weights(self): elif self._mode.on_device: meta.init_parameter(parameter, self._distributed) + if self.mode.on_device: + self.reset_shard_pad(self._weight_shard) + if self._config.debug_param_init: log_generator("CPU generator after reset", torch.random.default_generator) log_generator("PP init generator after reset", self._distributed.pp_init_generator) @@ -449,10 +452,10 @@ def _import_state_tensor(self, shard: torch.Tensor, parameter_name: str, tensor: shard[begin:end].copy_(tensor_shard) return end - begin - def _export_shard(self, shard: torch.Tensor, dtype: DataType | None = None): - if dtype is not None: - shard = shard.to(dtype=dtype.torch) - tensors = self._split_buffer(self._reconstruct_from_shard(shard.to(dtype=dtype))) + def _export_shard(self, shard: torch.Tensor, data_type: DataType | None = None): + if data_type is not None: + shard = shard.to(dtype=data_type.torch) + tensors = self._split_buffer(self._reconstruct_from_shard(shard.to(dtype=data_type))) for name, param_index in self._parameter_index.items(): yield name, self._parameter_metas[param_index].local_to_global( tensors[param_index], distributed=self._distributed diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 655f29677..ec454ac96 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -10,6 +10,7 @@ from fast_llm.engine.config_utils.checkpoint import ( CheckpointConfigBase, CheckpointFormat, + CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveConfigBase, CheckpointStateConfigBase, @@ -218,6 +219,14 @@ def get_save_config(self, path: pathlib.Path): optimizer_state=True, ) + def get_load_config(self, path: pathlib.Path): + return CheckpointLoadConfig( + path=path, + format=CheckpointFormat.distributed, + model_weights=True, + optimizer_state=True, + ) + @config_class() class ExportConfig(CheckpointBaseConfig, CheckpointConfigBase, CheckpointStateConfigBase, CheckpointSaveConfigBase): diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index ff8a5f029..e9f94e4aa 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -19,7 +19,7 @@ from fast_llm.engine.optimizer.optimizer import Optimizer from fast_llm.engine.schedule.runner import ScheduleRunner from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.engine.training.config import CheckpointBaseConfig, TrainerConfig +from fast_llm.engine.training.config import CheckpointBaseConfig, CheckpointConfig, TrainerConfig from fast_llm.engine.training.wandb import Wandb from fast_llm.logging import format_metrics, get_memory_usage_mib, log_memory_usage from fast_llm.utils import Assert @@ -383,7 +383,7 @@ def _prepare_training_state(self): f" ({'loading' if self._config.pretrained.optimizer_state else 'resetting'}" f" optimizer state)..." ) - self._multi_stage.load_pretrained_checkpoint(self._config.pretrained) + self._multi_stage.load_checkpoint(self._config.pretrained) else: log_main_rank(f"Initializing training state from scratch...") self._multi_stage.initialize_weights() @@ -435,12 +435,13 @@ def _save_checkpoint(self, config: CheckpointBaseConfig, metrics: dict[PhaseType config.callback.run() - def _load_checkpoint(self, config: CheckpointBaseConfig, iteration: int): + def _load_checkpoint(self, config: CheckpointConfig, iteration: int): checkpoint_directory = self._run.experiment_directory / config.directory_name / str(iteration) Assert.custom(pathlib.Path.is_file, checkpoint_directory / "ok") # TODO v0.2: Use config.get_load_config to make it generic # TODO v0.2: Detect format instead of hard-coding - metadata = self._multi_stage.load_distributed_checkpoint_same_format(checkpoint_directory) + + metadata = self._multi_stage.load_checkpoint(config.get_load_config(checkpoint_directory)) self._optimizer.load(metadata["optimizer"]) if "schedules" in metadata: # Backward compatibility. From c231a02727e14bf8d6ce246b94142620b628c8e9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Oct 2024 11:31:54 -0400 Subject: [PATCH 2/3] fixes --- fast_llm/engine/multi_stage/conversion.py | 4 ++-- fast_llm/engine/multi_stage/fast_llm_model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/multi_stage/conversion.py b/fast_llm/engine/multi_stage/conversion.py index becd4462d..dbfea9290 100644 --- a/fast_llm/engine/multi_stage/conversion.py +++ b/fast_llm/engine/multi_stage/conversion.py @@ -207,7 +207,7 @@ def load_weights( device=str(device), ) as f: metadata = _import_safetensors_metadata(f.metadata()) - Assert.eq(metadata["state_shard_names"][: len(shard_names)], shard_names) + Assert.eq(metadata["state_shard_names"][: len(shard_names)], list(shard_names)) for key in f.keys(): parameter_name, shard_name = key.split("/", 1) if shard_name in shard_names: @@ -388,7 +388,7 @@ def load_weights( ) -> typing.Iterator[tuple[str, str, torch.Tensor | SafeTensorSlice]]: import transformers - Assert.eq(shard_names, ["weights"]) + Assert.eq(shard_names, ("weights",)) if (directory / transformers.utils.SAFE_WEIGHTS_NAME).is_file(): paths = {directory / transformers.utils.SAFE_WEIGHTS_NAME} elif (directory / transformers.utils.SAFE_WEIGHTS_INDEX_NAME).is_file(): diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 77c30f196..152f91ecd 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -188,7 +188,7 @@ def _load_state_dict(self, config: CheckpointLoadConfig, converter: ModelConvert stage_shard = self._state_shard[self._state_shard_names.index(shard_name)].split( self._stage_shard_sizes, 0 )[self._stage_shard_indices[stage_index]] - loaded = multi_stage._stages[stage_index]._import_state_tensor( + loaded = self._stages[stage_index]._import_state_tensor( stage_shard, parameter_name, fast_llm_tensor ) # noqa context.mark_as_loaded(loaded, (parameter_name, shard_name)) From 3239793eb66af38777bc1748b85c83539a8a4eb7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 23 Oct 2024 11:35:42 -0400 Subject: [PATCH 3/3] cleanup --- fast_llm/engine/multi_stage/conversion.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/fast_llm/engine/multi_stage/conversion.py b/fast_llm/engine/multi_stage/conversion.py index dbfea9290..57fea315e 100644 --- a/fast_llm/engine/multi_stage/conversion.py +++ b/fast_llm/engine/multi_stage/conversion.py @@ -146,11 +146,6 @@ class ModelConverter(abc.ABC): def get_key(cls, parameter_name: str, shard_name: str) -> str: pass - @classmethod - @abc.abstractmethod - def load_key(cls, key: str) -> tuple[str, str]: - pass - @abc.abstractmethod def convert_state_dict( self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool @@ -178,11 +173,6 @@ class TrivialConverter(ModelConverter): def get_key(cls, parameter_name: str, shard_name: str) -> str: return f"{parameter_name}/{shard_name}" - @classmethod - def load_key(cls, key: str) -> tuple[str, str]: - parameter_name, shard_name = key.split("/", 1) - return parameter_name, shard_name - def convert_state_dict( self, state_dict: dict[str, torch.Tensor | SafeTensorSlice], export: bool ) -> dict[str, torch.Tensor | SafeTensorSlice]: @@ -356,10 +346,6 @@ def get_key(cls, parameter_name: str, shard_name: str) -> str: Assert.eq(shard_name, "weights") return parameter_name - @classmethod - def load_key(cls, key: str) -> tuple[str, str]: - return key, "weights" - @classmethod @abc.abstractmethod def _create_config_converters(cls) -> list[ParamConverter]: