Conversation
There was a problem hiding this comment.
Pull request overview
This pull request adds router replay functionality to the gpt_oss model and aligns dtype handling in MoE kernels with xtuner for rollout compatibility.
Changes:
- Added router replay support for gpt_oss model (tracking expert routing decisions across layers)
- Modified dtype handling in qwen3_moe gate to use torch.float32 for routing computations
- Aligned kernel dtype behavior in
_silu_and_mul_kerneland_fwd_kernel_ep_gatherwith xtuner
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| lmdeploy/pytorch/models/qwen3_moe.py | Changed gate dtype to torch.float32 and added dtype casting for router inputs to align with xtuner |
| lmdeploy/pytorch/models/gpt_oss.py | Added router replay functionality including all_routed_experts tensor propagation through model layers, context management, and return logic |
| lmdeploy/pytorch/kernels/cuda/fused_moe_ep.py | Modified accumulation dtype to use output tensor dtype instead of always float32, aligning with xtuner for rollout |
| lmdeploy/pytorch/kernels/cuda/activation.py | Optimized dtype conversions by only converting gate to fp32 for exponential operations, keeping up in original dtype |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| def forward(self, hidden_states): | ||
| def forward(self, hidden_states, all_routed_experts: torch.Tensor = None): | ||
| router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) |
There was a problem hiding this comment.
The comment incorrectly states the shape as "(num_experts, seq_len)". Based on the router's forward method (line 244), router_indices has shape (seq_len, top_k), where top_k is the number of experts per token (config.num_experts_per_tok).
| router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) | |
| router_scores, router_indices = self.router(hidden_states) # (seq_len, top_k) |
lmdeploy/pytorch/models/qwen3_moe.py
Outdated
| self.num_experts, | ||
| bias=False, | ||
| dtype=dtype, | ||
| dtype=torch.float32, |
There was a problem hiding this comment.
I am still concerned about the hard-coded torch.float32. After all the data type of gate.weight is BF16
https://huggingface.co/Qwen/Qwen3-30B-A3B-Thinking-2507/blob/main/model-00001-of-00016.safetensors
There was a problem hiding this comment.
I can accept this PR without changing qwen3_moe.py
|
bf16 ep use |
Motivation
Please describe the motivation of this PR and the goal you want to achieve through this PR.
Modification
_silu_and_mul_kerneland_fwd_kernel_ep_gatherwith xtuner for rolloutBC-breaking (Optional)
Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.
Checklist