Skip to content

feat: emit Llama-3.2-1B prefill from the Rust StableHLO emitter (#451 follow-up)#457

Merged
inureyes merged 1 commit into
mainfrom
spike/rust-emitter-prefill-451
Jun 26, 2026
Merged

feat: emit Llama-3.2-1B prefill from the Rust StableHLO emitter (#451 follow-up)#457
inureyes merged 1 commit into
mainfrom
spike/rust-emitter-prefill-451

Conversation

@inureyes

Copy link
Copy Markdown
Member

Follow-up to #451: the Rust-native StableHLO emitter now also emits the Llama-3.2-1B prefill graph. The #451 spike emitted decode_step token-exact but skipped standalone prefill; its findings noted the one missing op was stablehlo.gather.

  • Prefill emits and compiles (iree-compile llvm-cpu, exit 0).
  • Full Rust-emitted prefill + decode greedy-decodes token-exact (48/48) against the HF temp-0 reference; decode remains 48/48 after the shared-builder additions (independently re-run).
  • The one new op was stablehlo.gather (the embedding lookup and the per-position cos/sin lookup), as feat: evaluate a Rust-native StableHLO emitter as the compiler-family authoring path (spike) #451 predicted. GQA needed one transpose per layer; the KV write uses dynamic_update_slice (not the range-slice scatter), so it stays GPU-portable.

All changes under spike/rust-emitter/ (own empty [workspace], no mlxcel build impact). FINDINGS.md and README.md updated to drop the stale "prefill not emitted" note.

Refs #451.

…w-up)

Extend the Rust-native StableHLO text emitter to author the bucketed prefill
graph alongside decode_step. Adds builder ops gather (the multi-token embedding
lookup embed[tokens] and the per-position cos/sin lookup), transpose, and
linear_seq (the [Lp,K] activation matmul), plus model::emit_prefill matching the
JAX prefill signature: tokens[Lp], positions[Lp], real_len, no input caches
(zero-initialized internally, bucket Lp=64), an [Lp,Lp] causal mask (j<=i), the
[Lp] KV block written per layer with one dynamic_update_slice, and the last logit
sliced at real_len-1.

The gather dimension_numbers and slice_sizes mirror the JAX-emitted
prefill.stablehlo.mlir. run_prefill.py drives the Rust-emitted prefill (first
token) then decode_step (continuation): token-exact 48/48 against the HF temp-0
reference in spike/openxla/artifacts/results.json. Decode stays token-exact 48/48
unchanged. validate.sh gains a prefill mode. Standalone under spike/rust-emitter;
touches no mlxcel crate and no spike/openxla file.
@inureyes inureyes merged commit 0283067 into main Jun 26, 2026
@inureyes inureyes deleted the spike/rust-emitter-prefill-451 branch June 26, 2026 23:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant