Skip to content

feat: wire OpenXLA/IREE execution into mlxcel-xla and the CLI (#449 Phase 3 M2)#460

Merged
inureyes merged 1 commit into
mainfrom
feat/449-m2-execution-wiring
Jun 27, 2026
Merged

feat: wire OpenXLA/IREE execution into mlxcel-xla and the CLI (#449 Phase 3 M2)#460
inureyes merged 1 commit into
mainfrom
feat/449-m2-execution-wiring

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Phase 3 M2 wires the OpenXLA backend (issue #449, ADR 0004 Track B) from the M1 scaffold and the M2 FFI gate to real, token-exact execution driven end to end from the mlxcel CLI.

  • The C shim (proven first in spike/iree-ffi, then vendored to src/lib/mlxcel-xla/csrc/xla_iree.c) loads the feat: evaluate a Rust-native StableHLO emitter as the compiler-family authoring path (spike) #451-emitted prefill/decode_step vmfbs into one IREE session, uploads the 146 weights once as resident device buffers (bf16 to f32 via the safetensors crate, in the emitter's arg order), threads the KV cache across steps, and returns a token id per step. Output is read by element type, so the same shim drives both the logits graph (host argmax) and the on-device-argmax graph (a 4-byte scalar, the Phase 2b pattern). Rust drives the real Llama-3.2-1B token-exact (48/48) vs HF temp-0 on CPU.
  • Real execution is behind a new mlxcel-xla iree feature, exposed at the root as xla-iree = ["xla-backend", "mlxcel-xla/iree"]. Without it, --features xla-backend still builds with no IREE distribution (CI unchanged) and prefill/decode_step return a built-without-iree error.
  • The feat: evaluate a Rust-native StableHLO emitter as the compiler-family authoring path (spike) #451 emitter gains on-device argmax (Builder::argmax, the index-tracking two-operand reduce) behind a sample flag, and the prefill bucket is bumped to MAX_SEQ (256) so the default chat-template prompt fits one bucket; both stay token-exact. Bundled graphs (assets/llama-3.2-1b) compile to vmfbs at session load via the dist iree-compile, cached by content hash.
  • ComputeBackend::create_session now takes the model directory path (MLX ignores it; XLA loads its own weights and config). run_generate_once gains a self-contained XLA branch, taken when MLXCEL_BACKEND=xla, that skips load_model (which the XLA backend rejects) and drives the session's own greedy loop, then rejoins the shared decode and print path.

How to run (local)

export IREE_DIST=/path/to/iree-dist-3.12.0rc20260626-linux-aarch64
cargo build --release --features xla-iree
MLXCEL_BACKEND=xla ./target/release/mlxcel generate -m <Llama-3.2-1B-Instruct> -p "..." -n 48

Build gotcha (recorded)

A dependency's cargo:rustc-link-arg does not propagate to the final binary link, so mlxcel-xla/build.rs only compiles the shim (its rustc-link-lib propagates) and the IREE runtime link recipe (--whole-archive of the unified runtime plus a --start-group of flatcc and libgcc/libm) lives in a new root build.rs gated on CARGO_FEATURE_XLA_IREE. No-op for default and xla-backend builds.

Scope and limits (M2)

  • Bundled graphs are authored for Llama-3.2-1B-Instruct (config.json is verified, a mismatch errors); prompts are capped at the 256-token bucket; sampling is greedy and text-only.
  • The prebuilt aarch64 IREE dist registers local/vulkan HAL drivers but not CUDA, so M2 runs on CPU (about 2 tok/s). Phase 2b's CUDA used the pip runtime; a GPU device is a follow-up. (This corrects the earlier FINDINGS claim that the dist had a cuda driver.)

Validation

  • Spike token-exact 48/48 (bucket 64 and 256).
  • cargo build default, --features xla-backend, and --features xla-iree (with IREE_DIST) all green.
  • cargo test -p mlxcel-xla 5/5; cargo test --features xla-backend --lib backend:: 6/6.
  • cargo fmt --check and cargo clippy (default, xla-backend, xla-iree) clean.
  • MLXCEL_BACKEND=xla mlxcel generate end to end (raw prompt and default chat template).

Refs #449.

…hase 3 M2)

Phase 3 M2 takes the OpenXLA backend from the M1 scaffold and the M2 FFI gate to real, token-exact execution driven end to end from the mlxcel CLI.

The C shim (proven first in spike/iree-ffi, then vendored to src/lib/mlxcel-xla/csrc/xla_iree.c) loads the #451-emitted prefill and decode_step vmfbs into one IREE session, uploads the 146 model weights once as resident device buffers (bf16 to f32 via the safetensors crate, in the emitter's arg order), threads the KV cache across steps, and returns a token id per step. The output is read by element type, so the same shim drives both the logits-returning graph (host argmax) and the on-device-argmax graph (a 4-byte scalar readback, the Phase 2b pattern). Rust drives the real Llama-3.2-1B token-exact (48/48) against the HF temp-0 reference on CPU.

Real execution sits behind a new mlxcel-xla `iree` feature, exposed at the root as `xla-iree = [xla-backend, mlxcel-xla/iree]`. Without it, `--features xla-backend` still builds with no IREE distribution, so CI is unchanged and prefill/decode_step return a clear built-without-iree error. mlxcel-xla/build.rs compiles the shim against the IREE dist headers; the runtime link recipe (whole-archive of the unified runtime plus a flatcc and libgcc/libm group) lives in a new root build.rs, because a dependency's cargo:rustc-link-arg does not propagate to the final binary link.

The #451 emitter gains on-device argmax (Builder::argmax, the index-tracking two-operand reduce) behind a sample flag on emit_decode and emit_prefill, and the prefill bucket is bumped to MAX_SEQ (256) so the default chat-template prompt fits one bucket; both stay token-exact. The bundled graphs (assets/llama-3.2-1b) are compiled to vmfbs at session load by the dist iree-compile and cached by content hash.

ComputeBackend::create_session now takes the model directory path so a session-driven backend can load its own weights and config (the MLX backend ignores it). run_generate_once gains a self-contained XLA branch, taken when MLXCEL_BACKEND=xla, that skips load_model (which the XLA backend rejects) and drives the session's own greedy loop, then rejoins the shared decode and print path. End to end, MLXCEL_BACKEND=xla mlxcel generate produces coherent text through IREE at about 2 tok/s on CPU.

Scope: the bundled graphs are authored for Llama-3.2-1B-Instruct (config.json is verified and a mismatch errors), prompts are capped at the 256-token bucket, and sampling is greedy and text-only. The prebuilt aarch64 IREE dist registers local and vulkan HAL drivers but not CUDA, so M2 runs on CPU; a GPU device is a follow-up.

Validation: spike token-exact 48/48 (bucket 64 and 256); cargo build default, --features xla-backend, and --features xla-iree (with IREE_DIST); cargo test -p mlxcel-xla 5/5 and --features xla-backend backend:: 6/6; fmt and clippy clean; MLXCEL_BACKEND=xla mlxcel generate runs end to end.

Refs #449.
@inureyes inureyes added type:enhancement New features, capabilities, or significant additions priority:medium Medium priority area:architecture Architecture and code structure changes labels Jun 27, 2026
@inureyes inureyes merged commit 89bb4e3 into main Jun 27, 2026
5 checks passed
@inureyes inureyes deleted the feat/449-m2-execution-wiring branch June 27, 2026 14:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant