Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the CUDA backend’s NoAux-TC MoE router to use a centralized predicate for deciding when to dispatch to the fused Triton routing kernel, rather than checking the condition inline in forward().
Changes:
- Added an
is_power_of_two()helper and new power-of-two constraints for enabling the fused kernel. - Introduced
should_enable_custom_kernel()and cached the result inself.enable_custom_kernel. - Simplified
forward()to branch solely on the cached enablement flag.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| if not is_power_of_two(self.n_group): | ||
| return False | ||
|
|
There was a problem hiding this comment.
should_enable_custom_kernel() can return True even when n_routed_experts is not divisible by n_group (e.g., 32 experts / 64 groups if both are powers of two). In that case fused_noaux_tc_routing() will hit its assert num_experts % n_group == 0 at runtime. Add an explicit divisibility check (and/or n_group <= n_routed_experts) to the enablement predicate so the router safely falls back to the default implementation instead of crashing.
| # fused_noaux_tc_routing requires num_experts % n_group == 0 | |
| if self.n_routed_experts < self.n_group: | |
| return False | |
| if self.n_routed_experts % self.n_group != 0: | |
| return False |
|
|
||
| if not is_power_of_two(self.n_group): | ||
| return False | ||
|
|
There was a problem hiding this comment.
The custom kernel path does not appear to validate topk_group <= n_group. The default router uses torch.topk(group_scores, k=self.topk_group, ...) which will raise if topk_group exceeds n_group, while the Triton kernel would silently select duplicate groups (behavior divergence). Consider adding a check in should_enable_custom_kernel() to require self.topk_group <= self.n_group so misconfigurations fail consistently (or fall back).
| if self.topk_group > self.n_group: | |
| return False |
| self.enable_custom_kernel = self.should_enable_custom_kernel() | ||
|
|
||
| def should_enable_custom_kernel(self) -> bool: | ||
| if self.router_n_groups > 0: | ||
| return False | ||
|
|
||
| if self.scoring_func != 'sigmoid': | ||
| return False | ||
|
|
||
| if self.n_routed_experts % 32 != 0: | ||
| return False | ||
|
|
||
| if not is_power_of_two(self.n_routed_experts): | ||
| return False | ||
|
|
||
| if not is_power_of_two(self.n_group): | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
| def forward(self, logits: torch.Tensor, bias: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """Router forward.""" | ||
| if self.router_n_groups <= 0 and self.scoring_func == 'sigmoid' and self.n_routed_experts % 32 == 0: | ||
| if self.enable_custom_kernel: | ||
| return fused_noaux_tc_routing( |
There was a problem hiding this comment.
The new kernel enablement logic (power-of-two requirements + cached enable_custom_kernel) isn’t covered by tests. Since there are existing CUDA kernel tests (e.g., tests/pytorch/kernel/test_moe_route.py), it would be good to add a small unit test that exercises TritonRouterNoauxTCImpl.forward() for both an enabled configuration and a disabled configuration (e.g., non-power-of-two n_group or router_n_groups>0) to ensure the correct path/fallback is taken.
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
Motivation
Please describe the motivation of this PR and the goal you want to achieve through this PR.
Modification
Please briefly describe what modification is made in this PR.
BC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist