diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 8bc86b733..0384fdacd 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -119,8 +119,9 @@ def __init__(self, distributed_config: DistributedConfig): self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1)) def setup(self, distributed: "Distributed") -> None: - assert distributed.config is self._distributed_config assert not self._is_setup + if distributed.config is not self._distributed_config: + distributed.config.compare(self._distributed_config, ValueError) self._is_setup = True self._distributed = distributed diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 1b3e73bb6..1330ab2cf 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -251,6 +251,12 @@ class DistributedConfig(Config): desc="Ensure the initialization is the same for any distributed configuration.", hint=FieldHint.testing, ) + reference_config: "DistributedConfig|None" = Field( + default=None, + init=False, + desc="Pointer to the distributed config this one is an identical copy of.", + hint=FieldHint.derived, + ) def _validate(self) -> None: if self.world_size is None: @@ -281,76 +287,90 @@ def _validate(self) -> None: if self.tensor_parallel == 1: self.sequence_tensor_parallel = False - self.distributed_dims = {} + if self.reference_config is not None: + self.reference_config.validate() + if self.reference_config.reference_config is not None: + self.reference_config = self.reference_config.reference_config + assert self.reference_config.reference_config is None + self.compare(self.reference_config, ValueError) + self.distributed_dims = self.reference_config.distributed_dims + else: + self.distributed_dims = {} - self.add_distributed_dim( - DistributedDim(name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.data, - size=self.data_parallel, - rank=self.data_rank, - id_=f"x_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.world, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.pipeline, - size=self.pipeline_parallel, - rank=self.pipeline_rank, - id_=f"x_{self.data_rank}_{self.tensor_rank}", - parent=DistributedDimNames.world, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.data, + size=self.data_parallel, + rank=self.data_rank, + id_=f"x_{self.pipeline_rank}_{self.tensor_rank}", + parent=DistributedDimNames.world, + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor, - size=self.tensor_parallel, - rank=self.tensor_rank, - id_=f"x_{self.data_rank}_{self.pipeline_rank}", - parent=DistributedDimNames.world, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.pipeline, + size=self.pipeline_parallel, + rank=self.pipeline_rank, + id_=f"x_{self.data_rank}_{self.tensor_rank}", + parent=DistributedDimNames.world, + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.sequence_data, - size=self.sequence_data_parallel, - rank=self.sequence_data_rank, - id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.data, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.tensor, + size=self.tensor_parallel, + rank=self.tensor_rank, + id_=f"x_{self.data_rank}_{self.pipeline_rank}", + parent=DistributedDimNames.world, + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.batch_data, - size=self.batch_data_parallel, - rank=self.batch_data_rank, - id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", - parent=DistributedDimNames.data, + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.sequence_data, + size=self.sequence_data_parallel, + rank=self.sequence_data_rank, + id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", + parent=DistributedDimNames.data, + ) ) - ) - self.add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor_and_sequence_data, - size=self.sequence_data_parallel * self.tensor_parallel, - rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel, - id_=f"{self.batch_data_rank}_{self.pipeline_rank}", - parent=( - DistributedDimNames.tensor - if self.sequence_data_parallel == 1 - else DistributedDimNames.sequence_data if self.tensor_parallel == 1 else DistributedDimNames.world - ), + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.batch_data, + size=self.batch_data_parallel, + rank=self.batch_data_rank, + id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}", + parent=DistributedDimNames.data, + ) + ) + self._add_distributed_dim( + DistributedDim( + name=DistributedDimNames.tensor_and_sequence_data, + size=self.sequence_data_parallel * self.tensor_parallel, + rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel, + id_=f"{self.batch_data_rank}_{self.pipeline_rank}", + parent=( + DistributedDimNames.tensor + if self.sequence_data_parallel == 1 + else ( + DistributedDimNames.sequence_data + if self.tensor_parallel == 1 + else DistributedDimNames.world + ) + ), + ) ) - ) super()._validate() Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) - def add_distributed_dim(self, distributed_dim: DistributedDim) -> None: + def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: if distributed_dim.name in self.distributed_dims: Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name]) else: diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index a612d9cf9..42ec97f2e 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -33,6 +33,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): def __init__(self, config: DistributedConfig, use_cpu: bool = False): super().__init__(config) + assert self._config.reference_config is None self._use_cpu = use_cpu if self._use_cpu: @@ -148,6 +149,13 @@ def add_group(self, distributed_dim: DistributedDim) -> ProcessGroup | None: distributed_dim.setup(group) return group + def check_config(self, config: DistributedConfig) -> None: + # Allows using this `Distributed` on a model with a distributed config that is a copy of `self._config` + if config.reference_config is None: + Assert.is_(config, self._config) + else: + Assert.is_(config.reference_config, self._config) + def set_step(self, step: int, phase: PhaseType) -> None: """ Reseed pytorch for a given training step. diff --git a/fast_llm/engine/huggingface/__init__.py b/fast_llm/engine/inference/__init__.py similarity index 100% rename from fast_llm/engine/huggingface/__init__.py rename to fast_llm/engine/inference/__init__.py diff --git a/fast_llm/engine/huggingface/config.py b/fast_llm/engine/inference/config.py similarity index 100% rename from fast_llm/engine/huggingface/config.py rename to fast_llm/engine/inference/config.py diff --git a/fast_llm/engine/huggingface/model.py b/fast_llm/engine/inference/huggingface.py similarity index 54% rename from fast_llm/engine/huggingface/model.py rename to fast_llm/engine/inference/huggingface.py index 499f0af12..30e03b907 100644 --- a/fast_llm/engine/huggingface/model.py +++ b/fast_llm/engine/inference/huggingface.py @@ -4,20 +4,16 @@ import transformers.modeling_outputs -from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat -from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig +from fast_llm.engine.inference.config import HuggingfaceModelConfig +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel -from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule class HuggingfacePreTrainedModel(transformers.PreTrainedModel): config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig - model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel + runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner config: HuggingfaceModelConfig # base_model_prefix = "" # _no_split_modules = None @@ -25,36 +21,21 @@ class HuggingfacePreTrainedModel(transformers.PreTrainedModel): # _tied_weights_keys = [] def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, **kwargs): - assert self.model_class.config_class is config.model_config_class + assert self.runner_class.model_class.config_class is config.model_config_class assert config.fast_llm_config is fast_llm_model.config assert isinstance(config, self.config_class) + super().__init__(config, **kwargs) - self._fast_llm_config = config.fast_llm_config - self._fast_llm_model = fast_llm_model + + self._inference_runner = self.runner_class(fast_llm_model) + if not fast_llm_model.is_setup: + fast_llm_model.setup(mode=StageMode.inference) + self._inference_runner.setup() # Transformers needs to be able to inspect the base model. - self.fast_llm_base_model = self._fast_llm_model.base_model - self._distributed_config = self._fast_llm_config.distributed + self.fast_llm_base_model = fast_llm_model.base_model # TODO: Support distributed models? - assert self._distributed_config.world_size == 1 - self._schedule_config = ScheduleConfig() - # We only need a basic schedule and don't care about dimensions. - # TODO: Sort things out. - with NoAutoValidate(): - self._batch_config = BatchConfig() - self._batch_config.setup(self._distributed_config) - self._batch_config.validate() - self._runner = ScheduleRunner( - config=self._schedule_config, multi_stage=self._fast_llm_model, distributed_config=self._distributed_config - ) - self._runner.setup(self._fast_llm_model.distributed) - # TODO: Random state? (Distributed.set_step) - self._schedule = Schedule( - multi_stage=self._fast_llm_model, - batch_config=self._batch_config, - schedule_config=self._schedule_config, - distributed_config=self._distributed_config, - phase=PhaseType.inference, - ) + assert fast_llm_model.config.distributed.world_size == 1 + with transformers.modeling_utils.no_init_weights(): self.post_init() @@ -79,7 +60,7 @@ def from_pretrained( config_updates[("distributed", "training_dtype")] = torch_dtype # Create the model - fast_llm_model = cls.model_class.from_pretrained( + fast_llm_model = cls.runner_class.model_class.from_pretrained( pretrained_model_name_or_path, config_updates=config_updates, mode=mode ) config = cls.config_class(fast_llm_model.config) diff --git a/fast_llm/engine/inference/runner.py b/fast_llm/engine/inference/runner.py new file mode 100644 index 000000000..b83a5332a --- /dev/null +++ b/fast_llm/engine/inference/runner.py @@ -0,0 +1,58 @@ +import abc +import typing + +from fast_llm.config import NoAutoValidate +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig +from fast_llm.engine.schedule.runner import ScheduleRunner +from fast_llm.engine.schedule.schedule import Schedule + + +class InferenceRunner(abc.ABC): + model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel + + def __init__(self, fast_llm_model: FastLLMModel): + assert isinstance(fast_llm_model, self.model_class) + self._fast_llm_model = fast_llm_model + # We only need a basic schedule and don't care about dimensions. + self._schedule_config = ScheduleConfig() + # TODO: Sort things out. + with NoAutoValidate(): + self._batch_config = BatchConfig() + self._batch_config.setup(self._fast_llm_model.config.distributed) + self._batch_config.validate() + self._runner = ScheduleRunner( + config=self._schedule_config, + multi_stage=self._fast_llm_model, + distributed_config=self._fast_llm_model.config.distributed, + ) + # TODO: Random state? (Distributed.set_step) + self._schedule = Schedule( + multi_stage=self._fast_llm_model, + batch_config=self._batch_config, + schedule_config=self._schedule_config, + distributed_config=self._fast_llm_model.config.distributed, + phase=PhaseType.inference, + ) + + @property + def fast_llm_model(self) -> FastLLMModel: + return self._fast_llm_model + + def setup(self): + self._runner.setup(self._fast_llm_model.distributed) + + def forward( + self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False + ) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]: + # TODO: Return an actual model output. + reduced_losses, update_successful, metrics = self._runner.run_step( + iter((((input_, kwargs),),)), + self._schedule, + iteration=iteration, + return_metrics=return_metrics, + preprocessed=True, + ) + assert update_successful + return reduced_losses, metrics diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index fd7ba6453..e6de074f4 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -29,7 +29,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.huggingface.model import HuggingfacePreTrainedModel + from fast_llm.engine.inference.model import HuggingfacePreTrainedModel from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 238bd865e..21d0fe557 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -209,12 +209,15 @@ def __init__( "Bfloat16 gradient accumulation and reduction is not recommended. (use --full_precision_gradients=1)" ) - def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) -> None: + def setup(self, distributed: Distributed | None = None, mode: StageMode = StageMode.training) -> None: # TODO: More checks? stage: Stage - assert distributed.config is self._config.distributed assert not self._is_setup self._is_setup = True + if distributed is None: + distributed = Distributed(self._config.distributed) + else: + distributed.check_config(self._config.distributed) self._distributed = distributed self._mode = mode self._base_model.setup(distributed) @@ -381,6 +384,10 @@ def get_shard(self, name: str) -> torch.Tensor: raise KeyError(f"Unknown shard name {name}") return self._shards[name] + @property + def is_setup(self) -> bool: + return self._is_setup + @property def support_forward(self) -> bool: assert self._is_setup @@ -442,6 +449,7 @@ def is_parameter_on_device(self, parameter_name: str) -> bool: @property def distributed(self) -> Distributed: + assert self._is_setup return self._distributed def invalidate_buffers(self) -> None: diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 0f83c862d..da7eb7d88 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -128,7 +128,7 @@ def setup( mode: StageMode = StageMode.training, ) -> None: assert not self._is_setup - assert distributed.config is self._distributed_config + distributed.check_config(self._distributed_config) self._mode = mode self._is_setup = True self._distributed = distributed diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 1d4b04c1c..4a5425eed 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -97,7 +97,7 @@ def __init__( def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None: assert not self._is_setup - assert distributed.config is self._distributed_config + distributed.check_config(self._distributed_config) self._is_setup = True self._optimizer = optimizer assert self._multi_stage.support_forward diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index dac4e5533..4f5164b0b 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -5,7 +5,16 @@ import subprocess import typing -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import ( + Config, + Field, + FieldHint, + FieldUpdate, + NoAutoValidate, + check_field, + config_class, + skip_valid_if_none, +) from fast_llm.data.data.config import DataConfig from fast_llm.engine.checkpoint.config import ( CheckpointLoadConfig, @@ -14,6 +23,7 @@ DistributedCheckpointFormat, ) from fast_llm.engine.config_utils.run import ExperimentConfig +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.multi_stage.config import PretrainedFastLLMModelConfig from fast_llm.engine.optimizer.config import OptimizerConfig from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig @@ -21,6 +31,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.training.trainer import Trainer @@ -364,9 +375,24 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): desc="Configuration for the training optimizer and learning rate schedule.", hint=FieldHint.core, ) + reference_models: dict[str, PretrainedFastLLMModelConfig] = Field( + default_factory=dict, + desc="Auxiliary models used during training, ex. for knowledge distillation.", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.training.export.setup(self.model) + self.model.validate() + if self.reference_models: + # TODO: Add support. + Assert.eq(self.model.distributed.pipeline_parallel, 1) + # TODO: Check if these work. + Assert.eq(self.model.distributed.tensor_parallel, 1) + Assert.eq(self.model.distributed.sequence_data_parallel, 1) + + for reference_model in self.reference_models.values(): + _add_reference_distributed_to_pretrained(reference_model, self.model.distributed) super()._validate() if self.run.experiment_dir is None: assert not self.training.checkpoint.enabled() @@ -379,6 +405,10 @@ def _setup(self): def get_trainer_class(cls) -> type["Trainer"]: raise NotImplementedError + @classmethod + def get_inference_runner_class(cls) -> type["InferenceRunner"]: + raise NotImplementedError + def _get_runnable(self) -> typing.Callable[[], None]: from fast_llm.engine.distributed.distributed import Distributed @@ -392,3 +422,20 @@ def runnable(): trainer.run() return runnable + + +def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelConfig, distributed: DistributedConfig): + old_setup = pretrained._setup + + def new_setup(): + # Make sure the distributed config isn't set + # TODO!!!!!!!!!!!!!: Uncomment after #205 + # pretrained.model.distributed.validate() + # Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"}) + with NoAutoValidate(): + pretrained.model.distributed = distributed.to_copy() + # Allow sharing the `Distributed` instance. + pretrained.model.distributed.reference_config = distributed + old_setup() + + pretrained._setup = new_setup diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index a009fef39..33209b959 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -11,10 +11,12 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.engine.inference.runner import InferenceRunner +from fast_llm.engine.multi_stage.config import StageMode from fast_llm.engine.optimizer.config import ParamGroup from fast_llm.engine.optimizer.optimizer import Optimizer from fast_llm.engine.schedule.runner import ScheduleRunner @@ -29,7 +31,6 @@ class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): config_class: typing.ClassVar[type[TrainerConfig]] = TrainerConfig - model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel # TODO: Generalize data, schedule, logging, etc. _is_setup: bool = False _distributed: Distributed @@ -43,10 +44,20 @@ def __init__(self, config: TrainerConfig): super().__init__(config) self._data = self._get_data() log_main_rank("Creating model...") - self._multi_stage = self.model_class( + self._multi_stage = self._config.model.get_model_class()( self._config.model, optimizer_state_names=self._config.optimizer.state_names(), ) + self._reference_models = {} + for name, reference_config in self._config.reference_models.items(): + log_main_rank(f"Creating `{name} reference model...") + self._reference_models[name] = self._config.get_inference_runner_class()( + reference_config.model.get_model_class()(reference_config.model) + ) + self._multi_stage.base_model.add_preprocessor( + self._get_reference_model_preprocessor(name, self._reference_models[name]) + ) + phase: PhaseType self._runner = ScheduleRunner( config=self._config.schedule, @@ -102,7 +113,12 @@ def setup(self, distributed: Distributed, run: Run) -> None: # Setup the model. with torch.no_grad(): + log_main_rank("Setting up model...") self._multi_stage.setup(distributed) + for name, reference_model in self._reference_models.items(): + log_main_rank(f"Setting up `{name}` reference model...") + reference_model.fast_llm_model.setup(distributed, StageMode.inference) + reference_model.setup() # Setup the optimizer. param_groups, grads_for_norm = self._multi_stage.get_param_groups(ParamGroup) @@ -116,7 +132,6 @@ def setup(self, distributed: Distributed, run: Run) -> None: # Setup the schedules. with torch.no_grad(): self._runner.setup(distributed, self._optimizer) - # Setup the datasets. log_main_rank("Preparing datasets...") self._data.setup( @@ -438,6 +453,19 @@ def _prepare_training_state(self) -> None: log_main_rank(lambda: f"Loading checkpoint from iteration {last_iteration}...") self._load_checkpoint(self._config.training.checkpoint, last_iteration) + for name, reference_model in self._reference_models.items(): + pretrained = self._config.reference_models[name].pretrained + if pretrained.path is not None and pretrained.model_weights: + log_main_rank(f"Loading weights for `{name}` reference model from {pretrained.path}") + reference_model.fast_llm_model.load_checkpoint(pretrained) + else: + log_main_rank( + f"No pretrained checkpoint specified for `{name}` reference model," + f" using a freshly initialized model...", + log_fn=logger.warning, + ) + reference_model.fast_llm_model.initialize_weights() + Assert.eq(self._completed_steps, last_iteration or 0) assert self._multi_stage._is_loaded # noqa @@ -527,3 +555,6 @@ def _get_last_checkpoint(self) -> int | None: def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, int]: # TODO: Do in model, automate/generalize, get other stats pass + + def _get_reference_model_preprocessor(self, name: str, inference_runner: InferenceRunner) -> Preprocessor: + raise NotImplementedError() diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index b86722f82..1c82b59e4 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -61,8 +61,8 @@ class PretrainedCustomModelConfig(PretrainedGPTModelConfig): @config_class() class CustomTrainerConfig(PretrainedCustomModelConfig, GPTTrainerConfig): # TODO: Add custom trainer config parameters, if any (typically none). - data: CustomDataConfig = FieldUpdate(default_factory=CustomDataConfig) + reference_models: dict[str, PretrainedCustomModelConfig] = FieldUpdate(default_factory=PretrainedCustomModelConfig) @classmethod def get_trainer_class(cls) -> type["CustomTrainer"]: diff --git a/fast_llm/models/custom/trainer.py b/fast_llm/models/custom/trainer.py index 2c1b3b5f2..eba51235e 100644 --- a/fast_llm/models/custom/trainer.py +++ b/fast_llm/models/custom/trainer.py @@ -2,14 +2,12 @@ from fast_llm.models.custom.config import CustomTrainerConfig from fast_llm.models.custom.data import CustomData -from fast_llm.models.custom.model import CustomModel from fast_llm.models.gpt.trainer import GPTTrainer class CustomTrainer[ConfigType: CustomTrainerConfig](GPTTrainer[ConfigType]): # TODO: Implement changes in the training loop (or tflops computation), if any (typically none). config_class: typing.ClassVar[type[CustomTrainerConfig]] = CustomTrainerConfig - model_class: typing.ClassVar[type[CustomModel]] = CustomModel def _get_data(self): # TODO: Adjust signature if needed. diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 5a21368fa..19c8e6ac6 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -7,10 +7,11 @@ from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.language_model.config import LanguageModelArchitectureConfig, LanguageModelBaseConfig from fast_llm.models.gpt.megatron import set_megatron_distributed_seeds +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM - from fast_llm.models.gpt.model import GPTModel + from fast_llm.models.gpt.model import GPTInferenceRunner, GPTModel from fast_llm.models.gpt.trainer import GPTTrainer @@ -129,6 +130,8 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) + # TODO: Use dynamic model type? + reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() def _validate(self) -> None: if self.batch.sequence_length is None: @@ -136,6 +139,8 @@ def _validate(self) -> None: self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) + for reference_model in self.reference_models.values(): + Assert.none(reference_model.model.base_model.cross_entropy_splits) super()._validate() @classmethod @@ -143,3 +148,9 @@ def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer return GPTTrainer + + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + return GPTInferenceRunner diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index e4db9b07e..0da4acbb4 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -1,16 +1,17 @@ import logging import random +import typing import torch import transformers.modeling_outputs from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.distributed.config import PhaseType -from fast_llm.engine.huggingface.config import HuggingfaceModelConfig -from fast_llm.engine.huggingface.model import HuggingfacePreTrainedModel +from fast_llm.engine.inference.config import HuggingfaceModelConfig +from fast_llm.engine.inference.huggingface import HuggingfacePreTrainedModel from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTModelConfig -from fast_llm.models.gpt.model import GPTModel +from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner logger = logging.getLogger(__name__) @@ -24,8 +25,8 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel): config_class = HuggingfaceGPTModelConfig config: HuggingfaceGPTModelConfig - model_class = GPTModel - _fast_llm_model: GPTModel + runner_class: typing.ClassVar[type[GPTInferenceRunner]] = GPTInferenceRunner + fast_llm_base_model: GPTBaseModel # base_model_prefix = "" # _no_split_modules = None # _supports_cache_class = False @@ -67,10 +68,10 @@ def forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) - batch = self._fast_llm_model.base_model.preprocess( + batch = self.fast_llm_base_model.preprocess( GPTBatch(input_ids), phase=PhaseType.inference, iteration=iteration ) - ((_, kwargs),) = batch + ((input_, kwargs),) = batch if past_key_values is not None: # The transformers will use the past keys and values to this list. @@ -81,7 +82,7 @@ def forward( # The transformers will save the present keys and values to this list. kwargs[TransformerKwargs.presents] = [] - _, _, _ = self._runner.run_step(iter((batch,)), self._schedule, iteration=iteration, preprocessed=True) + self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. logits = kwargs["logits"] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 7e5c5d33b..c672b2165 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -9,6 +9,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames @@ -115,7 +116,7 @@ def get_layers(self) -> list[Layer]: def setup(self, distributed: Distributed) -> None: assert not self._is_setup - assert distributed.config is self._tensor_space.distributed_config + distributed.check_config(self._tensor_space.distributed_config) self._tensor_space.setup(distributed) self._is_setup = True @@ -369,3 +370,7 @@ def add_preprocessor(self, preprocessor: Preprocessor): class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel + + +class GPTInferenceRunner(InferenceRunner): + model_class: typing.ClassVar[type[GPTModel]] = GPTModel diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 376d8b840..f9e21d1e8 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -2,17 +2,35 @@ import typing from fast_llm.data.data.gpt.data import GPTData +from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer +from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.models.gpt.config import GPTTrainerConfig -from fast_llm.models.gpt.model import GPTModel +from fast_llm.models.gpt.model import GPTInferenceRunner logger = logging.getLogger(__name__) +class GPTReferenceModelPreprocessor(Preprocessor): + def __init__(self, name: str, inference_runner: GPTInferenceRunner): + self._name = name + self._inference_runner = inference_runner + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + pass + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + # TODO: Fix random state/iteration. + preprocess_kwargs = kwargs.copy() + del preprocess_kwargs[LanguageModelKwargs.labels] + self._inference_runner.forward(batch, preprocess_kwargs, iteration=1) + # TODO: Improve. + kwargs[f"{self._name}_logits"] = preprocess_kwargs["logits"] + + class GPTTrainer[ConfigType: GPTTrainerConfig](Trainer[ConfigType]): config_class: typing.ClassVar[type[GPTTrainerConfig]] = GPTTrainerConfig - model_class: typing.ClassVar[type[GPTModel]] = GPTModel def _get_data(self) -> GPTData: return GPTData( @@ -71,3 +89,6 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration) -> tuple[int, hardware_flops = flops_per_iteration + 7 / 6 * attn_flops ratio = elapsed_time_per_iteration * self._config.model.distributed.world_size * 1e12 return model_tflops / ratio, hardware_flops / ratio + + def _get_reference_model_preprocessor(self, name: str, inference_runner: GPTInferenceRunner) -> Preprocessor: + return GPTReferenceModelPreprocessor(name, inference_runner)