From 37dabca0e8c26a2bf38c1b20a28cd9970516bab1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 May 2026 14:03:17 -0400 Subject: [PATCH 1/3] Fix flaky layer tests by disabling TF32 against eager reference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `test_post_norms` and `test_hybrid_moe_mlp` compare the block forward (which uses Triton RMSNorm) to an eager reference built from `torch.rms_norm`. The two norm implementations agree to ~1e-7 (FP32 noise floor), but PyTorch's default TF32 matmul has only ~1e-3 mantissa precision and amplifies that tiny input perturbation into ~1e-5 output drift through the block's matmuls — flakily crossing the 1e-5 tolerance. `test_attention` already worked around this with a local `_no_tf32` context manager. Promote it to `tests.utils.utils.no_tf32` and use it in `test_post_norms` / `test_hybrid_moe_mlp` too. Also: - `test_post_norms` now uses `Assert.rms_close_relative` instead of `torch.testing.assert_close`, matching the codebase convention (the RMS-based assertion is robust to single-element outliers from FP rounding). - The `output_scale` parametrization uses 0.8 (a realistic Gemma 4 layer_scalar value) instead of 2.5. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/layers/test_attention.py | 15 ++------------- tests/layers/test_decoder_block.py | 12 ++++++------ tests/layers/test_mlp.py | 7 +++---- tests/utils/utils.py | 16 ++++++++++++++++ 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index d572816b2..dc105bce3 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -1,4 +1,3 @@ -import contextlib import dataclasses import pytest @@ -14,7 +13,7 @@ from fast_llm.layers.attention.attention import Attention, _flash_available from fast_llm.layers.attention.config import AttentionConfig from fast_llm.utils import Assert -from tests.utils.utils import get_stage +from tests.utils.utils import get_stage, no_tf32 _HEADS = 4 _KV_HEADS = 2 @@ -277,16 +276,6 @@ def _run_per_seq_reference( return torch.cat(out_refs, dim=0) -@contextlib.contextmanager -def _no_tf32(): - prev = torch.backends.cuda.matmul.allow_tf32 - torch.backends.cuda.matmul.allow_tf32 = False - try: - yield - finally: - torch.backends.cuda.matmul.allow_tf32 = prev - - def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: num_tokens = sum(lengths) hidden_dim = TensorDim("hidden", config.hidden_size) @@ -409,5 +398,5 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: [pytest.param(config, lengths, id=f"{config.name}-{lengths}") for config, lengths in _attention_test_cases], ) def test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: - with _no_tf32(): + with no_tf32(): _test_attention(config, lengths) diff --git a/tests/layers/test_decoder_block.py b/tests/layers/test_decoder_block.py index 7cdcf7dfa..2e8aece94 100644 --- a/tests/layers/test_decoder_block.py +++ b/tests/layers/test_decoder_block.py @@ -10,7 +10,8 @@ from fast_llm.layers.common.normalization.normalization import NoNormalization from fast_llm.layers.decoder.block import DecoderBlock from fast_llm.layers.decoder.config import DecoderBlockConfig -from tests.utils.utils import get_stage +from fast_llm.utils import Assert +from tests.utils.utils import get_stage, no_tf32 _NUM_TOKENS = 16 _HIDDEN_SIZE = 64 @@ -100,7 +101,7 @@ def _rms_norm(x: torch.Tensor, norm_module) -> torch.Tensor: ("post_mixer_norm", {"post_mixer_norm": True}), ("post_mlp_norm", {"post_mlp_norm": True}), ("both_post_norms", {"post_mixer_norm": True, "post_mlp_norm": True}), - ("output_scale", {"output_scale": 2.5}), + ("output_scale", {"output_scale": 0.8}), # `{"type": "none"}` disables the position-specific pre-norm. Gemma 4's MoE block path uses # this to skip the pre-MLP norm (the routed branch owns its own pre/post norms). ("pre_mixer_norm_disabled", {"pre_mixer_normalization": {"type": "none"}}), @@ -150,8 +151,7 @@ def test_post_norms(test_config: PostNormTestConfig): } block.preprocess(kwargs) - with torch.no_grad(): + with torch.no_grad(), no_tf32(): output = block(input_, kwargs) - - expected = test_config.expected_output(block, input_, kwargs) - torch.testing.assert_close(output, expected, rtol=_TOLERANCE, atol=_TOLERANCE) + expected = test_config.expected_output(block, input_, kwargs) + Assert.rms_close_relative(output, expected, _TOLERANCE, 1e-7) diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py index 53b351e10..f72909e1f 100644 --- a/tests/layers/test_mlp.py +++ b/tests/layers/test_mlp.py @@ -10,7 +10,7 @@ from fast_llm.layers.decoder.mlp.config import HybridMoEMLPConfig from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP from fast_llm.utils import Assert -from tests.utils.utils import get_stage +from tests.utils.utils import get_stage, no_tf32 _NUM_TOKENS = 128 _HIDDEN_SIZE = 128 @@ -119,8 +119,7 @@ def test_hybrid_moe_mlp(config: HybridMoEMLPTestConfig) -> None: token_dim = TensorDim("tokens", _NUM_TOKENS) kwargs = {BlockKwargs.hidden_token_dim: token_dim} - with torch.no_grad(): + with torch.no_grad(), no_tf32(): output = hybrid(input_, kwargs) - - expected = config.expected_output(hybrid, input_, kwargs) + expected = config.expected_output(hybrid, input_, kwargs) Assert.rms_close_relative(output, expected, 1e-5, 1e-7) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index da293e1df..e98e66ff9 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -1,3 +1,4 @@ +import contextlib import logging import typing @@ -18,6 +19,21 @@ requires_triton = pytest.mark.skipif(not triton_available, reason="Triton is not available") +@contextlib.contextmanager +def no_tf32(): + # TF32 (PyTorch's CUDA matmul default) has ~1e-3 mantissa precision, which amplifies + # sub-FP32 input perturbations (e.g. Triton-vs-`torch.rms_norm` differing by ~1e-7) into + # ~1e-5 output drift through matmuls. Disable for tests comparing block forward to an eager + # reference; otherwise mixing tile/batch sizes can produce ~1e-4 swings unrelated to the + # feature under test. + prev = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = prev + + @pytest.fixture(scope="session") def result_path(): return TEST_RESULTS_PATH From a79e6d0a7a4d960855f53c487cd51c4c4ce058fb Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 19 May 2026 12:56:50 -0400 Subject: [PATCH 2/3] Cover backward and decouple MoE routing in the layer-assembly tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous commit absorbed forward-side FP32 noise by disabling TF32 in the eager-reference comparison, but two issues remained: 1. `test_hybrid_moe_mlp[wrapper_norms]` / `[all_norms]` still flaked. With a wrapper pre-norm, the Triton vs `torch.rms_norm` divergence (~1e-7) propagates to the router and occasionally flips top-k decisions for near-boundary tokens — pure routing instability, unrelated to the wiring under test. 2. The tests only exercised the forward path; nothing verified that the block's custom-autograd backward agrees with the eager reference's. Add a `_RouterBridge` `torch.autograd.Function` for `test_hybrid_moe_mlp` that captures the real router output and incoming gradient (for cross-side comparison) and substitutes deterministic mock data in both directions, so downstream `torch.topk` + softmax + expert dispatch is deterministic regardless of the FP perturbation. The router's parameters still receive a gradient through their normal backward path via the injected `mock_grad`. Extend both tests to compare gradients in addition to outputs: - run modules in `.train()` mode so the codebase's custom-autograd kernels retain backward context (with default `dropout=0`/`jitter_eps=0`, train mode is functionally identical to eval here) - `stage.reset_gradients()` before each forward+backward, so each call starts with `param_grad_is_zero=True` - use `loss.backward(torch.ones_like(loss))` + `input.grad.clone()` — `torch.autograd.grad` produces wrong gradients for these modules (suspected interaction with `wrap_forward_backward`'s context handling) Loosen `test_hybrid_moe_mlp`'s threshold to 1e-4 for the wrapper-norm configs only — observed worst case ~5e-5 from the wrapper pre/post-norm Triton-vs-`torch.rms_norm` divergence propagated through the matmuls. All other configs are bit-exact or in the 1e-7 range. Verified: 16/16 stability runs of just these two test files, plus full `tests/` suite with no regressions traceable to this PR. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/layers/test_decoder_block.py | 66 ++++++++++-------- tests/layers/test_mlp.py | 104 ++++++++++++++++++++++++----- 2 files changed, 128 insertions(+), 42 deletions(-) diff --git a/tests/layers/test_decoder_block.py b/tests/layers/test_decoder_block.py index 2e8aece94..00d0f693e 100644 --- a/tests/layers/test_decoder_block.py +++ b/tests/layers/test_decoder_block.py @@ -64,36 +64,36 @@ def expected_output(self, block: DecoderBlock, input_: torch.Tensor, kwargs: dic # Block-assembly test. The mixer and MLP are treated as black boxes (covered by # `test_attention` and `test_mlp` respectively); norms/residual/output_scale are computed # via `torch.rms_norm` so the assembly under test does not appear in its own reference. + # Runs under autograd so the caller can backward through this reference. def _rms_norm(x: torch.Tensor, norm_module) -> torch.Tensor: if isinstance(norm_module, NoNormalization): return x return torch.rms_norm(x, x.shape[-1:], norm_module.weight, 1e-5) - with torch.no_grad(): - norm1_out = _rms_norm(input_, block.norm_1) - mixer_hidden, mixer_bias = block.mixer(norm1_out, kwargs) - if block.post_mixer_norm is not None: - if mixer_bias is not None: - mixer_hidden = mixer_hidden + mixer_bias - mixer_bias = None - mixer_hidden = _rms_norm(mixer_hidden, block.post_mixer_norm) + norm1_out = _rms_norm(input_, block.norm_1) + mixer_hidden, mixer_bias = block.mixer(norm1_out, kwargs) + if block.post_mixer_norm is not None: if mixer_bias is not None: mixer_hidden = mixer_hidden + mixer_bias - after_mixer = input_ + mixer_hidden - - norm2_out = _rms_norm(after_mixer, block.norm_2) - mlp_hidden, mlp_bias = block.mlp(norm2_out, kwargs) - if block.post_mlp_norm is not None: - if mlp_bias is not None: - mlp_hidden = mlp_hidden + mlp_bias - mlp_bias = None - mlp_hidden = _rms_norm(mlp_hidden, block.post_mlp_norm) + mixer_bias = None + mixer_hidden = _rms_norm(mixer_hidden, block.post_mixer_norm) + if mixer_bias is not None: + mixer_hidden = mixer_hidden + mixer_bias + after_mixer = input_ + mixer_hidden + + norm2_out = _rms_norm(after_mixer, block.norm_2) + mlp_hidden, mlp_bias = block.mlp(norm2_out, kwargs) + if block.post_mlp_norm is not None: if mlp_bias is not None: mlp_hidden = mlp_hidden + mlp_bias - output = after_mixer + mlp_hidden - if self.output_scale is not None: - output = output * self.output_scale - return output + mlp_bias = None + mlp_hidden = _rms_norm(mlp_hidden, block.post_mlp_norm) + if mlp_bias is not None: + mlp_hidden = mlp_hidden + mlp_bias + output = after_mixer + mlp_hidden + if self.output_scale is not None: + output = output * self.output_scale + return output _base_post_norm_cases = [ @@ -129,8 +129,10 @@ def test_post_norms(test_config: PostNormTestConfig): block: DecoderBlock = test_config.get_block_config().get_layer( distributed_config, hidden_dim, lr_scale=None, peft=None ) - get_stage([block], distributed) - block.eval() + stage = get_stage([block], distributed) + # Train mode so the codebase's custom-autograd kernels retain backward context. With + # dropout=0 (default), train mode is functionally identical to eval mode here. + block.train() device = distributed.device if test_config.output_scale is not None: @@ -151,7 +153,19 @@ def test_post_norms(test_config: PostNormTestConfig): } block.preprocess(kwargs) - with torch.no_grad(), no_tf32(): - output = block(input_, kwargs) - expected = test_config.expected_output(block, input_, kwargs) + stage.reset_gradients() + with no_tf32(): + input_actual = input_.clone().requires_grad_(True) + output = block(input_actual, kwargs) + output.backward(torch.ones_like(output)) + grad_actual = input_actual.grad.clone() + + stage.reset_gradients() + with no_tf32(): + input_ref = input_.clone().requires_grad_(True) + expected = test_config.expected_output(block, input_ref, kwargs) + expected.backward(torch.ones_like(expected)) + grad_ref = input_ref.grad.clone() + Assert.rms_close_relative(output, expected, _TOLERANCE, 1e-7) + Assert.rms_close_relative(grad_actual, grad_ref, _TOLERANCE, 1e-7) diff --git a/tests/layers/test_mlp.py b/tests/layers/test_mlp.py index f72909e1f..183856248 100644 --- a/tests/layers/test_mlp.py +++ b/tests/layers/test_mlp.py @@ -1,4 +1,5 @@ import dataclasses +import types import pytest import torch @@ -8,7 +9,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.decoder.mlp.config import HybridMoEMLPConfig -from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP +from fast_llm.layers.decoder.mlp.mixture_of_experts import HybridMoEMLP, MixtureOfExpertMLP from fast_llm.utils import Assert from tests.utils.utils import get_stage, no_tf32 @@ -24,6 +25,47 @@ def _norm() -> dict: return {"type": "rms_norm"} +class _RouterBridge(torch.autograd.Function): + """Catches the router's real output and incoming gradient for cross-side comparison and + substitutes deterministic mock data in both directions, so the assembly under test does + not depend on routing. + + Forward: catches `real_logits`, returns `mock_logits` so the downstream `torch.topk` + + softmax + expert dispatch is deterministic regardless of ~1e-7 FP perturbations that + would otherwise flip top-k for near-boundary tokens. + Backward: catches the gradient computed against `mock_logits`, returns `mock_grad` to + `real_logits` so the router's parameters still receive a gradient through their normal + backward path. + + `captured` is a mutable dict the caller passes in to receive both catches. + """ + + @staticmethod + def forward(ctx, real_logits, mock_logits, mock_grad, captured): + captured["real_logits"] = real_logits.detach() + ctx.save_for_backward(mock_grad) + ctx.captured = captured + return mock_logits.detach() + + @staticmethod + def backward(ctx, grad_output): + (mock_grad,) = ctx.saved_tensors + ctx.captured["mock_logits_grad"] = grad_output.detach() + return mock_grad, None, None, None + + +def _wrap_router(routed: MixtureOfExpertMLP, mock_logits, mock_grad, captured) -> None: + if not hasattr(routed, "_orig_topk_routing"): + routed._orig_topk_routing = routed._topk_routing + real_topk = routed._orig_topk_routing + + def _patched(self, logits, grad_scale=None, losses=None): + bridged = _RouterBridge.apply(logits, mock_logits, mock_grad, captured) + return real_topk(bridged, grad_scale, losses) + + routed._topk_routing = types.MethodType(_patched, routed) + + @dataclasses.dataclass class HybridMoEMLPTestConfig: name: str @@ -67,18 +109,18 @@ def get_mlp_config(self) -> HybridMoEMLPConfig: def expected_output(self, hybrid: HybridMoEMLP, input_: torch.Tensor, kwargs: dict) -> torch.Tensor: # Hybrid-assembly test. The dense and routed branches are treated as black boxes (covered # by `MLP` and `MixtureOfExpertMLP` tests); pre/post norms are computed via - # `torch.rms_norm` so the wrapper's norms do not appear in their own reference. + # `torch.rms_norm` so the wrapper's norms do not appear in their own reference. Runs + # under autograd so the caller can backward through this reference. def _rms_norm(x: torch.Tensor, norm_module) -> torch.Tensor: return torch.rms_norm(x, x.shape[-1:], norm_module.weight, 1e-5) - with torch.no_grad(): - shared = _rms_norm(input_, hybrid.pre_norm) if hybrid.pre_norm is not None else input_ - dense_out, _ = hybrid.dense(shared, kwargs) - routed_out, _ = hybrid.routed(shared, kwargs) - out = dense_out + routed_out - if hybrid.post_norm is not None: - out = _rms_norm(out, hybrid.post_norm) - return out + shared = _rms_norm(input_, hybrid.pre_norm) if hybrid.pre_norm is not None else input_ + dense_out, _ = hybrid.dense(shared, kwargs) + routed_out, _ = hybrid.routed(shared, kwargs) + out = dense_out + routed_out + if hybrid.post_norm is not None: + out = _rms_norm(out, hybrid.post_norm) + return out _test_configs = [ @@ -112,14 +154,44 @@ def test_hybrid_moe_mlp(config: HybridMoEMLPTestConfig) -> None: hybrid: HybridMoEMLP = config.get_mlp_config().get_layer( distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False ) - get_stage([hybrid], distributed) - hybrid.eval() + stage = get_stage([hybrid], distributed) + # Train mode so the codebase's custom-autograd kernels retain backward context. With + # dropout=0 and jitter_eps=0 (defaults), train mode is functionally identical to eval + # mode here — the only difference is context retention. + hybrid.train() + + # Predetermined mock router output + incoming gradient, shared between actual and reference. + n_router_experts = hybrid.routed._config.unshared_experts + g = torch.Generator(device=device).manual_seed(0xB007) + mock_logits = torch.randn(_NUM_TOKENS, n_router_experts, device=device, generator=g) + mock_grad = torch.randn(_NUM_TOKENS, n_router_experts, device=device, generator=g) input_ = torch.randn(_NUM_TOKENS, _HIDDEN_SIZE, device=device) token_dim = TensorDim("tokens", _NUM_TOKENS) kwargs = {BlockKwargs.hidden_token_dim: token_dim} - with torch.no_grad(), no_tf32(): - output = hybrid(input_, kwargs) - expected = config.expected_output(hybrid, input_, kwargs) - Assert.rms_close_relative(output, expected, 1e-5, 1e-7) + captures_actual: dict = {} + _wrap_router(hybrid.routed, mock_logits, mock_grad, captures_actual) + stage.reset_gradients() + with no_tf32(): + input_actual = input_.clone().requires_grad_(True) + output = hybrid(input_actual, kwargs) + output.backward(torch.ones_like(output)) + grad_actual = input_actual.grad.clone() + + captures_ref: dict = {} + _wrap_router(hybrid.routed, mock_logits, mock_grad, captures_ref) + stage.reset_gradients() + with no_tf32(): + input_ref = input_.clone().requires_grad_(True) + expected = config.expected_output(hybrid, input_ref, kwargs) + expected.backward(torch.ones_like(expected)) + grad_ref = input_ref.grad.clone() + + # 1e-4 absorbs FP32 noise from the wrapper pre-norm + post-norm Triton-vs-`torch.rms_norm` + # divergence propagated through matmuls (up to ~5e-5 observed for `wrapper_norms` / + # `all_norms`). All other configs are bit-exact or in the 1e-7 range, well below threshold. + Assert.rms_close_relative(output, expected, 1e-4, 1e-7) + Assert.rms_close_relative(grad_actual, grad_ref, 1e-4, 1e-7) + Assert.rms_close_relative(captures_actual["real_logits"], captures_ref["real_logits"], 1e-4, 1e-7) + Assert.rms_close_relative(captures_actual["mock_logits_grad"], captures_ref["mock_logits_grad"], 1e-4, 1e-7) From f95b24683c414c985c2434a4287a703d65cfa789 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 19 May 2026 14:03:23 -0400 Subject: [PATCH 3/3] Loosen `test_grpo_metrics` threshold; drop suspicious min_threshold floor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Assert.rms_close_relative(got_value, ref_value, 1e-5, 1e-6)` compared two scalar GRPO metrics with a 1e-6 absolute floor that masked a slightly- too-tight relative threshold. On the full-suite run an instance hit diff 1.43e-6 / scale 8.26e-2 = 1.73e-5 relative — over the 1e-5 bound but swallowed by the 1e-6 floor most of the time. Drop the min_threshold (no other loss test uses one), and raise the relative threshold to 5e-5 for float32. 20/20 sequential runs pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/layers/test_lm_losses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 19200476a..362a21974 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -357,7 +357,7 @@ def _check_grpo_metrics(ref: GRPOMetrics, got: GRPOMetrics, threshold: float) -> if ref_value is None: assert got_value is None, name else: - Assert.rms_close_relative(got_value, ref_value, threshold, 1e-6) + Assert.rms_close_relative(got_value, ref_value, threshold) def _test_grpo_metrics( @@ -394,7 +394,7 @@ def _test_grpo_metrics( group=group, compute_entropy=compute_entropy, ) - _check_grpo_metrics(ref, got, threshold=1e-5 if dtype == DataType.float32 else 1e-4) + _check_grpo_metrics(ref, got, threshold=5e-5 if dtype == DataType.float32 else 1e-4) def _test_z_loss(