Skip to content

Commit 4f7422f

Browse files
[NPU] support model skywork-reward-gemma2-2-27B-v0.2 (sgl-project#16947)
Co-authored-by: cy <chenyang08056032@163.com>
1 parent 72c1526 commit 4f7422f

File tree

7 files changed

+135
-10
lines changed

7 files changed

+135
-10
lines changed

python/sglang/srt/configs/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
11631163
or "BertForSequenceClassification" in model_architectures
11641164
or "XLMRobertaModel" in model_architectures
11651165
or "XLMRobertaForSequenceClassification" in model_architectures
1166+
or "Gemma2ForSequenceClassification" in model_architectures
11661167
):
11671168
return False
11681169
else:

python/sglang/srt/environ.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ class Envs:
294294
SGLANG_NPU_DISABLE_ACL_FORMAT_WEIGHT = EnvBool(False)
295295
SGLANG_NPU_USE_MULTI_STREAM = EnvBool(False)
296296
SGLANG_NPU_USE_MLAPO = EnvBool(False)
297+
# Forward native implementation for activation gelu tanh for model Skywork-Reward-Gemma-2-27B-v0.2
298+
SGLANG_NPU_FORWARD_NATIVE_GELUTANH = EnvBool(False)
299+
# Forward native implementation for gemma rms norm for model Skywork-Reward-Gemma-2-27B-v0.2
300+
SGLANG_NPU_FORWARD_NATIVE_GEMMA_RMS_NORM = EnvBool(False)
297301

298302
# Quantization
299303
SGLANG_INT4_WEIGHT = EnvBool(False)

python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,11 @@ def __init__(self, model_runner: ModelRunner):
225225
self.q_head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim
226226
else:
227227
self.use_alibi = getattr(model_runner.model_config, "use_alibi", False)
228+
if (
229+
"Gemma2ForSequenceClassification"
230+
in model_runner.model_config.hf_config.architectures
231+
):
232+
self.use_native_sdpa = True
228233
self.native_attn = AscendTorchNativeAttnBackend()
229234
self.graph_metadata = {}
230235
self.max_context_len = model_runner.model_config.context_len
@@ -821,10 +826,12 @@ def forward_extend(
821826

822827
# there are some accuracy issues in cross attention scene to use torch_npu._npu_flash_attention_qlens
823828
# forward_batch.encoder_lens is not None in cross attention scend, we add native attn to solve accuracy issues
829+
# Model skywork-reward-gemma2-2-27B also suffers from precision anomalies, thus the torch native backend becomes beneficial approach.
824830
if (
825831
layer.qk_head_dim <= 128
826832
and causal
827833
and forward_batch.encoder_lens is None
834+
and not getattr(self, "use_native_sdpa", False)
828835
):
829836
if not self.use_alibi:
830837
query = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)

python/sglang/srt/layers/activation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_tensor_model_parallel_rank,
2828
get_tensor_model_parallel_world_size,
2929
)
30+
from sglang.srt.environ import envs
3031
from sglang.srt.layers.quantization.base_config import QuantizationConfig
3132
from sglang.srt.layers.utils import MultiPlatformOp
3233
from sglang.srt.server_args import get_global_server_args
@@ -131,6 +132,8 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
131132
return self._forward_impl(x)
132133

133134
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
135+
if envs.SGLANG_NPU_FORWARD_NATIVE_GELUTANH.get():
136+
return self.forward_native(x)
134137
y_npu, gelu_npu = torch_npu.npu_geglu(
135138
x,
136139
dim=-1,

python/sglang/srt/layers/layernorm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
is_batch_invariant_mode_enabled,
2525
rms_norm_batch_invariant,
2626
)
27+
from sglang.srt.environ import envs
2728
from sglang.srt.layers.utils import MultiPlatformOp
2829
from sglang.srt.server_args import get_global_server_args
2930
from sglang.srt.utils import (
@@ -468,6 +469,8 @@ def forward_npu(
468469
residual: Optional[torch.Tensor] = None,
469470
post_residual_addition: Optional[torch.Tensor] = None,
470471
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
472+
if envs.SGLANG_NPU_FORWARD_NATIVE_GEMMA_RMS_NORM.get():
473+
return self.forward_native(x, residual)
471474
if residual is not None:
472475
if post_residual_addition is not None:
473476
residual = residual + post_residual_addition

python/sglang/srt/models/gemma2.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@
3939
default_weight_loader,
4040
maybe_remap_kv_scale_name,
4141
)
42-
from sglang.srt.utils import add_prefix, make_layers
42+
from sglang.srt.utils import add_prefix, is_npu, make_layers
43+
44+
_is_npu = is_npu()
4345

4446

4547
# Aligned with HF's implementation, using sliding window inclusive with the last token
@@ -142,13 +144,28 @@ def __init__(
142144
quant_config=quant_config,
143145
prefix=add_prefix("o_proj", prefix),
144146
)
145-
self.rotary_emb = get_rope(
146-
self.head_dim,
147-
rotary_dim=self.head_dim,
148-
max_position=max_position_embeddings,
149-
base=self.rope_theta,
150-
is_neox_style=True,
151-
)
147+
if (
148+
not _is_npu
149+
or "Gemma2ForSequenceClassification" not in self.config.architectures
150+
):
151+
self.rotary_emb = get_rope(
152+
self.head_dim,
153+
rotary_dim=self.head_dim,
154+
max_position=max_position_embeddings,
155+
base=self.rope_theta,
156+
is_neox_style=True,
157+
)
158+
logit_cap = self.config.attn_logit_softcapping
159+
else:
160+
self.rotary_emb = get_rope(
161+
self.head_dim,
162+
rotary_dim=self.head_dim,
163+
max_position=max_position_embeddings,
164+
base=self.rope_theta,
165+
is_neox_style=True,
166+
dtype=torch.float32,
167+
)
168+
logit_cap = 0.0
152169

153170
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
154171
self.attn = RadixAttention(
@@ -157,7 +174,7 @@ def __init__(
157174
self.scaling,
158175
num_kv_heads=self.num_kv_heads,
159176
layer_id=layer_id,
160-
logit_cap=self.config.attn_logit_softcapping,
177+
logit_cap=logit_cap,
161178
sliding_window_size=(
162179
get_attention_sliding_window_size(config)
163180
if use_sliding_window
@@ -294,7 +311,9 @@ def forward(
294311
hidden_states = self.embed_tokens(input_ids)
295312
else:
296313
hidden_states = input_embeds
297-
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=torch.float16)
314+
normalizer = torch.tensor(
315+
self.config.hidden_size**0.5, dtype=hidden_states.dtype
316+
)
298317
hidden_states *= normalizer
299318

300319
residual = None
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import logging
2+
import multiprocessing as mp
3+
import os
4+
import unittest
5+
6+
import torch
7+
8+
from sglang.test.ci.ci_register import register_npu_ci
9+
from sglang.test.runners import HFRunner, SRTRunner
10+
from sglang.test.test_utils import CustomTestCase
11+
12+
logger = logging.getLogger(__name__)
13+
register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True)
14+
15+
MODELS = [
16+
(
17+
"/root/.cache/modelscope/hub/models/AI-ModelScope/Skywork-Reward-Gemma-2-27B-v0.2",
18+
1,
19+
4e-2,
20+
),
21+
]
22+
TORCH_DTYPES = [torch.bfloat16]
23+
24+
PROMPT = (
25+
"What is the range of the numeric output of a sigmoid node in a neural network?"
26+
)
27+
RESPONSE1 = "The output of a sigmoid node is bounded between -1 and 1."
28+
RESPONSE2 = "The output of a sigmoid node is bounded between 0 and 1."
29+
30+
CONVS = [
31+
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE1}],
32+
[{"role": "user", "content": PROMPT}, {"role": "assistant", "content": RESPONSE2}],
33+
]
34+
35+
36+
class TestRewardModels(CustomTestCase):
37+
38+
@classmethod
39+
def setUpClass(cls):
40+
mp.set_start_method("spawn", force=True)
41+
42+
def assert_close_reward_scores(
43+
self,
44+
convs,
45+
model_path,
46+
tp_size,
47+
torch_dtype,
48+
tolerance,
49+
) -> None:
50+
with HFRunner(
51+
model_path,
52+
torch_dtype=torch_dtype,
53+
model_type="reward",
54+
) as hf_runner:
55+
hf_outputs = hf_runner.forward(convs)
56+
57+
with SRTRunner(
58+
model_path,
59+
torch_dtype=torch_dtype,
60+
model_type="reward",
61+
mem_fraction_static=0.95,
62+
) as srt_runner:
63+
prompts = srt_runner.tokenizer.apply_chat_template(
64+
convs, tokenize=False, return_dict=False
65+
)
66+
srt_outputs = srt_runner.forward(prompts)
67+
68+
hf_scores = torch.tensor(hf_outputs.scores)
69+
srt_scores = torch.tensor(srt_outputs.scores)
70+
logger.info(f"{hf_scores=}")
71+
logger.info(f"{srt_scores=}")
72+
73+
assert torch.all(
74+
abs(hf_scores - srt_scores) < tolerance
75+
), "reward scores are not all close"
76+
77+
def test_reward_scores(self):
78+
for model, tp_size, tolerance in MODELS:
79+
for torch_dtype in TORCH_DTYPES:
80+
self.assert_close_reward_scores(
81+
CONVS, model, tp_size, torch_dtype, tolerance
82+
)
83+
84+
85+
if __name__ == "__main__":
86+
os.environ["SGLANG_NPU_FORWARD_NATIVE_GELUTANH"] = "1"
87+
os.environ["SGLANG_NPU_FORWARD_NATIVE_GEMMA_RMS_NORM"] = "1"
88+
unittest.main()

0 commit comments

Comments
 (0)