Skip to content

[Refactor][Pipeline] Run pipeline rewriting before layout inference and stabilize tiled WS#2002

Merged
LeiWang1999 merged 89 commits intotile-ai:mainfrom
LeiWang1999:pipeline_refactor_0329
Apr 7, 2026
Merged

[Refactor][Pipeline] Run pipeline rewriting before layout inference and stabilize tiled WS#2002
LeiWang1999 merged 89 commits intotile-ai:mainfrom
LeiWang1999:pipeline_refactor_0329

Conversation

@LeiWang1999
Copy link
Copy Markdown
Member

@LeiWang1999 LeiWang1999 commented Mar 31, 2026

Summary

  • run pipeline planning, software-pipeline injection, and tiled warp specialization before layout inference while preserving the stage context needed by downstream passes
  • harden multi-version buffer rewriting, layout inference, tile-op lowering, and sparse GEMM lowering for the new pass order, including BufferRegion-aware sparse MMA loads and stricter TMA stride validation
  • stabilize tiled warp specialization and mixed TMA/cp.async handling with branch-private buffer remapping, consumer-init sinking, safer barrier reuse, and targeted regression coverage

Testing

  • ./format.sh
  • cmake --build build -j$(nproc)
  • python -m pytest -q testing/python/transform/test_tilelang_transform_Inject_software_pipeline.py testing/python/transform/test_tilelang_transform_pipeline_planning.py testing/python/transform/test_tilelang_transform_producer_consumer_ws_tiled.py
  • python -m pytest -q testing/python/issue/test_tilelang_issue_tma_no_ws.py -k mixed_tma_cp_async_shared_stage_barriers
  • python -m pytest -q testing/python/issue/test_tilelang_issue_ws_simt_copy_full_producer_extent.py

Summary by CodeRabbit

  • New Features

    • Added warp-specialized producer/consumer transform and an instruction-classification pass; new transform APIs and layout-construction helpers exposed.
  • Performance Improvements

    • Reordered pipeline passes and improved lowering, multi-version buffering, swizzle/stride handling, and TMA/copy selection for more consistent codegen and layouts.
  • Bug Fixes

    • Improved buffer-region detection, pipeline/barrier handling, thread-bounds, and memory-access legality; more robust layout propagation for multi-version buffers.
  • Tests

    • Added GPU correctness tests, relaxed some kernel-source assertions, and introduced pipeline-aware test helpers.

LeiWang1999 and others added 30 commits March 20, 2026 15:09
…ling

- Consolidated the handling of shared barriers and pipeline planning by removing redundant conditional checks.
- Ensured that `LowerSharedBarrier`, `PipelinePlanning`, and `InjectSoftwarePipeline` are consistently applied, enhancing the clarity and efficiency of the optimization process.

This change improves the maintainability of the code while preserving existing functionality.
…tor gemm.h

- Added a new `InstructionAnnotation` pass to annotate tile operations with their instruction kind before layout inference, improving the optimization pipeline's ability to reason about instruction mixes.
- Refactored `gemm.h` to move the `allowTcgen5Mma` and `allowWgmma` methods under the private section, enhancing code organization and encapsulation.

These changes improve the clarity and maintainability of the code while preserving existing functionality.
…pecialization

- Updated `multi_version_buffer_rewriter.cc` to improve read/write access detection for tile operations by analyzing `tl.tileop.region` calls, ensuring accurate buffer access tracking.
- Modified `phase.py` to integrate `ProducerConsumerWarpSpecializedTiled` before layout inference, allowing for high-level tile-op IR transformations that enhance producer/consumer splits.
- Added a new `ProducerConsumerWarpSpecializedTiled` function in `__init__.py` to facilitate tile-level warp specialization, improving the optimization pipeline's efficiency.

These changes enhance the handling of multi-version buffers and optimize the transformation process for tiled operations.
… plan

- inject_pipeline.cc: guard reads/writes recalculation with subtree_modified_
  flag to prevent local.var buffer promotion to kernel parameters
- phase.py: add temporary debug prints after PipelinePlanning and
  InjectSoftwarePipeline (to be removed during CI fix work)
- example_group_per_split_token_cast_to_fp8.py: add disable_cache for debugging
- docs/plan.md: implementation plan for fixing CI test failures
- draft.md: original design draft

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Remove print("After PipelinePlanning"), print(mod),
  print("After InjectSoftwarePipeline"), print(mod) from
  LowerAndLegalize in tilelang/engine/phase.py
- Remove tilelang.disable_cache() from
  examples/cast/example_group_per_split_token_cast_to_fp8.py

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
After InjectSoftwarePipeline, multi-versioned buffers share the same
data Var as the original but have an extra leading dimension (num_stages).
LayoutInference's alias propagation and annotation handling tried to
Reshape layouts between these buffers, which failed because the total
element counts differ.

Guard three Reshape call sites in layout_inference.cc to skip sibling
buffers whose total storage size is incompatible with the source layout.
This lets multi-versioned buffers get their own layout inference instead
of inheriting an incompatible layout from the original buffer.

Fixes compilation failures in dequantize_gemm, GDN, and other kernels
that use software pipelining with shared memory buffers.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When ProducerConsumerWarpSpecializedTiled identifies a TMA kernel as a
warp-specialization candidate but the tiled rewriter cannot handle it
(e.g., conditional loop bodies like sparse block masks), the fallback
previously returned the original function with num_stages annotations
intact.  PipelinePlanning and InjectSoftwarePipeline would then generate
non-WS TMA pipeline code with broken barrier phase tracking for
conditional pipeline bodies (barrier waits outside conditionals cause
deadlocks when the condition is false).

Fix: on WS fallback, strip num_stages annotations from pipeline loops
so that the pipeline passes skip the function.  The kernel runs
unpipelined but correctly.

Fixes CUDA_ERROR_LAUNCH_FAILED in blocksparse_gemm and related TMA
kernels with conditional loop bodies.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Port PhaseCounter and StageExprReplacer from the legacy
ProducerConsumerWarpSpecialized pass into the tiled WS pass to handle
conditional loop bodies (e.g., sparse block masks).

When the pipeline loop body is wrapped in an IfThenElse without else:
1. Unwrap the condition before classifying statements
2. Create separate producer/consumer PhaseCounters (local int32 buffers)
3. Use counter-based stage/parity expressions instead of loop-variable
4. Wrap producer and consumer bodies in the original condition
5. Increment counters at end of each guarded iteration
6. Rewrite shared-buffer stage indices via StageExprReplacer

This ensures barrier parity stays correct when iterations are
conditionally skipped, fixing CUDA_ERROR_LAUNCH_FAILED in
blocksparse_gemm and related TMA kernels with conditional execution.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add GemmSPNode handling to:
- inject_pipeline.cc AddReadsWritesForTileOp: model A, E, B as reads
  and C as write (E is the sparse metadata buffer)
- pipeline_planning.cc: same access model for dependency analysis

This makes sparse GEMM visible to the pipeline machinery for correct
stage assignment and buffer multi-versioning. However, the tile-op's
consumer-side buffer accesses still don't get stage-indexed because
the pipeline body rewriter can't rewrite high-level tile-op Call
arguments (an architectural limitation of running InjectSoftwarePipeline
before LowerTileOp).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace whole-buffer access_ptr(1) calls with MakeAccessPtrFromRegion
for A, B, C, and E buffers in sparse GEMM lowering. This preserves
stage-specific region offsets from pipeline multi-versioning, matching
the dense GemmNode::Lower pattern.

CUDA output now shows correct stage-indexed consumer accesses:
  gemm_sp_ss(..., (k%3)*8192, (k%3)*8192+27648, C_local, (k%3)*2048+49152)
instead of always using stage-0 offsets.

Note: gemm_sp still produces incorrect results because the kernel needs
warp specialization but TiledWSCandidate::Check doesn't recognize it
as a TMA candidate. The non-WS pipeline path generates structurally
different code from the reference WS path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Three changes in producer_consumer_ws_tiled.cc:
1. Require num_stages >= 2 for WS candidacy (was >= 1). Single-stage
   kernels like seer_attention don't need WS and the transformation
   produces incorrect results for them.
2. Add HasTmaPipeline() check to detect TMA kernels with pipeline
   annotations that are rejected by the full WS candidate check (e.g.,
   kernels with manual layout annotations like gemm_sp).
3. Strip num_stages annotations for rejected TMA pipeline kernels to
   prevent InjectSoftwarePipeline from generating broken non-WS TMA
   pipeline code.

One change in gemm_sp.cc:
- Use MakeAccessPtrFromRegion for A, B, C, E buffer access pointers
  instead of whole-buffer access_ptr. This preserves stage-specific
  region offsets from pipeline multi-versioning.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Only unwrap IfThenElse wrapper when the then-branch is a simple flat
sequence of tile-op Evaluate calls. Skip unwrapping for complex bodies
with LetStmt, For, or other control flow that could break variable
scoping when split into producer/consumer for WS.

Fixes variable-used-before-definition error in
blocksparse_attention sparse_gqa_decode_varlen_indice, which has a
conditional loop body containing LetStmt bindings.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The >= 2 threshold broke test_num_stages_one_pure_tma_keeps_auto_warp_specialize.
Pure TMA kernels with num_stages=1 should still be WS candidates.
The seer_attention issue (num_stages=1 with manual layout) is handled
by the has_manual_layout_ check, not the num_stages threshold.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When LayoutInference encounters an MVB-expanded buffer (with leading
stage dimensions) whose trailing dimensions match the original layout,
use Layout::Expand to propagate the manual layout instead of rejecting
or skipping the buffer. Applied to all 3 layout propagation paths:
annotated layout map, alias propagation, and finalization.

Also remove the blanket !has_manual_layout_ WS candidate rejection
since manual layouts now survive onto versioned shared buffers via
Layout::Expand.

Fixes test_sparse_ws_regular_metadata_copy_stays_in_producer.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
has_manual_layout_ guard

1. LetStmt chain peeling: when IfThenElse then_case starts with
   LetStmt bindings, peel them and append to let_bindings before
   checking the simple-body guard. This allows WS for conditional
   bodies with variable definitions (e.g., sparse attention patterns).

2. Restore !has_manual_layout_ in WS candidacy check: removing it
   caused dequant_groupedgemm_bf16_mxfp4_hopper to fail because
   MXFP4 layouts don't survive MVB expansion. The Layout::Expand fix
   handles sparse metadata layouts but not all manual layout types.

3. Layout::Expand propagation (from Round 6) remains in place for
   future use when MVB learns to handle all manual layout types.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The simple-body guard only accepted flat Evaluate sequences inside
IfThenElse, blocking legitimate WS for complex conditional bodies
like sparse flash attention (T.clear, T.reduce_max, T.gemm inside
the guard). The LetStmt peeling already handles variable scoping.

Fixes test_pure_tma_consumer_local_init_does_not_leak_into_producer
and test_sparse_ws_regular_metadata_copy_stays_in_producer.

The remaining test_mixed_tma_cp_async_shared_stage_barriers failure
is a pre-existing issue on the original branch: the tiled WS pass
produces SIMT copies instead of cp.async because LowerPTXAsyncCopy
is not in the pass pipeline (the comment at phase.py:308 says it
runs earlier but no actual call exists on either branch or reference).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Wrap the WS producer loop body in a kPipelineContextNumStages AttrStmt
so that LowerTileOp's pipelined_depth_ is > 0 when processing SIMT
producer copies. This enables InjectPTXAsyncCopy to generate cp.async
for global-to-shared copies in the WS producer branch.

Without this, the WS rewriter strips all pipeline annotations from
the rewritten loops, causing LowerTileOp to skip cp.async injection
for SIMT producers. The consumer loop stays annotation-free since it
doesn't need cp.async.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…out guard

1. Mixed TMA+cp.async barrier: use ptx_cp_async_barrier_noinc for
   forward barrier arrival in mixed producer groups, matching the
   reference producer_consumer_ws.cc protocol.

2. Consumer-only pre-loop init sinking: in ReplacePipelineLoopInStmt,
   guard pre-loop siblings as consumer-only when they're not classified
   as producer (TMA/SIMT/cp.async). Fragment init (T.fill, T.clear)
   and local buffer init are placed in the consumer branch instead of
   the shared prelude.

3. Restored blanket has_manual_layout_ guard: the dtype-based heuristic
   to distinguish sparse metadata from MXFP4 layouts doesn't work
   because both use uint8. Dequant_groupedgemm_bf16_mxfp4_hopper
   requires the guard to prevent broken WS.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… barrier

1. Consumer-init sinking: only sink pre-loop stmts that are FillNode
   writing to fragment/local buffers. Keep block_mask setup and shared
   state in the shared prelude.

2. Manual-layout: attempt targeted check that only rejects when TMA
   copy destinations match layout_map entries. Collect layout_map vars
   and compare against TMA copy destinations.

3. Per-group cp.async barrier: use group-level cp.async flag (single
   group for now) instead of function-wide boolean.

The 3 WS issue tests still fail because the layout_map annotation
parsing falls through to the conservative rejection path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The layout_map annotation uses Map<Var, Layout> before LayoutInference
(not Map<Buffer, Layout>). Parse as Map<ObjectRef, ObjectRef> and
handle both key types: Buffer (post-inference) and Var (pre-inference).
For Var keys, look up the corresponding alloc_buffer by data Var match.

Compare collected manual-layout buffers against pipeline copy
destinations: reject only when a manually-laid-out buffer is also
a producer copy target (TMA/SIMT/cp.async) inside the pipeline.
This allows sparse metadata (E_shared, SIMT-copied) onto the WS path
while rejecting dequant MXFP4 (B_shared, SIMT-copied with swizzle).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1. Fixed SEGFAULT: removed dangling layout_map_layouts_ vector that
   was never populated, causing OOB access. Now stores Buffer+Layout
   pairs in layout_map_entries_.

2. Use DetectSwizzleMode to distinguish swizzled layouts (MXFP4,
   incompatible with MVB) from non-swizzled (sparse metadata, safe).
   Swizzled layouts reject WS candidacy; non-swizzled layouts allow it.

3. Removed debug LOG(WARNING) from hot path.

4. Parse layout_map annotation keys as both Buffer and Var (via
   Map<ObjectRef, ObjectRef>), resolving Var keys to alloc_buffers.

Results: sparse metadata WS test PASSES, dequant PASSES (protected).
Down to 2 failures: mixed barrier pattern + consumer init sinking.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sink pre-loop Evaluate nodes classified as kConsumer into consumer
branch. Keep For loops (block_mask setup), producer copies, and other
control flow in the shared prelude. This is simpler and safer than
the FillNode scope check approach.

The test_pure_tma_consumer_local_init test still fails because the
T.fill statements are at a different structural level than the
SeqStmt where the pipeline loop lives. Fixing this requires deeper
IR structure analysis.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extract consumer-only pre-loop Evaluate statements (T.fill on
fragments) from the shared prelude and prepend them to the consumer
branch inside the WS if/else structure. This ensures fragment init
like acc_o, logsum, scores_max appears only in the consumer branch,
not in the shared prelude or producer branch.

Uses a two-pass approach: first ReplacePipelineLoopInStmt extracts
consumer-only stmts into extracted_consumer_init_, then the WS body
is rebuilt with the extracted stmts prepended to the consumer branch.

Fixes test_pure_tma_consumer_local_init_does_not_leak_into_producer.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@LeiWang1999
Copy link
Copy Markdown
Member Author

@regression-perf

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 5, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/23998738578

Results

File Original Latency Current Latency Speedup
example_tilelang_gemm_splitk 1.09 1.14124 0.955102
example_tilelang_gemm_splitk_vectorize_atomicadd 1.10217 1.14628 0.96152
example_warp_specialize_gemm_softpipe_stage2 0.0269336 0.0277273 0.971374
example_warp_specialize_gemm_copy_1_gemm_0 0.0269413 0.0277208 0.971879
example_warp_specialize_gemm_copy_0_gemm_1 0.03885 0.0399303 0.972947
example_warp_specialize_gemm_barrierpipe_stage2 0.0397743 0.0407113 0.976983
example_mhc_pre 0.149518 0.152303 0.981716
example_mha_bwd_bshd 0.0392325 0.0399242 0.982675
example_mha_fwd_varlen 0.0455182 0.0461243 0.98686
example_mha_sink_bwd_bhsd 0.0629021 0.0636204 0.98871
example_gqa_decode 0.0482183 0.0487408 0.989281
example_mha_bwd_bhsd 0.038846 0.0392603 0.989447
example_linear_attn_fwd 0.0360089 0.036352 0.990561
example_dequant_gemm_bf16_fp4_hopper 0.568597 0.573274 0.991842
example_gqa_sink_bwd_bhsd 0.0419182 0.0422393 0.992398
example_mla_decode 0.460575 0.462221 0.996438
example_vertical_slash_sparse_attn 0.230965 0.231426 0.998009
example_mhc_post 0.108988 0.109124 0.998746
example_elementwise_add 0.115712 0.115827 0.999001
example_gemm 0.0224023 0.0224144 0.999461
example_dequant_gemv_fp16xint4 0.0283372 0.0283474 0.999641
example_topk 0.0111128 0.0111145 0.999845
example_dequant_gemm_fp4_hopper 1.05168 1.05171 0.999966
example_gemv 0.288209 0.28821 0.999998
tilelang_example_sparse_tensorcore 0.014646 0.0146435 1.00017
example_dynamic 0.643648 0.643115 1.00083
example_gemm_intrinsics 0.03487 0.0348379 1.00092
example_mha_inference 0.0784942 0.0784137 1.00103
example_gemm_autotune 0.0225312 0.0225029 1.00126
example_convolution_autotune 0.984679 0.983216 1.00149
example_tilelang_gemm_fp8 0.311647 0.31118 1.0015
example_fusedmoe_tilelang 0.13315 0.132905 1.00185
example_tilelang_gemm_fp8_intrinsic 0.843726 0.842151 1.00187
example_per_token_cast_to_fp8 0.00739086 0.00737472 1.00219
example_gqa_fwd_bshd 0.070394 0.0702373 1.00223
example_group_per_split_token_cast_to_fp8 0.0104383 0.0103969 1.00398
fp8_lighting_indexer 0.0361452 0.0359887 1.00435
example_tilelang_gemm_fp8_2xAcc 0.187801 0.186795 1.00538
example_tilelang_nsa_decode 0.00746856 0.00742754 1.00552
example_mha_fwd_bshd 0.0259248 0.0257791 1.00565
example_tilelang_nsa_fwd 0.0069037 0.00686397 1.00579
example_blocksparse_gemm 0.0202157 0.0200824 1.00664
topk_selector 0.0543412 0.0539223 1.00777
example_linear_attn_bwd 0.152995 0.151672 1.00872
example_convolution 1.31051 1.29898 1.00888
example_tilelang_block_sparse_attn 0.00883077 0.00874075 1.0103
example_tilelang_sparse_gqa_decode_varlen_mask 0.0178499 0.0176193 1.01309
example_gqa_sink_bwd_bhsd_sliding_window 0.0259062 0.0255349 1.01454
example_mha_fwd_bhsd 0.0108134 0.0106092 1.01925
example_dequant_gemm_w4a8 5.68837 5.57947 1.01952
sparse_mla_fwd_pipelined 0.0965303 0.0945469 1.02098
example_tilelang_sparse_gqa_decode_varlen_indice 0.0163759 0.0160163 1.02245
example_dequant_gemm_bf16_mxfp4_hopper 0.521254 0.509405 1.02326
example_mha_sink_bwd_bhsd_sliding_window 0.0447449 0.0432303 1.03504
example_mha_sink_fwd_bhsd 0.0155182 0.0149903 1.03522
sparse_mla_fwd 0.132773 0.127851 1.0385
block_sparse_attn_tilelang 0.00922644 0.00884454 1.04318
example_mha_sink_fwd_bhsd_sliding_window 0.0159049 0.0150199 1.05892
example_gqa_bwd 0.0507554 0.0463927 1.09404
example_gqa_bwd_tma_reduce_varlen 0.0525196 0.0474874 1.10597
sparse_mla_bwd 0.425934 0.3069 1.38786

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999
Copy link
Copy Markdown
Member Author

@regression-perf

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 6, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24024286051

Results

File Original Latency Current Latency Speedup
example_gemm_autotune 0.0224844 0.0256021 0.878227
example_mha_sink_bwd_bhsd_sliding_window 0.0443557 0.0486504 0.911724
example_blocksparse_gemm 0.0199906 0.0215128 0.92924
example_tilelang_block_sparse_attn 0.00876865 0.00943179 0.929691
example_convolution_autotune 0.982797 1.05651 0.930227
example_mha_sink_fwd_bhsd 0.0153866 0.0163441 0.941415
example_vertical_slash_sparse_attn 0.231049 0.242895 0.951232
example_mha_sink_bwd_bhsd 0.0624514 0.0656417 0.951399
example_mha_bwd_bhsd 0.0387989 0.0406864 0.953611
example_mha_fwd_bhsd 0.0108164 0.0113331 0.954413
example_mha_bwd_bshd 0.0391646 0.0410219 0.954726
example_convolution 1.29596 1.34542 0.963242
example_mha_fwd_bshd 0.0258537 0.0266323 0.970766
example_warp_specialize_gemm_copy_1_gemm_0 0.0269154 0.0276903 0.972016
example_warp_specialize_gemm_softpipe_stage2 0.026919 0.0276843 0.972355
example_warp_specialize_gemm_copy_0_gemm_1 0.03882 0.0398691 0.973687
example_gqa_fwd_bshd 0.0703474 0.0719813 0.977301
example_warp_specialize_gemm_barrierpipe_stage2 0.0397814 0.0405745 0.980453
example_gqa_sink_bwd_bhsd 0.0413668 0.0421738 0.980864
example_mha_fwd_varlen 0.0454986 0.0462076 0.984656
example_tilelang_gemm_fp8_2xAcc 0.187109 0.189205 0.988922
example_mla_decode 0.454266 0.459332 0.98897
example_mha_sink_fwd_bhsd_sliding_window 0.0157457 0.0157932 0.996991
example_topk 0.0109947 0.0110162 0.99805
example_group_per_split_token_cast_to_fp8 0.0103416 0.0103518 0.999019
tilelang_example_sparse_tensorcore 0.0145877 0.0146009 0.999094
example_dynamic 0.642984 0.643242 0.999598
example_tilelang_gemm_fp8 0.311491 0.311576 0.999727
example_gemm_intrinsics 0.0348357 0.0348364 0.99998
example_per_token_cast_to_fp8 0.00735897 0.00735614 1.00038
example_tilelang_gemm_fp8_intrinsic 0.843636 0.842143 1.00177
example_gemm 0.0224349 0.0223786 1.00252
example_gqa_sink_bwd_bhsd_sliding_window 0.0255811 0.0254763 1.00411
example_tilelang_sparse_gqa_decode_varlen_mask 0.0176198 0.0175368 1.00474
example_elementwise_add 0.115711 0.114975 1.0064
block_sparse_attn_tilelang 0.00919018 0.00908102 1.01202
example_tilelang_sparse_gqa_decode_varlen_indice 0.0161803 0.0159441 1.01482
example_gqa_bwd_tma_reduce_varlen 0.0524131 0.0514666 1.01839
example_tilelang_gemm_splitk 1.09043 1.02697 1.06179
example_tilelang_gemm_splitk_vectorize_atomicadd 1.10049 1.03287 1.06547
example_gqa_bwd 0.0506896 0.0471402 1.07529

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999
Copy link
Copy Markdown
Member Author

@regression-perf

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 6, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24030730315

Results

File Original Latency Current Latency Speedup
example_vertical_slash_sparse_attn 0.226222 0.2389 0.946931
example_warp_specialize_gemm_copy_0_gemm_1 0.0384375 0.0396086 0.970434
example_warp_specialize_gemm_softpipe_stage2 0.0264394 0.0272203 0.97131
example_warp_specialize_gemm_copy_1_gemm_0 0.0264346 0.0272093 0.971526
example_warp_specialize_gemm_barrierpipe_stage2 0.0393872 0.0402669 0.978152
example_mha_sink_bwd_bhsd 0.0614657 0.0626678 0.980817
example_gqa_sink_bwd_bhsd 0.0405153 0.0412793 0.98149
example_mha_bwd_bshd 0.0385998 0.0393161 0.981779
example_mla_decode 0.442456 0.450163 0.982879
example_mha_fwd_varlen 0.0449755 0.0457313 0.983472
example_mha_bwd_bhsd 0.0382115 0.0386523 0.988596
example_tilelang_gemm_fp8_2xAcc 0.192797 0.193242 0.997698
example_convolution 1.27266 1.2755 0.997772
example_blocksparse_gemm 0.019813 0.0198347 0.998903
example_per_token_cast_to_fp8 0.00734028 0.00734608 0.99921
example_tilelang_gemm_fp8 0.308206 0.308323 0.999618
example_gemm_autotune 0.0222989 0.0223059 0.999686
tilelang_example_sparse_tensorcore 0.0144787 0.0144811 0.999831
example_gemm_intrinsics 0.0342651 0.0342649 1.00001
example_dynamic 0.637831 0.637826 1.00001
example_topk 0.0108872 0.0108861 1.00011
example_convolution_autotune 0.979676 0.979333 1.00035
example_group_per_split_token_cast_to_fp8 0.010326 0.0103201 1.00057
example_gqa_fwd_bshd 0.0691743 0.0691009 1.00106
example_gemm 0.0223091 0.0222819 1.00122
example_tilelang_gemm_fp8_intrinsic 0.823992 0.822985 1.00122
example_elementwise_add 0.115813 0.115475 1.00292
example_mha_fwd_bshd 0.0255816 0.0254701 1.00438
example_gqa_sink_bwd_bhsd_sliding_window 0.0251406 0.0250186 1.00488
example_tilelang_sparse_gqa_decode_varlen_mask 0.0174402 0.0173549 1.00492
example_tilelang_block_sparse_attn 0.00873252 0.00865662 1.00877
example_tilelang_sparse_gqa_decode_varlen_indice 0.0160373 0.0157966 1.01524
example_mha_fwd_bhsd 0.0107429 0.0105489 1.01839
example_mha_sink_fwd_bhsd 0.0152338 0.0148184 1.02803
example_mha_sink_bwd_bhsd_sliding_window 0.0440208 0.0427673 1.02931
block_sparse_attn_tilelang 0.00911659 0.00873719 1.04342
example_mha_sink_fwd_bhsd_sliding_window 0.0156315 0.0148515 1.05252
example_tilelang_gemm_splitk_vectorize_atomicadd 1.0848 1.02445 1.05891
example_tilelang_gemm_splitk 1.07871 1.0161 1.06162
example_gqa_bwd 0.0498685 0.045719 1.09076
example_gqa_bwd_tma_reduce_varlen 0.0515569 0.0467737 1.10226

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999
Copy link
Copy Markdown
Member Author

@regression-perf

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 6, 2026

Performance Regression Test Report

Triggered by: @LeiWang1999
Workflow run: https://github.com/tile-ai/tilelang/actions/runs/24045066648

Results

File Original Latency Current Latency Speedup
example_vertical_slash_sparse_attn 0.226136 0.238922 0.946483
example_warp_specialize_gemm_barrierpipe_stage2 0.0392853 0.0404261 0.971779
example_warp_specialize_gemm_copy_0_gemm_1 0.0384511 0.0395622 0.971915
example_warp_specialize_gemm_copy_1_gemm_0 0.0264535 0.027195 0.972734
example_warp_specialize_gemm_softpipe_stage2 0.0264598 0.0271897 0.973155
example_mha_bwd_bshd 0.0385429 0.0393904 0.978484
example_mha_sink_bwd_bhsd 0.0613384 0.0625817 0.980132
example_gqa_sink_bwd_bhsd 0.0405288 0.0412971 0.981395
example_mla_decode 0.442537 0.450189 0.983004
example_mha_fwd_varlen 0.0448769 0.04562 0.983711
example_mha_bwd_bhsd 0.0382868 0.0386584 0.990387
example_per_token_cast_to_fp8 0.00732985 0.00735217 0.996964
example_tilelang_gemm_fp8_2xAcc 0.192806 0.193287 0.997508
example_convolution 1.27271 1.27496 0.998233
example_blocksparse_gemm 0.019786 0.0198161 0.998483
example_group_per_split_token_cast_to_fp8 0.0103248 0.0103375 0.998772
example_gemm 0.022273 0.0222903 0.999226
example_gemm_autotune 0.0222929 0.0223082 0.999317
example_convolution_autotune 0.979152 0.979405 0.999742
example_topk 0.010897 0.0108998 0.99975
tilelang_example_sparse_tensorcore 0.0144696 0.0144705 0.999942
example_dynamic 0.638063 0.637839 1.00035
example_tilelang_gemm_fp8 0.308287 0.308177 1.00036
example_gemm_intrinsics 0.0342599 0.0342474 1.00037
example_gqa_fwd_bshd 0.0690314 0.0689696 1.0009
example_elementwise_add 0.115783 0.115401 1.00331
example_mha_fwd_bshd 0.0255118 0.0253945 1.00462
example_gqa_sink_bwd_bhsd_sliding_window 0.0251145 0.0249901 1.00498
example_tilelang_sparse_gqa_decode_varlen_mask 0.0174437 0.0173556 1.00508
example_tilelang_block_sparse_attn 0.00872245 0.0086575 1.0075
example_tilelang_gemm_fp8_intrinsic 0.824282 0.814198 1.01239
example_tilelang_sparse_gqa_decode_varlen_indice 0.0160249 0.015788 1.015
example_mha_fwd_bhsd 0.0107432 0.0105372 1.01954
example_mha_sink_fwd_bhsd 0.0152298 0.0148072 1.02854
example_mha_sink_bwd_bhsd_sliding_window 0.0439841 0.0426634 1.03095
block_sparse_attn_tilelang 0.00910388 0.00873056 1.04276
example_mha_sink_fwd_bhsd_sliding_window 0.0155792 0.0148316 1.05041
example_tilelang_gemm_splitk 1.07255 1.01718 1.05444
example_tilelang_gemm_splitk_vectorize_atomicadd 1.0844 1.02609 1.05683
example_gqa_bwd 0.0498591 0.0457105 1.09076
example_gqa_bwd_tma_reduce_varlen 0.0515541 0.0467743 1.10219

Artifacts

  • regression_result.png (speedup plot) is attached as a workflow artifact. Download it from the workflow run page above.

@LeiWang1999 LeiWang1999 merged commit 3ee0988 into tile-ai:main Apr 7, 2026
5 of 6 checks passed
LeiWang1999 added a commit to LeiWang1999/tilelang that referenced this pull request Apr 12, 2026
Two unrelated ROCm bugs were latent on main and surfaced on this PR's
ROCm CI run. Both predate the GEMM v1 removal work - the previous "green"
ROCm run (job 70894846938 on commit 0a34a6a) hit the exact same nine
failures but happened to receive exit code 0 from pytest-xdist after
maxfail kicked in, masking the failure. Fix both so the ROCm job is
actually green, not accidentally green.

Bug 1: HIP threadblock swizzle codegen emits runtime call instead of
template instantiation
----------------------------------------------------------------------

Introduced in 3ee0988 (tile-ai#2002). codegen_hip.cc rewrote the
threadblock_swizzle_pattern attribute handler to take a
(func_name, panel_size) tuple but emits panel_size as a runtime
argument: const dim3 blockIdx = tl::rasterization2DRow(10); while the
template in tl_templates/hip/threadblock_swizzle.h declares it as a
template parameter: template <int panel_width> dim3 rasterization2DRow().
hipcc rightly rejects the call with "no matching function". The CUDA
counterpart in codegen_cuda.cc already emits tl::func_name<panel_size>()
- mirror that in HIP.

Affects test_tilelang_gemm_mfma_intrinsic.py::test_assert_tl_matmul.

Bug 2: MFMA macro generator hardcodes 2D buffer indexing and breaks
under pipelined shared buffers
----------------------------------------------------------------------

MatrixCoreIntrinEmitter.ldmatrix_a / ldmatrix_b in
intrinsics/mfma_macro_generator.py extract A_base0/A_base1 from the last
two region dims but then index A_buf with exactly two indices:
A_buf[A_base0 + l + row, A_base1 + r + col]. This works when the
user-declared shared buffer is 2D, but pipeline multi-versioning
(T.Pipelined(..., num_stages >= 2)) rewrites the shared buffer to carry
a leading stage dimension, making A_buf 3D. The access then fails layout
inference with "Buffer A_shared is 3-dimensional, cannot be indexed with
the 2-dimensional indices provided". The CUDA counterpart in
intrinsics/mma_macro_generator.py handles this correctly by collecting
leading base offsets into A_other and indexing as
A_buf[tuple(A_other) + (A_base0 + ..., A_base1 + ...)]. Mirror that
pattern in the MFMA generator for both ldmatrix_a and ldmatrix_b.

Affects test_block_sparse_matmul_{global,shared,local} and
test_tilelang_jit_{callback,gemm_cython}::test_gemm_jit_kernel plus
test_cython_kernel_multi_stream.

Why these were masked on main
----------------------------------------------------------------------

.github/workflows/ci.yml runs the ROCm test step with
--maxfail=3 --numprocesses=4. With xdist parallelism four workers can
accumulate failures concurrently; the controller eventually raises
xdist.dsession.Interrupted: stopping after 3 failures but pytest's exit
code on that interrupt path is non-deterministic. Previous runs got
exit 0 (success) while this PR happened to get exit 2 (failure). Fixing
the underlying bugs is the only way to make the job reliably green.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant