Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
82eed2b
TP mamba
jlamypoirier Jul 21, 2025
4e310c7
TP mamba
jlamypoirier Jul 22, 2025
3cc4118
fix
jlamypoirier Jul 22, 2025
9f7f75c
fix
jlamypoirier Jul 22, 2025
4054e04
fixes
jlamypoirier Jul 23, 2025
0014cc6
fix
jlamypoirier Jul 23, 2025
47ad548
fixes
jlamypoirier Jul 23, 2025
6a074fa
fixes
jlamypoirier Jul 23, 2025
d66651f
Update external
jlamypoirier Jul 23, 2025
50083ba
SSM debugging
jlamypoirier Jul 24, 2025
5006328
Merge branch 'main' into tp_mamba
jlamypoirier Jul 24, 2025
13176bd
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
7b32699
stuff
jlamypoirier Jul 24, 2025
73f591f
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
1feccc8
stuff
jlamypoirier Jul 24, 2025
e528b50
misc
jlamypoirier Jul 24, 2025
b49c42f
misc
jlamypoirier Jul 24, 2025
bb4dcd9
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
c1b7f44
misc
jlamypoirier Jul 24, 2025
31f5d41
misc
jlamypoirier Jul 24, 2025
051bb07
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
0a9ff25
misc
jlamypoirier Jul 24, 2025
e7d9636
Parallel discrete mamba 2
jlamypoirier Jul 24, 2025
c14b764
Mamba 2, misc
jlamypoirier Jul 25, 2025
b605bd2
doc
jlamypoirier Jul 25, 2025
5eea938
fix
jlamypoirier Jul 28, 2025
0a3e2a7
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
2e6d082
fixes
jlamypoirier Jul 28, 2025
b6c8613
misc
jlamypoirier Jul 28, 2025
f0c04cf
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Jul 28, 2025
acdfab1
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
e536af9
Concatenated dim
jlamypoirier Jul 28, 2025
017f5cc
fixes
jlamypoirier Jul 28, 2025
93e4c94
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 28, 2025
c41efc2
doc
jlamypoirier Jul 28, 2025
0b8bd5d
cleanup
jlamypoirier Jul 28, 2025
02f8af5
Block interface
jlamypoirier Jul 29, 2025
6bf06d6
fix
jlamypoirier Jul 29, 2025
2ddc3a7
fix
jlamypoirier Jul 29, 2025
c0f1597
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 29, 2025
b2f4476
Merge branch 'tp_mamba' into block_interface
jlamypoirier Jul 29, 2025
ce70b16
fixes
jlamypoirier Jul 29, 2025
a9f733d
fix
jlamypoirier Jul 29, 2025
cef7c15
fix
jlamypoirier Jul 30, 2025
a5eb076
stuff
jlamypoirier Jul 31, 2025
ab484ac
Revert "stuff"
jlamypoirier Jul 31, 2025
b68d360
stuff
jlamypoirier Jul 31, 2025
82c9dbd
misc
jlamypoirier Jul 31, 2025
9fbb9ff
misc
jlamypoirier Jul 31, 2025
44df195
misc
jlamypoirier Jul 31, 2025
3bb03cb
misc
jlamypoirier Jul 31, 2025
98bae95
misc
jlamypoirier Jul 31, 2025
fd731ef
fixes
jlamypoirier Aug 1, 2025
f483321
fixes
jlamypoirier Aug 1, 2025
5a0eabc
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Aug 8, 2025
dd288df
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 8, 2025
defd6e0
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 8, 2025
8abf258
fixes
jlamypoirier Aug 8, 2025
c16c00f
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 8, 2025
07c9211
stuff
jlamypoirier Aug 8, 2025
be99372
Merge branch 'main' into debug_mamba
jlamypoirier Aug 12, 2025
a505f3a
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 12, 2025
0cc859a
Merge remote-tracking branch 'origin/main' into concatenated_dim
jlamypoirier Aug 12, 2025
bd4ff0d
doc
jlamypoirier Aug 12, 2025
fd3307d
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 12, 2025
0e2e124
stuff
jlamypoirier Aug 12, 2025
0a5e458
Remove tensor space, fixes
jlamypoirier Aug 14, 2025
797bd73
stuff
jlamypoirier Aug 14, 2025
c0a3782
stuff
jlamypoirier Aug 15, 2025
e60ded4
stuff
jlamypoirier Aug 15, 2025
1483bcc
stuff
jlamypoirier Aug 15, 2025
4deb501
misc
jlamypoirier Aug 15, 2025
fc809e0
Misc, tests pass
jlamypoirier Aug 15, 2025
cdb6710
misc
jlamypoirier Aug 20, 2025
9ce72e0
Move files
jlamypoirier Aug 20, 2025
065b34f
misc
jlamypoirier Aug 20, 2025
4510b7b
misc
jlamypoirier Aug 20, 2025
9a2a7a2
Pr comments
jlamypoirier Aug 21, 2025
8c382a9
Cleanup
jlamypoirier Aug 21, 2025
019e43d
Cleanup
jlamypoirier Aug 21, 2025
3e0f3e5
Cleanup
jlamypoirier Aug 21, 2025
90a3c98
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 21, 2025
39960ce
Cleanup
jlamypoirier Aug 21, 2025
1abdd19
fixes
jlamypoirier Aug 21, 2025
7c24292
fixes
jlamypoirier Aug 21, 2025
af2964b
fixes
jlamypoirier Aug 21, 2025
0e62f7d
Merge branch 'tp_mamba' into block_interface
jlamypoirier Aug 21, 2025
654aeeb
Fix merge
jlamypoirier Aug 21, 2025
3f4a8ba
fix
jlamypoirier Aug 27, 2025
9741ba0
stuff
jlamypoirier Aug 27, 2025
be69677
fixes
jlamypoirier Aug 27, 2025
82a70aa
Simplify bias options
jlamypoirier Aug 27, 2025
188587e
Merge branch 'main' into concatenated_dim
jlamypoirier Sep 17, 2025
e111509
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Sep 17, 2025
95e0231
Merge branch 'tp_mamba' into block_interface
jlamypoirier Sep 17, 2025
e076c7a
Merge remote-tracking branch 'origin/main' into block_interface
jlamypoirier Sep 18, 2025
2315ac4
Merge branch 'block_interface' into block_interface_weight
jlamypoirier Sep 18, 2025
79356f7
Merge remote-tracking branch 'origin/main' into block_interface_weight
jlamypoirier Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions fast_llm/engine/config_utils/parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import typing

from fast_llm.config import Config, Field, config_class
from fast_llm.engine.config_utils.initialization import Initializer
from fast_llm.engine.config_utils.tensor_dim import TensorDim

if typing.TYPE_CHECKING:
from fast_llm.tensor import ParameterMeta


@config_class()
class ParameterConfig(Config):
# TODO: Initialization, lr_scale

def _validate(self) -> None:
pass

def get_parameter(
self,
dims: tuple[TensorDim, ...],
default_initializer: Initializer,
lr_scale: float | None,
weight_decay: bool = True,
allow_sequence_tensor_parallel: bool = True,
) -> "ParameterMeta":
from fast_llm.tensor import ParameterMeta

return ParameterMeta.from_dims(
dims,
init_method=default_initializer,
lr_scale=lr_scale,
weight_decay=weight_decay,
allow_sequence_tensor_parallel=allow_sequence_tensor_parallel,
)


@config_class()
class OptionalParameterConfig(ParameterConfig):
enabled: bool | None = Field(
default=None,
)
# TODO: Initialization, lr_scale

def _validate(self) -> None:
pass

def get_parameter(
self,
dims: tuple[TensorDim, ...],
default_initializer: Initializer,
lr_scale: float | None,
weight_decay: bool = True,
allow_sequence_tensor_parallel: bool = True,
default_enabled: bool = False,
) -> "ParameterMeta|None":
from fast_llm.tensor import ParameterMeta

if (self.enabled is None and default_enabled) or self.enabled:
return ParameterMeta.from_dims(
dims,
init_method=default_initializer,
lr_scale=lr_scale,
weight_decay=weight_decay,
allow_sequence_tensor_parallel=allow_sequence_tensor_parallel,
)
else:
return None
2 changes: 0 additions & 2 deletions fast_llm/engine/multi_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
def _accumulate_grad_hook(buffer: torch.nn.Parameter, meta: ParameterMeta) -> typing.Callable[[tuple, tuple], None]:
def hook(grad_inputs, grad_outputs): # noqa
if buffer.grad is not None:
if not meta.auto_grad_accumulation:
raise RuntimeError(f"Unexpected grad for parameter {meta.tensor_name}")
accumulate_gradient(buffer, buffer.grad)
buffer.grad = None

Expand Down
25 changes: 11 additions & 14 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

from fast_llm.core.distributed import set_generator
from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim
from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_
from fast_llm.engine.config_utils.initialization import init_normal_
from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.autograd import wrap_forward_backward
from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs
from fast_llm.layers.block.block import BlockLayer
from fast_llm.layers.block.config import BlockConfig, BlockDimNames
from fast_llm.layers.block.peft import TransformerSubLayerName
from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear
from fast_llm.utils import combine_lr_scales, div

try:
Expand Down Expand Up @@ -112,21 +111,20 @@ def __init__(
)

# TODO: Merge the query and key-value computations? (harder with sequence parallel.)
self.query = OutputParallelLinear(
self.query = self._config.query_layer.get_layer(
hidden_dim,
query_dim,
bias=self._config.add_qkv_bias,
weight_init_method=init_method_qkv,
bias_init_method=init_zeros_,
default_weight_initializer=init_method_qkv,
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
)
self.key_value = OutputParallelLinear(
# TODO: Use value config.
self.key_value = self._config.query_layer.get_layer(
hidden_dim,
key_value_dim,
bias=self._config.add_qkv_bias,
weight_init_method=init_method_qkv,
bias_init_method=init_zeros_,
default_weight_initializer=init_method_qkv,
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
)
Expand All @@ -136,12 +134,11 @@ def __init__(
self._rotary = self._config.rotary.get_layer(kv_channels_dim)

# Output.
self.dense = InputParallelLinear(
self.dense = self._config.dense_layer.get_layer(
dense_dim,
hidden_dim,
bias=self._config.add_dense_bias,
weight_init_method=init_method_std_attn_proj,
bias_init_method=init_zeros_,
default_weight_initializer=init_method_std_attn_proj,
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
)
Expand Down
38 changes: 19 additions & 19 deletions fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.functional.config import TritonConfig
from fast_llm.layers.attention.rotary.config import RotaryConfig
from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockKwargs
from fast_llm.layers.block.config import BlockConfig, BlockKwargs
from fast_llm.layers.common.linear.config import AffineLinearConfig
from fast_llm.utils import Assert, div

logger = logging.getLogger(__name__)
Expand All @@ -32,6 +33,23 @@ class AttentionConfig(Config):
# TODO: Make mixer class dynamic.
_abstract = False

query_layer: AffineLinearConfig = Field(
desc="Configuration for the query layer.",
hint=FieldHint.architecture,
)
key_layer: AffineLinearConfig = Field(
desc="Configuration for the key layer.",
hint=FieldHint.architecture,
)
# TODO: Use
value_layer: AffineLinearConfig = Field(
desc="Configuration for the value layer.",
hint=FieldHint.architecture,
)
dense_layer: AffineLinearConfig = Field(
desc="Initialization configuration for the dense layer.",
hint=FieldHint.feature,
)
# TODO: Review names
rotary: RotaryConfig = Field(
desc="Configuration for the rotary positional embeddings.",
Expand Down Expand Up @@ -162,24 +180,6 @@ def projection_size(self):
def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool:
return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16)

@property
def add_qkv_bias(self) -> bool:
# TODO: Make this work without inheritance.
if isinstance(self.add_linear_biases, bool):
return self.add_linear_biases
if self.add_linear_biases == AddLinearBiasChoices.nowhere:
return False
return True

@property
def add_dense_bias(self) -> bool:
# TODO: Make this work without inheritance.
if isinstance(self.add_linear_biases, bool):
return self.add_linear_biases
if self.add_linear_biases == AddLinearBiasChoices.everywhere:
return True
return False


@config_class()
# TODO: Use composition instead
Expand Down
13 changes: 2 additions & 11 deletions fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import enum

from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.engine.base_model.config import BaseModelConfig
from fast_llm.layers.block.mlp.config import MLPConfig
Expand Down Expand Up @@ -34,12 +32,6 @@ class BlockKwargs:
grad_output = "grad_output"


class AddLinearBiasChoices(str, enum.Enum):
nowhere = "nowhere"
everywhere = "everywhere"
only_attn_qkv = "only_attn_qkv"


@config_class()
# TODO: Use composition instead
class BlockConfig(MLPConfig, BaseModelConfig):
Expand Down Expand Up @@ -76,12 +68,11 @@ class BlockConfig(MLPConfig, BaseModelConfig):
desc="Log the memory usage after each operation in a transformer layer..",
hint=FieldHint.logging,
)
add_linear_biases: bool | AddLinearBiasChoices = Field(
add_linear_biases: bool = Field(
default=True,
desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.",
desc="Add biases to linear layers. May be overridden for individual layers.",
hint=FieldHint.architecture,
)

# TODO: Move these, not specific to a single block.
num_layers: int = Field(
default=12,
Expand Down
30 changes: 18 additions & 12 deletions fast_llm/layers/block/mlp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.functional.config import ActivationType, MLPRecomputeLevel
from fast_llm.layers.common.linear.config import AffineLinearConfig, LinearConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
Expand All @@ -21,8 +22,24 @@ class RoutingType(str, enum.Enum):

@config_class()
class MLPConfig(Config):
# TODO: Review names # TODO: Separate MoE?
# TODO: Review names
# TODO: Separate MoE?
_abstract = False
# TODO: Configure experts, gate/up separately?
layer_1: AffineLinearConfig = Field(
desc="Configuration for the first MLP layer.",
hint=FieldHint.architecture,
)
# TODO: Separate gate and up
layer_2: AffineLinearConfig = Field(
desc="Configuration for the second MLP layer.",
hint=FieldHint.architecture,
)
router: LinearConfig = Field(
# TODO: Improve default?
desc="Configuration for the MoE router.",
hint=FieldHint.feature,
)
ffn_hidden_size: int = Field(
default=None,
desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.",
Expand Down Expand Up @@ -143,17 +160,6 @@ class MLPConfig(Config):
hint=FieldHint.optional,
)

@property
def add_mlp_bias(self) -> bool:
from fast_llm.layers.block.config import AddLinearBiasChoices

# TODO: Make this work without inheritance.
if isinstance(self.add_linear_biases, bool):
return self.add_linear_biases
if self.add_linear_biases == AddLinearBiasChoices.everywhere:
return True
return False

def _validate(self) -> None:
with self._set_implicit_default():
if self.activation_type is None:
Expand Down
7 changes: 2 additions & 5 deletions fast_llm/layers/block/mlp/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType
from fast_llm.layers.block.mlp.mlp import MLPBase
from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss
from fast_llm.layers.common.linear import Linear
from fast_llm.utils import Assert, combine_lr_scales

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,12 +46,10 @@ def __init__(
# TODO: Implement?
assert not config.add_linear_biases, "Biases not supported for MoE."
super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale)

self.router = Linear(
self.router = self._config.router.get_layer(
self._hidden_dim,
TensorDim("router_experts", self._config.num_unshared_experts),
bias=False,
weight_init_method=init_normal_(
default_weight_initializer=init_normal_(
std=self._block_config.init_method_std,
min_val=self._block_config.init_method_min,
max_val=self._block_config.init_method_max,
Expand Down
20 changes: 9 additions & 11 deletions fast_llm/layers/block/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_
from fast_llm.engine.config_utils.initialization import init_normal_
from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.config import TritonConfig
Expand All @@ -11,7 +11,6 @@
from fast_llm.layers.block.config import BlockConfig
from fast_llm.layers.block.mlp.config import MLPConfig
from fast_llm.layers.block.peft import TransformerSubLayerName
from fast_llm.layers.common.linear import LinearBase
from fast_llm.utils import Assert, combine_lr_scales


Expand Down Expand Up @@ -46,21 +45,20 @@ def __init__(
lr_scale = combine_lr_scales(self._lr_scale, self._config.mlp_lr_scale)

# So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size)
self.layer_1 = LinearBase(
self.layer_1 = self._config.layer_1.get_layer(
hidden_dim,
intermediate_1_dim,
bias=self._config.add_mlp_bias,
weight_init_method=init_method_1,
bias_init_method=init_zeros_,
default_weight_initializer=init_method_1,
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
lr_scale=lr_scale,
)
self.layer_2 = LinearBase(
self.layer_2 = self._config.layer_1.get_layer(
intermediate_2_dim,
hidden_dim,
bias=self._config.add_mlp_bias,
weight_init_method=init_method_2,
bias_init_method=init_zeros_,
auto_bias_grad_accumulation=self._distributed_config.tensor_parallel > 1,
default_weight_initializer=init_method_2,
default_add_bias=self._block_config.add_linear_biases,
sequence_parallel=self._sequence_parallel,
transposed_weight=True,
lr_scale=lr_scale,
)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/block/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fast_llm.utils import div

if typing.TYPE_CHECKING:
from fast_llm.layers.common.linear import LinearBase, LinearLike
from fast_llm.layers.common.linear.linear import LinearBase, LinearLike


class TransformerSubLayerName(str, enum.Enum):
Expand Down
Empty file.
Loading