diff --git a/fast_llm/engine/config_utils/parameter.py b/fast_llm/engine/config_utils/parameter.py new file mode 100644 index 000000000..aa84408d2 --- /dev/null +++ b/fast_llm/engine/config_utils/parameter.py @@ -0,0 +1,67 @@ +import typing + +from fast_llm.config import Config, Field, config_class +from fast_llm.engine.config_utils.initialization import Initializer +from fast_llm.engine.config_utils.tensor_dim import TensorDim + +if typing.TYPE_CHECKING: + from fast_llm.tensor import ParameterMeta + + +@config_class() +class ParameterConfig(Config): + # TODO: Initialization, lr_scale + + def _validate(self) -> None: + pass + + def get_parameter( + self, + dims: tuple[TensorDim, ...], + default_initializer: Initializer, + lr_scale: float | None, + weight_decay: bool = True, + allow_sequence_tensor_parallel: bool = True, + ) -> "ParameterMeta": + from fast_llm.tensor import ParameterMeta + + return ParameterMeta.from_dims( + dims, + init_method=default_initializer, + lr_scale=lr_scale, + weight_decay=weight_decay, + allow_sequence_tensor_parallel=allow_sequence_tensor_parallel, + ) + + +@config_class() +class OptionalParameterConfig(ParameterConfig): + enabled: bool | None = Field( + default=None, + ) + # TODO: Initialization, lr_scale + + def _validate(self) -> None: + pass + + def get_parameter( + self, + dims: tuple[TensorDim, ...], + default_initializer: Initializer, + lr_scale: float | None, + weight_decay: bool = True, + allow_sequence_tensor_parallel: bool = True, + default_enabled: bool = False, + ) -> "ParameterMeta|None": + from fast_llm.tensor import ParameterMeta + + if (self.enabled is None and default_enabled) or self.enabled: + return ParameterMeta.from_dims( + dims, + init_method=default_initializer, + lr_scale=lr_scale, + weight_decay=weight_decay, + allow_sequence_tensor_parallel=allow_sequence_tensor_parallel, + ) + else: + return None diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 35547cd87..40ef07f67 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -22,8 +22,6 @@ def _accumulate_grad_hook(buffer: torch.nn.Parameter, meta: ParameterMeta) -> typing.Callable[[tuple, tuple], None]: def hook(grad_inputs, grad_outputs): # noqa if buffer.grad is not None: - if not meta.auto_grad_accumulation: - raise RuntimeError(f"Unexpected grad for parameter {meta.tensor_name}") accumulate_gradient(buffer, buffer.grad) buffer.grad = None diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 8a4c490c9..dde6bbf94 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -4,7 +4,7 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward @@ -12,7 +12,6 @@ from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName -from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.utils import combine_lr_scales, div try: @@ -112,21 +111,20 @@ def __init__( ) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) - self.query = OutputParallelLinear( + self.query = self._config.query_layer.get_layer( hidden_dim, query_dim, - bias=self._config.add_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_zeros_, + default_weight_initializer=init_method_qkv, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.key_value = OutputParallelLinear( + # TODO: Use value config. + self.key_value = self._config.query_layer.get_layer( hidden_dim, key_value_dim, - bias=self._config.add_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_zeros_, + default_weight_initializer=init_method_qkv, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -136,12 +134,11 @@ def __init__( self._rotary = self._config.rotary.get_layer(kv_channels_dim) # Output. - self.dense = InputParallelLinear( + self.dense = self._config.dense_layer.get_layer( dense_dim, hidden_dim, - bias=self._config.add_dense_bias, - weight_init_method=init_method_std_attn_proj, - bias_init_method=init_zeros_, + default_weight_initializer=init_method_std_attn_proj, + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index e5c638adc..2c6d4f966 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -7,7 +7,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig -from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.common.linear.config import AffineLinearConfig from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) @@ -32,6 +33,23 @@ class AttentionConfig(Config): # TODO: Make mixer class dynamic. _abstract = False + query_layer: AffineLinearConfig = Field( + desc="Configuration for the query layer.", + hint=FieldHint.architecture, + ) + key_layer: AffineLinearConfig = Field( + desc="Configuration for the key layer.", + hint=FieldHint.architecture, + ) + # TODO: Use + value_layer: AffineLinearConfig = Field( + desc="Configuration for the value layer.", + hint=FieldHint.architecture, + ) + dense_layer: AffineLinearConfig = Field( + desc="Initialization configuration for the dense layer.", + hint=FieldHint.feature, + ) # TODO: Review names rotary: RotaryConfig = Field( desc="Configuration for the rotary positional embeddings.", @@ -162,24 +180,6 @@ def projection_size(self): def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) - @property - def add_qkv_bias(self) -> bool: - # TODO: Make this work without inheritance. - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.nowhere: - return False - return True - - @property - def add_dense_bias(self) -> bool: - # TODO: Make this work without inheritance. - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - @config_class() # TODO: Use composition instead diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 29acaadf0..4d7c9ef7c 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,5 +1,3 @@ -import enum - from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.layers.block.mlp.config import MLPConfig @@ -34,12 +32,6 @@ class BlockKwargs: grad_output = "grad_output" -class AddLinearBiasChoices(str, enum.Enum): - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" - - @config_class() # TODO: Use composition instead class BlockConfig(MLPConfig, BaseModelConfig): @@ -76,12 +68,11 @@ class BlockConfig(MLPConfig, BaseModelConfig): desc="Log the memory usage after each operation in a transformer layer..", hint=FieldHint.logging, ) - add_linear_biases: bool | AddLinearBiasChoices = Field( + add_linear_biases: bool = Field( default=True, - desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", + desc="Add biases to linear layers. May be overridden for individual layers.", hint=FieldHint.architecture, ) - # TODO: Move these, not specific to a single block. num_layers: int = Field( default=12, diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 88ce4af10..186c3007d 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -3,6 +3,7 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -21,8 +22,24 @@ class RoutingType(str, enum.Enum): @config_class() class MLPConfig(Config): - # TODO: Review names # TODO: Separate MoE? + # TODO: Review names + # TODO: Separate MoE? _abstract = False + # TODO: Configure experts, gate/up separately? + layer_1: AffineLinearConfig = Field( + desc="Configuration for the first MLP layer.", + hint=FieldHint.architecture, + ) + # TODO: Separate gate and up + layer_2: AffineLinearConfig = Field( + desc="Configuration for the second MLP layer.", + hint=FieldHint.architecture, + ) + router: LinearConfig = Field( + # TODO: Improve default? + desc="Configuration for the MoE router.", + hint=FieldHint.feature, + ) ffn_hidden_size: int = Field( default=None, desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", @@ -143,17 +160,6 @@ class MLPConfig(Config): hint=FieldHint.optional, ) - @property - def add_mlp_bias(self) -> bool: - from fast_llm.layers.block.config import AddLinearBiasChoices - - # TODO: Make this work without inheritance. - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False - def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index 4f7cf2dc4..9298e872b 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -13,7 +13,6 @@ from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.layers.common.linear import Linear from fast_llm.utils import Assert, combine_lr_scales logger = logging.getLogger(__name__) @@ -47,12 +46,10 @@ def __init__( # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) - - self.router = Linear( + self.router = self._config.router.get_layer( self._hidden_dim, TensorDim("router_experts", self._config.num_unshared_experts), - bias=False, - weight_init_method=init_normal_( + default_weight_initializer=init_normal_( std=self._block_config.init_method_std, min_val=self._block_config.init_method_min, max_val=self._block_config.init_method_max, diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index c3a714a42..ba7c45c31 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,7 +2,7 @@ import torch -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig @@ -11,7 +11,6 @@ from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerSubLayerName -from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert, combine_lr_scales @@ -46,21 +45,20 @@ def __init__( lr_scale = combine_lr_scales(self._lr_scale, self._config.mlp_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) - self.layer_1 = LinearBase( + self.layer_1 = self._config.layer_1.get_layer( hidden_dim, intermediate_1_dim, - bias=self._config.add_mlp_bias, - weight_init_method=init_method_1, - bias_init_method=init_zeros_, + default_weight_initializer=init_method_1, + default_add_bias=self._block_config.add_linear_biases, + sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.layer_2 = LinearBase( + self.layer_2 = self._config.layer_1.get_layer( intermediate_2_dim, hidden_dim, - bias=self._config.add_mlp_bias, - weight_init_method=init_method_2, - bias_init_method=init_zeros_, - auto_bias_grad_accumulation=self._distributed_config.tensor_parallel > 1, + default_weight_initializer=init_method_2, + default_add_bias=self._block_config.add_linear_biases, + sequence_parallel=self._sequence_parallel, transposed_weight=True, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py index b51d352bc..ffa40a255 100644 --- a/fast_llm/layers/block/peft.py +++ b/fast_llm/layers/block/peft.py @@ -10,7 +10,7 @@ from fast_llm.utils import div if typing.TYPE_CHECKING: - from fast_llm.layers.common.linear import LinearBase, LinearLike + from fast_llm.layers.common.linear.linear import LinearBase, LinearLike class TransformerSubLayerName(str, enum.Enum): diff --git a/fast_llm/layers/common/linear/__init__.py b/fast_llm/layers/common/linear/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py new file mode 100644 index 000000000..e9dbe9229 --- /dev/null +++ b/fast_llm/layers/common/linear/config.py @@ -0,0 +1,178 @@ +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.engine.config_utils.initialization import Initializer, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.functional.config import ActivationType +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.layers.common.linear.convolution import CausalConv1d + from fast_llm.layers.common.linear.linear import LinearBase + + +@config_class() +class LinearBaseConfig(Config): + """ + Configuration for a linear-like layer without bias. + """ + + weight: ParameterConfig = Field( + desc="Initialization configuration for the weight.", + hint=FieldHint.feature, + ) + + +@config_class() +class AffineLinearBaseConfig(LinearBaseConfig): + """ + Configuration for a linear-like layer with optional bias. + """ + + bias: OptionalParameterConfig = Field( + desc="Use bias.", + hint=FieldHint.architecture, + ) + + +@config_class() +class LinearConfig(LinearBaseConfig): + def get_layer( + self, + in_dim: TensorDim, + out_dim: TensorDim, + *, + default_weight_initializer: Initializer, + sequence_parallel: bool = False, + transposed_weight: bool = False, + lr_scale: float | None, + ) -> "LinearBase": + from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, OutputParallelLinear + + weight = self.weight.get_parameter( + (in_dim, out_dim) if transposed_weight else (out_dim, in_dim), + default_initializer=default_weight_initializer, + lr_scale=lr_scale, + ) + if in_dim.parallel_dim is not None: + assert out_dim.parallel_dim is None + return InputParallelLinear( + weight, + None, + transposed_weight=transposed_weight, + parallel_dim=in_dim.parallel_dim, + sequence_parallel=sequence_parallel, + ) + elif out_dim.parallel_dim is not None: + return OutputParallelLinear( + weight, + None, + transposed_weight=transposed_weight, + parallel_dim=out_dim.parallel_dim, + sequence_parallel=sequence_parallel, + ) + else: + assert not sequence_parallel + return Linear(weight, None, transposed_weight=transposed_weight) + + +@config_class() +class AffineLinearConfig(AffineLinearBaseConfig, LinearConfig): + def get_layer( + self, + in_dim: TensorDim, + out_dim: TensorDim, + *, + default_weight_initializer: Initializer, + default_bias_initializer: Initializer = init_zeros_, + default_add_bias: bool = True, + sequence_parallel: bool = False, + transposed_weight: bool = False, + lr_scale: float | None, + ) -> "LinearBase": + from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, OutputParallelLinear + + weight = self.weight.get_parameter( + (in_dim, out_dim) if transposed_weight else (out_dim, in_dim), + default_initializer=default_weight_initializer, + lr_scale=lr_scale, + ) + bias = self.bias.get_parameter( + (out_dim,), + default_initializer=default_bias_initializer, + lr_scale=lr_scale, + default_enabled=default_add_bias, + ) + if in_dim.parallel_dim is not None: + assert out_dim.parallel_dim is None + return InputParallelLinear( + weight, + bias, + transposed_weight=transposed_weight, + parallel_dim=in_dim.parallel_dim, + sequence_parallel=sequence_parallel, + ) + elif out_dim.parallel_dim is not None: + return OutputParallelLinear( + weight, + bias, + transposed_weight=transposed_weight, + parallel_dim=out_dim.parallel_dim, + sequence_parallel=sequence_parallel, + ) + else: + assert not sequence_parallel + return Linear(weight, bias, transposed_weight=transposed_weight) + + +@config_class() +class CausalConv1dConfig(AffineLinearBaseConfig): + """ + Configuration for a 1d causal convolution, as used in mamba layers. + """ + + kernel_size: int = Field( + default=4, + desc="Convolution kernel size.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + activation: ActivationType | None = Field( + default=None, + hint=FieldHint.architecture, + ) + + def get_layer( + self, + in_dim: TensorDim, + *, + default_weight_initializer: Initializer | None = None, + default_bias_initializer: Initializer | None = None, + default_add_bias: bool = True, + default_activation: ActivationType = ActivationType.identity, + lr_scale: float | None, + ) -> "CausalConv1d": + from fast_llm.layers.common.linear.convolution import CausalConv1d + + kernel_dim = TensorDim("convolution_kernel", self.kernel_size) + + if default_weight_initializer is None: + default_weight_initializer = init_uniform_centered_((in_dim.global_size * kernel_dim.global_size) ** -0.5) + if default_bias_initializer is None: + default_bias_initializer = init_uniform_centered_((in_dim.global_size * kernel_dim.global_size) ** -0.5) + + weight = self.weight.get_parameter( + (in_dim, scalar_dim, kernel_dim), + default_initializer=default_weight_initializer, + lr_scale=lr_scale, + ) + bias = self.bias.get_parameter( + (in_dim,), + default_initializer=default_bias_initializer, + lr_scale=lr_scale, + default_enabled=default_add_bias, + ) + return CausalConv1d( + weight, bias, activation=default_activation if self.activation is None else self.activation + ) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py new file mode 100644 index 000000000..57fdccfd5 --- /dev/null +++ b/fast_llm/layers/common/linear/convolution.py @@ -0,0 +1,53 @@ +import torch + +from fast_llm.functional.config import ActivationType +from fast_llm.tensor import ParameterMeta + +try: + from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa + + _causal_conv1d_available = True +except (ImportError, RuntimeError): + _causal_conv1d_available = False + + +class CausalConv1d(torch.nn.Module): + """ + TODO: Generalize to other convolutions? + """ + + def __init__( + self, + weight: ParameterMeta, + bias: ParameterMeta | None, + *, + activation: ActivationType = ActivationType.identity, + ): + super().__init__() + self.weight = weight + self.bias = bias + self._activation = activation + self.forward = ( + self._forward_causal_conv1d + if _causal_conv1d_available and self._activation in (ActivationType.identity, ActivationType.silu) + else self._forward_torch + ) + + def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: + return self._activation.activation_fn( + torch.nn.functional.conv1d( + input_, + self.weight, + bias=self.bias, + groups=self.weight.size(0), + padding=self.weight.size(2) - 1, + )[..., : input_.size(1)] + ) + + def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: + return _causal_conv1d_fn( + input_, + self.weight.squeeze(1), + self.bias, + activation=(None if self._activation == ActivationType.identity else self._activation.value), + ) diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear/linear.py similarity index 53% rename from fast_llm/layers/common/linear.py rename to fast_llm/layers/common/linear/linear.py index ca807e67c..631193249 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear/linear.py @@ -3,8 +3,7 @@ import torch -from fast_llm.engine.config_utils.initialization import init_zeros_ -from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, @@ -42,37 +41,15 @@ class LinearBase(LinearLike): def __init__( self, - in_dim: TensorDim, - out_dim: TensorDim, + weight: ParameterMeta, + bias: ParameterMeta | None, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, transposed_weight: bool = False, - auto_bias_grad_accumulation: bool = False, - lr_scale: float | None | tuple[float | None, ...] = None, ): super().__init__() + self.weight = weight + self.bias = bias self._transposed_weight = transposed_weight - self._in_dim = in_dim - self._out_dim = out_dim - self._weight_init_method = weight_init_method - self.weight = ParameterMeta.from_dims( - (self._in_dim, self._out_dim) if self._transposed_weight else (self._out_dim, self._in_dim), - init_method=weight_init_method, - auto_grad_accumulation=False, - lr_scale=lr_scale, - ) - if bias: - self.bias = ParameterMeta.from_dims( - (self._out_dim,), - init_method=bias_init_method, - weight_decay=False, - auto_grad_accumulation=auto_bias_grad_accumulation, - lr_scale=lr_scale, - ) - else: - self.bias = None @property def transposed_weight(self) -> bool: @@ -84,29 +61,6 @@ class Linear(LinearBase): A basic linear layer without tensor parallelism. """ - def __init__( - self, - in_dim: TensorDim, - out_dim: TensorDim, - *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, - transposed_weight: bool = False, - lr_scale: float | None | tuple[float | None, ...] = None, - ): - assert not in_dim.is_parallel - assert not out_dim.is_parallel - super().__init__( - in_dim, - out_dim, - bias=bias, - weight_init_method=weight_init_method, - bias_init_method=bias_init_method, - transposed_weight=transposed_weight, - lr_scale=lr_scale, - ) - def forward_only( self, input_: torch.Tensor ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]]: @@ -123,35 +77,23 @@ class OutputParallelLinear(LinearBase): def __init__( self, - in_dim: TensorDim, - out_dim: TensorDim, + weight: ParameterMeta, + bias: ParameterMeta | None, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, transposed_weight: bool = False, + parallel_dim: DistributedDim, sequence_parallel: bool = False, - lr_scale: float | None | tuple[float | None, ...] = None, ): - assert not in_dim.is_parallel - self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size - self._sequence_parallel = sequence_parallel and self._group_size > 1 - super().__init__( - in_dim, - out_dim, - bias=bias, - weight_init_method=weight_init_method, - bias_init_method=bias_init_method, - transposed_weight=transposed_weight, - lr_scale=lr_scale, - ) + super().__init__(weight, bias, transposed_weight=transposed_weight) + self._parallel_dim = parallel_dim + self._sequence_parallel = sequence_parallel and self._parallel_dim.size > 1 def forward_only(self, input_) -> tuple[torch.Tensor, tuple[typing.Any, ...]]: return output_parallel_linear_forward( input_, weight=self.weight, bias=self.bias, - group=self._out_dim.parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, transposed_weight=self._transposed_weight, ) @@ -167,30 +109,16 @@ class InputParallelLinear(LinearBase): def __init__( self, - in_dim: TensorDim, - out_dim: TensorDim, + weight: ParameterMeta, + bias: ParameterMeta | None, *, - bias=True, - weight_init_method, - bias_init_method=init_zeros_, - sequence_parallel: bool = False, transposed_weight: bool = False, - lr_scale: float | None | tuple[float | None, ...] = None, + parallel_dim: DistributedDim, + sequence_parallel: bool = False, ): - assert not out_dim.is_parallel - self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size - self._sequence_parallel = sequence_parallel and self._group_size > 1 - super().__init__( - in_dim, - out_dim, - bias=bias, - weight_init_method=weight_init_method, - bias_init_method=bias_init_method, - transposed_weight=transposed_weight, - # Tensor-parallel bias is computed in _bias_dropout_grad. - auto_bias_grad_accumulation=self._group_size > 1, - lr_scale=lr_scale, - ) + super().__init__(weight, bias, transposed_weight=transposed_weight) + self._parallel_dim = parallel_dim + self._sequence_parallel = sequence_parallel and self._parallel_dim.size > 1 def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: # TODO: Use self._forward instead (broken). @@ -198,13 +126,13 @@ def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | No input_, weight=self.weight, bias=self.bias, - group=self._in_dim.parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, transposed_weight=self._transposed_weight, ) def forward_only(self, input_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None, tuple[typing.Any, ...]]: - group = self._in_dim.parallel_group + group = self._parallel_dim.group output, context = input_parallel_linear_forward( input_, weight=self.weight, diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index a7eba72c8..0dc7b9589 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -215,14 +215,12 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | (hidden_dim,), init_method=weight_init_method, weight_decay=False, - auto_grad_accumulation=implementation == NormalizationImplementation.torch, lr_scale=self._lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), init_method=init_zeros_, weight_decay=False, - auto_grad_accumulation=implementation == NormalizationImplementation.torch, lr_scale=self._lr_scale, ) self._normalized_shape = self.weight.shape @@ -289,7 +287,6 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | (hidden_dim,), init_method=weight_init_method, weight_decay=False, - auto_grad_accumulation=True, lr_scale=lr_scale, ) self._normalized_shape = self.weight.shape diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index 64a2ca57a..7c7834cbd 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -6,7 +6,7 @@ if typing.TYPE_CHECKING: import torch - from fast_llm.layers.common.linear import LinearBase, LinearLike + from fast_llm.layers.common.linear.linear import LinearBase, LinearLike from fast_llm.layers.common.normalization.normalization import Normalization from fast_llm.tensor import ParameterMeta @@ -73,7 +73,7 @@ def apply_linear( if not enabled: return self.apply_other(module) - from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear + from fast_llm.layers.common.linear.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.common.peft.lora import lora_linear if isinstance(module, InputParallelLinear): diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index 9e0ca0dd0..f84967cab 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -4,7 +4,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.common.linear import Linear, LinearBase +from fast_llm.layers.common.linear.linear import Linear, LinearBase def lora_linear( @@ -50,9 +50,6 @@ def lora_linear( transposed_weight=module.transposed_weight, lr_scale=module.weight.lr_scale, ) - # TODO: Implement proper backward pass. - module.lora_0.weight.auto_grad_accumulation = True - module.lora_1.weight.auto_grad_accumulation = True old_forward = module._forward diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index df6969cfc..bfb240107 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,6 +2,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.parameter import ParameterConfig from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl from fast_llm.layers.attention.config import TransformerConfig from fast_llm.layers.attention.rotary.config import NoRotaryConfig @@ -41,23 +42,38 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) + word_embeddings_layer: ParameterConfig = Field( + desc="Configuration for the word embedding (weight).", + hint=FieldHint.architecture, + ) + position_embeddings_layer: ParameterConfig = Field( + desc="Configuration for the word embedding (weight).", + hint=FieldHint.architecture, + ) + output_layer: ParameterConfig = Field( + desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", + hint=FieldHint.architecture, + ) max_position_embeddings: int = Field( default=2048, desc="Number of absolute position embeddings, if applicable.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # TODO: Move to `word_embeddings_layer`/`output_layer`? vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + # TODO: Move to `position_embeddings_layer.enabled`? use_position_embeddings: bool = Field( default=None, desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", hint=FieldHint.architecture, ) + # TODO: Move to `output_layer`? (dynamic type?) tie_word_embeddings: bool = Field( default=True, desc="Tie the output weights (logits) with the vocabulary embedding.", diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index fd4e8412e..270f2630b 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -10,7 +10,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import BlockLayerBase from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs -from fast_llm.tensor import ParameterMeta, TensorMeta +from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" @@ -61,25 +61,25 @@ def __init__( self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size self._vocab_end_index = (self._distributed_config.tensor_rank + 1) * vocab_dim.size - self.word_embeddings_weight = ParameterMeta.from_dims( + self.word_embeddings_weight = self._config.word_embeddings_layer.get_parameter( (vocab_dim, self._hidden_dim), - init_method=init_normal_( + default_initializer=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), - lr_scale=config.embeddings_lr_scale, + lr_scale=self._config.embeddings_lr_scale, ) if self._config.use_absolute_position_embeddings: - self.position_embeddings_weight = ParameterMeta.from_dims( + self.position_embeddings_weight = self._config.position_embeddings_layer.get_parameter( (TensorDim("position_embeddings", self._config.max_position_embeddings), self._hidden_dim), - init_method=init_normal_( + default_initializer=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, max_val=config.init_method_max_embed, ), allow_sequence_tensor_parallel=not config.parallel_embeddings, - lr_scale=config.embeddings_lr_scale, + lr_scale=self._config.embeddings_lr_scale, ) # PEFT. diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 8917feaf6..fd86f47cf 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -2,7 +2,8 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.functional.config import ActivationType +from fast_llm.engine.config_utils.parameter import ParameterConfig +from fast_llm.layers.common.linear.config import AffineLinearConfig, CausalConv1dConfig, LinearConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -46,56 +47,172 @@ def get_init_method(self, scale: float) -> "Initializer": return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) -@config_class() -class SSMConfig(Config): - _abstract = False +@config_class(registry=True) +class MixerConfig(Config): + """ + Base config class for all mixers. + TODO: Generalize to include Attention + """ - # Model dimensions - # TODO: Remove (redundant default) - expansion_factor: int = Field( - default=2, - desc="Expansion factor.", + _abstract = True + + +@config_class() +class SSMConfig(MixerConfig): + # Layers + # [Mamba, Mamba2, DiscreteMamba2] + z_layer: AffineLinearConfig = Field( + desc="Configuration for the z layer.", + hint=FieldHint.architecture, + ) + # [Mamba, Mamba2, DiscreteMamba2] + x_layer: AffineLinearConfig = Field( + desc="Configuration for the x layer.", hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), ) - # head_size [MambaLayer, Mamba2, DiscreteMamba2] + # [Mamba, Mamba2, DiscreteMamba2] + convolution_layer: CausalConv1dConfig = Field( + desc="Configuration for the convolution layer.", + hint=FieldHint.architecture, + ) + # [Mamba, Mamba2, DiscreteMamba2] + d_weight: ParameterConfig = Field( + desc='Configuration for the D "skip" weight.', + hint=FieldHint.architecture, + ) + # [Mamba, Mamba2, DiscreteMamba2] + output_layer: AffineLinearConfig = Field( + desc="Configuration for the output layer.", + hint=FieldHint.architecture, + ) + + # Model dimensions + # head_size [Mamba, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, desc="State size.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # [MambaLayer, Mamba2, DiscreteMamba2] - conv_kernel_dimension: int = Field( - default=4, - desc="Conv kernel dimension.", + # [Mamba, Mamba2, DiscreteMamba2] + # c_size [Mamba, Mamba2, DiscreteMamba2]? + d_inner: int = Field( + default=None, + desc="Inner dimension.", + hint=FieldHint.core, + ) + + # Learning rate + # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] + mamba_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for Mamba blocks.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + + +@config_class() +class MambaBaseConfig(SSMConfig): + """ + Common configuration for Mamba and Mamba2. + """ + + _abstract = False + + # Layers + dt_layer: AffineLinearConfig = Field( + desc="Configuration for the dt layer.", hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), ) - # [MambaLayer, Mamba2] - dt_rank: None | int = Field( + a_log_weight: ParameterConfig = Field( + desc="Configuration for the a_log layer weight.", + hint=FieldHint.architecture, + ) + + # Model dimensions + # [Mamba, Mamba2] + dt_rank: int = Field( default=None, desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)", hint=FieldHint.architecture, ) - # head_groups [DiscreteMamba2] - n_qk_heads: int = Field( - default=32, - desc="Number of QK heads.", + + # Initialization + # dt_bias_initialization_min [Mamba, Mamba2] + dt_min: float = Field( + default=0.001, + desc="Minimum step size for discretization", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + # dt_bias_initialization_max [Mamba, Mamba2] + dt_max: float = Field( + default=0.1, + desc="Maximum step size for discretization", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + # dt_bias_initialization_floor [Mamba, Mamba2] + dt_init_floor: float = Field( + default=1e-4, + desc="Minimum value for initializing dt", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + + def _validate(self) -> None: + super()._validate() + Assert.geq(self.dt_max, self.dt_min) + + +@config_class(dynamic_type={MixerConfig: "mamba"}) +class MambaConfig(MambaBaseConfig): + """ + Configuration for Mamba. + """ + + # Layers + # TODO: Can be confused with `x_layer` + x_projection_layer: LinearConfig = Field( + desc="Configuration for the x projection layer.", hint=FieldHint.architecture, ) - # heads [DiscreteMamba2]# TODO: Remove? (redundant) - n_v_heads: int = Field( - default=32, - desc="Number of V heads.", + + def _validate(self) -> None: + super()._validate() + Assert.none(self.convolution_layer.activation) + # TODO: (Oleksiy) If bias is used there is a problem in the MambaInnerFn.backward for the bias grads. + # I think this bias is not used in other mamba repos. + assert not self.output_layer.bias.enabled + + +@config_class(dynamic_type={MixerConfig: "mamba_2"}) +class Mamba2Config(MambaBaseConfig): + """ + Configuration for Mamba2. + TODO: Actually a variation of Mamba 2. + """ + + _abstract = False + + # Layers + # [Mamba2, DiscreteMamba2] + b_layer: AffineLinearConfig = Field( + desc="Configuration for the b layer.", hint=FieldHint.architecture, ) - # c_size [MambaLayer, Mamba2, DiscreteMamba2]? - d_inner: None | int = Field( - default=None, - desc="Inner dimension.", - hint=FieldHint.core, + # [Mamba2, DiscreteMamba2] + c_layer: AffineLinearConfig = Field( + desc="Configuration for the c layer.", + hint=FieldHint.architecture, ) + dt_input_layer: AffineLinearConfig = Field( + desc="Configuration for the dt input projection layer.", + hint=FieldHint.architecture, + ) + + # Model dimensions # xb_size [Mamba2] d_xb: int = Field( default=None, @@ -104,38 +221,12 @@ class SSMConfig(Config): ) # Model options - # add_bias_linear [Mamba2, DiscreteMamba2] [hard-coded to False in MambaLayer] - add_bias_linear: bool = Field( - default=False, - desc="Whether to use bias in SSM layers", - hint=FieldHint.architecture, - ) - # activation_type [DiscreteMamba2] [hard-coded to silu in MambaLayer, Mamba2] - activation_type: ActivationType = Field( - default=None, - hint=FieldHint.architecture, - ) # repeat_xb_before_conv [Mamba2] repeat_kv_before_conv: bool = Field( default=True, desc="Whether to repeat x and B before (True) or after (False) the conv1d in Mamba2 blocks.", hint=FieldHint.architecture, ) - # chunk_size [DiscreteMamba2] - chunk_size: int = Field( - default=256, - desc="Chunk size for Mamba2 blocks.", - hint=FieldHint.architecture, - ) - - # Learning rate - # lr_scale [MambaLayer, Mamba2, DiscreteMamba2] - mamba_lr_scale: float | None = Field( - default=None, - desc="Learning rate scale for Mamba blocks.", - hint=FieldHint.feature, - valid=skip_valid_if_none(check_field(Assert.geq, 0)), - ) # Initialization # dt_weight_initialization_method [Mamba2] @@ -151,31 +242,43 @@ class SSMConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - # dt_bias_initialization_min [MambaLayer, Mamba2] - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - # dt_bias_initialization_max [MambaLayer, Mamba2] - dt_max: float = Field( - default=0.1, - desc="Maximum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + + +@config_class(dynamic_type={MixerConfig: "discrete_mamba_2"}) +class DiscreteMamba2Config(SSMConfig): + """ + Configuration for DiscreteMamba2. + """ + + _abstract = False + # Layers + # [Mamba2, DiscreteMamba2] + b_layer: AffineLinearConfig = Field( + desc="Configuration for the b layer.", + hint=FieldHint.architecture, ) - # dt_bias_initialization_floor [MambaLayer, Mamba2] - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), + # [Mamba2, DiscreteMamba2] + c_layer: AffineLinearConfig = Field( + desc="Configuration for the c layer.", + hint=FieldHint.architecture, ) - def _validate(self) -> None: - with self._set_implicit_default(): - if self.activation_type is None: - self.activation_type = ActivationType.silu - super()._validate() - Assert.geq(self.dt_max, self.dt_min) + # Model dimensions + # head_groups [DiscreteMamba2] + n_qk_heads: int = Field( + default=32, + desc="Number of QK heads.", + hint=FieldHint.architecture, + ) + # heads [DiscreteMamba2] + n_v_heads: int = Field( + default=32, + desc="Number of V heads.", + hint=FieldHint.architecture, + ) + # chunk_size [DiscreteMamba2] + chunk_size: int = Field( + default=256, + desc="Chunk size for Mamba2 blocks.", + hint=FieldHint.architecture, + ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index f9462a942..46f26ed4d 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -4,15 +4,13 @@ import einops import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_, init_zeros_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.ssm.mamba import init_kaiming_ +from fast_llm.layers.ssm.config import DiscreteMamba2Config from fast_llm.tensor import ParameterMeta from fast_llm.utils import combine_lr_scales, div @@ -27,15 +25,7 @@ _mamba_available = False -try: - from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa - - _causal_conv1d_available = True -except (ImportError, RuntimeError): - _causal_conv1d_available = False - - -class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): +class DiscreteMamba2[ConfigType: DiscreteMamba2Config](BlockLayer[ConfigType]): """ This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py """ @@ -65,7 +55,6 @@ def __init__( heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) - convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) inner_projection_dim = ConcatenatedTensorDim( "inner_projection", @@ -86,49 +75,42 @@ def __init__( # TODO: double check initializations # Projections - self.in_proj = OutputParallelLinear( + + # TODO: Use x_layer, b_layer, c_layer, a_log_layer + self.in_proj = self._config.z_layer.get_layer( hidden_dim, inner_projection_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(block_config.hidden_size), + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - if not config.add_bias_linear: + if self.in_proj.bias is None: + # TODO: Integrate to z_layer config? self.z_bias = ParameterMeta.from_dims( (inner_dim,), weight_decay=False, init_method=init_zeros_, lr_scale=lr_scale, ) - self.conv1d_weight = ParameterMeta.from_dims( - ( - convolution_dim, - scalar_dim, - convolution_kernel_dim, - ), - init_method=init_uniform_centered_( - (convolution_dim.global_size * self._config.conv_kernel_dimension) ** -0.5 - ), - lr_scale=lr_scale, - ) - self.conv1d_bias = ParameterMeta.from_dims( - (convolution_dim,), - init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + + self.convolution = self._config.convolution_layer.get_layer( + convolution_dim, + default_activation=ActivationType.silu, lr_scale=lr_scale, ) # D "skip" parameter - self.D = ParameterMeta.from_dims( + self.D = self._config.d_weight.get_parameter( (heads_dim,), - weight_decay=False, - init_method=init_ones_, + default_initializer=init_ones_, lr_scale=lr_scale, + weight_decay=False, ) - self.out_proj = InputParallelLinear( + self.out_proj = self._config.output_layer.get_layer( inner_dim, hidden_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(self._config.d_inner), + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -167,7 +149,7 @@ def forward( ) # Convolutional layer # xbc: (batch, padded_sequence, local_heads * head_size + 2 * local_head_groups * state) - xBC = self.convolutional_forward(xBC, padded_length) + xBC = self.convolution(xBC.transpose(1, 2)).transpose(1, 2) x, B, C = torch.split( xBC, @@ -210,37 +192,8 @@ def forward( y = y.transpose(0, 1).contiguous() # out_proj: (batch/sequence, sequence/batch, local_heads * head_size) # -> (batch/local_sequence, local_sequence/batch, hidden) - a, b = self.out_proj(y) return self.out_proj(y) @torch.compile def _apply_a_log(self, x: torch.Tensor, A_log: torch.Tensor) -> torch.Tensor: return x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1) - - def convolutional_forward(self, xBC, padded_len): - """Convolutional layer forward pass for the full sequence.""" - if _causal_conv1d_available and self._config.activation_type in ( - ActivationType.silu, - ActivationType.identity, - ): - xBC = _causal_conv1d_fn( - xBC.transpose(1, 2), - self.conv1d_weight.squeeze(1), - self.conv1d_bias, - activation=( - None - if self._config.activation_type == ActivationType.identity - else self._config.activation_type.value - ), - ).transpose(1, 2) - else: - xBC = self._config.activation_type.activation_fn( - torch.nn.functional.conv1d( - xBC.transpose(1, 2), - self.conv1d_weight, - bias=self.conv1d_bias, - groups=self.conv1d_weight.shape[0], - padding=self._config.conv_kernel_dimension - 1, - )[..., :padded_len].transpose(1, 2) - ) - return xBC diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 453c14af6..e7bd7674b 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -5,13 +5,12 @@ import torch from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.config import MambaConfig from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, combine_lr_scales, div @@ -33,8 +32,7 @@ def init_A(d_state, d_inner) -> LambdaInitializer: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa - if tensor.numel() != d_state * d_inner: - raise ValueError("_init_A requires not supported for tensor slices.") + Assert.eq(tensor.numel(), d_state * d_inner) torch.log( torch.arange(1, d_state + 1, dtype=torch.float32, device=tensor.device) .unsqueeze(0) @@ -54,7 +52,7 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return LambdaInitializer(init_) -class Mamba[ConfigType: SSMConfig](BlockLayer[ConfigType]): +class Mamba[ConfigType: MambaConfig](BlockLayer[ConfigType]): _mixer_name: typing.ClassVar[str] = "mamba" def __init__( @@ -69,77 +67,69 @@ def __init__( ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for Mamba" - # TODO: It's not silu? - Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: heads_dim = TensorDim("heads", div(self._config.d_inner, self._config.state_size)) state_dim = TensorDim("state", self._config.state_size) inner_dim = CompositeTensorDim("inner", (heads_dim, state_dim)) - convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) - # TODO: Backward compatibility? - self.in_proj = Linear( + # TODO: Use x_layer + self.in_proj = self._config.z_layer.get_layer( hidden_dim, inner_projection_dim, - bias=False, - weight_init_method=init_kaiming_(hidden_dim.size), + default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_add_bias=self._block_config.add_linear_biases, lr_scale=lr_scale, ) - self.conv1d_weight = ParameterMeta.from_dims( - ( - inner_dim, - scalar_dim, - convolution_kernel_dim, - ), - init_method=init_kaiming_(inner_dim.size), + self.convolution = self._config.convolution_layer.get_layer( + inner_dim, + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_add_bias=False, + default_activation=ActivationType.silu, lr_scale=lr_scale, ) - self.x_proj = Linear( + self.x_proj = self._config.x_projection_layer.get_layer( inner_dim, x_projection_dim, - weight_init_method=init_kaiming_(inner_dim.size), - bias=False, + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), lr_scale=lr_scale, ) - self.x_proj.weight.auto_grad_accumulation = True + # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 - self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, dt_rank_dim), - init_method=init_kaiming_(self._config.dt_rank), - lr_scale=lr_scale, - ) - self.dt_proj_bias = ParameterMeta.from_dims( - (inner_dim,), - init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + self.dt_proj = self._config.dt_layer.get_layer( + dt_rank_dim, + inner_dim, + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_bias_initializer=init_dtprojbias( + self._config.dt_max, self._config.dt_min, self._config.dt_init_floor + ), lr_scale=lr_scale, ) - self.A_log = ParameterMeta.from_dims( + self.A_log = self._config.a_log_weight.get_parameter( (inner_dim, state_dim), - weight_decay=False, - init_method=init_A(self._config.state_size, inner_dim.size), + default_initializer=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, + weight_decay=False, ) # D "skip" parameter - self.D = ParameterMeta.from_dims( + self.D = self._config.d_weight.get_parameter( (inner_dim,), - weight_decay=False, - init_method=init_ones_, + default_initializer=init_ones_, lr_scale=lr_scale, + weight_decay=False, ) - self.out_proj = Linear( + self.out_proj = self._config.output_layer.get_layer( inner_dim, hidden_dim, - bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=init_kaiming_(hidden_dim.size), + default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_add_bias=False, lr_scale=lr_scale, ) - self.out_proj.weight.auto_grad_accumulation = True def forward( self, @@ -152,26 +142,22 @@ def forward( in_proj = self.in_proj(input_).permute((1, 2, 0) if kwargs[BlockKwargs.sequence_first] else (0, 2, 1)) # In the backward pass we write dx and dz next to each other to avoid torch.cat - # not, if we wanbt to support inference, we would need to imp.lement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s + # If we wanbt to support inference, we would need to implement slow path here, see https://github.com/Zyphra/Zamba2/blob/1b182f40f2257f822cc06dd785df53d67d691a15/mamba_layer.py#L172s out = _mamba_inner_fn( in_proj, - self.conv1d_weight, - None, + self.convolution.weight, + self.convolution.bias, self.x_proj.weight, - self.dt_proj_weight, + self.dt_proj.weight, self.out_proj.weight, self.out_proj.bias, # is None here -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), - delta_bias=self.dt_proj_bias.float(), + delta_bias=None if self.dt_proj.bias is None else self.dt_proj.bias.float(), delta_softplus=True, ) if kwargs[BlockKwargs.sequence_first]: out = out.transpose(0, 1) return out, None - - -def init_kaiming_(d_in: float) -> LambdaInitializer: - return init_normal_(0.0, math.sqrt(2.0 / d_in)) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 2659e415f..90fdb343a 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -3,17 +3,15 @@ import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.config_utils.initialization import init_normal_, init_ones_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ -from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, combine_lr_scales, div +from fast_llm.layers.ssm.config import Mamba2Config +from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias +from fast_llm.utils import combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -22,17 +20,10 @@ except (ImportError, RuntimeError): _mamba_available = False -try: - from causal_conv1d import causal_conv1d_fn as _causal_conv1d_fn # noqa - - _causal_conv1d_available = True -except (ImportError, RuntimeError): - _causal_conv1d_available = False - logger = logging.getLogger(__name__) -class Mamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): +class Mamba2[ConfigType: Mamba2Config](BlockLayer[ConfigType]): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ @@ -50,7 +41,6 @@ def __init__( lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) - Assert.eq(self._config.activation_type, ActivationType.silu) num_heads = div(self._config.d_inner, self._config.state_size) num_head_groups = div(self._config.d_xb, self._config.state_size) @@ -66,7 +56,6 @@ def __init__( inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, state_dim)) xb_dim = CompositeTensorDim("xb", (head_groups_dim, state_dim)) - convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) # DT projection dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) @@ -81,73 +70,61 @@ def __init__( self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size - conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim + convolution_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) - self.conv1d_weight = ParameterMeta.from_dims( - ( - conv1d_dim, - scalar_dim, - convolution_kernel_dim, - ), - init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), - lr_scale=lr_scale, - ) - self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), - init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + self.convolution = self._config.convolution_layer.get_layer( + convolution_dim, + default_activation=ActivationType.silu, lr_scale=lr_scale, ) - self.in_proj = OutputParallelLinear( + # TODO: Use x_layer, b_layer, c_layer + self.in_proj = self._config.z_layer.get_layer( hidden_dim, inner_projection_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(block_config.hidden_size), + default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.dt_in_proj = Linear( + self.dt_in_proj = self._config.dt_input_layer.get_layer( hidden_dim, dt_rank_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(block_config.hidden_size), + default_weight_initializer=init_normal_(0, (2 / hidden_dim.global_size) ** 0.5), + default_add_bias=self._block_config.add_linear_biases, lr_scale=lr_scale, ) - self.dt_proj = OutputParallelLinear( + self.dt_proj = self._config.dt_layer.get_layer( dt_rank_dim, inner_dim, - bias=False, - # Initialize special dt projection to preserve variance at initialization - weight_init_method=self._config.dt_init.get_init_method( + default_weight_initializer=self._config.dt_init.get_init_method( self._config.dt_rank**-0.5 * self._config.dt_scale ), + default_bias_initializer=init_dtprojbias( + self._config.dt_max, self._config.dt_min, self._config.dt_init_floor + ), sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - # define bias outside the linear layer since it's also used in the selective_scan_fn - self.dt_proj_bias = ParameterMeta.from_dims( - (inner_dim,), - init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), - lr_scale=lr_scale, - ) - self.A_log = ParameterMeta.from_dims( + self.A_log = self._config.a_log_weight.get_parameter( (inner_dim, state_dim), - init_method=init_A(self._config.state_size, self._config.d_inner), + default_initializer=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, ) - self.D = ParameterMeta.from_dims( + # D "skip" parameter + self.D = self._config.d_weight.get_parameter( (inner_dim,), - weight_decay=False, - init_method=init_ones_, + default_initializer=init_ones_, lr_scale=lr_scale, + weight_decay=False, ) - self.out_proj = InputParallelLinear( + self.out_proj = self._config.output_layer.get_layer( inner_dim, hidden_dim, - bias=config.add_bias_linear, - weight_init_method=init_kaiming_(self._config.d_inner), + default_weight_initializer=init_normal_(0, (2 / self._config.d_inner) ** 0.5), + default_add_bias=self._block_config.add_linear_biases, sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) @@ -173,12 +150,11 @@ def forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - assert _causal_conv1d_available # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) # -> (batch/sequence, sequence/batch, local_inner_projection) inner_projection = self.in_proj(input_) - dt = self.dt_proj(self.dt_in_proj(input_)) + self.dt_proj_bias + dt = self.dt_proj(self.dt_in_proj(input_)) # Standardize to (batch, sequence, local_inner_projection) if kwargs[BlockKwargs.sequence_first]: inner_projection = inner_projection.transpose(0, 1) @@ -198,16 +174,15 @@ def forward( # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) if self._config.repeat_kv_before_conv: - x = ( + x = self.convolution( x.unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") else: - x = _causal_conv1d_fn(x=x, weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu") x = ( - x.unflatten(1, (self._local_head_groups, self._config.state_size)) + self.convolution(x) + .unflatten(1, (self._local_head_groups, self._config.state_size)) .repeat_interleave(self._group_heads, 1, output_size=self._local_heads) .flatten(1, 2) ) @@ -240,7 +215,7 @@ def forward( c, self.D.float(), z, - delta_bias=self.dt_proj_bias.float(), + delta_bias=self.dt_proj.bias.float(), delta_softplus=True, ) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 36975dea1..789201acc 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -199,19 +199,22 @@ def _create_transformer_layer_converters( ( f"{fast_llm_layer_name}.self_attn.query", f"{hf_layer_name}.self_attn.q_proj", - transformer_config.add_qkv_bias, + # TODO: Fix + transformer_config.add_linear_biases, QueryWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.key_value", (f"{hf_layer_name}.self_attn.k_proj", f"{hf_layer_name}.self_attn.v_proj"), - transformer_config.add_qkv_bias, + # TODO: Fix + transformer_config.add_linear_biases, KeyValueWeightConverter, ), ( f"{fast_llm_layer_name}.self_attn.dense", f"{hf_layer_name}.self_attn.o_proj", - transformer_config.add_dense_bias, + # TODO: Fix + transformer_config.add_linear_biases, WeightConverter, ), # Norm @@ -241,13 +244,15 @@ def _create_transformer_layer_converters( converters += self._get_weight_and_bias_converters( f"{fast_llm_layer_name}.mlp.layer_1", (), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, cls=IgnoreExportWeightConverter, ) converters += self._get_weight_and_bias_converters( f"{fast_llm_layer_name}.mlp.layer_2", (), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, cls=IgnoreExportWeightConverter, ) converters += [IgnoreExportWeightConverter(f"{fast_llm_layer_name}.mlp.router.weight", ())] @@ -344,12 +349,17 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig transformer_config: TransformerConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_mlp_bias + f"{fast_llm_prefix}.mlp.layer_1", + f"{hf_prefix}.mlp.c_fc", + # TODO: Fix + transformer_config.add_linear_biases, + Ω, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, MLPLayer2Converter, ), ] @@ -463,13 +473,15 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, MLPLayer2Converter, ), ] @@ -531,13 +543,15 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, MLPLayer2Converter, ), ] @@ -641,13 +655,15 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + # TODO: Fix + transformer_config.add_linear_biases, MLPLayer2Converter, ), ] diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index ef4325552..e47296b71 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -8,7 +8,7 @@ from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig +from fast_llm.layers.ssm.config import MixerConfig, SSMBlockType from fast_llm.models.gpt.config import ( GPTBaseModelConfig, GPTBatchConfig, @@ -29,7 +29,7 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False - ssm: SSMConfig = Field( + ssm: MixerConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) @@ -47,14 +47,13 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): ssm_block_type: SSMBlockType | None = Field(init=False) def _validate(self): - with self._set_implicit_default(None): - if self.ssm.dt_rank == "auto" or self.ssm.dt_rank is None: - self.ssm.dt_rank = math.ceil(self.transformer.hidden_size / 16) with self._set_implicit_default(): - if self.ssm.d_xb is None: + if getattr(self.ssm, "dt_rank", ...) is None: + self.ssm.dt_rank = math.ceil(self.transformer.hidden_size / 16) + if getattr(self.ssm, "d_xb", ...) is None: self.ssm.d_xb = self.transformer.hidden_size - if self.ssm.d_inner is None: - self.ssm.d_inner = int(self.ssm.expansion_factor * self.transformer.hidden_size) + if getattr(self.ssm, "d_inner", ...) is None: + self.ssm.d_inner = int(2 * self.transformer.hidden_size) if self.hybrid_block_layout is None: with self._set_implicit_default(): diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index e9b18b848..5e05364a4 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -224,7 +224,7 @@ def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] num_layers = self._model.config.base_model.transformer.num_layers - ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + ssm_bias: bool = self._model.config.base_model.transformer.add_linear_biases for i in range(num_layers): # SSM @@ -389,7 +389,7 @@ def _create_weight_converters(self) -> list[WeightConverter]: converters = [] num_layers = self._model.config.base_model.transformer.num_layers norm_bias: bool = False - ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear + ssm_bias: bool = self._model.config.base_model.transformer.add_linear_biases # Embedding and output if self._model.config.base_model.tie_word_embeddings: diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b6180c190..f56834c8a 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -244,7 +244,6 @@ def __init__( lr_scale: float | None | tuple[float | None, ...] = None, requires_grad: bool = True, allow_sequence_tensor_parallel: bool = True, - auto_grad_accumulation: bool = True, allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) @@ -259,9 +258,6 @@ def __init__( # Almost all parameters are either tensor-parallel or process tensor-sequence-parallel inputs. # Except for position embedding weights self.sequence_tensor_parallel = allow_sequence_tensor_parallel and not self.is_tensor_parallel - # If true, grad accumulation is handled automatically by copying or adding to the grad_buffer. - # Can be disabled to allow for a more efficient implementation that accumulates directly to it. - self.auto_grad_accumulation = auto_grad_accumulation # Disable the check that gradients have been computed for this parameter before the gradient reduction, # to support cases where gradients may not always be computed (ex. MOE layers). self.allow_no_grad = allow_no_grad @@ -281,7 +277,6 @@ def __new__( weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, - auto_grad_accumulation: bool = True, allow_no_grad: bool = False, ): return super().__new__( diff --git a/tests/models/distributed_test_model.py b/tests/models/distributed_test_model.py index 564920bd5..890a75077 100644 --- a/tests/models/distributed_test_model.py +++ b/tests/models/distributed_test_model.py @@ -25,6 +25,7 @@ def main(args: list[str] | None = None) -> None: world_size = DistributedConfig.default_world_size rank = DistributedConfig.default_rank group = pool.get_process_group(range(world_size), rank) + safe_barrier(group, "start") for name, config in DISTRIBUTED_TESTING_CONFIGS.items(): if model_testing_config.should_skip(config): diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 031ec6f97..6f4631320 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -80,6 +80,7 @@ def test_resume(run_test_script_for_all_models, compare_results_for_all_models, @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_resume_frozen(run_test_script_for_all_models, prepare_resume): + # TODO: No more frozen weights? distributed_testing_config = DistributedTestingConfig( name="resume_frozen", compare="checkpoint_and_eval", config_args=_CHECKPOINT_AND_EVAL_ARGS ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index e9bdeba97..83f6b50b2 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -475,6 +475,7 @@ def _update_and_add_testing_config( "llamba", model_type="hybrid_ssm", extra_args=[ + "model.base_model.ssm.type=mamba", "model.base_model.hybrid_block_layout=['t','m']", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=16", @@ -503,6 +504,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + "model.base_model.ssm.type=mamba_2", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", "model.base_model.ssm.d_xb=256", @@ -534,6 +536,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2d']", + "model.base_model.ssm.type=discrete_mamba_2", "model.base_model.ssm.d_inner=512", "model.base_model.ssm.state_size=8", "model.base_model.ssm.n_qk_heads=8",