Skip to content

Commit 1edc69b

Browse files
authored
[Ascend]Support qwen3.5 (sgl-project#18544)
This PR affects only the NPU. If any issues arise, please contact iforgetmyname.
1 parent 0305d12 commit 1edc69b

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
3636
from sglang.srt.utils.common import rank0_log
3737

38-
if not is_cpu():
38+
if not is_cpu() and not is_npu():
3939
# fix import error on CPU device, no impacts when non-CPU path
4040
from sglang.jit_kernel.cutedsl_gdn import (
4141
cutedsl_fused_sigmoid_gating_delta_rule_update,
@@ -814,7 +814,7 @@ def __init__(self, model_runner: ModelRunner):
814814
self.conv_states_shape = (
815815
model_runner.req_to_token_pool.mamba_pool.mamba_cache.conv[0].shape
816816
)
817-
if not is_cpu():
817+
if not is_cpu() and not is_npu():
818818
assert (
819819
self.conv_states_shape[-1] < FLA_CHUNK_SIZE
820820
), f"{self.conv_states_shape[-1]=} should be less than {FLA_CHUNK_SIZE}"

python/sglang/srt/layers/quantization/modelslim/modelslim.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,15 @@ def is_layer_skipped(
193193
):
194194
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
195195
proj_name = prefix.split(".")[-1]
196+
if not hasattr(self, "_quant_description_normalized"):
197+
quant_description = {}
198+
for prefix_, value in self.quant_description.items():
199+
prefix_ = prefix_.replace("language_model.", "")
200+
if "visual" in prefix_:
201+
prefix_ = prefix_.replace("model.", "")
202+
quant_description[prefix_] = value
203+
self.quant_description = quant_description
204+
self._quant_description_normalized = True
196205
if proj_name in fused_mapping:
197206
shard_prefixes = [
198207
prefix.replace(proj_name, shard_proj_name)

python/sglang/srt/models/qwen3_5.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
# Distributed
3535
from sglang.srt.distributed import get_pp_group
3636
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
37+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
3738

3839
# Layers - Attention
3940
from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
@@ -328,7 +329,7 @@ def __init__(
328329
config=config,
329330
quant_config=quant_config,
330331
alt_stream=alt_stream,
331-
prefix=add_prefix("mlp", prefix.replace(".self_attn", "")),
332+
prefix=add_prefix("mlp", prefix.replace(".linear_attn", "")),
332333
)
333334
is_layer_sparse = True
334335
is_previous_layer_sparse = True
@@ -339,7 +340,7 @@ def __init__(
339340
intermediate_size=config.intermediate_size,
340341
hidden_act=config.hidden_act,
341342
quant_config=quant_config,
342-
prefix=add_prefix("mlp", prefix.replace(".self_attn", "")),
343+
prefix=add_prefix("mlp", prefix.replace(".linear_attn", "")),
343344
)
344345
is_layer_sparse = False
345346
is_previous_layer_sparse = False
@@ -1318,5 +1319,14 @@ def load_fused_expert_weights(
13181319

13191320
return loaded_params
13201321

1322+
@classmethod
1323+
def get_model_config_for_expert_location(cls, config):
1324+
text_config = getattr(config, "text_config", config)
1325+
return ModelConfigForExpertLocation(
1326+
num_layers=text_config.num_hidden_layers,
1327+
num_logical_experts=text_config.num_experts,
1328+
num_groups=None,
1329+
)
1330+
13211331

13221332
EntryClass = [Qwen3_5MoeForConditionalGeneration, Qwen3_5ForConditionalGeneration]

0 commit comments

Comments
 (0)