feat: wire OpenXLA/IREE execution into mlxcel-xla and the CLI (#449 Phase 3 M2)#460
Merged
Conversation
…hase 3 M2) Phase 3 M2 takes the OpenXLA backend from the M1 scaffold and the M2 FFI gate to real, token-exact execution driven end to end from the mlxcel CLI. The C shim (proven first in spike/iree-ffi, then vendored to src/lib/mlxcel-xla/csrc/xla_iree.c) loads the #451-emitted prefill and decode_step vmfbs into one IREE session, uploads the 146 model weights once as resident device buffers (bf16 to f32 via the safetensors crate, in the emitter's arg order), threads the KV cache across steps, and returns a token id per step. The output is read by element type, so the same shim drives both the logits-returning graph (host argmax) and the on-device-argmax graph (a 4-byte scalar readback, the Phase 2b pattern). Rust drives the real Llama-3.2-1B token-exact (48/48) against the HF temp-0 reference on CPU. Real execution sits behind a new mlxcel-xla `iree` feature, exposed at the root as `xla-iree = [xla-backend, mlxcel-xla/iree]`. Without it, `--features xla-backend` still builds with no IREE distribution, so CI is unchanged and prefill/decode_step return a clear built-without-iree error. mlxcel-xla/build.rs compiles the shim against the IREE dist headers; the runtime link recipe (whole-archive of the unified runtime plus a flatcc and libgcc/libm group) lives in a new root build.rs, because a dependency's cargo:rustc-link-arg does not propagate to the final binary link. The #451 emitter gains on-device argmax (Builder::argmax, the index-tracking two-operand reduce) behind a sample flag on emit_decode and emit_prefill, and the prefill bucket is bumped to MAX_SEQ (256) so the default chat-template prompt fits one bucket; both stay token-exact. The bundled graphs (assets/llama-3.2-1b) are compiled to vmfbs at session load by the dist iree-compile and cached by content hash. ComputeBackend::create_session now takes the model directory path so a session-driven backend can load its own weights and config (the MLX backend ignores it). run_generate_once gains a self-contained XLA branch, taken when MLXCEL_BACKEND=xla, that skips load_model (which the XLA backend rejects) and drives the session's own greedy loop, then rejoins the shared decode and print path. End to end, MLXCEL_BACKEND=xla mlxcel generate produces coherent text through IREE at about 2 tok/s on CPU. Scope: the bundled graphs are authored for Llama-3.2-1B-Instruct (config.json is verified and a mismatch errors), prompts are capped at the 256-token bucket, and sampling is greedy and text-only. The prebuilt aarch64 IREE dist registers local and vulkan HAL drivers but not CUDA, so M2 runs on CPU; a GPU device is a follow-up. Validation: spike token-exact 48/48 (bucket 64 and 256); cargo build default, --features xla-backend, and --features xla-iree (with IREE_DIST); cargo test -p mlxcel-xla 5/5 and --features xla-backend backend:: 6/6; fmt and clippy clean; MLXCEL_BACKEND=xla mlxcel generate runs end to end. Refs #449.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Phase 3 M2 wires the OpenXLA backend (issue #449, ADR 0004 Track B) from the M1 scaffold and the M2 FFI gate to real, token-exact execution driven end to end from the
mlxcelCLI.spike/iree-ffi, then vendored tosrc/lib/mlxcel-xla/csrc/xla_iree.c) loads the feat: evaluate a Rust-native StableHLO emitter as the compiler-family authoring path (spike) #451-emittedprefill/decode_stepvmfbs into one IREE session, uploads the 146 weights once as resident device buffers (bf16 to f32 via thesafetensorscrate, in the emitter's arg order), threads the KV cache across steps, and returns a token id per step. Output is read by element type, so the same shim drives both the logits graph (host argmax) and the on-device-argmax graph (a 4-byte scalar, the Phase 2b pattern). Rust drives the real Llama-3.2-1B token-exact (48/48) vs HF temp-0 on CPU.mlxcel-xlaireefeature, exposed at the root asxla-iree = ["xla-backend", "mlxcel-xla/iree"]. Without it,--features xla-backendstill builds with no IREE distribution (CI unchanged) andprefill/decode_stepreturn a built-without-iree error.Builder::argmax, the index-tracking two-operand reduce) behind asampleflag, and the prefill bucket is bumped toMAX_SEQ(256) so the default chat-template prompt fits one bucket; both stay token-exact. Bundled graphs (assets/llama-3.2-1b) compile to vmfbs at session load via the distiree-compile, cached by content hash.ComputeBackend::create_sessionnow takes the model directory path (MLX ignores it; XLA loads its own weights and config).run_generate_oncegains a self-contained XLA branch, taken whenMLXCEL_BACKEND=xla, that skipsload_model(which the XLA backend rejects) and drives the session's own greedy loop, then rejoins the shared decode and print path.How to run (local)
Build gotcha (recorded)
A dependency's
cargo:rustc-link-argdoes not propagate to the final binary link, somlxcel-xla/build.rsonly compiles the shim (itsrustc-link-libpropagates) and the IREE runtime link recipe (--whole-archiveof the unified runtime plus a--start-groupof flatcc and libgcc/libm) lives in a new rootbuild.rsgated onCARGO_FEATURE_XLA_IREE. No-op for default andxla-backendbuilds.Scope and limits (M2)
config.jsonis verified, a mismatch errors); prompts are capped at the 256-token bucket; sampling is greedy and text-only.local/vulkanHAL drivers but not CUDA, so M2 runs on CPU (about 2 tok/s). Phase 2b's CUDA used the pip runtime; a GPU device is a follow-up. (This corrects the earlier FINDINGS claim that the dist had a cuda driver.)Validation
cargo builddefault,--features xla-backend, and--features xla-iree(withIREE_DIST) all green.cargo test -p mlxcel-xla5/5;cargo test --features xla-backend --lib backend::6/6.cargo fmt --checkandcargo clippy(default,xla-backend,xla-iree) clean.MLXCEL_BACKEND=xla mlxcel generateend to end (raw prompt and default chat template).Refs #449.