Skip to content

[AMD] support two batch overlapping for mori ep#17953

Merged
HaiShaw merged 23 commits intosgl-project:mainfrom
HaiShaw:mori_ep_tbo
Feb 20, 2026
Merged

[AMD] support two batch overlapping for mori ep#17953
HaiShaw merged 23 commits intosgl-project:mainfrom
HaiShaw:mori_ep_tbo

Conversation

@billishyahao
Copy link
Contributor

@billishyahao billishyahao commented Jan 29, 2026

Motivation

co-author with @kkHuang-amd @ZhaiFeiyue @Duyi-Wang
cc @HaiShaw

This patch is to support TBO aka two batch overlapping feature for mori ep. It can be divided into the following changes:
(1) We introduce MORI async API to support CU-free method for low latency scenario.
(2) We introduce multi hip stream to enable communication-computation overlapping for high throughput scenario.
(3) The relation between sglang arguments and the underlying configs are shown as below

server arguments MoRI kernel type
--deepep-mode normal Intra / InterNode / InterNode LL kernel
--deepep-mode normal --two-batch-overlap Intra / InterNode / InterNode LL kernel + dual stream
--deepep-mode low_latency AsyncLL kernel
--deepep-mode low_latency --two-batch-overlap AsyncLL kernel + CU-free overlapping
--deepep-mode auto Coming soon along with MoRI kernel shifting feature

Unittest is to be added.

Accuracy Tests

Accuracy check pass on gsm8k dataset:

DSR1 FP8 EP8: aiter backend + Mori normal mode + fp8 dispatch + eager

Click to expand the server command
SGLANG_MORI_FP8_DISP=true \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode normal \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--disable-cuda-graph \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter
python /benchmark/gsm8k/bench_sglang.py --num-questions 200 --port 8321
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:39<00:00,  5.12it/s]
Accuracy: 0.965
Invalid: 0.000
Latency: 39.088 s
Output throughput: 463.318 token/s

DSR1 FP8 EP8: aiter backend + Mori normal mode + fp8 dispatch + graph

Click to expand the server command
SGLANG_MORI_FP8_DISP=true \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode normal \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--cuda-graph-max-bs 16 \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter
python /benchmark/gsm8k/bench_sglang.py --num-questions 200 --port 8321
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:25<00:00,  7.91it/s]
Accuracy: 0.970
Invalid: 0.000
Latency: 25.272 s
Output throughput: 720.808 token/s

DSR1 FP8 EP8: aiter backend + Mori normal mode + fp8 dispatch + non-persist mla

Click to expand the server command
SGLANG_AITER_MLA_PERSIST=0 \
SGLANG_MORI_FP8_DISP=true \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode normal \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--cuda-graph-max-bs 16 \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter
python /benchmark/gsm8k/bench_sglang.py --num-questions 200 --port 8321
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:28<00:00,  6.98it/s]
Accuracy: 0.970
Invalid: 0.000
Latency: 28.675 s
Output throughput: 634.524 token/s

DSR1 FP8 EP8: aiter backend + Mori low_latency mode + fp8 dispatch + eager

Click to expand the server command
SGLANG_MORI_FP8_DISP=true \
MORI_DISABLE_P2P=1 \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode low_latency \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--disable-cuda-graph \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter
python /billhe/sglang-sa-tbo/benchmark/gsm8k/bench_sglang.py --num-questions 200 --port 8321
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:57<00:00,  3.48it/s]
Accuracy: 0.975
Invalid: 0.000
Latency: 57.551 s
Output throughput: 322.182 token/s

DSR1 FP8 EP8: aiter backend + Mori low_latency mode + fp8 dispatch + graph

Click to expand the server command
SGLANG_MORI_FP8_DISP=true \
MORI_DISABLE_P2P=1 \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode low_latency \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--cuda-graph-max-bs 16 \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter
python /benchmark/gsm8k/bench_sglang.py --num-questions 200 --port 8321
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:51<00:00,  3.87it/s]
Accuracy: 0.975
Invalid: 0.000
Latency: 51.684 s
Output throughput: 359.240 token/s

DSR1 FP8 EP8: aiter backend + Mori low_latency mode + fp8 dispatch + two-batch-overlap + eager

Click to expand the server command
SGLANG_MORI_FP8_DISP=true \
MORI_DISABLE_P2P=1 \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode low_latency \
	--enable-two-batch-overlap \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--disable-cuda-graph \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter
python /benchmark/gsm8k/bench_sglang.py --num-questions 200 --port 8321
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:23<00:00,  2.41it/s]
Accuracy: 0.970
Invalid: 0.000
Latency: 83.107 s
Output throughput: 224.192 token/s

DSR1 FP8 EP8: aiter backend + Mori low_latency mode + fp8 dispatch + two-batch-overlap + graph

Click to expand the server command
SGLANG_MORI_FP8_DISP=true \
MORI_DISABLE_P2P=1 \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode low_latency \
	--enable-two-batch-overlap \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--cuda-graph-max-bs 16 \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter
python /benchmark/gsm8k/bench_sglang.py --num-questions 200 --port 8321
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:47<00:00,  4.23it/s]
Accuracy: 0.970
Invalid: 0.000
Latency: 47.327 s
Output throughput: 383.569 token/s

DSR1 FP8 EP8: aiter backend + Mori low_latency mode + fp8 dispatch + two-batch-overlap + graph + non-persist mla

Click to expand the server command
SGLANG_AITER_MLA_PERSIST=0 \
SGLANG_MORI_FP8_DISP=true \
MORI_DISABLE_P2P=1 \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode low_latency \
	--enable-two-batch-overlap \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--cuda-graph-max-bs 16 \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter
python /benchmark/gsm8k/bench_sglang.py --num-questions 200 --port 8321
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:45<00:00,  4.40it/s]
Accuracy: 0.985
Invalid: 0.000
Latency: 45.494 s
Output throughput: 401.281 token/s

DSR1 FP8 EP8: aiter backend + Mori low_latency mode + fp8 dispatch + two-batch-overlap + graph + MTP

Click to expand the server command
SGLANG_MORI_FP8_DISP=true \
MORI_DISABLE_P2P=1 \
MORI_SHMEM_MODE=ISOLATION                         \
SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192 \
NCCL_IB_HCA=ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
GLOO_SOCKET_IFNAME=enp81s0f1 \
NCCL_SOCKET_IFNAME=enp81s0f1 \
SGLANG_USE_AITER=1           \
python3 -m sglang.launch_server \
	--model-path /models/DSR1  \
	--tp-size 8 \
	--dp-size 8 \
	--ep-size 8 \
	--moe-a2a-backend mori \
	--deepep-mode low_latency \
	--enable-two-batch-overlap \
	--enable-dp-attention \
	--decode-log-interval 1 \
	--host 0.0.0.0 \
	--port 8321 \
	--nnodes 1 \
	--node-rank 0 \
	--trust-remote-code \
	--moe-dense-tp-size 1 \
	--enable-dp-lm-head \
	--disable-radix-cache \
	--watchdog-timeout 1000000 \
	--mem-fraction-static 0.8 \
	--max-running-requests 128 \
	--chunked-prefill-size 65536 \
	--speculative-algorithm NEXTN \
	--speculative-num-steps 1 \
	--speculative-eagle-topk 1 \
	--speculative-num-draft-tokens 2 \
	--cuda-graph-max-bs 16 \
	--kv-cache-dtype fp8_e4m3 \
	--log-requests \
	--log-requests-level 3 \
	--attention-backend aiter
python /billhe/sglang-sa-tbo/benchmark/gsm8k/bench_sglang.py --num-questions 200 --port 8321100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:43<00:00,  4.62it/s]
Accuracy: 0.970
Invalid: 0.000Latency: 43.270 s
Output throughput: 426.602 token/s

Benchmarking, Performance Gain and Profiling

DSR1 FP8 1k/1k stress test

bs per gpu async (output toks/s/gpu) async + TBO (output toks/s/gpu) TBO gain (%)
128 1548.43 1360.84 -12%
256 1958.43 2138.30 +9%
512 2518.49 2792.93 +11%
1024 2868.89 3342.79 +17%
2048 2985.92 3735.71 +25%

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @billishyahao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly upgrades the Mori Expert Parallelism (EP) backend by integrating advanced features for performance optimization. It introduces an asynchronous API to reduce latency and leverages multi-HIP streams to enable efficient overlapping of communication and computation. These enhancements are crucial for improving the overall throughput and responsiveness of models utilizing Mori EP, particularly within a two-batch overlapping context, making the system more adaptable to diverse performance demands.

Highlights

  • Mori EP Integration: The mori backend has been added as a supported option for the --moe-a2a-backend argument, expanding the available expert parallelism communication backends.
  • Asynchronous Mori API for Low Latency: An asynchronous API for Mori EP is introduced to support low-latency scenarios, controlled by the new SGLANG_MORI_ASYNC_MODE environment variable. This enables non-blocking operations for improved responsiveness.
  • Multi-HIP Stream for Communication-Computation Overlapping: Multi-HIP stream functionality has been implemented to allow for communication-computation overlapping, enhancing throughput. This feature is enabled via the SGLANG_MORI_DUAL_STREAM environment variable.
  • Two Batch Overlapping (TBO) for Mori EP: The existing two-batch overlapping mechanism has been extended to fully support the Mori EP dispatcher, allowing for more efficient resource utilization by overlapping processing of different batches.
  • Refactored Mori Dispatcher Logic: The Mori EP dispatcher (MoriEPDispatcher) has been refactored to accommodate both normal and low-latency modes, introducing new dispatch and combine input/output structures (MoriEPLLDispatchOutput, MoriEPLLCombineInput) for better organization and functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for two-batch overlapping for the mori expert parallelism backend, primarily targeting AMD GPUs. The changes are extensive, adding an async API for low latency scenarios and multi-hip stream support to overlap communication and computation. Overall, the implementation is solid and aligns with the PR's objectives. I've identified a critical bug that could cause crashes on non-CUDA platforms and have also included a few suggestions to improve code style and maintainability.

@HaiShaw
Copy link
Collaborator

HaiShaw commented Feb 15, 2026

@billishyahao conflicts?

@HaiShaw HaiShaw self-assigned this Feb 17, 2026
@HaiShaw
Copy link
Collaborator

HaiShaw commented Feb 19, 2026

@kkHuang-amd Please review aiter backend, for performance implications

@billishyahao billishyahao requested a review from HaiShaw February 20, 2026 08:29
@HaiShaw
Copy link
Collaborator

HaiShaw commented Feb 20, 2026

/tag-and-rerun-ci

warp_num_per_block=8,
block_num=64,
rdma_block_num=32,
),
Copy link
Collaborator

@HaiShaw HaiShaw Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these parameters be tunable for diff chips (all 3 sections above)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theorically yes! But till now, we only provide end user two ways, (1) one is that manually decide the parameters by user like this, (2) another one is decide automatically inside mori through setting env MORI_EP_LAUNCH_CONFIG_MODE =auto.

For reference:
https://github.com/ROCm/mori/blob/b95cdbd6e0f36a61ae75a570da97d5c308a3fa85/python/mori/ops/dispatch_combine.py#L160C51-L181

@HaiShaw HaiShaw merged commit fbb6098 into sgl-project:main Feb 20, 2026
94 of 102 checks passed
michaelzhang-ai added a commit that referenced this pull request Feb 21, 2026
…e_fused_experts

The `op_output` method (TBO path) introduced in PR #17953 used only
`_use_aiter` to decide whether to skip `routed_scaling_factor`, while
`_forward_moe_fused_experts` (non-TBO path) uses the more complete
condition `self.experts.should_fuse_routed_scaling_factor_in_topk or
_use_aiter`. This inconsistency could cause incorrect scaling behavior
when TBO is enabled with certain quantization backends.

Align `op_output` to use the same condition as `_forward_moe_fused_experts`
so both paths handle `routed_scaling_factor` identically.

Co-authored-by: Cursor <cursoragent@cursor.com>
michaelzhang-ai added a commit to michaelzhang-ai/sglang that referenced this pull request Feb 21, 2026
PR sgl-project#17953 added EpDispatchCombineKernelType.AsyncLL but the CI Docker
image's mori package doesn't have it yet, causing an AttributeError on
every mori EP test. Guard the LOW_LATENCY config behind a hasattr check
and fall back to INTRA_NODE/INTER_NODE mode when AsyncLL is unavailable.
@Fridge003
Copy link
Collaborator

@hubertlu-tw
Copy link
Collaborator

Hi @billishyahao,
Is the following assertion what you intended to add?
https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py#L2414-L2420

assert (self.chunked_prefill_size) <= get_int_env_var(
"SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 4096
), "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK (default 4096) must be larger or equal to chunked_prefill_size"

Since I saw the commands you ran in the experiment used

SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192
--chunked-prefill-size 65536 

CC: @HaiShaw

@billishyahao
Copy link
Contributor Author

billishyahao commented Mar 5, 2026

Hi @hubertlu-tw Thanks for the inquiry. Please check out the post init process in server_args.py

# Handle data parallelism.
self._handle_data_parallelism()  <----- self.chunked_prefill_size will be adjusted to chunked_prefill_size/dp_size

        ...

self._handle_a2a_moe()  <--- assertion self.chunked_prefill_size <= SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK here

# Handle data parallelism.
self._handle_data_parallelism()
# Handle context parallelism.
self._handle_context_parallelism()
# Handle MoE configurations.
self._handle_moe_kernel_config()
self._handle_a2a_moe()

So for your case:

SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192
--chunked-prefill-size 65536 

new chunked_prefill_size will be 65536 / dp_size(8)=8192 so it is legit.

@hubertlu-tw
Copy link
Collaborator

Hi @hubertlu-tw Thanks for the inquiry. Please check out the post init process in server_args.py

# Handle data parallelism.
self._handle_data_parallelism()  <----- self.chunked_prefill_size will be adjusted to chunked_prefill_size/dp_size

        ...

self._handle_a2a_moe()  <--- assertion self.chunked_prefill_size <= SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK here

https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py#L764-L771

So for your case:

SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=8192
--chunked-prefill-size 65536 

new chunked_prefill_size will be 65536 / dp_size(8)=8192 so it is legit.

Thank you for answering my question!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants