Skip to content

Commit 9c4b947

Browse files
HandH1998Johnsonms
authored andcommitted
Support mxint4 flashinfer_trtllm moe gemm (sgl-project#16892)
1 parent 15f4218 commit 9c4b947

File tree

4 files changed

+367
-6
lines changed

4 files changed

+367
-6
lines changed

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
FusedMoEMethodBase,
5555
QuantizationConfig,
5656
)
57+
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
58+
CompressedTensorsMxInt4MoEMethod,
59+
)
5760
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
5861
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
5962
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
@@ -253,6 +256,7 @@ def __init__(
253256
gemm1_alpha=gemm1_alpha,
254257
gemm1_clamp_limit=gemm1_clamp_limit,
255258
is_gated=is_gated,
259+
routing_method_type=routing_method_type,
256260
)
257261

258262
self.quant_method: Optional[FusedMoEMethodBase] = None
@@ -688,6 +692,7 @@ def _weight_loader_impl(
688692
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
689693
or isinstance(self.quant_method, Fp8MoEMethod)
690694
or isinstance(self.quant_method, UnquantizedFusedMoEMethod)
695+
or isinstance(self.quant_method, CompressedTensorsMxInt4MoEMethod)
691696
):
692697
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
693698

@@ -1140,6 +1145,7 @@ def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
11401145
router_logits = topk_output.router_logits
11411146
topk_config = topk_output.topk_config
11421147
correction_bias = topk_config.correction_bias
1148+
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor
11431149

11441150
if isinstance(self.quant_method, UnquantizedFusedMoEMethod):
11451151
# lazy import
@@ -1170,6 +1176,7 @@ def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
11701176
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
11711177
local_num_experts=self.num_local_experts,
11721178
routing_method_type=self.routing_method_type,
1179+
routed_scaling_factor=routed_scaling_factor,
11731180
tune_max_num_tokens=next_power_of_2(hidden_states.shape[0]),
11741181
)
11751182

python/sglang/srt/layers/moe/moe_runner/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
import torch
88

9-
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
9+
from sglang.srt.layers.moe.utils import (
10+
MoeA2ABackend,
11+
MoeRunnerBackend,
12+
RoutingMethodType,
13+
)
1014

1115
if TYPE_CHECKING:
1216
from sglang.srt.layers.moe.moe_runner.triton import (
@@ -33,6 +37,7 @@ class MoeRunnerConfig:
3337
top_k: Optional[int] = None
3438
num_fused_shared_experts: Optional[int] = None
3539
params_dtype: Optional[torch.dtype] = None
40+
routing_method_type: Optional[RoutingMethodType] = None
3641

3742
# Runner configuration
3843
activation: str = "silu"

python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,19 @@ def _is_wNa16_group_channel(
471471

472472
return is_channel_group and input_quant_none and is_symmetric and is_static
473473

474+
def _is_mxint4a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
475+
input_quant_none = input_quant is None
476+
is_symmetric = weight_quant.symmetric
477+
is_mxint4 = (
478+
weight_quant.num_bits == 4
479+
and weight_quant.type == QuantizationType.INT
480+
and weight_quant.strategy == QuantizationStrategy.GROUP.value
481+
and weight_quant.group_size == 32
482+
)
483+
is_static = not weight_quant.dynamic
484+
485+
return is_mxint4 and input_quant_none and is_symmetric and is_static
486+
474487
def _is_dynamic_token_w4(
475488
self, weight_quant: BaseModel, input_quant: BaseModel
476489
) -> bool:

0 commit comments

Comments
 (0)