diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index d572816b2..dc105bce3 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -1,4 +1,3 @@ -import contextlib import dataclasses import pytest @@ -14,7 +13,7 @@ from fast_llm.layers.attention.attention import Attention, _flash_available from fast_llm.layers.attention.config import AttentionConfig from fast_llm.utils import Assert -from tests.utils.utils import get_stage +from tests.utils.utils import get_stage, no_tf32 _HEADS = 4 _KV_HEADS = 2 @@ -277,16 +276,6 @@ def _run_per_seq_reference( return torch.cat(out_refs, dim=0) -@contextlib.contextmanager -def _no_tf32(): - prev = torch.backends.cuda.matmul.allow_tf32 - torch.backends.cuda.matmul.allow_tf32 = False - try: - yield - finally: - torch.backends.cuda.matmul.allow_tf32 = prev - - def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: num_tokens = sum(lengths) hidden_dim = TensorDim("hidden", config.hidden_size) @@ -409,5 +398,5 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: [pytest.param(config, lengths, id=f"{config.name}-{lengths}") for config, lengths in _attention_test_cases], ) def test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: - with _no_tf32(): + with no_tf32(): _test_attention(config, lengths) diff --git a/tests/layers/test_decoder_block.py b/tests/layers/test_decoder_block.py index 7cdcf7dfa..2e8aece94 100644 --- a/tests/layers/test_decoder_block.py +++ b/tests/layers/test_decoder_block.py @@ -10,7 +10,8 @@ from fast_llm.layers.common.normalization.normalization import NoNormalization from fast_llm.layers.decoder.block import DecoderBlock from fast_llm.layers.decoder.config import DecoderBlockConfig -from tests.utils.utils import get_stage +from fast_llm.utils import Assert +from tests.utils.utils import get_stage, no_tf32 _NUM_TOKENS = 16 _HIDDEN_SIZE = 64 @@ -100,7 +101,7 @@ def _rms_norm(x: torch.Tensor, norm_module) -> torch.Tensor: ("post_mixer_norm", {"post_mixer_norm": True}), ("post_mlp_norm", {"post_mlp_norm": True}), ("both_post_norms", {"post_mixer_norm": True, "post_mlp_norm": True}), - ("output_scale", {"output_scale": 2.5}), + ("output_scale", {"output_scale": 0.8}), # `{"type": "none"}` disables the position-specific pre-norm. Gemma 4's MoE block path uses # this to skip the pre-MLP norm (the routed branch owns its own pre/post norms). ("pre_mixer_norm_disabled", {"pre_mixer_normalization": {"type": "none"}}), @@ -150,8 +151,7 @@ def test_post_norms(test_config: PostNormTestConfig): } block.preprocess(kwargs) - with torch.no_grad(): + with torch.no_grad(), no_tf32(): output = block(input_, kwargs) - - expected = test_config.expected_output(block, input_, kwargs) - torch.testing.assert_close(output, expected, rtol=_TOLERANCE, atol=_TOLERANCE) + expected = test_config.expected_output(block, input_, kwargs) + Assert.rms_close_relative(output, expected, _TOLERANCE, 1e-7) diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py index 53b351e10..f72909e1f 100644 --- a/tests/layers/test_mlp.py +++ b/tests/layers/test_mlp.py @@ -10,7 +10,7 @@ from fast_llm.layers.decoder.mlp.config import HybridMoEMLPConfig from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP from fast_llm.utils import Assert -from tests.utils.utils import get_stage +from tests.utils.utils import get_stage, no_tf32 _NUM_TOKENS = 128 _HIDDEN_SIZE = 128 @@ -119,8 +119,7 @@ def test_hybrid_moe_mlp(config: HybridMoEMLPTestConfig) -> None: token_dim = TensorDim("tokens", _NUM_TOKENS) kwargs = {BlockKwargs.hidden_token_dim: token_dim} - with torch.no_grad(): + with torch.no_grad(), no_tf32(): output = hybrid(input_, kwargs) - - expected = config.expected_output(hybrid, input_, kwargs) + expected = config.expected_output(hybrid, input_, kwargs) Assert.rms_close_relative(output, expected, 1e-5, 1e-7) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index da293e1df..e98e66ff9 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -1,3 +1,4 @@ +import contextlib import logging import typing @@ -18,6 +19,21 @@ requires_triton = pytest.mark.skipif(not triton_available, reason="Triton is not available") +@contextlib.contextmanager +def no_tf32(): + # TF32 (PyTorch's CUDA matmul default) has ~1e-3 mantissa precision, which amplifies + # sub-FP32 input perturbations (e.g. Triton-vs-`torch.rms_norm` differing by ~1e-7) into + # ~1e-5 output drift through matmuls. Disable for tests comparing block forward to an eager + # reference; otherwise mixing tile/batch sizes can produce ~1e-4 swings unrelated to the + # feature under test. + prev = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = prev + + @pytest.fixture(scope="session") def result_path(): return TEST_RESULTS_PATH