[Model] Add Qwen3ForRewardModel and fix Qwen3ForSequenceClassification#17992
Conversation
Summary of ChangesHello @shvmjndl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the model ecosystem by introducing comprehensive support for Qwen3-based reward models, which are crucial for advanced AI training techniques like RLHF and best-of-N sampling. Concurrently, it addresses and rectifies a critical bug in the existing Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
70b8a52 to
88bd8a7
Compare
Local Test ResultsTested with
Both within tolerance (0.15). Test PASSED. |
88bd8a7 to
311488f
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces Qwen3ForRewardModel and fixes a bug in Qwen3ForSequenceClassification where loading weights failed due to an AttributeError. The changes look good and correctly address the issue. However, I've identified significant code duplication between the new Qwen3ForRewardModel and the modified Qwen3ForSequenceClassification. Refactoring this into a common base class would greatly improve maintainability.
python/sglang/srt/models/qwen3_rm.py
Outdated
| class Qwen3ForRewardModel(nn.Module): | ||
| def __init__( | ||
| self, | ||
| config: Qwen3Config, | ||
| quant_config: Optional[QuantizationConfig] = None, | ||
| prefix: str = "", | ||
| ) -> None: | ||
| super().__init__() | ||
| self.config = config | ||
| self.quant_config = quant_config | ||
| self.num_labels = 1 | ||
| self.model = Qwen3Model( | ||
| config, quant_config=quant_config, prefix=add_prefix("model", prefix) | ||
| ) | ||
| self.score = nn.Sequential( | ||
| nn.Linear(config.hidden_size, config.hidden_size), | ||
| nn.ReLU(), | ||
| nn.Linear(config.hidden_size, self.num_labels), | ||
| ) | ||
| self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) | ||
|
|
||
| self.eos_token_id = config.eos_token_id | ||
|
|
||
| @torch.no_grad() | ||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| forward_batch: ForwardBatch, | ||
| input_embeds: torch.Tensor = None, | ||
| get_embedding: bool = True, | ||
| ) -> EmbeddingPoolerOutput: | ||
| assert get_embedding, "Qwen3ForRewardModel is only used for embedding" | ||
|
|
||
| hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) | ||
| logits = self.score(hidden_states) | ||
| pooled_logits = self.pooler(logits, forward_batch).embeddings | ||
|
|
||
| return EmbeddingPoolerOutput(pooled_logits) | ||
|
|
||
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
| stacked_params_mapping = [ | ||
| # (param_name, shard_name, shard_id) | ||
| ("qkv_proj", "q_proj", "q"), | ||
| ("qkv_proj", "k_proj", "k"), | ||
| ("qkv_proj", "v_proj", "v"), | ||
| ("gate_up_proj", "gate_proj", 0), | ||
| ("gate_up_proj", "up_proj", 1), | ||
| ] | ||
|
|
||
| params_dict = dict(self.named_parameters()) | ||
| for name, loaded_weight in weights: | ||
| # Skip lm_head weights (reward model doesn't have lm_head) | ||
| if name.startswith("lm_head"): | ||
| continue | ||
|
|
||
| # Skip rotary embeddings and other non-parameter tensors | ||
| if "rotary_emb.inv_freq" in name or "projector" in name: | ||
| continue | ||
| if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: | ||
| continue | ||
|
|
||
| # Handle stacked parameters (qkv_proj, gate_up_proj) | ||
| for param_name, weight_name, shard_id in stacked_params_mapping: | ||
| if weight_name not in name: | ||
| continue | ||
| name = name.replace(weight_name, param_name) | ||
| # Skip loading extra bias for GPTQ models | ||
| if name.endswith(".bias") and name not in params_dict: | ||
| continue | ||
| if name not in params_dict: | ||
| continue | ||
| param = params_dict[name] | ||
| weight_loader = param.weight_loader | ||
| weight_loader(param, loaded_weight, shard_id) | ||
| break | ||
| else: | ||
| # Skip loading extra bias for GPTQ models | ||
| if name.endswith(".bias") and name not in params_dict: | ||
| continue | ||
|
|
||
| if name in params_dict: | ||
| param = params_dict[name] | ||
| weight_loader = getattr( | ||
| param, "weight_loader", default_weight_loader | ||
| ) | ||
| weight_loader(param, loaded_weight) | ||
| else: | ||
| logger.warning(f"Parameter {name} not found in params_dict") |
There was a problem hiding this comment.
The load_weights method is identical to the one in Qwen3ForSequenceClassification. Additionally, the forward method is also duplicated. To improve maintainability and avoid code duplication, consider creating a common base class for both Qwen3ForRewardModel and Qwen3ForSequenceClassification. This base class could contain the shared forward and load_weights methods, while the subclasses would only need to define their specific score head and pooler in their __init__ methods.
This PR adds support for Qwen3 reward models and fixes a bug in the existing Qwen3ForSequenceClassification class. Changes: - Add new Qwen3ForRewardModel class for RLHF and best-of-N sampling - Fix AttributeError in Qwen3ForSequenceClassification.load_weights() (was failing with 'object has no attribute pp_group') - Add Qwen3ForRewardModel to non-generation model list in model_config.py - Add Skywork-Reward-V2-Qwen3-0.6B to reward model tests The bug fix enables SGLang to serve Qwen3-based reward models such as the Skywork-Reward-V2-Qwen3 series which were previously failing to load. Tested with Skywork/Skywork-Reward-V2-Qwen3-0.6B - scores match HuggingFace outputs within tolerance.
311488f to
c194759
Compare
|
Thanks for the feedback! I've refactored the code to eliminate duplication. Changes made:
Result:
The base class can also be reused if more Qwen3 pooled output model variants are needed in the future. |
|
/rerun-failed-ci |
1 similar comment
|
/rerun-failed-ci |
|
After passing CIs, I will merge it @shvmjndl |
|
@zhaochenyang20 can u pls check this |
|
/rerun-failed-ci |
2 similar comments
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
|
/rerun-failed-ci |
Summary
This PR adds support for Qwen3 reward models and fixes a bug in the existing
Qwen3ForSequenceClassificationclass.Changes
Qwen3ForRewardModelclass for RLHF and best-of-N samplingAttributeErrorinQwen3ForSequenceClassification.load_weights()(was failing with'Qwen3ForSequenceClassification' object has no attribute 'pp_group')Qwen3ForRewardModelto non-generation model list inmodel_config.pySkywork-Reward-V2-Qwen3-0.6Bto reward model testsBug Details
The
Qwen3ForSequenceClassification.load_weights()was delegating toQwen3ForCausalLM.load_weights(), which accessesself.pp_group. However, classification/reward models don't havepp_group, causing:This prevented loading Qwen3-based reward models like the Skywork-Reward-V2-Qwen3 series.
Testing
Skywork/Skywork-Reward-V2-Qwen3-0.6B