Skip to content

Router replay for gpt oss#4298

Merged
lvhan028 merged 3 commits intoInternLM:mainfrom
RunningLeon:align-rl
Jan 28, 2026
Merged

Router replay for gpt oss#4298
lvhan028 merged 3 commits intoInternLM:mainfrom
RunningLeon:align-rl

Conversation

@RunningLeon
Copy link
Copy Markdown
Collaborator

Motivation

Please describe the motivation of this PR and the goal you want to achieve through this PR.

Modification

  1. add router replay for gpt oss
  2. align dtype in _silu_and_mul_kernel and _fwd_kernel_ep_gather with xtuner for rollout

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

Copilot AI review requested due to automatic review settings January 27, 2026 09:13
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds router replay functionality to the gpt_oss model and aligns dtype handling in MoE kernels with xtuner for rollout compatibility.

Changes:

  • Added router replay support for gpt_oss model (tracking expert routing decisions across layers)
  • Modified dtype handling in qwen3_moe gate to use torch.float32 for routing computations
  • Aligned kernel dtype behavior in _silu_and_mul_kernel and _fwd_kernel_ep_gather with xtuner

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
lmdeploy/pytorch/models/qwen3_moe.py Changed gate dtype to torch.float32 and added dtype casting for router inputs to align with xtuner
lmdeploy/pytorch/models/gpt_oss.py Added router replay functionality including all_routed_experts tensor propagation through model layers, context management, and return logic
lmdeploy/pytorch/kernels/cuda/fused_moe_ep.py Modified accumulation dtype to use output tensor dtype instead of always float32, aligning with xtuner for rollout
lmdeploy/pytorch/kernels/cuda/activation.py Optimized dtype conversions by only converting gate to fp32 for exponential operations, keeping up in original dtype

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


def forward(self, hidden_states):
def forward(self, hidden_states, all_routed_experts: torch.Tensor = None):
router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment incorrectly states the shape as "(num_experts, seq_len)". Based on the router's forward method (line 244), router_indices has shape (seq_len, top_k), where top_k is the number of experts per token (config.num_experts_per_tok).

Suggested change
router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len)
router_scores, router_indices = self.router(hidden_states) # (seq_len, top_k)

Copilot uses AI. Check for mistakes.
@lvhan028 lvhan028 requested a review from grimoire January 28, 2026 02:43
@lvhan028 lvhan028 added the enhancement New feature or request label Jan 28, 2026
self.num_experts,
bias=False,
dtype=dtype,
dtype=torch.float32,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still concerned about the hard-coded torch.float32. After all the data type of gate.weight is BF16
https://huggingface.co/Qwen/Qwen3-30B-A3B-Thinking-2507/blob/main/model-00001-of-00016.safetensors

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can accept this PR without changing qwen3_moe.py

@grimoire
Copy link
Copy Markdown
Collaborator

bf16 ep use _silu_and_mul_moe_ep_kernel

@lvhan028 lvhan028 merged commit 2184ce7 into InternLM:main Jan 28, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants