Skip to content

feat: uniform-B batched decode graph + validation (#449 M3 Stage 1)#462

Merged
inureyes merged 1 commit into
mainfrom
feature/issue-449-batched-decode
Jun 28, 2026
Merged

feat: uniform-B batched decode graph + validation (#449 M3 Stage 1)#462
inureyes merged 1 commit into
mainfrom
feature/issue-449-batched-decode

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Stage 1 of the #449 throughput milestone (M3). batch-1 inference uses a small fraction of the device (the GB10 GPU runs batch-1 at ~10% of bandwidth, launch-bound), so batching is the direct throughput lever. This adds a uniform-B (lockstep) static batched decode path to the OpenXLA spike and validates it token-exact with strong aggregate tok/s scaling on both CPU and CUDA.

Spike-only (spike/rust-emitter, spike/iree-ffi): no mlxcel crate changes, so CI is unaffected (the spike is outside the Cargo workspace).

What is in it

  • Emitter emit_decode_batched (CLI decode-batch[-argmax] <B>): a leading batch dim B threaded through the whole decode, a rank-5 KV cache [B,L,S,nkv,d], two-batch-dim dot_general for the GQA score/context contractions, embedding via gather, and a shared scalar pos/cache_len/key-mask (the lockstep assumption). Plus Builder::argmax_batched ([B,V] -> [B] i32, sharing its reducer block with the scalar argmax).
  • Shim xla_llama_prefill_batch (runs the single-seq prefill once and tiles its rank-4 KV across B rows into a resident rank-5 cache, reusing the scalar prefill vmfb so no batched prefill graph is needed) plus xla_llama_decode_batch.
  • Driver src/bin/llama_batch.rs and the sweep validate_batch.sh.

Validation

Two gates pass at every B in 1..64 on both CPU and CUDA:

  1. token-exact: B identical rows each reproduce the 48-token HF temp-0 reference, all rows byte-identical.
  2. independence: row 0 seeded from the real first token while rows >= 1 are perturbed over the same prompt KV, so row 0 stays 48/48 while row 1 diverges (the batch dim is genuinely independent, not collapsed or averaged). B=1 batched reproduces the scalar baseline exactly.

Aggregate tok/s (B * steps / wall):

B CPU local-task CUDA GB10
1 2.20 5.61
2 23.7 38.6
4 36.4 45.7
8 54.9 89.9
16 64.3 14.2 (codegen cliff)
24 n/a 195.7
32 69.2 147.1
64 n/a 220.7

CPU ~31x (saturates near B16-32, compute-bound); CUDA ~39x at B64 and still climbing.

Findings

  • The batch-1 to batch-2 jump is large on both devices: batch-1 is pathologically underutilized (CPU leading-1 dim defeats vectorization/threading; GPU is bandwidth/launch-starved), so the "1x" baseline is genuinely bad and even B=2 is a big win.
  • The untuned IREE-CUDA codegen is non-monotonic: B=16 is a reproducible catastrophic kernel (1125 ms/step, 12x slower than B=8) while B=8/24/64 are fine and B=64 is the peak. A codegen-selection artifact, not a hardware or approach limit; productionizing a fixed B needs codegen tuning or empirical selection (do not assume a smooth curve).

Full tables and analysis in spike/iree-ffi/FINDINGS_batch.md.

Reproduce

cd spike/iree-ffi
./validate_batch.sh cpu       # local-task sweep via the prebuilt dist
./validate_batch.sh cuda      # cuda sweep via the source-built runtime

Next

Stage 2 (continuous batching): ragged per-sequence cache_len/positions, a paged-KV block table, an admit/evict scheduler, and a multi-sequence session beyond the single-sequence InferenceSession contract. Then Stage 3 (int4 dequant fusion).

Refs #449

Stage 1 of the #449 throughput milestone (M3). batch-1 inference uses a small fraction of the device (the GB10 GPU runs batch-1 at ~10% of bandwidth, launch-bound), so batching is the direct throughput lever. This adds a uniform-B (lockstep) static batched decode path to the OpenXLA spike and validates it token-exact with strong aggregate tok/s scaling on both CPU and CUDA. Spike-only: no mlxcel crate changes, so CI is unaffected.

Emitter (spike/rust-emitter): emit_decode_batched (CLI decode-batch[-argmax] <B>) threads a leading batch dim B through the whole decode, with a rank-5 KV cache [B,L,S,nkv,d], two-batch-dim dot_general for the GQA score/context contractions, embedding via gather, and a shared scalar pos/cache_len/key-mask (the lockstep assumption). argmax_batched ([B,V] -> [B] i32) shares its reducer block with the scalar argmax.

Shim (spike/iree-ffi/iree_gate.c): xla_llama_prefill_batch runs the single-seq prefill once and tiles its rank-4 KV across B rows into a resident rank-5 cache, so the scalar prefill vmfb is reused and no batched prefill graph is needed for Stage 1. xla_llama_decode_batch threads the rank-5 KV and returns one token per row (on-device argmax, or host argmax of [B,V] logits).

Driver spike/iree-ffi/src/bin/llama_batch.rs plus the sweep spike/iree-ffi/validate_batch.sh. Two gates pass at every B in 1..64 on CPU and CUDA: token-exact (B identical rows each reproduce the 48-token HF temp-0 reference, all rows identical) and independence (row 0 seeded real while rows >= 1 are perturbed over the same prompt KV, so row 0 stays exact while row 1 diverges, proving the batch dim is genuinely independent). B=1 batched reproduces the scalar baseline exactly.

Aggregate tok/s (B * steps / wall): CPU local-task 2.2 (B1) to 69.2 (B32), about 31x, saturating near B16-32; CUDA GB10 5.6 (B1) to 220.7 (B64), about 39x and still climbing. Two findings: the batch-1 to batch-2 jump is large on both because batch-1 is pathologically underutilized (CPU leading-1 dim defeats vectorization/threading, GPU is bandwidth/launch-starved); and the untuned IREE-CUDA codegen is non-monotonic (B=16 is a reproducible catastrophic kernel at 1125 ms/step while B=8/24/64 are fine), so productionizing a fixed B needs codegen tuning or empirical selection.

Full result and tables in spike/iree-ffi/FINDINGS_batch.md. Next is Stage 2 (continuous batching: ragged per-sequence cache_len/positions, paged-KV block table, an admit/evict scheduler, a multi-sequence session beyond the single-sequence InferenceSession contract).

Refs #449
@inureyes inureyes added type:enhancement New features, capabilities, or significant additions status:review Under review priority:medium Medium priority area:architecture Architecture and code structure changes labels Jun 28, 2026
@inureyes inureyes merged commit e5a8dd4 into main Jun 28, 2026
5 checks passed
@inureyes inureyes deleted the feature/issue-449-batched-decode branch June 28, 2026 11:08
@inureyes inureyes added status:done Completed and removed status:review Under review labels Jun 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant