Skip to content

feat: Rust-native StableHLO emitter spike for Llama-3.2-1B (#451)#455

Merged
inureyes merged 1 commit into
mainfrom
spike/rust-emitter-451
Jun 26, 2026
Merged

feat: Rust-native StableHLO emitter spike for Llama-3.2-1B (#451)#455
inureyes merged 1 commit into
mainfrom
spike/rust-emitter-451

Conversation

@inureyes

Copy link
Copy Markdown
Member

Rust-native StableHLO emitter, a parallel alternative authoring path to the JAX reference frontend validated in #449 Phase 1 (#451).

A dependency-free Rust crate emits the Llama-3.2-1B StableHLO as text; the existing iree-compile lowers it and the IREE runtime runs it. No Python in the authoring path, no melior, no MLIR/LLVM C++ build.

Results:

  • P0 (toolchain gate): a Rust-emitted single dot_general compiles and runs through IREE, exact versus numpy.
  • P1: the full Rust-emitted decode_step greedy-decodes token-exact (48/48) against the HF transformers temp-0 reference, same coherent text. 20 StableHLO op kinds, 145 dot_general, zero custom_call, zero f64. Validated with the real bf16 weights (independently re-run).
  • Standalone prefill was not separately emitted (the one new op it needs is stablehlo.gather); decode token-exactness was the gate and it is met. A follow-up adds prefill.

Recommendation for the ADR 0004 authoring-frontend decision: adopt the Rust text emitter as the production authoring path and keep the JAX reference as the per-architecture oracle (author and verify in JAX, emit from Rust, diff against the JAX .mlir and tokens). The emitter wins on graph control (int4 dequant insertion and custom_call routing are local edits to one Builder::linear applied to all matmuls) and on toolchain weight (cargo plus iree-compile only).

Standalone under spike/rust-emitter/ with its own empty [workspace] so it never joins the mlxcel build graph. Findings in spike/rust-emitter/FINDINGS.md.

Refs #451.

Evaluate authoring the compiler-family model graph from Rust instead of a JAX
reference, as a parallel alternative to the path #449 validated. A
dependency-free Rust crate emits StableHLO text, the existing iree-compile lowers
it to a CPU vmfb, and the IREE runtime executes it. CPU target only.

P0 (toolchain gate): a Rust-emitted single dot_general round-trips through
iree-compile and the IREE runtime with an exact match vs numpy. Text emission
needs only cargo plus the iree-compile already in spike/openxla/.venv. No melior
and no MLIR/LLVM C++ build were required.

P1: the full Rust-emitted decode_step greedy-decodes token-exact (48/48) against
the HF temp-0 reference in spike/openxla/artifacts/results.json, same bar as #449.
The emitted graph uses 20 StableHLO op kinds, 145 dot_general, zero custom_call,
zero f64. The prompt is streamed through decode_step (cache_len = i), which is
equivalent to a batched prefill under the iota <= cache_len mask.

Standalone under spike/rust-emitter/ with an empty [workspace] table so it never
joins the mlxcel build graph. Reuses the spike/openxla venv and weights and
touches no tracked mlxcel file. FINDINGS.md compares the JAX-reference and
Rust-emitter authoring paths (per-architecture effort, maintainability, graph
control for int4 and custom_call, toolchain weight) and recommends the emitter as
the production authoring path with the JAX reference as the per-architecture
oracle, feeding the ADR 0004 authoring-frontend decision.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant