Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ def get_mha_kv_ptrs_with_pp(
num_kv_layers < dst_num_total_layers
and dst_num_total_layers % num_kv_layers != 0
):
# Case: Decode has more layers than Prefill (e.g., Decode has draft model KV while Prefill is deployed without speculative decoding)
# To prevent empty Value Cache, which leads to wrong response
# Case: Decode has draft model KV while Prefill is deployed without speculative decoding
# dst_kv_ptrs layout: [K_main..., V_main..., draft_K..., draft_V...]
multiplier_ratio = dst_num_total_layers // num_kv_layers
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
dst_v_ptrs = dst_kv_ptrs[
num_kv_layers + start_layer : num_kv_layers + end_layer
num_kv_layers * multiplier_ratio + start_layer : num_kv_layers * multiplier_ratio + end_layer
]
else:
# Decode pp size should be equal to prefill pp size or 1
Expand Down
Loading