Skip to content

feat: add nsa and swa disagg support with nixl#18939

Merged
ShangmingCai merged 3 commits intosgl-project:mainfrom
nealvaidya:nealv/nixl_state_fix
Feb 18, 2026
Merged

feat: add nsa and swa disagg support with nixl#18939
ShangmingCai merged 3 commits intosgl-project:mainfrom
nealvaidya:nealv/nixl_state_fix

Conversation

@nealvaidya
Copy link
Contributor

Motivation

Fixes a bug reported on Slack where NIXL PD disaggregation fails with transfer timeouts for hybrid attention models that use SWA (sliding window attention) or NSA (native sparse attention) state types — e.g., DeepSeek V3.2 with Wide EP.

The NIXL connector's add_transfer_request previously only handled state_type == "mamba" via _send_mamba_state, but had no path for "swa" or "nsa" state types. When a hybrid model with SWA/NSA state was used, the state data was silently never transferred, causing the decode side to wait indefinitely for the state notification and eventually time out.

The mooncake connector already handled this correctly via its maybe_send_extra method, which dispatches "swa"/"nsa" state types through _send_kvcache_generic. This PR mirrors that approach in the NIXL connector.

Modifications

Modifications

  • python/sglang/srt/disaggregation/nixl/conn.py:
    • Added maybe_send_extra() method to NixlKVManager, following the mooncake implementation pattern. Dispatches based on state_type:
      • "mamba" → delegates to existing _send_mamba_state (no change in behavior)
      • "swa" / "nsa" → reuses _send_kvcache_generic with state_data_ptrs and state_item_lens instead of kv_data_ptrs/kv_item_lens
      • "none" → returns None (no state transfer needed)
    • Refactored _send_kvcache_generic to accept generic src_data_ptrs/dst_data_ptrs/item_lens so it can be reused for both KV cache and state transfers

Accuracy Tests

  • Mamba (Qwen3-Next-80B-A3B-Instruct): GSM8K 5-shot accuracy >0.93 (existing test, unchanged)
  • SWA (gpt-oss-20b): GSM8K 5-shot accuracy ~0.50 disaggregated vs ~0.53 non-disaggregated baseline (model is weak on this benchmark )
  • NSA (DeepSeek V3.2): Tested on K8s cluster with 8xH200 nodes (TP=8, EP=8 for both prefill and decode). Completions return successfully with no transfer timeouts — confirms the original reported bug is fixed.

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Signed-off-by: Neal Vaidya <nealv@nvidia.com>
Signed-off-by: Neal Vaidya <nealv@nvidia.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nealvaidya, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the NIXL disaggregation connector that previously caused transfer timeouts for hybrid attention models utilizing Sliding Window Attention (SWA) or Native Sparse Attention (NSA). The connector was initially designed to only handle 'mamba' state types, leading to silent failures when other state configurations were encountered. The changes introduce a generalized state transfer mechanism by refactoring the _send_kvcache_generic method and adding a new maybe_send_extra method. This ensures that SWA and NSA models can now correctly transfer their state data during disaggregation, preventing timeouts and enabling their proper operation.

Highlights

  • Enhanced State Type Support: Implemented support for NSA (Native Sparse Attention) and SWA (Sliding Window Attention) state types within the NIXL disaggregation connector, resolving previous transfer timeout issues for these models.
  • Generic KV Cache Transfer Refactoring: Refactored the internal _send_kvcache_generic method to be more flexible, allowing it to handle both standard KV cache transfers and generic state data transfers by accepting generalized data pointers and lengths.
  • New State Dispatch Method: Introduced a new maybe_send_extra method to NixlKVManager that intelligently dispatches state transfers based on the state_type (e.g., 'mamba', 'swa', 'nsa'), mirroring the approach used in the mooncake connector.
  • Bug Fix for Hybrid Models: Fixed a bug that caused transfer timeouts for hybrid attention models (like DeepSeek V3.2 with Wide EP) when using SWA or NSA state types during PD disaggregation, ensuring successful completions.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/disaggregation/nixl/conn.py
    • Renamed the original send_kvcache method to _send_kvcache_generic and generalized its parameters to accept generic source/destination data pointers and item lengths.
    • Updated a logging message within the _send_kvcache_generic method to use prefill_data_indices for clarity.
    • Reintroduced a public send_kvcache method that now acts as a wrapper, calling the refactored _send_kvcache_generic with KV cache specific arguments.
    • Added a new maybe_send_extra method to NixlKVManager to handle conditional state transfers based on state_type ('mamba', 'swa', 'nsa', 'none').
    • Modified the add_transfer_request method to utilize the new maybe_send_extra method for processing state transfers, replacing the previous direct _send_mamba_state call.
Activity
  • No human activity has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a bug in NIXL disaggregation for models using SWA and NSA state types. The refactoring of send_kvcache into a more generic _send_kvcache_generic is a solid approach that improves code reuse and maintainability. The new maybe_send_extra method correctly handles different state types, aligning the NIXL connector with the existing mooncake implementation. The changes are logical and well-executed. I have one minor suggestion to enhance code readability.

Comment on lines +758 to +762
if state_type != "none":
raise RuntimeError(
f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet."
)
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability, you could explicitly check for state_type == "none" and then handle unexpected state types in an else block. This makes the logic for handling unknown state types clearer.

Suggested change
if state_type != "none":
raise RuntimeError(
f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet."
)
return None
if state_type == "none":
return None
else:
raise RuntimeError(
f"PD Disaggregation via NIXL does NOT support {state_type} hybrid models yet."
)

@rainj-me
Copy link
Collaborator

I have tested, it works with DS V3.2 NVFP4 model.

@hnyls2002
Copy link
Collaborator

hnyls2002 commented Feb 17, 2026

/rerun-stage stage-c-test-8-gpu-h200

@sgl-project sgl-project deleted a comment from github-actions bot Feb 17, 2026
@github-actions
Copy link
Contributor

✅ Triggered stage-c-test-8-gpu-h200 to run independently (skipping dependencies).

@github-actions
Copy link
Contributor

🔗 View workflow run

Copy link
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ShangmingCai
Copy link
Collaborator

/rerun-stage stage-c-test-8-gpu-h20

@github-actions
Copy link
Contributor

✅ Triggered stage-c-test-8-gpu-h20 to run independently (skipping dependencies).

@github-actions
Copy link
Contributor

🔗 View workflow run

@ShangmingCai ShangmingCai merged commit ac0e493 into sgl-project:main Feb 18, 2026
53 of 61 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants