feat: OpenXLA export-route spike for Llama-3.2-1B (Phase 0/1)#453
Merged
Conversation
Standalone spike validating the export-first model-definition route from ADR 0004. Exports Llama-3.2-1B prefill and a bucketed decode-step graph to StableHLO via jax.export, runs greedy decode from the serialized artifact on PJRT, and checks the continuation against an HF transformers temp-0 reference. Result: 48/48 token-exact match with the HF reference and a coherent continuation. The same exported StableHLO compiles and runs on IREE (llvm-cpu) with argmax matching PJRT, and contains no custom_call ops. Weights are graph inputs, not baked constants, so one graph serves any same-architecture checkpoint. Scope: Phase 0 (environment, stack selection) and Phase 1 (fp16 export) only. Phase 2 (4-bit) and Phase 3 (mlxcel integration) are not started. The spike lives entirely under spike/openxla/, is not part of the Cargo workspace, and has zero effect on the default Apple-Silicon or CUDA builds. The venv, weights, and generated artifacts are gitignored; see README.md to reproduce the run and FINDINGS.md for the writeup. Refs #449.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Lands the Phase 0/1 OpenXLA export-route spike from #449. Self-contained under
spike/openxla/; touches no mlxcel crate or build, and adds no Rust changes.What this validates (ADR 0004)
jax.exportand greedy-decodes token-exact (48/48) against an HF transformers temp-0 reference.Scope
Notes
spark-101(NVIDIA GB10 / Grace-Blackwell, aarch64), CPU target. aarch64 + Blackwell has no CUDA jaxlib wheel, so the spike runs on the CPU PJRT/IREE plugin.spike/openxla/FINDINGS.md.Refs #449 (Phase 0/1 of 4; the issue stays open for Phase 2/3).