Skip to content

feat: ragged continuous-batching decode graph (#449 M3 Stage 2a-i)#466

Merged
inureyes merged 1 commit into
mainfrom
feature/issue-449-ragged-decode
Jun 28, 2026
Merged

feat: ragged continuous-batching decode graph (#449 M3 Stage 2a-i)#466
inureyes merged 1 commit into
mainfrom
feature/issue-449-ragged-decode

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

The first half of Stage 2 (#449 M3): a ragged decode_step where each row carries its OWN position and length, so sequences of different lengths can share one batch. This is the graph mechanic continuous batching needs. Validated by reference-equivalence, token-exact on CPU and CUDA.

Scope: this is 2a-i (the graph). The admit/evict scheduler with staggered arrival and slot recycling (2a-ii) is the remaining Stage 2a piece. Spike-only (spike/, outside the Cargo workspace), so CI is unaffected.

What is in it

  • Emitter emit_decode_ragged (CLI decode-ragged[-argmax] <B>): per-row pos[B]/cache_len[B]; RoPE cos/sin as a per-row gather [B,d] by pos[B]; key mask per-row [B,S] (valid iff s <= cache_len[b]); the KV write unrolled per row, each row writing at its own pos[b] via dynamic_update_slice (constant row/layer offsets, dynamic pos[r]; scatter stays avoided per the Phase 2a CUDA caveat). The attention contractions and the LM head are identical to the uniform-B graph; the per-row mask carries the raggedness.
  • Shim xla_llama_decode_ragged (token/pos/cache_len as [B] arrays, threads the rank-5 KV) plus a host-KV-mirror prefill path: xla_llama_ragged_reset, xla_llama_prefill_slot (single-seq prefill, d2h into the mirror slot), xla_llama_commit (h2d the mirror into the resident rank-5 KV). Reuses the scalar prefill vmfb.
  • Driver src/bin/llama_ragged.rs: Phase 1 captures one single-seq reference per prompt; Phase 2 prefills B different-length prompts into slots and runs the ragged decode with each slot at its own position; asserts each slot's stream matches its reference.

Validation (reference-equivalence)

Continuous batching has no single reference, so the gate is: B prompts of DIFFERENT lengths (truncations of the 46-token reference prompt), each run independently through the single-seq path, then run together in one ragged batch. Every slot must match its independent reference, proving per-row state is correct and a sequence is unaffected by its batch-mates.

Device B Prompt lengths Result ragged ms/step agg tok/s vs uniform-B
CPU local-task 4 46,43,40,37 4/4 slots 48/48 EXACT 115 34.7 +4.7%
CUDA GB10 4 46,43,40,37 4/4 slots 48/48 EXACT 89 44.8 +1.9%
CUDA GB10 8 46,43,40,37,34,31,28,25 8/8 slots 48/48 EXACT 96 83.6 +7.5%

Findings

  • Ragged decode is correct. Eight sequences of eight different lengths in one batch each reproduce their independent single-seq stream exactly; per-row pos/cache_len/mask and the per-row KV write are all right.
  • The per-row KV-write unroll is cheap. B*L dynamic_update_slices per step add only ~2% (B=4) to ~7.5% (B=8) over uniform-B, and lower fine on CPU and CUDA (compile ~5-7.5s).
  • CPU large-B validation is reference-bound (Phase 1 runs B single-seq references sequentially), a harness cost, not a graph cost; the B sweep runs on CUDA.

Full tables in spike/iree-ffi/FINDINGS_ragged.md; the full Stage 2 plan in spike/openxla/STAGE2_DESIGN.md.

Reproduce

cd spike/iree-ffi
EMIT=../rust-emitter/target/release/emit; IC=./iree-dist/bin/iree-compile
F="--iree-input-type=stablehlo --iree-hal-target-device=local --iree-hal-local-target-device-backends=llvm-cpu"
$EMIT prefill-argmax p.mlir && $IC $F p.mlir -o p.vmfb
$EMIT decode-argmax s.mlir  && $IC $F s.mlir -o s.vmfb
$EMIT decode-ragged-argmax 4 r.mlir && $IC $F r.mlir -o r.vmfb
IREE_DIST=./iree-dist cargo run --release --bin llama_ragged -- \
  --batch 4 --device local-task --prefill p.vmfb --sdecode s.vmfb --decode r.vmfb

Next

2a-ii: the admit/evict scheduler (request queue, mid-stream prefill into a freed slot, eviction on EOS, slot recycling, staggered-arrival validation). Then 2b productizes the engine into mlxcel-xla.

Refs #449

The first half of Stage 2: a ragged decode_step where each row carries its own position and length, so sequences of different lengths can share one batch (the graph mechanic continuous batching needs). Validated by reference-equivalence, token-exact on CPU and CUDA. This is 2a-i (the graph); the admit/evict scheduler with staggered arrival (2a-ii) is the remaining Stage 2a piece. Spike-only, so CI is unaffected.

Emitter (spike/rust-emitter): emit_decode_ragged (CLI decode-ragged[-argmax] <B>) gives each row its own pos[B]/cache_len[B]; RoPE cos/sin become a per-row gather [B,d] by pos[B]; the key mask is per-row [B,S] (valid iff s <= cache_len[b]); and the KV write is unrolled per row, each row writing its new K/V at its own pos[b] via dynamic_update_slice (constant row/layer offsets, dynamic pos[r]; scatter stays avoided per the Phase 2a CUDA caveat). The attention contractions and the LM head are identical to the uniform-B graph; the per-row mask carries the raggedness.

Shim (spike/iree-ffi/iree_gate.c): xla_llama_decode_ragged takes token/pos/cache_len as [B] arrays and threads the rank-5 KV. Prefill-into-slot uses a host KV mirror: xla_llama_ragged_reset, xla_llama_prefill_slot (single-seq prefill, d2h into the mirror slot), and xla_llama_commit (h2d the mirror into the resident rank-5 KV), reusing the scalar prefill vmfb.

Driver spike/iree-ffi/src/bin/llama_ragged.rs validates by reference-equivalence: Phase 1 captures one single-seq reference per prompt for B prompts of different lengths (truncations of the reference prompt); Phase 2 prefills them into slots and runs the ragged decode with every slot at its own position; each slot's stream must match its independent reference.

Results, all reference-exact 48/48: CPU local-task B=4 (lengths 46,43,40,37) 115 ms/step, 34.7 tok/s; CUDA GB10 B=4 89 ms/step, 44.8 tok/s; CUDA B=8 (lengths 46..25) 96 ms/step, 83.6 tok/s. The per-row KV-write unroll adds only ~2% (B=4) to ~7.5% (B=8) over the uniform-B graph and lowers fine on both targets (compile ~5-7.5s). A sequence's output is invariant to its batch-mates and their lengths.

Full result and tables in spike/iree-ffi/FINDINGS_ragged.md; the full Stage 2 plan (decisions: spike-first, a common BatchEngine trait, contiguous-per-slot KV) is in spike/openxla/STAGE2_DESIGN.md. Next is 2a-ii (the admit/evict scheduler).

Refs #449
@inureyes inureyes added type:enhancement New features, capabilities, or significant additions status:review Under review priority:medium Medium priority area:architecture Architecture and code structure changes labels Jun 28, 2026
@inureyes inureyes merged commit c2c76a7 into main Jun 28, 2026
5 checks passed
@inureyes inureyes deleted the feature/issue-449-ragged-decode branch June 28, 2026 12:34
@inureyes inureyes added status:done Completed and removed status:review Under review labels Jun 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant