feat: Rust-native StableHLO emitter spike for Llama-3.2-1B (#451)#455
Merged
Conversation
Evaluate authoring the compiler-family model graph from Rust instead of a JAX reference, as a parallel alternative to the path #449 validated. A dependency-free Rust crate emits StableHLO text, the existing iree-compile lowers it to a CPU vmfb, and the IREE runtime executes it. CPU target only. P0 (toolchain gate): a Rust-emitted single dot_general round-trips through iree-compile and the IREE runtime with an exact match vs numpy. Text emission needs only cargo plus the iree-compile already in spike/openxla/.venv. No melior and no MLIR/LLVM C++ build were required. P1: the full Rust-emitted decode_step greedy-decodes token-exact (48/48) against the HF temp-0 reference in spike/openxla/artifacts/results.json, same bar as #449. The emitted graph uses 20 StableHLO op kinds, 145 dot_general, zero custom_call, zero f64. The prompt is streamed through decode_step (cache_len = i), which is equivalent to a batched prefill under the iota <= cache_len mask. Standalone under spike/rust-emitter/ with an empty [workspace] table so it never joins the mlxcel build graph. Reuses the spike/openxla venv and weights and touches no tracked mlxcel file. FINDINGS.md compares the JAX-reference and Rust-emitter authoring paths (per-architecture effort, maintainability, graph control for int4 and custom_call, toolchain weight) and recommends the emitter as the production authoring path with the JAX reference as the per-architecture oracle, feeding the ADR 0004 authoring-frontend decision.
Open
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Rust-native StableHLO emitter, a parallel alternative authoring path to the JAX reference frontend validated in #449 Phase 1 (#451).
A dependency-free Rust crate emits the Llama-3.2-1B StableHLO as text; the existing
iree-compilelowers it and the IREE runtime runs it. No Python in the authoring path, nomelior, no MLIR/LLVM C++ build.Results:
dot_generalcompiles and runs through IREE, exact versus numpy.decode_stepgreedy-decodes token-exact (48/48) against the HF transformers temp-0 reference, same coherent text. 20 StableHLO op kinds, 145dot_general, zerocustom_call, zero f64. Validated with the real bf16 weights (independently re-run).prefillwas not separately emitted (the one new op it needs isstablehlo.gather); decode token-exactness was the gate and it is met. A follow-up adds prefill.Recommendation for the ADR 0004 authoring-frontend decision: adopt the Rust text emitter as the production authoring path and keep the JAX reference as the per-architecture oracle (author and verify in JAX, emit from Rust, diff against the JAX
.mlirand tokens). The emitter wins on graph control (int4 dequant insertion andcustom_callrouting are local edits to oneBuilder::linearapplied to all matmuls) and on toolchain weight (cargo plus iree-compile only).Standalone under
spike/rust-emitter/with its own empty[workspace]so it never joins the mlxcel build graph. Findings inspike/rust-emitter/FINDINGS.md.Refs #451.