diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f3633a76a..065eb94d8 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -187,7 +187,7 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(yaml_data) + self._load_yaml_data(loaded_yaml_data) if not self._truncate_documents: del loaded_yaml_data["unshuffled_tokens"] diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 1905fca75..2be1e4879 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -6,13 +6,16 @@ import torch.nn from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import BaseModelConfig, Preprocessor +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.engine.inference.runner import InferenceRunner + class Module(torch.nn.Module, abc.ABC): """ """ @@ -80,6 +83,7 @@ def get_layers(self) -> list[Layer]: class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC): config_class: typing.ClassVar[type[BaseModelConfig]] = BaseModelConfig + _is_setup: bool = False def __init__( self, @@ -96,12 +100,18 @@ def __init__( # Rename to the parameter full name value.tensor_name = key - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass + # Reference models + # TODO: Add basic handling (preprocessor) in this class. + self._reference_models: dict[str, "InferenceRunner"] = {} - @abc.abstractmethod def setup(self, distributed: Distributed) -> None: + assert not self._is_setup + distributed.check_config(self._tensor_space.distributed_config) + self._tensor_space.setup(distributed) + self._is_setup = True + + @abc.abstractmethod + def get_layers(self) -> list[Layer]: pass @abc.abstractmethod @@ -132,6 +142,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: def loss_defs(self) -> list[LossDef]: pass - def add_preprocessor(self, preprocessor: Preprocessor): - # TODO: Generalize preprocessors. - raise NotImplementedError() + def add_reference_model(self, name: str, inference_runner: "InferenceRunner") -> None: + assert name not in self._reference_models + assert not self._is_setup + self._reference_models[name] = inference_runner diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 0384fdacd..5020bc650 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -27,6 +27,9 @@ def __repr__(self) -> str: f")" ) + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + @property def name(self) -> str: return self._name diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 69bf3695a..e2d04f80f 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -30,7 +30,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.inference.model import HuggingfacePreTrainedModel + from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 8b4cadc31..1e990e9c8 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -23,7 +23,6 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -386,7 +385,7 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): - _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) + self._add_reference_distributed_to_pretrained(reference_model) super()._validate() if self.reference_models: # TODO: Add support. @@ -396,6 +395,8 @@ def _validate(self) -> None: Assert.eq(self.model.distributed.sequence_data_parallel, 1) if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() + for reference_model in self.reference_models.values(): + assert reference_model.model.distributed.reference_config is self.model.distributed def _setup(self): super()._setup() @@ -423,18 +424,17 @@ def runnable(): return runnable + def _add_reference_distributed_to_pretrained(self, pretrained: PretrainedFastLLMModelConfig): + old_setup = pretrained._setup -def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelConfig, distributed: DistributedConfig): - old_setup = pretrained._setup - - def new_setup(): - # Make sure the distributed config isn't set - pretrained.model.distributed.validate() - Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) - with NoAutoValidate(): - pretrained.model.distributed = distributed.to_copy() - # Allow sharing the `Distributed` instance. - pretrained.model.distributed.reference_config = distributed - old_setup() + def new_setup(): + # Make sure the distributed config isn't set + pretrained.model.distributed.validate() + Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) + with NoAutoValidate(): + pretrained.model.distributed = self.model.distributed.to_copy() + # Allow sharing the `Distributed` instance. + pretrained.model.distributed.reference_config = self.model.distributed + old_setup() - object.__setattr__(pretrained, "_setup", new_setup) + object.__setattr__(pretrained, "_setup", new_setup) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 66f1ad869..abd8f9dc0 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -12,11 +12,9 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.dataset.config import SamplingParameters -from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer @@ -55,9 +53,7 @@ def __init__(self, config: TrainerConfig): self._reference_models[name] = self._config.get_inference_runner_class()( reference_config.model.get_model_class()(reference_config.model) ) - self._multi_stage.base_model.add_preprocessor( - self._get_reference_model_preprocessor(name, self._reference_models[name]) - ) + self._multi_stage.base_model.add_reference_model(name, self._reference_models[name]) phase: PhaseType self._runner = ScheduleRunner( @@ -562,6 +558,3 @@ def _get_last_checkpoint(self) -> int | None: def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: # TODO: Do in model, automate/generalize, get other stats pass - - def _get_reference_model_preprocessor(self, name: str, inference_runner: InferenceRunner) -> Preprocessor: - raise NotImplementedError() diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 64893a3e1..e5afac160 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -188,29 +188,6 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) - # TODO: Support distinct preprocessing - reference_model.model.base_model.transformer.rotary.compare( - self.model.base_model.transformer.rotary, - NotImplementedError, - ) - Assert.eq( - reference_model.model.base_model.use_absolute_position_embeddings, - self.model.base_model.use_absolute_position_embeddings, - ) - if reference_model.model.base_model.use_absolute_position_embeddings: - assert self.model.base_model.use_absolute_position_embeddings - Assert.geq( - reference_model.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length - ) - use_flash = reference_model.model.base_model.transformer.do_use_flash_attention( - reference_model.model.distributed - ) - Assert.eq(use_flash, self.model.base_model.transformer.do_use_flash_attention(self.model.distributed)) - if use_flash: - Assert.eq( - reference_model.model.base_model.transformer.window_size, - self.model.base_model.transformer.window_size, - ) @classmethod def _from_dict( diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index c428bc011..bd733692e 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -407,7 +407,11 @@ def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]: (export_value,) = export_values - if export_value is None or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default": + if ( + export_value is None + or export_value is MISSING + or (rope_type := export_value[self._HUGGINGFACE_NAMES[0]]) == "default" + ): return (RotaryEmbeddingType.default,) + (DEFAULT,) * 7 elif rope_type == RotaryEmbeddingType.llama3: return ("llama3", *[export_value.get(key, DEFAULT) for key in self._HUGGINGFACE_NAMES[1:]]) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index f0aaf90b0..80c9caa23 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -8,7 +8,6 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType -from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames @@ -41,7 +40,6 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): """ config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - _is_setup: bool = False _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor _mask: torch.Tensor @@ -59,6 +57,7 @@ def __init__( for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa + # `self._reference_models` is not populated at this point, so we pass a mutable dict. self._preprocessors: list[Preprocessor] = [] if self._config.use_absolute_position_embeddings: self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) @@ -103,18 +102,15 @@ def get_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, layer_index=i + 1, + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) for i in range(self._config.transformer.num_layers) ], *self.get_output_layers(), ] - def setup(self, distributed: Distributed) -> None: - assert not self._is_setup - distributed.check_config(self._tensor_space.distributed_config) - self._tensor_space.setup(distributed) - self._is_setup = True - def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -128,7 +124,7 @@ def preprocess_meta( else: micro_batch_size, sequence_length = batch_meta.shape if phase != PhaseType.inference: - sequence_length -= 1 + sequence_length -= self._config.prediction_heads micro_sequence_length = sequence_length batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) @@ -182,12 +178,20 @@ def preprocess_meta( TransformerKwargs.sequence_q_dim: sequence_q_dim, } - preprocessed_meta = [] - for sequence_k_past in range( + sequence_k_pasts = range( sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, sequence_length, micro_sequence_length, - ): + ) + reference_preprocessed_metas = {} + for name, reference_model in self._reference_models.items(): + reference_preprocessed_metas[name] = reference_model.fast_llm_model.base_model.preprocess_meta( + batch_meta, PhaseType.inference + ) + Assert.eq(len(reference_preprocessed_metas[name]), len(sequence_k_pasts)) + + preprocessed_meta = [] + for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) @@ -205,6 +209,19 @@ def preprocess_meta( ) for preprocessor in self._preprocessors: preprocessor.preprocess_meta(kwargs) + reference_kwargs = {} + for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): + reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] + for key in ( + TransformerKwargs.sequence_first, + TransformerKwargs.sequence_length, + TransformerKwargs.sequence_q_dim, + TransformerKwargs.sequence_k_dim, + ): + Assert.eq(reference_kwargs_[key], kwargs[key]) + reference_kwargs[name] = reference_kwargs_ + kwargs["reference_models"] = reference_kwargs + preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -234,13 +251,30 @@ def preprocess( dtype=torch.int64, non_blocking=True, ) + + reference_logits = [{} for _ in preprocessed_meta] + for name, reference_model in self._reference_models.items(): + reference_preprocessed_meta = [ + (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta + ] + + reference_batch = reference_model.fast_llm_model.base_model.preprocess( + batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration + ) + + # TODO: Do things work with >1? + Assert.eq(len(reference_batch), len(preprocessed_meta), 1) + for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): + reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) + reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] + if sequence_first: # Move the sequence dimension first to make sequence parallel ops more efficient. batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() preprocessed = [] presents = None - for i, (tokens_meta, kwargs_meta) in enumerate(preprocessed_meta): + for i, (_, kwargs_meta) in enumerate(preprocessed_meta): sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] @@ -285,6 +319,8 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels + kwargs.update(reference_logits[i]) + for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) @@ -360,10 +396,6 @@ def loss_defs(self) -> list[LossDef]: ) return loss_defs - def add_preprocessor(self, preprocessor: Preprocessor): - assert not self._is_setup - self._preprocessors.append(preprocessor) - class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index a1c0c8bb7..3bdb05c3a 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -3,33 +3,13 @@ from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.gpt.config import GPTSamplingParameters -from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.models.gpt.config import GPTTrainerConfig -from fast_llm.models.gpt.model import GPTInferenceRunner logger = logging.getLogger(__name__) -class GPTReferenceModelPreprocessor(Preprocessor): - def __init__(self, name: str, inference_runner: GPTInferenceRunner): - self._name = name - self._inference_runner = inference_runner - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - pass - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - # TODO: Fix random state/iteration. - preprocess_kwargs = kwargs.copy() - del preprocess_kwargs[LanguageModelKwargs.labels] - self._inference_runner.forward(batch, preprocess_kwargs, iteration=1) - # TODO: Improve. - kwargs[f"{self._name}_logits"] = preprocess_kwargs["logits"] - - class GPTTrainer[ConfigType: GPTTrainerConfig](Trainer[ConfigType]): config_class: typing.ClassVar[type[GPTTrainerConfig]] = GPTTrainerConfig @@ -109,6 +89,3 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, hardware_flops = flops_per_iteration + 7 / 6 * attn_flops ratio = elapsed_time_per_iteration * self._config.model.distributed.world_size * 1e12 return model_tflops / ratio, hardware_flops / ratio - - def _get_reference_model_preprocessor(self, name: str, inference_runner: GPTInferenceRunner) -> Preprocessor: - return GPTReferenceModelPreprocessor(name, inference_runner)