Skip to content

feat: OpenXLA export-route spike for Llama-3.2-1B (Phase 0/1)#453

Merged
inureyes merged 1 commit into
mainfrom
spike/openxla-export-449
Jun 26, 2026
Merged

feat: OpenXLA export-route spike for Llama-3.2-1B (Phase 0/1)#453
inureyes merged 1 commit into
mainfrom
spike/openxla-export-449

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Lands the Phase 0/1 OpenXLA export-route spike from #449. Self-contained under spike/openxla/; touches no mlxcel crate or build, and adds no Rust changes.

What this validates (ADR 0004)

  • Export-first model definition holds for the fp16 path: a Llama-3.2-1B JAX reference exports to StableHLO via jax.export and greedy-decodes token-exact (48/48) against an HF transformers temp-0 reference.
  • The same StableHLO module runs unmodified on two independent runtimes (PJRT and IREE) with matching argmax. No hand-written StableHLO, no per-op work.

Scope

  • Phase 0 (environment + stack selection) and Phase 1 (fp16 export spike): done.
  • Phase 2 (4-bit quantized decode) and Phase 3 (mlxcel session-contract integration): not started.
  • Validated the JAX-reference to StableHLO route, not the HF-PyTorch to torch.export route (deliberate choice, recorded in FINDINGS.md).

Notes

  • Validated on spark-101 (NVIDIA GB10 / Grace-Blackwell, aarch64), CPU target. aarch64 + Blackwell has no CUDA jaxlib wheel, so the spike runs on the CPU PJRT/IREE plugin.
  • Findings writeup: spike/openxla/FINDINGS.md.

Refs #449 (Phase 0/1 of 4; the issue stays open for Phase 2/3).

Standalone spike validating the export-first model-definition route from
ADR 0004. Exports Llama-3.2-1B prefill and a bucketed decode-step graph to
StableHLO via jax.export, runs greedy decode from the serialized artifact on
PJRT, and checks the continuation against an HF transformers temp-0 reference.

Result: 48/48 token-exact match with the HF reference and a coherent
continuation. The same exported StableHLO compiles and runs on IREE (llvm-cpu)
with argmax matching PJRT, and contains no custom_call ops. Weights are graph
inputs, not baked constants, so one graph serves any same-architecture
checkpoint.

Scope: Phase 0 (environment, stack selection) and Phase 1 (fp16 export) only.
Phase 2 (4-bit) and Phase 3 (mlxcel integration) are not started. The spike
lives entirely under spike/openxla/, is not part of the Cargo workspace, and
has zero effect on the default Apple-Silicon or CUDA builds. The venv, weights,
and generated artifacts are gitignored; see README.md to reproduce the run and
FINDINGS.md for the writeup.

Refs #449.
@inureyes inureyes merged commit c068d12 into main Jun 26, 2026
5 checks passed
@inureyes inureyes deleted the spike/openxla-export-449 branch June 26, 2026 22:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant