Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
or "InternLM2ForRewardModel" in model_architectures
or "Qwen2ForRewardModel" in model_architectures
or "Qwen3ForRewardModel" in model_architectures
or "Qwen2ForSequenceClassification" in model_architectures
or "Qwen3ForSequenceClassification" in model_architectures
or "CLIPModel" in model_architectures
Expand Down
102 changes: 81 additions & 21 deletions python/sglang/srt/models/qwen3_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================

import logging
from typing import Iterable, Optional, Tuple

import torch
Expand All @@ -21,11 +22,19 @@
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3 import Qwen3Model
from sglang.srt.utils import add_prefix

logger = logging.getLogger(__name__)


class Qwen3ForPooledOutput(nn.Module):
"""Base class for Qwen3 models that produce pooled output (classification, reward).

Subclasses should set self.score and self.pooler in their __init__.
"""

class Qwen3ForSequenceClassification(nn.Module):
def __init__(
self,
config: Qwen2Config,
Expand All @@ -38,19 +47,8 @@ def __init__(
self.model = Qwen3Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Use normalize=True for qwen3 embedding based on official implementation
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
# Official code: output = F.normalize(output, p=2, dim=1)
normalize = True

# We don't want to normalize the embedding if we have a classification head
if config.id2label is not None or config.label2id is not None:
normalize = False

self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize)

self.eos_token_id = config.eos_token_id
# Subclasses must set self.score and self.pooler

@torch.no_grad()
def forward(
Expand All @@ -61,9 +59,7 @@ def forward(
input_embeds: Optional[torch.Tensor] = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "Qwen3ForSequenceClassification is only used for embedding"
assert get_embedding, f"{self.__class__.__name__} is only used for embedding"

hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
logits = self.score(hidden_states)
Expand All @@ -72,11 +68,75 @@ def forward(
return EmbeddingPoolerOutput(pooled_logits)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Filter out lm_head weights of Qwen3ForCausalLM
filtered_weights = [
(name, w) for name, w in weights if not name.startswith("lm_head")
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
return Qwen3ForCausalLM.load_weights(self, filtered_weights)

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# Skip lm_head weights (pooled output models don't have lm_head)
if name.startswith("lm_head"):
continue

# Skip rotary embeddings and other non-parameter tensors
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue

# Handle stacked parameters (qkv_proj, gate_up_proj)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models
if name.endswith(".bias") and name not in params_dict:
continue

if name in params_dict:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")


class Qwen3ForSequenceClassification(Qwen3ForPooledOutput):
def __init__(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, quant_config, prefix)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Use normalize=True for qwen3 embedding based on official implementation
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
# Official code: output = F.normalize(output, p=2, dim=1)
normalize = True

# We don't want to normalize the embedding if we have a classification head
if config.id2label is not None or config.label2id is not None:
normalize = False

self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=normalize)


EntryClass = [
Expand Down
47 changes: 47 additions & 0 deletions python/sglang/srt/models/qwen3_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Qwen3 Reward Model for RLHF and best-of-N sampling."""

from typing import Optional

from torch import nn
from transformers import Qwen2Config # Qwen3 uses Qwen2Config

from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.models.qwen3_classification import Qwen3ForPooledOutput


class Qwen3ForRewardModel(Qwen3ForPooledOutput):
"""Qwen3 Reward Model with 2-layer MLP scoring head for RLHF."""

def __init__(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, quant_config, prefix)
self.num_labels = 1
self.score = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(),
nn.Linear(config.hidden_size, self.num_labels),
)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)


EntryClass = [
Qwen3ForRewardModel,
]
2 changes: 2 additions & 0 deletions test/registered/models/test_reward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
MODELS = [
("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2),
("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 4e-2),
# Qwen3-based reward model (uses Qwen3ForSequenceClassification)
("Skywork/Skywork-Reward-V2-Qwen3-0.6B", 1, 1.5e-1),
]
TORCH_DTYPES = [torch.float16]

Expand Down
Loading