Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions tests/layers/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import dataclasses

import pytest
Expand All @@ -14,7 +13,7 @@
from fast_llm.layers.attention.attention import Attention, _flash_available
from fast_llm.layers.attention.config import AttentionConfig
from fast_llm.utils import Assert
from tests.utils.utils import get_stage
from tests.utils.utils import get_stage, no_tf32

_HEADS = 4
_KV_HEADS = 2
Expand Down Expand Up @@ -277,16 +276,6 @@ def _run_per_seq_reference(
return torch.cat(out_refs, dim=0)


@contextlib.contextmanager
def _no_tf32():
prev = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
try:
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = prev


def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None:
num_tokens = sum(lengths)
hidden_dim = TensorDim("hidden", config.hidden_size)
Expand Down Expand Up @@ -409,5 +398,5 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None:
[pytest.param(config, lengths, id=f"{config.name}-{lengths}") for config, lengths in _attention_test_cases],
)
def test_attention(config: AttentionTestConfig, lengths: list[int]) -> None:
with _no_tf32():
with no_tf32():
_test_attention(config, lengths)
12 changes: 6 additions & 6 deletions tests/layers/test_decoder_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from fast_llm.layers.common.normalization.normalization import NoNormalization
from fast_llm.layers.decoder.block import DecoderBlock
from fast_llm.layers.decoder.config import DecoderBlockConfig
from tests.utils.utils import get_stage
from fast_llm.utils import Assert
from tests.utils.utils import get_stage, no_tf32

_NUM_TOKENS = 16
_HIDDEN_SIZE = 64
Expand Down Expand Up @@ -100,7 +101,7 @@ def _rms_norm(x: torch.Tensor, norm_module) -> torch.Tensor:
("post_mixer_norm", {"post_mixer_norm": True}),
("post_mlp_norm", {"post_mlp_norm": True}),
("both_post_norms", {"post_mixer_norm": True, "post_mlp_norm": True}),
("output_scale", {"output_scale": 2.5}),
("output_scale", {"output_scale": 0.8}),
# `{"type": "none"}` disables the position-specific pre-norm. Gemma 4's MoE block path uses
# this to skip the pre-MLP norm (the routed branch owns its own pre/post norms).
("pre_mixer_norm_disabled", {"pre_mixer_normalization": {"type": "none"}}),
Expand Down Expand Up @@ -150,8 +151,7 @@ def test_post_norms(test_config: PostNormTestConfig):
}
block.preprocess(kwargs)

with torch.no_grad():
with torch.no_grad(), no_tf32():
output = block(input_, kwargs)

expected = test_config.expected_output(block, input_, kwargs)
torch.testing.assert_close(output, expected, rtol=_TOLERANCE, atol=_TOLERANCE)
expected = test_config.expected_output(block, input_, kwargs)
Assert.rms_close_relative(output, expected, _TOLERANCE, 1e-7)
7 changes: 3 additions & 4 deletions tests/layers/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fast_llm.layers.decoder.mlp.config import HybridMoEMLPConfig
from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP
from fast_llm.utils import Assert
from tests.utils.utils import get_stage
from tests.utils.utils import get_stage, no_tf32

_NUM_TOKENS = 128
_HIDDEN_SIZE = 128
Expand Down Expand Up @@ -119,8 +119,7 @@ def test_hybrid_moe_mlp(config: HybridMoEMLPTestConfig) -> None:
token_dim = TensorDim("tokens", _NUM_TOKENS)
kwargs = {BlockKwargs.hidden_token_dim: token_dim}

with torch.no_grad():
with torch.no_grad(), no_tf32():
output = hybrid(input_, kwargs)

expected = config.expected_output(hybrid, input_, kwargs)
expected = config.expected_output(hybrid, input_, kwargs)
Assert.rms_close_relative(output, expected, 1e-5, 1e-7)
16 changes: 16 additions & 0 deletions tests/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import logging
import typing

Expand All @@ -18,6 +19,21 @@
requires_triton = pytest.mark.skipif(not triton_available, reason="Triton is not available")


@contextlib.contextmanager
def no_tf32():
# TF32 (PyTorch's CUDA matmul default) has ~1e-3 mantissa precision, which amplifies
# sub-FP32 input perturbations (e.g. Triton-vs-`torch.rms_norm` differing by ~1e-7) into
# ~1e-5 output drift through matmuls. Disable for tests comparing block forward to an eager
# reference; otherwise mixing tile/batch sizes can produce ~1e-4 swings unrelated to the
# feature under test.
prev = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
try:
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = prev


@pytest.fixture(scope="session")
def result_path():
return TEST_RESULTS_PATH
Expand Down
Loading