feat: ragged continuous-batching decode graph (#449 M3 Stage 2a-i)#466
Merged
Conversation
The first half of Stage 2: a ragged decode_step where each row carries its own position and length, so sequences of different lengths can share one batch (the graph mechanic continuous batching needs). Validated by reference-equivalence, token-exact on CPU and CUDA. This is 2a-i (the graph); the admit/evict scheduler with staggered arrival (2a-ii) is the remaining Stage 2a piece. Spike-only, so CI is unaffected. Emitter (spike/rust-emitter): emit_decode_ragged (CLI decode-ragged[-argmax] <B>) gives each row its own pos[B]/cache_len[B]; RoPE cos/sin become a per-row gather [B,d] by pos[B]; the key mask is per-row [B,S] (valid iff s <= cache_len[b]); and the KV write is unrolled per row, each row writing its new K/V at its own pos[b] via dynamic_update_slice (constant row/layer offsets, dynamic pos[r]; scatter stays avoided per the Phase 2a CUDA caveat). The attention contractions and the LM head are identical to the uniform-B graph; the per-row mask carries the raggedness. Shim (spike/iree-ffi/iree_gate.c): xla_llama_decode_ragged takes token/pos/cache_len as [B] arrays and threads the rank-5 KV. Prefill-into-slot uses a host KV mirror: xla_llama_ragged_reset, xla_llama_prefill_slot (single-seq prefill, d2h into the mirror slot), and xla_llama_commit (h2d the mirror into the resident rank-5 KV), reusing the scalar prefill vmfb. Driver spike/iree-ffi/src/bin/llama_ragged.rs validates by reference-equivalence: Phase 1 captures one single-seq reference per prompt for B prompts of different lengths (truncations of the reference prompt); Phase 2 prefills them into slots and runs the ragged decode with every slot at its own position; each slot's stream must match its independent reference. Results, all reference-exact 48/48: CPU local-task B=4 (lengths 46,43,40,37) 115 ms/step, 34.7 tok/s; CUDA GB10 B=4 89 ms/step, 44.8 tok/s; CUDA B=8 (lengths 46..25) 96 ms/step, 83.6 tok/s. The per-row KV-write unroll adds only ~2% (B=4) to ~7.5% (B=8) over the uniform-B graph and lowers fine on both targets (compile ~5-7.5s). A sequence's output is invariant to its batch-mates and their lengths. Full result and tables in spike/iree-ffi/FINDINGS_ragged.md; the full Stage 2 plan (decisions: spike-first, a common BatchEngine trait, contiguous-per-slot KV) is in spike/openxla/STAGE2_DESIGN.md. Next is 2a-ii (the admit/evict scheduler). Refs #449
This was referenced Jun 28, 2026
Merged
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The first half of Stage 2 (#449 M3): a ragged
decode_stepwhere each row carries its OWN position and length, so sequences of different lengths can share one batch. This is the graph mechanic continuous batching needs. Validated by reference-equivalence, token-exact on CPU and CUDA.Scope: this is 2a-i (the graph). The admit/evict scheduler with staggered arrival and slot recycling (2a-ii) is the remaining Stage 2a piece. Spike-only (
spike/, outside the Cargo workspace), so CI is unaffected.What is in it
emit_decode_ragged(CLIdecode-ragged[-argmax] <B>): per-rowpos[B]/cache_len[B]; RoPE cos/sin as a per-row gather[B,d]bypos[B]; key mask per-row[B,S](valid iffs <= cache_len[b]); the KV write unrolled per row, each row writing at its ownpos[b]viadynamic_update_slice(constant row/layer offsets, dynamicpos[r]; scatter stays avoided per the Phase 2a CUDA caveat). The attention contractions and the LM head are identical to the uniform-B graph; the per-row mask carries the raggedness.xla_llama_decode_ragged(token/pos/cache_len as[B]arrays, threads the rank-5 KV) plus a host-KV-mirror prefill path:xla_llama_ragged_reset,xla_llama_prefill_slot(single-seq prefill, d2h into the mirror slot),xla_llama_commit(h2d the mirror into the resident rank-5 KV). Reuses the scalar prefill vmfb.src/bin/llama_ragged.rs: Phase 1 captures one single-seq reference per prompt; Phase 2 prefills B different-length prompts into slots and runs the ragged decode with each slot at its own position; asserts each slot's stream matches its reference.Validation (reference-equivalence)
Continuous batching has no single reference, so the gate is: B prompts of DIFFERENT lengths (truncations of the 46-token reference prompt), each run independently through the single-seq path, then run together in one ragged batch. Every slot must match its independent reference, proving per-row state is correct and a sequence is unaffected by its batch-mates.
local-taskFindings
B*Ldynamic_update_slices per step add only ~2% (B=4) to ~7.5% (B=8) over uniform-B, and lower fine on CPU and CUDA (compile ~5-7.5s).Full tables in
spike/iree-ffi/FINDINGS_ragged.md; the full Stage 2 plan inspike/openxla/STAGE2_DESIGN.md.Reproduce
Next
2a-ii: the admit/evict scheduler (request queue, mid-stream prefill into a freed slot, eviction on EOS, slot recycling, staggered-arrival validation). Then 2b productizes the engine into
mlxcel-xla.Refs #449