Skip to content

[WebGPU] Support warp-level shuffle primitives with subgroup#17699

Draft
CharlieFRuan wants to merge 1 commit intoapache:mainfrom
CharlieFRuan:pr-0302-webgpu-shuffle
Draft

[WebGPU] Support warp-level shuffle primitives with subgroup#17699
CharlieFRuan wants to merge 1 commit intoapache:mainfrom
CharlieFRuan:pr-0302-webgpu-shuffle

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Mar 3, 2025

Overview

This PR supports warp-level shuffle primitives using the newly introduced subgroup in WebGPU. We then use them in the implementation of allreduce lowering.

The introduced primitives are:

  • subgroupShuffle()
  • subgroupShuffleUp()
  • subgroupShuffleDown()

This PR largely follows the Metal counterpart:

Tested with Llama3.2-1B-q4f16_1 and Llama3.1-8B-q4f16_1 E2E with WebLLM. The dumped WebGPU kernel indeed contains subgroup shuffle primitives: https://gist.github.com/CharlieFRuan/cb54a8db0513ecbbc16c5de8df5ab845

Remaining TODOs

  • Benchmark speedup
  • Be able to parameterize whether to use subgroup or not when targeting WebGPU, since not all devices support it
  • Check GPUFeatureName's inclusion of subgroups in @webgpu/types
  • Some WebGPU devices can have > 256 max num thread per block, be able to target different kinds

Resources

Comment thread web/src/webgpu.ts
}

const requiredFeatures: GPUFeatureName[] = [];
// TODO(Charlie): cannot type annotate because @webgpu/types
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@webgpu/types 0.1.55 should work now. See gpuweb/types#167

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks!

MasterJH5574 pushed a commit that referenced this pull request Apr 6, 2026
## Summary
This adds gating logic on top of #17699 to support optional subgroup
shuffle
primitives based on a compile-time flag.

## Problem
The PR #17699 always generates subgroup shuffle ops when targeting
WebGPU.
However, not all WebGPU devices support subgroups. We need a way to:
- Default to shared memory reductions (universally compatible)
- Optionally enable subgroup shuffles for devices that support them

## Solution
Implement gating via TVM target parameter:
- Default `thread_warp_size=1` disables warp reductions (uses shared
memory + barriers)
- Add target parser `UpdateWebGPUAttrs()` that sets
`thread_warp_size=32` when `supports_subgroups=true`
- Add `--enable-subgroups` CLI flag in mlc-llm to surface the option to
users

The gating happens at the reduction path selection level
(`IsWarpReduction()` in
`lower_thread_allreduce.cc`), ensuring subgroup ops are never generated
unless explicitly enabled.

## Testing

Tested with Llama-3.2-1B-q4f16_1. Baseline (no flag) uses shared memory
reductions;
with flag, generates subgroupShuffle* ops.
Both the generated WGSLs here:
https://gist.github.com/ksgr5566/301664a5dda3e46f44092be4d09b2d4f
Benchmarking:
https://gist.github.com/ksgr5566/c9bd5bc5aadba999ec2f2c38eb0c49b3
Aharrypotter pushed a commit to Aharrypotter/tvm that referenced this pull request Apr 10, 2026
## Summary
This adds gating logic on top of apache#17699 to support optional subgroup
shuffle
primitives based on a compile-time flag.

## Problem
The PR apache#17699 always generates subgroup shuffle ops when targeting
WebGPU.
However, not all WebGPU devices support subgroups. We need a way to:
- Default to shared memory reductions (universally compatible)
- Optionally enable subgroup shuffles for devices that support them

## Solution
Implement gating via TVM target parameter:
- Default `thread_warp_size=1` disables warp reductions (uses shared
memory + barriers)
- Add target parser `UpdateWebGPUAttrs()` that sets
`thread_warp_size=32` when `supports_subgroups=true`
- Add `--enable-subgroups` CLI flag in mlc-llm to surface the option to
users

The gating happens at the reduction path selection level
(`IsWarpReduction()` in
`lower_thread_allreduce.cc`), ensuring subgroup ops are never generated
unless explicitly enabled.

## Testing

Tested with Llama-3.2-1B-q4f16_1. Baseline (no flag) uses shared memory
reductions;
with flag, generates subgroupShuffle* ops.
Both the generated WGSLs here:
https://gist.github.com/ksgr5566/301664a5dda3e46f44092be4d09b2d4f
Benchmarking:
https://gist.github.com/ksgr5566/c9bd5bc5aadba999ec2f2c38eb0c49b3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants