feat: GPU decode throughput for the OpenXLA spike on GB10 (#449 Phase 2b)#456
Merged
Conversation
… 2b) Measures real single-sequence decode throughput for the exported StableHLO on the GB10, with the Phase 2a harness artifacts removed: weights resident on the device (uploaded once), on-device argmax (4 bytes per step instead of a 513 KB logits copy), and the KV cache kept resident. Same IREE path for CUDA and CPU. Results (tok/s): int4 GB10 4.0 vs CPU 0.8; fp32 GB10 5.5 vs CPU 1.8. Two findings. With the harness fixed the GPU beats the CPU (5x int4, 3x fp32), so the Phase 2a "GPU slower" reading was the weight-reupload artifact, not the GB10. And int4 is slower than fp32 on the GPU (4.0 vs 5.5), confirming the Phase 2a fusion result: the unfused dequant is a separate kernel that materializes the fp32 weight, so int4 does more memory traffic per step. dequant-in-graph int4 is a weight-storage win, not a decode-latency win, until the dequant fuses into the GEMM. Absolute tok/s is low and not a ceiling (batch-1 decode, 145 tiny matmuls per token, untuned IREE CUDA codegen, no CUDA-graph capture); the relative results are the point. Standalone under spike/openxla/, no effect on mlxcel crates or builds. Refs #449.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
GPU decode throughput for the OpenXLA spike on the GB10 (#449 Phase 2b, deferred from Phase 2a until a GPU host existed; CUDA is available on this box).
Removes the Phase 2a harness artifacts that made the GPU look slow: weights resident on the device (uploaded once), on-device argmax (4 bytes per step instead of a 513 KB logits copy), and the KV cache kept resident. Same IREE path for CUDA (GB10) and CPU.
Results (tok/s, single-sequence greedy, Llama-3.2-1B):
Two findings:
custom_callto an int4 GEMM, or a recognized quantized-matmul pattern).Absolute tok/s is low and not a ceiling (batch-1 decode, 145 tiny matmuls per token, untuned IREE CUDA codegen, no CUDA-graph capture); the relative results are the point.
Standalone under
spike/openxla/, no effect on mlxcel crates or builds. Findings inspike/openxla/FINDINGS_phase2b.md.Refs #449.