From 59e5ece354448d5196cbea61dd86fd60868e841a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Apr 2025 16:02:45 -0400 Subject: [PATCH 1/3] Fix and test LM head --- fast_llm/engine/multi_stage/fsdp.py | 27 ++- fast_llm/engine/multi_stage/stage.py | 16 +- fast_llm/engine/multi_stage/stage_base.py | 5 +- fast_llm/functional/triton/normalization.py | 5 +- fast_llm/layers/common/auxiliary_loss.py | 2 +- fast_llm/layers/common/normalization.py | 12 +- fast_llm/layers/language_model/head.py | 21 +- fast_llm/layers/transformer/attention.py | 2 +- fast_llm/layers/transformer/config.py | 4 +- fast_llm/models/gpt/model.py | 42 ++-- fast_llm/utils.py | 14 +- tests/layers/__init__.py | 0 tests/layers/test_lm_head.py | 221 ++++++++++++++++++++ 13 files changed, 314 insertions(+), 57 deletions(-) create mode 100644 tests/layers/__init__.py create mode 100644 tests/layers/test_lm_head.py diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index d45566fc8..e9c84aa30 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -167,7 +167,8 @@ def setup( grad_shard: torch.Tensor | None, weight_buffer: torch.Tensor | None, grad_buffer: torch.Tensor | None, - sequence_tensor_parallel: bool = False, + sequence_tensor_parallel: bool, + device: torch.device | None, ) -> None: assert not self._is_setup self._is_setup = True @@ -176,11 +177,19 @@ def setup( # Validate and set the shards and buffers if self._mode.on_device: - self._weight_shard = self._weight_shard_meta.validate(weight_shard) + self._weight_shard = ( + torch.empty_like(self._weight_shard_meta, device=device) + if weight_shard is None + else self._weight_shard_meta.validate(weight_shard) + ) else: Assert.none(weight_shard) if self._mode.support_forward: - self._weight_buffer = self._weight_buffer_meta.validate(weight_buffer) + self._weight_buffer = ( + torch.empty_like(self._weight_buffer_meta, device=device) + if weight_buffer is None + else self._weight_buffer_meta.validate(weight_buffer) + ) # Pre-compute the local shard for restore ops. self._weight_buffer_local_shard = self._weight_buffer[ self._fsdp_dim.rank * self._shard_size : (self._fsdp_dim.rank + 1) * self._shard_size @@ -189,8 +198,16 @@ def setup( Assert.none(weight_buffer) if self._mode.support_backward: - self._grad_shard = self._grad_shard_meta.validate(grad_shard) - self._grad_buffer = self._grad_buffer_meta.validate(grad_buffer) + self._grad_shard = ( + torch.empty_like(self._grad_shard_meta, device=device) + if grad_shard is None + else self._grad_shard_meta.validate(grad_shard) + ) + self._grad_buffer = ( + torch.empty_like(self._grad_buffer_meta, device=device) + if grad_buffer is None + else self._grad_buffer_meta.validate(grad_buffer) + ) # Pre-compute the local shard for reduce ops. self._grad_buffer_local_shard = self._grad_buffer[ self._fsdp_dim.rank * self._shard_size : (self._fsdp_dim.rank + 1) * self._shard_size diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 7ccd740ee..568820819 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -38,13 +38,13 @@ def setup( # noqa self, *, distributed: Distributed, - weight_shards: list[torch.Tensor | None] | None, - grad_shards: list[torch.Tensor | None] | None, - weight_buffers: list[torch.Tensor | None] | None, - grad_buffers: list[torch.Tensor | None] | None, + weight_shards: list[torch.Tensor | None] | None = None, + grad_shards: list[torch.Tensor | None] | None = None, + weight_buffers: list[torch.Tensor | None] | None = None, + grad_buffers: list[torch.Tensor | None] | None = None, mode: StageMode = StageMode.training, is_tied_weight_copy: bool = False, - weight_buffer_shared_with: list["Stage"], + weight_buffer_shared_with: list["Stage"] = (), ) -> None: super().setup( distributed=distributed, @@ -92,7 +92,11 @@ def forward_meta(self, input_: TensorMeta, kwargs: dict) -> TensorMeta: return input_ def forward( - self, input_: torch.Tensor, kwargs: dict, losses: dict[str, list[torch.Tensor]], metrics: dict | None = None + self, + input_: torch.Tensor, + kwargs: dict, + losses: dict[str, list[torch.Tensor]] | None = None, + metrics: dict | None = None, ) -> tuple[torch.Tensor | None, tuple[torch.Tensor | None, torch.Tensor | None]]: assert self._is_restored assert self._mode.support_forward diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index da7eb7d88..fd50f55c5 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -6,7 +6,7 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import check_parallel_match -from fast_llm.engine.base_model.base_model import BaseModel +from fast_llm.engine.base_model.base_model import BaseModel, Layer from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed @@ -29,7 +29,7 @@ def __init__( self, *, config: StageConfig, - base_model: BaseModel, + base_model: BaseModel | list[Layer], distributed_config: DistributedConfig, begin: int, end: int, @@ -153,6 +153,7 @@ def setup( weight_buffer=weight_buffer, grad_buffer=grad_buffer, sequence_tensor_parallel=self._distributed_config.sequence_tensor_parallel, + device=self._distributed.device, ) if self._mode.support_forward: diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 9c937434a..96d1663f7 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -176,6 +176,9 @@ def triton_normalization_forward( training: bool, zero_centered: bool, ) -> tuple[torch.Tensor, list[typing.Any]] | None: + # Note: Converting input automatically to training dtype to match Apex behaviour, + # needed for full precision residual. + # TODO: Review this? assert weight.shape == input_.shape[-1:] if bias is not None: assert weight.shape == bias.shape @@ -183,7 +186,7 @@ def triton_normalization_forward( n_rows = input_.shape[:-1].numel() n_cols = weight.numel() - output = torch.empty_like(input_) + output = torch.empty_like(input_, dtype=weight.dtype) inv_var = torch.empty(n_rows, dtype=torch.float32, device="cuda") block_size = triton.next_power_of_2(n_cols) diff --git a/fast_llm/layers/common/auxiliary_loss.py b/fast_llm/layers/common/auxiliary_loss.py index b5b0c15e1..44c2d2088 100644 --- a/fast_llm/layers/common/auxiliary_loss.py +++ b/fast_llm/layers/common/auxiliary_loss.py @@ -16,7 +16,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None, ...]: def calculate_z_loss(logits: torch.Tensor, logits_scale_factor: float = 1.0) -> torch.Tensor: if logits_scale_factor != 1.0: logits *= logits_scale_factor - return torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) + return torch.mean(torch.logsumexp(logits, dim=-1) ** 2) def z_loss( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 04123014e..848abb974 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -3,7 +3,7 @@ from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton.normalization import rms_norm, triton_normalization_autograd +from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import NormalizationImplementation from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ from fast_llm.utils import Assert @@ -141,6 +141,9 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, class LayerNorm(torch.nn.Module): """ A layer normalization layer, supporting multiple implementations. + Note: Converting input automatically to training dtype to match Apex behaviour, + needed for full precision residual. + TODO: Review this? """ def __init__( @@ -214,12 +217,15 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: return FusedLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.layer_norm(input_, self.normalized_shape, self.weight, self.bias, self._eps) + return torch.layer_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self.bias, self._eps) class RMSNorm(torch.nn.Module): """ A RMS normalization layer. + Note: Converting input automatically to training dtype to match Apex behaviour, + needed for full precision residual. + TODO: Review this? """ def __init__( @@ -276,4 +282,4 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: return FusedRMSNorm.apply(input_, self.normalized_shape, self.weight, self._eps) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return rms_norm(input_, self.weight, self._eps) + return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c2974415d..c84b5f534 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -67,7 +67,7 @@ def __init__( # >0: multi-token prediction (MTP) Assert.geq(prediction_distance, 0) self._prediction_distance = prediction_distance - self.is_last_head = self._prediction_distance == config.prediction_heads - 1 + self._is_last_head = self._prediction_distance == config.prediction_heads - 1 self._init_output_weights(hidden_dim, config) @@ -114,7 +114,7 @@ def forward( tensor_name="Loss", reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa ) - if not self.is_last_head: + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) # TODO: Pytorch copies the grads in backward for no reason (not sure if still the case) @@ -123,10 +123,10 @@ def forward( # TODO: Drop autograd entirely. # TODO: Skip cross-entropy backward if not needed. language_model_loss = self._forward(input_, kwargs, losses) - if language_model_loss is not None: + if losses is not None and language_model_loss is not None: losses[self._loss_name].append(language_model_loss) # TODO: Return the model output when needed. - if self.is_last_head: + if self._is_last_head: # Last head should return the loss for backward. return language_model_loss else: @@ -147,14 +147,13 @@ def _forward_backward( if target is not None: if self._config.distillation_model is None: # MTP: Shift the labels - target = ( - target[self._prediction_distance : self._prediction_distance + input_.size(0),] - if kwargs[TransformerKwargs.sequence_first] - else target[ - :, - self._prediction_distance : self._prediction_distance + input_.size(1), - ] + target_slice = slice( + self._prediction_distance, + self._prediction_distance + + input_.size(1 - kwargs[TransformerKwargs.sequence_first]) + * (self._tensor_space.distributed_config.tensor_parallel if self._parallel_embeddings else 1), ) + target = target[target_slice] if kwargs[TransformerKwargs.sequence_first] else target[:, target_slice] target = target.flatten() else: # Target is reference model logits. diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c7ae55c5c..0b442f661 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -84,7 +84,7 @@ def __init__( super().__init__() self._config = config self._tensor_space = tensor_space - Assert.in_range_incl(layer_index, 1, self._config.num_layers) + Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) self._layer_index = layer_index self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel self._debug_transformer = self._config.debug_transformer diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index cf409e773..225ed5509 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -674,11 +674,11 @@ def _validate(self) -> None: if self.init_method_std_qkv is None: self.init_method_std_qkv = self.init_method_std if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / (2 * self.num_layers) ** 0.5 + self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 if self.init_method_std_mlp_1 is None: self.init_method_std_mlp_1 = self.init_method_std if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / (2 * self.num_layers) ** 0.5 + self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 if self.init_method_max_qkv is None: self.init_method_max_qkv = self.init_method_max if self.init_method_min_qkv is None: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index a7ec58d67..f0aaf90b0 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -72,34 +72,30 @@ def __init__( self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) def get_output_layers(self) -> list[Layer]: - return [ - layer - for i in range(self._config.prediction_heads) - for layer in [ - TransformerLayer( - self._config.transformer, - self._tensor_space, - # TODO MTP: which index? - layer_index=self._config.transformer.num_layers, - # The last layer only returns the transformer output. - # The previous layers return a stack of shared_hidden and transformer_output. - return_input=i < self._config.prediction_heads - 1, - ), + layers = [] + for i in range(self._config.prediction_heads): + if i > 0: + layers.append( + TransformerLayer( + self._config.transformer, + self._tensor_space, + # TODO MTP: which index? + layer_index=max(self._config.transformer.num_layers, 1), + # The last layer only returns the transformer output. + # The previous layers return a stack of shared_hidden and transformer_output. + return_input=i < self._config.prediction_heads - 1, + ) + ) + layers.append( LanguageModelHead( self._config, self._tensor_space, prediction_distance=i, - ), - ] - ] + ) + ) + return layers def get_layers(self) -> list[Layer]: - if self._config.transformer.num_layers == 0: - Assert.eq(self._config.prediction_heads, 1) - return [ - LanguageModelEmbedding(self._config, self._tensor_space), - LanguageModelHead(self._config, self._tensor_space, 0), - ] return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ @@ -108,7 +104,7 @@ def get_layers(self) -> list[Layer]: self._tensor_space, layer_index=i + 1, ) - for i in range(self._config.transformer.num_layers - 1) + for i in range(self._config.transformer.num_layers) ], *self.get_output_layers(), ] diff --git a/fast_llm/utils.py b/fast_llm/utils.py index a8c5eac61..51e0eee59 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -144,7 +144,17 @@ def multiple(x, y): @staticmethod def rms_close(x, y, threshold): rms = rms_diff(x, y).item() - assert rms <= threshold, f"Rms diff too big ({rms} > {threshold}) between tensors {x} and {y}" + assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" + + @staticmethod + def rms_close_relative(x, y, threshold, min_threshold=0): + import torch + + Assert.eq(x.shape, y.shape) + scale = (torch.sum(x**2 + y**2) / (2 * x.numel())) ** 0.5 + threshold = max(threshold * scale, min_threshold) + rms = rms_diff(x, y).item() + assert rms <= threshold, f"Rms diff too big ({rms:.3e} > {threshold:.3e}) between tensors {x} and {y}" @staticmethod def all_equal(x, y): @@ -156,7 +166,7 @@ def all_equal(x, y): neq = x != y if neq.any().item(): # noqa - index = torch.where(neq) # noqa + index = None if x.numel() == 1 else torch.where(neq) # noqa raise AssertionError( f"Tensors have {index[0].numel()} different entries out of " f"{x.numel()}: {x[index]} != {y[index]} at index {torch.stack(index, -1)}" diff --git a/tests/layers/__init__.py b/tests/layers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py new file mode 100644 index 000000000..bb9c0c831 --- /dev/null +++ b/tests/layers/test_lm_head.py @@ -0,0 +1,221 @@ +import typing + +import pytest +import torch + +from fast_llm.config import UpdateType +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.multi_stage.config import StageConfig +from fast_llm.engine.multi_stage.stage import Stage +from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.layers.common.config import NormalizationType +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT +from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead +from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.models.gpt.config import GPTBaseModelConfig +from fast_llm.models.gpt.model import GPTBaseModel +from fast_llm.utils import Assert + + +def _lm_head( + input_: torch.Tensor, + target: torch.Tensor, + *, + # config:LanguageModelBaseConfig, + rms_weight: torch.Tensor, + logit_weight: torch.Tensor, + grad_output: float = 1.0, + logit_scale_factor: float = 1.0, + logit_z_loss=0.0, +): + hidden = torch.rms_norm( + input_.to(rms_weight.dtype), + input_.shape[-1:], + rms_weight, + 1e-5, + ) + logits = torch.nn.functional.linear(hidden, logit_weight) + if logit_scale_factor != 1.0: + logits *= logit_scale_factor + z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) if logit_z_loss > 0 else None + loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) + loss.backward(torch.full_like(loss, grad_output)) + return loss, z_loss + + +SEQUENCE_LENGTH = 200 +BATCH_SIZE = 4 +HIDDEN_SIZE = 256 +VOCAB_SIZE = 500 + + +@pytest.mark.slow +@pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) +@pytest.mark.parametrize( + ("config_dict", "distributed_config_dict"), + ( + ({}, {}), + ({}, {"training_dtype": DataType.bfloat16}), + ({"transformer": {"full_precision_residual": True}}, {"training_dtype": DataType.bfloat16}), + ({"sequence_first": True}, {}), + ({"logit_z_loss": 1e-3}, {}), + ({"logits_scale_factor": 5.0}, {}), + ({"tie_word_embeddings": False}, {}), + ({"prediction_heads": 2}, {}), + ), +) +def test_lm_head( + cross_entropy_impl: CrossEntropyImpl, + config_dict: dict[str, typing.Any], + distributed_config_dict: dict[str, typing.Any], +): + config = GPTBaseModelConfig.from_dict( + { + "transformer": { + "normalization": {"type": NormalizationType.rms_norm}, + "hidden_size": HIDDEN_SIZE, + "num_layers": 0, + }, + "vocab_size": VOCAB_SIZE, + "cross_entropy_impl": cross_entropy_impl, + }, + config_dict, + update_type=UpdateType.update, + ) + distributed_config = DistributedConfig.from_dict(distributed_config_dict) + distributed = Distributed(distributed_config) + tensor_space = TensorSpace(distributed_config) + config.setup_tensor_space(tensor_space) + tensor_space.setup(distributed) + model = GPTBaseModel(config, distributed_config) + model.setup(distributed) + + sequence_first = config.sequence_first or ( + config.cross_entropy_splits is not None and config.cross_entropy_splits > 1 + ) + target = torch.randint( + 0, + VOCAB_SIZE, + ( + (SEQUENCE_LENGTH + config.prediction_heads - 1, BATCH_SIZE) + if sequence_first + else (BATCH_SIZE, SEQUENCE_LENGTH + config.prediction_heads - 1) + ), + dtype=torch.int64, + device=distributed.device, + ) + input_ = torch.randn( + (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) if sequence_first else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), + dtype=( + distributed_config.optimization_dtype.torch + if config.transformer.full_precision_residual + else distributed_config.training_dtype.torch + ), + device=distributed.device, + requires_grad=True, + ) + kwargs = { + TransformerKwargs.sequence_first: sequence_first, + LanguageModelKwargs.labels: target, + TransformerKwargs.grad_output: 1.0, + } + if config.tie_word_embeddings or config.prediction_heads > 1: + logit_weight = ( + torch.empty( + VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed_config.training_dtype.torch, device=distributed.device + ) + .normal_(config.transformer.init_method_std) + .requires_grad_(True) + ) + kwargs[WORD_EMBEDDINGS_WEIGHT if config.tie_word_embeddings else OUTPUT_WEIGHTS] = logit_weight + else: + logit_weight = None + + for prediction_distance, layer_index in enumerate(model.model_head_indices): + # Prepare the LM head + head: LanguageModelHead = model[layer_index] + Assert.custom(isinstance, head, LanguageModelHead) + Assert.eq(head._prediction_distance, prediction_distance) + stage = Stage( + config=StageConfig(), + base_model=[head], + distributed_config=distributed_config, + begin=0, + end=1, + index=0, + ) + stage.setup(distributed=distributed) + stage.initialize_weights() + stage.restore_parameters() + stage.reset_gradients() + + # Get reference outputs and grads + if logit_weight is None: + logit_weight = head.output_weights + else: + logit_weight.grad_buffer = torch.full_like(logit_weight, float("nan")) + logit_weight.param_grad_is_zero = True + + ref_input = input_.detach().requires_grad_() + ref_rms_weight = head.final_norm.weight.detach().requires_grad_() + ref_logit_weight = logit_weight.detach().requires_grad_() + + ref_loss, ref_z_loss = _lm_head( + ref_input, + ( + target[prediction_distance : prediction_distance + SEQUENCE_LENGTH] + if sequence_first + else target[:, prediction_distance : prediction_distance + SEQUENCE_LENGTH] + ), + rms_weight=ref_rms_weight, + logit_weight=ref_logit_weight, + logit_scale_factor=config.logits_scale_factor, + logit_z_loss=config.logit_z_loss, + ) + + # Prepare LM head inputs + if head._is_last_head: + head_input = input_ + output_grad = ref_input.new_full((), float("nan")) + else: + shared_hidden = torch.randn_like(input_) + head_input = torch.stack((shared_hidden, input_.detach())).requires_grad_() + output_grad = torch.randn_like(shared_hidden) + + loss_name = f"language_model_loss_{prediction_distance}" if prediction_distance > 0 else "language_model_loss" + Assert.eq(head._loss_name, loss_name) + loss_keys = {loss_name} + if ref_z_loss is not None: + loss_keys.add("z_loss") + losses = {key: [] for key in loss_keys} + output, context = stage.forward(head_input, kwargs, losses) + stage.backward(output_grad, context) + + threshold = 1e-5 if distributed_config.training_dtype == DataType.float32 else 5e-3 + min_threshold = ( + 1e-5 if distributed_config.training_dtype == DataType.float32 else 1e-4 + ) * config.logits_scale_factor + + Assert.eq(losses.keys(), loss_keys) + Assert.eq(len(losses[loss_name]), 1) + if ref_z_loss is not None: + Assert.eq(len(losses["z_loss"]), 1) + Assert.rms_close_relative(losses["z_loss"][0], ref_z_loss, threshold, min_threshold) + + Assert.rms_close_relative(losses[loss_name][0], ref_loss, threshold, min_threshold) + + if head._is_last_head: + Assert.all_equal(output, losses[loss_name][0]) + input_grad = head_input.grad + else: + Assert.all_equal(output, shared_hidden) + shared_hidden_grad, input_grad = head_input.grad.unbind() + Assert.all_equal(shared_hidden_grad, output_grad) + + Assert.rms_close_relative(input_grad, ref_input.grad, threshold, min_threshold) + Assert.rms_close_relative(head.final_norm.weight.grad_buffer, ref_rms_weight.grad, threshold, min_threshold) + Assert.rms_close_relative(logit_weight.grad_buffer, ref_logit_weight.grad, threshold, min_threshold) From 7d17b12d3bd460af246e218b799d4d490bab4be8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 22 Apr 2025 17:20:35 -0400 Subject: [PATCH 2/3] fixes --- fast_llm/layers/language_model/head.py | 6 ++++-- tests/layers/test_lm_head.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c84b5f534..913871e1d 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -50,7 +50,9 @@ def __init__( self._group_size = tensor_space.distributed_config.tensor_parallel self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings - self._sequence_parallel_logits = self._sequence_parallel and not self._parallel_embeddings + self._sequence_parallel_logits = ( + tensor_space.distributed_config.sequence_tensor_parallel and not config.parallel_embeddings + ) self._cross_entropy_splits = config.cross_entropy_splits if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings @@ -151,7 +153,7 @@ def _forward_backward( self._prediction_distance, self._prediction_distance + input_.size(1 - kwargs[TransformerKwargs.sequence_first]) - * (self._tensor_space.distributed_config.tensor_parallel if self._parallel_embeddings else 1), + * (self._tensor_space.distributed_config.tensor_parallel if self._sequence_parallel else 1), ) target = target[target_slice] if kwargs[TransformerKwargs.sequence_first] else target[:, target_slice] target = target.flatten() diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index bb9c0c831..79101f340 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -19,6 +19,7 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.model import GPTBaseModel from fast_llm.utils import Assert +from tests.common import requires_cuda def _lm_head( @@ -53,6 +54,7 @@ def _lm_head( VOCAB_SIZE = 500 +@requires_cuda @pytest.mark.slow @pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) @pytest.mark.parametrize( From 824846a50dce8aa79c3c284eab3a28dde713cc95 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Apr 2025 16:36:15 -0400 Subject: [PATCH 3/3] comments --- fast_llm/engine/multi_stage/stage.py | 3 ++- fast_llm/layers/language_model/head.py | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 568820819..675e878b3 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -1,3 +1,4 @@ +import collections import logging import typing @@ -44,7 +45,7 @@ def setup( # noqa grad_buffers: list[torch.Tensor | None] | None = None, mode: StageMode = StageMode.training, is_tied_weight_copy: bool = False, - weight_buffer_shared_with: list["Stage"] = (), + weight_buffer_shared_with: collections.abc.Sequence["Stage"] = (), ) -> None: super().setup( distributed=distributed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 913871e1d..3b476f6a3 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -149,12 +149,12 @@ def _forward_backward( if target is not None: if self._config.distillation_model is None: # MTP: Shift the labels - target_slice = slice( - self._prediction_distance, - self._prediction_distance - + input_.size(1 - kwargs[TransformerKwargs.sequence_first]) - * (self._tensor_space.distributed_config.tensor_parallel if self._sequence_parallel else 1), + target_sequence_length = ( + target.size(1 - kwargs[TransformerKwargs.sequence_first]) + 1 - self._config.prediction_heads ) + if TransformerKwargs.sequence_q_dim in kwargs: + Assert.eq(target_sequence_length, kwargs[TransformerKwargs.sequence_q_dim].size) + target_slice = slice(self._prediction_distance, self._prediction_distance + target_sequence_length) target = target[target_slice] if kwargs[TransformerKwargs.sequence_first] else target[:, target_slice] target = target.flatten() else: