Skip to content

Commit e20e6c2

Browse files
1am9trashyctseng0211kkHuang-amd
authored
[AMD] Fix accuracy issue when running TP4 dsv3 model with mtp (sgl-project#18607)
Co-authored-by: YC Tseng <yctseng@amd.com> Co-authored-by: kkHuang-amd <wunhuang@amd.com>
1 parent d6f0ef6 commit e20e6c2

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

docker/rocm.Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ ENV BUILD_TRITON="0"
2121
ENV BUILD_LLVM="0"
2222
ENV BUILD_AITER_ALL="1"
2323
ENV BUILD_MOONCAKE="1"
24-
ENV AITER_COMMIT="v0.1.10.post2"
24+
ENV AITER_COMMIT="v0.1.10.post3"
2525

2626
# ===============================
2727
# Base image 950 and args
@@ -31,7 +31,7 @@ ENV BUILD_TRITON="0"
3131
ENV BUILD_LLVM="0"
3232
ENV BUILD_AITER_ALL="1"
3333
ENV BUILD_MOONCAKE="1"
34-
ENV AITER_COMMIT="v0.1.10.post2"
34+
ENV AITER_COMMIT="v0.1.10.post3"
3535
# ===============================
3636
# Chosen arch and args
3737
FROM ${GPU_ARCH}

python/sglang/srt/layers/attention/aiter_backend.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,15 @@ def __init__(
195195
)
196196
global _use_mla_ps_kernel, fast_mode, intra_batch_mode
197197

198+
if self.num_head == 32:
199+
fast_mode = True
200+
intra_batch_mode = False
201+
198202
# current persist a16w16 mla_decode kernel does not support head_num = 128
199203
# need to fall back to non-persist
200204
# only use mla_ps_kernel when fp8 kv_cache
201-
# for non-fp8 kv_cache, use non-persist kernel to avoid performance degradation
202-
if self.kv_cache_dtype is not fp8_dtype:
205+
# for non-fp8 kv_cache on tp8, use non-persist kernel to avoid performance degradation
206+
if self.num_head == 16 and self.kv_cache_dtype is not fp8_dtype:
203207
_use_mla_ps_kernel = False
204208
fast_mode = False
205209
intra_batch_mode = False
@@ -301,7 +305,7 @@ def make_mla_meta_data(
301305
kv_last_page_len,
302306
self.num_head // nhead_kv,
303307
nhead_kv,
304-
True,
308+
False,
305309
work_metadata,
306310
work_info_set,
307311
work_indptr,

0 commit comments

Comments
 (0)