Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions fast_llm/data/document/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BlockModelInput(ModelInput):
document_index_q: torch.Tensor | None = None
document_index_k: torch.Tensor | None = None
position_index: torch.Tensor | None = None
first_document_begin: int = 0

def to_kwargs(self) -> dict[str, typing.Any]:
return {
Expand All @@ -51,6 +52,7 @@ def to_kwargs(self) -> dict[str, typing.Any]:
AttentionKwargs.document_index_q: self.document_index_q,
AttentionKwargs.document_index_k: self.document_index_k,
LanguageModelKwargs.position_ids: self.position_index,
AttentionKwargs.first_document_begin: self.first_document_begin,
}


Expand Down Expand Up @@ -101,6 +103,7 @@ def preprocess(self, model_input: BlockModelInput, config: LengthPreprocessingCo
# TODO: Support non-causal cropping (needs to know about the future too).
Assert.eq(model_input.sequence_k_dim.global_size, self.last_document_end)

model_input.first_document_begin = self.first_document_begin
if config.return_cumulative_sequence_lengths:
model_input.cumulative_lengths_q, model_input.cumulative_lengths_k = self.cumulative_lengths
if config.return_max_sequence_lengths or config.return_document_index:
Expand All @@ -118,9 +121,13 @@ def length(self) -> int:

@functools.cached_property
def cumulative_lengths(self) -> tuple[torch.Tensor, torch.Tensor]:
# `cu_seqlens_k` follows the canonical prefix-sum layout starting at 0, describing the K
# extent narrowed by `first_document_begin` (the inactive leading prefix from earlier
# documents brought in by the sequence-data-parallel gather is dropped). Downstream
# consumers narrow `key_value` by `first_document_begin` to match.
cumulative_lengths_q = torch.from_numpy(padded_cumsum(self.lengths)).to(dtype=torch.int32, device=self.device)
cumulative_lengths_k = cumulative_lengths_q + self.sequence_k_past
cumulative_lengths_k[0] = self.first_document_begin
cumulative_lengths_k = cumulative_lengths_q + (self.sequence_k_past - self.first_document_begin)
cumulative_lengths_k[0] = 0
return cumulative_lengths_q, cumulative_lengths_k

@functools.cached_property
Expand All @@ -141,30 +148,35 @@ def min_lengths(self) -> tuple[int, int]:

@functools.cached_property
def document_index(self) -> tuple[torch.Tensor, torch.Tensor]:
# `document_index_k` is computed against the narrowed K extent (length `_narrow_total_k`),
# consistent with the canonical `cumulative_lengths_k`. Values start at 1 (no leading
# "before first active document" entries).
cumulative_lengths_q, cumulative_lengths_k = self.cumulative_lengths
# Note: index starts at 1. Index 0 is for sequence k before `self.current_document_begin`.
return (
torch.searchsorted(
cumulative_lengths_q, torch.arange(self.length, device=self.device), side="right", out_int32=True
),
torch.searchsorted(
cumulative_lengths_k,
torch.arange(self.sequence_k_past + self.length, device=self.device),
torch.arange(self._narrow_total_k, device=self.device),
side="right",
out_int32=True,
),
)

@functools.cached_property
def _narrow_total_k(self) -> int:
return self.sequence_k_past + self.length - self.first_document_begin

@functools.cached_property
def position_index(self) -> torch.Tensor:
# Computed in the narrowed K coordinate space; the position-within-document is invariant
# under the narrowing shift, so this matches the un-narrowed result.
_, document_index_k = self.document_index
_, cumulative_lengths_k = self.cumulative_lengths
document_begins = cumulative_lengths_k[
document_index_k[self.sequence_k_past : self.sequence_k_past + self.length] - 1
]
narrow_total = self._narrow_total_k
document_begins = cumulative_lengths_k[document_index_k[narrow_total - self.length : narrow_total] - 1]
return (
torch.arange(
self.sequence_k_past, self.sequence_k_past + self.length, dtype=torch.int32, device=self.device
)
torch.arange(narrow_total - self.length, narrow_total, dtype=torch.int32, device=self.device)
- document_begins
)
120 changes: 81 additions & 39 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import typing

import torch
Expand Down Expand Up @@ -30,28 +31,20 @@
_flash_available = False


class AttachGrad(torch.autograd.Function):
"""
"Attach" the gradient of y to that of x,
so that the gradient of y is automatically added to that of x during the gradient computation of x.
The gradient of y should be computed first.
@dataclasses.dataclass
class _KVCacheSlot:
"""Per-layer K/V buffer shared across micro-sequences of a full sequence.

In practice this allows inserting a breaking point in autograd to
split the gradient computation of x in two separate backward calls,
by setting `y = x.detach().requires_grad_()`.
Forward fills `buffer[0:frontier]` progressively (one slice per micro-sequence's gather);
backward lazily allocates `grad_buffer` and accumulates each micro-sequence's K/V grad
into the slice it attended to. Because backwards run in reverse temporal order, by the
time micro-sequence i's backward extracts `grad_buffer[frontier_prev_i:frontier_new_i]`
for its projection, all later micro-sequences' contributions are already accumulated.
"""

@staticmethod
def forward(ctx, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # noqa
# TODO: can we do it without saving y? (We only need its grad)
ctx.save_for_backward(y)
return x

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # noqa
(y,) = ctx.saved_tensors
grad = y.grad + grad_output
return grad, None
buffer: torch.Tensor
frontier: int = 0
grad_buffer: torch.Tensor | None = None


class Attention[ConfigType: AttentionConfig](BlockWithBias[ConfigType]):
Expand Down Expand Up @@ -398,13 +391,50 @@ def _query_key_value_forward(
query_unflat, kv_unflat, kwargs, inplace_query=query_norm_context is None
)

if self._sequence_data_parallel_dim.group:
# sequence dim may not be zero, but this needs to be handled after `handle.wait()`
key_value, handle = gather_op(
key_value, group=self._sequence_data_parallel_dim.group, dim=0, async_op=True
# Buffer management absorbs the SDP gather, past/present chaining, and the leading +
# trailing narrows that used to live in `_forward`. Each micro-sequence's gather writes
# directly into the next slice of a shared buffer, so K/V from past micro-sequences is
# accessed in place and there's no per-step `torch.cat`. The narrowed view returned is
# what the kernel consumes; its backward grad accumulates into a shared `grad_buffer`,
# making the cross-micro-sequence gradient splice (the previous `AttachGrad` workaround)
# implicit.
past_key_values = kwargs.get(AttentionKwargs.past_key_values)
presents = kwargs.get(AttentionKwargs.presents)
sdp_group = self._sequence_data_parallel_dim.group
micro_seq_length = key_value.size(0) * (sdp_group.size() if sdp_group else 1)

if past_key_values:
slot = past_key_values.pop(0)
frontier_prev = slot.frontier
else:
slot = _KVCacheSlot(
buffer=torch.empty(
(kwargs[AttentionKwargs.sequence_length], 2 * self._local_head_groups, self._config.head_size),
device=key_value.device,
dtype=key_value.dtype,
),
)
if handle:
frontier_prev = 0
frontier_new = frontier_prev + micro_seq_length

# `.data` accessor sidesteps the version-counter bump that an in-place write to
# `slot.buffer` would otherwise trigger; downstream kernel ops save views of
# `slot.buffer` for their backward, and those views must not see a version change when
# a later micro-sequence writes to a different (non-overlapping) slice.
destination = slot.buffer.data[frontier_prev:frontier_new]
if sdp_group:
_, handle = gather_op(key_value, group=sdp_group, dim=0, async_op=True, out=destination)
handle.wait()
else:
destination.copy_(key_value)

slot.frontier = frontier_new
if presents is not None:
presents.append(slot)

first_document_begin = kwargs.get(AttentionKwargs.first_document_begin, 0)
sequence_k_end = kwargs[AttentionKwargs.sequence_k_dim].size
key_value_view = slot.buffer[first_document_begin:sequence_k_end]

context = {
"query": query_context,
Expand All @@ -413,12 +443,32 @@ def _query_key_value_forward(
"query_norm": query_norm_context,
"key_norm": key_norm_context,
"value_norm": value_norm_context,
"slot": slot,
"first_document_begin": first_document_begin,
"sequence_k_end": sequence_k_end,
"frontier_prev": frontier_prev,
"frontier_new": frontier_new,
}
return query, key_value, context
return query, key_value_view, context

def _query_key_value_backward(
self, query_grad: torch.Tensor, key_value_grad: torch.Tensor, context: dict
) -> torch.Tensor:
# Lazily allocate the shared grad buffer on the first backward call (which corresponds
# to the last micro-sequence's forward); accumulate this micro-sequence's K/V grad into
# the slice it attended to, then take this micro-sequence's own contribution slice
# (which already includes all later micro-sequences' contributions) for the reduce-
# scatter and projection backward.
slot = context.pop("slot")
first_document_begin = context.pop("first_document_begin")
sequence_k_end = context.pop("sequence_k_end")
frontier_prev = context.pop("frontier_prev")
frontier_new = context.pop("frontier_new")
if slot.grad_buffer is None:
slot.grad_buffer = torch.zeros_like(slot.buffer)
slot.grad_buffer[first_document_begin:sequence_k_end].add_(key_value_grad)
key_value_grad = slot.grad_buffer[frontier_prev:frontier_new]

# TODO: De-allocate qkv grads quicker.
key_value_grad, handle = reduce_scatter_op(
key_value_grad,
Expand Down Expand Up @@ -471,6 +521,8 @@ def _forward(
metrics: dict[str, typing.Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
self._debug(input_, "attn_input", (kwargs[AttentionKwargs.hidden_token_dim], self._hidden_dim), kwargs)
# `_query_key_value` absorbs the SDP gather + past/present chaining + leading/trailing
# K/V narrows, so the returned `key_value` is already the right view for the kernel.
query, key_value = self._query_key_value(input_, kwargs)

self._debug(
Expand All @@ -480,19 +532,6 @@ def _forward(
kwargs,
)

# TODO: These get unnecessarily big with lots of small documents.
if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None:
# Clear the lists so tensors can be de-allocated
key_value = torch.cat((past_key_values.pop(0), key_value), dim=0)

if (presents := kwargs.get(AttentionKwargs.presents)) is not None:
# Return the presents as a leaf tensors so the gradients from later micro-sequences
# don't propagate to this one.
presents.append(present := key_value.detach().requires_grad_())
# Manually add the gradients from later micro-sequences.
key_value = AttachGrad.apply(key_value, present)

key_value = key_value[: kwargs[AttentionKwargs.sequence_k_dim].size]
key, value = key_value.chunk(2, dim=1)

with set_generator(self._distributed.tp_generator):
Expand Down Expand Up @@ -609,6 +648,7 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non
device = kwargs[AttentionKwargs.device] if AttentionKwargs.device in kwargs else self._distributed.device
sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size
sequence_q = kwargs[AttentionKwargs.token_dim].size
first_document_begin = kwargs.get(AttentionKwargs.first_document_begin, 0)
if self._config.causal:
if (
sequence_length := kwargs[AttentionKwargs.sequence_length]
Expand All @@ -624,7 +664,9 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non

if self._config.window_size is not None:
self._backup_attention_mask.triu_(-self._config.window_size + 1)
attention_mask = self._backup_attention_mask[None, sequence_k - sequence_q : sequence_k, None, :sequence_k]
attention_mask = self._backup_attention_mask[
None, sequence_k - sequence_q : sequence_k, None, first_document_begin:sequence_k
]
else:
attention_mask = None

Expand Down
1 change: 1 addition & 0 deletions fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class MixerKwargs(BlockKwargs):
document_index_q = "document_index_q"
document_index_k = "document_index_k"
position_ids = "position_ids"
first_document_begin = "first_document_begin"


class AttentionKwargs(MixerKwargs):
Expand Down
7 changes: 5 additions & 2 deletions tests/data/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,15 @@ def expected_position_index(self) -> list[torch.Tensor | None]:
def expected_cumulative_lengths(self) -> list[tuple[torch.Tensor | None, torch.Tensor | None]]:
if not self.return_cumulative_sequence_lengths:
return [(None, None)] * self.micro_batch_splits
# `cu_seqlens_k` follows the canonical varlen prefix-sum layout starting at 0; the K
# extent is narrowed by `first_doc_begin` (the inactive leading prefix from earlier
# documents). Downstream consumers narrow `key_value` by the same amount.
result = []
for split_index, (begin, _end) in enumerate(self._split_ranges):
cropped_lengths, first_doc_begin = self._cropped_lengths_per_split[split_index]
cu_q = torch.tensor([0] + cropped_lengths, dtype=torch.int32).cumsum(dim=0)
cu_k = (cu_q + begin).clone()
cu_k[0] = first_doc_begin
cu_k = cu_q + (begin - first_doc_begin)
cu_k[0] = 0
result.append((cu_q, cu_k))
return result

Expand Down
Loading