Skip to content

[Model] Add Qwen3ForRewardModel and fix Qwen3ForSequenceClassification#17992

Merged
Kangyan-Zhou merged 4 commits intosgl-project:mainfrom
shvmjndl:feature/qwen3-reward-model
Feb 16, 2026
Merged

[Model] Add Qwen3ForRewardModel and fix Qwen3ForSequenceClassification#17992
Kangyan-Zhou merged 4 commits intosgl-project:mainfrom
shvmjndl:feature/qwen3-reward-model

Conversation

@shvmjndl
Copy link
Contributor

@shvmjndl shvmjndl commented Jan 30, 2026

Summary

This PR adds support for Qwen3 reward models and fixes a bug in the existing Qwen3ForSequenceClassification class.

Changes

  • Add new Qwen3ForRewardModel class for RLHF and best-of-N sampling
  • Fix AttributeError in Qwen3ForSequenceClassification.load_weights() (was failing with 'Qwen3ForSequenceClassification' object has no attribute 'pp_group')
  • Add Qwen3ForRewardModel to non-generation model list in model_config.py
  • Add Skywork-Reward-V2-Qwen3-0.6B to reward model tests

Bug Details

The Qwen3ForSequenceClassification.load_weights() was delegating to Qwen3ForCausalLM.load_weights(), which accesses self.pp_group. However, classification/reward models don't have pp_group, causing:

AttributeError: 'Qwen3ForSequenceClassification' object has no attribute 'pp_group'

This prevented loading Qwen3-based reward models like the Skywork-Reward-V2-Qwen3 series.

Testing

  • Tested with Skywork/Skywork-Reward-V2-Qwen3-0.6B
  • HuggingFace vs SGLang scores match within tolerance
  • Official test data from Skywork model card
Response HuggingFace SGLang Status
Correct (9÷3=3) 7.8750 7.9375 ✅ PASS
Incorrect (9÷2=4.5) 0.9453 1.0469 ✅ PASS

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @shvmjndl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the model ecosystem by introducing comprehensive support for Qwen3-based reward models, which are crucial for advanced AI training techniques like RLHF and best-of-N sampling. Concurrently, it addresses and rectifies a critical bug in the existing Qwen3ForSequenceClassification class, ensuring that these models can now load their weights correctly and operate without errors. These changes collectively improve the robustness and expand the capabilities for working with Qwen3 models.

Highlights

  • New Model Support: Added the Qwen3ForRewardModel class, enabling support for Qwen3-based reward models for applications like Reinforcement Learning from Human Feedback (RLHF) and best-of-N sampling.
  • Bug Fix: Resolved an AttributeError in Qwen3ForSequenceClassification.load_weights() by implementing a custom weight loading mechanism, which correctly handles stacked parameters and filters out irrelevant weights, preventing crashes during model loading.
  • Configuration Update: Included Qwen3ForRewardModel in the list of non-generation models within model_config.py to ensure proper classification and handling.
  • Expanded Testing: Added Skywork-Reward-V2-Qwen3-0.6B to the reward model test suite, validating the new Qwen3 reward model implementation against a known model.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@shvmjndl shvmjndl force-pushed the feature/qwen3-reward-model branch from 70b8a52 to 88bd8a7 Compare January 30, 2026 13:38
@shvmjndl
Copy link
Contributor Author

Local Test Results

Tested with Skywork/Skywork-Reward-V2-Qwen3-0.6B:

$ python test/registered/models/test_reward_models.py

hf_scores=tensor([1.2598, 1.8564])
srt_scores=tensor([1.2471, 1.8281])

----------------------------------------------------------------------
Ran 1 test in 76.390s

OK
Metric HuggingFace SGLang Diff
Response 1 1.2598 1.2471 0.0127
Response 2 1.8564 1.8281 0.0283

Both within tolerance (0.15). Test PASSED.

@shvmjndl shvmjndl force-pushed the feature/qwen3-reward-model branch from 88bd8a7 to 311488f Compare January 30, 2026 13:41
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Qwen3ForRewardModel and fixes a bug in Qwen3ForSequenceClassification where loading weights failed due to an AttributeError. The changes look good and correctly address the issue. However, I've identified significant code duplication between the new Qwen3ForRewardModel and the modified Qwen3ForSequenceClassification. Refactoring this into a common base class would greatly improve maintainability.

Comment on lines +35 to +123
class Qwen3ForRewardModel(nn.Module):
def __init__(
self,
config: Qwen3Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.num_labels = 1
self.model = Qwen3Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.score = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, self.num_labels),
)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)

self.eos_token_id = config.eos_token_id

@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert get_embedding, "Qwen3ForRewardModel is only used for embedding"

hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
logits = self.score(hidden_states)
pooled_logits = self.pooler(logits, forward_batch).embeddings

return EmbeddingPoolerOutput(pooled_logits)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# Skip lm_head weights (reward model doesn't have lm_head)
if name.startswith("lm_head"):
continue

# Skip rotary embeddings and other non-parameter tensors
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue

# Handle stacked parameters (qkv_proj, gate_up_proj)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue

if name in params_dict:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The load_weights method is identical to the one in Qwen3ForSequenceClassification. Additionally, the forward method is also duplicated. To improve maintainability and avoid code duplication, consider creating a common base class for both Qwen3ForRewardModel and Qwen3ForSequenceClassification. This base class could contain the shared forward and load_weights methods, while the subclasses would only need to define their specific score head and pooler in their __init__ methods.

This PR adds support for Qwen3 reward models and fixes a bug in the
existing Qwen3ForSequenceClassification class.

Changes:
- Add new Qwen3ForRewardModel class for RLHF and best-of-N sampling
- Fix AttributeError in Qwen3ForSequenceClassification.load_weights()
  (was failing with 'object has no attribute pp_group')
- Add Qwen3ForRewardModel to non-generation model list in model_config.py
- Add Skywork-Reward-V2-Qwen3-0.6B to reward model tests

The bug fix enables SGLang to serve Qwen3-based reward models such as
the Skywork-Reward-V2-Qwen3 series which were previously failing to load.

Tested with Skywork/Skywork-Reward-V2-Qwen3-0.6B - scores match
HuggingFace outputs within tolerance.
@shvmjndl shvmjndl force-pushed the feature/qwen3-reward-model branch from 311488f to c194759 Compare January 30, 2026 13:55
@shvmjndl
Copy link
Contributor Author

Thanks for the feedback! I've refactored the code to eliminate duplication.

Changes made:

  • Created Qwen3ForPooledOutput base class in qwen3_classification.py containing:
    • Common __init__ (config, model, eos_token_id)
    • Shared forward() method
    • Shared load_weights() method
  • Qwen3ForSequenceClassification now inherits from Qwen3ForPooledOutput and only sets up its specific score (Linear) and pooler (with conditional normalization)
  • Qwen3ForRewardModel inherits from Qwen3ForPooledOutput and only sets up its specific score (2-layer MLP) and pooler

Result:

  • Net reduction of ~70 lines
  • qwen3_rm.py is now 47 lines (down from 128)
  • All tests pass

The base class can also be reused if more Qwen3 pooled output model variants are needed in the future.

@Kangyan-Zhou Kangyan-Zhou self-assigned this Feb 3, 2026
@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

1 similar comment
@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

@zhaochenyang20
Copy link
Collaborator

After passing CIs, I will merge it @shvmjndl

@shvmjndl
Copy link
Contributor Author

shvmjndl commented Feb 9, 2026

@zhaochenyang20 can u pls check this

@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

2 similar comments
@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

@zhaochenyang20
Copy link
Collaborator

/rerun-failed-ci

@JustinTong0323
Copy link
Collaborator

/rerun-failed-ci

@Kangyan-Zhou Kangyan-Zhou merged commit 4f0409f into sgl-project:main Feb 16, 2026
218 of 232 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants