SDPA decode perf improvements for qwen-3.5-35B-A3B#18759
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18759
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 114 PendingAs of commit d5209fc with merge base 841181e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
2cb04c3 to
febc419
Compare
There was a problem hiding this comment.
Pull request overview
This PR improves ExecuTorch CUDA SDPA decode performance for the common decode case where Lq = 1 (e.g., Qwen3.5 MoE generation), by introducing a Split-K “flash-decoding” Triton path and dispatching to it at runtime.
Changes:
- Add a Split-K decode SDPA Triton kernel (
sdpa_decode_splitk) plus a reduction kernel to improve occupancy whenL_q == 1. - Update the Qwen3.5 MoE attention path to dispatch between Split-K (decode) and tiled SDPA (prefill) via
torch.cond. - Add correctness tests and a benchmark script for SDPA decode shapes; update export example shapes to avoid overly-small AOTI shape specialization.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| examples/models/qwen3_5_moe/model.py | Switch attention to Triton SDPA and add decode-time Split-K dispatch via torch.cond. |
| examples/models/qwen3_5_moe/main.cpp | Plumb a stats callback into generation and print throughput/timing breakdown. |
| examples/models/qwen3_5_moe/export.py | Use a max-length example sequence to prevent AOTI from baking in too-small intermediate buffers. |
| backends/cuda/triton/kernels/sdpa.py | Implement Split-K decode kernel + reduction and expose sdpa_decode_splitk. |
| backends/cuda/triton/kernels/init.py | Export sdpa_decode_splitk from the kernels package. |
| backends/cuda/tests/test_triton_sdpa_splitk.py | Add CUDA BF16 unit tests validating Split-K correctness vs PyTorch SDPA reference. |
| backends/cuda/benchmarks/benchmark_sdpa.py | Add a benchmark script comparing Triton SDPA/Split-K vs PyTorch SDPA backends. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @triton_op("triton::sdpa_decode_splitk", mutates_args={}) | ||
| def sdpa_decode_splitk( | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: Optional[torch.Tensor] = None, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = False, | ||
| scale: float = 0.0, | ||
| enable_gqa: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Split-K flash-decoding SDPA for L_q=1 (decode step). | ||
|
|
||
| Signature mirrors sdpa() for drop-in use with torch.cond dispatch. | ||
| enable_gqa is accepted but ignored — GQA is handled natively via | ||
| H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1. | ||
| """ | ||
| B, H_q, L_q, D = query.shape | ||
| _, H_kv, L_kv, _ = key.shape | ||
|
|
||
| out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype) | ||
|
|
There was a problem hiding this comment.
sdpa_decode_splitk() launches kernels that assume CUDA + bfloat16 inputs (and the reduce kernel stores bfloat16 unconditionally), but unlike sdpa() it never calls _validate_sdpa_inputs(). This can lead to hard-to-debug failures or incorrect behavior if the op is called with non-bfloat16 or non-CUDA tensors. Consider reusing _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) at the top of sdpa_decode_splitk (and keep the existing feature/shape checks).
| HAS_MASK=HAS_MASK, | ||
| HEAD_DIM=D, | ||
| NUM_GROUPS=num_groups, | ||
| BLOCK_G=_next_power_of_2(num_groups), | ||
| ) |
There was a problem hiding this comment.
BLOCK_G is derived via _next_power_of_2(num_groups), but _next_power_of_2() is documented as clamping to [16, 256] for head-dim tiling. Using it for num_groups forces a minimum of 16 (wasted work for MHA/low-group configs) and would be incorrect if num_groups > 256 (since BLOCK_G would clamp to 256 and not cover all groups). Consider a separate next_pow2 helper for group counts (no min clamp) and/or add a clear runtime check that num_groups <= 256 if that’s an intended constraint.
| device_name = torch.cuda.get_device_name() | ||
| print() | ||
| print("=" * 100) | ||
| print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}") |
There was a problem hiding this comment.
The benchmark header prints "Qwen3.5-35B-A4B", but this benchmark (and the PR description) refers to the A3B variant. This looks like a typo and can confuse readers when comparing numbers; consider correcting the printed model name.
| print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}") | |
| print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}") |
|
@claude review this PR. |
|
Claude finished @digantdesai's task in 2m 21s —— View job PR Review: SDPA decode perf improvements for qwen-3.5-35B-A3B
Overall this is a well-structured PR. The Split-K flash-decoding approach is the right strategy for L_q=1 with low H_kv — the occupancy problem is real and the 1.45x end-to-end speedup is meaningful. The kernel implementation, Issues1. Missing input validation in
Call 2.
This doesn't cause incorrect results (the 3. Reduce kernel output addressing assumes B=1 — The reduce kernel uses The same pattern applies to the partial buffers — 4. Benchmark typo: "A4B" should be "A3B" — print(f"SDPA Benchmark Qwen3.5-35B-A4B — {scenario}")Should be "A3B" to match the model variant. Copilot flagged this too. Suggestions (non-blocking)5. Partial accumulator stores unscaled The split-K kernel stores the raw unnormalized accumulator 6. num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128)The choice of 256 tokens per split and max 128 splits is reasonable for A100 with ~108 SMs, but this could benefit from being SM-count-aware (e.g., 7. Test tolerance — All tests use 8. The y = torch.cond(
q.shape[2] == 1,
lambda q, k, v, mask: sdpa_decode_splitk(q, k, v, attn_mask=mask),
lambda q, k, v, mask: sdpa(q, k, v, attn_mask=mask, enable_gqa=True),
[q, k, v, attn_mask],
)Note that during AOTI tracing, both branches are traced with symbolic shapes, so 9. Export shape change — Changing SummaryThe kernel implementation is correct and well-tested. The main actionable item is issue #1 (missing CUDA/bf16 validation in |
ebe61e8 to
5d3b620
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| printf( | ||
| "\n\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, | ||
| stats.num_prompt_tokens, | ||
| stats.num_generated_tokens); |
| for name, label, _ in backends: | ||
| if name == ref_name or outputs[name] is None: | ||
| continue | ||
| err = _max_abs_error(outputs[name], ref_out) | ||
| assert err < 1e-2, ( | ||
| f"Output mismatch for {_shape_label(shape)}: " | ||
| f"{label} vs {BACKENDS[ref_name][0]}, " | ||
| f"max abs error {err:.3e} >= 1e-2" | ||
| ) |
| out = self.splitk(q, k, v, attn_mask=mask) | ||
|
|
||
| self.assertFalse(torch.isnan(out).any(), "All-masked should not NaN") | ||
| self.assertFalse(torch.isinf(out).any(), "All-masked should not Inf") |
62428be to
f011e54
Compare
|
Can you also list the prefill performance in the benchmark result? |
f011e54 to
3836bea
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.assertFalse(torch.isnan(out).any(), "NaN in output") | ||
| self.assertLess( | ||
| _max_abs_error(out, ref), | ||
| 0.05, |
| key=["Lk", "HEAD_DIM", "NUM_GROUPS", "HAS_MASK"], | ||
| ) | ||
| @triton.jit | ||
| def _sdpa_decode_splitk_kernel( |
There was a problem hiding this comment.
decode is not related to the sdpa kernel. Consider renaming it with _sdpa_splitk_kernel. Same as others.
There was a problem hiding this comment.
This kernel only works with Lq == 1. Hence the decode in the name, and assert in the check_args.
| # The export produces two methods — decode (T=1, static) and | ||
| # prefill (T>=2, dynamic). Each traces only one branch, so no | ||
| # torch.cond is needed and we avoid GPU→CPU sync overhead. | ||
| if T == 1: |
There was a problem hiding this comment.
maybe in another PR, but im thinking if we should apply the new split_k sdpa only on T == 1.
IIUC, the core idea of split_k algo is trying to fully leverage the compute unit for GPU. Given the circumstance that
- batch size will always be 1 in our usage
- 108 SM kernels in A100
maybe we can apply the split_k sdpa even in
what if we can use torch.cond here and make runtime dynamic choose the right kernel?
There was a problem hiding this comment.
Not sure I follow. the split-k kernel is for decode where the existing kernel doesn't work well. It is working well for prefill case with T > 1.
| @classmethod | ||
| def setUpClass(cls): | ||
| _skip_if_no_cuda() | ||
| cls.sdpa = _import_sdpa() |
There was a problem hiding this comment.
right now we only test regular sdpa but not the split version. Please add a test for it.
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| Benchmark the Triton SDPA kernel against PyTorch SDPA backends. |
There was a problem hiding this comment.
maybe on another PR, but we can make it as a perf-ci to guard the perf.
There was a problem hiding this comment.
We can but I don't rely on the CI perf to be stable.
There was a problem hiding this comment.
i think some up and down should be fine, as long as the perf change is not too large.
In the PR summary. |
0a46be3 to
0c0f132
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # prefill (T>=2, dynamic). Each traces only one branch, so no | ||
| # torch.cond is needed and we avoid GPU→CPU sync overhead. | ||
| if T == 1: | ||
| y = sdpa_decode_splitk(q, k, v, attn_mask=attn_mask) |
| self.assertFalse(torch.isnan(out).any(), "NaN in output") | ||
| self.assertLess( | ||
| _max_abs_error(out, ref), | ||
| 0.05, |
| # is_causal is a no-op at L_q=1 (single query can't attend to future | ||
| # positions), so we accept it silently for API compatibility with callers | ||
| # that always pass is_causal=True for decode. | ||
|
|
| """Split-K flash-decoding SDPA for L_q=1 (decode step). | ||
|
|
||
| Signature mirrors sdpa() for drop-in use with torch.cond dispatch. | ||
| enable_gqa is accepted but ignored — GQA is handled natively via | ||
| H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1. | ||
| """ |
0c0f132 to
0609ae2
Compare
0609ae2 to
4de4538
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """Split-K flash-decoding SDPA for L_q=1 (decode step). | ||
|
|
||
| Signature mirrors sdpa() for drop-in use with torch.cond dispatch. | ||
| enable_gqa is accepted but ignored — GQA is handled natively via | ||
| H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1. | ||
| """ |
| max_seq_len: int = 4096 | ||
| use_splitk_decode: bool = True | ||
| layer_types: list = field(default_factory=list) |
| self.assertFalse(torch.isnan(out).any(), "NaN in output") | ||
| self.assertLess( | ||
| _max_abs_error(out, ref), | ||
| 0.05, |
| # is_causal is a no-op at L_q=1 (single query can't attend to future | ||
| # positions), so we accept it silently for API compatibility with callers | ||
| # that always pass is_causal=True for decode. | ||
|
|
||
| # Validation — only check at runtime (concrete shapes), not during AOTI | ||
| # tracing where shapes are symbolic. torch.cond traces both branches with | ||
| # the same symbolic L_q, so L_q is not necessarily 1 during tracing. | ||
| if isinstance(L_q, int): | ||
| if L_q != 1: | ||
| raise RuntimeError( | ||
| f"sdpa_decode_splitk requires L_q == 1 (decode); got L_q={L_q}" | ||
| ) |
4de4538 to
14bd4cb
Compare
14bd4cb to
3069c79
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from executorch.backends.cuda.triton.kernels.sdpa import sdpa, sdpa_decode_splitk |
| value: torch.Tensor, | ||
| attn_mask: Optional[torch.Tensor] = None, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = False, | ||
| scale: float = 0.0, | ||
| enable_gqa: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Split-K flash-decoding SDPA for L_q=1 (decode step). | ||
|
|
||
| Signature mirrors sdpa() for drop-in use with torch.cond dispatch. | ||
| enable_gqa is accepted but ignored — GQA is handled natively via | ||
| H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1. | ||
| """ | ||
| _validate_sdpa_inputs(query, key, value, dropout_p, enable_gqa) | ||
|
|
||
| B, H_q, L_q, D = query.shape | ||
| _, H_kv, L_kv, _ = key.shape | ||
|
|
||
| out = torch.empty((B, H_q, L_q, D), device=query.device, dtype=query.dtype) | ||
|
|
||
| # is_causal is a no-op at L_q=1 (single query can't attend to future | ||
| # positions), so we accept it silently for API compatibility with callers | ||
| # that always pass is_causal=True for decode. | ||
|
|
| def test_is_causal_rejected(self): | ||
| """is_causal=True should raise RuntimeError.""" | ||
| B, H_q, H_kv, D = 1, 8, 2, 64 | ||
| q = torch.randn(B, H_q, 1, D, dtype=torch.bfloat16, device="cuda") | ||
| k = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") | ||
| v = torch.randn(B, H_kv, 64, D, dtype=torch.bfloat16, device="cuda") | ||
| with self.assertRaises(RuntimeError): | ||
| self.splitk(q, k, v, is_causal=True) |
3069c79 to
4d08bf0
Compare
4d08bf0 to
f02b19a
Compare
Compares ET Triton SDPA (native GQA) against PyTorch Flash/Efficient/Math backends (expanded KV) across Lk=64..16K on A100. Uses triton.testing.do_bench for timing. Standalone script, no changes to the kernel. This PR was authored with the assistance of Claude
Register `triton::sdpa_decode_splitk` as an independent op so AOTI can trace and compile it without the runtime L_kv conditional that prevents the split-K path from appearing in the standard `sdpa` op. The split-K (flash-decoding) approach partitions the KV sequence across CTAs and reduces partial softmax results in a second kernel. The benchmark script now includes the split-K column for comparison. BLOCK_G (the GQA group tile) uses _next_power_of_2_unclamped() to avoid inflating small group counts to 16. Phantom rows from over-sized tiles change register pressure and instruction scheduling, altering fp32 accumulation order enough to degrade output quality over long autoregressive sequences. Standalone kernel benchmark on H100 (Qwen3.5 MoE decode, B=1, H_q=16, H_kv=2, D=256, bf16): Lk ET Tiled (us) ET Split-K (us) Speedup 64 131.8 259.5 0.5x 512 98.9 221.5 0.4x 4096 199.9 214.4 0.9x 8192 392.2 211.3 1.9x 16384 775.3 211.8 3.7x Split-K breaks even around Lk=4096 and dominates at longer sequences where the tiled kernel's single-CTA-per-head bottleneck becomes severe. This PR was authored with the assistance of Claude
The previous example used T=2, which caused AOTI to compile the
chunk_gated_delta_rule kernel for a single chunk (NT=1). At runtime,
prompts longer than 64 tokens (requiring NT>1 chunks) failed with
"Error resizing tensor at input 0". Using max_seq_len-1 as the
example ensures AOTI generalizes intermediate buffer sizes for the
full sequence length range.
Comparison against original export (tq4_sdpa fused kernel)
on H100 (Qwen3.5-35B-A3B, HQQ-INT4, max_seq_len=4096, 5 runs median):
Original (tq4_sdpa) Baseline (Triton SDPA)
Decode tok/s 68.4 61.7
Prefill tok/s 275.7 378.2
Baseline prefill is 1.37x faster; decode is 0.90x (tq4_sdpa's fused
decode kernel is faster than the tiled Triton SDPA at L_q=1). The
split-K commit addresses the decode gap.
This PR was authored with the assistance of Claude
Dual-method export (decode T=1, prefill T>=2) lets the model use a simple if/else on T instead of torch.cond, eliminating the GPU-to-CPU sync overhead that torch.cond's predicate evaluation requires. Decode calls sdpa_decode_splitk (split-K flash-decoding for high KV occupancy), prefill calls tiled sdpa. Guard sdpa_decode_splitk validation behind isinstance(L_q, int) so AOTI tracing with symbolic shapes doesn't trip the L_q==1 check. Align sdpa_decode_splitk signature with sdpa (dropout_p, is_causal, enable_gqa) for consistent API; unsupported args fail with clear messages. This PR was authored with the assistance of Claude
Add `use_splitk_decode` config flag to control whether FullAttention uses the split-K (flash-decoding) SDPA kernel or the tiled SDPA for decode (T=1). The split-K kernel partitions the KV sequence across CTAs, yielding ~20% higher decode throughput on H100: Variant Decode tok/s (avg across prompts) Tiled SDPA 88.5 Split-K SDPA 107.5 (+21%) The flag defaults to True (split-K on). Pass `--no-splitk` at export time to disable. Quality is verified identical at temperature=0. This PR was authored with the assistance of Claude
1af2029 to
d5209fc
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| self.assertFalse(torch.isnan(out).any(), "NaN in output") | ||
| self.assertLess( | ||
| _max_abs_error(out, ref), | ||
| 0.05, |
Performance Improvements for SDPA
Improves SDPA performance for decode sequences where$L_q = 1$ .
Benchmark: qwen3.5-35B-A3B
Decode Performance (tok/s)
Prefill Performance (tok/s)
Summary
temperature=0(~25x speedup at the SDPA op level, for ~10.2K = 1024 tokens x 10 layers, calls we saw 5.3sec to 209ms speedup)
Implementation Details
_sdpa_fwd_kernel_m64).