diff --git a/Megatron-LM b/Megatron-LM index 75b0d9787..f02b413f7 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 +Subproject commit f02b413f793af05ade3893bccd8aef6d644d3edf diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 0620beaea..35a324db0 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define: ```python def _create_weight_converters(self) -> list[WeightConverter]: - converters = [] - # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers - - # A simple renaming example, for the word embeddings. - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - - # We usually want to loop dynamically over layers - for i in range(num_layers): - # A `SplitWeightConverter` example, splitting a weight in two. - converters.append(SplitWeightConverter( - f"layers.{i + 1}.weight", - (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), - )) - return converters + converters = [] + # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. + num_layers = self._model.config.base_model.transformer.num_layers + + # A simple renaming example, for the word embeddings. + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + + # We usually want to loop dynamically over layers + for i in range(num_layers): + # A `SplitWeightConverter` example, splitting a weight in two. + converters.append(SplitWeightConverter( + f"layers.{i + 1}.weight", + (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), + )) + return converters ``` And that's it! We're ready to use the new checkpoint format in Fast-LLM. diff --git a/fast_llm/config.py b/fast_llm/config.py index c534b11f3..3352f3570 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1028,6 +1028,28 @@ def __init__(self, config: ConfigType, *args, **kwargs): # Handle multiple inheritance. super().__init__(*args, **kwargs) + def __init_subclass__(cls): + # Automatically set `config_class` based on the bound type. + # Make sure `ConfigType` is bound and respects class hierarchy. + try: + config_class = None + for base in types.get_original_bases(cls): + if hasattr(base, "__origin__") and issubclass(base.__origin__, Configurable): + for arg in base.__args__: + if arg.__name__ == "ConfigType": + if config_class is None: + config_class = arg.__bound__ + else: + assert arg.__bound__ is config_class + assert config_class is not None + except Exception as e: + raise TypeError( + f"Could not determine the configuration class for the configurable class {cls.__name__}: {e.args}. " + "Please make sure to declare in the format " + f"`class {cls.__name__}[ConfigType: ConfigClass](BaseConfigurable[ConfigType])`.] " + ) + cls.config_class = config_class + @property def config(self) -> ConfigType: return self._config diff --git a/fast_llm/data/preparator/config.py b/fast_llm/data/preparator/config.py index 7f6376c7d..160fccafc 100644 --- a/fast_llm/data/preparator/config.py +++ b/fast_llm/data/preparator/config.py @@ -19,8 +19,6 @@ def _get_runnable(self) -> typing.Callable[[], None]: class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[DatasetPreparatorConfig]] = DatasetPreparatorConfig - @abc.abstractmethod def run(self) -> None: raise NotImplementedError diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 7c878b264..d2aaee5e2 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -30,25 +30,6 @@ class SourceSchemaConfig(Config): pass -@config_class(dynamic_type={SourceSchemaConfig: "prompt_completion"}) -class PromptCompletionConfig(SourceSchemaConfig): - prompt_column: str = Field( - default="prompt", - desc="Field of the dataset to use.", - hint=FieldHint.optional, - ) - completion_column: str = Field( - default="completion", - desc="Field of the dataset to use.", - hint=FieldHint.optional, - ) - delimiter: str = Field( - default="", - desc="Delimiter between prompt and completion.", - hint=FieldHint.optional, - ) - - @config_class(dynamic_type={SourceSchemaConfig: "text_column"}) class TextColumnConfig(SourceSchemaConfig): input_column: str = Field( diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b70413482..33c40bf8f 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -24,11 +24,7 @@ from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import ( - GPTMemmapDatasetPreparatorConfig, - PromptCompletionConfig, - TextColumnConfig, -) +from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -37,8 +33,6 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): - config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig - _tokenizer: Tokenizer _data_type: DataType _text_column: str @@ -54,30 +48,6 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[ "num_tokens": num_tokens, } - def _tokenize_prompt_completion_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - """ - Tokenize prompt and completion columns separately, then concatenate. - Returns input_ids, token_spans (prompt len), and num_tokens. - """ - prompt_col = self._config.dataset.source_schema.prompt_column - completion_col = self._config.dataset.source_schema.completion_column - delimiter = self._config.dataset.source_schema.delimiter - input_ids = [] - token_spans = [] - for prompt, completion in zip(batch[prompt_col], batch[completion_col]): - prompt_tokens = self._tokenizer.tokenize(prompt, begin=True, end=False) - completion_tokens = self._tokenizer.tokenize(f"{delimiter}{completion}", begin=False, end=True) - combined = prompt_tokens + completion_tokens - input_ids.append(np.array(combined, dtype=self._data_type.numpy)) - token_spans.append(np.array((0, len(prompt_tokens) - 1), dtype=np.int32).reshape(-1, 2)) - - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "token_spans": token_spans, - "num_tokens": num_tokens, - } - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids, token_spans = map( list, @@ -171,7 +141,7 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names: + if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( np.array(item["input_ids"], dtype=self._data_type.numpy), @@ -317,46 +287,37 @@ def run(self) -> None: ) # Set data column and loss masking spans column based on source schema - source_schema = self._config.dataset.source_schema - if isinstance(source_schema, TextColumnConfig): - self._text_column = source_schema.input_column - self._loss_masking_spans_column = source_schema.loss_masking_spans_column - elif isinstance(source_schema, PromptCompletionConfig): - Assert.incl(source_schema.prompt_column, dataset.column_names) - Assert.incl(source_schema.completion_column, dataset.column_names) - tokenize_fn = self._tokenize_prompt_completion_batch + if isinstance(self._config.dataset.source_schema, TextColumnConfig): + self._text_column = self._config.dataset.source_schema.input_column + self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column else: raise ValueError( f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'." ) - # TODO: Add a new schema for preference datasets then drop class vars _loss_masking_spans_column & _text_column - if isinstance(source_schema, TextColumnConfig): - if self._text_column not in dataset.column_names: - raise ValueError(f"Dataset does not have field '{self._text_column}'.") + if self._text_column not in dataset.column_names: + raise ValueError(f"Dataset does not have field '{self._text_column}'.") - if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( - self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None - ): - raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") - if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): - raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") - - # route tokenize function - if self._loss_masking_spans_column is not None: - if self._loss_masking_spans_column not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") - tokenize_fn = self._tokenize_batch_with_spans - elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: - if self._config.dataset.chosen_text not in dataset.column_names: - raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") - if self._config.dataset.rejected_text not in dataset.column_names: - raise ValueError( - f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'." - ) - tokenize_fn = self._tokenize_preference_batch_with_spans - else: - tokenize_fn = self._tokenize_batch + if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( + self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None + ): + raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): + raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") + + # route tokenize function + if self._loss_masking_spans_column is not None: + if self._loss_masking_spans_column not in dataset.column_names: + raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") + tokenize_fn = self._tokenize_batch_with_spans + elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + if self._config.dataset.chosen_text not in dataset.column_names: + raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") + if self._config.dataset.rejected_text not in dataset.column_names: + raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.") + tokenize_fn = self._tokenize_preference_batch_with_spans + else: + tokenize_fn = self._tokenize_batch # Tokenize the dataset in parallel tokenized_dataset = dataset.map( @@ -368,6 +329,7 @@ def run(self) -> None: # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) shards = [ diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index df603a910..832225803 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -7,7 +7,6 @@ from fast_llm.config import Configurable from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta @@ -20,11 +19,18 @@ class Module(torch.nn.Module, abc.ABC): """ """ - def forward(self, input_, kwargs): - """ - Run a forward pass for the module, with autograd support. - """ - raise NotImplementedError() + _is_setup: bool = False + _distributed: Distributed + + def __init__(self, distributed_config: DistributedConfig): + self._distributed_config = distributed_config + super().__init__() + + def setup(self, distributed: Distributed) -> None: + assert not self._is_setup + distributed.check_config(self._distributed_config) + self._distributed = distributed + self._is_setup = True class Layer(Module): @@ -39,9 +45,9 @@ def forward( class Sequential(Layer): - def __init__(self, layers: list[Layer]): - super().__init__() - self.layers = torch.nn.ModuleList(layers) + def __init__(self, distributed_config: DistributedConfig): + super().__init__(distributed_config) + self.layers = torch.nn.ModuleList(self.get_layers()) def __getitem__(self, item): return self.layers[item] @@ -59,6 +65,15 @@ def forward( input_ = layer(input_, kwargs, losses, metrics) return input_ + @abc.abstractmethod + def get_layers(self) -> list[Layer]: + pass + + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + for layer in self.layers: + layer.setup(distributed) + @dataclasses.dataclass() class LossDef: @@ -71,29 +86,14 @@ class LossDef: dtype: torch.dtype = torch.float32 -class SequentialLayers(Sequential, abc.ABC): - # Small class defined to fix the MRO of BaseModel.__init__ - def __init__(self): - super().__init__(self.get_layers()) - - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass - - -class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC): - config_class: typing.ClassVar[type[BaseModelConfig]] = BaseModelConfig - _is_setup: bool = False +class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential): def __init__( self, config: BaseModelConfig, distributed_config: DistributedConfig, ): - self._tensor_space: TensorSpace = TensorSpace(distributed_config) - config.setup_tensor_space(self._tensor_space) - - super().__init__(config) + super().__init__(config, distributed_config) for key, value in self.named_parameters(): Assert.custom(isinstance, value, ParameterMeta) @@ -104,12 +104,6 @@ def __init__( # TODO: Add basic handling (preprocessor) in this class. self._reference_models: dict[str, "InferenceRunner"] = {} - def setup(self, distributed: Distributed) -> None: - assert not self._is_setup - distributed.check_config(self._tensor_space.distributed_config) - self._tensor_space.setup(distributed) - self._is_setup = True - @abc.abstractmethod def get_layers(self) -> list[Layer]: pass diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 4be42e069..22abb021b 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -6,7 +6,7 @@ from fast_llm.utils import compare_nested, log if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_space import TensorSpace + import torch @config_class() @@ -18,9 +18,6 @@ class BaseModelConfig(Config): _abstract = True - def setup_tensor_space(self, tensor_space: "TensorSpace") -> None: - raise NotImplementedError() - def compare_architecture( self, model_config: typing.Self, @@ -64,5 +61,5 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass @abc.abstractmethod - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: pass diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py new file mode 100644 index 000000000..b60070562 --- /dev/null +++ b/fast_llm/engine/config_utils/initialization.py @@ -0,0 +1,57 @@ +import abc +import typing + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.tensor import ParameterMeta + + +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[["ParameterMeta", "torch.Tensor", "torch.Generator"], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: + return self._init_method(meta, tensor, generator) + + +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) + + +init_zeros_ = init_fill_(0.0) +init_ones_ = init_fill_(1.0) + + +def init_normal_( + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor = tensor.normal_(mean, std, generator=generator) + if min_val is not None or max_val is not None: + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + + +def init_uniform_centered_(scale: float, mean: float = 0.0) -> LambdaInitializer: + def init_(meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: # noqa + tensor.uniform_(mean - scale, mean + scale, generator=generator) + + return LambdaInitializer(init_) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_dim.py similarity index 57% rename from fast_llm/engine/config_utils/tensor_space.py rename to fast_llm/engine/config_utils/tensor_dim.py index 66176ee0f..f67916a66 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_dim.py @@ -2,29 +2,18 @@ import math import typing -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim +from fast_llm.engine.distributed.config import DistributedDim from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: import torch from fast_llm.core.distributed import ProcessGroup - from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) class TensorDim: - """ - Describes a simple, atomic dimension of a tensor and its size. - The dimension may be parallelized along a distributed dimension `parallel_dim`, - in which case its actual (local) `size` will differ from its `global_size`. - - TensorDim's are used to represent the metadata of tensors through `TensorMeta`. - - This class also serves as a base for more complex tensor dimensions. - """ - def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): # TODO: Handle None for unknown sizes? self._name = name @@ -72,25 +61,10 @@ def parallel_group(self) -> "ProcessGroup|None": return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - """ - Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, - but the local size remains the same. - - Used in`TensorMeta.replace_tensor_parallel_dim`. - """ assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - """ - Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. - If the dimension is parallelized, this amounts to gathering along dimension `dim` - and parallel dimension `parallel_dim`, otherwise return the input tensor. - The method needs to be called my all members of the parallel group using their appropriate local slice. - - Used in`TensorMeta.local_to_global`, - which iterates over the tensor dimensions to fully reconstruct the global tensor. - """ if self.is_parallel: from fast_llm.core.ops import gather_op @@ -101,14 +75,6 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": - """ - Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`. - Unlike `local_to_global`, this method does not need to be called from a distributed setting. - Instead, entries from other ranks are populated with `fill_value`. - - Used in`TensorMeta.local_to_global_partial`, - which iterates over the tensor dimensions to fully reconstruct the global tensor. - """ if self.is_parallel: output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) @@ -117,14 +83,6 @@ def local_to_global_partial( return tensor def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": - """ - Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. - If the dimension is parallel, this amounts to taking the `rank`th chunk of size `size` along dimension `dim` - and parallel dimension `self.parallel_dim`, otherwise return the input tensor. - - Used in`TensorMeta.local_to_global`, - which iterates over the tensor dimensions to fully reconstruct the local tensor. - """ return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] if self.parallel_dim is not None and self.parallel_dim.size > 1 @@ -133,20 +91,11 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F class CompositeTensorDim(TensorDim): - """ - A composite tensor dimension that represent multiple dimensions flattened into ones. - Typically happens for flattened view or higher-dimensional tensors, or tensors that can be expanded as such. - If one of the composed dimensions -- other than the first one -- is parallelized, - this is **not** equivalent to an atomic `TensorDim` of the same size, - as the relation between local and global tensors is different. - - At most one of the sub-dimensions may be parallelized. TODO: Allow for more than one? - """ - def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = None for dim, tensor_dim in enumerate(tensor_dims): if tensor_dim.parallel_dim is not None: + # TODO: Allow more than one parallel subdim? assert parallel_dim is None parallel_dim = tensor_dim.parallel_dim self._parallel_dim_index = dim @@ -159,19 +108,12 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - """ - Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, - but the local size remains the same. - """ assert self._parallel_dim_index is not None dims = list(self._tensor_dims) dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) return CompositeTensorDim(self.name, tuple(dims)) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - """ - Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. - """ tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) for i, tensor_dim in enumerate(self._tensor_dims): tensor = tensor_dim.local_to_global(tensor, dim + i) @@ -181,10 +123,6 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": - """ - Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, - populating other ranks with `fill_value`. - """ tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) for i, tensor_dim in enumerate(self._tensor_dims): tensor = tensor_dim.local_to_global_partial(tensor, dim + i) @@ -192,9 +130,6 @@ def local_to_global_partial( return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": - """ - Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. - """ tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): tensor = tensor_dim.global_to_local(tensor, dim + i) @@ -202,12 +137,6 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F class ConcatenatedTensorDim(TensorDim): - """ - A complex tensor dimension that results from concatenating tensors. - - All sub-dimensions should have the same `parallel_dim` (may be None). TODO: Allow for more complex scenarios? - """ - def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = tensor_dims[0].parallel_dim for dim, tensor_dim in enumerate(tensor_dims[1:]): @@ -222,19 +151,12 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - """ - Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, - but the local size remains the same. - """ assert self.is_parallel return ConcatenatedTensorDim( self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) ) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - """ - Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. - """ import torch return ( @@ -256,10 +178,6 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": - """ - Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, - populating other ranks with `fill_value`. - """ import torch return ( @@ -279,9 +197,6 @@ def local_to_global_partial( ) def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": - """ - Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. - """ if self.is_parallel and expand: raise NotImplementedError() import torch @@ -303,49 +218,4 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F ) -class DefaultDimNames: - # Scalar - scalar = "scalar" - - -class TensorSpace: - _is_setup: bool = False - _distributed: "Distributed" - - def __init__(self, distributed_config: DistributedConfig): - self._distributed_config = distributed_config - self._tensor_dims: dict[str, TensorDim] = {} - self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1)) - - def setup(self, distributed: "Distributed") -> None: - assert not self._is_setup - if distributed.config is not self._distributed_config: - distributed.config.compare(self._distributed_config, ValueError) - self._is_setup = True - self._distributed = distributed - - @property - def distributed_config(self) -> DistributedConfig: - return self._distributed_config - - @property - def distributed(self) -> "Distributed": - assert self._is_setup - return self._distributed - - def add_tensor_dim(self, tensor_dim: TensorDim) -> None: - if tensor_dim.name in self._tensor_dims: - Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) - else: - if tensor_dim.parallel_dim is not None: - assert ( - tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims - ), tensor_dim.parallel_dim.name - Assert.eq( - tensor_dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, - ) - self._tensor_dims[tensor_dim.name] = tensor_dim - - def __getitem__(self, name: str) -> TensorDim: - return self._tensor_dims[name] +scalar_dim = TensorDim("scalar", 1) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index f17a8f452..dc41539c0 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -1,6 +1,5 @@ import datetime import logging -import typing import torch import torch.distributed @@ -146,8 +145,6 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): TODO: Clarify cpu support. """ - config_class: typing.ClassVar[type[DistributedConfig]] = DistributedConfig - def __init__(self, config: DistributedConfig, use_cpu: bool = False): super().__init__(config) assert self._config.reference_config is None diff --git a/fast_llm/engine/evaluation/evaluator.py b/fast_llm/engine/evaluation/evaluator.py index 3bdc2407f..6b8f8db00 100644 --- a/fast_llm/engine/evaluation/evaluator.py +++ b/fast_llm/engine/evaluation/evaluator.py @@ -44,8 +44,6 @@ class EvaluatorSamplingParameters: class Evaluator[ConfigType: EvaluatorConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[EvaluatorConfig]] = EvaluatorConfig - _is_setup: bool = False def __init__( @@ -96,8 +94,6 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: class LossEvaluator[ConfigType: LossEvaluatorConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[LossEvaluatorConfig]] = LossEvaluatorConfig - def setup( self, distributed: Distributed, diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 162ceaf60..9040b11b4 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -25,8 +25,6 @@ class LmEvalEvaluator[ConfigType: LmEvalEvaluatorConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[LmEvalEvaluatorConfig]] = LmEvalEvaluatorConfig - _hf_model: "HuggingfaceBaseModelForCausalLM" = None _flm_wrapper: "FastLLMLmEvalWrapper" = None diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 8f4dffedf..439d1da2e 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -16,7 +16,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.attention.rotary.config import NoRotaryConfig logger = logging.getLogger(__name__) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 56bae90fe..09ee788e6 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -14,7 +14,6 @@ class FastLLMModel[ConfigType: FastLLMModelConfig](MultiStageModel[ConfigType]): - config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig _is_loaded: bool = False def save_checkpoint( diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index be15cd37a..cb0a02a67 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -9,7 +9,7 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode @@ -320,27 +320,31 @@ def import_state_tensor( return end - begin def export_shard( - self, shard: torch.Tensor, distributed: Distributed, data_type: DataType | None = None + self, shard: torch.Tensor, data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: if data_type is not None: shard = shard.to(dtype=data_type.torch) tensors = self.split_buffer(self.reconstruct_from_shard(shard)) for name, meta in self._parameter_metas.items(): - yield name, meta.local_to_global(tensors[name], distributed=distributed)[0] + yield name, meta.local_to_global(tensors[name])[0] def log_shard(self, name, shard, *, distributed: Distributed, level, global_: bool) -> None: # if global_ is None: # global_ = self._config.debug_global_tensors parameters = self.split_buffer(self.reconstruct_from_shard(shard)) if global_ else self.split_shard(shard) for parameter_name, parameter in parameters.items(): + meta = self.get_parameter_meta(parameter_name) log_distributed_tensor( name, parameter, level=level, - distributed=distributed, global_=global_, - duplicate_groups=(distributed.data_group,), - meta=self.get_parameter_meta(parameter_name), + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=( + distributed.data_group, + distributed.tensor_group, + ), + meta=meta, ) def restore_parameters(self) -> None: diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 1f734268b..d939bda2b 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -12,7 +12,7 @@ from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode @@ -26,7 +26,6 @@ class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]): - config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel _is_setup: bool = False _flat_shard: torch.Tensor diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 87eac31c4..35547cd87 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import check_parallel_match from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.config import StageConfig, StageMode from fast_llm.engine.multi_stage.stage_base import StageBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage, log_tensor from fast_llm.tensor import ParameterMeta, TensorMeta, accumulate_gradient @@ -30,7 +30,7 @@ def hook(grad_inputs, grad_outputs): # noqa return hook -class Stage(StageBase): +class Stage[ConfigType: StageConfig](StageBase[ConfigType]): _is_restored: bool _training: bool | None = None # TODO: Handle all buffer sharing in multi_stage @@ -123,7 +123,7 @@ def forward( # Last layer does not provide output if output is not None: meta = self._meta_outputs[i] - output_global, _ = meta.local_to_global(output.detach(), distributed=self._distributed) + output_global, _ = meta.local_to_global(output.detach()) kwargs["hidden_states"][self._layer_range[i]] = { "layer_type": type(layer).__name__, "tensor": output_global, @@ -216,11 +216,13 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] if (nms := kwargs.get("micro_batch_splits", 1)) > 1: name = f"{name}, ms={kwargs.get('micro_batch_split',0)}/{nms}" + # Assuming all tensors are either duplicated of parallel in the TP direction. log_distributed_tensor( name, output, level=self._config.debug_layer_outputs, - distributed=self._distributed, + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=(self._distributed.tensor_group,), global_=self._config.debug_global_tensors, meta=self._meta_outputs[i], ) @@ -250,8 +252,9 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any name, input_, level=self._config.debug_layer_gradients, - distributed=self._distributed, grad_fn=lambda grad: grad / self._fsdp_size, + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=(self._distributed.tensor_group,), global_=self._config.debug_global_tensors, meta=self._meta_inputs[i], ) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 3218a1963..ded24e538 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -20,8 +20,7 @@ logger = logging.getLogger(__name__) -class StageBase(Configurable[StageConfig]): - config_class: typing.ClassVar[type[StageConfig]] = StageConfig +class StageBase[ConfigType: StageConfig](Configurable[ConfigType]): _distributed: Distributed _mode: StageMode @@ -315,7 +314,7 @@ def _export_shard( self, shards: tuple[torch.Tensor], data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: for fsdp, shard in zip(self._fsdps, shards, strict=True): - yield from fsdp.export_shard(shard, self._distributed, data_type) + yield from fsdp.export_shard(shard, data_type) def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]]: # Get all the stage parameters, diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index b367eff81..21ecbe476 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -56,15 +56,14 @@ def done(self): def __repr__(self): return ( - f"BatchContext(iteration={self.iteration}," - f" batch_len={len(self.batch)}," + f"BatchContext(batch_len={len(self.batch)}," f" inputs={list(self.inputs)}," f" contexts={list(self.contexts)}," f" losses={ {key: len(value) for key, value in self.losses.items()}}," ) -class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ScheduleConfig]): +class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ConfigType]): _is_setup: bool = False _compute_stream: torch.cuda.Stream _data_stream: torch.cuda.Stream @@ -192,11 +191,7 @@ def run_step( # Run the steps according to the schedule for step in schedule: - try: - self._train_step(context, step) - except Exception as e: - # Add debugging information to errors, which may be cryptic outside the forward pass. - raise RuntimeError(f"Schedule step {step} failed, context = {context}: {e}") + self._train_step(context, step) # Make sure we used all the data. This also ensures the generator terminates and prevents a memory leak. try: @@ -400,7 +395,7 @@ def _recv(self, context: BatchContext, step: Step) -> None: step.recv_event.wait() self._record_event(context, EventType.compute_wait_pipe, step) - def _forward(self, context: BatchContext, step: Step) -> torch.Tensor | None: + def _forward(self, context: BatchContext, step: Step) -> None: output, grad_context = self._stages[step.stage].forward( self._get_forward_input(context, step), context.batch[step.data_index], diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index ec3c4cebe..c9b272366 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -43,8 +43,6 @@ class TrainingEvaluator[ConfigType: TrainingEvaluatorConfig](Evaluator[ConfigType]): - config_class: typing.ClassVar[type[TrainingEvaluatorConfig]] = TrainingEvaluatorConfig - evaluator: Evaluator def __init__( @@ -114,7 +112,6 @@ def get_sampling_parameters(self) -> EvaluatorSamplingParameters | None: class Trainer[ConfigType: TrainerConfig](Configurable[ConfigType], abc.ABC): - config_class: typing.ClassVar[type[TrainerConfig]] = TrainerConfig # TODO: Generalize data, schedule, logging, etc. _is_setup: bool = False _distributed: Distributed diff --git a/fast_llm/layers/transformer/__init__.py b/fast_llm/layers/attention/__init__.py similarity index 100% rename from fast_llm/layers/transformer/__init__.py rename to fast_llm/layers/attention/__init__.py diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/attention/attention.py similarity index 58% rename from fast_llm/layers/transformer/attention.py rename to fast_llm/layers/attention/attention.py index c59b191af..8a4c490c9 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -4,18 +4,16 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockConfig, BlockDimNames +from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import ( - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerSubLayerName, -) -from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import get_lr_scale +from fast_llm.utils import combine_lr_scales, div try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -50,35 +48,52 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(Mixer): +class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): """ A self-attention layer. """ - _mixer_name: typing.ClassVar[str] = "attn" - - _QUERY_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_heads, - TransformerDimNames.kv_channels, - ) - _KV_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.head_groups, - TransformerDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - TransformerDimNames.batch, - TransformerDimNames.sequence_q, - TransformerDimNames.composite_dense, - ) - - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): - super().__init__(tensor_space, block_index, config.debug_transformer) - self._config = config - self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) + def __init__( + self, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + ): + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) + + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( + DistributedDimNames.sequence_data + ) + head_group_dim = TensorDim( + "head_groups", self._config.head_groups, self._parallel_dim if self._config.head_groups > 1 else None + ) + group_heads_dim = TensorDim( + "group_heads", + div(self._config.num_attention_heads, self._config.head_groups), + None if self._config.head_groups > 1 else self._parallel_dim, + ) + self._local_head_groups = head_group_dim.size + self._local_heads_per_group = group_heads_dim.size + self._local_heads = self._local_head_groups * self._local_heads_per_group + + kv_channels_dim = TensorDim("kv_channels", self._config.kv_channels) + query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, kv_channels_dim)) + key_value_dim = ConcatenatedTensorDim( + "key_value", + ( + CompositeTensorDim("key", (head_group_dim, kv_channels_dim)), + CompositeTensorDim("value", (head_group_dim, kv_channels_dim)), + ), + ) + dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, kv_channels_dim)) + + self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) init_method_qkv = init_normal_( std=self._config.init_method_std_qkv, @@ -91,57 +106,69 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size - self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size - self._local_heads = self._local_head_groups * self._local_heads_per_group - self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - - hidden_dim = self._tensor_space[TransformerDimNames.hidden] - - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales( + self._lr_scale, + self._config.attention_lr_scale, + ) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_query], - bias=self._config.add_attn_qkv_bias, + query_dim, + bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[TransformerDimNames.composite_key_value], - bias=self._config.add_attn_qkv_bias, + key_value_dim, + bias=self._config.add_qkv_bias, weight_init_method=init_method_qkv, - bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. - self._rotary = self._config.rotary.build() + self._rotary = self._config.rotary.get_layer(kv_channels_dim) # Output. self.dense = InputParallelLinear( - self._tensor_space[TransformerDimNames.composite_dense], + dense_dim, hidden_dim, - bias=self._config.add_attn_dense_bias, + bias=self._config.add_dense_bias, weight_init_method=init_method_std_attn_proj, - bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, + bias_init_method=init_zeros_, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) # PEFT. - self.query = self._config.peft.apply_linear(self.query, TransformerSubLayerName.query) - self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) - self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + self.query = self._block_config.peft.apply_linear(self.query, TransformerSubLayerName.query) + self.key_value = self._block_config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) + self.dense = self._block_config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + + if self._debug.enabled: + self._query_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + CompositeTensorDim("heads", (head_group_dim, group_heads_dim)), + kv_channels_dim, + ) + self._kv_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + head_group_dim, + kv_channels_dim, + ) + self._context_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + dense_dim, + ) def _attn_fused( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor @@ -151,16 +178,18 @@ def _attn_fused( sk = key.size(1) if self._local_head_groups == 1: - query = query.view(b, sq * self._local_heads, self._kv_channels) + query = query.view(b, sq * self._local_heads, self._config.kv_channels) key = key.transpose(-1, -2) else: query = ( - query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._kv_channels)) + query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._config.kv_channels)) .transpose(1, 2) - .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._kv_channels) + .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._config.kv_channels) + ) + key = key.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).movedim(1, 3).flatten(0, 1) + value = ( + value.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).transpose(1, 2).flatten(0, 1) ) - key = key.unflatten(-1, (self._local_head_groups, self._kv_channels)).movedim(1, 3).flatten(0, 1) - value = value.unflatten(-1, (self._local_head_groups, self._kv_channels)).transpose(1, 2).flatten(0, 1) attn_weights = torch.empty( (b * self._local_head_groups, sq * self._local_heads_per_group, sk), device=query.device, dtype=query.dtype @@ -177,7 +206,7 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - with set_generator(self._tensor_space.distributed.tp_generator): + with set_generator(self._distributed.tp_generator): attn_weights = torch.dropout(attn_weights, self._config.attention_dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value @@ -187,7 +216,7 @@ def _attn_fused( return attn_output.view(b, sq, -1) else: return ( - attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._kv_channels) + attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._config.kv_channels) .transpose(1, 2) .flatten(2) ) @@ -199,18 +228,16 @@ def _query_key_value_forward( handle = None - if self._head_groups == 1 and self._sequence_parallel: - key_value, handle = gather_op( - key_value, group=self._tensor_space.distributed.tensor_group, dim=0, async_op=True - ) + if self._config.head_groups == 1 and self._sequence_parallel: + key_value, handle = gather_op(key_value, group=self._parallel_dim.group, dim=0, async_op=True) - if self._tensor_space.distributed.sequence_data_group: + if self._sequence_data_parallel_dim.group: if handle: # TODO: This is probably unnecessary. handle.wait() # sequence dim may not be zero, but this needs to be handled after `handle.wait()` key_value, handle = gather_op( - key_value, group=self._tensor_space.distributed.sequence_data_group, dim=0, async_op=True + key_value, group=self._sequence_data_parallel_dim.group, dim=0, async_op=True ) query, query_context = self.query.forward_only(input_) @@ -218,8 +245,8 @@ def _query_key_value_forward( if handle: handle.wait() - if self._tensor_space.distributed.sequence_data_group and not sequence_first: - key_value = swap_mult_dim(key_value, self._tensor_space.distributed_config.sequence_data_parallel, 0, 1) + if self._sequence_data_parallel_dim.group and not sequence_first: + key_value = swap_mult_dim(key_value, self._sequence_parallel, 0, 1) context = {"query": query_context, "key_value": key_value_context, "sequence_first": sequence_first} return query, key_value, context @@ -228,15 +255,12 @@ def _query_key_value_backward( self, query_grad: torch.Tensor, key_value_grad: torch.Tensor, context: dict ) -> torch.Tensor: # TODO: De-allocate qkv grads quicker. - handle = None - - if self._tensor_space.distributed.sequence_data_group: - key_value_grad, handle = reduce_scatter_op( - key_value_grad, - group=self._tensor_space.distributed.sequence_data_group, - dim=1 - context["sequence_first"], - async_op=True, - ) + key_value_grad, handle = reduce_scatter_op( + key_value_grad, + group=self._sequence_data_parallel_dim.group, + dim=1 - context["sequence_first"], + async_op=True, + ) # TODO: Overlap with both. input_grad = self.query.backward(query_grad, context.pop("query")) @@ -244,7 +268,7 @@ def _query_key_value_backward( if handle: handle.wait() - if self._head_groups == 1 and (group := self._tensor_space.distributed.tensor_group): + if self._config.head_groups == 1 and (group := self._parallel_dim.group): if self._sequence_parallel: key_value_grad = reduce_scatter_op(key_value_grad, group=group, dim=0) else: @@ -263,29 +287,35 @@ def _decide_window_size(self) -> int | None: return window_size - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - sequence_first = kwargs[TransformerKwargs.sequence_first] + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + sequence_first = kwargs[AttentionKwargs.sequence_first] query, key_value = self._query_key_value(input_, sequence_first) # TODO: Move the rest to function. - if (past_key_values := kwargs.get(TransformerKwargs.past_key_values)) is not None: + if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: assert sequence_first # Clear the lists so tensors can be de-allocated key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - if (presents := kwargs.get(TransformerKwargs.presents)) is not None: + if (presents := kwargs.get(AttentionKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences # don't propagate to this one. presents.append(present := key_value.detach().requires_grad_()) # Manually add the gradients from later micro-sequences. key_value = AttachGrad.apply(key_value, present) - if self._tensor_space.distributed.sequence_data_group: + if self._sequence_data_parallel_dim.group: key_value = ( - key_value[: kwargs[TransformerKwargs.sequence_k_dim].size] + key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] if sequence_first - else key_value[:, : kwargs[TransformerKwargs.sequence_k_dim].size] + else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] ) if sequence_first: @@ -293,28 +323,23 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query = query.transpose(0, 1).contiguous() key_value = key_value.transpose(0, 1).contiguous() - key, value = key_value.split(self._local_head_groups * self._kv_channels, dim=-1) + key, value = key_value.split(self._local_head_groups * self._config.kv_channels, dim=-1) - query = query.view(*query.shape[:2], self._local_heads, self._kv_channels) - key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) - value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) + query = query.view(*query.shape[:2], self._local_heads, self._config.kv_channels) + key = key.view(*key.shape[:2], self._local_head_groups, self._config.kv_channels) + value = value.view(*value.shape[:2], self._local_head_groups, self._config.kv_channels) - if self._debug_level: - self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) - self._debug_log( - key, - "key_rotary_input", - self._KV_DIMS, - kwargs, - ) + if self._debug.enabled: + self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) + self._debug(key, "key_rotary_input", self._KV_DIMS, kwargs) query, key = self._rotary(query, key, kwargs) window_size = self._decide_window_size() if self._use_flash_attention: assert _flash_available - with set_generator(self._tensor_space.distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(TransformerKwargs.cu_seqlens_q, None)) is not None: + with set_generator(self._distributed.tp_generator): + if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) key = key.view(-1, key.size(-2), key.size(-1)) @@ -324,9 +349,9 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key, value, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(TransformerKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(TransformerKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(TransformerKwargs.max_seqlen_k), + cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), dropout_p=self._config.attention_dropout if self.training else 0.0, window_size=(-1, -1) if window_size is None else (window_size - 1, 0), causal=True, @@ -349,25 +374,15 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ query.flatten(-2), key.flatten(-2), value.flatten(-2), - kwargs[TransformerKwargs.attention_mask], - kwargs[TransformerKwargs.attention_mask_value], + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], ) - if self._debug_level: - self._debug_log(query, "query", self._QUERY_DIMS, kwargs) - self._debug_log( - key, - "key", - self._KV_DIMS, - kwargs, - ) - self._debug_log( - value, - "value", - self._KV_DIMS, - kwargs, - ) - self._debug_log(input_, "context", self._CONTEXT_DIMS, kwargs) + if self._debug.enabled: + self._debug(query, "query", self._query_dims, kwargs) + self._debug(key, "key", self._kv_dims, kwargs) + self._debug(value, "value", self._kv_dims, kwargs) + self._debug(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/attention/block.py b/fast_llm/layers/attention/block.py new file mode 100644 index 000000000..3396a2997 --- /dev/null +++ b/fast_llm/layers/attention/block.py @@ -0,0 +1,22 @@ +import functools +import logging +import typing + +from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.config import AttentionConfig, TransformerConfig +from fast_llm.layers.block.block import Block + +logger = logging.getLogger(__name__) + + +class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" + + @functools.cached_property + def _mixer_class(self) -> type[Attention]: + return Attention + + @property + def _mixer_config(self) -> AttentionConfig: + return self._config diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py new file mode 100644 index 000000000..e5c638adc --- /dev/null +++ b/fast_llm/layers/attention/config.py @@ -0,0 +1,198 @@ +import functools +import logging +import warnings + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.functional.config import TritonConfig +from fast_llm.layers.attention.rotary.config import RotaryConfig +from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs +from fast_llm.utils import Assert, div + +logger = logging.getLogger(__name__) + + +class AttentionKwargs(BlockKwargs): + rotary_freq_q = "rotary_freq_q" + rotary_freq_k = "rotary_freq_k" + attention_mask = "attention_mask" + attention_mask_value = "attention_mask_value" + cu_seqlens_q = "cu_seqlens_q" + cu_seqlens_k = "cu_seqlens_k" + max_seqlen_q = "max_seqlen_q" + max_seqlen_k = "max_seqlen_k" + # TODO: Review these + presents = "presents" + past_key_values = "past_key_values" + + +@config_class() +class AttentionConfig(Config): + # TODO: Make mixer class dynamic. + _abstract = False + + # TODO: Review names + rotary: RotaryConfig = Field( + desc="Configuration for the rotary positional embeddings.", + hint=FieldHint.architecture, + ) + num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) + head_groups: int = Field( + default=1, + desc="Number of head group for grouped query attention.", + doc="Set to 1 for multi-query attention, `num_attention_heads` for multi-head.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + kv_channels: int = Field( + default=None, + desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + attention_dropout: float = Field( + default=0.0, + desc="Dropout applied to the attention intermediate states.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + # Use flash attention if possible (fp16 or bf16) + use_flash_attention: bool = Field( + default=True, desc="Enable Flash Attention if possible.", hint=FieldHint.optional + ) + window_size: int | None = Field( + default=None, + desc="Size of the attention sliding window. Warning: this parameter is not part of the architecture and must be redefined when loading a pretrained model.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + max_window_layers: int | None = Field( + default=None, + desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", + hint=FieldHint.optional, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + attention_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the Attention projection weights.", + doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + attention_softmax_scale_power: float = Field( + default=0.5, + desc="The scaling power to apply to kv_channel in the attention calculation. " + " Under Standard Parameterization (SP): default to 0.5. " + " Under muP (if scaling kv_channels size): use 1. " + " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + # TODO: Review initialization + init_method_std_qkv: float = Field( + default=None, + desc="Scale for the query, key and value weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_qkv: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_qkv: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_attn_proj: float = Field( + default=None, + desc="Scale for the attention projection weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_attn_proj: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_attn_proj: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", + hint=FieldHint.optional, + ) + + def _validate(self) -> None: + with self._set_implicit_default(): + # TODO: Make this work without inheritance. + if self.kv_channels is None: + self.kv_channels = div(self.hidden_size, self.num_attention_heads) + # TODO: Review initialization + if self.init_method_std_qkv is None: + self.init_method_std_qkv = self.init_method_std + if self.init_method_std_attn_proj is None: + self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_max_qkv is None: + self.init_method_max_qkv = self.init_method_max + if self.init_method_min_qkv is None: + self.init_method_min_qkv = self.init_method_min + if self.init_method_max_attn_proj is None: + self.init_method_max_attn_proj = self.init_method_max + if self.init_method_min_attn_proj is None: + self.init_method_min_attn_proj = self.init_method_min + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min, self.init_method_max) + if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: + Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) + if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: + Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + + super()._validate() + + if not TritonConfig.TRITON_ENABLED: + warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") + + Assert.multiple(self.num_attention_heads, self.head_groups) + + @functools.cached_property + def projection_size(self): + assert self._validated + return self.num_attention_heads * self.kv_channels + + def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: + return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + + @property + def add_qkv_bias(self) -> bool: + # TODO: Make this work without inheritance. + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.nowhere: + return False + return True + + @property + def add_dense_bias(self) -> bool: + # TODO: Make this work without inheritance. + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + + +@config_class() +# TODO: Use composition instead +class TransformerConfig(AttentionConfig, BlockConfig): + _abstract = False + + def _validate(self) -> None: + with self._set_implicit_default(): + # Kept here for initialization order. + # TODO: Review initialization + if self.init_method_std is None: + self.init_method_std = self.hidden_size**-0.5 + if self.init_method_min is not None and self.init_method_max is not None: + Assert.leq(self.init_method_min, self.init_method_max) + + super()._validate() diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/attention/preprocessing.py similarity index 61% rename from fast_llm/layers/transformer/preprocessing.py rename to fast_llm/layers/attention/preprocessing.py index 3f0e14eb7..24ef3397c 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -4,33 +4,27 @@ import torch from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) class BackupAttentionPreprocessor(Preprocessor): - _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _mask: torch.Tensor _mask_value: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - ): + def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + self._distributed_config = distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -38,7 +32,7 @@ def _create_tensors(self, sequence_length: int) -> None: self._mask = torch.ones( (sequence_length, sequence_length), dtype=torch.bool, - device=self._tensor_space.distributed.device, + device=device, ).tril_() if self._config.window_size is not None: @@ -47,57 +41,56 @@ def _create_tensors(self, sequence_length: int) -> None: [], torch.finfo(self._distributed_config.training_dtype.torch).min, dtype=self._distributed_config.training_dtype.torch, - device=self._tensor_space.distributed.device, + device=device, ) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - kwargs[TransformerKwargs.attention_mask] = self._mask[ + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + kwargs[AttentionKwargs.attention_mask] = self._mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths, None)) is not None: + if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) for sample_lens in sequence_lengths ] ) - document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) - kwargs[TransformerKwargs.attention_mask] = ( - kwargs[TransformerKwargs.attention_mask] + document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) + kwargs[AttentionKwargs.attention_mask] = ( + kwargs[AttentionKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] ) - kwargs[TransformerKwargs.attention_mask_value] = self._mask_value + kwargs[AttentionKwargs.attention_mask_value] = self._mask_value def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( + kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( ( - self._scalar_dim, - self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], - self._scalar_dim, - kwargs[TransformerKwargs.sequence_k_dim], + scalar_dim, + scalar_dim, + kwargs[AttentionKwargs.sequence_q_dim], + scalar_dim, + kwargs[AttentionKwargs.sequence_k_dim], ), - tensor_name=TransformerKwargs.attention_mask, + tensor_name=AttentionKwargs.attention_mask, dtype=torch.bool, ) - kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( - (self._scalar_dim,), - tensor_name=TransformerKwargs.attention_mask_value, - dtype=self._tensor_space.distributed_config.training_dtype.torch, + kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( + (scalar_dim,), + tensor_name=AttentionKwargs.attention_mask_value, + dtype=self._distributed_config.training_dtype.torch, ) class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): + def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + self._distributed_config = distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 @@ -107,12 +100,12 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if TransformerKwargs.sequence_lengths not in kwargs: + if AttentionKwargs.sequence_lengths not in kwargs: return - sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if sequence_q < kwargs[TransformerKwargs.sequence_length]: + sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size + if sequence_q < kwargs[AttentionKwargs.sequence_length]: cumsums = [torch.cumsum(x, dim=0) for x in sequence_lengths] # The first and last documents in a microsequence need to be handled separately. Include all tokens from other documents # in the microsequence. We need to consider all keys computed so far from the first sample. We also store the offsets @@ -146,17 +139,17 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: else: seqlens_q = torch.cat(sequence_lengths) seqlens_k = torch.cat(sequence_lengths) - kwargs[TransformerKwargs.cu_seqlens_q] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( ( - torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), - torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), ) ) - kwargs[TransformerKwargs.cu_seqlens_k] = torch.cat( + kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( ( - torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), - torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), ) ) - kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() - kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() + kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() + kwargs[AttentionKwargs.max_seqlen_k] = seqlens_k.max() diff --git a/fast_llm/layers/transformer/rotary/__init__.py b/fast_llm/layers/attention/rotary/__init__.py similarity index 100% rename from fast_llm/layers/transformer/rotary/__init__.py rename to fast_llm/layers/attention/rotary/__init__.py diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/attention/rotary/config.py similarity index 88% rename from fast_llm/layers/transformer/rotary/config.py rename to fast_llm/layers/attention/rotary/config.py index 748f2af28..4ebd6c5dc 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -5,12 +5,12 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary @config_class(registry=True) @@ -29,8 +29,8 @@ def _from_dict( return NoRotaryConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) - def build(self, tensor_space: TensorSpace | None = None) -> "Rotary": - return self._get_configurable_class()(self, tensor_space) + def get_layer(self, kv_channels_dim: TensorDim) -> "Rotary": + return self._get_configurable_class()(self, kv_channels_dim) @classmethod @abc.abstractmethod @@ -44,7 +44,7 @@ class NoRotaryConfig(RotaryConfig): @classmethod def _get_configurable_class(self) -> "type[NoRotary]": - from fast_llm.layers.transformer.rotary.rotary import NoRotary + from fast_llm.layers.attention.rotary.rotary import NoRotary return NoRotary @@ -75,7 +75,7 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") def _get_configurable_class(self) -> "type[DefaultRotary]": - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary return DefaultRotary @@ -97,7 +97,7 @@ def _validate(self) -> None: Assert.gt(self.high_frequency_factor, self.low_frequency_factor) def _get_configurable_class(self) -> "type[Llama3Rotary]": - from fast_llm.layers.transformer.rotary.rotary import Llama3Rotary + from fast_llm.layers.attention.rotary.rotary import Llama3Rotary return Llama3Rotary @@ -137,6 +137,6 @@ def _validate(self) -> None: super()._validate() def _get_configurable_class(self) -> "type[YarnRotary]": - from fast_llm.layers.transformer.rotary.rotary import YarnRotary + from fast_llm.layers.attention.rotary.rotary import YarnRotary return YarnRotary diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py similarity index 71% rename from fast_llm/layers/transformer/rotary/rotary.py rename to fast_llm/layers/attention/rotary/rotary.py index 17b18a1ca..53b24c9bb 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -6,10 +6,10 @@ from fast_llm.config import Configurable from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.rotary.config import ( +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, @@ -41,14 +41,14 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) -class Rotary[ConfigType: RotaryConfig](Configurable[RotaryConfig], torch.nn.Module, Preprocessor): +class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module, Preprocessor): def __init__( self, config: ConfigType, - # The tensor space is only needed for preprocessing, so we make it optional. - tensor_space: TensorSpace | None = None, + kv_channels_dim: TensorDim, ): super().__init__(config) + self._kv_channels_dim = kv_channels_dim @abc.abstractmethod def forward( @@ -57,7 +57,7 @@ def forward( pass -class NoRotary[ConfigType: NoRotaryConfig](Rotary[NoRotaryConfig]): +class NoRotary[ConfigType: NoRotaryConfig](Rotary[ConfigType]): def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -70,60 +70,47 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass -class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[DefaultRotaryConfig]): +class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[ConfigType]): _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: ConfigType, - tensor_space: TensorSpace | None = None, - ): - super().__init__(config, tensor_space) - self._tensor_space = tensor_space - if self._tensor_space is not None: - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) + sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ + :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( - self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, + kwargs[AttentionKwargs.sequence_q_dim], + scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_q, + tensor_name=AttentionKwargs.rotary_freq_q, ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( + kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( - self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, + kwargs[AttentionKwargs.sequence_q_dim], + scalar_dim, self._kv_channels_dim, ), - tensor_name=TransformerKwargs.rotary_freq_k, + tensor_name=AttentionKwargs.rotary_freq_k, ) def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings - query = rotary_fn(query, kwargs[TransformerKwargs.rotary_freq_q]) - key = rotary_fn(key, kwargs[TransformerKwargs.rotary_freq_k]) + query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -131,10 +118,10 @@ def _create_tensors(self, sequence_length: int) -> None: self._rotary_embedding_frequencies = self._get_frequencies( sequence_length, self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, + device=device, ) - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, # `a = theta ** - (2 * (channel // 2) / kv_channels)`, @@ -149,12 +136,12 @@ def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda" ).contiguous() return frequencies - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: return self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) -class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[Llama3RotaryConfig]): - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: +class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[ConfigType]): + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) low_frequency_wavelength = self._config.original_context_length / self._config.low_frequency_factor high_frequency_wavelength = self._config.original_context_length / self._config.high_frequency_factor @@ -173,17 +160,17 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: return torch.stack(new_scales) -class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): +class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[ConfigType]): """ Yarn scaling: https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 [original paper](https://arxiv.org/abs/2309.00071) """ - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) # TODO: max_position_embeddings or original_context_length? # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304 diff --git a/fast_llm/layers/block/__init__.py b/fast_llm/layers/block/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py new file mode 100644 index 000000000..f90fce698 --- /dev/null +++ b/fast_llm/layers/block/block.py @@ -0,0 +1,273 @@ +import abc +import functools +import logging +import typing + +import torch + +from fast_llm.config import Config, Configurable +from fast_llm.core.distributed import set_generator +from fast_llm.engine.base_model.base_model import Layer, Module +from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage +from fast_llm.tensor import TensorMeta + +logger = logging.getLogger(__name__) + + +class DebugLayer: + # TODO: Move elsewhere? + def __init__(self, name: str, debug_level: int = 0, debug_memory: bool = False): + self._name = name + self._debug_level = debug_level + self._debug_memory = debug_memory + + def _get_meta( + self, tensor: torch.Tensor, name: str, dims: tuple[TensorDim | str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim for dim in kwargs[BlockKwargs.hidden_dims] + (kwargs[BlockKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + ( + dim + if isinstance(dim, TensorDim) + else hidden_dims[dim] if dim in hidden_dims else TensorDim(dim, tensor.size(i)) + ) + for i, dim in enumerate(dims) + ), + tensor_name=f"{self._name} {name}", + dtype=tensor.dtype, + ) + + @functools.cached_property + def enabled(self) -> bool: + return self._debug_level > 0 or self._debug_memory + + def __call__[ + T + ]( + self, + tensor: torch.Tensor | None, + name: str, + dims: tuple[TensorDim | str, ...], + kwargs: dict[str, typing.Any], + scale: float = 1.0, + global_: bool = True, + log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, + ) -> None: + # TODO: Local vs global? + if self._debug_memory: + log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) + if self._debug_level > 0 and tensor is not None: + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dims, kwargs), + global_=global_, + log_fn=log_fn, + scale=scale, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dims, kwargs), + global_=global_, + log_fn=log_fn, + scale=scale, + ) + + +class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): + """ + Base class for blocks, mixer and MLP modules. + """ + + def __init__( + self, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + # TODO: Review `hidden_dim` and `block_index` + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + ): + super().__init__(config, distributed_config) + self._block_config = block_config + self._hidden_dim = hidden_dim + self._block_index = block_index + self._name = name + self._sequence_parallel: bool = self._distributed_config.sequence_tensor_parallel + self._debug = DebugLayer( + self._name, + self._block_config.debug_transformer, + self._block_config.debug_transformer_memory, + ) + self._lr_scale = lr_scale + + +class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): + """ + Base class for mixer and MLP modules. + """ + + @abc.abstractmethod + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + pass + + +class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): + """ + A transformer-like decoder base block with abstract mixer. + """ + + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + return_input: bool = False, + ): + super().__init__( + config, + config, + distributed_config, + hidden_dim, + block_index, + name, + lr_scale, + ) + # For multi-token prediction, return a stack of shared_hidden and transformer_output. + self._return_input: bool = return_input + # Note, layer_lr_scale does not impact the norms + # TODO: add a separate norm_lr_scale + self.norm_1 = self._config.peft.apply_other( + self._config.normalization.get_layer(self._hidden_dim, self._lr_scale) + ) + self.norm_2 = self._config.peft.apply_other( + self._config.normalization.get_layer(self._hidden_dim, self._lr_scale) + ) + + # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. + setattr( + self, + self._mixer_module_name, + self._mixer_class( + self._mixer_config, + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} mixer", + self._lr_scale, + ), + ) + + # TODO: Use dynamic type. + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + from fast_llm.layers.block.mlp.mlp import MLP + + self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( + self._config, + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} MLP", + self._lr_scale, + ) + + @functools.cached_property + @abc.abstractmethod + def _mixer_class(self) -> type[BlockLayer]: + pass + + @property + @abc.abstractmethod + def _mixer_config(self) -> Config: + pass + + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + getattr(self, self._mixer_module_name).setup(distributed) + self.mlp.setup(distributed) + + @torch.compile + def _bias_dropout_add( + self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor + ) -> torch.Tensor: + if bias is not None: + input_ = input_ + bias + return residual + torch.dropout(input_, self._config.hidden_dropout, self.training) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + dims = kwargs[BlockKwargs.hidden_dims] + if self._return_input: + dims = (TensorDim("stacked_input_output", 2),) + dims + return TensorMeta.from_dims(dims, tensor_name=f"{self._name} output", dtype=input_.dtype) + generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator + if self._debug.enabled: + self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) + fw_input = input_ + hidden_states = self.norm_1(input_) + if self._debug.enabled: + self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "mixer output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + input_ = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(input_, "mixer residual", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states = self.norm_2(input_) + if self._debug.enabled: + self._debug(hidden_states, "norm 2", kwargs[BlockKwargs.hidden_dims], kwargs) + hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) + if self._debug.enabled: + self._debug( + hidden_states if bias is None else hidden_states + bias, + "MLP output", + kwargs[BlockKwargs.hidden_dims], + kwargs, + ) + with set_generator(generator): + hidden_states = self._bias_dropout_add(hidden_states, bias, input_) + if self._debug.enabled: + self._debug(None, "MLP residual", kwargs[BlockKwargs.hidden_dims], kwargs) + if self._return_input: + hidden_states = torch.stack((fw_input, hidden_states), dim=0) + return hidden_states diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py new file mode 100644 index 000000000..29acaadf0 --- /dev/null +++ b/fast_llm/layers/block/config.py @@ -0,0 +1,121 @@ +import enum + +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.layers.block.mlp.config import MLPConfig +from fast_llm.layers.block.peft import TransformerPeftConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.utils import Assert + +# TODO: Generalize these beyond language models? (Ex. vision) + + +class BlockDimNames: + # A set of common tensor dim names packed into a namespace. + # Input dimensions (variable) + # TODO: Does batch belong here? + batch = "batch" + # TODO: Distinguish micro-sequence? + sequence_q = "sequence_q" + sequence_q_tp = "sequence_q_tp" + sequence_k = "sequence_k" + hidden = "hidden" + + +class BlockKwargs: + sequence_first = "sequence_first" + hidden_dims = "hidden_dims" + sequence_q_dim = "sequence_q_dim" + sequence_k_dim = "sequence_k_dim" + # TODO: These are confusing + sequence_length = "sequence_length" + sequence_lengths = "sequence_lengths" + # TODO: Belongs elsewhere? + grad_output = "grad_output" + + +class AddLinearBiasChoices(str, enum.Enum): + nowhere = "nowhere" + everywhere = "everywhere" + only_attn_qkv = "only_attn_qkv" + + +@config_class() +# TODO: Use composition instead +class BlockConfig(MLPConfig, BaseModelConfig): + + # TODO: Review names + normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layers architecture.", + hint=FieldHint.architecture, + ) + peft: TransformerPeftConfig = Field( + desc="Configuration for the parameter-efficient fine tuning.", + hint=FieldHint.architecture, + ) + # TODO: Review names + hidden_dropout: float = Field( + default=0.0, + desc="Dropout applied to the residual connections.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) + debug_transformer: int = Field( + default=0, + desc="Log the output of each operation in a transformer layer.", + hint=FieldHint.logging, + valid=check_field(Assert.geq, 0), + ) + debug_transformer_memory: bool = Field( + default=False, + desc="Log the memory usage after each operation in a transformer layer..", + hint=FieldHint.logging, + ) + add_linear_biases: bool | AddLinearBiasChoices = Field( + default=True, + desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", + hint=FieldHint.architecture, + ) + + # TODO: Move these, not specific to a single block. + num_layers: int = Field( + default=12, + desc="Number of blocks in the model.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + hidden_size: int = Field( + default=1024, + desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + per_layer_lr_scale: list[float | None] | None = Field( + default=None, + desc="Custom learning rate scale for each layer.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + ) + + # TODO: Review initialization + init_method_std: float = Field( + default=None, + desc="Default scale for weight initialization. Default: hidden_size**-0.5", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max: float | None = Field( + default=None, + desc="Max value for clamping initialized weights. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min: float | None = Field( + default=None, + desc="Min value for clamping initialized weights. Default: -float('inf')", + hint=FieldHint.optional, + ) diff --git a/fast_llm/layers/block/mlp/__init__.py b/fast_llm/layers/block/mlp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py new file mode 100644 index 000000000..88ce4af10 --- /dev/null +++ b/fast_llm/layers/block/mlp/config.py @@ -0,0 +1,195 @@ +import enum +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + pass + + +class MLPLossNames: + load_balancing_loss = "load_balancing_loss" + router_z_loss = "router_z_loss" + + +class RoutingType(str, enum.Enum): + topk = "aux_loss" + sinkhorn = "sinkhorn" + + +@config_class() +class MLPConfig(Config): + # TODO: Review names # TODO: Separate MoE? + _abstract = False + ffn_hidden_size: int = Field( + default=None, + desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + num_experts: int = Field( + default=1, + desc="Number of MLP experts in a Mixture of Expert (MoE) model", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + num_shared_experts: int = Field( + default=0, + desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_unshared_experts: int = Field( + init=False, + desc="Number of MLP experts excluding shared ones", + hint=FieldHint.architecture, + valid=check_field(Assert.geq, 0), + ) + num_experts_per_token: int = Field( + default=1, + desc="Active experts for each token in a MoE model.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + expert_routing_type: RoutingType = Field( + default=RoutingType.topk, + desc="The routing method, i.e., the method used to assign experts to tokens.", + hint=FieldHint.architecture, + ) + gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) + activation_type: ActivationType = Field( + default=None, + desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", + hint=FieldHint.core, + ) + # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto + mlp_recompute_level: MLPRecomputeLevel = Field( + default=MLPRecomputeLevel.none, + desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", + hint=FieldHint.performance, + ) + expert_auxiliary_loss_coefficient: float = Field( + default=0.01, + desc="Scale of the load balancing auxiliary loss for topk routing.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + expert_z_loss_coefficient: float = Field( + default=0.0, + desc="Regularize the router during training by applying Z-loss to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + moe_jitter_eps: float = Field( + default=0.0, + desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) + mlp_lr_scale: float | None | tuple[float | None, ...] = Field( + default=None, + desc="Custom learning rate scale for each expert.", + doc="May be used to freeze some experts by setting their scale to zero.", + hint=FieldHint.feature, + ) + router_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate for the MoE router weight.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + dropless_moe: bool = Field( + default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert + ) + dropless_dynamic_shape: bool = Field( + default=False, + desc="Use a dynamic shape for dropless MLP instead of the worst-case value." + " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", + hint=FieldHint.expert, + ) + # TODO: Review initialization + init_method_std_mlp_1: float = Field( + default=None, + desc="Scale for the MLP first layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_1: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_1: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + init_method_std_mlp_2: float = Field( + default=None, + desc="Scale for the MLP second layer weight initialization. Default: init_method_std", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + init_method_max_mlp_2: float | None = Field( + default=None, + desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", + hint=FieldHint.optional, + ) + init_method_min_mlp_2: float | None = Field( + default=None, + desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", + hint=FieldHint.optional, + ) + + @property + def add_mlp_bias(self) -> bool: + from fast_llm.layers.block.config import AddLinearBiasChoices + + # TODO: Make this work without inheritance. + if isinstance(self.add_linear_biases, bool): + return self.add_linear_biases + if self.add_linear_biases == AddLinearBiasChoices.everywhere: + return True + return False + + def _validate(self) -> None: + with self._set_implicit_default(): + if self.activation_type is None: + self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + # TODO: Make this work without inheritance. + if self.ffn_hidden_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + # TODO: Review initialization + if self.init_method_std_mlp_1 is None: + self.init_method_std_mlp_1 = self.init_method_std + if self.init_method_std_mlp_2 is None: + self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 + if self.init_method_max_mlp_1 is None: + self.init_method_max_mlp_1 = self.init_method_max + if self.init_method_min_mlp_1 is None: + self.init_method_min_mlp_1 = self.init_method_min + if self.init_method_max_mlp_2 is None: + self.init_method_max_mlp_2 = self.init_method_max + if self.init_method_min_mlp_2 is None: + self.init_method_min_mlp_2 = self.init_method_min + if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: + Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) + if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: + Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + + self.num_unshared_experts = self.num_experts - self.num_shared_experts + + super()._validate() + + Assert.leq(self.num_shared_experts, self.num_experts) + Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) + + if isinstance(self.mlp_lr_scale, tuple): + Assert.eq(len(self.mlp_lr_scale), self.num_experts) + for scale in self.mlp_lr_scale: + if scale is not None: + Assert.geq(scale, 0) + elif self.mlp_lr_scale is not None: + Assert.geq(self.mlp_lr_scale, 0) diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py similarity index 52% rename from fast_llm/layers/transformer/mixture_of_experts.py rename to fast_llm/layers/block/mlp/mixture_of_experts.py index 4fd2844d5..4f7cf2dc4 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -1,32 +1,25 @@ import logging -import typing import warnings import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.initialization import init_normal_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType +from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.layers.transformer.config import ( - RoutingType, - TransformerConfig, - TransformerDimNames, - TransformerKwargs, - TransformerLossNames, -) -from fast_llm.layers.transformer.mlp import MLPBase -from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage -from fast_llm.tensor import TensorMeta, init_normal_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales logger = logging.getLogger(__name__) -class MixtureOfExpertMLP(MLPBase): +class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -40,84 +33,98 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): + def __init__( + self, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, block_index) - self._config = config - self._tensor_space = tensor_space - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - - self._num_experts = config.num_experts - self._experts_per_token = config.num_experts_per_token - self._num_shared_experts = config.num_shared_experts - self._num_unshared_experts = config.num_unshared_experts - - self._routing_type = config.expert_routing_type - self._load_balancing_factor = config.expert_auxiliary_loss_coefficient - self._z_loss_factor = config.expert_z_loss_coefficient - self._moe_jitter_eps = config.moe_jitter_eps - - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self.router = Linear( - tensor_space[TransformerDimNames.hidden], - tensor_space[TransformerDimNames.unshared_experts], + self._hidden_dim, + TensorDim("router_experts", self._config.num_unshared_experts), bias=False, weight_init_method=init_normal_( - std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max + std=self._block_config.init_method_std, + min_val=self._block_config.init_method_min, + max_val=self._block_config.init_method_max, ), - lr_scale=router_lr_scale, + lr_scale=combine_lr_scales(self._config.router_lr_scale, self._lr_scale), ) - dropless_moe = config.dropless_moe - if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: + dropless_moe = self._config.dropless_moe + if dropless_moe and self._sequence_parallel: warnings.warn( "Dropless MoE not supported for sequence-tensor-parallel, falling back to looped implementation." ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped - self._dynamic_shape = config.dropless_dynamic_shape + + if self._debug.enabled: + self._top_expert_dim = TensorDim("top_experts", self._config.num_experts_per_token) + + def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: + intermediate_1_dim, intermediate_2_dim = super()._get_intermediate_dims() + experts_dim = TensorDim("experts", self._config.num_experts) + return ( + CompositeTensorDim("moe_intermediate_1", (experts_dim, intermediate_1_dim)), + CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), + ) def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - if self._debug_mode: - self._debug_log(logits, "Router logits", TransformerDimNames.experts, kwargs) + if self._debug.enabled: + self._debug( + logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs + ) # Apply z_loss if applicable - if self._z_loss_factor > 0.0: + if self._config.expert_z_loss_coefficient > 0.0: logits = z_loss( logits, - self._z_loss_factor, + self._config.expert_z_loss_coefficient, self.training, grad_scale=kwargs.get("grad_output"), losses=losses, - loss_name=TransformerLossNames.router_z_loss, + loss_name=MLPLossNames.router_z_loss, ) # Apply input_jitter if applicable: - if self.training and self._moe_jitter_eps > 0.0: - with set_generator(self._tensor_space.distributed.pp_generator): + if self.training and self._config.moe_jitter_eps > 0.0: + with set_generator(self._distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing - if self._routing_type == RoutingType.topk: - scores, top_experts = self._topk_routing(logits, kwargs.get(TransformerKwargs.grad_output), losses) - if self._num_shared_experts > 0: + if self._config.expert_routing_type == RoutingType.topk: + scores, top_experts = self._topk_routing(logits, kwargs.get(BlockKwargs.grad_output), losses) + if self._config.num_shared_experts > 0: scores, top_experts = self._add_shared_experts(top_experts, scores) - elif self._routing_type == RoutingType.sinkhorn: + elif self._config.expert_routing_type == RoutingType.sinkhorn: scores, top_experts = self._sinkhorn_routing(logits) else: - raise NotImplementedError(self._routing_type) + raise NotImplementedError(self._config.expert_routing_type) - if self._debug_mode: + if self._debug.enabled: # To log all ranks set `global_=False` - self._debug_log(scores, "Router scores", TransformerDimNames.top_experts, kwargs) - self._debug_log(top_experts, "Router top experts", TransformerDimNames.top_experts, kwargs) + self._debug( + scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs + ) + self._debug( + top_experts, + "Router top experts", + kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), + kwargs, + ) return self._mlp_forward(hidden_states, scores, top_experts).view_as(input_), None # noqa @@ -125,7 +132,9 @@ def _forward_dropless( self, hidden_states: torch.Tensor, scores: torch.Tensor, top_experts: torch.Tensor ) -> torch.Tensor: # Compute token counts and the sparse mapping (dense_row, top_index) -> sparse_row. - sparse_map = get_sparse_map(top_experts, self._num_experts, dynamic_shape=self._dynamic_shape) + sparse_map = get_sparse_map( + top_experts, self._config.num_experts, dynamic_shape=self._config.dropless_dynamic_shape + ) # Sparse MLP return mlp_autograd( @@ -135,12 +144,12 @@ def _forward_dropless( None, self.layer_2.weight, None, - gated=self._gated, - activation_type=self._activation_type, - group=self._intermediate_dim.parallel_group, + gated=self._config.gated, + activation_type=self._config.activation_type, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, - recompute_level=self._recompute_level, + recompute_level=self._config.mlp_recompute_level, transposed_layer_2_weight=True, sparse_map=sparse_map, ) @@ -154,18 +163,20 @@ def _forward_looped( top_experts, self.layer_1.weight, self.layer_2.weight, - self._num_experts, - self._gated, - self._activation_type, - self._intermediate_dim.parallel_group, + self._config.num_experts, + self._config.gated, + self._config.activation_type, + self._parallel_dim.group, self._sequence_parallel, self.training, - self._recompute_level, + self._config.mlp_recompute_level, ) @torch.compile def _apply_input_jitter(self, logits: torch.Tensor) -> torch.Tensor: - return logits * torch.empty_like(logits).uniform_(1.0 - self._moe_jitter_eps, 1.0 + self._moe_jitter_eps) + return logits * torch.empty_like(logits).uniform_( + 1.0 - self._config.moe_jitter_eps, 1.0 + self._config.moe_jitter_eps + ) def _topk_routing( self, @@ -173,11 +184,11 @@ def _topk_routing( grad_scale: float | None = None, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - top_logits, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + top_logits, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) if losses is not None or (self.training and grad_scale is not None): probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - mask = torch.nn.functional.one_hot(top_experts, num_classes=self._num_unshared_experts).sum(dim=1) + mask = torch.nn.functional.one_hot(top_experts, num_classes=self._config.num_unshared_experts).sum(dim=1) # Auxiliary loss, corresponding to the sum of probabilities for the top experts. # In the optimal case (uniform distribution), loss = experts_per_token / num_experts. # In the worst case (whole distribution in the top experts), loss = 1. @@ -185,10 +196,12 @@ def _topk_routing( probs.flatten(0, -2).mean(dim=0) * mask.flatten(0, -2).mean(dim=0, dtype=torch.float32) ) if losses is not None: - losses[TransformerLossNames.load_balancing_loss].append(aux_loss.detach()) + losses[MLPLossNames.load_balancing_loss].append(aux_loss.detach()) if self.training and grad_scale is not None: scores = AuxiliaryLoss.apply( - scores, aux_loss, self._num_unshared_experts * self._load_balancing_factor * grad_scale + scores, + aux_loss, + self._config.num_unshared_experts * self._config.expert_auxiliary_loss_coefficient * grad_scale, ) return scores, top_experts @@ -197,69 +210,33 @@ def _add_shared_experts( ) -> tuple[torch.Tensor, torch.Tensor]: # Add the shared experts (last ones) to the top experts. shared_experts = torch.arange( - self._num_unshared_experts, self._num_experts, device=top_experts.device, dtype=top_experts.dtype + self._config.num_unshared_experts, + self._config.num_experts, + device=top_experts.device, + dtype=top_experts.dtype, )[None].repeat(top_experts.size(0), 1) top_experts = torch.cat((shared_experts, top_experts), dim=1) # Add scores of 1 to scores for shared experts. - scores = torch.cat((scores.new_ones(scores.size(0), self._num_shared_experts), scores), dim=1) + scores = torch.cat((scores.new_ones(scores.size(0), self._config.num_shared_experts), scores), dim=1) return scores, top_experts def _sinkhorn_routing(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if self.training: - _, top_experts = torch.topk(sinkhorn(logits), k=self._experts_per_token, dim=-1) + _, top_experts = torch.topk(sinkhorn(logits), k=self._config.num_experts_per_token, dim=-1) logits = self._sinkhorn_activation(logits) scores = torch.gather(logits, -1, top_experts) else: logits = self._sinkhorn_activation(logits) - scores, top_experts = torch.topk(logits, k=self._experts_per_token, dim=-1) + scores, top_experts = torch.topk(logits, k=self._config.num_experts_per_token, dim=-1) return scores, top_experts def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: return ( torch.sigmoid(logits) - if self._experts_per_token == 1 + if self._config.num_experts_per_token == 1 else torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) ) - def _debug_log( - self, - tensor: torch.Tensor | None, - name: str, - dim_name: str, - kwargs: dict[str, typing.Any], - global_: bool = True, - ) -> None: - if self._config.debug_transformer_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self._name} {name}", str)) - if self._config.debug_transformer and tensor is not None: - # TODO: Local vs global - meta = self._get_meta(tensor, name, dim_name, kwargs) - log_distributed_tensor( - "", - tensor.view_as(meta), - level=self._config.debug_transformer, - meta=meta, - distributed=self._tensor_space.distributed, - global_=global_, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_name, kwargs), - distributed=self._tensor_space.distributed, - grad_fn=lambda tensor_: tensor_.view_as(meta), - global_=global_, - ) - - def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: - return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), - tensor_name=f"{self._name} {name}", - dtype=tensor.dtype, - ) - def sinkhorn(cost: torch.Tensor, tolerance: float = 1e-5, eps=1e-9) -> torch.Tensor: """Sinkhorn based MoE routing function""" diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py new file mode 100644 index 000000000..c3a714a42 --- /dev/null +++ b/fast_llm/layers/block/mlp/mlp.py @@ -0,0 +1,120 @@ +import typing + +import torch + +from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd +from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.block.mlp.config import MLPConfig +from fast_llm.layers.block.peft import TransformerSubLayerName +from fast_llm.layers.common.linear import LinearBase +from fast_llm.utils import Assert, combine_lr_scales + + +class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): + def __init__( + self, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + ): + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() + + init_method_1 = init_normal_( + std=self._config.init_method_std_mlp_1, + min_val=self._config.init_method_min_mlp_1, + max_val=self._config.init_method_max_mlp_1, + ) + init_method_2 = init_normal_( + std=self._config.init_method_std_mlp_2, + min_val=self._config.init_method_min_mlp_2, + max_val=self._config.init_method_max_mlp_2, + ) + + self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + + lr_scale = combine_lr_scales(self._lr_scale, self._config.mlp_lr_scale) + + # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) + self.layer_1 = LinearBase( + hidden_dim, + intermediate_1_dim, + bias=self._config.add_mlp_bias, + weight_init_method=init_method_1, + bias_init_method=init_zeros_, + lr_scale=lr_scale, + ) + self.layer_2 = LinearBase( + intermediate_2_dim, + hidden_dim, + bias=self._config.add_mlp_bias, + weight_init_method=init_method_2, + bias_init_method=init_zeros_, + auto_bias_grad_accumulation=self._distributed_config.tensor_parallel > 1, + transposed_weight=True, + lr_scale=lr_scale, + ) + + # PEFT. + self.layer_1 = self._block_config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) + self.layer_2 = self._block_config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + + def _get_intermediate_dims(self): + intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) + if self._config.gated: + TensorDim("gate_and_up", 2) + intermediate_1_dim = ConcatenatedTensorDim("gate_and_up", (intermediate_2_dim, intermediate_2_dim)) + else: + intermediate_1_dim = intermediate_2_dim + return intermediate_1_dim, intermediate_2_dim + + +class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): + def __init__( + self, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + ): + Assert.eq(config.num_experts, 1) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return ( + mlp_autograd( + input_, + None, + self.layer_1.weight, + self.layer_1.bias, + self.layer_2.weight, + None if self._parallel_dim.group else self.layer_2.bias, + gated=self._config.gated, + activation_type=self._config.activation_type, + group=self._parallel_dim.group, + sequence_parallel=self._sequence_parallel, + training=self.training, + recompute_level=self._config.mlp_recompute_level, + transposed_layer_2_weight=self.layer_2.transposed_weight, + ), + self.layer_2.bias if self._parallel_dim.group else None, + ) diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py new file mode 100644 index 000000000..b51d352bc --- /dev/null +++ b/fast_llm/layers/block/peft.py @@ -0,0 +1,87 @@ +""" +TODO: Generalize beyond transformers. +""" + +import enum +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.layers.common.peft.config import LoRAConfig, NoPeftConfig, PeftConfig +from fast_llm.utils import div + +if typing.TYPE_CHECKING: + from fast_llm.layers.common.linear import LinearBase, LinearLike + + +class TransformerSubLayerName(str, enum.Enum): + query = "query" + key = "key" + value_ = "value" + key_value = "key_value" + dense = "dense" + mlp_1 = "mlp_1" + mlp_2 = "mlp_2" + + +@config_class(registry=True) +class TransformerPeftConfig(PeftConfig): + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return TransformerNoPeftConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={TransformerPeftConfig: "none"}) +class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): + pass + + +@config_class(dynamic_type={TransformerPeftConfig: "lora"}) +class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): + layers: list[TransformerSubLayerName] = Field( + default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), + desc="The layers on which to apply LoRA.", + hint=FieldHint.feature, + ) + + def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": + out_channel_begin, out_channel_end = None, None + if layer_type is None or self.layers is None or layer_type in self.layers: + enabled = True + if layer_type == TransformerSubLayerName.key: + out_channel_end = div(linear._out_dim.global_size, 2) + elif layer_type == TransformerSubLayerName.value_: + out_channel_begin = div(linear._out_dim.global_size, 2) + else: + enabled = False + return super().apply_linear(linear, enabled, out_channel_begin, out_channel_end) + + def _validate(self) -> None: + super()._validate() + if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: + # TODO: Add MLP support. + raise NotImplementedError("LoRA not supported for MLP.") + if TransformerSubLayerName.dense in self.layers: + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for attention dense layer.") + if ( + sum( + name in self.layers + for name in ( + TransformerSubLayerName.key_value, + TransformerSubLayerName.key, + TransformerSubLayerName.value_, + ) + ) + > 1 + ): + raise ValueError( + f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." + ) diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 7249ef569..ca807e67c 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -3,7 +3,8 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.initialization import init_zeros_ +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, @@ -14,7 +15,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.tensor import ParameterMeta, init_zeros_ +from fast_llm.tensor import ParameterMeta logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/common/normalization/__init__.py b/fast_llm/layers/common/normalization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/normalization/config.py similarity index 54% rename from fast_llm/layers/common/config.py rename to fast_llm/layers/common/normalization/config.py index 710b2668f..569d48b0e 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -7,23 +7,8 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - import torch - - from fast_llm.engine.config_utils.tensor_space import TensorDim - from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.layers.common.normalization import LayerNorm, RMSNorm - - -@config_class() -class LLMBlockConfig(BaseModelConfig): - _abstract = False - - per_layer_lr_scale: list[float] | None = Field( - default=None, - desc="Custom learning rate scale for each layer.", - doc="May be used to freeze some layers by setting their scale to zero.", - hint=FieldHint.feature, - ) + from fast_llm.engine.config_utils.tensor_dim import TensorDim + from fast_llm.layers.common.normalization.normalization import Normalization class NormalizationImplementation(str, enum.Enum): @@ -42,10 +27,18 @@ class NormalizationImplementation(str, enum.Enum): class NormalizationConfig(BaseModelConfig): pass + @property @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module": + def module_class(self) -> type["Normalization"]: pass + def get_layer( + self, + hidden_dim: "TensorDim", + lr_scale: float | None = None, + ) -> "Normalization": + return self.module_class(self, hidden_dim, lr_scale) + @classmethod def _from_dict( cls, @@ -63,8 +56,11 @@ def _from_dict( class NoNormalizationConfig(NormalizationConfig): _abstract = False - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None) -> "torch.nn.Module": - return torch.nn.Identity() + @property + def module_class(self) -> type["Normalization"]: + from fast_llm.layers.common.normalization.normalization import NoNormalization + + return NoNormalization @config_class() @@ -98,21 +94,6 @@ class LayerNormalizationBaseConfig(NormalizationConfig): valid=check_field(Assert.geq, 0), ) - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_centered_ - - kwargs = { - "hidden_dim": hidden_dim, - "eps": self.epsilon, - "implementation": self.implementation, - "zero_centered": self.zero_centered, - "lr_scale": lr_scale, - } - if self.initialization_range: - mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) - return self.module_class(**kwargs) - @property @abc.abstractmethod def module_class(self): @@ -139,9 +120,9 @@ class LayerNormalizationConfig(LayerNormalizationBaseConfig): @property def module_class(self): - from fast_llm.layers.common.normalization import LayerNorm + from fast_llm.layers.common.normalization.normalization import LayerNormalization - return LayerNorm + return LayerNormalization @config_class(dynamic_type={NormalizationConfig: "rms_norm"}) @@ -150,56 +131,6 @@ class RMSNormalizationConfig(LayerNormalizationBaseConfig): @property def module_class(self): - from fast_llm.layers.common.normalization import RMSNorm - - return RMSNorm - - -@config_class() -class PeftConfig(BaseModelConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - pass - - -@config_class() -class NoPeftConfig(PeftConfig): - _abstract = False - - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - return linear - - -@config_class() -class LoRAConfig(PeftConfig): - _abstract = False - - rank: int = Field( - default=8, - desc="The LoRA rank, i.e. the size of the intermediate dimension.", - hint=FieldHint.stability, - ) - alpha: float = Field( - default=8.0, - desc="The LoRA scaling parameter.", - hint=FieldHint.stability, - ) - dropout: float = Field( - default=0.0, - desc="Dropout rate for LoRA.", - hint=FieldHint.stability, - ) + from fast_llm.layers.common.normalization.normalization import RMSNormalization - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - from fast_llm.layers.common.peft import lora_linear - - # TODO: Init method? - return lora_linear( - linear, - linear.weight.param_init_method, - linear.weight.param_init_method, - self.rank, - self.alpha, - self.dropout, - **kwargs, - ) + return RMSNormalization diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization/normalization.py similarity index 69% rename from fast_llm/layers/common/normalization.py rename to fast_llm/layers/common/normalization/normalization.py index bccc1d627..a7eba72c8 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -1,11 +1,21 @@ +import abc + import torch +from fast_llm.config import Configurable +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.config import NormalizationImplementation -from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ +from fast_llm.layers.common.normalization.config import ( + LayerNormalizationConfig, + NoNormalizationConfig, + NormalizationConfig, + NormalizationImplementation, + RMSNormalizationConfig, +) +from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert try: @@ -138,7 +148,24 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, return grad_input, None, None, None -class LayerNorm(torch.nn.Module): +class Normalization[ConfigType: NormalizationConfig](Configurable[ConfigType], torch.nn.Module): + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config) + self._hidden_dim = hidden_dim + self._lr_scale = lr_scale + assert not self._hidden_dim.is_parallel + + @abc.abstractmethod + def forward(self, input_: torch.Tensor) -> torch.Tensor: + pass + + +class NoNormalization[ConfigType: NoNormalizationConfig](Normalization[ConfigType]): + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return input_ + + +class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[ConfigType]): """ A layer normalization layer, supporting multiple implementations. Note: Converting input automatically to training dtype to match Apex behaviour, @@ -146,25 +173,17 @@ class LayerNorm(torch.nn.Module): TODO: Review this? """ - def __init__( - self, - hidden_dim: TensorDim, - *, - eps=1e-5, - implementation: NormalizationImplementation = NormalizationImplementation.auto, - weight_init_method=None, - bias_init_method=init_zeros_, - zero_centered: bool = False, - lr_scale: float | None = None, - ): - super().__init__() - assert not hidden_dim.is_parallel - self._eps = eps - self._zero_centered = zero_centered + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) + implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if _fast_normalization_available and hidden_dim.size in _PERSIST_LN_SIZES and not self._zero_centered: + if ( + _fast_normalization_available + and hidden_dim.size in _PERSIST_LN_SIZES + and not self._config.zero_centered + ): implementation = NormalizationImplementation.fast - elif TritonConfig.TRITON_ENABLED or self._zero_centered: + elif TritonConfig.TRITON_ENABLED or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton elif _fused_normalization_available: @@ -173,7 +192,7 @@ def __init__( else: log_main_rank("Fast and fused layer norm unavailable, using backup pytorch implementation.") implementation = NormalizationImplementation.torch - if self._zero_centered: + if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: self._forward = self._forward_triton @@ -186,44 +205,49 @@ def __init__( else: raise NotImplementedError(implementation) - if weight_init_method is None: - weight_init_method = init_zeros_ if self._zero_centered else init_ones_ + if self.config.initialization_range: + mean = 0 if self.zero_centered else 1 + weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) + else: + weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ self.weight = ParameterMeta.from_dims( (hidden_dim,), init_method=weight_init_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, + lr_scale=self._lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), - init_method=bias_init_method, + init_method=init_zeros_, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, + lr_scale=self._lr_scale, ) - self.normalized_shape = self.weight.shape + self._normalized_shape = self.weight.shape def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self.normalized_shape)).view_as(input_) + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: return triton_normalization_autograd( - input_, self.weight, self.bias, self._eps, self.training, self._zero_centered + input_, self.weight, self.bias, self._config.epsilon, self.training, self._config.zero_centered ) def _forward_fast(self, input_: torch.Tensor) -> torch.Tensor: - return FastLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps) + return FastLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps) + return FusedLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.layer_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self.bias, self._eps) + return torch.layer_norm( + input_.to(self.weight.dtype), self._normalized_shape, self.weight, self.bias, self._config.epsilon + ) -class RMSNorm(torch.nn.Module): +class RMSNormalization[ConfigType: RMSNormalizationConfig](Normalization[ConfigType], torch.nn.Module): """ A RMS normalization layer. Note: Converting input automatically to training dtype to match Apex behaviour, @@ -231,22 +255,12 @@ class RMSNorm(torch.nn.Module): TODO: Review this? """ - def __init__( - self, - hidden_dim: TensorDim, - *, - eps=1e-5, - implementation: NormalizationImplementation = NormalizationImplementation.auto, - weight_init_method=None, - zero_centered: bool = False, - lr_scale: float | None = None, - ): - super().__init__() + def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + super().__init__(config, hidden_dim, lr_scale) assert not hidden_dim.is_parallel - self._eps = eps - self._zero_centered = zero_centered + implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if TritonConfig.TRITON_ENABLED or self._zero_centered: + if TritonConfig.TRITON_ENABLED or self._config.zero_centered: implementation = NormalizationImplementation.triton elif _fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") @@ -254,7 +268,7 @@ def __init__( else: log_main_rank("Fused RMS norm unavailable, using backup implementation.") implementation = NormalizationImplementation.torch - if self._zero_centered: + if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: self._forward = self._forward_triton @@ -265,8 +279,11 @@ def __init__( else: raise NotImplementedError(implementation) - if weight_init_method is None: - weight_init_method = init_zeros_ if self._zero_centered else init_ones_ + if self.config.initialization_range: + mean = 0 if self.zero_centered else 1 + weight_init_method = init_uniform_centered_(self.config.initialization_range, mean=mean) + else: + weight_init_method = init_zeros_ if self._config.zero_centered else init_ones_ self.weight = ParameterMeta.from_dims( (hidden_dim,), @@ -275,16 +292,18 @@ def __init__( auto_grad_accumulation=True, lr_scale=lr_scale, ) - self.normalized_shape = self.weight.shape + self._normalized_shape = self.weight.shape def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self.normalized_shape)).view_as(input_) + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: - return triton_normalization_autograd(input_, self.weight, None, self._eps, self.training, self._zero_centered) + return triton_normalization_autograd( + input_, self.weight, None, self._config.epsilon, self.training, self._config.zero_centered + ) def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedRMSNorm.apply(input_, self.normalized_shape, self.weight, self._eps) + return FusedRMSNorm.apply(input_, self._normalized_shape, self.weight, self._config.epsilon) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps) + return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) diff --git a/fast_llm/layers/common/peft/__init__.py b/fast_llm/layers/common/peft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py new file mode 100644 index 000000000..64a2ca57a --- /dev/null +++ b/fast_llm/layers/common/peft/config.py @@ -0,0 +1,91 @@ +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelConfig + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.layers.common.linear import LinearBase, LinearLike + from fast_llm.layers.common.normalization.normalization import Normalization + from fast_llm.tensor import ParameterMeta + + +@config_class() +class PeftConfig(BaseModelConfig): + def apply_linear( + self, + module: "LinearBase", + enabled: bool, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, + ) -> "LinearLike": + return self.apply_other(module) + + def apply_normalization(self, module: "Normalization") -> "Normalization": + return self.apply_other(module) + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + for parameter in module.parameters(): + self.apply_weight(parameter) + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + return parameter + + +@config_class() +class NoPeftConfig(PeftConfig): + _abstract = False + + +@config_class() +class LoRAConfig(PeftConfig): + _abstract = False + + rank: int = Field( + default=8, + desc="The LoRA rank, i.e. the size of the intermediate dimension.", + hint=FieldHint.stability, + ) + alpha: float = Field( + default=8.0, + desc="The LoRA scaling parameter.", + hint=FieldHint.stability, + ) + dropout: float = Field( + default=0.0, + desc="Dropout rate for LoRA.", + hint=FieldHint.stability, + ) + freeze_others: bool = Field( + default=True, + desc="Whether to freeze other layers during training.", + ) + + def apply_linear( + self, + module: "LinearBase", + enabled: bool, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, + ) -> "LinearLike": + if not enabled: + return self.apply_other(module) + + from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear + from fast_llm.layers.common.peft.lora import lora_linear + + if isinstance(module, InputParallelLinear): + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for InputParallelLinear.") + elif isinstance(module, OutputParallelLinear): + assert out_channel_begin is None and out_channel_end is None + + # TODO: Init method? + return lora_linear(module, self.rank, self.alpha, self.dropout, out_channel_begin, out_channel_end) + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + if self.freeze_others: + parameter.requires_grad = False + return parameter diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft/lora.py similarity index 68% rename from fast_llm/layers/common/peft.py rename to fast_llm/layers/common/peft/lora.py index 08f3e535b..9e0ca0dd0 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft/lora.py @@ -2,27 +2,25 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.common.linear import Linear, LinearBase def lora_linear( - layer: LinearBase, - init_method_0, - init_method_1, + module: LinearBase, rank: int, alpha: float, dropout: float = 0.0, out_channel_begin: int | None = None, out_channel_end: int | None = None, ): - layer.weight.requires_grad = False - in_dim = layer._in_dim + module.weight.requires_grad = False + in_dim = module._in_dim assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: in_dim = TensorDim(in_dim.name, in_dim.global_size) - out_dim = layer._out_dim + out_dim = module._out_dim assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: out_dim = TensorDim(out_dim.name, out_dim.global_size) @@ -36,27 +34,27 @@ def lora_linear( middle_dim = TensorDim("lora_middle", rank) - layer.lora_0 = Linear( + module.lora_0 = Linear( in_dim, middle_dim, bias=False, - weight_init_method=init_method_0, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, ) - layer.lora_1 = Linear( + module.lora_1 = Linear( middle_dim, out_dim, bias=False, - weight_init_method=init_method_1, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, ) # TODO: Implement proper backward pass. - layer.lora_0.weight.auto_grad_accumulation = True - layer.lora_1.weight.auto_grad_accumulation = True + module.lora_0.weight.auto_grad_accumulation = True + module.lora_1.weight.auto_grad_accumulation = True - old_forward = layer._forward + old_forward = module._forward def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: # TODO: torch compile? @@ -66,8 +64,8 @@ def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor if isinstance(output, tuple): layer_out, tp_bias = output[0] assert tp_bias is None - lora_out = (alpha / rank) * layer.lora_1( - layer.lora_0(torch.dropout(input_, dropout, layer.training) if dropout > 0.0 else input_) + lora_out = (alpha / rank) * module.lora_1( + module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) ) if out_channel_begin is None: output = output + lora_out @@ -83,8 +81,8 @@ def backward( output.backward(grad_output) return input_.grad - layer._forward = wrap_forward_backward(forward_only, backward) - layer.forward_only = forward_only - layer.backward = backward + module._forward = wrap_forward_backward(forward_only, backward) + module.forward_only = forward_only + module.backward = backward - return layer + return module diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e2e97f1a..df6969cfc 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,23 +2,13 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import NoRotaryConfig +from fast_llm.layers.block.config import BlockKwargs from fast_llm.utils import Assert -class LanguageModelDimNames: - # Embedding dimensions - position_embed = "position_embed" - vocab = "vocab" - vocab_tp = "vocab_tp" - # Misc - scalar = "scalar" - - class LanguageModelLossNames: language_model_loss = "language_model_loss" z_loss = "z_loss" @@ -33,7 +23,7 @@ def multi_token_prediction_loss(index: int) -> str: return f"language_model_loss_{index}" -class LanguageModelKwargs: +class LanguageModelKwargs(BlockKwargs): position_ids = "position_ids" # TODO: These are generic labels = "labels" @@ -46,6 +36,7 @@ class LanguageModelKwargs: @config_class() class LanguageModelBaseConfig(BaseModelConfig): + # TODO: block transformer: TransformerConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, @@ -235,16 +226,6 @@ def _validate(self) -> None: len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 ) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Embedding dimensions - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.position_embed, self.max_position_embeddings)) - # TODO: Need both? - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - @property def num_absolute_position_embeddings(self) -> int: # TODO: Rename from max embeddings. diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index f6f43d199..fd4e8412e 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -2,61 +2,67 @@ import torch -from fast_llm.config import Configurable from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.engine.config_utils.initialization import init_normal_ +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), together with optional absolute position embeddings and dropout. """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig - # Ensure the layer is on its own stage. layer_count: float = 1000.0 def __init__( self, - config: LanguageModelBaseConfig, - tensor_space: TensorSpace, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + # TODO: Unnecessary? + block_index: int, + name: str, ): - super().__init__(config) - self._distributed_config = tensor_space.distributed_config - self._tensor_space = tensor_space + super().__init__( + config, + config.transformer, + distributed_config, + hidden_dim, + block_index, + name, + # TODO: Add lr scale? + None, + ) self._residual_dtype = ( self._distributed_config.optimization_dtype if config.transformer.full_precision_residual else self._distributed_config.training_dtype ).torch - self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._dropout_p = config.transformer.hidden_dropout - self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - - hidden_dim = tensor_space[TransformerDimNames.hidden] - vocab_dim = tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] + self._parallel_embeddings = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + vocab_dim = TensorDim( + "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_embeddings else None + ) if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size self._vocab_end_index = (self._distributed_config.tensor_rank + 1) * vocab_dim.size self.word_embeddings_weight = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), + (vocab_dim, self._hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, @@ -64,9 +70,9 @@ def __init__( ), lr_scale=config.embeddings_lr_scale, ) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), + (TensorDim("position_embeddings", self._config.max_position_embeddings), self._hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, @@ -85,21 +91,21 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: - Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) - group = self._tensor_space.distributed.tensor_group + Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) + group = self._parallel_dim.group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: input_ = split(input_, group=group, dim=0) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: @@ -108,16 +114,14 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) with set_generator( - self._tensor_space.distributed.tp_generator - if self._sequence_parallel - else self._tensor_space.distributed.pp_generator + self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + embeddings = torch.dropout(embeddings, self._config.transformer.hidden_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( @@ -129,7 +133,7 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims], + kwargs[LanguageModelKwargs.hidden_dims], tensor_name="Embedding output", dtype=self._residual_dtype, ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 210cad644..d0c0eb8f9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,31 +1,25 @@ import logging -import typing import torch from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.engine.config_utils.initialization import init_normal_ +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward +from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.layers.language_model.config import ( - LanguageModelBaseConfig, - LanguageModelDimNames, - LanguageModelKwargs, - LanguageModelLossNames, -) +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs -from fast_llm.logging import log_distributed_tensor -from fast_llm.tensor import ParameterMeta, TensorMeta, init_normal_ +from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique logger = logging.getLogger(__name__) @@ -33,63 +27,56 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig - def __init__( self, - config: LanguageModelBaseConfig, - tensor_space: TensorSpace, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + # TODO: Unnecessary? + block_index: int, + name: str, prediction_distance: int, ): - super().__init__(config) - self._debug_transformer = config.transformer.debug_transformer - self._tie_word_embeddings = config.tie_word_embeddings - self._tensor_space = tensor_space - - self._group_size = tensor_space.distributed_config.tensor_parallel - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings + super().__init__( + config, + config.transformer, + distributed_config, + hidden_dim, + block_index, + name, + # TODO: Add lr scale? + None, ) - self._cross_entropy_splits = config.cross_entropy_splits - if self._cross_entropy_splits is not None and self._sequence_parallel: - assert not self._parallel_embeddings + self._parallel_logits = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] + self._sequence_parallel_logits = self._sequence_parallel and not self._config.parallel_embeddings + if self._config.cross_entropy_splits is not None and self._sequence_parallel: + assert not self._parallel_logits self._loss_coefficient = ( - config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 + self._config.prediction_loss_coefficient[prediction_distance] + if self._config.prediction_loss_coefficient + else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = config.transformer.normalization.get_layer(hidden_dim) - self._logits_scale_factor = config.logits_scale_factor - self._language_model_loss_factor = config.language_model_loss_factor - self._distillation_loss_factor = config.distillation_loss_factor - self._z_loss_factor = config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction # >0: multi-token prediction (MTP) Assert.geq(prediction_distance, 0) self._prediction_distance = prediction_distance - self._is_last_head = self._prediction_distance == config.prediction_heads - 1 - - self._init_output_weights(hidden_dim, config) + self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 - self._use_dpo_loss = config.enable_dpo - if self._use_dpo_loss: - self.dpo_beta = config.dpo_beta - else: - self._cross_entropy_impl = config.cross_entropy_impl - self._distillation_loss_implementation = config.distillation_loss_implementation + if not self._config.enable_dpo: + self._cross_entropy_impl = self._config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._parallel_embeddings: + if self._parallel_logits: self._cross_entropy_impl = CrossEntropyImpl.fused elif TritonConfig.TRITON_ENABLED: self._cross_entropy_impl = CrossEntropyImpl.triton @@ -98,39 +85,40 @@ def __init__( self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) + self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) + + self._vocab_dim = TensorDim( + "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_logits else None + ) + # Only the first head defines the output weights + if self._prediction_distance == 0 and not self._config.tie_word_embeddings: + # untie embedding weights + self.output_weights = ParameterMeta.from_dims( + (self._vocab_dim, hidden_dim), + init_method=init_normal_( + std=self._config.init_method_std_embed, + min_val=self._config.init_method_min_embed, + max_val=self._config.init_method_max_embed, + ), + lr_scale=self._config.output_lr_scale, + ) + # PEFT. self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) if hasattr(self, "output_weights"): self.output_weights = self._config.transformer.peft.apply_weight(self.output_weights) - def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: - # Only the first head defines the output weights - if self._tie_word_embeddings or self._prediction_distance > 0: - return - # untie embedding weights - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] - self.output_weights = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - lr_scale=config.output_lr_scale, - ) - def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): if self._is_last_head: - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, + return TensorMeta.from_dims( + (scalar_dim,), tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + reductions=( + (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), + ), ) else: return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") @@ -168,21 +156,23 @@ def _forward_backward( if "output_hidden_states" in kwargs and kwargs["output_hidden_states"]: # The last hidden layer output is returned normalized in the HF Transformers-style output, at least for LLama style models. # So, if needed, we gather the data after normalization and set it as the output of the previous layer. - dims = list(kwargs[TransformerKwargs.hidden_dims]) - sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) + dims = list(kwargs[LanguageModelKwargs.hidden_dims]) + sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) dims[sequence_index] = ( TensorDim( - TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor + BlockDimNames.sequence_q_tp, + dims[sequence_index].global_size, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) if self._sequence_parallel_logits - else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) + else TensorDim(BlockDimNames.sequence_q, dims[sequence_index].global_size) ) meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype) - hidden_state, _ = meta.local_to_global(ln_output.detach(), distributed=self._tensor_space.distributed) + hidden_state, _ = meta.local_to_global(ln_output.detach()) kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state - grad_output = kwargs[TransformerKwargs.grad_output] / ( - self._group_size if self._sequence_parallel_logits else 1 + grad_output = kwargs[LanguageModelKwargs.grad_output] / ( + self._parallel_dim.size if self._sequence_parallel_logits else 1 ) output_weights = self._get_output_weights(kwargs) @@ -200,7 +190,7 @@ def _get_targets( self, kwargs: dict ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None: # Loss mask for distillation. (Labels are already masked.) - if self._use_dpo_loss: + if self._config.enable_dpo: dpo_target = kwargs.get(LanguageModelKwargs.labels) lm_target = None distillation_target = None @@ -216,23 +206,23 @@ def _get_targets( if loss_mask is not None: loss_mask = loss_mask.flatten() - if self._config.distillation_model is None or self._language_model_loss_factor > 0.0: + if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: lm_target = kwargs.get(LanguageModelKwargs.labels) if lm_target is not None: # MTP: Shift the labels lm_target_sequence_length = ( - lm_target.size(1 - kwargs[TransformerKwargs.sequence_first]) + lm_target.size(1 - kwargs[LanguageModelKwargs.sequence_first]) + 1 - self._config.prediction_heads ) - if TransformerKwargs.sequence_q_dim in kwargs: - Assert.eq(lm_target_sequence_length, kwargs[TransformerKwargs.sequence_q_dim].size) + if LanguageModelKwargs.sequence_q_dim in kwargs: + Assert.eq(lm_target_sequence_length, kwargs[LanguageModelKwargs.sequence_q_dim].size) lm_target_slice = slice( self._prediction_distance, self._prediction_distance + lm_target_sequence_length ) lm_target = ( lm_target[lm_target_slice] - if kwargs[TransformerKwargs.sequence_first] + if kwargs[LanguageModelKwargs.sequence_first] else lm_target[:, lm_target_slice] ).flatten() else: @@ -240,17 +230,14 @@ def _get_targets( targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: - targets = [ - None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) - for target in targets - ] + targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None return targets def _get_output_weights(self, kwargs: dict) -> torch.Tensor: - if self._tie_word_embeddings: + if self._config.tie_word_embeddings: return kwargs[WORD_EMBEDDINGS_WEIGHT] if self._prediction_distance > 0: return kwargs[OUTPUT_WEIGHTS] @@ -265,7 +252,7 @@ def _logits_cross_entropy_forward_backward_split( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._cross_entropy_splits is None or targets is None: + if self._config.cross_entropy_splits is None or targets is None: loss, logit_input_grad = self._logits_cross_entropy_forward_backward( input_, targets, weight, grad_output, kwargs, losses ) @@ -275,18 +262,19 @@ def _logits_cross_entropy_forward_backward_split( return None, None else: loss = None - # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length - grad_output /= self._cross_entropy_splits + # TODO MTP: allow a cross_entropy_splits that is not a divisor of the sequence length + grad_output /= self._config.cross_entropy_splits logit_input = input_.flatten(0, -2) if self.training: logit_input_grad = torch.empty_like(logit_input) else: logit_input_grad = None split_size = div( - get_unique(target.size(0) for target in targets if target is not None), self._cross_entropy_splits + get_unique(target.size(0) for target in targets if target is not None), + self._config.cross_entropy_splits, ) tensors_split = [ - [None] * self._cross_entropy_splits if tensor is None else tensor.split(split_size) + [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) for tensor in [logit_input, *targets, logit_input_grad] ] for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): @@ -302,12 +290,14 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) + loss_count = (self._config.cross_entropy_splits or 1) * ( + self._parallel_dim.size if self._sequence_parallel_logits else 1 + ) if loss_count != 1: loss.div_(loss_count) if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss, group=self._tensor_space.distributed.tensor_group) + all_reduce(loss, group=self._parallel_dim.group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -319,56 +309,37 @@ def _logits_cross_entropy_forward_backward( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + group = self._parallel_dim.group if self._parallel_logits else None logits, context = output_parallel_linear_forward( input_=input_, weight=weight, bias=None, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - sequence_parallel=self._sequence_parallel and self._parallel_embeddings, + group=group, + sequence_parallel=self._sequence_parallel and self._parallel_logits, ) - if self._z_loss_factor > 0.0: + if self._config.logit_z_loss > 0.0: logits = z_loss( logits, - self._z_loss_factor, + self._config.logit_z_loss, self.training, grad_output, losses, LanguageModelLossNames.z_loss, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._config.logits_scale_factor, ) - if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ] - dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] - sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) - dims[sequence_index] = ( - TensorDim( - TransformerDimNames.sequence_q_tp, dims[sequence_index].global_size, DistributedDimNames.tensor - ) - if self._sequence_parallel_logits - else TensorDim(TransformerDimNames.sequence_q, dims[sequence_index].global_size) - ) - - dim_names = ( - [TransformerDimNames.sequence_q_tp, LanguageModelDimNames.vocab] - if self._sequence_parallel_logits - else [TransformerDimNames.sequence_q, LanguageModelDimNames.vocab_tp] - ) - - dim_names.insert(int(kwargs[TransformerKwargs.sequence_first]), TransformerDimNames.batch) - log_distributed_tensor( - "", - logits, - level=self._debug_transformer, - meta=TensorMeta.from_dims(tuple(dims), tensor_name="transformer logits", dtype=logits.dtype), - distributed=self._tensor_space.distributed, - scale=self._logits_scale_factor, + if self._debug.enabled and self._config.cross_entropy_splits is None: + sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q + batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] + dims = ( + (sequence_dim, batch_dim, self._vocab_dim) + if kwargs[LanguageModelKwargs.sequence_first] + else (batch_dim, sequence_dim, self._vocab_dim) ) + self._debug(logits, "Language model logits", dims, kwargs, scale=self._config.logits_scale_factor) if targets is None: - return logits * self._logits_scale_factor, None + return logits * self._config.logits_scale_factor, None dpo_target, lm_target, distillation_target, loss_mask = targets if dpo_target is not None: @@ -378,7 +349,7 @@ def _logits_cross_entropy_forward_backward( kwargs.get(f"{self._config.dpo_reference_model}_logits"), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], - self.dpo_beta, + self._config.dpo_beta, grad_output * self._loss_coefficient, ) else: @@ -389,44 +360,46 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), lm_target, None, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output * self._loss_coefficient * self._language_model_loss_factor, + group=group, + grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) - lm_loss = lm_loss * self._language_model_loss_factor + lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None and self._distillation_loss_factor > 0.0: - if self._distillation_loss_implementation == DistillationLossImpl.reverse_kl: + if distillation_target is not None and self._config.distillation_loss_factor > 0.0: + if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._distillation_loss_factor, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - logits_scale_factor=self._logits_scale_factor, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + group=group, + logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits ), ) - elif self._distillation_loss_implementation == DistillationLossImpl.cross_entropy: + elif self._config.distillation_loss_implementation == DistillationLossImpl.cross_entropy: distillation_loss, distillation_grad = cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output * self._loss_coefficient * self._distillation_loss_factor, + group=group, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.logits, ) else: - raise ValueError(f"Invalid distillation loss implementation: {self._distillation_loss_implementation}") - distillation_loss = distillation_loss * self._distillation_loss_factor + raise ValueError( + f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" + ) + distillation_loss = distillation_loss * self._config.distillation_loss_factor else: distillation_loss, distillation_grad = None, None diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index c8d53a789..5ba31c0d0 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -4,9 +4,9 @@ import torch from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs -from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -14,57 +14,48 @@ class PositionEmbeddingPreprocessor(Preprocessor): - _scalar_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: LanguageModelBaseConfig, - tensor_space: TensorSpace, - ): + def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): self._config = config assert config.use_absolute_position_embeddings - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._distributed_config = distributed_config - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) - self._position_ids = torch.arange( - 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 - ) + self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(TransformerKwargs.sequence_lengths)) is not None: + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[LanguageModelKwargs.sequence_length], batch.device) + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: position_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] - ).to(self._tensor_space.distributed.device, dtype=torch.int64) + ).to(batch.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[LanguageModelKwargs.sequence_first]: position_ids = position_ids.transpose(0, 1) kwargs[LanguageModelKwargs.position_ids] = position_ids else: kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ sequence_k - sequence_q : sequence_k - ].unsqueeze(int(kwargs[TransformerKwargs.sequence_first])) + ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: # Position embeddings will be broadcast. - sequence_q_dim = kwargs[TransformerKwargs.sequence_q_dim] + sequence_q_dim = kwargs[LanguageModelKwargs.sequence_q_dim] kwargs[LanguageModelKwargs.position_ids] = TensorMeta.from_dims( ( - (sequence_q_dim, self._scalar_dim) - if kwargs[TransformerKwargs.sequence_first] - else (self._scalar_dim, sequence_q_dim) + (sequence_q_dim, scalar_dim) + if kwargs[LanguageModelKwargs.sequence_first] + else (scalar_dim, sequence_q_dim) ), tensor_name=LanguageModelKwargs.position_ids, dtype=torch.int64, @@ -72,18 +63,16 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class PreferenceSpanPreprocessor(Preprocessor): - def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._distributed_config = distributed_config def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py new file mode 100644 index 000000000..22d01a5cb --- /dev/null +++ b/fast_llm/layers/ssm/block.py @@ -0,0 +1,38 @@ +import functools + +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import Block, BlockLayer +from fast_llm.layers.block.config import BlockConfig +from fast_llm.layers.ssm.config import SSMConfig + + +# TODO: Sort out configs. +class SSMBlock[ConfigType: BlockConfig](Block[ConfigType]): + """ + A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 + """ + + def __init__( + self, + config: ConfigType, + ssm_config: SSMConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + mixer_class: type[BlockLayer], + return_input: bool = False, + ): + self._ssm_config = ssm_config + self._mixer_class = mixer_class + super().__init__(config, distributed_config, hidden_dim, block_index, name, lr_scale, return_input) + + @functools.cached_property + def _mixer_class(self) -> type[BlockLayer]: + return self._mixer_class + + @property + def _mixer_config(self) -> SSMConfig: + return self._ssm_config diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index fb178e7d5..8917feaf6 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,14 +1,12 @@ import enum import typing -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.tensor import Initializer + from fast_llm.engine.config_utils.initialization import Initializer class SSMBlockType(enum.StrEnum): @@ -23,9 +21,9 @@ class SSMBlockType(enum.StrEnum): def get_mixer_class(self): if self == SSMBlockType.mamba: - from fast_llm.layers.ssm.mamba_layer import MambaLayer + from fast_llm.layers.ssm.mamba import Mamba - return MambaLayer + return Mamba elif self == SSMBlockType.mamba2: from fast_llm.layers.ssm.mamba2 import Mamba2 @@ -43,40 +41,34 @@ class DTInitType(enum.StrEnum): random = "random" def get_init_method(self, scale: float) -> "Initializer": - from fast_llm.tensor import init_fill_, init_uniform_centered_ + from fast_llm.engine.config_utils.initialization import init_fill_, init_uniform_centered_ return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) @config_class() -class SSMConfig(LLMBlockConfig): +class SSMConfig(Config): _abstract = False - # Normalization - normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers architecture.", - hint=FieldHint.architecture, - ) - # Model dimensions # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, - desc="Expansion factor for Mamba blocks.", + desc="Expansion factor.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, - desc="State size for Mamba blocks.", + desc="State size.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, - desc="Conv kernel dimension for Mamba blocks.", + desc="Conv kernel dimension.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) @@ -89,19 +81,19 @@ class SSMConfig(LLMBlockConfig): # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, - desc="Number of QK heads for Mamba2 blocks.", + desc="Number of QK heads.", hint=FieldHint.architecture, ) # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, - desc="Number of V heads for Mamba2 blocks.", + desc="Number of V heads.", hint=FieldHint.architecture, ) # c_size [MambaLayer, Mamba2, DiscreteMamba2]? d_inner: None | int = Field( default=None, - desc="Inner dimension for Mamba2 blocks.", + desc="Inner dimension.", hint=FieldHint.core, ) # xb_size [Mamba2] @@ -187,7 +179,3 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) - - def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: - # Handled in the model. - pass diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 47a94214a..f9462a942 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,21 +4,17 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import ( - CompositeTensorDim, - ConcatenatedTensorDim, - DefaultDimNames, - TensorDim, - TensorSpace, -) -from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_, init_zeros_ -from fast_llm.utils import div, get_lr_scale +from fast_llm.layers.ssm.mamba import init_kaiming_ +from fast_llm.tensor import ParameterMeta +from fast_llm.utils import combine_lr_scales, div logger = logging.getLogger(__name__) @@ -39,31 +35,31 @@ _causal_conv1d_available = False -class DiscreteMamba2(Mixer): - """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" +class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): + """ + This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py + """ _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" def __init__( self, - config: SSMConfig, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, block_index: int, - tensor_space: TensorSpace, - transformer_config: TransformerConfig, + name: str, + lr_scale: float | None, ): - super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - self._config: SSMConfig = config - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - - hidden_dim = tensor_space[TransformerDimNames.hidden] + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) state_dim = TensorDim("state", self._config.state_size) v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) head_groups_dim = TensorDim( "head_groups", self._config.n_qk_heads, - self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor), + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) group_heads_dim = TensorDim("group_heads", div(self._config.n_v_heads, self._config.n_qk_heads)) heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) @@ -86,13 +82,15 @@ def __init__( # local_bc_size = local_head_groups * state self._local_bc_size = bc_dim.size + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) + # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, inner_projection_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -106,7 +104,7 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( convolution_dim, - tensor_space[DefaultDimNames.scalar], + scalar_dim, convolution_kernel_dim, ), init_method=init_uniform_centered_( @@ -135,22 +133,27 @@ def __init__( lr_scale=lr_scale, ) - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - sequence_length = kwargs[TransformerKwargs.sequence_q_dim].global_size + sequence_length = kwargs[BlockKwargs.sequence_q_dim].global_size # Pad input to nearest multiple of chunklen padded_length = (1 + (sequence_length - 1) // self._config.chunk_size) * self._config.chunk_size if padded_length != sequence_length: - assert not kwargs[TransformerKwargs.sequence_first] and input_.size(1) == sequence_length + assert not kwargs[BlockKwargs.sequence_first] and input_.size(1) == sequence_length input_ = torch.nn.functional.pad(input_, (0, 0, 0, padded_length - sequence_length)) - # inner_projection : (batch/local_or_padded_sequence, local_sequence/batch, hidden) - # -> (batch/padded_sequence, sequence/batch, local_inner_projection) + # -> (batch/padded_sequence, sequence/batch, local_inner_projection inner_projection = self.in_proj(input_) # Standardize to (batch, padded_sequence, local_inner_projection) - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) xBC, z, A_log = torch.split( @@ -162,7 +165,6 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ ], dim=-1, ) - # Convolutional layer # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) xBC = self.convolutional_forward(xBC, padded_length) @@ -203,7 +205,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # y: (batch, padded_sequence, local_heads, head_size) -> (batch, sequence, local_heads * head_size) y = ((y + Du).flatten(2, 3) * torch.nn.functional.silu(z))[:, :sequence_length] - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py deleted file mode 100644 index 986606634..000000000 --- a/fast_llm/layers/ssm/llamba_block.py +++ /dev/null @@ -1,37 +0,0 @@ -import typing - -from fast_llm.layers.transformer.transformer import BaseBlock, Mixer - -if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_space import TensorSpace - from fast_llm.layers.ssm.config import SSMConfig - from fast_llm.layers.transformer.config import TransformerConfig - - -class SSMBlock(BaseBlock): - """ - A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 - """ - - _name = "Llamba block" - - def __init__( - self, - transformer_config: "TransformerConfig", - ssm_config: "SSMConfig", - tensor_space: "TensorSpace", - mixer_cls: type[Mixer], - block_index: int, - return_input: bool = False, - ): - self._ssm_config = ssm_config - self._mixer_cls = mixer_cls - super().__init__(transformer_config, tensor_space, block_index, return_input) - - def _create_mixer(self) -> Mixer: - return self._mixer_cls( - self._ssm_config, - tensor_space=self._tensor_space, - block_index=self._block_index, - transformer_config=self._config, - ) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba.py similarity index 78% rename from fast_llm/layers/ssm/mamba_layer.py rename to fast_llm/layers/ssm/mamba.py index 061921b3d..453c14af6 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba.py @@ -4,20 +4,16 @@ import torch -from fast_llm.engine.config_utils.tensor_space import ( - CompositeTensorDim, - ConcatenatedTensorDim, - DefaultDimNames, - TensorDim, - TensorSpace, -) +from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_ -from fast_llm.utils import Assert, div, get_lr_scale +from fast_llm.tensor import ParameterMeta +from fast_llm.utils import Assert, combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -58,26 +54,25 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return LambdaInitializer(init_) -class MambaLayer(Mixer): +class Mamba[ConfigType: SSMConfig](BlockLayer[ConfigType]): _mixer_name: typing.ClassVar[str] = "mamba" def __init__( self, - config: SSMConfig, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, block_index: int, - tensor_space: TensorSpace, - transformer_config: TransformerConfig, + name: str, + lr_scale: float | None, ): - super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" - self._config = config + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # Tensor dims: - hidden_dim = tensor_space[TransformerDimNames.hidden] heads_dim = TensorDim("heads", div(self._config.d_inner, self._config.state_size)) state_dim = TensorDim("state", self._config.state_size) inner_dim = CompositeTensorDim("inner", (heads_dim, state_dim)) @@ -86,6 +81,8 @@ def __init__( inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) + # TODO: Backward compatibility? self.in_proj = Linear( hidden_dim, @@ -94,17 +91,15 @@ def __init__( weight_init_method=init_kaiming_(hidden_dim.size), lr_scale=lr_scale, ) - self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, - tensor_space[DefaultDimNames.scalar], + scalar_dim, convolution_kernel_dim, ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, ) - self.x_proj = Linear( inner_dim, x_projection_dim, @@ -113,27 +108,23 @@ def __init__( lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True - # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( (inner_dim, dt_rank_dim), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) - self.dt_proj_bias = ParameterMeta.from_dims( (inner_dim,), init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), lr_scale=lr_scale, ) - self.A_log = ParameterMeta.from_dims( (inner_dim, state_dim), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( (inner_dim,), @@ -141,7 +132,6 @@ def __init__( init_method=init_ones_, lr_scale=lr_scale, ) - self.out_proj = Linear( inner_dim, hidden_dim, @@ -151,9 +141,15 @@ def __init__( ) self.out_proj.weight.auto_grad_accumulation = True - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[TransformerKwargs.sequence_first] else (0, 2, 1)) + in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[BlockKwargs.sequence_first] else (0, 2, 1)) # In the backward pass we write dx and dz next to each other to avoid torch.cat # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s @@ -172,6 +168,10 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None + + +def init_kaiming_(d_in: float) -> LambdaInitializer: + return init_normal_(0.0, math.sqrt(2.0 / d_in)) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 7151da394..2659e415f 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,22 +3,17 @@ import torch -from fast_llm.engine.config_utils.tensor_space import ( - CompositeTensorDim, - ConcatenatedTensorDim, - DefaultDimNames, - TensorDim, - TensorSpace, -) -from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType +from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ -from fast_llm.utils import Assert, div, get_lr_scale +from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ +from fast_llm.tensor import ParameterMeta +from fast_llm.utils import Assert, combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -37,7 +32,7 @@ logger = logging.getLogger(__name__) -class Mamba2(Mixer): +class Mamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ @@ -46,27 +41,24 @@ class Mamba2(Mixer): def __init__( self, - config: SSMConfig, - tensor_space: TensorSpace, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, block_index: int, - transformer_config: TransformerConfig, + name: str, + lr_scale: float | None, ): - super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) - self._config: SSMConfig = config + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) num_heads = div(self._config.d_inner, self._config.state_size) num_head_groups = div(self._config.d_xb, self._config.state_size) - hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] state_dim = TensorDim("state", self._config.state_size) head_groups_dim = TensorDim( - "head_groups", - num_head_groups, - self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor), + "head_groups", num_head_groups, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) ) group_heads_dim = TensorDim("group_heads", div(num_heads, num_head_groups)) @@ -89,12 +81,14 @@ def __init__( self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size - conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim + + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) + self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space[DefaultDimNames.scalar], + scalar_dim, convolution_kernel_dim, ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), @@ -109,16 +103,15 @@ def __init__( hidden_dim, inner_projection_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.dt_in_proj = Linear( hidden_dim, dt_rank_dim, bias=config.add_bias_linear, - weight_init_method=init_kaiming_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(block_config.hidden_size), lr_scale=lr_scale, ) self.dt_proj = OutputParallelLinear( @@ -158,20 +151,27 @@ def __init__( sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - if self._debug_level: + + if self._debug.enabled: self._xz_dims = ( - TransformerDimNames.batch, + BlockDimNames.batch, inner_dim, - TransformerDimNames.sequence_q, + BlockDimNames.sequence_q, ) self._bc_dims = ( - TransformerDimNames.batch, + BlockDimNames.batch, heads_dim, state_dim, - TransformerDimNames.sequence_q, + BlockDimNames.sequence_q, ) - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict[str, typing.Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available assert _causal_conv1d_available @@ -180,7 +180,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ inner_projection = self.in_proj(input_) dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias # Standardize to (batch, sequence, local_inner_projection) - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) dt = dt.transpose(0, 1) @@ -225,12 +225,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # dt: (batch, sequence, heads * state) -> (batch, heads * state, sequence) dt = dt.transpose(1, 2) - if self._debug_level: - self._debug_log(z, "z", self._xz_dims, kwargs) - self._debug_log(x, "x", self._xz_dims, kwargs) - self._debug_log(b, "b", self._bc_dims, kwargs) - self._debug_log(c, "c", self._bc_dims, kwargs) - self._debug_log(dt, "dt", self._xz_dims, kwargs) + if self._debug.enabled: + self._debug(z, "z", self._xz_dims, kwargs) + self._debug(x, "x", self._xz_dims, kwargs) + self._debug(b, "b", self._bc_dims, kwargs) + self._debug(c, "c", self._bc_dims, kwargs) + self._debug(dt, "dt", self._xz_dims, kwargs) y = selective_scan_fn( x, @@ -244,12 +244,12 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ delta_softplus=True, ) - if self._debug_level: - self._debug_log(y, "y", self._xz_dims, kwargs) + if self._debug.enabled: + self._debug(y, "y", self._xz_dims, kwargs) # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] - if kwargs[TransformerKwargs.sequence_first]: + if kwargs[BlockKwargs.sequence_first]: # TODO: Is contiguous needed? y = y.transpose(0, 1).contiguous() # (batch/sequence, sequence/batch, local_heads * state) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py deleted file mode 100644 index f6eaf5890..000000000 --- a/fast_llm/layers/transformer/config.py +++ /dev/null @@ -1,664 +0,0 @@ -import abc -import enum -import functools -import logging -import math -import typing -import warnings - -from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.functional.config import ActivationType, MLPRecomputeLevel, TritonConfig -from fast_llm.layers.common.config import LLMBlockConfig, LoRAConfig, NoPeftConfig, NormalizationConfig, PeftConfig -from fast_llm.layers.transformer.rotary.config import RotaryConfig -from fast_llm.utils import Assert, div - -if typing.TYPE_CHECKING: - import torch - - from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.tensor import ParameterMeta - -logger = logging.getLogger(__name__) - - -class RoutingType(str, enum.Enum): - topk = "aux_loss" - sinkhorn = "sinkhorn" - - -class TransformerDimNames: - # A set of common tensor dim names packed into a namespace. - # Input dimensions (variable) - # TODO: Does batch belong here? - batch = "batch" - # TODO: Distinguish micro-sequence? - sequence_q = "sequence_q" - sequence_q_tp = "sequence_q_tp" - sequence_k = "sequence_k" - hidden = "hidden" - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" - - -class TransformerKwargs: - rotary_freq_q = "rotary_freq_q" - rotary_freq_k = "rotary_freq_k" - attention_mask = "attention_mask" - attention_mask_value = "attention_mask_value" - sequence_lengths = "sequence_lengths" - cu_seqlens_q = "cu_seqlens_q" - cu_seqlens_k = "cu_seqlens_k" - max_seqlen_q = "max_seqlen_q" - max_seqlen_k = "max_seqlen_k" - # TODO: Review these - presents = "presents" - past_key_values = "past_key_values" - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" - sequence_q_dim = "sequence_q_dim" - sequence_k_dim = "sequence_k_dim" - sequence_length = "sequence_length" - # TODO: Move - grad_output = "grad_output" - - -class TransformerLossNames: - load_balancing_loss = "load_balancing_loss" - router_z_loss = "router_z_loss" - - -class AddLinearBiasChoices(str, enum.Enum): - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" - - -class TransformerSubLayerName(str, enum.Enum): - # TODO: Use this to replace AddLinearBiasChoices. - query = "query" - key = "key" - value_ = "value" - key_value = "key_value" - dense = "dense" - mlp_1 = "mlp_1" - mlp_2 = "mlp_2" - - -@config_class(registry=True) -class TransformerPeftConfig(PeftConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - pass - - @abc.abstractmethod - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - pass - - @abc.abstractmethod - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - pass - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return TransformerNoPeftConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class(dynamic_type={TransformerPeftConfig: "none"}) -class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): - _abstract = False - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - return super().apply_linear(linear) - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - return parameter - - -@config_class(dynamic_type={TransformerPeftConfig: "lora"}) -class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): - layers: list[TransformerSubLayerName] = Field( - default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), - desc="The layers on which to apply LoRA.", - hint=FieldHint.feature, - ) - freeze_others: bool = Field( - default=True, - desc="Whether to freeze other layers during training.", - ) - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - if layer_type is None or self.layers is None or layer_type in self.layers: - if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) - elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False - return linear - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.freeze_others: - for parameter in module.parameters(): - parameter.requires_grad = False - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.freeze_others: - parameter.requires_grad = False - return parameter - - def _validate(self) -> None: - super()._validate() - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." - ) - - -@config_class() -class TransformerConfig(LLMBlockConfig): - _abstract = False - normalization: NormalizationConfig = Field( - desc="Configuration for the normalization layers architecture.", - hint=FieldHint.architecture, - ) - rotary: RotaryConfig = Field( - desc="Configuration for the rotary positional embeddings.", - hint=FieldHint.architecture, - ) - peft: TransformerPeftConfig = Field( - desc="Configuration for the parameter-efficient fine tuning.", - hint=FieldHint.architecture, - ) - num_layers: int = Field( - default=12, - desc="Number of layers in the transformer.", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - hidden_size: int = Field( - default=1024, - desc="Size of the transformer's main hidden dimension, e.g., for its input and output layers.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - num_attention_heads: int = Field(default=8, desc="Number of attention heads.", hint=FieldHint.architecture) - head_groups: int = Field( - default=1, - desc="Number of head group for grouped query attention.", - doc="Set to 1 for multi-query attention, `num_attention_heads` for multi-head.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - add_linear_biases: bool | AddLinearBiasChoices = Field( - default=True, - desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", - hint=FieldHint.architecture, - ) - ffn_hidden_size: int = Field( - default=None, - desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - kv_channels: int = Field( - default=None, - desc="Number of key and value channels, i.e., hidden dimension of each attention head. Default: hidden_size // num_attention_heads", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - gated: bool = Field(default=False, desc="Enable gated MLP.", hint=FieldHint.architecture) - num_experts: int = Field( - default=1, - desc="Number of MLP experts in a Mixture of Expert (MoE) model", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - num_shared_experts: int = Field( - default=0, - desc="Number of MLP experts that are shared between all tokens, i.e., always enabled.", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - num_unshared_experts: int = Field( - init=False, - desc="Number of MLP experts excluding shared ones", - hint=FieldHint.architecture, - valid=check_field(Assert.geq, 0), - ) - num_experts_per_token: int = Field( - default=1, - desc="Active experts for each token in a MoE model.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - expert_routing_type: RoutingType = Field( - default=RoutingType.topk, - desc="The routing method, i.e., the method used to assign experts to tokens.", - hint=FieldHint.architecture, - ) - activation_type: ActivationType = Field( - default=None, - desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", - hint=FieldHint.core, - ) - # Default: hidden_size**-0.5 - # TODO: Allow custom initialization (InitializationConfig?) - init_method_std: float = Field( - default=None, - desc="Default scale for weight initialization. Default: hidden_size**-0.5", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max: float | None = Field( - default=None, - desc="Max value for clamping initialized weights. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min: float | None = Field( - default=None, - desc="Min value for clamping initialized weights. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_qkv: float = Field( - default=None, - desc="Scale for the query, key and value weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_qkv: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_qkv: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_attn_proj: float = Field( - default=None, - desc="Scale for the attention projection weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_attn_proj: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_attn_proj: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_1: float = Field( - default=None, - desc="Scale for the MLP first layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_1: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_1: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - init_method_std_mlp_2: float = Field( - default=None, - desc="Scale for the MLP second layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_2: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_mlp_2: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", - hint=FieldHint.optional, - ) - attention_dropout: float = Field( - default=0.0, - desc="Dropout applied to the attention intermediate states.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - hidden_dropout: float = Field( - default=0.0, - desc="Dropout applied to the residual connections.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, - ) - # Use flash attention if possible (fp16 or bf16) - use_flash_attention: bool = Field( - default=True, desc="Enable Flash Attention if possible.", hint=FieldHint.optional - ) - window_size: int | None = Field( - default=None, - desc="Size of the attention sliding window. Warning: this parameter is not part of the architecture and must be redefined when loading a pretrained model.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - max_window_layers: int | None = Field( - default=None, - desc="The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.", - hint=FieldHint.optional, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - # normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto - mlp_recompute_level: MLPRecomputeLevel = Field( - default=MLPRecomputeLevel.none, - desc="Set which of the MLP intermediate activations will be recomputed during the backward passes. This provides a trade-off between memory and speed.", - hint=FieldHint.performance, - ) - debug_transformer: int = Field( - default=0, - desc="Log the output of each operation in a transformer layer.", - hint=FieldHint.logging, - valid=check_field(Assert.geq, 0), - ) - debug_transformer_memory: bool = Field( - default=False, - desc="Log the memory usage after each operation in a transformer layer..", - hint=FieldHint.logging, - ) - # Use random inits instead of constant values, useful for debugging. - random_bias_init: bool = Field( - default=False, - desc="Initialize the biases using the initialization method of their respective weights instead of setting them to zero. Used to test for issues that may not be visible when the biases are zero.", - hint=FieldHint.testing, - ) - expert_auxiliary_loss_coefficient: float = Field( - default=0.01, - desc="Scale of the load balancing auxiliary loss for topk routing.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - expert_z_loss_coefficient: float = Field( - default=0.0, - desc="Regularize the router during training by applying Z-loss to the logits.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - moe_jitter_eps: float = Field( - default=0.0, - desc="Regularize the router during training by applying a random multiplicative noise `uniform(1-eps, 1+eps)` to the logits.", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - mlp_lr_scale: float | None | list[float | None] = Field( - default=None, - desc="Custom learning rate scale for each expert.", - doc="May be used to freeze some experts by setting their scale to zero.", - hint=FieldHint.feature, - ) - router_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate for the MoE router weight.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_lr_scale: float | None = Field( - default=None, - desc="Custom learning rate scale for the Attention projection weights.", - doc="Can be used in muP to scale the Attention learning rate by 1/width_factor", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - attention_softmax_scale_power: float = Field( - default=0.5, - desc="The scaling power to apply to kv_channel in the attention calculation. " - " Under Standard Parameterization (SP): default to 0.5. " - " Under muP (if scaling kv_channels size): use 1. " - " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) - dropless_moe: bool = Field( - default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert - ) - dropless_dynamic_shape: bool = Field( - default=False, - desc="Use a dynamic shape for dropless MLP instead of the worst-case value." - " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", - hint=FieldHint.expert, - ) - - def _validate(self) -> None: - with self._set_implicit_default(): - if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - if self.kv_channels is None: - self.kv_channels = div(self.hidden_size, self.num_attention_heads) - if self.activation_type is None: - self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) - self.num_unshared_experts = self.num_experts - self.num_shared_experts - - super()._validate() - - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - - Assert.leq(self.num_shared_experts, self.num_experts) - Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) - Assert.multiple(self.num_attention_heads, self.head_groups) - Assert.geq(self.attention_dropout, 0) - Assert.geq(self.hidden_dropout, 0) - - if isinstance(self.mlp_lr_scale, list): - Assert.eq(len(self.mlp_lr_scale), self.num_experts) - for scale in self.mlp_lr_scale: - if scale is not None: - Assert.geq(scale, 0) - elif self.mlp_lr_scale is not None: - Assert.geq(self.mlp_lr_scale, 0) - - @functools.cached_property - def projection_size(self): - assert self._validated - return self.num_attention_heads * self.kv_channels - - @property - def add_mlp_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - - @property - def add_attn_qkv_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.nowhere: - return False - return True - - @property - def add_attn_dense_bias(self) -> bool: - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - # TODO v0.x: Remove backward compatibility. - cls._handle_renamed_field( - default, - "use_rotary_embeddings", - ("rotary", "type"), - lambda x: "default" if x else "none", - ) - cls._handle_renamed_field(default, "rotary_embedding_scale", ("rotary", "theta"), lambda x: math.exp(-x)) - cls._handle_renamed_field(default, "triton_rotary", ("rotary", "triton")) - return super()._from_dict(default, strict, flat) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.hidden, self.hidden_size)) - - # Self-attention dimensions - tensor_space.add_tensor_dim( - head_groups := TensorDim( - TransformerDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None - ) - ) - tensor_space.add_tensor_dim( - group_heads := TensorDim( - TransformerDimNames.group_heads, - div(self.num_attention_heads, self.head_groups), - None if self.head_groups > 1 else tensor, - ) - ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(TransformerDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(TransformerDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_query, (head_groups, group_heads, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_dense, (head_groups, group_heads, kv_channels)) - ) - - # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(TransformerDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(TransformerDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(TransformerDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(TransformerDimNames.composite_expert_mlp, (experts, mlp))) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(TransformerDimNames.unshared_experts, self.num_unshared_experts)) - - # shared_experts - if self.num_shared_experts: - tensor_space.add_tensor_dim( - shared_experts := TensorDim(TransformerDimNames.shared_experts, self.num_shared_experts) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(TransformerDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim( - TransformerDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp) - ) - ) - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py deleted file mode 100644 index 101d97ef3..000000000 --- a/fast_llm/layers/transformer/mlp.py +++ /dev/null @@ -1,101 +0,0 @@ -import typing -from abc import ABC - -import torch - -from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd -from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerName -from fast_llm.tensor import init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale - - -class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - super().__init__() - self._name = name - self._block_index = block_index - - init_method_1 = init_normal_( - std=config.init_method_std_mlp_1, - min_val=config.init_method_min_mlp_1, - max_val=config.init_method_max_mlp_1, - ) - init_method_2 = init_normal_( - std=config.init_method_std_mlp_2, - min_val=config.init_method_min_mlp_2, - max_val=config.init_method_max_mlp_2, - ) - - hidden_dim = tensor_space[TransformerDimNames.hidden] - self._intermediate_dim = tensor_space[TransformerDimNames.composite_expert_mlp] - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._recompute_level = config.mlp_recompute_level - - self._gated = config.gated - self._activation_type = config.activation_type - self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale - lr_scale = get_lr_scale(lr_scale, layer_lr_scale) - - # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) - self.layer_1 = LinearBase( - hidden_dim, - tensor_space[TransformerDimNames.composite_gated_expert_mlp], - bias=config.add_mlp_bias, - weight_init_method=init_method_1, - bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, - lr_scale=lr_scale, - ) - self.layer_2 = LinearBase( - self._intermediate_dim, - hidden_dim, - bias=config.add_mlp_bias, - weight_init_method=init_method_2, - bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, - auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, - transposed_weight=True, - lr_scale=lr_scale, - ) - - # PEFT. - self.layer_1 = config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) - - -class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): - Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, block_index) - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - parallel_group = self._intermediate_dim.parallel_group - return ( - mlp_autograd( - input_, - None, - self.layer_1.weight, - self.layer_1.bias, - self.layer_2.weight, - None if parallel_group else self.layer_2.bias, - gated=self._gated, - activation_type=self._activation_type, - group=parallel_group, - sequence_parallel=self._sequence_parallel, - training=self.training, - recompute_level=self._recompute_level, - transposed_layer_2_weight=self.layer_2.transposed_weight, - ), - self.layer_2.bias if parallel_group else None, - ) diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py deleted file mode 100644 index c357411b6..000000000 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ /dev/null @@ -1,68 +0,0 @@ -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.tensor import TensorMeta - - -class RotaryEmbeddingPreprocessor(Preprocessor): - _scalar_dim: TensorDim - _kv_channels_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor - _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__( - self, - config: DefaultRotaryConfig, - tensor_space: TensorSpace, - ): - self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[TransformerKwargs.sequence_length]) - sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size - kwargs[TransformerKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[TransformerKwargs.sequence_q_dim].size : sequence_k - ] - kwargs[TransformerKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[TransformerKwargs.rotary_freq_q] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=TransformerKwargs.rotary_freq_q, - ) - kwargs[TransformerKwargs.rotary_freq_k] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[TransformerKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=TransformerKwargs.rotary_freq_k, - ) - - def _create_tensors(self, sequence_length: int) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - self._rotary_embedding_frequencies = self._config.get_frequencies( - sequence_length, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py deleted file mode 100644 index c7becd948..000000000 --- a/fast_llm/layers/transformer/transformer.py +++ /dev/null @@ -1,215 +0,0 @@ -import abc -import logging -import typing - -import torch - -from fast_llm.core.distributed import set_generator -from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.transformer.mlp import MLP -from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage -from fast_llm.tensor import TensorMeta -from fast_llm.utils import Assert - -logger = logging.getLogger(__name__) - - -class Mixer(torch.nn.Module, abc.ABC): - """ - Base class for mixer modules. - """ - - _mixer_name: typing.ClassVar[str] - - def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): - super().__init__() - self._tensor_space = tensor_space - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._block_index = block_index - self._debug_level = debug_level - - @abc.abstractmethod - def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Mixer module forward. Returns the output hidden states and an optional bias, - in case its addition can be made more efficient in `_bias_dropout_add`. - """ - - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = { - dim.name: dim - for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) - } - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] - for dim_name in dim_names - ), - tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_level, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_level, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - - -class BaseBlock(Layer, abc.ABC): - """ - A transformer-like decoder base block with abstract mixer. - """ - - # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "mixer" - - def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False - ): - super().__init__() - self._config: TransformerConfig = config - self._tensor_space: TensorSpace = tensor_space - self._dropout_p: float = self._config.hidden_dropout - # For multi-token prediction, return a stack of shared_hidden and transformer_output. - self._return_input: bool = return_input - - self._block_index = block_index - self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space[TransformerDimNames.hidden] - - # TODO: add a separate norm_lr_scale - lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - self.norm_1 = self._config.normalization.get_layer(hidden_dim, lr_scale) - self.norm_2 = self._config.normalization.get_layer(hidden_dim, lr_scale) - - # The mixer needs to be created here for backward-compatible weight ordering. - setattr(self, self._mixer_module_name, self._create_mixer()) - - self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index - ) - - # PEFT. - self.norm_1 = self._config.peft.apply_other(self.norm_1) - self.norm_2 = self._config.peft.apply_other(self.norm_2) - - @abc.abstractmethod - def _create_mixer(self) -> Mixer: - pass - - @torch.compile - def _bias_dropout_add( - self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor - ) -> torch.Tensor: - if bias is not None: - input_ = input_ + bias - return residual + torch.dropout(input_, self._dropout_p, self.training) - - @property - def name(self) -> str: - return f"{self._name} {self._block_index}" - - def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): - dims = kwargs[TransformerKwargs.hidden_dims] - if self._return_input: - dims = (TensorDim("stacked_input_output", 2),) + dims - return TensorMeta.from_dims(dims, tensor_name=f"{self.name} {name}", dtype=tensor.dtype) - - def _debug_log(self, tensor: torch.Tensor | None, name: str, kwargs: dict[str, typing.Any], *, bias=None) -> None: - if self._config.debug_transformer_memory: - log_pipeline_parallel_main_rank(lambda: log_memory_usage(f"{self.name} {name}", str)) - if self._config.debug_transformer and tensor is not None: - # TODO: Local vs global - log_distributed_tensor( - "", - tensor if bias is None else tensor + bias, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name, kwargs), - distributed=self._tensor_space.distributed, - ) - log_distributed_grad( - "", - tensor, - level=self._config.debug_transformer, - meta=self._get_meta(tensor, name + " grad", kwargs), - distributed=self._tensor_space.distributed, - ) - - def forward( - self, - input_: torch.Tensor, - kwargs: dict[str, typing.Any], - losses: dict[str, typing.Any] | None = None, - metrics: dict[str, typing.Any] | None = None, - ) -> torch.Tensor: - if isinstance(input_, TensorMeta): - return self._get_meta(input_, "output", kwargs) - generator = ( - self._tensor_space.distributed.tp_generator - if self._tensor_space.distributed_config.sequence_tensor_parallel - else self._tensor_space.distributed.pp_generator - ) - if self._debug_mode: - self._debug_log(None, "Begin", kwargs) - fw_input = input_ - hidden_states = self.norm_1(input_) - if self._debug_mode: - self._debug_log(hidden_states, "Norm 1", kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) - if self._debug_mode: - self._debug_log(hidden_states, f"{self._mixer_module_name} output", kwargs, bias=bias) - with set_generator(generator): - input_ = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug_mode: - self._debug_log(input_, f"{self._mixer_module_name} residual", kwargs) - hidden_states = self.norm_2(input_) - if self._debug_mode: - self._debug_log(hidden_states, "Norm 2", kwargs) - hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) - if self._debug_mode: - self._debug_log(hidden_states, "MLP output", kwargs, bias=bias) - with set_generator(generator): - hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - if self._debug_mode: - self._debug_log(None, "MLP residual", kwargs, bias=bias) - if self._return_input: - hidden_states = torch.stack((fw_input, hidden_states), dim=0) - return hidden_states - - -class TransformerBlock(BaseBlock): - _name = "Transformer layer" - # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "self_attn" - - def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False - ): - super().__init__(config, tensor_space, block_index, return_input) - - def _create_mixer(self) -> Mixer: - from fast_llm.layers.transformer.attention import Attention - - return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 6d555a0bb..024d7d79c 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -14,7 +14,6 @@ if typing.TYPE_CHECKING: from fast_llm.core.distributed import ProcessGroup - from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -254,7 +253,6 @@ def log_distributed_tensor[ scale: float = 1.0, level: int = 2, storage: bool = False, - distributed: "Distributed", duplicate_groups: tuple[typing.Optional["ProcessGroup"], ...] = (), global_: bool = True, log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, @@ -263,7 +261,7 @@ def log_distributed_tensor[ if level <= 0: return if global_: - tensor, is_first_rank = meta.local_to_global(tensor, distributed=distributed) + tensor, is_first_rank = meta.local_to_global(tensor) storage = False is_first_rank = is_first_rank and all(group.rank() == 0 for group in duplicate_groups if group) if not is_first_rank: @@ -289,7 +287,6 @@ def log_distributed_grad[ scale: float = 1.0, level: int = 2, storage: bool = False, - distributed: "Distributed", duplicate_groups: tuple[typing.Optional["ProcessGroup"], ...] = (), grad_fn: typing.Callable[[torch.Tensor], torch.Tensor] | None = None, global_: bool = True, @@ -305,7 +302,6 @@ def log_distributed_grad[ scale=scale, level=level, storage=storage, - distributed=distributed, duplicate_groups=duplicate_groups, global_=global_, log_fn=log_fn, diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 534d813ff..3afd88ce1 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -3,43 +3,33 @@ import torch from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import Layer, LossDef +from fast_llm.engine.base_model.base_model import LossDef from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerBlock -from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig +from fast_llm.models.custom.config import CustomBaseModelConfig from fast_llm.models.custom.head import CustomHead -from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.tensor import TensorMeta class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): - config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - def __init__( self, - config: CustomBaseModelConfig, + config: ConfigType, distributed_config: DistributedConfig, ): # TODO: Implement / update. super().__init__(config, distributed_config) - def get_layers(self) -> list[Layer]: - # TODO: Adjust as needed. - return [ - LanguageModelEmbedding(self._config, self._tensor_space), - *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - ) - for i in range(self._config.transformer.num_layers) - ], - CustomHead(self._config, self._tensor_space), - ] + def _get_head(self, prediction_distance): + return CustomHead( + self._config, + self._distributed_config, + self._hidden_dim, + max(self._config.transformer.num_layers + prediction_distance, 1), + f"Language model head {prediction_distance}", + prediction_distance=prediction_distance, + ) def preprocess_meta( self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType @@ -66,5 +56,4 @@ def loss_defs(self) -> list[LossDef]: class CustomModel[ConfigType: CustomBaseModelConfig](GPTModel[ConfigType]): - config_class: typing.ClassVar[type[CustomModelConfig]] = CustomModelConfig base_model_class: typing.ClassVar[type[CustomBaseModel]] = CustomBaseModel diff --git a/fast_llm/models/custom/trainer.py b/fast_llm/models/custom/trainer.py index eba51235e..587adad3e 100644 --- a/fast_llm/models/custom/trainer.py +++ b/fast_llm/models/custom/trainer.py @@ -1,5 +1,3 @@ -import typing - from fast_llm.models.custom.config import CustomTrainerConfig from fast_llm.models.custom.data import CustomData from fast_llm.models.gpt.trainer import GPTTrainer @@ -7,8 +5,6 @@ class CustomTrainer[ConfigType: CustomTrainerConfig](GPTTrainer[ConfigType]): # TODO: Implement changes in the training loop (or tflops computation), if any (typically none). - config_class: typing.ClassVar[type[CustomTrainerConfig]] = CustomTrainerConfig - def _get_data(self): # TODO: Adjust signature if needed. return CustomData( diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index d8425786d..36975dea1 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,10 +24,11 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import LayerNormalizationConfig -from fast_llm.layers.transformer.config import RoutingType, TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex +from fast_llm.layers.block.mlp.config import RoutingType +from fast_llm.layers.common.normalization.config import LayerNormalizationConfig from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, @@ -198,19 +199,19 @@ def _create_transformer_layer_converters( ( f"{fast_llm_layer_name}.self_attn.query", f"{hf_layer_name}.self_attn.q_proj", - transformer_config.add_attn_qkv_bias, + transformer_config.add_qkv_bias, QueryWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.key_value", (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), - transformer_config.add_attn_qkv_bias, + transformer_config.add_qkv_bias, KeyValueWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.dense", f"{hf_layer_name}.self_attn.o_proj", - transformer_config.add_attn_dense_bias, + transformer_config.add_dense_bias, WeightConverter, ), # Norm diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index cf7da3872..2f99ae4c3 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner @@ -86,12 +86,12 @@ def forward( if past_key_values is not None: # The transformers will use the past keys and values to this list. - kwargs[TransformerKwargs.past_key_values] = past_key_values + kwargs[AttentionKwargs.past_key_values] = past_key_values # TODO: preprocess needs to know about the past. raise NotImplementedError() if use_cache: # The transformers will save the present keys and values to this list. - kwargs[TransformerKwargs.presents] = [] + kwargs[AttentionKwargs.presents] = [] if output_hidden_states: kwargs["output_hidden_states"] = True @@ -117,11 +117,11 @@ def forward( outputs = (logits,) if use_cache: - outputs += (kwargs[TransformerKwargs.presents],) + outputs += (kwargs[AttentionKwargs.presents],) return outputs return transformers.modeling_outputs.CausalLMOutputWithPast( logits=logits, hidden_states=hidden_states, - past_key_values=kwargs[TransformerKwargs.presents], + past_key_values=kwargs[AttentionKwargs.presents], ) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 20ed8e828..5d3130549 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,7 +1,7 @@ import typing -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -94,7 +94,7 @@ def _init_attention_megatron( raise NotImplementedError(meta.tensor_name) if isinstance(config.rotary, DefaultRotaryConfig) and config.rotary.complex_format: - from fast_llm.layers.transformer.rotary.config import convert_rotary_real_to_complex + from fast_llm.layers.attention.rotary.config import convert_rotary_real_to_complex # Megatron uses (2, kv_channels/2) for the complex split; we use (kv_channels/2, 2). # TODO: Avoid unnecessarily changing the value and dense tensors. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 49a5dcbd3..b13c77724 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -6,22 +6,19 @@ from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.attention.block import TransformerBlock +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor +from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.transformer.config import ( - RoutingType, - TransformerDimNames, - TransformerKwargs, - TransformerLossNames, -) -from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -35,13 +32,12 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](BaseModel[ConfigType]): A transformer-based language model generalizing the GPT model architecture. """ - config_class: typing.ClassVar[type[GPTBaseModelConfig]] = GPTBaseModelConfig - def __init__( self, config: GPTBaseModelConfig, distributed_config: DistributedConfig, ): + self._hidden_dim = TensorDim("hidden", config.transformer.hidden_size) super().__init__(config, distributed_config) self._use_flash_attention = self._config.transformer.do_use_flash_attention(distributed_config) if self._config.use_megatron_initialization: @@ -51,59 +47,87 @@ def __init__( # `self._reference_models` is not populated at this point, so we pass a mutable dict. self._preprocessors: list[Preprocessor] = [] if self._config.use_absolute_position_embeddings: - self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) + self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._distributed_config)) # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. - self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) + self._preprocessors.append( + self._config.transformer.rotary.get_layer(TensorDim("kv_channels", self._config.transformer.kv_channels)) + ) if self._use_flash_attention: - self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._distributed_config)) else: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._distributed_config)) if self._config.enable_dpo: # TODO better way to pass in? - self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._distributed_config)) - def get_output_layers(self) -> list[Layer]: + def _get_output_layers(self) -> list[Layer]: layers = [] for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, + self._get_block( # TODO MTP: which index? - block_index=max(self._config.transformer.num_layers + i, 1), + max(self._config.transformer.num_layers + i, 1), + f"MPT head {i} block", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=i < self._config.prediction_heads - 1, + i < self._config.prediction_heads - 1, ) ) - layers.append( - LanguageModelHead( - self._config, - self._tensor_space, - prediction_distance=i, - ) - ) + layers.append(self._get_head(i)) return layers def get_layers(self) -> list[Layer]: return [ - LanguageModelEmbedding(self._config, self._tensor_space), + self._get_embeddings(), *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, + self._get_block( + i + 1, + f"Block {i + 1}", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, + self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) for i in range(self._config.transformer.num_layers) ], - *self.get_output_layers(), + *self._get_output_layers(), ] + def _get_block( + self, + block_index: int, + name: str, + return_input: bool = False, + ): + lr_scale = ( + None + if self._config.transformer.per_layer_lr_scale is None + else self._config.transformer.per_layer_lr_scale[block_index] + ) + return TransformerBlock( + self._config.transformer, + self._distributed_config, + self._hidden_dim, + block_index, + name, + lr_scale, + return_input, + ) + + def _get_embeddings(self): + return LanguageModelEmbedding(self._config, self._distributed_config, self._hidden_dim, 0, "Embeddings") + + def _get_head(self, prediction_distance): + return LanguageModelHead( + self._config, + self._distributed_config, + self._hidden_dim, + max(self._config.transformer.num_layers + prediction_distance, 1), + f"Language model head {prediction_distance}", + prediction_distance=prediction_distance, + ) + def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -122,8 +146,8 @@ def preprocess_meta( micro_sequence_length = sequence_length truncate_documents = True - batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(TransformerDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) + batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -132,19 +156,17 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - TransformerDimNames.sequence_q, + BlockDimNames.sequence_q, micro_sequence_length, - self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), + self._distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - TransformerDimNames.sequence_q_tp, + BlockDimNames.sequence_q_tp, micro_sequence_length, - self._tensor_space.distributed_config.get_distributed_dim( - DistributedDimNames.tensor_and_sequence_data - ), + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_sequence_data), ) - if self._tensor_space.distributed_config.sequence_tensor_parallel + if self._distributed_config.sequence_tensor_parallel else sequence_q_dim ) @@ -155,24 +177,23 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( - (hidden_sequence_q_dim, batch_dim, hidden_dim) + (hidden_sequence_q_dim, batch_dim, self._hidden_dim) if sequence_first - else (batch_dim, hidden_sequence_q_dim, hidden_dim) + else (batch_dim, hidden_sequence_q_dim, self._hidden_dim) ) common_kwargs = { LanguageModelKwargs.phase: phase, - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.hidden_dims: hidden_dims, - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_q_dim: sequence_q_dim, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.hidden_dims: hidden_dims, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_q_dim: sequence_q_dim, LanguageModelKwargs.mask_inputs: not truncate_documents, } sequence_k_pasts = range( - sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, + sequence_q_dim.size * self._distributed_config.sequence_data_rank, sequence_length, micro_sequence_length, ) @@ -186,7 +207,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(BlockDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -194,7 +215,7 @@ def preprocess_meta( kwargs = { **common_kwargs, - TransformerKwargs.sequence_k_dim: sequence_k_dim, + AttentionKwargs.sequence_k_dim: sequence_k_dim, } if phase != PhaseType.inference: kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( @@ -206,10 +227,10 @@ def preprocess_meta( for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( - TransformerKwargs.sequence_first, - TransformerKwargs.sequence_length, - TransformerKwargs.sequence_q_dim, - TransformerKwargs.sequence_k_dim, + AttentionKwargs.sequence_first, + AttentionKwargs.sequence_length, + AttentionKwargs.sequence_q_dim, + AttentionKwargs.sequence_k_dim, ): Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ @@ -235,12 +256,12 @@ def preprocess( preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[TransformerKwargs.sequence_q_dim].size - sequence_first = common_kwargs[TransformerKwargs.sequence_first] + sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size + sequence_first = common_kwargs[AttentionKwargs.sequence_first] prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( - device=self._tensor_space.distributed.device, + device=self._distributed.device, dtype=torch.int64, non_blocking=True, ) @@ -268,14 +289,14 @@ def preprocess( preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[TransformerKwargs.sequence_k_dim].size + sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size if sequence_first: tokens = batch.token_ids[sequence_k - sequence_q : sequence_k] else: # TODO: Avoid multiple contiguous calls? tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: - kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths + kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths if batch.chosen_spans is not None: kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans if batch.rejected_spans is not None: @@ -287,8 +308,8 @@ def preprocess( presents = None if i == len(preprocessed_meta) - 1 else [] kwargs = { **kwargs_meta, - TransformerKwargs.past_key_values: pasts, - TransformerKwargs.presents: presents, + AttentionKwargs.past_key_values: pasts, + AttentionKwargs.presents: presents, } if phase != PhaseType.inference: sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels @@ -374,7 +395,7 @@ def loss_defs(self) -> list[LossDef]: ): loss_defs.append( LossDef( - name=TransformerLossNames.load_balancing_loss, + name=MLPLossNames.load_balancing_loss, formatted_name="load balancing loss", count=self._config.transformer.num_layers, ) @@ -382,7 +403,7 @@ def loss_defs(self) -> list[LossDef]: if self._config.transformer.expert_z_loss_coefficient: loss_defs.append( LossDef( - name=TransformerLossNames.router_z_loss, + name=MLPLossNames.router_z_loss, formatted_name="router z loss", count=self._config.transformer.num_layers, ) @@ -414,7 +435,6 @@ def loss_defs(self) -> list[LossDef]: class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): - config_class: typing.ClassVar[type[GPTModelConfig]] = GPTModelConfig base_model_class: typing.ClassVar[type[GPTBaseModel]] = GPTBaseModel def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, sequence_length) -> tuple[int, int]: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54508e8e1..7f2e83ab4 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -10,8 +10,6 @@ class GPTTrainer[ConfigType: GPTTrainerConfig](Trainer[ConfigType]): - config_class: typing.ClassVar[type[GPTTrainerConfig]] = GPTTrainerConfig - def _get_data(self) -> GPTData: return GPTData( config=self._config.data, diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index f632ab6c7..ef4325552 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -6,7 +6,6 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig @@ -47,14 +46,6 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): # TODO: Support combination of different SSM block types. ssm_block_type: SSMBlockType | None = Field(init=False) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - """ - Setup the tensor space for the model. - """ - super().setup_tensor_space(tensor_space) - if self.ssm_block_type is not None: - self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) - def _validate(self): with self._set_implicit_default(None): if self.ssm.dt_rank == "auto" or self.ssm.dt_rank is None: diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index b5e77e0f0..e9b18b848 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -21,7 +21,7 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import RMSNormalizationConfig +from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 3ba6b1a62..9b79e74a3 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -1,12 +1,8 @@ import logging import typing -from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.llamba_block import SSMBlock -from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.layers.attention.block import TransformerBlock +from fast_llm.layers.ssm.block import SSMBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -20,89 +16,47 @@ class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[Conf As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. """ - config_class: typing.ClassVar[type[HybridSSMBaseModelConfig]] = HybridSSMBaseModelConfig - _is_setup: bool = False - - def __init__( + def _get_block( self, - config: HybridSSMBaseModelConfig, - distributed_config: DistributedConfig, + block_index: int, + name: str, + return_input: bool = False, ): - super().__init__(config, distributed_config) - - def get_output_layers(self) -> list[Layer]: - """ - Get the output layers of the model. - This includes the language model head and any additional heads specified in the configuration. - """ - layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] - - if self._config.prediction_heads > 1: + if block_index > self._config.transformer.num_layers: + # MTP block block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] - for i in range(1, self._config.prediction_heads): - if block_type == SSMBlockType.transformer: - layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=len(self._config.hybrid_block_layout), - return_input=i != self._config.prediction_heads - 1, - ) - ) - else: - layers.append( - SSMBlock( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=self._config.ssm_block_type.get_mixer_class(), - block_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - ) - layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) - - return layers - - def get_layers(self) -> list[Layer]: - """ - Create a list of layers for the model, interleaving Transformer and Mamba blocks - according to the block pattern. - """ - layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] - - # Create blocks according to pattern - for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == SSMBlockType.transformer: - # Transformer block - layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - else: - layers.append( - SSMBlock( - transformer_config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=self._config.ssm_block_type.get_mixer_class(), - block_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - - # Add the output layers - layers += self.get_output_layers() - - return layers + else: + # Decoder block + block_type = self._config.hybrid_block_layout[block_index - 1] + + lr_scale = ( + None + if self._config.transformer.per_layer_lr_scale is None + else self._config.transformer.per_layer_lr_scale[block_index] + ) + + if block_type == SSMBlockType.transformer: + return TransformerBlock( + self._config.transformer, + self._distributed_config, + self._hidden_dim, + block_index, + name, + lr_scale, + return_input, + ) + else: + return SSMBlock( + self._config.transformer, + self._config.ssm, + self._distributed_config, + self._hidden_dim, + block_index, + name, + lr_scale, + self._config.ssm_block_type.get_mixer_class(), + return_input, + ) class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): @@ -110,7 +64,6 @@ class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): A hybrid model that combines Transformer and SSM blocks. """ - config_class: typing.ClassVar[type[HybridSSMModelConfig]] = HybridSSMModelConfig base_model_class: typing.ClassVar[type[HybridSSMBaseModel]] = HybridSSMBaseModel diff --git a/fast_llm/models/ssm/trainer.py b/fast_llm/models/ssm/trainer.py index efa7b704f..39f589384 100644 --- a/fast_llm/models/ssm/trainer.py +++ b/fast_llm/models/ssm/trainer.py @@ -6,5 +6,4 @@ class HybridSSMTrainer[ConfigType: HybridSSMTrainerConfig](GPTTrainer[ConfigType]): - config_class: typing.ClassVar[type[HybridSSMTrainerConfig]] = HybridSSMTrainerConfig model_class: typing.ClassVar[type[HybridSSMModel]] = HybridSSMModel diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index c17df9d0c..b6180c190 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,14 +1,13 @@ -import abc import functools import logging -import math import typing import torch from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy @@ -139,30 +138,11 @@ def from_dims( **kwargs, ) - @classmethod - def from_tensor_space( - cls, - dim_names: tuple[str, ...], - tensor_space: TensorSpace, - *, - tensor_name: str = "", - dtype: torch.dtype = torch.float32, - reductions: tuple[tuple[str, ReduceOp], ...] = (), - **kwargs: typing.Any, - ) -> typing.Self: - dims = tuple(tensor_space[dim_name] for dim_name in dim_names) - if reductions: - # kwarg not available for ParameterMeta, so we only provide if necessary. - kwargs["reductions"] = tuple( - (tensor_space.distributed_config.get_distributed_dim(name), op) for name, op in reductions - ) - return cls.from_dims(dims, tensor_name=tensor_name, dtype=dtype, **kwargs) - @property def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor) -> tuple[torch.Tensor, ...]: """ Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. Returns a view of the input tensor (or the input tensor itself) when possible. @@ -172,7 +152,7 @@ def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank, modified = distributed.config.tensor_rank == 0, False + is_first_rank, modified = True, False for dim, tensor_dim in enumerate(self.dims): if tensor_dim.is_parallel: @@ -240,12 +220,8 @@ def validate(self, tensor: torch.Tensor, device: torch.device | None = None) -> return validate_tensor(tensor, self, device) def replace_tensor_parallel_dim(self, distributed_dim: DistributedDim) -> "TensorMeta": - """ - Replace the tensor-parallel `DistributedDim` in `meta`, preserving the local size. - Requires for advanced tensor manipulations, - ex. turn tensor-parallel slices of a tensor into slices of a different tensor-parallel size. - Note: This will turn `ParameterMeta` into `TensorMeta` - """ + # Replace the tensor-parallel `DistributedDim` in `meta`. + # Note: This will turn `ParameterMeta` into `TensorMeta` if not self.is_tensor_parallel: return self dims = list(self.dims) @@ -365,70 +341,3 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_copy(grad, param.grad_buffer) # noqa else: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa - - -class Initializer(abc.ABC): - @abc.abstractmethod - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - pass - - requires_global_initialization = False - - -class LambdaInitializer(Initializer): - def __init__( - self, - init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], - requires_global_initialization: bool = False, - ) -> None: - self._init_method = init_method - self.requires_global_initialization = requires_global_initialization - - def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: - return self._init_method(meta, tensor, generator) - - -def init_fill_(value: float) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor.fill_(value) - - return LambdaInitializer(init_) - - -init_zeros_ = init_fill_(0.0) -init_ones_ = init_fill_(1.0) - - -def init_normal_( - mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.normal_(mean, std, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_kaiming_(d_in: float) -> LambdaInitializer: - return init_normal_(0.0, math.sqrt(2.0 / d_in)) - - -def init_uniform_( - low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None -) -> LambdaInitializer: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - tensor = tensor.uniform_(low, high, generator=generator) - if min_val is not None or max_val is not None: - tensor.clamp_(min=min_val, max=max_val) - - return LambdaInitializer(init_) - - -def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: - return init_uniform_( - mean - high, - mean + high, - min_val=None if max_val is None else mean - max_val, - max_val=None if max_val is None else mean + max_val, - ) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 58285d408..51249c3fa 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -348,22 +348,29 @@ def check_equal_nested(config_a, config_b): raise ValueError("\n".join(errors)) -def get_lr_scale( - lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None -) -> float | None | tuple[float | None, ...]: - """ - Combine module and layer lr_scale. - If one is None, return the other. - """ - if lr_scale is None: - return layer_lr_scale - if layer_lr_scale is None: - return lr_scale - if isinstance(lr_scale, float): - return lr_scale * layer_lr_scale - if isinstance(lr_scale, tuple): - return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) - raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") +def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): + # Remove `None` entries. + lr_scales = tuple(lr_scale for lr_scale in lr_scales if lr_scale is not None) + if not lr_scales: + # Everything is None + return None + tuple_length = None + # Check if we have tuples, and determine the length. + for lr_scale in lr_scales: + if isinstance(lr_scale, tuple): + if tuple_length is None: + tuple_length = len(lr_scale) + else: + assert len(lr_scale) == tuple_length + if tuple_length is None: + # No tuple: simple product. + return math.prod(lr_scales) + else: + # Tuple(s): use recursion. + return tuple( + combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) + for i in range(tuple_length) + ) class Interrupter: diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index e61f72244..5a9065454 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -23,8 +23,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.functional.triton.rotary import triton_rotary_ from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import ( +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.rotary.rotary import ( apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, @@ -92,7 +92,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, DefaultRotaryConfig(triton=False) - .build() + .get_layer(None) ._get_frequencies( sequence_length, kv_channels, @@ -103,7 +103,9 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - DefaultRotaryConfig(triton=True).build()._get_frequencies(sequence_length, kv_channels, device="cuda"), + DefaultRotaryConfig(triton=True) + .get_layer(None) + ._get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c494..380ab0550 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -6,10 +6,10 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import TransformerKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda @@ -198,8 +198,8 @@ def test_lm_head( else: loss_mask = None kwargs = { - TransformerKwargs.sequence_first: sequence_first, - TransformerKwargs.grad_output: 1.0, + AttentionKwargs.sequence_first: sequence_first, + AttentionKwargs.grad_output: 1.0, } if config.distillation_model is None: target = torch.randint( diff --git a/tests/test_attention.py b/tests/test_attention.py index dd36b840a..9564a931f 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -2,12 +2,13 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs -from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor +from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.config import AttentionKwargs, TransformerConfig +from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor +from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert @@ -30,19 +31,6 @@ def test_decide_window_size(): assert attention._decide_window_size() == 512 -def test_attention_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - Attention(transformer_conf, tensor_space, 1) - - def test_varlen_preprocessor(): sequence_lengths = [torch.tensor([8, 13, 4, 11], dtype=torch.int32), torch.tensor([11, 16, 9], dtype=torch.int32)] # First micro-sequence: @@ -63,27 +51,24 @@ def test_varlen_preprocessor(): ] micro_sequence_length = 12 sequence_length = 36 - transformer_cfg = TransformerConfig( + transformer_config = TransformerConfig( num_layers=2, num_attention_heads=2, hidden_size=16, use_flash_attention=True, ) - distributed_cfg = DistributedConfig(training_dtype="bfloat16") - distributed = Distributed(distributed_cfg, use_cpu=True) - tensor_space = TensorSpace(distributed_config=distributed_cfg) - tensor_space.setup(distributed) - transformer_cfg.setup_tensor_space(tensor_space) - varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) + distributed_config = DistributedConfig(training_dtype="bfloat16") + distributed = Distributed(distributed_config, use_cpu=True) + varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_config, distributed_config=distributed_config) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - TransformerKwargs.sequence_q_dim: TensorDim(TransformerDimNames.sequence_k, micro_sequence_length), - TransformerKwargs.sequence_k_dim: TensorDim( - TransformerDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + AttentionKwargs.sequence_q_dim: TensorDim(BlockDimNames.sequence_k, micro_sequence_length), + AttentionKwargs.sequence_k_dim: TensorDim( + BlockDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), - TransformerKwargs.sequence_length: sequence_length, - TransformerKwargs.sequence_lengths: sequence_lengths, + AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.sequence_lengths: sequence_lengths, } - varlen_preprocessor.preprocess(None, kwargs) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) - Assert.all_equal(kwargs[TransformerKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) + varlen_preprocessor.preprocess(torch.empty(1, device="cpu"), kwargs) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) + Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_mlp.py b/tests/test_mlp.py deleted file mode 100644 index bcfbaf693..000000000 --- a/tests/test_mlp.py +++ /dev/null @@ -1,33 +0,0 @@ -from fast_llm.layers.transformer.mlp import MLP -from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace - - -def test_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - MLP(transformer_conf, tensor_space, "name") - - -def test_moe_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - num_experts=2, - add_linear_biases=False - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - MixtureOfExpertMLP(transformer_conf, tensor_space, "name") diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 2f125717e..56356cf7a 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,8 +3,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import SSMBlock -from fast_llm.layers.transformer.transformer import TransformerBlock +from fast_llm.layers.attention.block import TransformerBlock +from fast_llm.layers.ssm.block import SSMBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup diff --git a/tests/test_ssms.py b/tests/test_ssms.py deleted file mode 100644 index 694faa55b..000000000 --- a/tests/test_ssms.py +++ /dev/null @@ -1,82 +0,0 @@ -import pathlib - -import pytest -import torch - -from fast_llm.config import NoAutoValidate -from fast_llm.engine.checkpoint.config import CheckpointLoadConfig -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.schedule.config import ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.transformer.config import TransformerKwargs -from fast_llm.models.gpt.config import GPTBatchConfig -from fast_llm.models.ssm.config import LLambaHuggingfaceCheckpointFormat -from fast_llm.models.ssm.model import HybridSSMModel - - -@pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") -@pytest.mark.slow -def test_load_from_llamba_checkpoint(): - """ - Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. - """ - import cartesia_pytorch.Llamba.llamba - - vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json - batch_size = 2 - seq_length = 32 - - path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") - format = LLambaHuggingfaceCheckpointFormat - - x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - - hf_model = cartesia_pytorch.Llamba.llamba.LMHeadModel.from_pretrained(path, strict=True).to("cuda") - parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) - hf_logits = hf_model(x)["logits"].cpu() - del hf_model - torch.cuda.empty_cache() - - # Create checkpoint load config - checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) - # Initialize model - model = HybridSSMModel.from_pretrained(checkpoint_config) - param_sum = 0 - for stage in model.stages: - for fsdp in stage.fsdps: - if hasattr(fsdp, "_weight_shard"): - param_sum += torch.sum(fsdp._weight_shard).item() - assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - - # model = GPTModel.from_pretrained(checkpoint_config) - assert model.config.base_model.vocab_size == vocab_size - schedule_config = ScheduleConfig() - with NoAutoValidate(): - batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - batch_config.setup(DistributedConfig.from_dict({})) - batch_config.validate() - schedule_runner = ScheduleRunner( - config=schedule_config, - multi_stage=model, - distributed_config=model.distributed.config, - ) - schedule = Schedule( - multi_stage=model, - batch_config=batch_config, - schedule_config=schedule_config, - distributed_config=model.distributed.config, - phase=PhaseType.inference, - ) - schedule_runner.setup(model.distributed, optimizer=None) - - common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, - } - input_data = [(x, common_kwargs)] - - schedule_runner.run_step(iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True) - - logits = input_data[0][1]["logits"].cpu() - assert torch.allclose(logits, hf_logits, atol=1e-2) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 836b6b79d..42e588911 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -29,7 +29,6 @@ def set_testing_global_variables(): num_gpus = len(gpus) gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) - # TODO: This might help with some issues, but slows down testing significantly. # os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") # os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 88303a0f4..0dc3462eb 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -13,7 +13,6 @@ from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier from fast_llm.engine.base_model.base_model import BaseModel, Layer from fast_llm.engine.config_utils.logging import configure_logging -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage @@ -33,12 +32,8 @@ def result_path(): def get_base_model(config: FastLLMModelConfig): # Create a base model (and distributed). # Using a full model config so we have the model type and distributed config in the same argument. - distributed = Distributed(config.distributed) - tensor_space = TensorSpace(config.distributed) - config.base_model.setup_tensor_space(tensor_space) - tensor_space.setup(distributed) base_model = config.get_model_class().base_model_class(config.base_model, config.distributed) - base_model.setup(distributed) + base_model.setup(distributed := Distributed(config.distributed)) return base_model, distributed