From dbb6162d13c49f9233ddce25576d348b8c5fea5e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 1 Apr 2025 00:52:49 -0400 Subject: [PATCH 1/9] Reference model prototype --- fast_llm/engine/base_model/base_model.py | 6 +- fast_llm/engine/base_model/config.py | 11 ++++ .../{huggingface => inference}/__init__.py | 0 .../{huggingface => inference}/config.py | 0 .../model.py => inference/huggingface.py} | 40 +++---------- fast_llm/engine/inference/runner.py | 58 +++++++++++++++++++ fast_llm/engine/multi_stage/config.py | 2 +- fast_llm/engine/multi_stage/multi_stage.py | 1 + fast_llm/engine/training/config.py | 10 ++++ fast_llm/engine/training/trainer.py | 25 ++++++-- .../layers/language_model/preprocessing.py | 8 ++- fast_llm/layers/transformer/preprocessing.py | 19 +++--- fast_llm/models/custom/trainer.py | 2 - fast_llm/models/gpt/huggingface.py | 12 ++-- fast_llm/models/gpt/model.py | 49 +++++++--------- fast_llm/models/gpt/trainer.py | 23 +++++++- 16 files changed, 180 insertions(+), 86 deletions(-) rename fast_llm/engine/{huggingface => inference}/__init__.py (100%) rename fast_llm/engine/{huggingface => inference}/config.py (100%) rename fast_llm/engine/{huggingface/model.py => inference/huggingface.py} (58%) create mode 100644 fast_llm/engine/inference/runner.py diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 7233c1838..76da0f9b0 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -6,7 +6,7 @@ import torch.nn from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig, Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed @@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: @abc.abstractmethod def loss_defs(self) -> list[LossDef]: pass + + def add_preprocessor(self, preprocessor: Preprocessor): + # TODO: Generalize preprocessors. + raise NotImplementedError() diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 409c2c83b..891fdbb0a 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -1,3 +1,4 @@ +import abc import typing from fast_llm.config import Config, config_class @@ -40,3 +41,13 @@ class BaseModelConfig(BaseModelArchitectureConfig): def get_architecture(self) -> BaseModelArchitectureConfig: return self.architecture_class.from_dict(self, strict=False) + + +class Preprocessor(abc.ABC): + @abc.abstractmethod + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + pass + + @abc.abstractmethod + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + pass diff --git a/fast_llm/engine/huggingface/__init__.py b/fast_llm/engine/inference/__init__.py similarity index 100% rename from fast_llm/engine/huggingface/__init__.py rename to fast_llm/engine/inference/__init__.py diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/inference/config.py similarity index 100% rename from fast_llm/engine/huggingface/config.py rename to fast_llm/engine/inference/config.py diff --git a/fast_llm/engine/huggingface/model.py b/fast_llm/engine/inference/huggingface.py similarity index 58% rename from fast_llm/engine/huggingface/model.py rename to fast_llm/engine/inference/huggingface.py index 499f0af12..9a213691d 100644 --- a/fast_llm/engine/huggingface/model.py +++ b/fast_llm/engine/inference/huggingface.py @@ -4,20 +4,16 @@ import transformers.modeling_outputs -from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig +from fast_llm.engine.inference.config import HuggingfaceModelConfig +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule class HuggingfacePreTrainedModel(transformers.PreTrainedModel): config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig - model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel + runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner config: HuggingfaceModelConfig # base_model_prefix = "" # _no_split_modules = None @@ -25,36 +21,18 @@ class HuggingfacePreTrainedModel(transformers.PreTrainedModel): # _tied_weights_keys = [] def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, **kwargs): - assert self.model_class.config_class is config.model_config_class + assert self.runner_class.model_class.config_class is config.model_config_class assert config.fast_llm_config is fast_llm_model.config assert isinstance(config, self.config_class) + super().__init__(config, **kwargs) - self._fast_llm_config = config.fast_llm_config - self._fast_llm_model = fast_llm_model + + self._inference_runner = self.runner_class(fast_llm_model) # Transformers needs to be able to inspect the base model. self.fast_llm_base_model = self._fast_llm_model.base_model - self._distributed_config = self._fast_llm_config.distributed # TODO: Support distributed models? - assert self._distributed_config.world_size == 1 - self._schedule_config = ScheduleConfig() - # We only need a basic schedule and don't care about dimensions. - # TODO: Sort things out. - with NoAutoValidate(): - self._batch_config = BatchConfig() - self._batch_config.setup(self._distributed_config) - self._batch_config.validate() - self._runner = ScheduleRunner( - config=self._schedule_config, multi_stage=self._fast_llm_model, distributed_config=self._distributed_config - ) - self._runner.setup(self._fast_llm_model.distributed) - # TODO: Random state? (Distributed.set_step) - self._schedule = Schedule( - multi_stage=self._fast_llm_model, - batch_config=self._batch_config, - schedule_config=self._schedule_config, - distributed_config=self._distributed_config, - phase=PhaseType.inference, - ) + assert fast_llm_model.config.distributed.world_size == 1 + with transformers.modeling_utils.no_init_weights(): self.post_init() diff --git a/fast_llm/engine/inference/runner.py b/fast_llm/engine/inference/runner.py new file mode 100644 index 000000000..89426616b --- /dev/null +++ b/fast_llm/engine/inference/runner.py @@ -0,0 +1,58 @@ +import abc +import typing + +from fast_llm.config import NoAutoValidate +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule + + +class InferenceRunner(abc.ABC): + model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel + + def __init__(self, fast_llm_model: FastLLMModel): + assert isinstance(fast_llm_model, self.model_class) + self._fast_llm_model = fast_llm_model + # We only need a basic schedule and don't care about dimensions. + self._schedule_config = ScheduleConfig() + # TODO: Sort things out. + with NoAutoValidate(): + self._batch_config = BatchConfig() + self._batch_config.setup(self._fast_llm_model.config.distributed) + self._batch_config.validate() + self._runner = ScheduleRunner( + config=self._schedule_config, + multi_stage=self._fast_llm_model, + distributed_config=self._fast_llm_model.config.distributed, + ) + # TODO: Random state? (Distributed.set_step) + self._schedule = Schedule( + multi_stage=self._fast_llm_model, + batch_config=self._batch_config, + schedule_config=self._schedule_config, + distributed_config=self._fast_llm_model.config.distributed, + phase=PhaseType.inference, + ) + + @property + def fast_llm_model(self) -> FastLLMModel: + return self._fast_llm_model + + def setup(self): + self._runner.setup(self._fast_llm_model.distributed) + + def forward( + self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False + ) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]: + # TODO: Return an actual model output. + reduced_losses, update_successful, metrics = self._runner.run_step( + iter(((input_, kwargs),)), + self._schedule, + iteration=iteration, + return_metrics=return_metrics, + preprocessed=True, + ) + assert update_successful + return reduced_losses, metrics diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index fd7ba6453..e6de074f4 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -29,7 +29,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.huggingface.model import HuggingfacePreTrainedModel + from fast_llm.engine.inference.model import HuggingfacePreTrainedModel from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 238bd865e..53e9c0846 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -442,6 +442,7 @@ def is_parameter_on_device(self, parameter_name: str) -> bool: @property def distributed(self) -> Distributed: + assert self._is_setup return self._distributed def invalidate_buffers(self) -> None: diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index dac4e5533..fcc03a484 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -21,6 +21,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.training.trainer import Trainer @@ -364,6 +365,11 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): desc="Configuration for the training optimizer and learning rate schedule.", hint=FieldHint.core, ) + reference_models: dict[str, PretrainedFastLLMModelConfig] = Field( + default_factory=dict, + desc="Additional models used during training, ex. for knowledge distillation.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.training.export.setup(self.model) @@ -379,6 +385,10 @@ def _setup(self): def get_trainer_class(cls) -> type["Trainer"]: raise NotImplementedError + @classmethod + def get_inference_runner_class(cls) -> type["InferenceRunner"]: + raise NotImplementedError + def _get_runnable(self) -> typing.Callable[[], None]: from fast_llm.engine.distributed.distributed import Distributed diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index f2ed4a38f..1e71794ef 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -11,10 +11,12 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.inference.runner import InferenceRunner +from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer from fast_llm.engine.schedule.runner import ScheduleRunner @@ -29,7 +31,6 @@ class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): config_class: typing.ClassVar[type[TrainerConfig]] = TrainerConfig - model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel # TODO: Generalize data, schedule, logging, etc. _is_setup: bool = False _distributed: Distributed @@ -43,10 +44,21 @@ def __init__(self, config: TrainerConfig): super().__init__(config) self._data = self._get_data() log_main_rank("Creating model...") - self._multi_stage = self.model_class( + self._multi_stage = self._config.model.get_model_class()( self._config.model, optimizer_state_names=self._config.optimizer.state_names(), ) + self._reference_models = { + name: self._config.get_inference_runner_class()( + reference_config.model.get_model_class()(reference_config.model) + ) + for name, reference_config in self._config.reference_models.items() + } + for name, inference_runner in self._reference_models.items(): + self._multi_stage.base_model.add_preprocessor( + self._get_reference_model_preprocessor(name, inference_runner) + ) + phase: PhaseType self._runner = ScheduleRunner( config=self._config.schedule, @@ -103,6 +115,9 @@ def setup(self, distributed: Distributed, run: Run) -> None: # Setup the model. with torch.no_grad(): self._multi_stage.setup(distributed) + for reference_model in self._reference_models.values(): + reference_model.fast_llm_model.setup(distributed, StageMode.inference) + reference_model.setup() # Setup the optimizer. param_groups, grads_for_norm = self._multi_stage.get_param_groups(ParamGroup) @@ -116,7 +131,6 @@ def setup(self, distributed: Distributed, run: Run) -> None: # Setup the schedules. with torch.no_grad(): self._runner.setup(distributed, self._optimizer) - # Setup the datasets. log_main_rank("Preparing datasets...") self._data.setup( @@ -526,3 +540,6 @@ def _get_last_checkpoint(self) -> int | None: def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: # TODO: Do in model, automate/generalize, get other stats pass + + def _get_reference_model_preprocessor(self, name: str, inference_runner: InferenceRunner) -> Preprocessor: + raise NotImplementedError() diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index e6f480aa1..7e95bb5cc 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -3,6 +3,7 @@ import torch +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.transformer.config import TransformerKwargs @@ -12,7 +13,7 @@ logger = logging.getLogger(__name__) -class PositionEmbeddingPreprocessor: +class PositionEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor @@ -29,7 +30,7 @@ def __init__( self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - def create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -39,7 +40,8 @@ def create_tensors(self, sequence_length: int) -> None: 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths)) is not None: diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index cbafe6c97..c45a1d2d4 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -4,6 +4,7 @@ import torch +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.rotary import convert_rotary_complex_to_real from fast_llm.layers.transformer.config import ( @@ -129,7 +130,7 @@ def get_rotary_frequencies( return frequencies -class RotaryEmbeddingPreprocessor: +class RotaryEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor @@ -149,7 +150,7 @@ def __init__( self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) - def create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -161,7 +162,8 @@ def create_tensors(self, sequence_length: int) -> None: device=self._tensor_space.distributed.device, ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k @@ -189,7 +191,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: ) -class BackupAttentionPreprocessor: +class BackupAttentionPreprocessor(Preprocessor): _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor @@ -208,7 +210,7 @@ def __init__( assert not self._config.do_use_flash_attention(self._distributed_config) self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - def create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -228,7 +230,8 @@ def create_tensors(self, sequence_length: int) -> None: device=self._tensor_space.distributed.device, ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size kwargs[TransformerKwargs.attention_mask] = self._mask[ @@ -264,14 +267,14 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: ) -class FlashAttnVarlenPreprocessor: +class FlashAttnVarlenPreprocessor(Preprocessor): def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 diff --git a/fast_llm/models/custom/trainer.py b/fast_llm/models/custom/trainer.py index 2c1b3b5f2..eba51235e 100644 --- a/fast_llm/models/custom/trainer.py +++ b/fast_llm/models/custom/trainer.py @@ -2,14 +2,12 @@ from fast_llm.models.custom.config import CustomTrainerConfig from fast_llm.models.custom.data import CustomData -from fast_llm.models.custom.model import CustomModel from fast_llm.models.gpt.trainer import GPTTrainer class CustomTrainer[ConfigType: CustomTrainerConfig](GPTTrainer[ConfigType]): # TODO: Implement changes in the training loop (or tflops computation), if any (typically none). config_class: typing.ClassVar[type[CustomTrainerConfig]] = CustomTrainerConfig - model_class: typing.ClassVar[type[CustomModel]] = CustomModel def _get_data(self): # TODO: Adjust signature if needed. diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index e4db9b07e..9627c68c4 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -1,13 +1,15 @@ import logging import random +import typing import torch import transformers.modeling_outputs from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig -from fast_llm.engine.huggingface.model import HuggingfacePreTrainedModel +from fast_llm.engine.inference.config import HuggingfaceModelConfig +from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTModel @@ -24,7 +26,7 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel): config_class = HuggingfaceGPTModelConfig config: HuggingfaceGPTModelConfig - model_class = GPTModel + runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner _fast_llm_model: GPTModel # base_model_prefix = "" # _no_split_modules = None @@ -70,7 +72,7 @@ def forward( batch = self._fast_llm_model.base_model.preprocess( GPTBatch(input_ids), phase=PhaseType.inference, iteration=iteration ) - ((_, kwargs),) = batch + ((input_, kwargs),) = batch if past_key_values is not None: # The transformers will use the past keys and values to this list. @@ -81,7 +83,7 @@ def forward( # The transformers will save the present keys and values to this list. kwargs[TransformerKwargs.presents] = [] - _, _, _ = self._runner.run_step(iter((batch,)), self._schedule, iteration=iteration, preprocessed=True) + self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. logits = kwargs["logits"] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index d29fc28da..dda5e7c28 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -5,9 +5,11 @@ from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames @@ -58,18 +60,17 @@ def __init__( for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa + self._preprocessors: list[Preprocessor] = [] if self._config.use_absolute_position_embeddings: - self._position_embedding_preprocessor = PositionEmbeddingPreprocessor(self._config, self._tensor_space) + self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) if self._config.transformer.rotary.enabled: - self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - self._config.transformer.rotary, self._tensor_space + self._preprocessors.append( + RotaryEmbeddingPreprocessor(self._config.transformer.rotary, self._tensor_space) ) if not self._use_flash_attention: - self._backup_attention_preprocessor = BackupAttentionPreprocessor( - self._config.transformer, self._tensor_space - ) + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) else: - self._flash_varlen_preprocessor = FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space) + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) def get_layers(self) -> list[Layer]: return [ @@ -179,12 +180,8 @@ def preprocess_meta( kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( hidden_dims[:2], tensor_name="labels", dtype=torch.int64 ) - if self._config.use_absolute_position_embeddings: - self._position_embedding_preprocessor.preprocess_meta(kwargs) - if self._config.transformer.rotary.enabled: - self._rotary_embedding_preprocessor.preprocess_meta(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess_meta(kwargs) + for preprocessor in self._preprocessors: + preprocessor.preprocess_meta(kwargs) preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -207,7 +204,6 @@ def preprocess( _, common_kwargs = preprocessed_meta[0] sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size sequence_first = common_kwargs[TransformerKwargs.sequence_first] - sequence_length = common_kwargs[TransformerKwargs.sequence_length] batch.token_ids = batch.token_ids.to( device=self._tensor_space.distributed.device, @@ -218,13 +214,6 @@ def preprocess( # Move the sequence dimension first to make sequence parallel ops more efficient. batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() - if self._config.use_absolute_position_embeddings: - self._position_embedding_preprocessor.create_tensors(sequence_length) - if self._config.transformer.rotary.enabled: - self._rotary_embedding_preprocessor.create_tensors(sequence_length) - if not self._use_flash_attention: - self._backup_attention_preprocessor.create_tensors(sequence_length) - preprocessed = [] presents = None for i, (tokens_meta, kwargs_meta) in enumerate(preprocessed_meta): @@ -236,8 +225,6 @@ def preprocess( tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths - if self._use_flash_attention: - self._flash_varlen_preprocessor.preprocess(kwargs_meta) # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. @@ -272,12 +259,8 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels - if self._config.use_absolute_position_embeddings: - self._position_embedding_preprocessor.preprocess(kwargs) - if self._config.transformer.rotary.enabled: - self._rotary_embedding_preprocessor.preprocess(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess(kwargs) + for preprocessor in self._preprocessors: + preprocessor.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) return preprocessed @@ -329,7 +312,15 @@ def loss_defs(self) -> list[LossDef]: LossDef(name=LanguageModelLossNames.z_loss, formatted_name="logit z loss", count=1) return loss_defs + def add_preprocessor(self, preprocessor: Preprocessor): + assert not self._is_setup + self._preprocessors.append(preprocessor) + class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel + + +class GPTInferenceRunner(InferenceRunner): + model_class: typing.ClassVar[type[GPTModel]] = GPTModel diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 376d8b840..173c6af63 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -2,17 +2,33 @@ import typing from fast_llm.data.data.gpt.data import GPTData +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer from fast_llm.models.gpt.config import GPTTrainerConfig -from fast_llm.models.gpt.model import GPTModel +from fast_llm.models.gpt.model import GPTInferenceRunner logger = logging.getLogger(__name__) +class GPTReferenceModelPreprocessor(Preprocessor): + def __init__(self, name: str, inference_runner: GPTInferenceRunner): + self._name = name + self._inference_runner = inference_runner + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + pass + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + # TODO: Fix random state/iteration. + preprocess_kwargs = kwargs.copy() + self._inference_runner.forward(batch, preprocess_kwargs, iteration=1) + # TODO: Improve. + kwargs[f"{self._name}_logits"] = preprocess_kwargs["logits"] + + class GPTTrainer[ConfigType: GPTTrainerConfig](Trainer[ConfigType]): config_class: typing.ClassVar[type[GPTTrainerConfig]] = GPTTrainerConfig - model_class: typing.ClassVar[type[GPTModel]] = GPTModel def _get_data(self) -> GPTData: return GPTData( @@ -71,3 +87,6 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, hardware_flops = flops_per_iteration + 7 / 6 * attn_flops ratio = elapsed_time_per_iteration * self._config.model.distributed.world_size * 1e12 return model_tflops / ratio, hardware_flops / ratio + + def _get_reference_model_preprocessor(self, name: str, inference_runner: GPTInferenceRunner) -> Preprocessor: + return GPTReferenceModelPreprocessor(name, inference_runner) From ed8ec43014d0669976054f3120014a5e16cedab2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 19:49:34 -0400 Subject: [PATCH 2/9] Generalize preprocessor --- fast_llm/engine/base_model/base_model.py | 6 ++- fast_llm/engine/base_model/config.py | 11 +++++ .../layers/language_model/preprocessing.py | 8 ++-- fast_llm/layers/transformer/preprocessing.py | 19 ++++---- fast_llm/models/gpt/model.py | 44 +++++++------------ 5 files changed, 47 insertions(+), 41 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index 7233c1838..76da0f9b0 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -6,7 +6,7 @@ import torch.nn from fast_llm.config import Configurable -from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig +from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig, Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed @@ -135,3 +135,7 @@ def get_tied_weights(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]]]: @abc.abstractmethod def loss_defs(self) -> list[LossDef]: pass + + def add_preprocessor(self, preprocessor: Preprocessor): + # TODO: Generalize preprocessors. + raise NotImplementedError() diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 409c2c83b..891fdbb0a 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -1,3 +1,4 @@ +import abc import typing from fast_llm.config import Config, config_class @@ -40,3 +41,13 @@ class BaseModelConfig(BaseModelArchitectureConfig): def get_architecture(self) -> BaseModelArchitectureConfig: return self.architecture_class.from_dict(self, strict=False) + + +class Preprocessor(abc.ABC): + @abc.abstractmethod + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + pass + + @abc.abstractmethod + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + pass diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index e6f480aa1..7e95bb5cc 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -3,6 +3,7 @@ import torch +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.layers.transformer.config import TransformerKwargs @@ -12,7 +13,7 @@ logger = logging.getLogger(__name__) -class PositionEmbeddingPreprocessor: +class PositionEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor @@ -29,7 +30,7 @@ def __init__( self._distributed_config = self._tensor_space.distributed_config self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - def create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -39,7 +40,8 @@ def create_tensors(self, sequence_length: int) -> None: 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths)) is not None: diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index cbafe6c97..c45a1d2d4 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -4,6 +4,7 @@ import torch +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.rotary import convert_rotary_complex_to_real from fast_llm.layers.transformer.config import ( @@ -129,7 +130,7 @@ def get_rotary_frequencies( return frequencies -class RotaryEmbeddingPreprocessor: +class RotaryEmbeddingPreprocessor(Preprocessor): _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor @@ -149,7 +150,7 @@ def __init__( self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) - def create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -161,7 +162,8 @@ def create_tensors(self, sequence_length: int) -> None: device=self._tensor_space.distributed.device, ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k @@ -189,7 +191,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: ) -class BackupAttentionPreprocessor: +class BackupAttentionPreprocessor(Preprocessor): _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor @@ -208,7 +210,7 @@ def __init__( assert not self._config.do_use_flash_attention(self._distributed_config) self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - def create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -228,7 +230,8 @@ def create_tensors(self, sequence_length: int) -> None: device=self._tensor_space.distributed.device, ) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[TransformerKwargs.sequence_length]) sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size kwargs[TransformerKwargs.attention_mask] = self._mask[ @@ -264,14 +267,14 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: ) -class FlashAttnVarlenPreprocessor: +class FlashAttnVarlenPreprocessor(Preprocessor): def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - def preprocess(self, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index e878530cf..a118944e5 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -5,6 +5,7 @@ from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.distributed.distributed import Distributed @@ -58,18 +59,17 @@ def __init__( for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) param.init_parameter = get_init_megatron(param, self._config.transformer) # Noqa + self._preprocessors: list[Preprocessor] = [] if self._config.use_absolute_position_embeddings: - self._position_embedding_preprocessor = PositionEmbeddingPreprocessor(self._config, self._tensor_space) + self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) if self._config.transformer.rotary.enabled: - self._rotary_embedding_preprocessor = RotaryEmbeddingPreprocessor( - self._config.transformer.rotary, self._tensor_space + self._preprocessors.append( + RotaryEmbeddingPreprocessor(self._config.transformer.rotary, self._tensor_space) ) if not self._use_flash_attention: - self._backup_attention_preprocessor = BackupAttentionPreprocessor( - self._config.transformer, self._tensor_space - ) + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) else: - self._flash_varlen_preprocessor = FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space) + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) def get_output_layers(self) -> list[Layer]: return [ @@ -207,12 +207,8 @@ def preprocess_meta( kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( hidden_dims[:2], tensor_name="labels", dtype=torch.int64 ) - if self._config.use_absolute_position_embeddings: - self._position_embedding_preprocessor.preprocess_meta(kwargs) - if self._config.transformer.rotary.enabled: - self._rotary_embedding_preprocessor.preprocess_meta(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess_meta(kwargs) + for preprocessor in self._preprocessors: + preprocessor.preprocess_meta(kwargs) preprocessed_meta.append((tokens, kwargs)) return preprocessed_meta @@ -235,7 +231,6 @@ def preprocess( _, common_kwargs = preprocessed_meta[0] sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size sequence_first = common_kwargs[TransformerKwargs.sequence_first] - sequence_length = common_kwargs[TransformerKwargs.sequence_length] batch.token_ids = batch.token_ids.to( device=self._tensor_space.distributed.device, @@ -246,13 +241,6 @@ def preprocess( # Move the sequence dimension first to make sequence parallel ops more efficient. batch.token_ids = batch.token_ids.transpose(0, 1).contiguous() - if self._config.use_absolute_position_embeddings: - self._position_embedding_preprocessor.create_tensors(sequence_length) - if self._config.transformer.rotary.enabled: - self._rotary_embedding_preprocessor.create_tensors(sequence_length) - if not self._use_flash_attention: - self._backup_attention_preprocessor.create_tensors(sequence_length) - preprocessed = [] presents = None for i, (tokens_meta, kwargs_meta) in enumerate(preprocessed_meta): @@ -264,8 +252,6 @@ def preprocess( tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths - if self._use_flash_attention: - self._flash_varlen_preprocessor.preprocess(kwargs_meta) # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. @@ -300,12 +286,8 @@ def preprocess( else: labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels - if self._config.use_absolute_position_embeddings: - self._position_embedding_preprocessor.preprocess(kwargs) - if self._config.transformer.rotary.enabled: - self._rotary_embedding_preprocessor.preprocess(kwargs) - if not self._use_flash_attention: - self._backup_attention_preprocessor.preprocess(kwargs) + for preprocessor in self._preprocessors: + preprocessor.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) return preprocessed @@ -379,6 +361,10 @@ def loss_defs(self) -> list[LossDef]: ) return loss_defs + def add_preprocessor(self, preprocessor: Preprocessor): + assert not self._is_setup + self._preprocessors.append(preprocessor) + class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig From 8fe5973b8cb5341383e2ff7f006b5a8fba27339c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 20:21:25 -0400 Subject: [PATCH 3/9] fixes --- fast_llm/engine/base_model/config.py | 1 - fast_llm/layers/transformer/preprocessing.py | 4 +++- fast_llm/models/gpt/model.py | 6 +++--- tests/test_attention.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 891fdbb0a..565aad223 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -44,7 +44,6 @@ def get_architecture(self) -> BaseModelArchitectureConfig: class Preprocessor(abc.ABC): - @abc.abstractmethod def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index c45a1d2d4..2415a2f91 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -284,7 +284,9 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - sequence_lengths = kwargs.get(TransformerKwargs.sequence_lengths) + if TransformerKwargs.sequence_lengths not in kwargs: + return + sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size if sequence_q < kwargs[TransformerKwargs.sequence_length]: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a118944e5..7e5c5d33b 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -66,10 +66,10 @@ def __init__( self._preprocessors.append( RotaryEmbeddingPreprocessor(self._config.transformer.rotary, self._tensor_space) ) - if not self._use_flash_attention: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - else: + if self._use_flash_attention: self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + else: + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) def get_output_layers(self) -> list[Layer]: return [ diff --git a/tests/test_attention.py b/tests/test_attention.py index fc1eb80b7..87b0d3e59 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -84,6 +84,6 @@ def test_varlen_preprocessor(): TransformerKwargs.sequence_length: sequence_length, TransformerKwargs.sequence_lengths: sequence_lengths, } - varlen_preprocessor.preprocess(kwargs) + varlen_preprocessor.preprocess(None, kwargs) Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) From 191308cfe6f9c1538ec55efa4023b2e07f95b223 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 21:10:15 -0400 Subject: [PATCH 4/9] fixes --- fast_llm/engine/inference/huggingface.py | 7 +++++-- fast_llm/engine/inference/runner.py | 2 +- fast_llm/engine/multi_stage/multi_stage.py | 11 +++++++++-- fast_llm/models/gpt/huggingface.py | 9 ++++----- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 9a213691d..30e03b907 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -28,8 +28,11 @@ def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, super().__init__(config, **kwargs) self._inference_runner = self.runner_class(fast_llm_model) + if not fast_llm_model.is_setup: + fast_llm_model.setup(mode=StageMode.inference) + self._inference_runner.setup() # Transformers needs to be able to inspect the base model. - self.fast_llm_base_model = self._fast_llm_model.base_model + self.fast_llm_base_model = fast_llm_model.base_model # TODO: Support distributed models? assert fast_llm_model.config.distributed.world_size == 1 @@ -57,7 +60,7 @@ def from_pretrained( config_updates[("distributed", "training_dtype")] = torch_dtype # Create the model - fast_llm_model = cls.model_class.from_pretrained( + fast_llm_model = cls.runner_class.model_class.from_pretrained( pretrained_model_name_or_path, config_updates=config_updates, mode=mode ) config = cls.config_class(fast_llm_model.config) diff --git a/fast_llm/engine/inference/runner.py b/fast_llm/engine/inference/runner.py index 89426616b..b83a5332a 100644 --- a/fast_llm/engine/inference/runner.py +++ b/fast_llm/engine/inference/runner.py @@ -48,7 +48,7 @@ def forward( ) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]: # TODO: Return an actual model output. reduced_losses, update_successful, metrics = self._runner.run_step( - iter(((input_, kwargs),)), + iter((((input_, kwargs),),)), self._schedule, iteration=iteration, return_metrics=return_metrics, diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 53e9c0846..c28e2537e 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -209,12 +209,15 @@ def __init__( "Bfloat16 gradient accumulation and reduction is not recommended. (use --full_precision_gradients=1)" ) - def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) -> None: + def setup(self, distributed: Distributed | None = None, mode: StageMode = StageMode.training) -> None: # TODO: More checks? stage: Stage - assert distributed.config is self._config.distributed assert not self._is_setup self._is_setup = True + if distributed is None: + distributed = Distributed(self._config.distributed) + else: + assert distributed.config is self._config.distributed self._distributed = distributed self._mode = mode self._base_model.setup(distributed) @@ -381,6 +384,10 @@ def get_shard(self, name: str) -> torch.Tensor: raise KeyError(f"Unknown shard name {name}") return self._shards[name] + @property + def is_setup(self) -> bool: + return self._is_setup + @property def support_forward(self) -> bool: assert self._is_setup diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 9627c68c4..0da4acbb4 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,10 +9,9 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel -from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTModelConfig -from fast_llm.models.gpt.model import GPTModel +from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner logger = logging.getLogger(__name__) @@ -26,8 +25,8 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel): config_class = HuggingfaceGPTModelConfig config: HuggingfaceGPTModelConfig - runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner - _fast_llm_model: GPTModel + runner_class: typing.ClassVar[type[GPTInferenceRunner]] = GPTInferenceRunner + fast_llm_base_model: GPTBaseModel # base_model_prefix = "" # _no_split_modules = None # _supports_cache_class = False @@ -69,7 +68,7 @@ def forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) - batch = self._fast_llm_model.base_model.preprocess( + batch = self.fast_llm_base_model.preprocess( GPTBatch(input_ids), phase=PhaseType.inference, iteration=iteration ) ((input_, kwargs),) = batch From 4e6eaeeca980b1dae535bfa29db937fcb0fd46c0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 22:40:00 -0400 Subject: [PATCH 5/9] fix --- fast_llm/engine/config_utils/tensor_space.py | 3 +- fast_llm/engine/distributed/config.py | 134 +++++++++++-------- fast_llm/engine/distributed/distributed.py | 8 ++ fast_llm/engine/multi_stage/multi_stage.py | 2 +- fast_llm/engine/multi_stage/stage_base.py | 2 +- fast_llm/engine/schedule/runner.py | 2 +- fast_llm/engine/training/config.py | 14 ++ fast_llm/models/custom/config.py | 2 +- fast_llm/models/gpt/config.py | 10 +- fast_llm/models/gpt/model.py | 2 +- 10 files changed, 115 insertions(+), 64 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 8bc86b733..0384fdacd 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -119,8 +119,9 @@ def __init__(self, distributed_config: DistributedConfig): self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1)) def setup(self, distributed: "Distributed") -> None: - assert distributed.config is self._distributed_config assert not self._is_setup + if distributed.config is not self._distributed_config: + distributed.config.compare(self._distributed_config, ValueError) self._is_setup = True self._distributed = distributed diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 1b3e73bb6..1330ab2cf 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -251,6 +251,12 @@ class DistributedConfig(Config): desc="Ensure the initialization is the same for any distributed configuration.", hint=FieldHint.testing, ) + reference_config: "DistributedConfig|None" = Field( + default=None, + init=False, + desc="Pointer to the distributed config this one is an identical copy of.", + hint=FieldHint.derived, + ) def _validate(self) -> None: if self.world_size is None: @@ -281,76 +287,90 @@ def _validate(self) -> None: if self.tensor_parallel == 1: self.sequence_tensor_parallel = False - self.distributed_dims = {} + if self.reference_config is not None: + self.reference_config.validate() + if self.reference_config.reference_config is not None: + self.reference_config = self.reference_config.reference_config + assert self.reference_config.reference_config is None + self.compare(self.reference_config, ValueError) + self.distributed_dims = self.reference_config.distributed_dims + else: + self.distributed_dims = {} - self.add_distributed_dim( - DistributedDim(name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.data, - size=self.data_parallel, - rank=self.data_rank, - id_=f"x_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.world, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.pipeline, - size=self.pipeline_parallel, - rank=self.pipeline_rank, - id_=f"x_{self.data_rank}_{self.tensor_rank}", - parent=DistributedDimNames.world, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.data, + size=self.data_parallel, + rank=self.data_rank, + id_=f"x_{self.pipeline_rank}_{self.tensor_rank}", + parent=DistributedDimNames.world, + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor, - size=self.tensor_parallel, - rank=self.tensor_rank, - id_=f"x_{self.data_rank}_{self.pipeline_rank}", - parent=DistributedDimNames.world, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.pipeline, + size=self.pipeline_parallel, + rank=self.pipeline_rank, + id_=f"x_{self.data_rank}_{self.tensor_rank}", + parent=DistributedDimNames.world, + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.sequence_data, - size=self.sequence_data_parallel, - rank=self.sequence_data_rank, - id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.data, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.tensor, + size=self.tensor_parallel, + rank=self.tensor_rank, + id_=f"x_{self.data_rank}_{self.pipeline_rank}", + parent=DistributedDimNames.world, + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.batch_data, - size=self.batch_data_parallel, - rank=self.batch_data_rank, - id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.data, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.sequence_data, + size=self.sequence_data_parallel, + rank=self.sequence_data_rank, + id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", + parent=DistributedDimNames.data, + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor_and_sequence_data, - size=self.sequence_data_parallel * self.tensor_parallel, - rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel, - id_=f"{self.batch_data_rank}_{self.pipeline_rank}", - parent=( - DistributedDimNames.tensor - if self.sequence_data_parallel == 1 - else DistributedDimNames.sequence_data if self.tensor_parallel == 1 else DistributedDimNames.world - ), + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.batch_data, + size=self.batch_data_parallel, + rank=self.batch_data_rank, + id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", + parent=DistributedDimNames.data, + ) + ) + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.tensor_and_sequence_data, + size=self.sequence_data_parallel * self.tensor_parallel, + rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel, + id_=f"{self.batch_data_rank}_{self.pipeline_rank}", + parent=( + DistributedDimNames.tensor + if self.sequence_data_parallel == 1 + else ( + DistributedDimNames.sequence_data + if self.tensor_parallel == 1 + else DistributedDimNames.world + ) + ), + ) ) - ) super()._validate() Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) - def add_distributed_dim(self, distributed_dim: DistributedDim) -> None: + def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: if distributed_dim.name in self.distributed_dims: Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name]) else: diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index a612d9cf9..42ec97f2e 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -33,6 +33,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): def __init__(self, config: DistributedConfig, use_cpu: bool = False): super().__init__(config) + assert self._config.reference_config is None self._use_cpu = use_cpu if self._use_cpu: @@ -148,6 +149,13 @@ def add_group(self, distributed_dim: DistributedDim) -> ProcessGroup | None: distributed_dim.setup(group) return group + def check_config(self, config: DistributedConfig) -> None: + # Allows using this `Distributed` on a model with a distributed config that is a copy of `self._config` + if config.reference_config is None: + Assert.is_(config, self._config) + else: + Assert.is_(config.reference_config, self._config) + def set_step(self, step: int, phase: PhaseType) -> None: """ Reseed pytorch for a given training step. diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index c28e2537e..21d0fe557 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -217,7 +217,7 @@ def setup(self, distributed: Distributed | None = None, mode: StageMode = StageM if distributed is None: distributed = Distributed(self._config.distributed) else: - assert distributed.config is self._config.distributed + distributed.check_config(self._config.distributed) self._distributed = distributed self._mode = mode self._base_model.setup(distributed) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 0f83c862d..da7eb7d88 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -128,7 +128,7 @@ def setup( mode: StageMode = StageMode.training, ) -> None: assert not self._is_setup - assert distributed.config is self._distributed_config + distributed.check_config(self._distributed_config) self._mode = mode self._is_setup = True self._distributed = distributed diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 1d4b04c1c..4a5425eed 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -97,7 +97,7 @@ def __init__( def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup - assert distributed.config is self._distributed_config + distributed.check_config(self._distributed_config) self._is_setup = True self._optimizer = optimizer assert self._multi_stage.support_forward diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index fcc03a484..226bde105 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -14,6 +14,7 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -373,6 +374,9 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): def _validate(self) -> None: self.training.export.setup(self.model) + self.model.validate() + for reference_model in self.reference_models.values(): + _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) super()._validate() if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() @@ -402,3 +406,13 @@ def runnable(): trainer.run() return runnable + + +def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelConfig, distributed: DistributedConfig): + old_setup = pretrained._setup + + def new_setup(): + pretrained.model.distributed.reference_config = distributed + old_setup() + + pretrained._setup = new_setup diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index b86722f82..1c82b59e4 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -61,8 +61,8 @@ class PretrainedCustomModelConfig(PretrainedGPTModelConfig): @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) + reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate(default_factory=PretrainedCustomModelConfig) @classmethod def get_trainer_class(cls) -> type["CustomTrainer"]: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 5a21368fa..e9ab01297 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -10,7 +10,7 @@ if typing.TYPE_CHECKING: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM - from fast_llm.models.gpt.model import GPTModel + from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel from fast_llm.models.gpt.trainer import GPTTrainer @@ -129,6 +129,8 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) + # TODO: Use dynamic model type? + reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate(default_factory=PretrainedGPTModelConfig) def _validate(self) -> None: if self.batch.sequence_length is None: @@ -143,3 +145,9 @@ def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer return GPTTrainer + + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + return GPTInferenceRunner diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9a5dcd5e9..c672b2165 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -116,7 +116,7 @@ def get_layers(self) -> list[Layer]: def setup(self, distributed: Distributed) -> None: assert not self._is_setup - assert distributed.config is self._tensor_space.distributed_config + distributed.check_config(self._tensor_space.distributed_config) self._tensor_space.setup(distributed) self._is_setup = True From e2fe55892e79ce23d1af13d9035c38f7e1ea2148 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 23:10:03 -0400 Subject: [PATCH 6/9] fix --- fast_llm/engine/training/config.py | 2 +- fast_llm/engine/training/trainer.py | 17 ++++++++++++++++- fast_llm/models/gpt/config.py | 2 +- fast_llm/models/gpt/trainer.py | 2 ++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 226bde105..fc7c56e01 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -368,7 +368,7 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): ) reference_models: dict[str, PretrainedFastLLMModelConfig] = Field( default_factory=dict, - desc="Additional models used during training, ex. for knowledge distillation.", + desc="Auxiliary models used during training, ex. for knowledge distillation.", hint=FieldHint.feature, ) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index f167b26e0..b284b3a72 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -114,8 +114,10 @@ def setup(self, distributed: Distributed, run: Run) -> None: # Setup the model. with torch.no_grad(): + log_main_rank("Setting up model...") self._multi_stage.setup(distributed) - for reference_model in self._reference_models.values(): + for name, reference_model in self._reference_models.items(): + log_main_rank(f"Setting up `{name}` reference model...") reference_model.fast_llm_model.setup(distributed, StageMode.inference) reference_model.setup() @@ -452,6 +454,19 @@ def _prepare_training_state(self) -> None: log_main_rank(lambda: f"Loading checkpoint from iteration {last_iteration}...") self._load_checkpoint(self._config.training.checkpoint, last_iteration) + for name, reference_model in self._reference_models.items(): + pretrained = self._config.reference_models[name].pretrained + if pretrained.path is not None and pretrained.model_weights: + log_main_rank(f"Loading weights for `{name}` reference model from {pretrained.path}") + reference_model.fast_llm_model.load_checkpoint(pretrained) + else: + log_main_rank( + f"No pretrained checkpoint specified for `{name}` reference model," + f" using a freshly initialized model...", + log_fn=logger.warning, + ) + reference_model.fast_llm_model.initialize_weights() + Assert.eq(self._completed_steps, last_iteration or 0) assert self._multi_stage._is_loaded # noqa diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index e9ab01297..d1200a429 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -130,7 +130,7 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) # TODO: Use dynamic model type? - reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate(default_factory=PretrainedGPTModelConfig) + reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() def _validate(self) -> None: if self.batch.sequence_length is None: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 173c6af63..f9e21d1e8 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -5,6 +5,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.models.gpt.config import GPTTrainerConfig from fast_llm.models.gpt.model import GPTInferenceRunner @@ -22,6 +23,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: # TODO: Fix random state/iteration. preprocess_kwargs = kwargs.copy() + del preprocess_kwargs[LanguageModelKwargs.labels] self._inference_runner.forward(batch, preprocess_kwargs, iteration=1) # TODO: Improve. kwargs[f"{self._name}_logits"] = preprocess_kwargs["logits"] From 52d160d991aca3371afda9b8f5375abd278575dd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 23:12:50 -0400 Subject: [PATCH 7/9] fix --- fast_llm/engine/training/trainer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index b284b3a72..33209b959 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -48,15 +48,14 @@ def __init__(self, config: TrainerConfig): self._config.model, optimizer_state_names=self._config.optimizer.state_names(), ) - self._reference_models = { - name: self._config.get_inference_runner_class()( + self._reference_models = {} + for name, reference_config in self._config.reference_models.items(): + log_main_rank(f"Creating `{name} reference model...") + self._reference_models[name] = self._config.get_inference_runner_class()( reference_config.model.get_model_class()(reference_config.model) ) - for name, reference_config in self._config.reference_models.items() - } - for name, inference_runner in self._reference_models.items(): self._multi_stage.base_model.add_preprocessor( - self._get_reference_model_preprocessor(name, inference_runner) + self._get_reference_model_preprocessor(name, self._reference_models[name]) ) phase: PhaseType From 35cc58bf1e5f24f1a571ffa0751513649dae74a7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 23:33:08 -0400 Subject: [PATCH 8/9] fix --- fast_llm/engine/training/config.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index fc7c56e01..43f084286 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -412,6 +412,12 @@ def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelC old_setup = pretrained._setup def new_setup(): + # Make sure the distributed config isn't set + # TODO!!!!!!!!!!!!!: Uncomment after #205 + # pretrained.model.distributed.validate() + # Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) + pretrained.model.distributed = distributed.to_copy() + # Allow sharing the `Distributed` instance. pretrained.model.distributed.reference_config = distributed old_setup() From 015232fb9eea40fd47171d81e2c9884d97066e7a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 3 Apr 2025 23:54:29 -0400 Subject: [PATCH 9/9] misc --- fast_llm/engine/training/config.py | 21 +++++++++++++++++++-- fast_llm/models/gpt/config.py | 3 +++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 43f084286..4f5164b0b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -5,7 +5,16 @@ import subprocess import typing -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import ( + Config, + Field, + FieldHint, + FieldUpdate, + NoAutoValidate, + check_field, + config_class, + skip_valid_if_none, +) from fast_llm.data.data.config import DataConfig from fast_llm.engine.checkpoint.config import ( CheckpointLoadConfig, @@ -375,6 +384,13 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): def _validate(self) -> None: self.training.export.setup(self.model) self.model.validate() + if self.reference_models: + # TODO: Add support. + Assert.eq(self.model.distributed.pipeline_parallel, 1) + # TODO: Check if these work. + Assert.eq(self.model.distributed.tensor_parallel, 1) + Assert.eq(self.model.distributed.sequence_data_parallel, 1) + for reference_model in self.reference_models.values(): _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) super()._validate() @@ -416,7 +432,8 @@ def new_setup(): # TODO!!!!!!!!!!!!!: Uncomment after #205 # pretrained.model.distributed.validate() # Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) - pretrained.model.distributed = distributed.to_copy() + with NoAutoValidate(): + pretrained.model.distributed = distributed.to_copy() # Allow sharing the `Distributed` instance. pretrained.model.distributed.reference_config = distributed old_setup() diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index d1200a429..19c8e6ac6 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -7,6 +7,7 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM @@ -138,6 +139,8 @@ def _validate(self) -> None: self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) + for reference_model in self.reference_models.values(): + Assert.none(reference_model.model.base_model.cross_entropy_splits) super()._validate() @classmethod