Kernel: optimize decoding metadata in NSA multi-spec backend with fused kernels#17554
Merged
Fridge003 merged 9 commits intosgl-project:mainfrom Feb 14, 2026
Merged
Conversation
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
Collaborator
|
Can you please move the kernels to jit folder, thanks~ |
f1744cc to
b59ce0b
Compare
8533df9 to
f9cd4d8
Compare
Contributor
Author
d920fe3 to
094323f
Compare
Contributor
Author
Done, with re-performance Accuracy Tests. Thanks @Fridge003 ! |
python/sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh
Outdated
Show resolved
Hide resolved
af349b7 to
ceec97a
Compare
python/sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh
Outdated
Show resolved
Hide resolved
python/sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh
Outdated
Show resolved
Hide resolved
python/sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh
Outdated
Show resolved
Hide resolved
python/sglang/jit_kernel/csrc/elementwise/fused_metadata_copy.cuh
Outdated
Show resolved
Hide resolved
Collaborator
|
/tag-and-rerun-ci |
… parameters
Consolidate three separate CUDA kernels (decode, target_verify, draft_extend)
into a single unified kernel with runtime mode selection and structured
parameter passing.
Key changes:
- Merge fused_metadata_copy_{decode,target_verify,draft_extend}_kernel into
single fused_metadata_copy_kernel with runtime forward_mode branching
- Introduce structured parameter passing via SourcePointers, DestinationPointers,
FusedMetadataCopyParams, and FusedMetadataCopyMultiParams structs
- Use __grid_constant__ attribute for efficient constant memory parameter access
- Reduce parameter count from 29 individual parameters to 1 struct
- Maintain compile-time optimization via template parameters (HAS_REAL_PAGE_TABLE,
HAS_FLASHMLA)
- Update multi-backend kernel with same structured parameter pattern
Benefits:
- Code reduction: 233 fewer lines
- Eliminates code duplication across three nearly-identical kernels
- Improves maintainability: modify logic once instead of three times
- Cleaner API: structured parameters group related pointers logically
- Better alignment with other kernels (qknorm.cuh, rmsnorm.cuh patterns)
No performance regression expected as hot paths remain optimized via template
specialization and __grid_constant__ provides efficient parameter access.
…used metadata copy Replace throw statements with RuntimeCheck and add type-safe helper functions for tensor data pointer extraction. Switch from relying on torch.empty(0) null pointer behavior to proper tvm::ffi::Optional<TensorView> for optional tensors. Changes: - Add unwrap_data_ptr/unwrap_optional_data_ptr helper functions with integrated dtype validation using RuntimeCheck - Update function signatures to use Optional<TensorView> for seqlens_expanded, real_page_table, and flashmla tensors - Replace .data_ptr() null checks with .has_value()/.value() pattern - Update Python wrapper to pass None directly instead of empty tensors - Remove _make_empty_tensor_if_none helper (no longer needed) Addresses review feedback from @DarkSharpness
- Add linear_metadata.py for linear attention metadata handling - Update communicator_nsa_cp.py for context parallel support - Add test_nsa_pool_host_unit.py for HiCache NSA pool testing - Update test_nsa_indexer.py with new test cases - Update discover_metadata.rs for Rust gateway support - Improve error messages in fused_metadata_copy.cuh
3e7e878 to
50b1638
Compare
5 tasks
Collaborator
|
Awesome job! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.




Motivation
Implement fused CUDA kernels to eliminate redundant metadata copies in Native Sparse Attention (NSA) backend during CUDA graph replay for speculative decoding. This optimization provides 3-5x speedup for multi-backend metadata operations.
Changes
Core Implementation
fused_metadata_copy_cuda: Single-backend fused kernel supporting DECODE, TARGET_VERIFY, and DRAFT_EXTEND forward modesfused_metadata_copy_multi_cuda: Multi-backend kernel that copies metadata to 3 backends simultaneously in a single kernel launchRuntime Optimization
nsa_backend.pyto intelligently use fused kernels:speculative_num_steps >= 3: Use fused kernel for first 3 backends,then copy remaining backends individually
speculative_num_steps < 3: Use individual copy (overhead not worth it)Testing
Performance Impact
CPU side
Before:


After:
GPU side
Before


After
Performance Improvements
CPU-side metadata processing:
Kernel execution time:
End-to-end throughput:
Technical Details
The fused kernels handle:
Files Modified
sgl-kernel/csrc/elementwise/fused_metadata_copy.cu: CUDA kernel implsgl-kernel/python/sgl_kernel/elementwise.py: Python bindingspython/sglang/srt/layers/attention/nsa_backend.py: Runtime integrationsgl-kernel/tests/test_fused_metadata_copy.py: Comprehensive testsTested: All 40 tests passing (100% pass rate)
Modifications
Accuracy Tests
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319Service:
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 120000 --repeat 8 --thinking-mode deepseek-v3python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --tool-call-parser deepseekv32 --reasoning-parser deepseek-v3After the above tests, using both kernels and verified every steps
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci