Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
33d29f6
[AMD] support two batch overlapping for mori ep
billishyahao Jan 29, 2026
e0b73b9
clean the debug log
billishyahao Jan 29, 2026
aa84e36
fix lint
billishyahao Jan 29, 2026
03a032f
[ROCm] prepare aiter mla metadata for TBO
billishyahao Jan 30, 2026
2e5bf5a
skip validation when decode only
billishyahao Feb 2, 2026
297e34a
fix bench one batch server script
billishyahao Feb 2, 2026
3168743
fix skip_token_capacity_threshold
billishyahao Feb 2, 2026
fc9feff
fix none issue for mtp
billishyahao Feb 2, 2026
8c0b0a7
Add mori inter kernel auto detected.
Duyi-Wang Jan 29, 2026
d95089d
[TBO] fix cuda graph intermittently becomes disabled bug
billishyahao Feb 5, 2026
89b10fe
adapt new AITER API
billishyahao Feb 16, 2026
e2f3521
skip deepgemm sm num calculation in hip env
billishyahao Feb 17, 2026
5d41a83
Merge branch 'main' into mori_ep_tbo
HaiShaw Feb 17, 2026
419c850
Lint fix
HaiShaw Feb 19, 2026
f677861
Upd: comment on low_latency supports
HaiShaw Feb 19, 2026
2341fea
Merge branch 'main' into mori_ep_tbo
HaiShaw Feb 19, 2026
25d430f
Merge branch 'main' into mori_ep_tbo
billishyahao Feb 19, 2026
d2a89d5
fix: ensure scheme attribute exists before type checking in MoriEPMoE
Duyi-Wang Feb 19, 2026
f2299a0
reduce mori env
billishyahao Feb 20, 2026
46f23c1
fix fusedmoe output dtype magic number
billishyahao Feb 20, 2026
ba78277
fix comment
billishyahao Feb 20, 2026
f9f0ff3
Merge branch 'main' into mori_ep_tbo
billishyahao Feb 20, 2026
d916500
fix the comments
billishyahao Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
clean the debug log
  • Loading branch information
billishyahao committed Feb 16, 2026
commit e0b73b9518b9c8ee4dea3c84a66ecc0f0e3ef13b
116 changes: 35 additions & 81 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,28 +150,18 @@ def __init__(
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if _use_aiter:
if get_moe_a2a_backend().is_deepep():
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
# self.expert_mask = [1, 1, 1, 1, 0]
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
self.expert_mask = torch.zeros(
(self.num_local_experts + 1),
device=torch.cuda.current_device(),
dtype=torch.int,
)
# the last one is invalid rank_id
self.expert_mask[:-1] = 1
elif get_moe_a2a_backend().is_mori():
self.expert_mask = torch.zeros(
(self.num_experts),
device=torch.cuda.current_device(),
dtype=torch.int32,
)
expert_start_idx = self.moe_ep_rank * self.num_local_experts
expert_end_idx = expert_start_idx + self.num_local_experts
self.expert_mask[expert_start_idx : expert_end_idx] = 1
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
# self.expert_mask = [1, 1, 1, 1, 0]
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
self.expert_mask = torch.zeros(
(self.num_local_experts + 1),
device=torch.cuda.current_device(),
dtype=torch.int,
)
# the last one is invalid rank_id
self.expert_mask[:-1] = 1


def forward(
Expand Down Expand Up @@ -261,25 +251,16 @@ def run_moe_core(
else:
assert False, "forward_deepgemm_masked is deprecated"



if get_moe_a2a_backend().is_deepep():
combine_input_wrapper = (
DeepEPNormalCombineInput
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
else DeepEPLLCombineInput
)
elif get_moe_a2a_backend().is_mori():
combine_input_wrapper = (
MoriEPNormalCombineInput
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
else MoriEPLLCombineInput
)
combine_input_wrapper = (
DeepEPNormalCombineInput
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
else DeepEPLLCombineInput
)

return combine_input_wrapper(
hidden_states=output,
topk_ids=dispatch_output.origin_topk_ids,
topk_weights=dispatch_output.origin_topk_weights,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
)

def combine(
Expand All @@ -298,65 +279,38 @@ def combine(

def forward_aiter(
self,
dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput, MoriEPNormalDispatchOutput, MoriEPLLDispatchOutput],
dispatch_output: Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput],
):
recv_scales = None

if get_moe_a2a_backend().is_deepep():
hidden_states, topk_ids, topk_weights = (
dispatch_output.hidden_states,
dispatch_output.topk_ids,
dispatch_output.topk_weights,
)
elif get_moe_a2a_backend().is_mori():
hidden_states, recv_scales, topk_ids, topk_weights, packed_recv_count = (
dispatch_output.hidden_states,
dispatch_output.hidden_states_scale,
dispatch_output.topk_ids,
dispatch_output.topk_weights,
dispatch_output.num_recv_tokens_per_expert
)

# logger.info(f"bill-dbg: {hidden_states=}")
# logger.info(f"bill-dbg: {recv_scales=}")
# logger.info(f"bill-dbg: {topk_ids=}")
# logger.info(f"bill-dbg: {topk_weights=}")
# logger.info(f"bill-dbg: {packed_recv_count=}")

#billishyahao: for now, fused_moe only support torch.bfloat16
output_dtype = torch.bfloat16

# logger.info(f"bill-dbg: {output_dtype=}")
hidden_states, topk_ids, topk_weights = (
dispatch_output.hidden_states,
dispatch_output.topk_ids,
dispatch_output.topk_weights,
)

if hidden_states.shape[0] == 0:
return hidden_states

if get_moe_a2a_backend().is_deepep():
# in original deepep, idx == -1 meaning invalid and will not be processed.
# aiter does not accept -1, we use a expert mask to make these idx invalid
# (idx == num_local_experts) meaning not used in aiter fused_moe
topk_ids_copy = topk_ids.to(torch.int32)
topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts

# in original deepep, idx == -1 meaning invalid and will not be processed.
# aiter does not accept -1, we use a expert mask to make these idx invalid
# (idx == num_local_experts) meaning not used in aiter fused_moe
topk_ids_copy = topk_ids.to(torch.int32)
topk_ids_copy[topk_ids_copy == -1] = self.num_local_experts

return fused_moe(
hidden_states=hidden_states,
w1=self.w13_weight,
w2=self.w2_weight,
hidden_states,
self.w13_weight,
self.w2_weight,
topk_weights,
topk_ids_copy,
w1_scale=self.w13_weight_scale_inv,
w2_scale=self.w2_weight_scale_inv,
a1_scale=recv_scales,
topk_weight=topk_weights,
topk_ids=topk_ids_copy if get_moe_a2a_backend().is_deepep() else topk_ids,
quant_type=QuantType.per_128x128,
activation=(
ActivationType.Silu
if self.moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
expert_mask=self.expert_mask,
num_local_tokens=packed_recv_count,
dtype=output_dtype
)

def forward_flashinfer_cutedsl(
Expand Down
42 changes: 7 additions & 35 deletions python/sglang/srt/layers/moe/token_dispatcher/moriep.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from functools import lru_cache

import torch
import torch.distributed as dist

from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
Expand Down Expand Up @@ -186,15 +185,16 @@ def init_mori_op(
cpu_group = group.cpu_group
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
mori.shmem.shmem_torch_process_group_init("mori")
logger.info(
f"[MORI init] {world_size=} {rank=} {hidden_size=} {params_dtype=} {num_max_dispatch_tokens_per_rank=} {num_local_experts=} {router_topk=}"
)


mode = EpMode.INTRA_NODE if world_size <= 8 else EpMode.INTER_NODE
async_mode = get_bool_env_var("SGLANG_MORI_ASYNC_MODE", "false")
# logger.info(f"bill-dbg: {async_mode=}")
if async_mode:
mode = EpMode.LOW_LATENCY

logger.info(
f"[MORI init] {world_size=} {rank=} {hidden_size=} {params_dtype=} {num_max_dispatch_tokens_per_rank=} {num_local_experts=} {router_topk=} {mode=}"
)

cfg = get_ep_dispatch_configs()[mode]

kernel_type = cfg.kernel_type
Expand Down Expand Up @@ -505,9 +505,6 @@ def combine_a(
return hidden_states, topk_ids, topk_weights, previous_event

def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event):
# logger.info(f"bill-dbg: before comb combine_b: {hidden_states.shape=}")
# logger.info(f"bill-dbg: before comb combine_b: {topk_ids.shape=}")
# logger.info(f"bill-dbg: before comb combine_b: {topk_weights.shape=}")

hidden_states, done_event = self._combine_core(
hidden_states, topk_ids, topk_weights, previous_event
Expand All @@ -516,7 +513,6 @@ def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event):
if self._comm_stream and self.async_finish and done_event is not None:
torch.cuda.current_stream().wait_event(done_event)

# logger.info(f"bill-dbg: after comb combine_b: {hidden_states.shape=}")
return hidden_states

def _combine_core(
Expand Down Expand Up @@ -640,10 +636,7 @@ def dispatch_b(
import mori
assert self.mori_op.config.kernel_type is mori.ops.EpDispatchCombineKernelType.AsyncLL, "mori asyncll mismatch"

# logger.info(f"bill-dbg: mori_op.dispatch_recv start")

self.mori_op.dispatch_recv()
# logger.info(f"bill-dbg: mori_op.dispatch_recv end")

return MoriEPLLDispatchOutput(
hidden_states=hidden_states,
Expand All @@ -664,15 +657,13 @@ def _dispatch_core(
):
##TODO(billishyahao): add assertion here to check async

# logger.info(f"bill-dbg: mori_op.dispatch_send start")
(
packed_recv_hidden,
recv_topk_weights,
recv_scales,
recv_topk_ids,
packed_recv_count
) = self.mori_op.dispatch_send(hidden_states, topk_weights, scale, topk_ids)
# logger.info(f"bill-dbg: mori_op.dispatch_send end")

return (
packed_recv_hidden,
Expand All @@ -698,17 +689,9 @@ def combine_a(
return hidden_states, topk_ids, topk_weights, overlap_args

def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event):
# logger.info(f"bill-dbg: before comb combine_b: {hidden_states.shape=}")
# logger.info(f"bill-dbg: before comb combine_b: {topk_ids.shape=}")
# logger.info(f"bill-dbg: before comb combine_b: {topk_weights.shape=}")

# logger.info(f"bill-dbg: mori_op.combine_recv start")

self.mori_op.combine_recv()

# logger.info(f"bill-dbg: mori_op.combine_recv end")

# logger.info(f"bill-dbg: after comb combine_b: {hidden_states.shape=}")
return hidden_states[0]

def _combine_core(
Expand All @@ -717,17 +700,11 @@ def _combine_core(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
overlap_args: Optional[CombineOverlapArgs] = None,
):

# logger.info(f"bill-dbg: mori_op.combine_send start")

):
combined_hidden_states = self.mori_op.combine_send(
hidden_states, None, topk_ids
)

# logger.info(f"bill-dbg: mori_op.combine_send end")


return combined_hidden_states

def set_quant_config(self, quant_config: dict):
Expand Down Expand Up @@ -772,21 +749,16 @@ def __init__(
deepep_mode=deepep_mode,
)

# logger.info(f"bill-dbg: {self.deepep_mode.enable_low_latency()=}")

if self.deepep_mode.enable_low_latency():
# logger.info(f"bill-dbg: low latency ")
self._low_latency_dispatcher = _MoriEPDispatcherImplLowLatency(
**common_kwargs,
)

if self.deepep_mode.enable_normal():
# logger.info(f"bill-dbg: normal ")
self._normal_dispatcher = _MoriEPDispatcherImplNormal(
async_finish=async_finish,
**common_kwargs,
)


self._stage = _Stage.INITIAL
self._deepep_dispatch_hooks = MoriEPPDispatchHooks()
Expand Down
17 changes: 1 addition & 16 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,9 +822,7 @@ def _post_combine_hook(
):
dispatcher.clear_overlap_args()
self.experts.clear_overlap_args()
post_combine_hook_handle.remove()


post_combine_hook_handle.remove()

assert isinstance(self.experts.dispatcher, MaybeTboDeepEPDispatcher)
deepep_dispatch_hook_handle = (
Expand Down Expand Up @@ -1058,22 +1056,9 @@ def op_combine_b(self, state):
def op_output(self, state):
final_hidden_states = state.pop("hidden_states_after_combine")

##TODO(billishyahao): fix this bug

if get_moe_a2a_backend().is_mori():
# logger.info(f"bill-dbg: op_output {state._data=}")
num_tokens = state.pop("num_tokens")
# logger.info(f"bill-dbg: {num_tokens=}")
final_hidden_states = final_hidden_states[:num_tokens]
# if (shared_output := state.get("shared_output")) is not None:
# num_tokens = shared_output.shape[0]
# logger.info(f"bill-dbg: {num_tokens=}")
# final_hidden_states = final_hidden_states[:num_tokens]
# else:
# # logger.info(f"bill-dbg: shared_output is None...")
# # logger.info(f"bill-dbg: {final_hidden_states.shape=}")
# num_tokens = 0
# final_hidden_states = []

if (shared_output := state.pop("shared_output")) is not None:
x = shared_output
Expand Down