feat: emit Llama-3.2-1B prefill from the Rust StableHLO emitter (#451 follow-up)#457
Merged
Merged
Conversation
…w-up) Extend the Rust-native StableHLO text emitter to author the bucketed prefill graph alongside decode_step. Adds builder ops gather (the multi-token embedding lookup embed[tokens] and the per-position cos/sin lookup), transpose, and linear_seq (the [Lp,K] activation matmul), plus model::emit_prefill matching the JAX prefill signature: tokens[Lp], positions[Lp], real_len, no input caches (zero-initialized internally, bucket Lp=64), an [Lp,Lp] causal mask (j<=i), the [Lp] KV block written per layer with one dynamic_update_slice, and the last logit sliced at real_len-1. The gather dimension_numbers and slice_sizes mirror the JAX-emitted prefill.stablehlo.mlir. run_prefill.py drives the Rust-emitted prefill (first token) then decode_step (continuation): token-exact 48/48 against the HF temp-0 reference in spike/openxla/artifacts/results.json. Decode stays token-exact 48/48 unchanged. validate.sh gains a prefill mode. Standalone under spike/rust-emitter; touches no mlxcel crate and no spike/openxla file.
Open
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Follow-up to #451: the Rust-native StableHLO emitter now also emits the Llama-3.2-1B
prefillgraph. The #451 spike emitteddecode_steptoken-exact but skipped standalone prefill; its findings noted the one missing op wasstablehlo.gather.stablehlo.gather(the embedding lookup and the per-position cos/sin lookup), as feat: evaluate a Rust-native StableHLO emitter as the compiler-family authoring path (spike) #451 predicted. GQA needed onetransposeper layer; the KV write usesdynamic_update_slice(not the range-slice scatter), so it stays GPU-portable.All changes under
spike/rust-emitter/(own empty[workspace], no mlxcel build impact).FINDINGS.mdandREADME.mdupdated to drop the stale "prefill not emitted" note.Refs #451.