Skip to content

feat: int4 dequant-in-graph spike for Llama-3.2-1B (#449 Phase 2a)#454

Merged
inureyes merged 1 commit into
mainfrom
spike/openxla-int4-phase2a-449
Jun 26, 2026
Merged

feat: int4 dequant-in-graph spike for Llama-3.2-1B (#449 Phase 2a)#454
inureyes merged 1 commit into
mainfrom
spike/openxla-int4-phase2a-449

Conversation

@inureyes

Copy link
Copy Markdown
Member

int4 dequant-in-graph for the OpenXLA spike, characterized on the JAX harness (#449 Phase 2a).

Affine int4 (group 64, packed 8 per uint32) on the 7 per-layer linears; the unpack, convert, and scale run in-graph before each dot_general, while embedding and norms stay fp32.

Findings:

  • In-graph dequant matches a host dequant-then-fp32 baseline (max logit diff 7e-5); int4 greedy stays coherent; the emitted StableHLO carries the dequant as standard ops with no custom_call.
  • Fuse question: none of XLA-CPU, IREE-CPU, or IREE-CUDA fuse the dequant into the matmul. XLA wraps both in one kLoop fusion but rebuilds the full fp32 weight inside it; IREE forms a separate dequant dispatch that materializes the fp32 weight on both CPU and CUDA. So int4 here is an 8x weight-storage win, not a matmul win; a real int4 matmul needs a custom_call to an int4 GEMM or a quantized-matmul fusion.
  • CUDA is available on this box (GB10, sm_121). The int4 graphs compile to CUDA via IREE and run on the GB10 token-exact (48/48) with the CPU int4 run. Prefill needed a pad and stack cache write (dynamic_update_slice) to lower to the CUDA backend. No perf claim (the harness re-uploads weights each step); real GPU perf is Phase 2b.

Standalone under spike/openxla/, no effect on mlxcel crates or builds. Findings in spike/openxla/FINDINGS_phase2a.md.

Refs #449.

Adds the int4 dequant-in-graph path to the OpenXLA spike and characterizes its
lowering on the JAX harness. Affine int4 (group 64, packed 8 per uint32) on the
7 per-layer linears; the unpack, convert, and scale run in-graph before each
dot_general, while embedding and norms stay fp32.

Findings. The in-graph dequant matches a host dequant-then-fp32 baseline
(max logit diff 7e-5), int4 greedy stays coherent, and the emitted StableHLO
carries the dequant as standard ops with no custom_call. On the fuse question,
none of XLA-CPU, IREE-CPU, or IREE-CUDA fuse the dequant into the matmul: XLA
wraps both in one kLoop fusion but rebuilds the full fp32 weight inside it, and
IREE forms a separate dequant dispatch that materializes the fp32 weight on both
CPU and CUDA. So int4 here is an 8x weight-storage win, not a matmul win; a real
int4 matmul needs a custom_call to an int4 GEMM or a quantized-matmul fusion.

CUDA turned out to be available on this box (GB10, sm_121). The int4 graphs
compile to CUDA via IREE and run on the GB10 token-exact (48/48) with the CPU
int4 run. Prefill's range-slice cache write needed a pad and stack rewrite
(dynamic_update_slice) to lower to the CUDA backend. No perf claim here (the
harness re-uploads weights each step); real GPU perf is Phase 2b, now runnable
on this box via IREE-CUDA.

Standalone under spike/openxla/, no effect on mlxcel crates or builds.

Refs #449.
@inureyes inureyes merged commit 6a0846c into main Jun 26, 2026
@inureyes inureyes deleted the spike/openxla-int4-phase2a-449 branch June 26, 2026 23:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant