Skip to content

Faster MLA kernels#4391

Merged
lvhan028 merged 3 commits intoInternLM:mainfrom
lzhangzz:faster-mla
Mar 5, 2026
Merged

Faster MLA kernels#4391
lvhan028 merged 3 commits intoInternLM:mainfrom
lzhangzz:faster-mla

Conversation

@lzhangzz
Copy link
Copy Markdown
Collaborator

@lzhangzz lzhangzz commented Mar 3, 2026

  • Add MLA mainloop for sm_80
  • Allow CTA_H to be multiple of WARP_H
  • Adjust hyper-parameters to lower register-spilling

set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_compile_options(attention PRIVATE -O3
$<$<COMPILE_LANGUAGE:CUDA>:-use_fast_math --expt-relaxed-constexpr>)
$<$<COMPILE_LANGUAGE:CUDA>:-use_fast_math --expt-relaxed-constexpr -Xptxas=-v --threads 16>)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Now the compiliation duration of target kv_cache_utils_v2 dropped from 600+s to 110+s

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR speeds up Turbomind attention/decoding for MLA (HeadDim=576) by adding an Sm80 cp.async MLA mainloop, relaxing CTA/WARP head tiling constraints, and retuning kernel hyperparameters to reduce register pressure/spills.

Changes:

  • Add an Sm80 MLA mainloop path (selected via Impl::MLA) and introduce 576-specific decoding/attention configs.
  • Generalize warp mapping/tiling to allow CTA_H to be a multiple of WARP_H, updating thread maps and shared-memory indexing accordingly.
  • Retune KV-cache processing/flattening for HeadDim=576 (e.g., smaller CTA_S) and adjust build flags for CUDA compilation.

Reviewed changes

Copilot reviewed 19 out of 19 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/turbomind/kernels/core/thread_map.h Extends RakedThreadMap to support configurable warp partitioning and uses cdiv() for iteration counts.
src/turbomind/kernels/attention/mainloop_sm80.h Adds MLA-specific Sm80 cp.async mainloop and routes selection via a new Impl::MLA boolean.
src/turbomind/kernels/attention/kv_cache_utils_v2.cu Adjusts CTA shape for HeadDim=576 (share-KV) and updates checks/grid math.
src/turbomind/kernels/attention/impl_simt.h Allows multi-warp head tiling (kWarpCntH>1) with updated warp-id mapping and sync behavior; adds MLA flag.
src/turbomind/kernels/attention/impl_884.h Adds Impl::MLA = false for compatibility with Sm80 mainloop dispatch.
src/turbomind/kernels/attention/impl_81616.h Updates warp-counting/shared-memory layout for multi-warp head tiling; adds MLA flag and warp-id helper.
src/turbomind/kernels/attention/impl_1688.h Adds Impl::MLA flag for Sm75 tensorcore implementation.
src/turbomind/kernels/attention/impl_16816.h Adds Impl::MLA flag for Sm80 tensorcore implementation.
src/turbomind/kernels/attention/decoding_config.h Adds HeadDim=576 specializations and modifies Sm80 dispatch to avoid generic paths for 576.
src/turbomind/kernels/attention/codegen/*.cu Removes some unused decoding instantiations for 576 variants (primarily Qh != 8 cases).
src/turbomind/kernels/attention/attention_config.h Adds linear-cache attention configs for HeadDim=576 on Sm80 and Sm75.
src/turbomind/kernels/attention/CMakeLists.txt Adds CUDA compile flags (-Xptxas=-v --threads 16) for the attention target.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

minor fix

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@lvhan028 lvhan028 merged commit 035dd4e into InternLM:main Mar 5, 2026
1 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants