Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
33d29f6
[AMD] support two batch overlapping for mori ep
billishyahao Jan 29, 2026
e0b73b9
clean the debug log
billishyahao Jan 29, 2026
aa84e36
fix lint
billishyahao Jan 29, 2026
03a032f
[ROCm] prepare aiter mla metadata for TBO
billishyahao Jan 30, 2026
2e5bf5a
skip validation when decode only
billishyahao Feb 2, 2026
297e34a
fix bench one batch server script
billishyahao Feb 2, 2026
3168743
fix skip_token_capacity_threshold
billishyahao Feb 2, 2026
fc9feff
fix none issue for mtp
billishyahao Feb 2, 2026
8c0b0a7
Add mori inter kernel auto detected.
Duyi-Wang Jan 29, 2026
d95089d
[TBO] fix cuda graph intermittently becomes disabled bug
billishyahao Feb 5, 2026
89b10fe
adapt new AITER API
billishyahao Feb 16, 2026
e2f3521
skip deepgemm sm num calculation in hip env
billishyahao Feb 17, 2026
5d41a83
Merge branch 'main' into mori_ep_tbo
HaiShaw Feb 17, 2026
419c850
Lint fix
HaiShaw Feb 19, 2026
f677861
Upd: comment on low_latency supports
HaiShaw Feb 19, 2026
2341fea
Merge branch 'main' into mori_ep_tbo
HaiShaw Feb 19, 2026
25d430f
Merge branch 'main' into mori_ep_tbo
billishyahao Feb 19, 2026
d2a89d5
fix: ensure scheme attribute exists before type checking in MoriEPMoE
Duyi-Wang Feb 19, 2026
f2299a0
reduce mori env
billishyahao Feb 20, 2026
46f23c1
fix fusedmoe output dtype magic number
billishyahao Feb 20, 2026
ba78277
fix comment
billishyahao Feb 20, 2026
f9f0ff3
Merge branch 'main' into mori_ep_tbo
billishyahao Feb 20, 2026
d916500
fix the comments
billishyahao Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix lint
  • Loading branch information
billishyahao committed Feb 16, 2026
commit aa84e36b8a10a5dc5378707f4a66f4b0be32bd36
11 changes: 8 additions & 3 deletions python/sglang/srt/batch_overlap/operations_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from sglang.srt.batch_overlap.operations import Operation
from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import is_cuda

_is_cuda = is_cuda()


@dataclass
Expand Down Expand Up @@ -91,9 +94,9 @@ def _compute_moe_deepseek_layer_operations_strategy_tbo(
def _compute_moe_deepseek_blog_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
# deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
##TODO(billishyahao): fixme
deep_gemm_num_sms = None
if _is_cuda:
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms

return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
Expand Down Expand Up @@ -170,7 +173,9 @@ def _compute_moe_qwen3_layer_operations_strategy_tbo(
def _compute_moe_qwen3_prefill(layer):
device_properties = torch.cuda.get_device_properties(device="cuda")
total_num_sms = device_properties.multi_processor_count
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms
deep_gemm_num_sms = None
if _is_cuda:
deep_gemm_num_sms = total_num_sms - DeepEPConfig.get_instance().num_sms

return OperationsStrategy(
deep_gemm_num_sms=deep_gemm_num_sms,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/batch_overlap/two_batch_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPDispatcher,
MooncakeEPDispatcher,
MoriEPDispatcher
MoriEPDispatcher,
)
from sglang.srt.layers.moe.token_dispatcher.base import BaseDispatcher
from sglang.srt.managers.schedule_batch import ScheduleBatch
Expand Down
29 changes: 10 additions & 19 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,11 @@
DeepEPLLCombineInput,
DeepEPNormalCombineInput,
)
from sglang.srt.layers.moe.token_dispatcher.moriep import MoriEPNormalCombineInput
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.moe.token_dispatcher.moriep import (
MoriEPLLCombineInput,
MoriEPNormalCombineInput,
)
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker

from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
NPUCompressedTensorsW4A16Int4DynamicMoEMethod,
Expand All @@ -46,8 +43,6 @@
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPLLDispatchOutput,
DeepEPNormalDispatchOutput,
MoriEPNormalDispatchOutput,
MoriEPLLDispatchOutput,
DispatchOutput,
)

Expand Down Expand Up @@ -163,7 +158,6 @@ def __init__(
# the last one is invalid rank_id
self.expert_mask[:-1] = 1


def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -256,12 +250,12 @@ def run_moe_core(
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
else DeepEPLLCombineInput
)

return combine_input_wrapper(
hidden_states=output,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
)
hidden_states=output,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
)

def combine(
self,
Expand Down Expand Up @@ -606,7 +600,7 @@ def forward(
hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
num_token = hidden_states.shape[0]
num_token = hidden_states.shape[0]
dispatch_output = self.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output
)
Expand All @@ -617,13 +611,12 @@ def forward(

return hidden_states[:num_token]


def run_moe_core(
self,
dispatch_output: DispatchOutput,
):
#TODO(billishyahao): check aiter path
#billishyahao: for now, fused_moe only support torch.bfloat16
# TODO(billishyahao): check aiter path
# billishyahao: for now, fused_moe only support torch.bfloat16
output_dtype = torch.bfloat16
scale = None
is_fp8_quant = isinstance(self.quant_method, Fp8MoEMethod)
Expand All @@ -636,15 +629,15 @@ def run_moe_core(
dispatch_weights,
dispatch_recv_token_num,
origin_topk_ids,
origin_topk_weights
origin_topk_weights,
) = (
dispatch_output.hidden_states,
dispatch_output.hidden_states_scale,
dispatch_output.topk_ids,
dispatch_output.topk_weights,
dispatch_output.num_recv_tokens_per_expert,
dispatch_output.origin_topk_ids,
dispatch_output.origin_topk_weights
dispatch_output.origin_topk_weights,
)

w13_weight = self.w13_weight
Expand Down Expand Up @@ -717,8 +710,6 @@ def run_moe_core(
)




def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
# [TODO] kk, temporary solution
if get_moe_a2a_backend().is_mori():
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def create_moe_dispatcher(moe_runner_config: MoeRunnerConfig) -> BaseDispatcher:
return StandardDispatcher(moe_runner_config)
elif a2a_backend.is_deepep() or a2a_backend.is_mooncake() or a2a_backend.is_mori():
return MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group if not a2a_backend.is_mori() else get_tp_group(),
group=(
get_tp_group().device_group
if not a2a_backend.is_mori()
else get_tp_group()
),
router_topk=moe_runner_config.top_k,
permute_fusion=True,
num_experts=moe_runner_config.num_experts,
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/moe/token_dispatcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
)
from sglang.srt.layers.moe.token_dispatcher.moriep import (
MoriEPDispatcher,
MoriEPLLCombineInput,
MoriEPLLDispatchOutput,
MoriEPNormalCombineInput,
MoriEPNormalDispatchOutput,
MoriEPLLDispatchOutput,
MoriEPLLCombineInput,
)
from sglang.srt.layers.moe.token_dispatcher.standard import (
StandardCombineInput,
Expand Down
Loading