feat: uniform-B batched decode graph + validation (#449 M3 Stage 1)#462
Merged
Conversation
Stage 1 of the #449 throughput milestone (M3). batch-1 inference uses a small fraction of the device (the GB10 GPU runs batch-1 at ~10% of bandwidth, launch-bound), so batching is the direct throughput lever. This adds a uniform-B (lockstep) static batched decode path to the OpenXLA spike and validates it token-exact with strong aggregate tok/s scaling on both CPU and CUDA. Spike-only: no mlxcel crate changes, so CI is unaffected. Emitter (spike/rust-emitter): emit_decode_batched (CLI decode-batch[-argmax] <B>) threads a leading batch dim B through the whole decode, with a rank-5 KV cache [B,L,S,nkv,d], two-batch-dim dot_general for the GQA score/context contractions, embedding via gather, and a shared scalar pos/cache_len/key-mask (the lockstep assumption). argmax_batched ([B,V] -> [B] i32) shares its reducer block with the scalar argmax. Shim (spike/iree-ffi/iree_gate.c): xla_llama_prefill_batch runs the single-seq prefill once and tiles its rank-4 KV across B rows into a resident rank-5 cache, so the scalar prefill vmfb is reused and no batched prefill graph is needed for Stage 1. xla_llama_decode_batch threads the rank-5 KV and returns one token per row (on-device argmax, or host argmax of [B,V] logits). Driver spike/iree-ffi/src/bin/llama_batch.rs plus the sweep spike/iree-ffi/validate_batch.sh. Two gates pass at every B in 1..64 on CPU and CUDA: token-exact (B identical rows each reproduce the 48-token HF temp-0 reference, all rows identical) and independence (row 0 seeded real while rows >= 1 are perturbed over the same prompt KV, so row 0 stays exact while row 1 diverges, proving the batch dim is genuinely independent). B=1 batched reproduces the scalar baseline exactly. Aggregate tok/s (B * steps / wall): CPU local-task 2.2 (B1) to 69.2 (B32), about 31x, saturating near B16-32; CUDA GB10 5.6 (B1) to 220.7 (B64), about 39x and still climbing. Two findings: the batch-1 to batch-2 jump is large on both because batch-1 is pathologically underutilized (CPU leading-1 dim defeats vectorization/threading, GPU is bandwidth/launch-starved); and the untuned IREE-CUDA codegen is non-monotonic (B=16 is a reproducible catastrophic kernel at 1125 ms/step while B=8/24/64 are fine), so productionizing a fixed B needs codegen tuning or empirical selection. Full result and tables in spike/iree-ffi/FINDINGS_batch.md. Next is Stage 2 (continuous batching: ragged per-sequence cache_len/positions, paged-KV block table, an admit/evict scheduler, a multi-sequence session beyond the single-sequence InferenceSession contract). Refs #449
This was referenced Jun 28, 2026
Merged
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stage 1 of the #449 throughput milestone (M3). batch-1 inference uses a small fraction of the device (the GB10 GPU runs batch-1 at ~10% of bandwidth, launch-bound), so batching is the direct throughput lever. This adds a uniform-B (lockstep) static batched decode path to the OpenXLA spike and validates it token-exact with strong aggregate tok/s scaling on both CPU and CUDA.
Spike-only (
spike/rust-emitter,spike/iree-ffi): no mlxcel crate changes, so CI is unaffected (the spike is outside the Cargo workspace).What is in it
emit_decode_batched(CLIdecode-batch[-argmax] <B>): a leading batch dim B threaded through the whole decode, a rank-5 KV cache[B,L,S,nkv,d], two-batch-dimdot_generalfor the GQA score/context contractions, embedding via gather, and a shared scalarpos/cache_len/key-mask (the lockstep assumption). PlusBuilder::argmax_batched([B,V] -> [B] i32, sharing its reducer block with the scalar argmax).xla_llama_prefill_batch(runs the single-seq prefill once and tiles its rank-4 KV across B rows into a resident rank-5 cache, reusing the scalar prefill vmfb so no batched prefill graph is needed) plusxla_llama_decode_batch.src/bin/llama_batch.rsand the sweepvalidate_batch.sh.Validation
Two gates pass at every B in 1..64 on both CPU and CUDA:
Aggregate tok/s (
B * steps / wall):CPU ~31x (saturates near B16-32, compute-bound); CUDA ~39x at B64 and still climbing.
Findings
Full tables and analysis in
spike/iree-ffi/FINDINGS_batch.md.Reproduce
Next
Stage 2 (continuous batching): ragged per-sequence
cache_len/positions, a paged-KV block table, an admit/evict scheduler, and a multi-sequence session beyond the single-sequenceInferenceSessioncontract. Then Stage 3 (int4 dequant fusion).Refs #449