From 3b9e8100a916889fca088cf80791cb32b1a0f248 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 11:43:25 -0700 Subject: [PATCH 01/15] throwaway: conc-64 gsm8k eval for DEP8+MTP3 to reproduce dispatch token corruption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Narrow dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp search-space to a single DEP8+MTP3 conc-64 entry. With max(CONC_LIST)=64, the server computes SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32, which is below the 256 threshold that selects the correct All2All kernel. Expected: ~0% gsm8k (silent corruption from the low-latency All2All variant). Not for merge — throwaway validation of the dispatch token bug. --- .github/configs/amd-master.yaml | 140 +------------------------------- perf-changelog.yaml | 7 ++ 2 files changed, 11 insertions(+), 136 deletions(-) diff --git a/.github/configs/amd-master.yaml b/.github/configs/amd-master.yaml index 1ad705468..cb075b438 100644 --- a/.github/configs/amd-master.yaml +++ b/.github/configs/amd-master.yaml @@ -1986,123 +1986,10 @@ dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp: - isl: 8192 osl: 1024 search-space: - # MTP configurations - # 1P1D pure TP8 - - spec-decoding: "mtp" - conc-list: [ 1, 2, 4, 8 ] - prefill: - num-worker: 1 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 1 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=3" - - # 1P2D TP8 - - spec-decoding: "mtp" - conc-list: [ 2, 4, 8, 16, 32 ] - prefill: - num-worker: 1 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 2 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "DECODE_NODES=2" - - "DECODE_MTP_SIZE=3" - - # 1P2D TP8 - - spec-decoding: "mtp" - conc-list: [ 32, 64 ] - prefill: - num-worker: 1 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 2 - tp: 8 - ep: 1 - dp-attn: false - additional-settings: - - "DECODE_NODES=2" - - "DECODE_MTP_SIZE=3" - - # 1*DEP8 + 1*DEP8 - - spec-decoding: "mtp" - conc-list: [ 640, 512 ] - prefill: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=3" - - # 1*DEP8 + 1*DEP8 - - spec-decoding: "mtp" - conc-list: [ 256 ] - prefill: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=3" - - - # 1*DEP8 + 1*DEP8 - - spec-decoding: "mtp" - conc-list: [ 128 ] - prefill: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "PREFILL_NODES=1" - decode: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=3" - - # 1*DEP8 + 1*DEP8 + # THROWAWAY (not for merge): conc-64 only DEP8+MTP3 to reproduce + # SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK < 256 corruption. + # max(CONC_LIST)=64 → dispatch_tokens=64/8*4=32 → broken All2All kernel. + # 1*DEP8 + 1*DEP8, MTP3 - spec-decoding: "mtp" conc-list: [ 64 ] prefill: @@ -2121,25 +2008,6 @@ dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp: - "DECODE_NODES=1" - "DECODE_MTP_SIZE=3" - # 2*DEP8 + 1*DEP8 - - spec-decoding: "mtp" - conc-list: [ 1024, 2048, 4096 ] - prefill: - num-worker: 2 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "PREFILL_NODES=2" - decode: - num-worker: 1 - tp: 8 - ep: 8 - dp-attn: true - additional-settings: - - "DECODE_NODES=1" - - "DECODE_MTP_SIZE=1" - # DSv4-Pro FP4 on MI355X via SGLang. Uses a rocm720 mi35x image built off the # amd/deepseek_v4 branch in sgl-project/sglang; the SHA is encoded in the diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 5bee49782..ec1264da3 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3444,3 +3444,10 @@ - "Add MiniMax-M2.5 NVFP4 GB300 disaggregated multinode vLLM benchmarks via Dynamo" - "Add 1k1k/8k1k minimax recipe set under benchmarks/multi_node/srt-slurm-recipes/vllm/minimax-m2.5/" pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1641 + +- config-keys: + - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp + description: + - "Throwaway: conc-64-only gsm8k eval for DEP8+MTP3 to reproduce SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK < 256 corruption (dispatch=32 triggers broken All2All kernel, expect 0pct gsm8k). Not for merge." + pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 + evals-only: true From 45f69f594bcfc78726b395e66b11a3b678dc3174 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 11:45:43 -0700 Subject: [PATCH 02/15] trigger sweep From 9983cc0f853b7012767ccab6119d68cde1acf17d Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 13:52:50 -0700 Subject: [PATCH 03/15] add dispatch token clamp (>=256) and run benchmark+eval at conc-64 Clamp MORI_MAX_DISPATCH_TOKENS_DECODE to minimum 256 when DP+EP are both enabled, preventing SGLang's low-latency All2All kernel from being selected. That kernel silently corrupts outputs at small buffer sizes. Run A of A/B test: benchmark + eval WITH clamp on conc-64 DEP8+MTP3. --- benchmarks/multi_node/amd_utils/server_sglang.sh | 9 +++++++++ perf-changelog.yaml | 3 +-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/benchmarks/multi_node/amd_utils/server_sglang.sh b/benchmarks/multi_node/amd_utils/server_sglang.sh index c28ccab41..627db803a 100755 --- a/benchmarks/multi_node/amd_utils/server_sglang.sh +++ b/benchmarks/multi_node/amd_utils/server_sglang.sh @@ -248,6 +248,15 @@ if [[ "$DECODE_MTP_SIZE" -gt 0 ]]; then MORI_MOE_MAX_INPUT_TOKENS_DECODE=$((MORI_MOE_MAX_INPUT_TOKENS_DECODE * (DECODE_MTP_SIZE + 1))) fi +# Clamp dispatch tokens to >= 256 to avoid the low-latency All2All kernel +# variant in MoRI which silently corrupts outputs at small buffer sizes. +if [[ "$DECODE_ENABLE_DP" == "true" ]] && [[ "$DECODE_ENABLE_EP" == "true" ]]; then + if [[ $MORI_MAX_DISPATCH_TOKENS_DECODE -lt 256 ]]; then + echo "[WARN] Clamping MORI_MAX_DISPATCH_TOKENS_DECODE from $MORI_MAX_DISPATCH_TOKENS_DECODE to 256 (All2All kernel threshold)" + MORI_MAX_DISPATCH_TOKENS_DECODE=256 + fi +fi + # ============================================================================= # Cluster Topology Configuration # ============================================================================= diff --git a/perf-changelog.yaml b/perf-changelog.yaml index ec1264da3..d21e4e5d7 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3448,6 +3448,5 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Throwaway: conc-64-only gsm8k eval for DEP8+MTP3 to reproduce SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK < 256 corruption (dispatch=32 triggers broken All2All kernel, expect 0pct gsm8k). Not for merge." + - "Throwaway: conc-64 DEP8+MTP3 benchmark+eval WITH dispatch token clamp (MORI_MAX_DISPATCH_TOKENS_DECODE >= 256). A/B test for All2All kernel corruption fix." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 - evals-only: true From 139f64671353b7d4d45a77b4b868a7c23821ca50 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 14:10:22 -0700 Subject: [PATCH 04/15] revert clamp for Run B (without fix) benchmark+eval at conc-64 Run B of A/B test: benchmark + eval WITHOUT dispatch token clamp. MORI_MAX_DISPATCH_TOKENS_DECODE will be 32 (<256 threshold). Expected: corrupted output, inflated AL, ~0% gsm8k. --- benchmarks/multi_node/amd_utils/server_sglang.sh | 9 --------- perf-changelog.yaml | 2 +- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/benchmarks/multi_node/amd_utils/server_sglang.sh b/benchmarks/multi_node/amd_utils/server_sglang.sh index 627db803a..c28ccab41 100755 --- a/benchmarks/multi_node/amd_utils/server_sglang.sh +++ b/benchmarks/multi_node/amd_utils/server_sglang.sh @@ -248,15 +248,6 @@ if [[ "$DECODE_MTP_SIZE" -gt 0 ]]; then MORI_MOE_MAX_INPUT_TOKENS_DECODE=$((MORI_MOE_MAX_INPUT_TOKENS_DECODE * (DECODE_MTP_SIZE + 1))) fi -# Clamp dispatch tokens to >= 256 to avoid the low-latency All2All kernel -# variant in MoRI which silently corrupts outputs at small buffer sizes. -if [[ "$DECODE_ENABLE_DP" == "true" ]] && [[ "$DECODE_ENABLE_EP" == "true" ]]; then - if [[ $MORI_MAX_DISPATCH_TOKENS_DECODE -lt 256 ]]; then - echo "[WARN] Clamping MORI_MAX_DISPATCH_TOKENS_DECODE from $MORI_MAX_DISPATCH_TOKENS_DECODE to 256 (All2All kernel threshold)" - MORI_MAX_DISPATCH_TOKENS_DECODE=256 - fi -fi - # ============================================================================= # Cluster Topology Configuration # ============================================================================= diff --git a/perf-changelog.yaml b/perf-changelog.yaml index d21e4e5d7..488c3ac1e 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3448,5 +3448,5 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Throwaway: conc-64 DEP8+MTP3 benchmark+eval WITH dispatch token clamp (MORI_MAX_DISPATCH_TOKENS_DECODE >= 256). A/B test for All2All kernel corruption fix." + - "Throwaway: conc-64 DEP8+MTP3 benchmark+eval WITHOUT dispatch token clamp (Run B of A/B test). Expect corrupted output / 0pct gsm8k." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 From 906a9ae2a272648e046143ff52b63c04e7b4a348 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 16:08:53 -0700 Subject: [PATCH 05/15] sed-patch moriep.py to clamp dispatch tokens >= 64 (warpSize) and run benchmark+eval at conc-64 Validates Option A: instead of clamping the env var, patch the installed SGLang moriep.py at runtime to enforce a minimum of 64 (AMD CDNA3/4 warpSize) on num_max_dispatch_tokens_per_rank before it reaches the MoRI kernel config. If gsm8k recovers (like the 256 clamp did), this confirms warpSize is the minimum viable buffer floor and scopes the upstream fix. --- .../multi_node/amd_utils/server_sglang.sh | 33 +++++++++++++++++++ perf-changelog.yaml | 2 +- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/benchmarks/multi_node/amd_utils/server_sglang.sh b/benchmarks/multi_node/amd_utils/server_sglang.sh index c28ccab41..14c51fbd2 100755 --- a/benchmarks/multi_node/amd_utils/server_sglang.sh +++ b/benchmarks/multi_node/amd_utils/server_sglang.sh @@ -712,6 +712,39 @@ else echo "Decode node rank: $RANK" echo "Decode parallelism: TP=${DECODE_TP_SIZE}, EP enabled: ${DECODE_ENABLE_EP}, DP enabled: ${DECODE_ENABLE_DP}" + # ── MoRI dispatch token floor patch ────────────────────────────────── + # When conc/TP is small (e.g. 64/8=8, ×4 MTP → 32 tokens), the MoRI + # All2All dispatch kernel silently corrupts output because buffer strides + # assume a minimum alignment of warpSize (64 on CDNA3/4). Patch the + # installed SGLang moriep.py to clamp num_max_dispatch_tokens_per_rank + # to at least 64 before it reaches the kernel config. + SGLANG_PKG=$(python3 -c "import sglang; print(sglang.__path__[0])" 2>/dev/null || true) + MORIEP_FILE="${SGLANG_PKG}/srt/layers/moe/token_dispatcher/moriep.py" + if [[ -n "$SGLANG_PKG" ]] && [[ -f "$MORIEP_FILE" ]]; then + python3 - "$MORIEP_FILE" << 'MORI_PATCH_EOF' +import re, sys +p = sys.argv[1] +with open(p) as f: + s = f.read() +pattern = r'(self\.num_max_dispatch_tokens_per_rank\s*=\s*get_int_env_var\(\s*"SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK",\s*4096\s*\))' +replacement = (r'\1\n' + r' self.num_max_dispatch_tokens_per_rank = ' + r'max(self.num_max_dispatch_tokens_per_rank, 64)') +s2 = re.sub(pattern, replacement, s, count=1, flags=re.DOTALL) +if s2 == s: + print(f'WARNING: MoRI patch pattern not found in {p}', file=sys.stderr) + sys.exit(1) +with open(p, 'w') as f: + f.write(s2) +print(f'[MORI PATCH] Clamped num_max_dispatch_tokens_per_rank >= 64 (warpSize) in {p}') +MORI_PATCH_EOF + if [[ $? -ne 0 ]]; then + echo "WARNING: MoRI sed patch failed — run may still exhibit corruption" + fi + else + echo "WARNING: moriep.py not found at ${MORIEP_FILE} — skipping patch" + fi + DECODE_MORI_MOE_ENV="" set -x if [[ -n "$MORI_MOE_MAX_INPUT_TOKENS_DECODE" ]]; then diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 488c3ac1e..df8f600fd 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3448,5 +3448,5 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Throwaway: conc-64 DEP8+MTP3 benchmark+eval WITHOUT dispatch token clamp (Run B of A/B test). Expect corrupted output / 0pct gsm8k." + - "Throwaway: validate Option A fix — sed-patch moriep.py to clamp num_max_dispatch_tokens_per_rank >= 64 (AMD warpSize). If gsm8k recovers, confirms warpSize is the minimum viable buffer floor for the MoRI All2All dispatch kernel." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 From 12dadc15517f0e185f805f3fb12cfddcc4f1a476 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 16:23:21 -0700 Subject: [PATCH 06/15] clamp MoRI dispatch tokens to warpSize floor (64) and run benchmark+eval at conc-64 Root cause: the MoRI All2All dispatch kernel (EpDispatchInterNodeV1Kernel / IntraNode) writes dispatched tokens into warpSize-aligned receive slots (destTokId = flagSlotId*warpSize + laneId, laneId 0..63), so each warp-chunk spans 64 (CDNA3/4 wavefront) token slots. The per-rank receive region is sized to maxNumInpTokenPerRank, which the harness derives as max(CONC_LIST)/TP*(MTP+1). At low concurrency this collapses below 64 (conc-64/TP8/MTP3 -> 64/8*4 = 32), so a single chunk overruns the 32-slot region -> silent out-of-bounds writes -> semantically corrupt output (decodes fine, gsm8k=0). Confirmed from the conc-64 Run B decode log: INTER_KERNEL_SWITCH=16 with DISPATCH_TOKENS=32 selects the NON-LL InterNodeV1 kernel (32 > 16), yet output was still corrupt -> the bug is the buffer-size floor, not the LL-vs-non-LL kernel choice. Fix: clamp MORI_MAX_DISPATCH_TOKENS_DECODE to >= 64 after the MTP multiply. Only raises the value at low conc; adds a few MB of staging buffer but no compute, so real throughput is unchanged (the ~3% edge of the corrupt run was an artifact of dropping work). 64 is the principled minimum vs the proven-but- larger 256. --- .../multi_node/amd_utils/server_sglang.sh | 52 +++++++------------ perf-changelog.yaml | 2 +- 2 files changed, 20 insertions(+), 34 deletions(-) diff --git a/benchmarks/multi_node/amd_utils/server_sglang.sh b/benchmarks/multi_node/amd_utils/server_sglang.sh index 14c51fbd2..8a88d4afa 100755 --- a/benchmarks/multi_node/amd_utils/server_sglang.sh +++ b/benchmarks/multi_node/amd_utils/server_sglang.sh @@ -248,6 +248,25 @@ if [[ "$DECODE_MTP_SIZE" -gt 0 ]]; then MORI_MOE_MAX_INPUT_TOKENS_DECODE=$((MORI_MOE_MAX_INPUT_TOKENS_DECODE * (DECODE_MTP_SIZE + 1))) fi +# ── MoRI dispatch-buffer warpSize floor ────────────────────────────────────── +# The MoRI All2All dispatch kernel (EpDispatchInterNodeV1Kernel / IntraNode) +# places dispatched tokens in warpSize-aligned receive slots: +# destTokId = flagSlotId * warpSize + laneId (laneId = 0..warpSize-1) +# i.e. each warp writes up to warpSize=64 (CDNA3/4 wavefront) token slots per +# chunk. The per-rank receive region is sized to maxNumInpTokenPerRank, which +# the harness derives from max(CONC_LIST)/TP*(MTP+1). At low concurrency this +# collapses below 64 (e.g. conc-64 / TP8 / MTP3 -> 64/8*4 = 32), so a single +# warp-chunk overruns the 32-slot region -> silent out-of-bounds writes -> +# semantically corrupt output (decodes fine, gsm8k = 0.0). Clamp to one +# wavefront so the slot arithmetic is always in-bounds. This only raises the +# value at low conc (high conc is naturally larger); it adds a few MB of +# staging buffer but no compute, so real throughput is unchanged. +MORI_DISPATCH_TOKENS_FLOOR=64 +if [[ "$MORI_MAX_DISPATCH_TOKENS_DECODE" -lt "$MORI_DISPATCH_TOKENS_FLOOR" ]]; then + echo "[MoRI floor] DISPATCH_TOKENS=${MORI_MAX_DISPATCH_TOKENS_DECODE} < warpSize floor ${MORI_DISPATCH_TOKENS_FLOOR}; clamping to ${MORI_DISPATCH_TOKENS_FLOOR}" + MORI_MAX_DISPATCH_TOKENS_DECODE=$MORI_DISPATCH_TOKENS_FLOOR +fi + # ============================================================================= # Cluster Topology Configuration # ============================================================================= @@ -712,39 +731,6 @@ else echo "Decode node rank: $RANK" echo "Decode parallelism: TP=${DECODE_TP_SIZE}, EP enabled: ${DECODE_ENABLE_EP}, DP enabled: ${DECODE_ENABLE_DP}" - # ── MoRI dispatch token floor patch ────────────────────────────────── - # When conc/TP is small (e.g. 64/8=8, ×4 MTP → 32 tokens), the MoRI - # All2All dispatch kernel silently corrupts output because buffer strides - # assume a minimum alignment of warpSize (64 on CDNA3/4). Patch the - # installed SGLang moriep.py to clamp num_max_dispatch_tokens_per_rank - # to at least 64 before it reaches the kernel config. - SGLANG_PKG=$(python3 -c "import sglang; print(sglang.__path__[0])" 2>/dev/null || true) - MORIEP_FILE="${SGLANG_PKG}/srt/layers/moe/token_dispatcher/moriep.py" - if [[ -n "$SGLANG_PKG" ]] && [[ -f "$MORIEP_FILE" ]]; then - python3 - "$MORIEP_FILE" << 'MORI_PATCH_EOF' -import re, sys -p = sys.argv[1] -with open(p) as f: - s = f.read() -pattern = r'(self\.num_max_dispatch_tokens_per_rank\s*=\s*get_int_env_var\(\s*"SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK",\s*4096\s*\))' -replacement = (r'\1\n' - r' self.num_max_dispatch_tokens_per_rank = ' - r'max(self.num_max_dispatch_tokens_per_rank, 64)') -s2 = re.sub(pattern, replacement, s, count=1, flags=re.DOTALL) -if s2 == s: - print(f'WARNING: MoRI patch pattern not found in {p}', file=sys.stderr) - sys.exit(1) -with open(p, 'w') as f: - f.write(s2) -print(f'[MORI PATCH] Clamped num_max_dispatch_tokens_per_rank >= 64 (warpSize) in {p}') -MORI_PATCH_EOF - if [[ $? -ne 0 ]]; then - echo "WARNING: MoRI sed patch failed — run may still exhibit corruption" - fi - else - echo "WARNING: moriep.py not found at ${MORIEP_FILE} — skipping patch" - fi - DECODE_MORI_MOE_ENV="" set -x if [[ -n "$MORI_MOE_MAX_INPUT_TOKENS_DECODE" ]]; then diff --git a/perf-changelog.yaml b/perf-changelog.yaml index df8f600fd..5e0e1acf8 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3448,5 +3448,5 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Throwaway: validate Option A fix — sed-patch moriep.py to clamp num_max_dispatch_tokens_per_rank >= 64 (AMD warpSize). If gsm8k recovers, confirms warpSize is the minimum viable buffer floor for the MoRI All2All dispatch kernel." + - "Throwaway: validate warpSize floor fix — clamp MORI_MAX_DISPATCH_TOKENS_DECODE >= 64 (CDNA3/4 wavefront) in server_sglang.sh. The MoRI All2All dispatch kernel writes warpSize-aligned receive slots (destTokId = flagSlotId*warpSize + laneId), so a per-rank buffer < 64 overruns its region -> silent corruption (conc-64/TP8/MTP3 -> 32 tokens -> gsm8k=0). If gsm8k recovers, 64 is the minimal correct floor (best perf vs the proven-but-larger 256)." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 From 272be180b3a6652b3b025b4ae459543f6a659ab8 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 16:50:52 -0700 Subject: [PATCH 07/15] raise MoRI dispatch-buffer floor to 256 (warpSize=64 proven insufficient) The conc-64 run with the warpSize floor (64) still scored gsm8k=0.00 (run 26919517564), disproving the one-wavefront hypothesis. The per-rank dispatch buffer must hold the routing fan-in (a receiving rank takes tokens from all worldSize peers), not just one warp-chunk. Empirically on MI355X: dispatch=32 -> 0.00, dispatch=64 -> 0.00, dispatch>=256 -> 0.94. Clamp to the proven 256. Throughput is unchanged; the corrupt run's ~3% edge was dropped work, not real speed. --- .../multi_node/amd_utils/server_sglang.sh | 38 ++++++++++++------- perf-changelog.yaml | 2 +- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/benchmarks/multi_node/amd_utils/server_sglang.sh b/benchmarks/multi_node/amd_utils/server_sglang.sh index 8a88d4afa..f37ad4af2 100755 --- a/benchmarks/multi_node/amd_utils/server_sglang.sh +++ b/benchmarks/multi_node/amd_utils/server_sglang.sh @@ -248,22 +248,32 @@ if [[ "$DECODE_MTP_SIZE" -gt 0 ]]; then MORI_MOE_MAX_INPUT_TOKENS_DECODE=$((MORI_MOE_MAX_INPUT_TOKENS_DECODE * (DECODE_MTP_SIZE + 1))) fi -# ── MoRI dispatch-buffer warpSize floor ────────────────────────────────────── +# ── MoRI dispatch-buffer minimum floor ─────────────────────────────────────── # The MoRI All2All dispatch kernel (EpDispatchInterNodeV1Kernel / IntraNode) -# places dispatched tokens in warpSize-aligned receive slots: -# destTokId = flagSlotId * warpSize + laneId (laneId = 0..warpSize-1) -# i.e. each warp writes up to warpSize=64 (CDNA3/4 wavefront) token slots per -# chunk. The per-rank receive region is sized to maxNumInpTokenPerRank, which -# the harness derives from max(CONC_LIST)/TP*(MTP+1). At low concurrency this -# collapses below 64 (e.g. conc-64 / TP8 / MTP3 -> 64/8*4 = 32), so a single -# warp-chunk overruns the 32-slot region -> silent out-of-bounds writes -> -# semantically corrupt output (decodes fine, gsm8k = 0.0). Clamp to one -# wavefront so the slot arithmetic is always in-bounds. This only raises the -# value at low conc (high conc is naturally larger); it adds a few MB of -# staging buffer but no compute, so real throughput is unchanged. -MORI_DISPATCH_TOKENS_FLOOR=64 +# silently corrupts output when the per-rank dispatch buffer +# (maxNumInpTokenPerRank = SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK) is too +# small. The harness derives that value from max(CONC_LIST)/TP*(MTP+1), which +# collapses at low concurrency (conc-64 / TP8 / MTP3 -> 64/8*4 = 32). Two things +# break: (1) the kernel writes tokens in warpSize-aligned chunks +# (destTokId = flagSlotId*warpSize + laneId, laneId 0..63), so a buffer < 64 +# can't even hold one wavefront; (2) a receiving rank takes tokens from all +# `worldSize` peers, so the per-rank buffer must hold the routing fan-in, not +# just the local token count. The result is out-of-bounds receive-slot writes +# -> output that decodes fine (acceptance length stays high) but is semantically +# garbage (gsm8k = 0.0). +# +# Empirically validated on MI355X (conc-64 DEP8+MTP3, this config): +# dispatch=32 -> gsm8k 0.00 (run 26913235190) +# dispatch=64 -> gsm8k 0.00 (run 26919517564) # warpSize alone insufficient +# dispatch>=256 -> gsm8k 0.94 (run 26912330265) +# So clamp to 256. This only raises the value at low conc (high conc is already +# larger); it adds a few MB of staging buffer but no compute, so real throughput +# is unchanged (the ~3% edge of the corrupt run was an artifact of dropped work). +# NOTE: 128 is untested; the proper upstream fix sizes the buffer from the +# routing fan-in rather than a flat constant. +MORI_DISPATCH_TOKENS_FLOOR=256 if [[ "$MORI_MAX_DISPATCH_TOKENS_DECODE" -lt "$MORI_DISPATCH_TOKENS_FLOOR" ]]; then - echo "[MoRI floor] DISPATCH_TOKENS=${MORI_MAX_DISPATCH_TOKENS_DECODE} < warpSize floor ${MORI_DISPATCH_TOKENS_FLOOR}; clamping to ${MORI_DISPATCH_TOKENS_FLOOR}" + echo "[MoRI floor] DISPATCH_TOKENS=${MORI_MAX_DISPATCH_TOKENS_DECODE} < floor ${MORI_DISPATCH_TOKENS_FLOOR}; clamping to ${MORI_DISPATCH_TOKENS_FLOOR}" MORI_MAX_DISPATCH_TOKENS_DECODE=$MORI_DISPATCH_TOKENS_FLOOR fi diff --git a/perf-changelog.yaml b/perf-changelog.yaml index ce84bcb3d..2d6c3d649 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3455,5 +3455,5 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Throwaway: validate warpSize floor fix — clamp MORI_MAX_DISPATCH_TOKENS_DECODE >= 64 (CDNA3/4 wavefront) in server_sglang.sh. The MoRI All2All dispatch kernel writes warpSize-aligned receive slots (destTokId = flagSlotId*warpSize + laneId), so a per-rank buffer < 64 overruns its region -> silent corruption (conc-64/TP8/MTP3 -> 32 tokens -> gsm8k=0). If gsm8k recovers, 64 is the minimal correct floor (best perf vs the proven-but-larger 256)." + - "Fix MoRI dispatch-buffer corruption at low concurrency: clamp MORI_MAX_DISPATCH_TOKENS_DECODE >= 256 in server_sglang.sh. The harness sizes the per-rank All2All dispatch buffer from max(CONC_LIST)/TP*(MTP+1), which collapses to 32 at conc-64/TP8/MTP3 and silently corrupts the dispatch kernel's receive slots (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, dispatch=64->0.00 (warpSize alone insufficient), dispatch>=256->0.94. Throughput unchanged (the corrupt run's ~3% edge was dropped work)." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 From 998408b898270b8352a4c7ac16f719b39d90b4cc Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 18:04:56 -0700 Subject: [PATCH 08/15] fix MoRI dispatch corruption at the root: moriep.py overlay floors dispatch tokens to 256; remove harness env-clamp bandaid Replaces the server_sglang.sh env clamp (a launch-script workaround) with a library-level fix: an overlay of sglang's moriep.py that floors num_max_dispatch_tokens_per_rank to 256 at its env read -- the single source of truth feeding both the kernel selection and the mori buffer-sizing arg. Root cause (adversarially verified against mori/sglang source): for single-node EP8 the intra-node dispatch (DispatchIntraNodeBlock) writes received tokens into a per-PE buffer sized MaxNumTokensToRecv() = worldSize * maxNumInpTokenPerRank (dispatch_combine.hpp:126-136; max_total_recv_tokens defaults to 0 -> that fallback, and it is a cap not a floor). The harness derives maxNumInpTokenPerRank from max(CONC_LIST)/TP*(MTP+1), which collapses to 32 at conc-64/TP8/MTP3. The per-dest atomic counter then overruns the buffer; the only guard is assert(destTokId < MaxNumTokensToRecv()), compiled out under -DNDEBUG -> silent out-of-bounds writes -> output that decodes fine (high acceptance length) but is semantically garbage (gsm8k=0). Delivery: patches/moriep.py is byte-identical to upstream v0.5.12.post1 (md5 ac626f5459...) plus a +22-line floor; auto-mounted by job.slurm via EXTRA_DOCKER_MOUNTS, gated on the v0.5.12.post1 image tag (same mechanism as the mori_conn.py overlay). Empirically validated on MI355X (conc-64 DEP8+MTP3): dispatch 32 -> gsm8k 0.00, 64 -> 0.00 (one wavefront insufficient), 256 -> 0.94. Throughput unchanged; the corrupt run's ~3% edge was dropped work. Upstream issues: sgl-project/sglang#27194, ROCm/mori#356. --- benchmarks/multi_node/amd_utils/job.slurm | 21 + .../multi_node/amd_utils/patches/README.md | 39 + .../multi_node/amd_utils/patches/moriep.py | 1134 +++++++++++++++++ .../multi_node/amd_utils/server_sglang.sh | 33 +- perf-changelog.yaml | 2 +- 5 files changed, 1200 insertions(+), 29 deletions(-) create mode 100644 benchmarks/multi_node/amd_utils/patches/moriep.py diff --git a/benchmarks/multi_node/amd_utils/job.slurm b/benchmarks/multi_node/amd_utils/job.slurm index 5e8e67606..a2d976ca7 100755 --- a/benchmarks/multi_node/amd_utils/job.slurm +++ b/benchmarks/multi_node/amd_utils/job.slurm @@ -79,6 +79,27 @@ if [[ "${MORI_CONN_PATCH:-auto}" != "skip" ]] \ echo "[job.slurm] auto-applied MoRI conn.py overlay: ${_MORI_PATCH_FILE}" fi +# ── MoRI dispatch-buffer corruption fix: moriep.py overlay ──────────── +# sglang v0.5.12.post1 silently corrupts the MoRI EP dispatch path when the +# per-rank dispatch buffer (num_max_dispatch_tokens_per_rank) is small: the +# receive buffer is sized worldSize*maxNumInpTokenPerRank and the only overflow +# guard is an assert() compiled out in release builds, so low concurrency +# (e.g. conc-64 DEP8+MTP3 -> 32 tokens) yields out-of-bounds writes and gsm8k=0. +# The overlay floors num_max_dispatch_tokens_per_rank to 256 at its env read +# (the single source of truth for kernel selection + buffer sizing). The base +# file is byte-identical to upstream v0.5.12.post1 (md5 ac626f5459...), so the +# overlay is a +22-line diff. See patches/README.md and sgl-project/sglang#27194. +_MORIEP_PATCH_FILE="$DI_REPO_DIR/benchmarks/multi_node/amd_utils/patches/moriep.py" +_MORIEP_PATCH_TARGET="/sgl-workspace/sglang/python/sglang/srt/layers/moe/token_dispatcher/moriep.py" +if [[ "${MORIEP_PATCH:-auto}" != "skip" ]] \ + && [[ -f "$_MORIEP_PATCH_FILE" ]] \ + && [[ "${DOCKER_IMAGE_NAME:-}" == *"v0.5.12.post1"* ]] \ + && [[ "${EXTRA_DOCKER_MOUNTS:-}" != *"$_MORIEP_PATCH_TARGET"* ]]; then + EXTRA_DOCKER_MOUNTS="${EXTRA_DOCKER_MOUNTS:-} -v ${_MORIEP_PATCH_FILE}:${_MORIEP_PATCH_TARGET}:ro" + export EXTRA_DOCKER_MOUNTS + echo "[job.slurm] auto-applied MoRI moriep.py dispatch-floor overlay: ${_MORIEP_PATCH_FILE}" +fi + xP="${xP:-1}" yD="${yD:-1}" diff --git a/benchmarks/multi_node/amd_utils/patches/README.md b/benchmarks/multi_node/amd_utils/patches/README.md index d9b5de79d..45bb814a9 100644 --- a/benchmarks/multi_node/amd_utils/patches/README.md +++ b/benchmarks/multi_node/amd_utils/patches/README.md @@ -60,6 +60,45 @@ This is a stop-gap. The proper upstream fix is to migrate MoRI to the plural `state_types: List[StateType]` API (full design + diff in `scripts/sglang_disagg/docs/03-upstream-pr-proposal.md`). +## `moriep.py` + +Overlays +`/sgl-workspace/sglang/python/sglang/srt/layers/moe/token_dispatcher/moriep.py`. + +Source: forked from `lmsysorg/sglang-rocm:v0.5.12.post1-*` (sglang +[v0.5.12.post1](https://github.com/sgl-project/sglang/tree/v0.5.12.post1)). +The base file is **byte-identical to the upstream tag** +(`md5 ac626f5459a699f9ac953d9d8e71d861`); the overlay is a single ++22-line insertion in `MoriTokenDispatcher.__init__`. + +**Bug it fixes:** at low concurrency the MoRI EP dispatch path silently +corrupts output (decodes fine, acceptance length stays high, but gsm8k +drops to 0). The per-rank dispatch buffer +`num_max_dispatch_tokens_per_rank` (→ mori `max_num_inp_token_per_rank`) +is derived by the harness as `max(CONC_LIST)/TP*(MTP+1)`, which collapses +at low conc (conc-64 / TP8 / MTP3 → `64/8*4 = 32`). MoRI sizes its +receive buffer `MaxNumTokensToRecv() = worldSize * maxNumInpTokenPerRank` +(`max_total_recv_tokens` defaults to 0 → that fallback, and it is a *cap* +not a floor — `dispatch_combine.hpp:126-136`). The intra-node dispatch +kernel's per-dest atomic counter then runs past that buffer; the only +guard is `assert(destTokId < MaxNumTokensToRecv())`, compiled out under +`-DNDEBUG`, so the result is silent out-of-bounds writes +(`internode_v1.cpp` `DispatchIntraNodeBlock`). + +The overlay floors `num_max_dispatch_tokens_per_rank` to **256** right at +its env read — the single source of truth that feeds both +`get_ep_dispatch_configs()` (kernel selection) and the buffer-sizing +arg. Empirically validated on MI355X (conc-64 DEP8+MTP3): +dispatch `32 → gsm8k 0.00`, `64 → 0.00` (one wavefront is not enough), +`256 → 0.94`. + +This is a stop-gap. The proper upstream fix is in MoRI: size the receive +buffer from the routing fan-in and turn the compiled-out `assert` into a +real bounds guard (see [ROCm/mori#356](https://github.com/ROCm/mori/issues/356)). +The integration-level guard belongs in sglang's `moriep.py` +([sgl-project/sglang#27194](https://github.com/sgl-project/sglang/issues/27194)) — +this overlay is exactly that guard, pending upstream merge. + ## How to enable ```bash diff --git a/benchmarks/multi_node/amd_utils/patches/moriep.py b/benchmarks/multi_node/amd_utils/patches/moriep.py new file mode 100644 index 000000000..4ab882c29 --- /dev/null +++ b/benchmarks/multi_node/amd_utils/patches/moriep.py @@ -0,0 +1,1134 @@ +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple + +from sglang.srt.layers.dp_attention import get_is_extend_in_batch +from sglang.srt.layers.moe.token_dispatcher.base import ( + BaseDispatcher, + CombineInput, + CombineInputFormat, + DispatchOutput, + DispatchOutputFormat, +) +from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPPDispatchHooks +from sglang.srt.layers.moe.topk import TopKOutput +from sglang.srt.layers.moe.utils import ( + DeepEPMode, + is_tbo_enabled, +) +from sglang.srt.utils import ( + get_bool_env_var, + get_int_env_var, + is_hip, +) + +if TYPE_CHECKING: + from sglang.srt.single_batch_overlap import CombineOverlapArgs + import mori + +from enum import Enum, auto +from functools import lru_cache + +import torch + +from sglang.srt.distributed import ( + get_moe_expert_parallel_rank, + get_moe_expert_parallel_world_size, +) +from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype + +# Blockwise quantization group sizes: number of elements sharing one scale factor +FP8_BLOCK_SIZE = 128 +MXFP4_BLOCK_SIZE = 32 + +_is_hip = is_hip() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _use_aiter: + from aiter import QuantType, get_hip_quant + +logger = logging.getLogger(__name__) + + +class MoriEPPDispatchHooks(DeepEPPDispatchHooks): + + def __call__(self, dispatcher: BaseDispatcher): + for hook_fun in self.hook_dict.values(): + hook_fun(dispatcher) + + +class MoriEPNormalDispatchOutput(NamedTuple): + """Mori EP normal dispatch output.""" + + hidden_states: torch.Tensor + hidden_states_scale: Optional[torch.Tensor] + topk_ids: torch.Tensor + topk_weights: torch.Tensor + num_recv_tokens_per_expert: List[int] + origin_topk_ids: torch.Tensor + origin_topk_weights: torch.Tensor + out_dtype: torch.dtype + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.DEEPEP_NORMAL + + +class MoriEPLLDispatchOutput(NamedTuple): + """Mori EP low latency dispatch output.""" + + hidden_states: torch.Tensor + hidden_states_scale: Optional[torch.Tensor] + topk_ids: torch.Tensor + topk_weights: torch.Tensor + num_recv_tokens_per_expert: List[int] + origin_topk_ids: torch.Tensor + origin_topk_weights: torch.Tensor + out_dtype: torch.dtype + + @property + def format(self) -> DispatchOutputFormat: + return DispatchOutputFormat.DEEPEP_LL + + +assert isinstance(MoriEPNormalDispatchOutput, DispatchOutput) +assert isinstance(MoriEPLLDispatchOutput, DispatchOutput) + + +class MoriEPNormalCombineInput(NamedTuple): + """Mori EP combine input.""" + + hidden_states: torch.Tensor + topk_ids: torch.Tensor + topk_weights: torch.Tensor + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.DEEPEP_NORMAL + + +class MoriEPLLCombineInput(NamedTuple): + """Mori EP combine input.""" + + hidden_states: torch.Tensor + topk_ids: torch.Tensor + topk_weights: torch.Tensor + + @property + def format(self) -> CombineInputFormat: + return CombineInputFormat.DEEPEP_LL + + +assert isinstance(MoriEPNormalCombineInput, CombineInput) +assert isinstance(MoriEPLLCombineInput, CombineInput) + + +class EpMode(Enum): + INTRA_NODE = "intra_node" + INTER_NODE = "inter_node" + LOW_LATENCY = "low_latency" + + +class DispatchDtype(Enum): + bf16 = "bfloat16" + fp8 = "float8_blockwise" + fp4 = "mxfp4_blockwise" + + +class CombineDtype(Enum): + bf16 = "bfloat16" + fp8 = "float8_blockwise" + fp8_direct_cast = "float8_direct_cast" + + +@dataclass(frozen=True) +class EpDispatchConfig: + kernel_type: mori.ops.EpDispatchCombineKernelType + warp_num_per_block: int + block_num: int + rdma_block_num: int + + +def get_ep_dispatch_configs(num_max_dispatch_tokens_per_rank: int = 4096): + import mori + + # Selects the inter-node kernel. `InterNodeV1LL` is used if `num_max_dispatch_tokens_per_rank` + # is less than or equal to the threshold, otherwise `InterNodeV1` is used. The threshold defaults to 256. + inter_kernel_switch_threshold = get_int_env_var( + "SGLANG_MORI_DISPATCH_INTER_KERNEL_SWITCH_THRESHOLD", 256 + ) + + inter_kernel_type = ( + mori.ops.EpDispatchCombineKernelType.InterNodeV1LL + if num_max_dispatch_tokens_per_rank <= inter_kernel_switch_threshold + else mori.ops.EpDispatchCombineKernelType.InterNodeV1 + ) + + return { + # TODO(billishyahao): need to tune different configs for intra node async + # Also could be tuned for different AMD platform + EpMode.INTRA_NODE: EpDispatchConfig( + kernel_type=mori.ops.EpDispatchCombineKernelType.IntraNode, + warp_num_per_block=16, + block_num=80, + rdma_block_num=0, + ), + EpMode.INTER_NODE: EpDispatchConfig( + kernel_type=inter_kernel_type, + warp_num_per_block=8, + block_num=64, + rdma_block_num=32, + ), + EpMode.LOW_LATENCY: EpDispatchConfig( + kernel_type=mori.ops.EpDispatchCombineKernelType.AsyncLL, + warp_num_per_block=8, + block_num=64, + rdma_block_num=32, + ), + } + + +# init_mori_op only needs do once in model initial stage +# use lru_cache to reuse the same mori_op instance to avoid the init overhead for mori +@lru_cache(maxsize=4) +def init_mori_op( + group, + router_topk, + num_experts, + num_local_experts, + hidden_size, + params_dtype, + num_max_dispatch_tokens_per_rank, + deepep_mode, + instance_id=0, + dispatch_dtype=DispatchDtype.bf16, + combine_dtype=CombineDtype.bf16, + enable_sdma=False, +): + + import mori + + world_size = get_moe_expert_parallel_world_size() + rank = get_moe_expert_parallel_rank() + + gpu_per_node = 8 if world_size >= 8 else world_size + + group_name = f"mori" + cpu_group = group.cpu_group + try: + torch._C._distributed_c10d._register_process_group(group_name, cpu_group) + except Exception as e: + if "already registered" in str(e): + logger.info( + f"[MORI init] The same process group is already " + f"registered. Ignoring [{str(e)}]" + ) + else: + raise + else: + # If new group is newly registered then need to init mori shmem. However + # if the group is registered already then need to skip init mori shmem + # and reuse the previous one. + mori.shmem.shmem_torch_process_group_init(group_name) + + mode = EpMode.INTRA_NODE if world_size <= 8 else EpMode.INTER_NODE + async_mode = deepep_mode.enable_low_latency() or enable_sdma + if async_mode: + mode = EpMode.LOW_LATENCY + + cfg = get_ep_dispatch_configs(num_max_dispatch_tokens_per_rank)[mode] + + kernel_type = cfg.kernel_type + warp_num_per_block = cfg.warp_num_per_block + block_num = cfg.block_num + rdma_block_num = cfg.rdma_block_num + + hidden_dim = hidden_size + scale_dim = 1 + data_type = fp8_dtype + scale_type_size = torch.float32.itemsize + + if dispatch_dtype == DispatchDtype.fp8: + scale_dim = hidden_size // FP8_BLOCK_SIZE + elif dispatch_dtype == DispatchDtype.fp4: + # FP4 kernel still takes the original hidden size and do quantization + # internally, so hidden_dim is not reduced. The reason is that for FP4 + # quantization, we need to keep the original hidden size to calculate + # the quantization scale correctly. Don't use packed hidden size for FP4 kernel. + hidden_dim = hidden_size + scale_dim = hidden_size // MXFP4_BLOCK_SIZE + data_type = torch.float4_e2m1fn_x2 + scale_type_size = torch.float8_e8m0fnu.itemsize + + if mode == EpMode.INTRA_NODE: + if num_max_dispatch_tokens_per_rank < 128: + block_num = 225 + warp_num_per_block = 5 + else: + block_num = 256 + warp_num_per_block = 16 + + # Fp8 blockwise combine uses its own internal scale_dim driven which can be + # overridden by env ``MORI_FP8_COMBINE_SCALE_DIM`` (default 56) + # See https://github.com/ROCm/mori/blob/96ffa169710f214e76e07abe5008d686fe54522b/python/mori/ops/dispatch_combine.py#L81-L84 + combine_quant_type = "none" + if combine_dtype == CombineDtype.fp8: + combine_quant_type = "fp8_blockwise" + elif combine_dtype == CombineDtype.fp8_direct_cast: + combine_quant_type = "fp8_direct_cast" + + logger.info( + f"[MORI init] {world_size=} {rank=} {hidden_size=} {params_dtype=} " + f"{num_max_dispatch_tokens_per_rank=} {num_local_experts=} " + f"{router_topk=} {mode=} {dispatch_dtype=} {combine_dtype=} " + ) + + def check_mori_compatibility(kwargs: dict) -> None: + """Remove kwargs not accepted by the installed mori's EpDispatchCombineConfig.""" + import dataclasses + + config_cls = mori.ops.EpDispatchCombineConfig + valid_kwargs = {f.name for f in dataclasses.fields(config_cls)} + + invalid_kwargs = set(kwargs.keys()) - valid_kwargs + for arg in invalid_kwargs: + logger.warning(f"[MORI compat] Removing incompatible argument {arg} ") + del kwargs[arg] + + # Definition refer to https://github.com/ROCm/mori/blob/f9be5ee2e5ac87256b9523399ae9d4d0e8a54f53/python/mori/ops/dispatch_combine.py#L66-L121 + common_kwargs = dict( + data_type=data_type, + rank=rank, + world_size=world_size, + hidden_dim=hidden_dim, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + max_token_type_size=params_dtype.itemsize, + max_num_inp_token_per_rank=num_max_dispatch_tokens_per_rank, + num_experts_per_rank=num_local_experts, + num_experts_per_token=router_topk, + warp_num_per_block=warp_num_per_block, + block_num=block_num, + max_total_recv_tokens=get_int_env_var( + "SGLANG_MORI_PREALLOC_MAX_RECV_TOKENS", 0 + ), + kernel_type=kernel_type, + gpu_per_node=gpu_per_node, + rdma_block_num=rdma_block_num, + num_qp_per_pe=2, # Number of queue pairs per processing element + quant_type=combine_quant_type, + ) + + check_mori_compatibility(common_kwargs) + + mori_config = mori.ops.EpDispatchCombineConfig(**common_kwargs) + mori_op = mori.ops.EpDispatchCombineOp(mori_config) + return mori_op + + +class CommStreamPool: + _streams = {} # key -> torch.cuda.Stream + + @classmethod + def _make_key(cls, group): + return (torch.cuda.current_device(), id(group)) + + @classmethod + def get_stream_from_pool(cls, group) -> torch.cuda.Stream: + key = cls._make_key(group) + stream = cls._streams.get(key) + if stream is None: + stream = torch.cuda.Stream(priority=0) + cls._streams[key] = stream + return stream + + @classmethod + def clear_group(cls, group): + key = (torch.cuda.current_device(), id(group)) + cls._streams.pop(key, None) + + +class _MoriEPDispatcherImplBase: + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool, + num_experts: int, + num_local_experts: int, + hidden_size: int, + params_dtype: torch.dtype, + deepep_mode: DeepEPMode, + instance_id: int = 0, + ): + try: + import mori # noqa: F401 + except ImportError: + raise ImportError("Mori EP is not installed. Please install.") + self.group = group + self.router_topk = router_topk + self.permute_fusion = permute_fusion + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.params_dtype = params_dtype + self.deepep_mode = deepep_mode + self.instance_id = instance_id + + self.num_max_dispatch_tokens_per_rank = get_int_env_var( + "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 4096 + ) + # [InferenceX overlay] Floor the per-rank dispatch buffer to 256. + # The MoRI receive buffer is sized MaxNumTokensToRecv() = worldSize * + # maxNumInpTokenPerRank (mori dispatch_combine.hpp; max_total_recv_tokens + # defaults to 0 -> that fallback, and it is a cap not a floor). A small + # per-rank value therefore starves the intra-node receive buffer, and the + # only overflow guard is an assert(destTokId < MaxNumTokensToRecv()) that + # is compiled out under -DNDEBUG -> silent out-of-bounds writes -> output + # that decodes fine (high acceptance length) but is semantically garbage. + # Empirically on MI355X (conc-64 DEP8+MTP3): dispatch 32 -> gsm8k 0.00, + # 64 -> 0.00, 256 -> 0.94. Flooring here is the single source of truth: it + # feeds both get_ep_dispatch_configs() (LL/non-LL kernel selection) and + # max_num_inp_token_per_rank (buffer sizing) via the mori_op property. + # Upstream: sgl-project/sglang#27194, ROCm/mori#356. + _MORI_DISPATCH_TOKENS_FLOOR = 256 + if self.num_max_dispatch_tokens_per_rank < _MORI_DISPATCH_TOKENS_FLOOR: + logger.warning( + "[MORI floor] num_max_dispatch_tokens_per_rank=%d < %d; " + "clamping to avoid silent MoRI dispatch corruption", + self.num_max_dispatch_tokens_per_rank, + _MORI_DISPATCH_TOKENS_FLOOR, + ) + self.num_max_dispatch_tokens_per_rank = _MORI_DISPATCH_TOKENS_FLOOR + + self.enable_sdma = get_bool_env_var("MORI_ENABLE_SDMA", "false") + + self._mori_op = None + self.dispatch_dtype = DispatchDtype.bf16 + self.combine_dtype = CombineDtype.bf16 + + self.quant_config: Optional[dict] = None + + self.overlap_args: Optional[CombineOverlapArgs] = None + self.meta_overlap_args: Optional[dict] = None + + @property + def mori_op(self): + if self._mori_op is None: + # If set_quant_config was never called, apply env var override now + if self.quant_config is None: + self._apply_dispatch_dtype_override() + self._mori_op = init_mori_op( + self.group, + self.router_topk, + self.num_experts, + self.num_local_experts, + self.hidden_size, + self.params_dtype, + self.num_max_dispatch_tokens_per_rank, + self.deepep_mode, + self.instance_id, + self.dispatch_dtype, + self.combine_dtype, + self.enable_sdma, + ) + return self._mori_op + + def _apply_dispatch_dtype_override(self): + """Apply env var override to fp8_dispatch/fp4_dispatch/fp8_combine flags.""" + if "SGLANG_MORI_DISPATCH_DTYPE" in os.environ: + dispatch_dtype = os.environ["SGLANG_MORI_DISPATCH_DTYPE"].lower() + if dispatch_dtype != "auto": + if dispatch_dtype == "bf16": + self.dispatch_dtype = DispatchDtype.bf16 + elif dispatch_dtype == "fp8": + self.dispatch_dtype = DispatchDtype.fp8 + elif dispatch_dtype == "fp4": + self.dispatch_dtype = DispatchDtype.fp4 + elif ( + "SGLANG_MORI_FP8_DISP" in os.environ or "SGLANG_MORI_FP4_DISP" in os.environ + ): + # Deprecated: will be removed in a future release + logger.warning_once( + "SGLANG_MORI_FP8_DISP and SGLANG_MORI_FP4_DISP are deprecated " + "and will be removed in a future release. " + "Use SGLANG_MORI_DISPATCH_DTYPE=auto|bf16|fp8|fp4 instead." + ) + if get_bool_env_var("SGLANG_MORI_FP8_DISP", "False"): + self.dispatch_dtype = DispatchDtype.fp8 + if get_bool_env_var("SGLANG_MORI_FP4_DISP", "False"): + self.dispatch_dtype = DispatchDtype.fp4 + + if "SGLANG_MORI_COMBINE_DTYPE" in os.environ: + combine_dtype = os.environ["SGLANG_MORI_COMBINE_DTYPE"].lower() + if combine_dtype != "auto": + if combine_dtype == "fp8": + self.combine_dtype = CombineDtype.fp8 + elif combine_dtype == "bf16": + self.combine_dtype = CombineDtype.bf16 + elif combine_dtype == "fp8_direct_cast": + self.combine_dtype = CombineDtype.fp8_direct_cast + elif "SGLANG_MORI_FP8_COMB" in os.environ: + # Deprecated: will be removed in a future release + logger.warning_once( + "SGLANG_MORI_FP8_COMB is deprecated " + "and will be removed in a future release. " + "Use SGLANG_MORI_COMBINE_DTYPE=auto|bf16|fp8|fp8_direct_cast instead." + ) + if get_bool_env_var("SGLANG_MORI_FP8_COMB", "False"): + self.combine_dtype = CombineDtype.fp8 + + def dispatch_a( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + ): + raise NotImplementedError + + def dispatch_b(self, *args, **kwargs): + raise NotImplementedError + + def combine_a( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + ): + raise NotImplementedError + + def combine_b(self, *args, **kwargs): + raise NotImplementedError + + def set_quant_config(self, quant_config: dict) -> None: + self.quant_config = quant_config + # Auto-detect dispatch quantization from weight dtype + weight_dtype = quant_config.get("weight_dtype", None) + if weight_dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + self.dispatch_dtype = DispatchDtype.fp8 + self.combine_dtype = CombineDtype.bf16 + elif weight_dtype == torch.float4_e2m1fn_x2: + self.dispatch_dtype = DispatchDtype.fp4 + self.combine_dtype = CombineDtype.fp8 + else: + self.dispatch_dtype = DispatchDtype.bf16 + self.combine_dtype = CombineDtype.bf16 + # Apply env var override immediately so dispatch_a sees correct flags + self._apply_dispatch_dtype_override() + + def set_overlap_args( + self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict + ) -> None: + self.overlap_args = combine_overlap_args + self.meta_overlap_args = meta_overlap_args + + def clear_overlap_args(self) -> None: + self.overlap_args = None + self.meta_overlap_args = None + + +class _MoriEPDispatcherImplNormal(_MoriEPDispatcherImplBase): + def __init__(self, async_finish: bool, **kwargs): + super().__init__(**kwargs) + + self.async_finish = async_finish + self.quant_config = {} + self.fp8_quant_func = get_hip_quant(QuantType.per_1x128) + self.fp4_quant_func = get_hip_quant(QuantType.per_1x32) + self.enable_dual_stream = is_tbo_enabled() + self._comm_stream = None + if self.enable_dual_stream: + self._comm_stream = CommStreamPool.get_stream_from_pool(self.group) + + def _capture_event_if_async(self) -> Optional[torch.cuda.Event]: + assert self.enable_dual_stream, "dual stream must be enabled" + if not self.async_finish: + return None + ev = torch.cuda.Event(blocking=False, interprocess=False) + ev.record(torch.cuda.current_stream()) + return ev + + def dispatch_a( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + ): + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + + num_token = hidden_states.shape[0] + output_dtype = hidden_states.dtype + scale = None + + if self.dispatch_dtype == DispatchDtype.fp8: + # FP8 quant + if num_token > 0: + # NOTE: aiter is able to handle token=0 case in UT. But for some + # reason it failed at e2e case. Root cause TBD. + hidden_states, scale = self.fp8_quant_func( + hidden_states, quant_dtype=fp8_dtype + ) + else: + hidden_states = torch.empty( + hidden_states.shape, dtype=fp8_dtype, device=hidden_states.device + ) + scale = torch.empty( + (0, self.hidden_size // FP8_BLOCK_SIZE), + dtype=torch.float32, + device=hidden_states.device, + ) + + elif self.dispatch_dtype == DispatchDtype.fp4: + # FP4 quant + if num_token > 0: + hidden_states, scale = self.fp4_quant_func(hidden_states, shuffle=False) + else: + hidden_states = torch.empty( + (0, self.hidden_size // 2), + dtype=torch.float4_e2m1fn_x2, + device=hidden_states.device, + ) + scale = torch.empty( + (0, self.hidden_size // MXFP4_BLOCK_SIZE), + dtype=torch.float8_e8m0fnu, + device=hidden_states.device, + ) + + previous_event = self._capture_event_if_async() if self._comm_stream else None + + return ( + hidden_states, + topk_weights, + topk_ids, + scale, + output_dtype, + previous_event, + ) + + def dispatch_b( + self, + hidden_states, + topk_weights, + topk_ids, + scale, + output_dtype, + previous_event, + ): + + ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + packed_recv_count, + done_event, + ) = self._dispatch_core( + hidden_states, + topk_weights, + topk_ids, + scale=scale, + previous_event=previous_event, + ) + + if self._comm_stream and self.async_finish and done_event is not None: + torch.cuda.current_stream().wait_event(done_event) + + return MoriEPNormalDispatchOutput( + hidden_states=packed_recv_hidden, + hidden_states_scale=recv_scales, + topk_ids=recv_topk_ids, + topk_weights=recv_topk_weights, + num_recv_tokens_per_expert=packed_recv_count, + origin_topk_ids=topk_ids, + origin_topk_weights=topk_weights, + out_dtype=output_dtype, + ) + + def _dispatch_core( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + scale: Optional[torch.Tensor] = None, + previous_event: Optional[torch.cuda.Event] = None, + ): + done_event: Optional[torch.cuda.Event] = None + + if self._comm_stream: + compute_stream = torch.cuda.current_stream() + comm_stream = self._comm_stream # comm stream + + for t in (hidden_states, topk_weights, topk_ids): + t.record_stream(comm_stream) + if scale is not None: + scale.record_stream(comm_stream) + + with torch.cuda.stream(comm_stream): + # if (previous_event) stream_wait(comm_stream, previous_event) + # else stream_wait(comm_stream, compute_stream) + + if previous_event is not None: + comm_stream.wait_event(previous_event) + else: + comm_stream.wait_stream(compute_stream) + + dispatch_fn = ( + self.mori_op.dispatch_send + if self.enable_sdma + else self.mori_op.dispatch + ) + ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + packed_recv_count, + ) = dispatch_fn(hidden_states, topk_weights, scale, topk_ids) + if self.enable_sdma: + self.mori_op.dispatch_recv() + + if self.async_finish: + done_event = torch.cuda.Event(blocking=False, interprocess=False) + done_event.record(comm_stream) + else: + compute_stream.wait_stream(comm_stream) + + for t in ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + ): + if t is not None: + t.record_stream(comm_stream) + else: + + ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + packed_recv_count, + ) = self.mori_op.dispatch(hidden_states, topk_weights, scale, topk_ids) + + # TODO(billishyahao): EPLB + # get_global_expert_distribution_recorder().on_deepep_dispatch_normal( + + return ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + packed_recv_count, + done_event, + ) + + def combine_a( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + ): + previous_event = self._capture_event_if_async() if self._comm_stream else None + return hidden_states, topk_ids, topk_weights, previous_event + + def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event): + + hidden_states, done_event = self._combine_core( + hidden_states, topk_ids, topk_weights, previous_event + ) + + if self._comm_stream and self.async_finish and done_event is not None: + torch.cuda.current_stream().wait_event(done_event) + + return hidden_states + + def _combine_core( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + previous_event: Optional[torch.cuda.Event], + ): + done_event: Optional[torch.cuda.Event] = None + + if self._comm_stream: + compute_stream = torch.cuda.current_stream() + comm_stream = self._comm_stream + + for t in (hidden_states, topk_ids, topk_weights): + t.record_stream(comm_stream) + + with torch.cuda.stream(comm_stream): + if previous_event is not None: + comm_stream.wait_event(previous_event) + else: + comm_stream.wait_stream(compute_stream) + + combine_fn = ( + self.mori_op.combine_send + if self.enable_sdma + else self.mori_op.combine + ) + combined_hidden_states = combine_fn(hidden_states, None, topk_ids)[0] + if self.enable_sdma: + self.mori_op.combine_recv() + + if self.async_finish: + done_event = torch.cuda.Event(blocking=False, interprocess=False) + done_event.record(comm_stream) + else: + compute_stream.wait_stream(comm_stream) + + combined_hidden_states.record_stream(comm_stream) + + else: + combined_hidden_states = self.mori_op.combine( + hidden_states, None, topk_ids + )[0] + + return combined_hidden_states, done_event + + def set_quant_config(self, quant_config: dict): + super().set_quant_config(quant_config) + + +class _MoriEPDispatcherImplLowLatency(_MoriEPDispatcherImplBase): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.quant_config = {} + self.fp8_quant_func = get_hip_quant(QuantType.per_1x128) + self.fp4_quant_func = get_hip_quant(QuantType.per_1x32) + + def dispatch_a( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + ): + import mori + + assert ( + self.mori_op.config.kernel_type + is mori.ops.EpDispatchCombineKernelType.AsyncLL + ), "mori asyncll mismatch" + + num_tokens = hidden_states.shape[0] + output_dtype = hidden_states.dtype + scale = None + + if self.dispatch_dtype == DispatchDtype.fp8: + # FP8 quant + if num_tokens > 0: + # NOTE: aiter is able to handle token=0 case in UT. But for some + # reason it failed at e2e case. Root cause TBD. + hidden_states, scale = self.fp8_quant_func( + hidden_states, quant_dtype=fp8_dtype + ) + else: + hidden_states = torch.empty( + hidden_states.shape, dtype=fp8_dtype, device=hidden_states.device + ) + scale = torch.empty( + (0, self.hidden_size // FP8_BLOCK_SIZE), + dtype=torch.float32, + device=hidden_states.device, + ) + + elif self.dispatch_dtype == DispatchDtype.fp4: + # FP4 quant + if num_tokens > 0: + hidden_states, scale = self.fp4_quant_func(hidden_states, shuffle=False) + else: + hidden_states = torch.empty( + (0, self.hidden_size // 2), + dtype=torch.float4_e2m1fn_x2, + device=hidden_states.device, + ) + scale = torch.empty( + (0, self.hidden_size // MXFP4_BLOCK_SIZE), + dtype=torch.float8_e8m0fnu, + device=hidden_states.device, + ) + + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + + ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + packed_recv_count, + ) = self._dispatch_core(hidden_states, topk_weights, topk_ids, scale=scale) + + return ( + packed_recv_hidden, + recv_topk_weights, + recv_topk_ids, + recv_scales, + packed_recv_count, + topk_weights, + topk_ids, + output_dtype, + ) + + def dispatch_b( + self, + hidden_states, + recv_topk_weights, + recv_topk_ids, + recv_scales, + packed_recv_count, + topk_weights, + topk_ids, + output_dtype, + ): + + ##TODO(billishyahao): add assertion here to check async + import mori + + assert ( + self.mori_op.config.kernel_type + is mori.ops.EpDispatchCombineKernelType.AsyncLL + ), "mori asyncll mismatch" + + self.mori_op.dispatch_recv() + + return MoriEPLLDispatchOutput( + hidden_states=hidden_states, + hidden_states_scale=recv_scales, + topk_ids=recv_topk_ids, + topk_weights=recv_topk_weights, + num_recv_tokens_per_expert=packed_recv_count, + origin_topk_ids=topk_ids, + origin_topk_weights=topk_weights, + out_dtype=output_dtype, + ) + + def _dispatch_core( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + scale: Optional[torch.Tensor] = None, + ): + ##TODO(billishyahao): add assertion here to check async + + ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + packed_recv_count, + ) = self.mori_op.dispatch_send(hidden_states, topk_weights, scale, topk_ids) + + return ( + packed_recv_hidden, + recv_topk_weights, + recv_scales, + recv_topk_ids, + packed_recv_count, + ) + + def combine_a( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + overlap_args: Optional[CombineOverlapArgs] = None, + ): + hidden_states = self._combine_core( + hidden_states, + topk_ids, + topk_weights, + overlap_args=overlap_args, + ) + return hidden_states, topk_ids, topk_weights, overlap_args + + def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event): + + self.mori_op.combine_recv() + + return hidden_states[0] + + def _combine_core( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + overlap_args: Optional[CombineOverlapArgs] = None, + ): + combined_hidden_states = self.mori_op.combine_send( + hidden_states, None, topk_ids + ) + + return combined_hidden_states + + def set_quant_config(self, quant_config: dict): + super().set_quant_config(quant_config) + + +@dataclass +class _Stage(Enum): + INITIAL = auto() + AFTER_DISPATCH_A = auto() + AFTER_DISPATCH_B = auto() + AFTER_COMBINE_A = auto() + + +class MoriEPDispatcher(BaseDispatcher): + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool = False, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + deepep_mode: DeepEPMode = DeepEPMode.AUTO, + async_finish: bool = False, + return_recv_hook: bool = False, + instance_id: int = 0, + ): + super().__init__() + + self.deepep_mode = deepep_mode + + async_mode = self.deepep_mode.enable_low_latency() + if get_bool_env_var("SGLANG_ROCM_USE_MULTI_STREAM") and not async_mode: + logger.warning_once( + "SGLANG_ROCM_USE_MULTI_STREAM=1 is set but Mori AsyncLL is " + "not enabled (--deepep-mode=%s). The alt-stream overlap only " + "frees up CUs when dispatch/combine runs on the AsyncLL " + "copy-engine kernel; otherwise it stays on CUs and competes " + "with the alt-stream work. Pass --deepep-mode low_latency " + "(or auto) to enable the AsyncLL kernel.", + self.deepep_mode.value, + ) + + common_kwargs = dict( + group=group, + router_topk=router_topk, + permute_fusion=permute_fusion, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + params_dtype=params_dtype, + deepep_mode=deepep_mode, + instance_id=instance_id, + ) + + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher = _MoriEPDispatcherImplLowLatency( + **common_kwargs, + ) + + if self.deepep_mode.enable_normal(): + self._normal_dispatcher = _MoriEPDispatcherImplNormal( + async_finish=async_finish, + **common_kwargs, + ) + + self._stage = _Stage.INITIAL + self._deepep_dispatch_hooks = MoriEPPDispatchHooks() + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + ) -> DispatchOutput: + self.dispatch_a(hidden_states, topk_output) + if self._deepep_dispatch_hooks is not None: + self._deepep_dispatch_hooks(self) + ret = self.dispatch_b() + return ret + + def dispatch_a( + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, + ): + self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) + inner_state = self._get_impl().dispatch_a( + hidden_states=hidden_states, + topk_output=topk_output, + ) + self._dispatch_intermediate_state = inner_state + + def dispatch_b(self): + self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) + inner_state = self._dispatch_intermediate_state + del self._dispatch_intermediate_state + return self._get_impl().dispatch_b(*inner_state) + + def combine( + self, + combine_input: CombineInput, + ) -> Tuple: + self.combine_a(combine_input) + ret = self.combine_b() + return ret + + def combine_a( + self, + combine_input: CombineInput, + ): + hidden_states, topk_ids, topk_weights = combine_input + self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) + inner_state = self._get_impl().combine_a( + hidden_states=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + ) + self._combine_intermediate_state = inner_state + + def combine_b(self): + self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) + inner_state = self._combine_intermediate_state + del self._combine_intermediate_state + return self._get_impl().combine_b(*inner_state) + + def _get_impl(self) -> _MoriEPDispatcherImplBase: + is_extend_in_batch = get_is_extend_in_batch() + resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch) + if resolved_deepep_mode == DeepEPMode.NORMAL: + return self._normal_dispatcher + elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY: + return self._low_latency_dispatcher + else: + raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") + + def _update_stage(self, old_stage, new_stage): + assert self._stage == old_stage + self._stage = new_stage + + def set_quant_config(self, quant_config: dict): + super().set_quant_config(quant_config) + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher.set_quant_config(quant_config) + if self.deepep_mode.enable_normal(): + self._normal_dispatcher.set_quant_config(quant_config) + + def set_overlap_args( + self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict + ): + super().set_overlap_args(combine_overlap_args, meta_overlap_args) + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher.set_overlap_args( + combine_overlap_args, meta_overlap_args + ) + if self.deepep_mode.enable_normal(): + self._normal_dispatcher.set_overlap_args( + combine_overlap_args, meta_overlap_args + ) + + def clear_overlap_args(self): + super().clear_overlap_args() + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher.clear_overlap_args() + if self.deepep_mode.enable_normal(): + self._normal_dispatcher.clear_overlap_args() + + def register_deepep_dispatch_hook(self, hook): + return self._deepep_dispatch_hooks.register_hook(hook) diff --git a/benchmarks/multi_node/amd_utils/server_sglang.sh b/benchmarks/multi_node/amd_utils/server_sglang.sh index f37ad4af2..fda3aee54 100755 --- a/benchmarks/multi_node/amd_utils/server_sglang.sh +++ b/benchmarks/multi_node/amd_utils/server_sglang.sh @@ -248,34 +248,11 @@ if [[ "$DECODE_MTP_SIZE" -gt 0 ]]; then MORI_MOE_MAX_INPUT_TOKENS_DECODE=$((MORI_MOE_MAX_INPUT_TOKENS_DECODE * (DECODE_MTP_SIZE + 1))) fi -# ── MoRI dispatch-buffer minimum floor ─────────────────────────────────────── -# The MoRI All2All dispatch kernel (EpDispatchInterNodeV1Kernel / IntraNode) -# silently corrupts output when the per-rank dispatch buffer -# (maxNumInpTokenPerRank = SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK) is too -# small. The harness derives that value from max(CONC_LIST)/TP*(MTP+1), which -# collapses at low concurrency (conc-64 / TP8 / MTP3 -> 64/8*4 = 32). Two things -# break: (1) the kernel writes tokens in warpSize-aligned chunks -# (destTokId = flagSlotId*warpSize + laneId, laneId 0..63), so a buffer < 64 -# can't even hold one wavefront; (2) a receiving rank takes tokens from all -# `worldSize` peers, so the per-rank buffer must hold the routing fan-in, not -# just the local token count. The result is out-of-bounds receive-slot writes -# -> output that decodes fine (acceptance length stays high) but is semantically -# garbage (gsm8k = 0.0). -# -# Empirically validated on MI355X (conc-64 DEP8+MTP3, this config): -# dispatch=32 -> gsm8k 0.00 (run 26913235190) -# dispatch=64 -> gsm8k 0.00 (run 26919517564) # warpSize alone insufficient -# dispatch>=256 -> gsm8k 0.94 (run 26912330265) -# So clamp to 256. This only raises the value at low conc (high conc is already -# larger); it adds a few MB of staging buffer but no compute, so real throughput -# is unchanged (the ~3% edge of the corrupt run was an artifact of dropped work). -# NOTE: 128 is untested; the proper upstream fix sizes the buffer from the -# routing fan-in rather than a flat constant. -MORI_DISPATCH_TOKENS_FLOOR=256 -if [[ "$MORI_MAX_DISPATCH_TOKENS_DECODE" -lt "$MORI_DISPATCH_TOKENS_FLOOR" ]]; then - echo "[MoRI floor] DISPATCH_TOKENS=${MORI_MAX_DISPATCH_TOKENS_DECODE} < floor ${MORI_DISPATCH_TOKENS_FLOOR}; clamping to ${MORI_DISPATCH_TOKENS_FLOOR}" - MORI_MAX_DISPATCH_TOKENS_DECODE=$MORI_DISPATCH_TOKENS_FLOOR -fi +# NOTE: the low-concurrency MoRI dispatch-buffer corruption (small +# SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK -> silent OOB -> gsm8k=0) is fixed +# at the root cause by the moriep.py overlay (patches/moriep.py, auto-mounted by +# job.slurm), which floors num_max_dispatch_tokens_per_rank to 256 inside sglang. +# The earlier harness-level env clamp here has been removed in favor of that. # ============================================================================= # Cluster Topology Configuration diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 2d6c3d649..039be5042 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3455,5 +3455,5 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Fix MoRI dispatch-buffer corruption at low concurrency: clamp MORI_MAX_DISPATCH_TOKENS_DECODE >= 256 in server_sglang.sh. The harness sizes the per-rank All2All dispatch buffer from max(CONC_LIST)/TP*(MTP+1), which collapses to 32 at conc-64/TP8/MTP3 and silently corrupts the dispatch kernel's receive slots (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, dispatch=64->0.00 (warpSize alone insufficient), dispatch>=256->0.94. Throughput unchanged (the corrupt run's ~3% edge was dropped work)." + - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with a moriep.py overlay (patches/moriep.py, auto-mounted by job.slurm, gated on v0.5.12.post1) that floors num_max_dispatch_tokens_per_rank to 256 inside sglang. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the harness value collapses to 32, overrunning the dispatch kernel's receive slots (the only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged (corrupt run's ~3% edge was dropped work). Upstream: sgl-project/sglang#27194, ROCm/mori#356." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 From 6508c77284a3d1258842485069bf12aa714cd64f Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 19:13:29 -0700 Subject: [PATCH 09/15] switch MoRI dispatch-floor fix from full-file overlay to in-place patch (vendor image diverges from upstream) The full-file moriep.py overlay crashed the scheduler at init: AttributeError: 'MoriEPDispatcher' object has no attribute 'expert_mask_gpu' RuntimeError: Rank 0 scheduler died during initialization Root cause of the failure: the lmsysorg/sglang-rocm:v0.5.12.post1 image ships a DOWNSTREAM-patched moriep.py (class MoriEPDispatcher, extra attrs like expert_mask_gpu) that diverges from the upstream v0.5.12.post1 tag. The overlay was byte-identical to the upstream tag (md5 ac626f5459...), so bind-mounting it reverted the AMD additions -> AttributeError. (The overlay DID mount and the floor DID fire -- "[MORI floor] num_max_dispatch_tokens_per_rank=32 < 256; clamping" -- so the fix value is right; only the delivery was wrong.) Fix: replace the overlay with patches/apply_moriep_dispatch_floor.py, a surgical in-place patch run by server_sglang.sh inside the container. It edits the image's own moriep.py, injecting `num_max_dispatch_tokens_per_rank = max(..., 256)` after the dispatch-token env read (line-based, balanced-paren end detection, class- agnostic, idempotent, fail-loud-but-non-fatal with a diagnostic dump of the image's actual source). This preserves all vendor downstream code. The fix value (256) is unchanged and proven (env-clamp run gsm8k 0.94). Upstream: sgl-project/sglang#27194, ROCm/mori#356. --- benchmarks/multi_node/amd_utils/job.slurm | 21 - .../multi_node/amd_utils/patches/README.md | 44 +- .../patches/apply_moriep_dispatch_floor.py | 127 ++ .../multi_node/amd_utils/patches/moriep.py | 1134 ----------------- .../multi_node/amd_utils/server_sglang.sh | 10 + perf-changelog.yaml | 2 +- 6 files changed, 167 insertions(+), 1171 deletions(-) create mode 100644 benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py delete mode 100644 benchmarks/multi_node/amd_utils/patches/moriep.py diff --git a/benchmarks/multi_node/amd_utils/job.slurm b/benchmarks/multi_node/amd_utils/job.slurm index a2d976ca7..5e8e67606 100755 --- a/benchmarks/multi_node/amd_utils/job.slurm +++ b/benchmarks/multi_node/amd_utils/job.slurm @@ -79,27 +79,6 @@ if [[ "${MORI_CONN_PATCH:-auto}" != "skip" ]] \ echo "[job.slurm] auto-applied MoRI conn.py overlay: ${_MORI_PATCH_FILE}" fi -# ── MoRI dispatch-buffer corruption fix: moriep.py overlay ──────────── -# sglang v0.5.12.post1 silently corrupts the MoRI EP dispatch path when the -# per-rank dispatch buffer (num_max_dispatch_tokens_per_rank) is small: the -# receive buffer is sized worldSize*maxNumInpTokenPerRank and the only overflow -# guard is an assert() compiled out in release builds, so low concurrency -# (e.g. conc-64 DEP8+MTP3 -> 32 tokens) yields out-of-bounds writes and gsm8k=0. -# The overlay floors num_max_dispatch_tokens_per_rank to 256 at its env read -# (the single source of truth for kernel selection + buffer sizing). The base -# file is byte-identical to upstream v0.5.12.post1 (md5 ac626f5459...), so the -# overlay is a +22-line diff. See patches/README.md and sgl-project/sglang#27194. -_MORIEP_PATCH_FILE="$DI_REPO_DIR/benchmarks/multi_node/amd_utils/patches/moriep.py" -_MORIEP_PATCH_TARGET="/sgl-workspace/sglang/python/sglang/srt/layers/moe/token_dispatcher/moriep.py" -if [[ "${MORIEP_PATCH:-auto}" != "skip" ]] \ - && [[ -f "$_MORIEP_PATCH_FILE" ]] \ - && [[ "${DOCKER_IMAGE_NAME:-}" == *"v0.5.12.post1"* ]] \ - && [[ "${EXTRA_DOCKER_MOUNTS:-}" != *"$_MORIEP_PATCH_TARGET"* ]]; then - EXTRA_DOCKER_MOUNTS="${EXTRA_DOCKER_MOUNTS:-} -v ${_MORIEP_PATCH_FILE}:${_MORIEP_PATCH_TARGET}:ro" - export EXTRA_DOCKER_MOUNTS - echo "[job.slurm] auto-applied MoRI moriep.py dispatch-floor overlay: ${_MORIEP_PATCH_FILE}" -fi - xP="${xP:-1}" yD="${yD:-1}" diff --git a/benchmarks/multi_node/amd_utils/patches/README.md b/benchmarks/multi_node/amd_utils/patches/README.md index 45bb814a9..97ab47d26 100644 --- a/benchmarks/multi_node/amd_utils/patches/README.md +++ b/benchmarks/multi_node/amd_utils/patches/README.md @@ -60,16 +60,26 @@ This is a stop-gap. The proper upstream fix is to migrate MoRI to the plural `state_types: List[StateType]` API (full design + diff in `scripts/sglang_disagg/docs/03-upstream-pr-proposal.md`). -## `moriep.py` - -Overlays -`/sgl-workspace/sglang/python/sglang/srt/layers/moe/token_dispatcher/moriep.py`. - -Source: forked from `lmsysorg/sglang-rocm:v0.5.12.post1-*` (sglang -[v0.5.12.post1](https://github.com/sgl-project/sglang/tree/v0.5.12.post1)). -The base file is **byte-identical to the upstream tag** -(`md5 ac626f5459a699f9ac953d9d8e71d861`); the overlay is a single -+22-line insertion in `MoriTokenDispatcher.__init__`. +## `apply_moriep_dispatch_floor.py` (in-place patch, NOT a bind-mount overlay) + +This one is different from `mori_conn.py`: it is a **surgical in-place +patch script**, not a full-file bind-mount overlay. It is run inside the +container by `server_sglang.sh` (right after `env.sh`) and edits the +installed +`/sgl-workspace/sglang/.../token_dispatcher/moriep.py` +in place, injecting a single floor after the dispatch-token env read. + +**Why not a bind-mount overlay (learned the hard way):** the +`lmsysorg/sglang-rocm:v0.5.12.post1-*` image ships a **downstream-patched +`moriep.py`** (class `MoriEPDispatcher`, with attrs such as +`expert_mask_gpu`) that diverges from the upstream +[v0.5.12.post1](https://github.com/sgl-project/sglang/tree/v0.5.12.post1) +tag. A full-file overlay of the upstream file (even one byte-identical to +the tag, `md5 ac626f5459...`) reverts the AMD additions and crashes the +scheduler at init: `AttributeError: 'MoriEPDispatcher' object has no +attribute 'expert_mask_gpu'`. The in-place patch touches only the +dispatch-token read and preserves all downstream code, so it is robust to +the vendor fork. **Bug it fixes:** at low concurrency the MoRI EP dispatch path silently corrupts output (decodes fine, acceptance length stays high, but gsm8k @@ -85,19 +95,23 @@ guard is `assert(destTokId < MaxNumTokensToRecv())`, compiled out under `-DNDEBUG`, so the result is silent out-of-bounds writes (`internode_v1.cpp` `DispatchIntraNodeBlock`). -The overlay floors `num_max_dispatch_tokens_per_rank` to **256** right at +The patch floors `num_max_dispatch_tokens_per_rank` to **256** right at its env read — the single source of truth that feeds both `get_ep_dispatch_configs()` (kernel selection) and the buffer-sizing -arg. Empirically validated on MI355X (conc-64 DEP8+MTP3): -dispatch `32 → gsm8k 0.00`, `64 → 0.00` (one wavefront is not enough), -`256 → 0.94`. +arg. It is idempotent and fail-loud-but-non-fatal (a structure miss prints +a clear marker plus the surrounding source and lets the server proceed). +Empirically validated on MI355X (conc-64 DEP8+MTP3): dispatch `32 → +gsm8k 0.00`, `64 → 0.00` (one wavefront is not enough), `256 → 0.94`. This is a stop-gap. The proper upstream fix is in MoRI: size the receive buffer from the routing fan-in and turn the compiled-out `assert` into a real bounds guard (see [ROCm/mori#356](https://github.com/ROCm/mori/issues/356)). The integration-level guard belongs in sglang's `moriep.py` ([sgl-project/sglang#27194](https://github.com/sgl-project/sglang/issues/27194)) — -this overlay is exactly that guard, pending upstream merge. +this patch is exactly that guard, pending upstream merge. No +`EXTRA_DOCKER_MOUNTS` wiring is needed; the patch is applied +unconditionally by `server_sglang.sh` and no-ops when the value is +already ≥256 (e.g. prefill, which uses 8192). ## How to enable diff --git a/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py b/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py new file mode 100644 index 000000000..f9263a0d0 --- /dev/null +++ b/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +"""Surgically floor the MoRI per-rank dispatch buffer to >=256 in the installed +sglang `moriep.py`, in place, inside the container. + +Why in-place (not a bind-mount overlay): the lmsysorg/sglang-rocm image ships a +*downstream-patched* moriep.py (class `MoriEPDispatcher`, extra attrs such as +`expert_mask_gpu`) that diverges from the upstream v0.5.12.post1 tag. A full-file +overlay of the upstream file reverts those AMD additions and crashes the +scheduler at init (`AttributeError: ... 'expert_mask_gpu'`). So we patch the +image's own file and touch only the dispatch-token read. + +The bug being fixed: at low concurrency the per-rank dispatch buffer +(num_max_dispatch_tokens_per_rank -> mori max_num_inp_token_per_rank) collapses +(conc-64/TP8/MTP3 -> 64/8*4 = 32). MoRI sizes its receive buffer +MaxNumTokensToRecv() = worldSize * maxNumInpTokenPerRank (dispatch_combine.hpp; +max_total_recv_tokens defaults to 0 -> that fallback, and it is a cap not a +floor). The intra-node dispatch kernel's per-dest atomic counter then overruns +the buffer; the only guard is assert(destTokId < MaxNumTokensToRecv()) which is +compiled out under -DNDEBUG -> silent out-of-bounds writes -> output that decodes +fine (high acceptance length) but is semantically garbage (gsm8k=0). + +Empirically on MI355X (conc-64 DEP8+MTP3): dispatch 32 -> gsm8k 0.00, +64 -> 0.00 (one wavefront insufficient), 256 -> 0.94. We floor to 256. + +Idempotent and fail-loud-but-non-fatal: a regex/structure miss prints a clear +marker and the surrounding source (for diagnosis) but does not abort the server. + +Upstream: sgl-project/sglang#27194, ROCm/mori#356. +""" +import os +import re +import sys + +FLOOR = 256 +MARKER = "[InferenceX moriep dispatch floor]" +TAG = "[moriep-floor]" + + +def find_target(): + try: + import sglang + except Exception as e: # pragma: no cover + print(f"{TAG} ERROR: could not import sglang ({e}); NOT patched") + return None + path = os.path.join( + os.path.dirname(sglang.__file__), + "srt", "layers", "moe", "token_dispatcher", "moriep.py", + ) + if not os.path.isfile(path): + print(f"{TAG} ERROR: moriep.py not found at {path}; NOT patched") + return None + return path + + +def main(): + path = find_target() + if path is None: + return 0 # non-fatal + + with open(path) as f: + src = f.read() + lines = src.splitlines(keepends=True) + + # Diagnostic: always show where the dispatch-token count is read/used so the + # CI log reveals the image's actual file shape even on a clean apply. + for i, l in enumerate(lines): + if "num_max_dispatch_tokens_per_rank" in l: + print(f"{TAG}[diag] {path}:{i + 1}: {l.rstrip()}") + + if MARKER in src: + print(f"{TAG} already applied; skipping") + return 0 + + # Find the assignment that reads the env var, regardless of class name or + # formatting: `self.num_max_dispatch_tokens_per_rank = get_int_env_var(`. + start = None + for i, l in enumerate(lines): + if re.search( + r"self\.num_max_dispatch_tokens_per_rank\s*=\s*get_int_env_var\s*\(", + l, + ): + start = i + break + if start is None: + print( + f"{TAG} ERROR: dispatch-token env read not found in {path}; " + f"NOT patched (server will run UNPATCHED -> expect corruption at " + f"low conc). See [diag] lines above for the actual source shape." + ) + return 0 # non-fatal: surface loudly but let the run proceed + + # Walk forward to the end of the (possibly multi-line) call by balancing parens. + depth = 0 + end = start + for j in range(start, len(lines)): + depth += lines[j].count("(") - lines[j].count(")") + if depth <= 0: + end = j + break + + indent = re.match(r"\s*", lines[start]).group(0) + floor_block = ( + f"{indent}# {MARKER} floor to {FLOOR} (warpSize/fan-in safe). MoRI recv buffer\n" + f"{indent}# is worldSize*maxNumInpTokenPerRank; values below {FLOOR} silently\n" + f"{indent}# corrupt the dispatch path (gsm8k=0). sgl#27194 / mori#356.\n" + f"{indent}self.num_max_dispatch_tokens_per_rank = max(\n" + f"{indent} self.num_max_dispatch_tokens_per_rank, {FLOOR}\n" + f"{indent})\n" + ) + lines.insert(end + 1, floor_block) + + try: + with open(path, "w") as f: + f.write("".join(lines)) + except OSError as e: + print(f"{TAG} ERROR: could not write {path} ({e}); NOT patched") + return 0 + + print( + f"{TAG} applied: floored num_max_dispatch_tokens_per_rank to >= {FLOOR} " + f"in {path} (after line {end + 1})" + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/benchmarks/multi_node/amd_utils/patches/moriep.py b/benchmarks/multi_node/amd_utils/patches/moriep.py deleted file mode 100644 index 4ab882c29..000000000 --- a/benchmarks/multi_node/amd_utils/patches/moriep.py +++ /dev/null @@ -1,1134 +0,0 @@ -from __future__ import annotations - -import logging -import os -from dataclasses import dataclass -from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple - -from sglang.srt.layers.dp_attention import get_is_extend_in_batch -from sglang.srt.layers.moe.token_dispatcher.base import ( - BaseDispatcher, - CombineInput, - CombineInputFormat, - DispatchOutput, - DispatchOutputFormat, -) -from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPPDispatchHooks -from sglang.srt.layers.moe.topk import TopKOutput -from sglang.srt.layers.moe.utils import ( - DeepEPMode, - is_tbo_enabled, -) -from sglang.srt.utils import ( - get_bool_env_var, - get_int_env_var, - is_hip, -) - -if TYPE_CHECKING: - from sglang.srt.single_batch_overlap import CombineOverlapArgs - import mori - -from enum import Enum, auto -from functools import lru_cache - -import torch - -from sglang.srt.distributed import ( - get_moe_expert_parallel_rank, - get_moe_expert_parallel_world_size, -) -from sglang.srt.layers.quantization.fp8_kernel import fp8_dtype - -# Blockwise quantization group sizes: number of elements sharing one scale factor -FP8_BLOCK_SIZE = 128 -MXFP4_BLOCK_SIZE = 32 - -_is_hip = is_hip() -_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip - -if _use_aiter: - from aiter import QuantType, get_hip_quant - -logger = logging.getLogger(__name__) - - -class MoriEPPDispatchHooks(DeepEPPDispatchHooks): - - def __call__(self, dispatcher: BaseDispatcher): - for hook_fun in self.hook_dict.values(): - hook_fun(dispatcher) - - -class MoriEPNormalDispatchOutput(NamedTuple): - """Mori EP normal dispatch output.""" - - hidden_states: torch.Tensor - hidden_states_scale: Optional[torch.Tensor] - topk_ids: torch.Tensor - topk_weights: torch.Tensor - num_recv_tokens_per_expert: List[int] - origin_topk_ids: torch.Tensor - origin_topk_weights: torch.Tensor - out_dtype: torch.dtype - - @property - def format(self) -> DispatchOutputFormat: - return DispatchOutputFormat.DEEPEP_NORMAL - - -class MoriEPLLDispatchOutput(NamedTuple): - """Mori EP low latency dispatch output.""" - - hidden_states: torch.Tensor - hidden_states_scale: Optional[torch.Tensor] - topk_ids: torch.Tensor - topk_weights: torch.Tensor - num_recv_tokens_per_expert: List[int] - origin_topk_ids: torch.Tensor - origin_topk_weights: torch.Tensor - out_dtype: torch.dtype - - @property - def format(self) -> DispatchOutputFormat: - return DispatchOutputFormat.DEEPEP_LL - - -assert isinstance(MoriEPNormalDispatchOutput, DispatchOutput) -assert isinstance(MoriEPLLDispatchOutput, DispatchOutput) - - -class MoriEPNormalCombineInput(NamedTuple): - """Mori EP combine input.""" - - hidden_states: torch.Tensor - topk_ids: torch.Tensor - topk_weights: torch.Tensor - - @property - def format(self) -> CombineInputFormat: - return CombineInputFormat.DEEPEP_NORMAL - - -class MoriEPLLCombineInput(NamedTuple): - """Mori EP combine input.""" - - hidden_states: torch.Tensor - topk_ids: torch.Tensor - topk_weights: torch.Tensor - - @property - def format(self) -> CombineInputFormat: - return CombineInputFormat.DEEPEP_LL - - -assert isinstance(MoriEPNormalCombineInput, CombineInput) -assert isinstance(MoriEPLLCombineInput, CombineInput) - - -class EpMode(Enum): - INTRA_NODE = "intra_node" - INTER_NODE = "inter_node" - LOW_LATENCY = "low_latency" - - -class DispatchDtype(Enum): - bf16 = "bfloat16" - fp8 = "float8_blockwise" - fp4 = "mxfp4_blockwise" - - -class CombineDtype(Enum): - bf16 = "bfloat16" - fp8 = "float8_blockwise" - fp8_direct_cast = "float8_direct_cast" - - -@dataclass(frozen=True) -class EpDispatchConfig: - kernel_type: mori.ops.EpDispatchCombineKernelType - warp_num_per_block: int - block_num: int - rdma_block_num: int - - -def get_ep_dispatch_configs(num_max_dispatch_tokens_per_rank: int = 4096): - import mori - - # Selects the inter-node kernel. `InterNodeV1LL` is used if `num_max_dispatch_tokens_per_rank` - # is less than or equal to the threshold, otherwise `InterNodeV1` is used. The threshold defaults to 256. - inter_kernel_switch_threshold = get_int_env_var( - "SGLANG_MORI_DISPATCH_INTER_KERNEL_SWITCH_THRESHOLD", 256 - ) - - inter_kernel_type = ( - mori.ops.EpDispatchCombineKernelType.InterNodeV1LL - if num_max_dispatch_tokens_per_rank <= inter_kernel_switch_threshold - else mori.ops.EpDispatchCombineKernelType.InterNodeV1 - ) - - return { - # TODO(billishyahao): need to tune different configs for intra node async - # Also could be tuned for different AMD platform - EpMode.INTRA_NODE: EpDispatchConfig( - kernel_type=mori.ops.EpDispatchCombineKernelType.IntraNode, - warp_num_per_block=16, - block_num=80, - rdma_block_num=0, - ), - EpMode.INTER_NODE: EpDispatchConfig( - kernel_type=inter_kernel_type, - warp_num_per_block=8, - block_num=64, - rdma_block_num=32, - ), - EpMode.LOW_LATENCY: EpDispatchConfig( - kernel_type=mori.ops.EpDispatchCombineKernelType.AsyncLL, - warp_num_per_block=8, - block_num=64, - rdma_block_num=32, - ), - } - - -# init_mori_op only needs do once in model initial stage -# use lru_cache to reuse the same mori_op instance to avoid the init overhead for mori -@lru_cache(maxsize=4) -def init_mori_op( - group, - router_topk, - num_experts, - num_local_experts, - hidden_size, - params_dtype, - num_max_dispatch_tokens_per_rank, - deepep_mode, - instance_id=0, - dispatch_dtype=DispatchDtype.bf16, - combine_dtype=CombineDtype.bf16, - enable_sdma=False, -): - - import mori - - world_size = get_moe_expert_parallel_world_size() - rank = get_moe_expert_parallel_rank() - - gpu_per_node = 8 if world_size >= 8 else world_size - - group_name = f"mori" - cpu_group = group.cpu_group - try: - torch._C._distributed_c10d._register_process_group(group_name, cpu_group) - except Exception as e: - if "already registered" in str(e): - logger.info( - f"[MORI init] The same process group is already " - f"registered. Ignoring [{str(e)}]" - ) - else: - raise - else: - # If new group is newly registered then need to init mori shmem. However - # if the group is registered already then need to skip init mori shmem - # and reuse the previous one. - mori.shmem.shmem_torch_process_group_init(group_name) - - mode = EpMode.INTRA_NODE if world_size <= 8 else EpMode.INTER_NODE - async_mode = deepep_mode.enable_low_latency() or enable_sdma - if async_mode: - mode = EpMode.LOW_LATENCY - - cfg = get_ep_dispatch_configs(num_max_dispatch_tokens_per_rank)[mode] - - kernel_type = cfg.kernel_type - warp_num_per_block = cfg.warp_num_per_block - block_num = cfg.block_num - rdma_block_num = cfg.rdma_block_num - - hidden_dim = hidden_size - scale_dim = 1 - data_type = fp8_dtype - scale_type_size = torch.float32.itemsize - - if dispatch_dtype == DispatchDtype.fp8: - scale_dim = hidden_size // FP8_BLOCK_SIZE - elif dispatch_dtype == DispatchDtype.fp4: - # FP4 kernel still takes the original hidden size and do quantization - # internally, so hidden_dim is not reduced. The reason is that for FP4 - # quantization, we need to keep the original hidden size to calculate - # the quantization scale correctly. Don't use packed hidden size for FP4 kernel. - hidden_dim = hidden_size - scale_dim = hidden_size // MXFP4_BLOCK_SIZE - data_type = torch.float4_e2m1fn_x2 - scale_type_size = torch.float8_e8m0fnu.itemsize - - if mode == EpMode.INTRA_NODE: - if num_max_dispatch_tokens_per_rank < 128: - block_num = 225 - warp_num_per_block = 5 - else: - block_num = 256 - warp_num_per_block = 16 - - # Fp8 blockwise combine uses its own internal scale_dim driven which can be - # overridden by env ``MORI_FP8_COMBINE_SCALE_DIM`` (default 56) - # See https://github.com/ROCm/mori/blob/96ffa169710f214e76e07abe5008d686fe54522b/python/mori/ops/dispatch_combine.py#L81-L84 - combine_quant_type = "none" - if combine_dtype == CombineDtype.fp8: - combine_quant_type = "fp8_blockwise" - elif combine_dtype == CombineDtype.fp8_direct_cast: - combine_quant_type = "fp8_direct_cast" - - logger.info( - f"[MORI init] {world_size=} {rank=} {hidden_size=} {params_dtype=} " - f"{num_max_dispatch_tokens_per_rank=} {num_local_experts=} " - f"{router_topk=} {mode=} {dispatch_dtype=} {combine_dtype=} " - ) - - def check_mori_compatibility(kwargs: dict) -> None: - """Remove kwargs not accepted by the installed mori's EpDispatchCombineConfig.""" - import dataclasses - - config_cls = mori.ops.EpDispatchCombineConfig - valid_kwargs = {f.name for f in dataclasses.fields(config_cls)} - - invalid_kwargs = set(kwargs.keys()) - valid_kwargs - for arg in invalid_kwargs: - logger.warning(f"[MORI compat] Removing incompatible argument {arg} ") - del kwargs[arg] - - # Definition refer to https://github.com/ROCm/mori/blob/f9be5ee2e5ac87256b9523399ae9d4d0e8a54f53/python/mori/ops/dispatch_combine.py#L66-L121 - common_kwargs = dict( - data_type=data_type, - rank=rank, - world_size=world_size, - hidden_dim=hidden_dim, - scale_dim=scale_dim, - scale_type_size=scale_type_size, - max_token_type_size=params_dtype.itemsize, - max_num_inp_token_per_rank=num_max_dispatch_tokens_per_rank, - num_experts_per_rank=num_local_experts, - num_experts_per_token=router_topk, - warp_num_per_block=warp_num_per_block, - block_num=block_num, - max_total_recv_tokens=get_int_env_var( - "SGLANG_MORI_PREALLOC_MAX_RECV_TOKENS", 0 - ), - kernel_type=kernel_type, - gpu_per_node=gpu_per_node, - rdma_block_num=rdma_block_num, - num_qp_per_pe=2, # Number of queue pairs per processing element - quant_type=combine_quant_type, - ) - - check_mori_compatibility(common_kwargs) - - mori_config = mori.ops.EpDispatchCombineConfig(**common_kwargs) - mori_op = mori.ops.EpDispatchCombineOp(mori_config) - return mori_op - - -class CommStreamPool: - _streams = {} # key -> torch.cuda.Stream - - @classmethod - def _make_key(cls, group): - return (torch.cuda.current_device(), id(group)) - - @classmethod - def get_stream_from_pool(cls, group) -> torch.cuda.Stream: - key = cls._make_key(group) - stream = cls._streams.get(key) - if stream is None: - stream = torch.cuda.Stream(priority=0) - cls._streams[key] = stream - return stream - - @classmethod - def clear_group(cls, group): - key = (torch.cuda.current_device(), id(group)) - cls._streams.pop(key, None) - - -class _MoriEPDispatcherImplBase: - def __init__( - self, - group: torch.distributed.ProcessGroup, - router_topk: int, - permute_fusion: bool, - num_experts: int, - num_local_experts: int, - hidden_size: int, - params_dtype: torch.dtype, - deepep_mode: DeepEPMode, - instance_id: int = 0, - ): - try: - import mori # noqa: F401 - except ImportError: - raise ImportError("Mori EP is not installed. Please install.") - self.group = group - self.router_topk = router_topk - self.permute_fusion = permute_fusion - self.num_experts = num_experts - self.num_local_experts = num_local_experts - self.hidden_size = hidden_size - self.params_dtype = params_dtype - self.deepep_mode = deepep_mode - self.instance_id = instance_id - - self.num_max_dispatch_tokens_per_rank = get_int_env_var( - "SGLANG_MORI_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 4096 - ) - # [InferenceX overlay] Floor the per-rank dispatch buffer to 256. - # The MoRI receive buffer is sized MaxNumTokensToRecv() = worldSize * - # maxNumInpTokenPerRank (mori dispatch_combine.hpp; max_total_recv_tokens - # defaults to 0 -> that fallback, and it is a cap not a floor). A small - # per-rank value therefore starves the intra-node receive buffer, and the - # only overflow guard is an assert(destTokId < MaxNumTokensToRecv()) that - # is compiled out under -DNDEBUG -> silent out-of-bounds writes -> output - # that decodes fine (high acceptance length) but is semantically garbage. - # Empirically on MI355X (conc-64 DEP8+MTP3): dispatch 32 -> gsm8k 0.00, - # 64 -> 0.00, 256 -> 0.94. Flooring here is the single source of truth: it - # feeds both get_ep_dispatch_configs() (LL/non-LL kernel selection) and - # max_num_inp_token_per_rank (buffer sizing) via the mori_op property. - # Upstream: sgl-project/sglang#27194, ROCm/mori#356. - _MORI_DISPATCH_TOKENS_FLOOR = 256 - if self.num_max_dispatch_tokens_per_rank < _MORI_DISPATCH_TOKENS_FLOOR: - logger.warning( - "[MORI floor] num_max_dispatch_tokens_per_rank=%d < %d; " - "clamping to avoid silent MoRI dispatch corruption", - self.num_max_dispatch_tokens_per_rank, - _MORI_DISPATCH_TOKENS_FLOOR, - ) - self.num_max_dispatch_tokens_per_rank = _MORI_DISPATCH_TOKENS_FLOOR - - self.enable_sdma = get_bool_env_var("MORI_ENABLE_SDMA", "false") - - self._mori_op = None - self.dispatch_dtype = DispatchDtype.bf16 - self.combine_dtype = CombineDtype.bf16 - - self.quant_config: Optional[dict] = None - - self.overlap_args: Optional[CombineOverlapArgs] = None - self.meta_overlap_args: Optional[dict] = None - - @property - def mori_op(self): - if self._mori_op is None: - # If set_quant_config was never called, apply env var override now - if self.quant_config is None: - self._apply_dispatch_dtype_override() - self._mori_op = init_mori_op( - self.group, - self.router_topk, - self.num_experts, - self.num_local_experts, - self.hidden_size, - self.params_dtype, - self.num_max_dispatch_tokens_per_rank, - self.deepep_mode, - self.instance_id, - self.dispatch_dtype, - self.combine_dtype, - self.enable_sdma, - ) - return self._mori_op - - def _apply_dispatch_dtype_override(self): - """Apply env var override to fp8_dispatch/fp4_dispatch/fp8_combine flags.""" - if "SGLANG_MORI_DISPATCH_DTYPE" in os.environ: - dispatch_dtype = os.environ["SGLANG_MORI_DISPATCH_DTYPE"].lower() - if dispatch_dtype != "auto": - if dispatch_dtype == "bf16": - self.dispatch_dtype = DispatchDtype.bf16 - elif dispatch_dtype == "fp8": - self.dispatch_dtype = DispatchDtype.fp8 - elif dispatch_dtype == "fp4": - self.dispatch_dtype = DispatchDtype.fp4 - elif ( - "SGLANG_MORI_FP8_DISP" in os.environ or "SGLANG_MORI_FP4_DISP" in os.environ - ): - # Deprecated: will be removed in a future release - logger.warning_once( - "SGLANG_MORI_FP8_DISP and SGLANG_MORI_FP4_DISP are deprecated " - "and will be removed in a future release. " - "Use SGLANG_MORI_DISPATCH_DTYPE=auto|bf16|fp8|fp4 instead." - ) - if get_bool_env_var("SGLANG_MORI_FP8_DISP", "False"): - self.dispatch_dtype = DispatchDtype.fp8 - if get_bool_env_var("SGLANG_MORI_FP4_DISP", "False"): - self.dispatch_dtype = DispatchDtype.fp4 - - if "SGLANG_MORI_COMBINE_DTYPE" in os.environ: - combine_dtype = os.environ["SGLANG_MORI_COMBINE_DTYPE"].lower() - if combine_dtype != "auto": - if combine_dtype == "fp8": - self.combine_dtype = CombineDtype.fp8 - elif combine_dtype == "bf16": - self.combine_dtype = CombineDtype.bf16 - elif combine_dtype == "fp8_direct_cast": - self.combine_dtype = CombineDtype.fp8_direct_cast - elif "SGLANG_MORI_FP8_COMB" in os.environ: - # Deprecated: will be removed in a future release - logger.warning_once( - "SGLANG_MORI_FP8_COMB is deprecated " - "and will be removed in a future release. " - "Use SGLANG_MORI_COMBINE_DTYPE=auto|bf16|fp8|fp8_direct_cast instead." - ) - if get_bool_env_var("SGLANG_MORI_FP8_COMB", "False"): - self.combine_dtype = CombineDtype.fp8 - - def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_output: TopKOutput, - ): - raise NotImplementedError - - def dispatch_b(self, *args, **kwargs): - raise NotImplementedError - - def combine_a( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - ): - raise NotImplementedError - - def combine_b(self, *args, **kwargs): - raise NotImplementedError - - def set_quant_config(self, quant_config: dict) -> None: - self.quant_config = quant_config - # Auto-detect dispatch quantization from weight dtype - weight_dtype = quant_config.get("weight_dtype", None) - if weight_dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): - self.dispatch_dtype = DispatchDtype.fp8 - self.combine_dtype = CombineDtype.bf16 - elif weight_dtype == torch.float4_e2m1fn_x2: - self.dispatch_dtype = DispatchDtype.fp4 - self.combine_dtype = CombineDtype.fp8 - else: - self.dispatch_dtype = DispatchDtype.bf16 - self.combine_dtype = CombineDtype.bf16 - # Apply env var override immediately so dispatch_a sees correct flags - self._apply_dispatch_dtype_override() - - def set_overlap_args( - self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict - ) -> None: - self.overlap_args = combine_overlap_args - self.meta_overlap_args = meta_overlap_args - - def clear_overlap_args(self) -> None: - self.overlap_args = None - self.meta_overlap_args = None - - -class _MoriEPDispatcherImplNormal(_MoriEPDispatcherImplBase): - def __init__(self, async_finish: bool, **kwargs): - super().__init__(**kwargs) - - self.async_finish = async_finish - self.quant_config = {} - self.fp8_quant_func = get_hip_quant(QuantType.per_1x128) - self.fp4_quant_func = get_hip_quant(QuantType.per_1x32) - self.enable_dual_stream = is_tbo_enabled() - self._comm_stream = None - if self.enable_dual_stream: - self._comm_stream = CommStreamPool.get_stream_from_pool(self.group) - - def _capture_event_if_async(self) -> Optional[torch.cuda.Event]: - assert self.enable_dual_stream, "dual stream must be enabled" - if not self.async_finish: - return None - ev = torch.cuda.Event(blocking=False, interprocess=False) - ev.record(torch.cuda.current_stream()) - return ev - - def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_output: TopKOutput, - ): - topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids - - num_token = hidden_states.shape[0] - output_dtype = hidden_states.dtype - scale = None - - if self.dispatch_dtype == DispatchDtype.fp8: - # FP8 quant - if num_token > 0: - # NOTE: aiter is able to handle token=0 case in UT. But for some - # reason it failed at e2e case. Root cause TBD. - hidden_states, scale = self.fp8_quant_func( - hidden_states, quant_dtype=fp8_dtype - ) - else: - hidden_states = torch.empty( - hidden_states.shape, dtype=fp8_dtype, device=hidden_states.device - ) - scale = torch.empty( - (0, self.hidden_size // FP8_BLOCK_SIZE), - dtype=torch.float32, - device=hidden_states.device, - ) - - elif self.dispatch_dtype == DispatchDtype.fp4: - # FP4 quant - if num_token > 0: - hidden_states, scale = self.fp4_quant_func(hidden_states, shuffle=False) - else: - hidden_states = torch.empty( - (0, self.hidden_size // 2), - dtype=torch.float4_e2m1fn_x2, - device=hidden_states.device, - ) - scale = torch.empty( - (0, self.hidden_size // MXFP4_BLOCK_SIZE), - dtype=torch.float8_e8m0fnu, - device=hidden_states.device, - ) - - previous_event = self._capture_event_if_async() if self._comm_stream else None - - return ( - hidden_states, - topk_weights, - topk_ids, - scale, - output_dtype, - previous_event, - ) - - def dispatch_b( - self, - hidden_states, - topk_weights, - topk_ids, - scale, - output_dtype, - previous_event, - ): - - ( - packed_recv_hidden, - recv_topk_weights, - recv_scales, - recv_topk_ids, - packed_recv_count, - done_event, - ) = self._dispatch_core( - hidden_states, - topk_weights, - topk_ids, - scale=scale, - previous_event=previous_event, - ) - - if self._comm_stream and self.async_finish and done_event is not None: - torch.cuda.current_stream().wait_event(done_event) - - return MoriEPNormalDispatchOutput( - hidden_states=packed_recv_hidden, - hidden_states_scale=recv_scales, - topk_ids=recv_topk_ids, - topk_weights=recv_topk_weights, - num_recv_tokens_per_expert=packed_recv_count, - origin_topk_ids=topk_ids, - origin_topk_weights=topk_weights, - out_dtype=output_dtype, - ) - - def _dispatch_core( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - scale: Optional[torch.Tensor] = None, - previous_event: Optional[torch.cuda.Event] = None, - ): - done_event: Optional[torch.cuda.Event] = None - - if self._comm_stream: - compute_stream = torch.cuda.current_stream() - comm_stream = self._comm_stream # comm stream - - for t in (hidden_states, topk_weights, topk_ids): - t.record_stream(comm_stream) - if scale is not None: - scale.record_stream(comm_stream) - - with torch.cuda.stream(comm_stream): - # if (previous_event) stream_wait(comm_stream, previous_event) - # else stream_wait(comm_stream, compute_stream) - - if previous_event is not None: - comm_stream.wait_event(previous_event) - else: - comm_stream.wait_stream(compute_stream) - - dispatch_fn = ( - self.mori_op.dispatch_send - if self.enable_sdma - else self.mori_op.dispatch - ) - ( - packed_recv_hidden, - recv_topk_weights, - recv_scales, - recv_topk_ids, - packed_recv_count, - ) = dispatch_fn(hidden_states, topk_weights, scale, topk_ids) - if self.enable_sdma: - self.mori_op.dispatch_recv() - - if self.async_finish: - done_event = torch.cuda.Event(blocking=False, interprocess=False) - done_event.record(comm_stream) - else: - compute_stream.wait_stream(comm_stream) - - for t in ( - packed_recv_hidden, - recv_topk_weights, - recv_scales, - recv_topk_ids, - ): - if t is not None: - t.record_stream(comm_stream) - else: - - ( - packed_recv_hidden, - recv_topk_weights, - recv_scales, - recv_topk_ids, - packed_recv_count, - ) = self.mori_op.dispatch(hidden_states, topk_weights, scale, topk_ids) - - # TODO(billishyahao): EPLB - # get_global_expert_distribution_recorder().on_deepep_dispatch_normal( - - return ( - packed_recv_hidden, - recv_topk_weights, - recv_scales, - recv_topk_ids, - packed_recv_count, - done_event, - ) - - def combine_a( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - ): - previous_event = self._capture_event_if_async() if self._comm_stream else None - return hidden_states, topk_ids, topk_weights, previous_event - - def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event): - - hidden_states, done_event = self._combine_core( - hidden_states, topk_ids, topk_weights, previous_event - ) - - if self._comm_stream and self.async_finish and done_event is not None: - torch.cuda.current_stream().wait_event(done_event) - - return hidden_states - - def _combine_core( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - previous_event: Optional[torch.cuda.Event], - ): - done_event: Optional[torch.cuda.Event] = None - - if self._comm_stream: - compute_stream = torch.cuda.current_stream() - comm_stream = self._comm_stream - - for t in (hidden_states, topk_ids, topk_weights): - t.record_stream(comm_stream) - - with torch.cuda.stream(comm_stream): - if previous_event is not None: - comm_stream.wait_event(previous_event) - else: - comm_stream.wait_stream(compute_stream) - - combine_fn = ( - self.mori_op.combine_send - if self.enable_sdma - else self.mori_op.combine - ) - combined_hidden_states = combine_fn(hidden_states, None, topk_ids)[0] - if self.enable_sdma: - self.mori_op.combine_recv() - - if self.async_finish: - done_event = torch.cuda.Event(blocking=False, interprocess=False) - done_event.record(comm_stream) - else: - compute_stream.wait_stream(comm_stream) - - combined_hidden_states.record_stream(comm_stream) - - else: - combined_hidden_states = self.mori_op.combine( - hidden_states, None, topk_ids - )[0] - - return combined_hidden_states, done_event - - def set_quant_config(self, quant_config: dict): - super().set_quant_config(quant_config) - - -class _MoriEPDispatcherImplLowLatency(_MoriEPDispatcherImplBase): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.quant_config = {} - self.fp8_quant_func = get_hip_quant(QuantType.per_1x128) - self.fp4_quant_func = get_hip_quant(QuantType.per_1x32) - - def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_output: TopKOutput, - ): - import mori - - assert ( - self.mori_op.config.kernel_type - is mori.ops.EpDispatchCombineKernelType.AsyncLL - ), "mori asyncll mismatch" - - num_tokens = hidden_states.shape[0] - output_dtype = hidden_states.dtype - scale = None - - if self.dispatch_dtype == DispatchDtype.fp8: - # FP8 quant - if num_tokens > 0: - # NOTE: aiter is able to handle token=0 case in UT. But for some - # reason it failed at e2e case. Root cause TBD. - hidden_states, scale = self.fp8_quant_func( - hidden_states, quant_dtype=fp8_dtype - ) - else: - hidden_states = torch.empty( - hidden_states.shape, dtype=fp8_dtype, device=hidden_states.device - ) - scale = torch.empty( - (0, self.hidden_size // FP8_BLOCK_SIZE), - dtype=torch.float32, - device=hidden_states.device, - ) - - elif self.dispatch_dtype == DispatchDtype.fp4: - # FP4 quant - if num_tokens > 0: - hidden_states, scale = self.fp4_quant_func(hidden_states, shuffle=False) - else: - hidden_states = torch.empty( - (0, self.hidden_size // 2), - dtype=torch.float4_e2m1fn_x2, - device=hidden_states.device, - ) - scale = torch.empty( - (0, self.hidden_size // MXFP4_BLOCK_SIZE), - dtype=torch.float8_e8m0fnu, - device=hidden_states.device, - ) - - topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids - - ( - packed_recv_hidden, - recv_topk_weights, - recv_scales, - recv_topk_ids, - packed_recv_count, - ) = self._dispatch_core(hidden_states, topk_weights, topk_ids, scale=scale) - - return ( - packed_recv_hidden, - recv_topk_weights, - recv_topk_ids, - recv_scales, - packed_recv_count, - topk_weights, - topk_ids, - output_dtype, - ) - - def dispatch_b( - self, - hidden_states, - recv_topk_weights, - recv_topk_ids, - recv_scales, - packed_recv_count, - topk_weights, - topk_ids, - output_dtype, - ): - - ##TODO(billishyahao): add assertion here to check async - import mori - - assert ( - self.mori_op.config.kernel_type - is mori.ops.EpDispatchCombineKernelType.AsyncLL - ), "mori asyncll mismatch" - - self.mori_op.dispatch_recv() - - return MoriEPLLDispatchOutput( - hidden_states=hidden_states, - hidden_states_scale=recv_scales, - topk_ids=recv_topk_ids, - topk_weights=recv_topk_weights, - num_recv_tokens_per_expert=packed_recv_count, - origin_topk_ids=topk_ids, - origin_topk_weights=topk_weights, - out_dtype=output_dtype, - ) - - def _dispatch_core( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - scale: Optional[torch.Tensor] = None, - ): - ##TODO(billishyahao): add assertion here to check async - - ( - packed_recv_hidden, - recv_topk_weights, - recv_scales, - recv_topk_ids, - packed_recv_count, - ) = self.mori_op.dispatch_send(hidden_states, topk_weights, scale, topk_ids) - - return ( - packed_recv_hidden, - recv_topk_weights, - recv_scales, - recv_topk_ids, - packed_recv_count, - ) - - def combine_a( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - overlap_args: Optional[CombineOverlapArgs] = None, - ): - hidden_states = self._combine_core( - hidden_states, - topk_ids, - topk_weights, - overlap_args=overlap_args, - ) - return hidden_states, topk_ids, topk_weights, overlap_args - - def combine_b(self, hidden_states, topk_ids, topk_weights, previous_event): - - self.mori_op.combine_recv() - - return hidden_states[0] - - def _combine_core( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - overlap_args: Optional[CombineOverlapArgs] = None, - ): - combined_hidden_states = self.mori_op.combine_send( - hidden_states, None, topk_ids - ) - - return combined_hidden_states - - def set_quant_config(self, quant_config: dict): - super().set_quant_config(quant_config) - - -@dataclass -class _Stage(Enum): - INITIAL = auto() - AFTER_DISPATCH_A = auto() - AFTER_DISPATCH_B = auto() - AFTER_COMBINE_A = auto() - - -class MoriEPDispatcher(BaseDispatcher): - def __init__( - self, - group: torch.distributed.ProcessGroup, - router_topk: int, - permute_fusion: bool = False, - num_experts: int = None, - num_local_experts: int = None, - hidden_size: int = None, - params_dtype: torch.dtype = None, - deepep_mode: DeepEPMode = DeepEPMode.AUTO, - async_finish: bool = False, - return_recv_hook: bool = False, - instance_id: int = 0, - ): - super().__init__() - - self.deepep_mode = deepep_mode - - async_mode = self.deepep_mode.enable_low_latency() - if get_bool_env_var("SGLANG_ROCM_USE_MULTI_STREAM") and not async_mode: - logger.warning_once( - "SGLANG_ROCM_USE_MULTI_STREAM=1 is set but Mori AsyncLL is " - "not enabled (--deepep-mode=%s). The alt-stream overlap only " - "frees up CUs when dispatch/combine runs on the AsyncLL " - "copy-engine kernel; otherwise it stays on CUs and competes " - "with the alt-stream work. Pass --deepep-mode low_latency " - "(or auto) to enable the AsyncLL kernel.", - self.deepep_mode.value, - ) - - common_kwargs = dict( - group=group, - router_topk=router_topk, - permute_fusion=permute_fusion, - num_experts=num_experts, - num_local_experts=num_local_experts, - hidden_size=hidden_size, - params_dtype=params_dtype, - deepep_mode=deepep_mode, - instance_id=instance_id, - ) - - if self.deepep_mode.enable_low_latency(): - self._low_latency_dispatcher = _MoriEPDispatcherImplLowLatency( - **common_kwargs, - ) - - if self.deepep_mode.enable_normal(): - self._normal_dispatcher = _MoriEPDispatcherImplNormal( - async_finish=async_finish, - **common_kwargs, - ) - - self._stage = _Stage.INITIAL - self._deepep_dispatch_hooks = MoriEPPDispatchHooks() - - def dispatch( - self, - hidden_states: torch.Tensor, - topk_output: TopKOutput, - ) -> DispatchOutput: - self.dispatch_a(hidden_states, topk_output) - if self._deepep_dispatch_hooks is not None: - self._deepep_dispatch_hooks(self) - ret = self.dispatch_b() - return ret - - def dispatch_a( - self, - hidden_states: torch.Tensor, - topk_output: TopKOutput, - ): - self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A) - inner_state = self._get_impl().dispatch_a( - hidden_states=hidden_states, - topk_output=topk_output, - ) - self._dispatch_intermediate_state = inner_state - - def dispatch_b(self): - self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) - inner_state = self._dispatch_intermediate_state - del self._dispatch_intermediate_state - return self._get_impl().dispatch_b(*inner_state) - - def combine( - self, - combine_input: CombineInput, - ) -> Tuple: - self.combine_a(combine_input) - ret = self.combine_b() - return ret - - def combine_a( - self, - combine_input: CombineInput, - ): - hidden_states, topk_ids, topk_weights = combine_input - self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) - inner_state = self._get_impl().combine_a( - hidden_states=hidden_states, - topk_ids=topk_ids, - topk_weights=topk_weights, - ) - self._combine_intermediate_state = inner_state - - def combine_b(self): - self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) - inner_state = self._combine_intermediate_state - del self._combine_intermediate_state - return self._get_impl().combine_b(*inner_state) - - def _get_impl(self) -> _MoriEPDispatcherImplBase: - is_extend_in_batch = get_is_extend_in_batch() - resolved_deepep_mode = self.deepep_mode.resolve(is_extend_in_batch) - if resolved_deepep_mode == DeepEPMode.NORMAL: - return self._normal_dispatcher - elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY: - return self._low_latency_dispatcher - else: - raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") - - def _update_stage(self, old_stage, new_stage): - assert self._stage == old_stage - self._stage = new_stage - - def set_quant_config(self, quant_config: dict): - super().set_quant_config(quant_config) - if self.deepep_mode.enable_low_latency(): - self._low_latency_dispatcher.set_quant_config(quant_config) - if self.deepep_mode.enable_normal(): - self._normal_dispatcher.set_quant_config(quant_config) - - def set_overlap_args( - self, combine_overlap_args: CombineOverlapArgs, meta_overlap_args: dict - ): - super().set_overlap_args(combine_overlap_args, meta_overlap_args) - if self.deepep_mode.enable_low_latency(): - self._low_latency_dispatcher.set_overlap_args( - combine_overlap_args, meta_overlap_args - ) - if self.deepep_mode.enable_normal(): - self._normal_dispatcher.set_overlap_args( - combine_overlap_args, meta_overlap_args - ) - - def clear_overlap_args(self): - super().clear_overlap_args() - if self.deepep_mode.enable_low_latency(): - self._low_latency_dispatcher.clear_overlap_args() - if self.deepep_mode.enable_normal(): - self._normal_dispatcher.clear_overlap_args() - - def register_deepep_dispatch_hook(self, hook): - return self._deepep_dispatch_hooks.register_hook(hook) diff --git a/benchmarks/multi_node/amd_utils/server_sglang.sh b/benchmarks/multi_node/amd_utils/server_sglang.sh index fda3aee54..8b0c8001e 100755 --- a/benchmarks/multi_node/amd_utils/server_sglang.sh +++ b/benchmarks/multi_node/amd_utils/server_sglang.sh @@ -49,6 +49,16 @@ GPUS_PER_NODE="${GPUS_PER_NODE:-8}" source $SGLANG_WS_PATH/setup_deps.sh source $SGLANG_WS_PATH/env.sh +# Root-cause fix for low-concurrency MoRI dispatch-buffer corruption: surgically +# floor num_max_dispatch_tokens_per_rank to >=256 in the installed (vendor-patched) +# sglang moriep.py, in place, before any sglang.launch_server starts. A full-file +# overlay can't be used here because the lmsysorg image ships a downstream-patched +# moriep.py (class MoriEPDispatcher / expert_mask_gpu) that diverges from upstream. +# See patches/apply_moriep_dispatch_floor.py and patches/README.md. +echo "[server_sglang] applying MoRI dispatch-floor patch to installed sglang moriep.py" +python3 "$SGLANG_WS_PATH/patches/apply_moriep_dispatch_floor.py" \ + || echo "[server_sglang] WARN: moriep dispatch-floor patch returned non-zero" + host_ip=$(ip route get 1.1.1.1 | awk '/src/ {print $7}') host_name=$(hostname) diff --git a/perf-changelog.yaml b/perf-changelog.yaml index 039be5042..d985a3d8c 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3455,5 +3455,5 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with a moriep.py overlay (patches/moriep.py, auto-mounted by job.slurm, gated on v0.5.12.post1) that floors num_max_dispatch_tokens_per_rank to 256 inside sglang. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the harness value collapses to 32, overrunning the dispatch kernel's receive slots (the only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged (corrupt run's ~3% edge was dropped work). Upstream: sgl-project/sglang#27194, ROCm/mori#356." + - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with an in-place patch (patches/apply_moriep_dispatch_floor.py, run by server_sglang.sh) that floors num_max_dispatch_tokens_per_rank to 256 inside the installed sglang moriep.py. NOTE a full-file overlay was tried first and crashed the scheduler (AttributeError: MoriEPDispatcher has no attribute expert_mask_gpu) because the lmsysorg image ships a downstream-patched moriep.py that diverges from the upstream v0.5.12.post1 tag; the surgical in-place patch preserves the vendor fork. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the value collapses to 32, overrunning the dispatch kernel's receive slots (only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged. Upstream: sgl-project/sglang#27194, ROCm/mori#356." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 From 79bb67aa9463de27c4b49acdaeaf58c82a06ecaa Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 21:04:38 -0700 Subject: [PATCH 10/15] fix moriep dispatch-floor patcher crash when sglang.__file__ is None MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The vendor image installs sglang as a namespace package where __file__ is None. os.path.dirname(None) throws TypeError, so the patcher crashed and the floor was never applied — eval ran unpatched. Fall through __file__ → __path__ → importlib.util.find_spec() to locate the package directory robustly. --- .../patches/apply_moriep_dispatch_floor.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py b/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py index f9263a0d0..c2088a00d 100644 --- a/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py +++ b/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py @@ -42,9 +42,31 @@ def find_target(): except Exception as e: # pragma: no cover print(f"{TAG} ERROR: could not import sglang ({e}); NOT patched") return None + + # sglang may be a namespace package (no __init__.py) where __file__ is + # None. Fall through several strategies to locate the package root. + pkg_dir = None + if getattr(sglang, "__file__", None) is not None: + pkg_dir = os.path.dirname(sglang.__file__) + elif getattr(sglang, "__path__", None): + pkg_dir = list(sglang.__path__)[0] + else: + try: + import importlib.util + spec = importlib.util.find_spec("sglang") + if spec and spec.submodule_search_locations: + pkg_dir = list(spec.submodule_search_locations)[0] + except Exception: + pass + + if pkg_dir is None: + print(f"{TAG} ERROR: could not determine sglang install path " + f"(__file__={getattr(sglang, '__file__', '?')}, " + f"__path__={getattr(sglang, '__path__', '?')}); NOT patched") + return None + path = os.path.join( - os.path.dirname(sglang.__file__), - "srt", "layers", "moe", "token_dispatcher", "moriep.py", + pkg_dir, "srt", "layers", "moe", "token_dispatcher", "moriep.py", ) if not os.path.isfile(path): print(f"{TAG} ERROR: moriep.py not found at {path}; NOT patched") From 9b50d698b0af2879afc0a30d008a087beb5117cc Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 21:10:21 -0700 Subject: [PATCH 11/15] note namespace-package patcher fix in changelog (re-trigger sweep) --- perf-changelog.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/perf-changelog.yaml b/perf-changelog.yaml index d985a3d8c..3613592ff 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3455,5 +3455,5 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with an in-place patch (patches/apply_moriep_dispatch_floor.py, run by server_sglang.sh) that floors num_max_dispatch_tokens_per_rank to 256 inside the installed sglang moriep.py. NOTE a full-file overlay was tried first and crashed the scheduler (AttributeError: MoriEPDispatcher has no attribute expert_mask_gpu) because the lmsysorg image ships a downstream-patched moriep.py that diverges from the upstream v0.5.12.post1 tag; the surgical in-place patch preserves the vendor fork. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the value collapses to 32, overrunning the dispatch kernel's receive slots (only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged. Upstream: sgl-project/sglang#27194, ROCm/mori#356." + - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with an in-place patch (patches/apply_moriep_dispatch_floor.py, run by server_sglang.sh) that floors num_max_dispatch_tokens_per_rank to 256 inside the installed sglang moriep.py. NOTE a full-file overlay was tried first and crashed the scheduler (AttributeError: MoriEPDispatcher has no attribute expert_mask_gpu) because the lmsysorg image ships a downstream-patched moriep.py that diverges from the upstream v0.5.12.post1 tag; the surgical in-place patch preserves the vendor fork. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the value collapses to 32, overrunning the dispatch kernel's receive slots (only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged. Upstream: sgl-project/sglang#27194, ROCm/mori#356. Also fixes namespace-package crash in patcher (sglang.__file__ is None in lmsysorg vendor image)." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 From 9fb87c1d660884271a1c46e432e5b34bddd08285 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 22:36:15 -0700 Subject: [PATCH 12/15] fix moriep patcher path: handle vendor image python/sglang/ layout The vendor image installs sglang at /sgl-workspace/sglang/ but the actual Python package is under python/sglang/ within that tree. When __path__ returns the repo root, the patcher couldn't find moriep.py. Add candidate paths and a bounded walk fallback. --- .../patches/apply_moriep_dispatch_floor.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py b/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py index c2088a00d..c2c8f5ecb 100644 --- a/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py +++ b/benchmarks/multi_node/amd_utils/patches/apply_moriep_dispatch_floor.py @@ -65,13 +65,28 @@ def find_target(): f"__path__={getattr(sglang, '__path__', '?')}); NOT patched") return None - path = os.path.join( - pkg_dir, "srt", "layers", "moe", "token_dispatcher", "moriep.py", - ) - if not os.path.isfile(path): - print(f"{TAG} ERROR: moriep.py not found at {path}; NOT patched") - return None - return path + rel = os.path.join("srt", "layers", "moe", "token_dispatcher", "moriep.py") + candidates = [ + os.path.join(pkg_dir, rel), + os.path.join(pkg_dir, "python", "sglang", rel), + ] + for path in candidates: + if os.path.isfile(path): + return path + + # Last resort: walk the tree (bounded to 6 levels to avoid scanning /). + for root, _dirs, files in os.walk(pkg_dir): + if root.count(os.sep) - pkg_dir.count(os.sep) > 6: + _dirs.clear() + continue + if "moriep.py" in files: + found = os.path.join(root, "moriep.py") + print(f"{TAG} found moriep.py via walk: {found}") + return found + + print(f"{TAG} ERROR: moriep.py not found under {pkg_dir} " + f"(tried {candidates}); NOT patched") + return None def main(): From 59dd6c3323d538096b97bca6ad4cec5e3370e3d3 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 22:36:34 -0700 Subject: [PATCH 13/15] update changelog: note patcher path fix (re-trigger sweep) --- perf-changelog.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/perf-changelog.yaml b/perf-changelog.yaml index adbae9d72..cde0c905f 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3455,7 +3455,7 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with an in-place patch (patches/apply_moriep_dispatch_floor.py, run by server_sglang.sh) that floors num_max_dispatch_tokens_per_rank to 256 inside the installed sglang moriep.py. NOTE a full-file overlay was tried first and crashed the scheduler (AttributeError: MoriEPDispatcher has no attribute expert_mask_gpu) because the lmsysorg image ships a downstream-patched moriep.py that diverges from the upstream v0.5.12.post1 tag; the surgical in-place patch preserves the vendor fork. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the value collapses to 32, overrunning the dispatch kernel's receive slots (only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged. Upstream: sgl-project/sglang#27194, ROCm/mori#356. Also fixes namespace-package crash in patcher (sglang.__file__ is None in lmsysorg vendor image)." + - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with an in-place patch (patches/apply_moriep_dispatch_floor.py, run by server_sglang.sh) that floors num_max_dispatch_tokens_per_rank to 256 inside the installed sglang moriep.py. NOTE a full-file overlay was tried first and crashed the scheduler (AttributeError: MoriEPDispatcher has no attribute expert_mask_gpu) because the lmsysorg image ships a downstream-patched moriep.py that diverges from the upstream v0.5.12.post1 tag; the surgical in-place patch preserves the vendor fork. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the value collapses to 32, overrunning the dispatch kernel's receive slots (only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged. Upstream: sgl-project/sglang#27194, ROCm/mori#356. Also fixes namespace-package crash (sglang.__file__ is None) and python/sglang/ subdirectory layout in patcher path resolution." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 - config-keys: From fe96b05682dd3deabe12210375ad3a22b51e7564 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 22:51:23 -0700 Subject: [PATCH 14/15] fix MoRI dispatch token formula: --max-running-requests is per DP rank The formula divided BENCH_MAX_CONC_VALUE by decode_dp_ranks, assuming --max-running-requests is a global limit split across DP ranks. It is actually per-rank: each of the 8 DP schedulers independently allows up to BENCH_MAX_CONC_VALUE requests. At conc-64/TP8/MTP3 the old formula produced dispatch=32 (64/8*4), but each rank can hold 64*4=256 tokens, causing 8x buffer overflow in MoRI's intra-node dispatch kernel (the only guard is an assert compiled out under -DNDEBUG) and silent corruption (gsm8k=0). --- benchmarks/multi_node/amd_utils/server_sglang.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/multi_node/amd_utils/server_sglang.sh b/benchmarks/multi_node/amd_utils/server_sglang.sh index 8b0c8001e..f0bfb9d6f 100755 --- a/benchmarks/multi_node/amd_utils/server_sglang.sh +++ b/benchmarks/multi_node/amd_utils/server_sglang.sh @@ -223,7 +223,9 @@ fi if [[ "$DECODE_ENABLE_DP" == "true" ]] && [[ "$DECODE_ENABLE_EP" == "true" ]]; then decode_max_running_requests=$BENCH_MAX_CONC_VALUE decode_dp_ranks=$DECODE_TP_SIZE - MORI_MAX_DISPATCH_TOKENS_DECODE=$((BENCH_MAX_CONC_VALUE / decode_dp_ranks)) + # --max-running-requests is PER DP RANK (not global); each rank can hold + # up to BENCH_MAX_CONC_VALUE requests, so dispatch tokens = that capacity. + MORI_MAX_DISPATCH_TOKENS_DECODE=$BENCH_MAX_CONC_VALUE MORI_MOE_MAX_INPUT_TOKENS_DECODE=$((MORI_MAX_DISPATCH_TOKENS_DECODE * decode_dp_ranks * 7 / 10)) # Update derived variable SGLANG_MORI_DISPATCH_INTER_KERNEL_SWITCH_THRESHOLD=$((MORI_MAX_DISPATCH_TOKENS_DECODE * 2)) From 1fdb89f25fd4853fa402e5ab689d8fbff68562d0 Mon Sep 17 00:00:00 2001 From: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com> Date: Wed, 3 Jun 2026 22:51:28 -0700 Subject: [PATCH 15/15] update changelog: root-cause formula fix (re-trigger sweep) --- perf-changelog.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/perf-changelog.yaml b/perf-changelog.yaml index cde0c905f..cd2f8d1f1 100644 --- a/perf-changelog.yaml +++ b/perf-changelog.yaml @@ -3455,7 +3455,7 @@ - config-keys: - dsr1-fp4-mi355x-sglang-disagg-8k1k-mtp description: - - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with an in-place patch (patches/apply_moriep_dispatch_floor.py, run by server_sglang.sh) that floors num_max_dispatch_tokens_per_rank to 256 inside the installed sglang moriep.py. NOTE a full-file overlay was tried first and crashed the scheduler (AttributeError: MoriEPDispatcher has no attribute expert_mask_gpu) because the lmsysorg image ships a downstream-patched moriep.py that diverges from the upstream v0.5.12.post1 tag; the surgical in-place patch preserves the vendor fork. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the value collapses to 32, overrunning the dispatch kernel's receive slots (only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged. Upstream: sgl-project/sglang#27194, ROCm/mori#356. Also fixes namespace-package crash (sglang.__file__ is None) and python/sglang/ subdirectory layout in patcher path resolution." + - "Root-cause fix for MoRI dispatch-buffer corruption at low concurrency: replace the harness env clamp (bandaid) with an in-place patch (patches/apply_moriep_dispatch_floor.py, run by server_sglang.sh) that floors num_max_dispatch_tokens_per_rank to 256 inside the installed sglang moriep.py. NOTE a full-file overlay was tried first and crashed the scheduler (AttributeError: MoriEPDispatcher has no attribute expert_mask_gpu) because the lmsysorg image ships a downstream-patched moriep.py that diverges from the upstream v0.5.12.post1 tag; the surgical in-place patch preserves the vendor fork. The per-rank All2All receive buffer is sized worldSize*maxNumInpTokenPerRank; at conc-64/TP8/MTP3 the value collapses to 32, overrunning the dispatch kernel's receive slots (only guard is an assert compiled out under -DNDEBUG) -> silent corruption (decodes fine, gsm8k=0). Confirmed on MI355X: dispatch=32->0.00, 64->0.00 (one wavefront insufficient), >=256->0.94. Throughput unchanged. Upstream: sgl-project/sglang#27194, ROCm/mori#356. Also fixes the root-cause harness formula (BENCH_MAX_CONC_VALUE/dp_ranks was wrong: --max-running-requests is per-DP-rank, not global) and patcher path resolution for vendor image layout." pr-link: https://github.com/SemiAnalysisAI/InferenceX/pull/1659 - config-keys: