feat: productize the continuous-batching engine (#449 M3 Stage 2b)#469
Merged
Conversation
Productize the Stage 2a spike into mlxcel-xla as XlaBatchEngine: B_max slots
share one rank-5 KV cache and serve a request stream, advancing every active
slot one token per step through the ragged decode graph. Replaces the spike's
host-mirror admit (a full d2h+h2d KV round-trip) with a DEVICE-SIDE slot write:
a new request's prompt KV is copied into just its slot's region via
iree_hal_device_transfer_d2d, leaving live slots untouched.
- C shim (csrc/xla_iree.c): rank-5 KV state + xla_llama_ragged_reset /
xla_llama_prefill_slot (device-side d2d) / xla_llama_decode_ragged; the
create signature is unchanged.
- Assets: bundled ragged decode graphs for B_max in {4, 8} (emitter
decode-ragged-argmax), selected by b_max at load.
- iree.rs: ragged FFI + IreeRaggedLlama owner; shared create_ctx /
compile_prefill_and helpers; IreeLlama::prefill_first for references.
- batch.rs: XlaBatchEngine (submit / pump / cancel, per-request EngineEvents,
greedy), a backend-neutral Scheduler split out for unit tests, and
XlaReferenceEngine for validation.
- examples/xla_batch_bench.rs (required-features = xla-iree): the
reference-equivalence + throughput harness.
Validated reference-exact (every request matches its independent single-seq
reference) on CPU local-task (B=4 N=8, B=8 N=10) and CUDA GB10 (B=4 N=8,
B=8 N=16), with mid-stream admit and slot recycling. Backend-neutral at the
request level, so the Stage 2c BatchEngine trait + server adapter wrap it
unchanged; supports_batched_serving() stays false until 2c.
Default and CI builds are unaffected: the engine is behind the iree feature
and the example behind xla-iree.
This was referenced Jun 28, 2026
Merged
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Stage 2b of the OpenXLA/IREE throughput milestone (#449 M3): productize the Stage 2a spike into
mlxcel-xlaasXlaBatchEngine, a continuous-batching engine.B_maxslots share one rank-5 KV cache and serve a request stream; every active slot advances one token per step through the ragged decode graph, and requests of different lengths join and leave the batch at different times.The headline change over the spike: admit is now a device-side slot write. The Stage 2a scheduler (PR #468) admitted a request with a full host-mirror round-trip (d2h the whole rank-5 KV, overwrite one slot, h2d back). This PR replaces that with
iree_hal_device_transfer_d2dof only the admitted slot's region, so live slots are never moved off-device. It is both cheaper and simpler (no refresh/commit), and inherently non-disturbing because only one slot's bytes change.What's here
csrc/xla_iree.c): rank-5 KV state +xla_llama_ragged_reset/xla_llama_prefill_slot(device-side d2d slot write) /xla_llama_decode_ragged. Thexla_llama_createsignature is unchanged (the spike's vestigialvocabarg is dropped).B_max ∈ {4, 8}(emitterdecode-ragged-argmax,module @decode_stepso the shim'sdecode_step.mainresolves them). Selected byb_maxat load; more slot counts = regenerate an asset (assets README).iree.rs: ragged FFI +IreeRaggedLlamaowner; sharedcreate_ctx/compile_prefill_andhelpers (dedup with the single-seq path);IreeLlama::prefill_firstfor the reference path.batch.rs:XlaBatchEngine(submit/pump/cancel, per-requestEngineEvents, greedy), a backend-neutralSchedulersplit out so its admit/evict/cancel bookkeeping is unit-tested without a device, andXlaReferenceEnginefor validation.examples/xla_batch_bench.rs(required-features = ["xla-iree"]): the reference-equivalence + throughput harness.Validation
Reference-equivalence gate (the Stage 2a gate, now over the productized engine + device-side admit): every request's batched stream must equal its independent single-sequence reference, regardless of when it was admitted or which peers shared its batch. All reference-exact, with mid-stream admit and slot recycling exercised in every run:
local-tasklocal-task(tok/s is batched vs the same bench's sequential single-seq baseline; both bundled assets validated on both devices.)
Gates:
cargo fmtclean;cargo clippy -D warningsclean (no-features,iree, and the example);cargo test -p mlxcel-xla10 pass (incl. theSchedulerlogic). Three build configs compile: no-features (CI-equivalent, engine absent),--features iree, and thexla-ireeexample.Scope / non-goals
BatchEnginetrait + server adapter wrap it unchanged.XlaBackend::supports_batched_serving()staysfalseuntil 2c wires it in.B_maxfrom the bundled buckets; contiguous per-slot KV. Sampling, runtime bucket selection, paged KV, and chunked prefill are Stage 2c/2d.ireefeature and the example behindxla-iree.Refs #449 (epic). Follows #462 (Stage 1), #466 (2a-i), #468 (2a-ii). Design:
spike/openxla/STAGE2_DESIGN.md.