Skip to content

feat: OpenXLA reference backend - export-route spike through 4-bit quantized decode #449

Description

@inureyes

TL;DR

Stand up the first non-MLX compute backend for mlxcel: an OpenXLA / StableHLO backend, validated on one small model (Llama-3.2-1B class) all the way through 4-bit quantized decode. This issue is written to be self-contained so an engineer or agent on a Linux (CUDA) host can start immediately without prior mlxcel context. The early phases need no mlxcel session contract and produce a standalone runnable plus a findings writeup; only the final integration phase binds to the session contract from #448. This is the second reference backend that forces the seam in ADR 0004 to be genuine, and it is the validation of the export-first model-definition strategy.

This work runs in parallel with #448 (the session-contract redesign on the MLX side). The only synchronization point is Phase 3 (integration). Phases 0 to 2 are independent of mlxcel and can start now.

Background: what mlxcel is, and why a non-MLX backend

mlxcel is a Rust inference runtime for Apple Silicon (primary) and Linux/CUDA (secondary). It links MLX C++ through a cxx FFI crate, and every model runs its forward pass through MLX. MLX owns the device abstraction (Metal, CUDA, CPU) and a dynamic-graph executor with its own memory planner. Adding an accelerator that already has an MLX backend is cheap, but there is no way today to host an execution engine that MLX does not cover, where forward computation, KV-cache representation, and weight loading happen entirely outside MLX.

The motivating targets are FuriosaAI (TCP / RNGD), Tenstorrent, and an OpenXLA-based path. All three are graph-compiler backends: they ingest a whole-graph (or whole-module) description, compile it, and execute through their own runtime with static shapes and their own memory placement. OpenXLA (StableHLO compiled by XLA, run through PJRT) is the broadest of the three and the natural first non-MLX reference backend, because OpenXLA and Tenstorrent (via TT-MLIR) converge on StableHLO, and PJRT/IREE turn hardware support into a target-plugin problem.

The architecture direction is recorded in docs/adr/0004-compute-backend-session-seam-and-stablehlo-family.md (ADR 0004). Read it. The summary: the compute-backend seam is an inference-session engine (token-in, token-out, backend owns its KV), not an op-level abstraction and not a load factory. A non-MLX backend implements that session contract with its own engine behind it. PR #446 (issue #338) landed a provisional load-boundary seam; #448 redraws it as the session contract this issue's Phase 3 will bind to.

Decisions already locked (ADR 0004, 2026-06-26)

  • Model definition for the compiler family is export-first, validated by the spike in this issue: import an exported graph (PyTorch via torch.export / torch-mlir, or a JAX/Flax reference, lowered to StableHLO) rather than hand-writing StableHLO per architecture. In-tree hand-written StableHLO emission is the fallback if the export route does not hold. Confirming or rejecting the export route is an explicit deliverable of this issue.
  • The first-milestone success bar includes 4-bit quantized decode, not just fp16. fp16 single-sequence coherent output is the intermediate checkpoint.
  • The backend starts single-sequence. Batching, paged KV, and speculative decode stay MLX-session features for now and are a later abstraction phase, out of scope here.
  • The OpenXLA backend lives in its own default-off crate. XLA / PJRT dependencies must never touch the default Apple-Silicon or CUDA builds of mlxcel.
  • Reference model is small (Llama-3.2-1B class): 4-bit for the success bar, an fp16 variant as the intermediate checkpoint.

The core technical challenge

An autoregressive LLM decode loop with a KV cache has to be expressed for a static-shape graph compiler. The standard pattern is to export two graphs (or one graph parameterized by a small set of bucketed shapes) and let the host drive the token loop:

  • Prefill graph. Input: prompt token ids (padded to a fixed bucket length), position ids, attention mask. Output: logits for the last position (or all positions) plus the populated KV cache tensors. The KV cache is a graph output (or a donated in/out buffer).
  • Decode-step graph. Input: one new token id, its position id, the current KV cache tensors (as donated in/out buffers), and the running sequence length. Output: next-token logits and the updated KV cache. The host calls this once per generated token.

Things to pin down and record during the spike:

  • KV cache representation as graph arguments: fixed-capacity tensors [layers, 2, n_kv_heads, max_seq, head_dim] (or per-layer), with an index/position telling the graph where to write. Donated buffers (XLA input_output_alias / buffer donation, or PJRT donation) avoid reallocating the cache each step.
  • Static shapes: choose a max sequence length and, for prefill, a small set of padded prompt-length buckets. Decode-step is shape-static (one token). Document the bucketing scheme.
  • RoPE / position handling and the causal attention mask under fixed shapes (mask out padding and future positions).
  • Sampling: do it on-device in the graph where possible (argmax for greedy; the full sampler later) and return token ids, to avoid a per-token device-to-host logits copy. For the spike, returning last-position logits and sampling on host is acceptable to get coherence first; note the perf implication.

Quantization: the success bar and its open question

4-bit quantized decode is the success bar, and how XLA / StableHLO represents 4-bit weights is the central unknown to characterize. Evaluate these routes and report which is viable, with rough perf and memory numbers:

  • Dequantize-in-graph. Store weights packed int4 plus scales (and zero points / biases), and dequantize to bf16 inside the graph immediately before each matmul. This is the correctness-first route. It likely does not save memory bandwidth on the matmul itself unless XLA fuses the dequant into the matmul, but it is the simplest to get correct. Measure whether XLA fuses it.
  • Custom call to a target int4 kernel. StableHLO custom_call to a hand-provided or vendor int4 GEMM. More work, target-specific, but the only route to real 4-bit memory and bandwidth wins on hardware that supports it.
  • A framework quantized export. Whatever int4 path the chosen export toolchain already supports lowered to StableHLO. Characterize if it exists and what it lowers to.

Note: matching mlxcel's exact 4-bit format (mlx-community affine quantization, group size, bf16 scales) is NOT required for this spike. The spike validates the route and characterizes feasibility. Exact weight-format alignment with mlxcel's loader is an integration concern for Phase 3.

Environment setup (Linux / CUDA host)

This Mac has no XLA/export stack, which is why this runs on the Linux box. The agent chooses and records the exact stack in Phase 0. Two viable routes:

  • PyTorch route. torch plus torch.export to capture the model graph, then either torch-mlir (PyTorch to StableHLO) or torch-xla (PJRT execution). HF transformers provides the reference Llama implementation and the tokenizer.
  • JAX route. A JAX/Flax Llama reference (for example the one in transformers Flax, or a known-good JAX LLM), jax.jit, then jax.export / StableHLO export, run via PJRT or IREE.

Runtime options: OpenXLA PJRT (CPU plugin is sufficient for the milestone; CUDA plugin if available on the box), or IREE (iree-compile to a target, iree-runtime to execute). CPU is an acceptable target for the milestone; GPU is a bonus. Pin all versions and record them in the findings doc so the result is reproducible.

Reference model

Llama-3.2-1B-Instruct (small, well-understood dense transformer, fast iteration). Use the HF safetensors weights: bf16/fp16 for the intermediate checkpoint, and a 4-bit scheme the chosen route supports for the success bar (characterize the scheme). HF tokenizer for the spike; the mlxcel tokenizer is reused only in Phase 3 integration.

Provisional session-contract sketch (the Phase 3 integration target, defined finally in #448)

Build toward this shape so Phase 3 integration is small. It is provisional; #448 finalizes it.

  • Backend::load(model_path, config) -> Session (the OpenXLA backend compiles/loads the exported graphs and weights here).
  • Session::prefill(token_ids: &[u32]) -> PrefillOut where the session allocates and owns the per-sequence KV state internally and returns the first-step logits or the first sampled token id.
  • Session::decode_step(&mut state, token_id: u32) -> StepOut returning the next token id (sampling on-device) or logits, advancing the session-owned KV.
  • The session advertises capabilities (single-sequence only for this backend at first; no batching/paged/speculative).
  • The control plane above the seam (tokenizer, chat template, sampling policy, OpenAI / llama-server API, request lifecycle) is mlxcel's and is reused unchanged.

Phased plan, deliverables, and acceptance per phase

Phase 0: environment + stack selection.
Set up the Linux env, choose the export+runtime stack (PyTorch+torch-mlir / torch-xla, or JAX+PJRT/IREE), pin versions. Deliverable: a reproducible env spec recorded in the findings doc.

Phase 1: fp16 export spike (mlxcel-independent).
Export prefill and decode-step StableHLO for Llama-3.2-1B fp16, run on PJRT/IREE, greedy-decode a fixed prompt, and produce a coherent continuation. Verify against a reference: run the same model+prompt at temperature 0 through HF transformers and confirm the OpenXLA output is coherent and ideally token-matching for greedy. Deliverable: a standalone runnable harness (a script or a small Rust/Python program, in the new crate or a spike/ dir) plus a findings doc covering the chosen toolchain, how KV cache and shape bucketing were handled, the RoPE/mask handling, and what worked or failed.
Acceptance: fp16 single-sequence greedy produces coherent output that matches the HF reference continuation for a fixed prompt (token-level match for greedy is the goal; explain any divergence).

Phase 2: 4-bit quantized decode (the success bar).
Get coherent output with 4-bit weights, using whichever lowering route (dequant-in-graph / custom_call / framework path) the spike finds viable. Characterize perf and memory versus the fp16 path. Deliverable: the quantized harness plus a findings section resolving which int4 route works on the target and its cost.
Acceptance: 4-bit quantized greedy decode of the reference model on OpenXLA produces coherent output; the int4 lowering route is documented with rough perf/memory.

Phase 3: integration behind the mlxcel session contract (joins #448).
Wrap the backend in a new default-off crate behind the session contract, reuse mlxcel weight loading / tokenizer / sampling / serving, and produce coherent output through mlxcel's generate (and ideally the server) for the reference model. This phase depends on the #448 session trait existing; until then, build against the provisional sketch above.
Acceptance: the reference model runs through mlxcel end to end on the OpenXLA backend (4-bit), the backend is isolated in a default-off crate with zero effect on default Apple-Silicon/CUDA builds, and the export-vs-handwrite model-definition decision is recorded.

Overall deliverables

  • A standalone OpenXLA runnable for the reference model (fp16 and 4-bit), reproducible from the pinned env.
  • A findings writeup (suggested: docs/backends/openxla-spike.md, or issue comments if a doc PR is premature) that: confirms or rejects the export-first route (this resolves the open model-definition-strategy decision in ADR 0004); documents the KV/shape-bucketing approach; characterizes the int4 lowering route with perf/memory; records the pinned toolchain.
  • Phase 3 only: a default-off crate (suggested name mlxcel-backend-xla or mlxcel-xla) wiring the backend into mlxcel behind the session contract.

Concurrency and dependencies

Conventions and constraints

  • Work on lablup/mlxcel. Branch and PR per the repo conventions: English commit messages / PR titles / bodies, conventional-commit prefixes, no AI attribution anywhere (no Co-Authored-By, no "Generated with"), no em dashes (use commas, periods, parentheses).
  • The XLA/PJRT crate or spike code must be default-off and must not enter the default build graph. Verify a plain cargo build --release --features metal,accelerate (macOS) and the default Linux/CUDA build are unaffected.
  • Working-tree safety: stage only your own files by explicit path; never run tree-wide git add -A, git stash -u, git clean, or git reset --hard.
  • Validate with the real reference model before claiming a phase done; coherence against the HF reference is the bar, not synthetic shapes.

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    area:architectureArchitecture and code structure changespriority:mediumMedium prioritystatus:readyReady to be worked ontype:enhancementNew features, capabilities, or significant additions

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions