feat: int4 dequant-in-graph spike for Llama-3.2-1B (#449 Phase 2a)#454
Merged
Conversation
Adds the int4 dequant-in-graph path to the OpenXLA spike and characterizes its lowering on the JAX harness. Affine int4 (group 64, packed 8 per uint32) on the 7 per-layer linears; the unpack, convert, and scale run in-graph before each dot_general, while embedding and norms stay fp32. Findings. The in-graph dequant matches a host dequant-then-fp32 baseline (max logit diff 7e-5), int4 greedy stays coherent, and the emitted StableHLO carries the dequant as standard ops with no custom_call. On the fuse question, none of XLA-CPU, IREE-CPU, or IREE-CUDA fuse the dequant into the matmul: XLA wraps both in one kLoop fusion but rebuilds the full fp32 weight inside it, and IREE forms a separate dequant dispatch that materializes the fp32 weight on both CPU and CUDA. So int4 here is an 8x weight-storage win, not a matmul win; a real int4 matmul needs a custom_call to an int4 GEMM or a quantized-matmul fusion. CUDA turned out to be available on this box (GB10, sm_121). The int4 graphs compile to CUDA via IREE and run on the GB10 token-exact (48/48) with the CPU int4 run. Prefill's range-slice cache write needed a pad and stack rewrite (dynamic_update_slice) to lower to the CUDA backend. No perf claim here (the harness re-uploads weights each step); real GPU perf is Phase 2b, now runnable on this box via IREE-CUDA. Standalone under spike/openxla/, no effect on mlxcel crates or builds. Refs #449.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
int4 dequant-in-graph for the OpenXLA spike, characterized on the JAX harness (#449 Phase 2a).
Affine int4 (group 64, packed 8 per uint32) on the 7 per-layer linears; the unpack, convert, and scale run in-graph before each
dot_general, while embedding and norms stay fp32.Findings:
custom_call.custom_callto an int4 GEMM or a quantized-matmul fusion.dynamic_update_slice) to lower to the CUDA backend. No perf claim (the harness re-uploads weights each step); real GPU perf is Phase 2b.Standalone under
spike/openxla/, no effect on mlxcel crates or builds. Findings inspike/openxla/FINDINGS_phase2a.md.Refs #449.