From b13c8509505378ae520287fad1f84b5057fff008 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 18:53:31 -0400 Subject: [PATCH 1/2] Fix non-canonical cu_seqlens_k from preprocessor The data preprocessor emitted `cu_seqlens_k[0] = first_document_begin` rather than 0, violating the canonical varlen prefix-sum layout required by every public varlen attention API. SDPA's EFFICIENT backward writes corrupted dK/dV rows when fed this layout, propagating wrong gradients through the K/V projection's reduce-scatter under sequence-data-parallel + micro-batch splits. Three changes that compose: - `LengthModelInputPreprocessor` now produces `cu_seqlens_k` starting at 0 and narrows `document_index_k` / `position_index` to the active K extent. The dropped leading-prefix length is exposed as a new `first_document_begin` int kwarg. - Pre-allocate one K/V buffer per attention layer across all micro-sequences of a sequence. Each forward writes the SDP-gather result into the next slice via `gather_op(out=)`; backward accumulates each micro-seq's K/V grad into a shared grad buffer slice. The leading + trailing narrows and the per-step `torch.cat` / `AttachGrad` workaround for the cross-micro-seq splice are all absorbed into the `_query_key_value` custom autograd region. - `_preprocess_for_backup_attention` builds the attention mask against the narrowed K cols so `sdpa_dense` and `backup` consume the same K extent as flash and `sdpa_nested`. Update `tests/data/test_preprocessing.py` to expect the canonical layout. --- fast_llm/data/document/block.py | 32 ++++--- fast_llm/layers/attention/attention.py | 120 +++++++++++++++++-------- fast_llm/layers/attention/config.py | 1 + tests/data/test_preprocessing.py | 7 +- 4 files changed, 109 insertions(+), 51 deletions(-) diff --git a/fast_llm/data/document/block.py b/fast_llm/data/document/block.py index a02f92bdf..82b7657bd 100644 --- a/fast_llm/data/document/block.py +++ b/fast_llm/data/document/block.py @@ -31,6 +31,7 @@ class BlockModelInput(ModelInput): document_index_q: torch.Tensor | None = None document_index_k: torch.Tensor | None = None position_index: torch.Tensor | None = None + first_document_begin: int = 0 def to_kwargs(self) -> dict[str, typing.Any]: return { @@ -51,6 +52,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: AttentionKwargs.document_index_q: self.document_index_q, AttentionKwargs.document_index_k: self.document_index_k, LanguageModelKwargs.position_ids: self.position_index, + AttentionKwargs.first_document_begin: self.first_document_begin, } @@ -101,6 +103,7 @@ def preprocess(self, model_input: BlockModelInput, config: LengthPreprocessingCo # TODO: Support non-causal cropping (needs to know about the future too). Assert.eq(model_input.sequence_k_dim.global_size, self.last_document_end) + model_input.first_document_begin = self.first_document_begin if config.return_cumulative_sequence_lengths: model_input.cumulative_lengths_q, model_input.cumulative_lengths_k = self.cumulative_lengths if config.return_max_sequence_lengths or config.return_document_index: @@ -118,9 +121,13 @@ def length(self) -> int: @functools.cached_property def cumulative_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: + # `cu_seqlens_k` follows the canonical prefix-sum layout starting at 0, describing the K + # extent narrowed by `first_document_begin` (the inactive leading prefix from earlier + # documents brought in by the sequence-data-parallel gather is dropped). Downstream + # consumers narrow `key_value` by `first_document_begin` to match. cumulative_lengths_q = torch.from_numpy(padded_cumsum(self.lengths)).to(dtype=torch.int32, device=self.device) - cumulative_lengths_k = cumulative_lengths_q + self.sequence_k_past - cumulative_lengths_k[0] = self.first_document_begin + cumulative_lengths_k = cumulative_lengths_q + (self.sequence_k_past - self.first_document_begin) + cumulative_lengths_k[0] = 0 return cumulative_lengths_q, cumulative_lengths_k @functools.cached_property @@ -141,30 +148,35 @@ def min_lengths(self) -> tuple[int, int]: @functools.cached_property def document_index(self) -> tuple[torch.Tensor, torch.Tensor]: + # `document_index_k` is computed against the narrowed K extent (length `_narrow_total_k`), + # consistent with the canonical `cumulative_lengths_k`. Values start at 1 (no leading + # "before first active document" entries). cumulative_lengths_q, cumulative_lengths_k = self.cumulative_lengths - # Note: index starts at 1. Index 0 is for sequence k before `self.current_document_begin`. return ( torch.searchsorted( cumulative_lengths_q, torch.arange(self.length, device=self.device), side="right", out_int32=True ), torch.searchsorted( cumulative_lengths_k, - torch.arange(self.sequence_k_past + self.length, device=self.device), + torch.arange(self._narrow_total_k, device=self.device), side="right", out_int32=True, ), ) + @functools.cached_property + def _narrow_total_k(self) -> int: + return self.sequence_k_past + self.length - self.first_document_begin + @functools.cached_property def position_index(self) -> torch.Tensor: + # Computed in the narrowed K coordinate space; the position-within-document is invariant + # under the narrowing shift, so this matches the un-narrowed result. _, document_index_k = self.document_index _, cumulative_lengths_k = self.cumulative_lengths - document_begins = cumulative_lengths_k[ - document_index_k[self.sequence_k_past : self.sequence_k_past + self.length] - 1 - ] + narrow_total = self._narrow_total_k + document_begins = cumulative_lengths_k[document_index_k[narrow_total - self.length : narrow_total] - 1] return ( - torch.arange( - self.sequence_k_past, self.sequence_k_past + self.length, dtype=torch.int32, device=self.device - ) + torch.arange(narrow_total - self.length, narrow_total, dtype=torch.int32, device=self.device) - document_begins ) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index dba863e94..04948aa28 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -1,3 +1,4 @@ +import dataclasses import typing import torch @@ -30,28 +31,20 @@ _flash_available = False -class AttachGrad(torch.autograd.Function): - """ - "Attach" the gradient of y to that of x, - so that the gradient of y is automatically added to that of x during the gradient computation of x. - The gradient of y should be computed first. +@dataclasses.dataclass +class _KVCacheSlot: + """Per-layer K/V buffer shared across micro-sequences of a full sequence. - In practice this allows inserting a breaking point in autograd to - split the gradient computation of x in two separate backward calls, - by setting `y = x.detach().requires_grad_()`. + Forward fills `buffer[0:frontier]` progressively (one slice per micro-sequence's gather); + backward lazily allocates `grad_buffer` and accumulates each micro-sequence's K/V grad + into the slice it attended to. Because backwards run in reverse temporal order, by the + time micro-sequence i's backward extracts `grad_buffer[frontier_prev_i:frontier_new_i]` + for its projection, all later micro-sequences' contributions are already accumulated. """ - @staticmethod - def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # noqa - # TODO: can we do it without saving y? (We only need its grad) - ctx.save_for_backward(y) - return x - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # noqa - (y,) = ctx.saved_tensors - grad = y.grad + grad_output - return grad, None + buffer: torch.Tensor + frontier: int = 0 + grad_buffer: torch.Tensor | None = None class Attention[ConfigType: AttentionConfig](BlockWithBias[ConfigType]): @@ -398,13 +391,50 @@ def _query_key_value_forward( query_unflat, kv_unflat, kwargs, inplace_query=query_norm_context is None ) - if self._sequence_data_parallel_dim.group: - # sequence dim may not be zero, but this needs to be handled after `handle.wait()` - key_value, handle = gather_op( - key_value, group=self._sequence_data_parallel_dim.group, dim=0, async_op=True + # Buffer management absorbs the SDP gather, past/present chaining, and the leading + + # trailing narrows that used to live in `_forward`. Each micro-sequence's gather writes + # directly into the next slice of a shared buffer, so K/V from past micro-sequences is + # accessed in place and there's no per-step `torch.cat`. The narrowed view returned is + # what the kernel consumes; its backward grad accumulates into a shared `grad_buffer`, + # making the cross-micro-sequence gradient splice (the previous `AttachGrad` workaround) + # implicit. + past_key_values = kwargs.get(AttentionKwargs.past_key_values) + presents = kwargs.get(AttentionKwargs.presents) + sdp_group = self._sequence_data_parallel_dim.group + micro_seq_length = key_value.size(0) * (sdp_group.size() if sdp_group else 1) + + if past_key_values: + slot = past_key_values.pop(0) + frontier_prev = slot.frontier + else: + slot = _KVCacheSlot( + buffer=torch.empty( + (kwargs[AttentionKwargs.sequence_length], 2 * self._local_head_groups, self._config.head_size), + device=key_value.device, + dtype=key_value.dtype, + ), ) - if handle: + frontier_prev = 0 + frontier_new = frontier_prev + micro_seq_length + + # `.data` accessor sidesteps the version-counter bump that an in-place write to + # `slot.buffer` would otherwise trigger; downstream kernel ops save views of + # `slot.buffer` for their backward, and those views must not see a version change when + # a later micro-sequence writes to a different (non-overlapping) slice. + destination = slot.buffer.data[frontier_prev:frontier_new] + if sdp_group: + _, handle = gather_op(key_value, group=sdp_group, dim=0, async_op=True, out=destination) handle.wait() + else: + destination.copy_(key_value) + + slot.frontier = frontier_new + if presents is not None: + presents.append(slot) + + first_document_begin = kwargs.get(AttentionKwargs.first_document_begin, 0) + sequence_k_end = kwargs[AttentionKwargs.sequence_k_dim].size + key_value_view = slot.buffer[first_document_begin:sequence_k_end] context = { "query": query_context, @@ -413,12 +443,32 @@ def _query_key_value_forward( "query_norm": query_norm_context, "key_norm": key_norm_context, "value_norm": value_norm_context, + "slot": slot, + "first_document_begin": first_document_begin, + "sequence_k_end": sequence_k_end, + "frontier_prev": frontier_prev, + "frontier_new": frontier_new, } - return query, key_value, context + return query, key_value_view, context def _query_key_value_backward( self, query_grad: torch.Tensor, key_value_grad: torch.Tensor, context: dict ) -> torch.Tensor: + # Lazily allocate the shared grad buffer on the first backward call (which corresponds + # to the last micro-sequence's forward); accumulate this micro-sequence's K/V grad into + # the slice it attended to, then take this micro-sequence's own contribution slice + # (which already includes all later micro-sequences' contributions) for the reduce- + # scatter and projection backward. + slot = context.pop("slot") + first_document_begin = context.pop("first_document_begin") + sequence_k_end = context.pop("sequence_k_end") + frontier_prev = context.pop("frontier_prev") + frontier_new = context.pop("frontier_new") + if slot.grad_buffer is None: + slot.grad_buffer = torch.zeros_like(slot.buffer) + slot.grad_buffer[first_document_begin:sequence_k_end].add_(key_value_grad) + key_value_grad = slot.grad_buffer[frontier_prev:frontier_new] + # TODO: De-allocate qkv grads quicker. key_value_grad, handle = reduce_scatter_op( key_value_grad, @@ -471,6 +521,8 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: self._debug(input_, "attn_input", (kwargs[AttentionKwargs.hidden_token_dim], self._hidden_dim), kwargs) + # `_query_key_value` absorbs the SDP gather + past/present chaining + leading/trailing + # K/V narrows, so the returned `key_value` is already the right view for the kernel. query, key_value = self._query_key_value(input_, kwargs) self._debug( @@ -480,19 +532,6 @@ def _forward( kwargs, ) - # TODO: These get unnecessarily big with lots of small documents. - if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: - # Clear the lists so tensors can be de-allocated - key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) - - if (presents := kwargs.get(AttentionKwargs.presents)) is not None: - # Return the presents as a leaf tensors so the gradients from later micro-sequences - # don't propagate to this one. - presents.append(present := key_value.detach().requires_grad_()) - # Manually add the gradients from later micro-sequences. - key_value = AttachGrad.apply(key_value, present) - - key_value = key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] key, value = key_value.chunk(2, dim=1) with set_generator(self._distributed.tp_generator): @@ -609,6 +648,7 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non device = kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size sequence_q = kwargs[AttentionKwargs.token_dim].size + first_document_begin = kwargs.get(AttentionKwargs.first_document_begin, 0) if self._config.causal: if ( sequence_length := kwargs[AttentionKwargs.sequence_length] @@ -624,7 +664,9 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non if self._config.window_size is not None: self._backup_attention_mask.triu_(-self._config.window_size + 1) - attention_mask = self._backup_attention_mask[None, sequence_k - sequence_q : sequence_k, None, :sequence_k] + attention_mask = self._backup_attention_mask[ + None, sequence_k - sequence_q : sequence_k, None, first_document_begin:sequence_k + ] else: attention_mask = None diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 95c67e7a3..d7ce1ad7e 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -26,6 +26,7 @@ class MixerKwargs(BlockKwargs): document_index_q = "document_index_q" document_index_k = "document_index_k" position_ids = "position_ids" + first_document_begin = "first_document_begin" class AttentionKwargs(MixerKwargs): diff --git a/tests/data/test_preprocessing.py b/tests/data/test_preprocessing.py index ae58121ae..6a3a6302e 100644 --- a/tests/data/test_preprocessing.py +++ b/tests/data/test_preprocessing.py @@ -227,12 +227,15 @@ def expected_position_index(self) -> list[torch.Tensor | None]: def expected_cumulative_lengths(self) -> list[tuple[torch.Tensor | None, torch.Tensor | None]]: if not self.return_cumulative_sequence_lengths: return [(None, None)] * self.micro_batch_splits + # `cu_seqlens_k` follows the canonical varlen prefix-sum layout starting at 0; the K + # extent is narrowed by `first_doc_begin` (the inactive leading prefix from earlier + # documents). Downstream consumers narrow `key_value` by the same amount. result = [] for split_index, (begin, _end) in enumerate(self._split_ranges): cropped_lengths, first_doc_begin = self._cropped_lengths_per_split[split_index] cu_q = torch.tensor([0] + cropped_lengths, dtype=torch.int32).cumsum(dim=0) - cu_k = (cu_q + begin).clone() - cu_k[0] = first_doc_begin + cu_k = cu_q + (begin - first_doc_begin) + cu_k[0] = 0 result.append((cu_q, cu_k)) return result From 5761ff6d280e3201da68ef86acae00c1bf5bf4a7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 20:08:03 -0400 Subject: [PATCH 2/2] Add regression test for first_document_begin > 0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `_test_first_document_begin` injects a fake past K/V slot with arbitrary leading data, drives attention through a manually-built kwargs with `sequence_k_past` and `first_document_begin` both set to a non-zero `past_length`, and verifies: - forward output matches a per-doc reference computed on the active documents alone (the dropped prefix has no observable effect), - parameter gradients match the reference, - the K/V grad buffer at `[:past_length]` is exactly zero — the specific guarantee of the cu_seqlens_k canonicalization fix. Runs backup + sdpa_dense on fp32, flash + sdpa_nested on bf16 (flash rejects fp32). Plugged into the existing `test_attention` parametrization as a new case with `name="first_document_begin"`, dispatched via name check. --- tests/layers/test_attention.py | 188 ++++++++++++++++++++++++++++++++- 1 file changed, 183 insertions(+), 5 deletions(-) diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 4940631b9..a58ed7269 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -5,14 +5,15 @@ import torch import torch.nn.functional -from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig -from fast_llm.data.document.language_model import LanguageModelBatch +from fast_llm.data.document.block import LengthModelInputPreprocessor +from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig, LengthPreprocessingConfig +from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelInput from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.attention.attention import Attention, _flash_available -from fast_llm.layers.attention.config import AttentionConfig +from fast_llm.layers.attention.attention import Attention, _flash_available, _KVCacheSlot +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.utils import Assert from tests.utils.utils import get_stage @@ -270,6 +271,11 @@ def expected_output( for lengths in _LENGTHS_SHORT: _attention_test_cases.append((AttentionTestConfig(name=name, **kwargs), lengths)) +# Regression for the cu_seqlens_k canonicalization: attention with a non-zero leading K/V +# prefix from earlier documents (first_document_begin > 0) must match attention on the +# active documents alone, and the K/V grad for the inactive prefix must be exactly zero. +_attention_test_cases.append((AttentionTestConfig(name="first_document_begin"), [4, 1, 10])) + def _check_packed( implementation: str, @@ -479,4 +485,176 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: ) def test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: with _no_tf32(): - _test_attention(config, lengths) + if config.name == "first_document_begin": + _test_first_document_begin(config, lengths) + else: + _test_attention(config, lengths) + + +def _check_first_document_begin( + implementation: str, + config: AttentionTestConfig, + active_lengths: list[int], + past_length: int, + distributed_config: DistributedConfig, + distributed: Distributed, + hidden_dim: TensorDim, + hidden_states: torch.Tensor, + out_ref: torch.Tensor, + grads_ref: list[torch.Tensor], + ref_params: list[torch.Tensor], + out_rtol: float, + grad_rtol: float, +) -> None: + """Run attention with `first_document_begin > 0` via a fake past slot, compare against + the per-doc reference, and verify that the inactive prefix's K/V grad is exactly zero.""" + active_total = sum(active_lengths) + total_k = past_length + active_total + + attention: Attention = config.get_attention_config(implementation).get_layer( + distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage = get_stage([attention], distributed) + for param, ref_param in zip(attention.parameters(), ref_params, strict=True): + param.data.copy_(ref_param.data) + + # Build kwargs directly so we can inject `sequence_k_past` / `first_document_begin` > 0. + model_input = LanguageModelInput(tokens=torch.empty(active_total, dtype=torch.int64, device=hidden_states.device)) + LengthModelInputPreprocessor( + lengths=active_lengths, + sequence_k_past=past_length, + first_document_begin=past_length, + last_document_end=total_k, + device=hidden_states.device, + unpadded_length=active_total, + sequence_length=total_k, + ).preprocess( + model_input, + LengthPreprocessingConfig(distributed=distributed_config, **attention.get_preprocessing_config()), + ) + + # Fake "past" K/V slot with arbitrary data in the leading-prefix region; the narrow must + # drop those positions, so their contents should not influence output or gradients. + slot = _KVCacheSlot( + buffer=torch.randn( + total_k, 2 * config.kv_heads, config.head_size, device=hidden_states.device, dtype=hidden_states.dtype + ), + frontier=past_length, + ) + kwargs = model_input.to_kwargs() + kwargs[AttentionKwargs.past_key_values] = [slot] + attention.preprocess(kwargs) + + hidden_states_test = hidden_states.detach().clone().requires_grad_() + out, context = stage.forward(hidden_states_test, kwargs) + stage.backward(torch.ones_like(out), context) + + Assert.rms_close_relative(out, out_ref, out_rtol, 1e-7, msg=implementation) + for param, grad_ref in zip(attention.parameters(), grads_ref, strict=True): + Assert.rms_close_relative(param.grad_buffer, grad_ref, grad_rtol, 1e-7, msg=implementation) + # Specific guarantee of the fix: K/V grad for the inactive leading prefix must be zero. + assert slot.grad_buffer is not None, f"{implementation}: grad_buffer not populated" + assert (slot.grad_buffer[:past_length] == 0).all(), f"{implementation}: dK/dV for inactive prefix not zero" + + +def _test_first_document_begin(config: AttentionTestConfig, lengths: list[int]) -> None: + """Regression for the cu_seqlens_k canonicalization. Attention with + `first_document_begin > 0` (a non-zero K/V prefix from earlier documents that the current + micro-sequence does not attend to) must produce the same output and gradients as + attention on the active documents alone, and the K/V grad for the inactive prefix must + be exactly zero — the specific guarantee of the fix. + """ + past_length = 7 + active_total = sum(lengths) + + distributed_config = DistributedConfig(use_cuda=torch.cuda.is_available()) + distributed = Distributed(distributed_config) + device = distributed.device + hidden_dim = TensorDim("hidden", config.hidden_size) + + # fp32 reference: per-doc backup attention on the active documents alone. + attention_ref: Attention = config.get_attention_config("backup").get_layer( + distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage_ref = get_stage([attention_ref], distributed) + hidden_states = torch.randn(active_total, config.hidden_size, device=device, requires_grad=True) + (model_input_ref,) = LanguageModelBatch( + tokens=torch.empty(active_total, dtype=torch.int64, device=device), lengths=lengths + ).get_model_inputs( + LanguageModelBatchPreprocessingConfig( + distributed=distributed_config, + predicted_tokens=0, + **attention_ref.get_preprocessing_config(), + ) + ) + kwargs_ref = model_input_ref.to_kwargs() + attention_ref.preprocess(kwargs_ref) + out_ref, context_ref = stage_ref.forward(hidden_states, kwargs_ref) + stage_ref.backward(torch.ones_like(out_ref), context_ref) + grads_ref = [param.grad_buffer.clone() for param in attention_ref.parameters()] + ref_params = list(attention_ref.parameters()) + + for implementation in ("backup", "sdpa_dense"): + _check_first_document_begin( + implementation, + config, + lengths, + past_length, + distributed_config, + distributed, + hidden_dim, + hidden_states, + out_ref.detach(), + grads_ref, + ref_params, + 1e-5, + 1e-5, + ) + + if not torch.cuda.is_available(): + return + + # bf16 reference for flash + sdpa_nested (flash rejects fp32). + distributed_config_bf16 = DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=True) + distributed_bf16 = Distributed(distributed_config_bf16) + attention_ref_bf16: Attention = config.get_attention_config("backup").get_layer( + distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage_ref_bf16 = get_stage([attention_ref_bf16], distributed_bf16) + for param_bf16, param_f32 in zip(attention_ref_bf16.parameters(), ref_params, strict=True): + param_bf16.data.copy_(param_f32.data) + hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16).requires_grad_() + (model_input_ref_bf16,) = LanguageModelBatch( + tokens=torch.empty(active_total, dtype=torch.int64, device=device), lengths=lengths + ).get_model_inputs( + LanguageModelBatchPreprocessingConfig( + distributed=distributed_config_bf16, + predicted_tokens=0, + **attention_ref_bf16.get_preprocessing_config(), + ) + ) + kwargs_ref_bf16 = model_input_ref_bf16.to_kwargs() + attention_ref_bf16.preprocess(kwargs_ref_bf16) + out_ref_bf16, context_ref_bf16 = stage_ref_bf16.forward(hidden_states_bf16, kwargs_ref_bf16) + stage_ref_bf16.backward(torch.ones_like(out_ref_bf16), context_ref_bf16) + grads_ref_bf16 = [param.grad_buffer.clone() for param in attention_ref_bf16.parameters()] + ref_params_bf16 = list(attention_ref_bf16.parameters()) + + for implementation in ("flash", "sdpa_nested"): + if implementation == "flash" and (not _flash_available or config.head_size > 256): + continue + _check_first_document_begin( + implementation, + config, + lengths, + past_length, + distributed_config_bf16, + distributed_bf16, + hidden_dim, + hidden_states_bf16, + out_ref_bf16.detach(), + grads_ref_bf16, + ref_params_bf16, + 5e-3, + 1.5e-2, + )