Skip to content

feat: GPU decode throughput for the OpenXLA spike on GB10 (#449 Phase 2b)#456

Merged
inureyes merged 1 commit into
mainfrom
spike/openxla-phase2b-gpu-449
Jun 26, 2026
Merged

feat: GPU decode throughput for the OpenXLA spike on GB10 (#449 Phase 2b)#456
inureyes merged 1 commit into
mainfrom
spike/openxla-phase2b-gpu-449

Conversation

@inureyes

Copy link
Copy Markdown
Member

GPU decode throughput for the OpenXLA spike on the GB10 (#449 Phase 2b, deferred from Phase 2a until a GPU host existed; CUDA is available on this box).

Removes the Phase 2a harness artifacts that made the GPU look slow: weights resident on the device (uploaded once), on-device argmax (4 bytes per step instead of a 513 KB logits copy), and the KV cache kept resident. Same IREE path for CUDA (GB10) and CPU.

Results (tok/s, single-sequence greedy, Llama-3.2-1B):

variant device tok/s ms/tok
int4 GB10 4.0 252
int4 CPU 0.8 1274
fp32 GB10 5.5 182
fp32 CPU 1.8 541

Two findings:

  • With the harness fixed, the GPU beats the CPU (5x int4, 3x fp32), so the Phase 2a "GPU slower" reading was the weight-reupload artifact, not the GB10.
  • int4 is slower than fp32 on the GPU (4.0 vs 5.5), confirming the Phase 2a fusion result: the unfused dequant is a separate kernel that materializes the fp32 weight, so int4 does more memory traffic per step. dequant-in-graph int4 is a weight-storage win (8x), not a decode-latency win, until the dequant fuses into the GEMM (a custom_call to an int4 GEMM, or a recognized quantized-matmul pattern).

Absolute tok/s is low and not a ceiling (batch-1 decode, 145 tiny matmuls per token, untuned IREE CUDA codegen, no CUDA-graph capture); the relative results are the point.

Standalone under spike/openxla/, no effect on mlxcel crates or builds. Findings in spike/openxla/FINDINGS_phase2b.md.

Refs #449.

… 2b)

Measures real single-sequence decode throughput for the exported StableHLO on the
GB10, with the Phase 2a harness artifacts removed: weights resident on the device
(uploaded once), on-device argmax (4 bytes per step instead of a 513 KB logits
copy), and the KV cache kept resident. Same IREE path for CUDA and CPU.

Results (tok/s): int4 GB10 4.0 vs CPU 0.8; fp32 GB10 5.5 vs CPU 1.8. Two findings.
With the harness fixed the GPU beats the CPU (5x int4, 3x fp32), so the Phase 2a
"GPU slower" reading was the weight-reupload artifact, not the GB10. And int4 is
slower than fp32 on the GPU (4.0 vs 5.5), confirming the Phase 2a fusion result:
the unfused dequant is a separate kernel that materializes the fp32 weight, so
int4 does more memory traffic per step. dequant-in-graph int4 is a weight-storage
win, not a decode-latency win, until the dequant fuses into the GEMM.

Absolute tok/s is low and not a ceiling (batch-1 decode, 145 tiny matmuls per
token, untuned IREE CUDA codegen, no CUDA-graph capture); the relative results are
the point.

Standalone under spike/openxla/, no effect on mlxcel crates or builds.

Refs #449.
@inureyes inureyes merged commit 270e0cf into main Jun 26, 2026
@inureyes inureyes deleted the spike/openxla-phase2b-gpu-449 branch June 26, 2026 23:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant