Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 99 additions & 31 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,51 +368,55 @@ def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
self.decode_kv_args_table[agent_name] = decode_kv_args
self.agent.add_remote_agent(decode_kv_args.agent_metadata)

def send_kvcache(
def _send_kvcache_generic(
self,
peer_name: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
src_data_ptrs: list[int],
dst_data_ptrs: list[int],
item_lens: list[int],
prefill_data_indices: npt.NDArray[np.int32],
dst_data_indices: npt.NDArray[np.int32],
dst_gpu_id: int,
notif: str,
):
"""Generic KV cache transfer supporting both MHA and MLA architectures.
Used by both send_kvcache and maybe_send_extra."""
# group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
prefill_data_indices, dst_data_indices
)

logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
# Make descs
if self.is_mla_backend:
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
)
layers_params = [
(
src_kv_ptrs[layer_id],
dst_kv_ptrs[layer_id],
self.kv_args.kv_item_lens[layer_id],
item_lens[layer_id],
)
for layer_id in range(layers_current_pp_stage)
]
else:
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
)

layers_params = [
(
src_k_ptrs[layer_id],
dst_k_ptrs[layer_id],
self.kv_args.kv_item_lens[layer_id],
item_lens[layer_id],
)
for layer_id in range(layers_current_pp_stage)
] + [
(
src_v_ptrs[layer_id],
dst_v_ptrs[layer_id],
self.kv_args.kv_item_lens[layer_id],
item_lens[layer_id],
)
for layer_id in range(layers_current_pp_stage)
]
Expand Down Expand Up @@ -455,7 +459,7 @@ def make_req_array(addr_chunks, len_chunks, gpu):
dst_reqs = make_req_array(dst_addrs, dst_lens, dst_gpu_id)

logger.debug(
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
f"len(src_addrs): before group: {len(prefill_data_indices)}, after group: {len(src_addrs)}"
)
src_descs = self.agent.get_xfer_descs(src_reqs, "VRAM")
dst_descs = self.agent.get_xfer_descs(dst_reqs, "VRAM")
Expand All @@ -474,6 +478,26 @@ def make_req_array(addr_chunks, len_chunks, gpu):
raise Exception("KVSender failed to post transfer")
return xfer_handle

def send_kvcache(
self,
peer_name: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
dst_gpu_id: int,
notif: str,
):
return self._send_kvcache_generic(
peer_name=peer_name,
src_data_ptrs=self.kv_args.kv_data_ptrs,
dst_data_ptrs=dst_kv_ptrs,
item_lens=self.kv_args.kv_item_lens,
prefill_data_indices=prefill_kv_indices,
dst_data_indices=dst_kv_indices,
dst_gpu_id=dst_gpu_id,
notif=notif,
)

def send_kvcache_slice(
self,
peer_name: str,
Expand Down Expand Up @@ -684,6 +708,59 @@ def _send_mamba_state(
raise Exception("Failed to post Mamba state transfer")
return xfer_handle

def maybe_send_extra(
self,
peer_name: str,
prefill_state_indices: List[int],
dst_state_data_ptrs: list[int],
dst_state_indices: List[int],
dst_gpu_id: int,
notif: str,
decode_tp_size: int,
):
"""Send state or extra pool data with type-specific handling."""
state_type = getattr(self.kv_args, "state_type", "none")

if state_type == "mamba":
if self.attn_tp_size != decode_tp_size:
raise RuntimeError(
"PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet."
)
return self._send_mamba_state(
peer_name,
prefill_state_indices,
dst_state_data_ptrs,
dst_state_indices,
dst_gpu_id,
notif,
)
elif state_type in ["swa", "nsa"]:
if not self.is_mla_backend and self.attn_tp_size != decode_tp_size:
raise RuntimeError(
f"PD Disaggregation does NOT support PD different TP sizes for non-MLA {state_type.upper()} hybrid models yet."
)
if len(prefill_state_indices) != len(dst_state_indices):
raise RuntimeError(
f"State index length mismatch: prefill={len(prefill_state_indices)}, "
f"dst={len(dst_state_indices)}"
)
return self._send_kvcache_generic(
peer_name=peer_name,
src_data_ptrs=self.kv_args.state_data_ptrs,
dst_data_ptrs=dst_state_data_ptrs,
item_lens=self.kv_args.state_item_lens,
prefill_data_indices=np.array(prefill_state_indices, dtype=np.int32),
dst_data_indices=np.array(dst_state_indices, dtype=np.int32),
dst_gpu_id=dst_gpu_id,
notif=notif,
)
else:
if state_type != "none":
raise RuntimeError(
f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet."
)
return None
Comment on lines +758 to +762
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability, you could explicitly check for state_type == "none" and then handle unexpected state types in an else block. This makes the logic for handling unknown state types clearer.

Suggested change
if state_type != "none":
raise RuntimeError(
f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet."
)
return None
if state_type == "none":
return None
else:
raise RuntimeError(
f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet."
)


def add_transfer_request(
self,
bootstrap_room: int,
Expand Down Expand Up @@ -742,26 +819,17 @@ def add_transfer_request(
# Only the last chunk we need to send the aux data.
if is_last:
if state_indices is not None:
state_type = getattr(self.kv_args, "state_type", "none")
if (
self.attn_tp_size
!= self.decode_kv_args_table[req.agent_name].decode_tp_size
):
raise RuntimeError(
"PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet."
)

if state_type == "mamba":
state_xfer_handle = self._send_mamba_state(
req.agent_name,
state_indices,
self.decode_kv_args_table[
req.agent_name
].dst_state_data_ptrs,
req.dst_state_indices,
self.decode_kv_args_table[req.agent_name].gpu_id,
f"{req.room}_state_{self.kv_args.pp_rank}",
)
dst_info = self.decode_kv_args_table[req.agent_name]
state_xfer_handle = self.maybe_send_extra(
req.agent_name,
state_indices,
dst_info.dst_state_data_ptrs,
req.dst_state_indices,
dst_info.gpu_id,
f"{req.room}_state_{self.kv_args.pp_rank}",
decode_tp_size,
)
if state_xfer_handle is not None:
handles.append(state_xfer_handle)

assert aux_index is not None
Expand Down
Loading