File tree Expand file tree Collapse file tree 2 files changed +9
-5
lines changed
python/sglang/srt/layers/attention Expand file tree Collapse file tree 2 files changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -21,7 +21,7 @@ ENV BUILD_TRITON="0"
2121ENV BUILD_LLVM="0"
2222ENV BUILD_AITER_ALL="1"
2323ENV BUILD_MOONCAKE="1"
24- ENV AITER_COMMIT="v0.1.10.post2 "
24+ ENV AITER_COMMIT="v0.1.10.post3 "
2525
2626# ===============================
2727# Base image 950 and args
@@ -31,7 +31,7 @@ ENV BUILD_TRITON="0"
3131ENV BUILD_LLVM="0"
3232ENV BUILD_AITER_ALL="1"
3333ENV BUILD_MOONCAKE="1"
34- ENV AITER_COMMIT="v0.1.10.post2 "
34+ ENV AITER_COMMIT="v0.1.10.post3 "
3535# ===============================
3636# Chosen arch and args
3737FROM ${GPU_ARCH}
Original file line number Diff line number Diff line change @@ -195,11 +195,15 @@ def __init__(
195195 )
196196 global _use_mla_ps_kernel , fast_mode , intra_batch_mode
197197
198+ if self .num_head == 32 :
199+ fast_mode = True
200+ intra_batch_mode = False
201+
198202 # current persist a16w16 mla_decode kernel does not support head_num = 128
199203 # need to fall back to non-persist
200204 # only use mla_ps_kernel when fp8 kv_cache
201- # for non-fp8 kv_cache, use non-persist kernel to avoid performance degradation
202- if self .kv_cache_dtype is not fp8_dtype :
205+ # for non-fp8 kv_cache on tp8 , use non-persist kernel to avoid performance degradation
206+ if self .num_head == 16 and self . kv_cache_dtype is not fp8_dtype :
203207 _use_mla_ps_kernel = False
204208 fast_mode = False
205209 intra_batch_mode = False
@@ -301,7 +305,7 @@ def make_mla_meta_data(
301305 kv_last_page_len ,
302306 self .num_head // nhead_kv ,
303307 nhead_kv ,
304- True ,
308+ False ,
305309 work_metadata ,
306310 work_info_set ,
307311 work_indptr ,
You can’t perform that action at this time.
0 commit comments