Skip to content

Kernel: optimize decoding metadata in NSA multi-spec backend with fused kernels#17554

Merged
Fridge003 merged 9 commits intosgl-project:mainfrom
Johnsonms:nsa-metadata-copy-kernel-v2
Feb 14, 2026
Merged

Kernel: optimize decoding metadata in NSA multi-spec backend with fused kernels#17554
Fridge003 merged 9 commits intosgl-project:mainfrom
Johnsonms:nsa-metadata-copy-kernel-v2

Conversation

@Johnsonms
Copy link
Contributor

@Johnsonms Johnsonms commented Jan 22, 2026

Motivation

Implement fused CUDA kernels to eliminate redundant metadata copies in Native Sparse Attention (NSA) backend during CUDA graph replay for speculative decoding. This optimization provides 3-5x speedup for multi-backend metadata operations.

Changes

Core Implementation

  • Add fused_metadata_copy_cuda: Single-backend fused kernel supporting DECODE, TARGET_VERIFY, and DRAFT_EXTEND forward modes
  • Add fused_metadata_copy_multi_cuda: Multi-backend kernel that copies metadata to 3 backends simultaneously in a single kernel launch

Runtime Optimization

  • Update nsa_backend.py to intelligently use fused kernels:
  • speculative_num_steps >= 3: Use fused kernel for first 3 backends,
    then copy remaining backends individually
  • speculative_num_steps < 3: Use individual copy (overhead not worth it)

Testing

  • Create comprehensive test suite covering:
  • Single-backend kernel: all forward modes, optional tensors
  • Multi-backend kernel: 3-backend simultaneous copy
  • Performance benchmarks with timing and speedup measurements

Performance Impact

CPU side

Before:
image
After:
image

GPU side

Before
image
After
image

Performance Improvements

CPU-side metadata processing:

  • Before: 193μs (24 kernel launches)
  • After: 12μs (1 fused kernel launch)
  • Speedup: 16x faster (193μs → 12μs)

Kernel execution time:

  • Before: 177μs
  • After: 4.7μs
  • Speedup: 37.6x faster (177μs → 4.7μs)

End-to-end throughput:

  • Before: 151 tokens/sec
  • After: 179 tokens/sec
  • Improvement: +18.5% (+28 tokens/sec)

Technical Details

The fused kernels handle:

  • Basic metadata: cache_seqlens, cu_seqlens_k, page_table, nsa_cache_seqlens
  • Optional tensors: real_page_table, flashmla_num_splits, flashmla_metadata
  • All copies done in single kernel launch, reducing PCIe/memory overhead

Files Modified

  • sgl-kernel/csrc/elementwise/fused_metadata_copy.cu: CUDA kernel impl
  • sgl-kernel/python/sgl_kernel/elementwise.py: Python bindings
  • python/sglang/srt/layers/attention/nsa_backend.py: Runtime integration
  • sgl-kernel/tests/test_fused_metadata_copy.py: Comprehensive tests

Tested: All 40 tests passing (100% pass rate)

Modifications

Accuracy Tests

  1. Accuracy Test with gsm8k
python3 -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3.2-Exp \
  --trust-remote-code \
  --tp-size 8 --dp-size 8 --enable-dp-attention \
  --tool-call-parser deepseekv31 \
  --reasoning-parser deepseek-v3 \
  --chat-template ./examples/chat_template/tool_chat_template_deepseekv32.jinja

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319

image
  1. Accuracy Test with gpqa-diamond
    Service:
    python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4

python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 120000 --repeat 8 --thinking-mode deepseek-v3

image
  1. Accuracy Test with aime 2025
    python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --tool-call-parser deepseekv32 --reasoning-parser deepseek-v3
#! /bin/bash
export NEMO_SKILLS_DISABLE_UNCOMMITTED_CHANGES_CHECK=1

ns prepare_data aime25

PORT=30000
BACKEND=sglang
MODEL="deepseek-ai/DeepSeek-V3.2-Exp" # Should be changed to the model name
MODEL_NAME="dsv32-fp8"

echo "Starting AIME25 evaluation with model $MODEL on port $PORT using backend $BACKEND..."
ns eval \
  --benchmarks=aime25:4 \
  --server_type=$BACKEND \
  --model=$MODEL \
  --server_address=http://localhost:${PORT}/v1 \
  --output_dir=nemo_skills_aime25_${MODEL_NAME}_output_${BACKEND}_$(date +%Y%m%d_%H%M%S) \
  ++chat_template_kwargs.thinking=true \
  ++inference.temperature=1.0 \
  ++inference.top_p=0.95 \
  ++inference.tokens_to_generate=64000
  # ++inference.tokens_to_generate=120000 for Speciale model
image
  1. Correctness
    After the above tests, using both kernels and verified every steps
image

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Fridge003
Copy link
Collaborator

Can you please move the kernels to jit folder, thanks~

@Johnsonms Johnsonms force-pushed the nsa-metadata-copy-kernel-v2 branch from f1744cc to b59ce0b Compare January 26, 2026 19:17
@Johnsonms Johnsonms force-pushed the nsa-metadata-copy-kernel-v2 branch from 8533df9 to f9cd4d8 Compare January 26, 2026 23:38
@Johnsonms
Copy link
Contributor Author

Johnsonms commented Jan 27, 2026

Re-performed the accuracy testing

Accuracy Tests

  1. Accuracy Test with gsm8k
python3 -m sglang.launch_server \
  --model-path deepseek-ai/DeepSeek-V3.2-Exp \
  --trust-remote-code \
  --tp-size 8 --dp-size 8 --enable-dp-attention \
  --tool-call-parser deepseekv31 \
  --reasoning-parser deepseek-v3 \
  --chat-template ./examples/chat_template/tool_chat_template_deepseekv32.jinja

python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319
image

  1. Accuracy Test with gpqa-diamond
    Service:
    python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4
    python3 -m sglang.test.run_eval --port 30000 --eval-name gpqa --num-examples 198 --max-tokens 120000 --repeat 8 --thinking-mode deepseek-v3
image
  1. Accuracy Test with aime 2025
    python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --dp 8 --enable-dp-attention --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --tool-call-parser deepseekv32 --reasoning-parser deepseek-v3
#! /bin/bash
export NEMO_SKILLS_DISABLE_UNCOMMITTED_CHANGES_CHECK=1

ns prepare_data aime25

PORT=30000
BACKEND=sglang
MODEL="deepseek-ai/DeepSeek-V3.2-Exp" # Should be changed to the model name
MODEL_NAME="dsv32-fp8"

echo "Starting AIME25 evaluation with model $MODEL on port $PORT using backend $BACKEND..."
ns eval \
  --benchmarks=aime25:4 \
  --server_type=$BACKEND \
  --model=$MODEL \
  --server_address=http://localhost:${PORT}/v1 \
  --output_dir=nemo_skills_aime25_${MODEL_NAME}_output_${BACKEND}_$(date +%Y%m%d_%H%M%S) \
  ++chat_template_kwargs.thinking=true \
  ++inference.temperature=1.0 \
  ++inference.top_p=0.95 \
  ++inference.tokens_to_generate=64000
  # ++inference.tokens_to_generate=120000 for Speciale model
image image

@Johnsonms Johnsonms force-pushed the nsa-metadata-copy-kernel-v2 branch from d920fe3 to 094323f Compare January 27, 2026 23:36
@Johnsonms
Copy link
Contributor Author

Can you please move the kernels to jit folder, thanks~

Done, with re-performance Accuracy Tests. Thanks @Fridge003 !

@Johnsonms Johnsonms force-pushed the nsa-metadata-copy-kernel-v2 branch from af349b7 to ceec97a Compare January 28, 2026 18:50
@DarkSharpness
Copy link
Collaborator

/tag-and-rerun-ci

@github-actions github-actions bot added documentation Improvements or additions to documentation quant LLM Quantization amd dependencies Pull requests that update a dependency file lora Multi-modal multi-modal language model deepseek speculative-decoding hicache Hierarchical Caching for SGLang blackwell SM100/SM120 npu deterministic Issues on deterministic inference/kernels piecewise-cuda-graph diffusion SGLang Diffusion model-gateway mthreads labels Feb 14, 2026
… parameters

Consolidate three separate CUDA kernels (decode, target_verify, draft_extend)
into a single unified kernel with runtime mode selection and structured
parameter passing.

Key changes:
- Merge fused_metadata_copy_{decode,target_verify,draft_extend}_kernel into
  single fused_metadata_copy_kernel with runtime forward_mode branching
- Introduce structured parameter passing via SourcePointers, DestinationPointers,
  FusedMetadataCopyParams, and FusedMetadataCopyMultiParams structs
- Use __grid_constant__ attribute for efficient constant memory parameter access
- Reduce parameter count from 29 individual parameters to 1 struct
- Maintain compile-time optimization via template parameters (HAS_REAL_PAGE_TABLE,
  HAS_FLASHMLA)
- Update multi-backend kernel with same structured parameter pattern

Benefits:
- Code reduction: 233 fewer lines
- Eliminates code duplication across three nearly-identical kernels
- Improves maintainability: modify logic once instead of three times
- Cleaner API: structured parameters group related pointers logically
- Better alignment with other kernels (qknorm.cuh, rmsnorm.cuh patterns)

No performance regression expected as hot paths remain optimized via template
specialization and __grid_constant__ provides efficient parameter access.
…used metadata copy

Replace throw statements with RuntimeCheck and add type-safe helper functions
for tensor data pointer extraction. Switch from relying on torch.empty(0) null
pointer behavior to proper tvm::ffi::Optional<TensorView> for optional tensors.

Changes:
- Add unwrap_data_ptr/unwrap_optional_data_ptr helper functions with integrated
  dtype validation using RuntimeCheck
- Update function signatures to use Optional<TensorView> for seqlens_expanded,
  real_page_table, and flashmla tensors
- Replace .data_ptr() null checks with .has_value()/.value() pattern
- Update Python wrapper to pass None directly instead of empty tensors
- Remove _make_empty_tensor_if_none helper (no longer needed)

Addresses review feedback from @DarkSharpness
- Add linear_metadata.py for linear attention metadata handling
- Update communicator_nsa_cp.py for context parallel support
- Add test_nsa_pool_host_unit.py for HiCache NSA pool testing
- Update test_nsa_indexer.py with new test cases
- Update discover_metadata.rs for Rust gateway support
- Improve error messages in fused_metadata_copy.cuh
@Johnsonms Johnsonms force-pushed the nsa-metadata-copy-kernel-v2 branch from 3e7e878 to 50b1638 Compare February 14, 2026 02:09
@Fridge003 Fridge003 merged commit 34132d6 into sgl-project:main Feb 14, 2026
82 of 94 checks passed
@yuan-luo
Copy link
Collaborator

Awesome job!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

amd blackwell SM100/SM120 deepseek dependencies Pull requests that update a dependency file deterministic Issues on deterministic inference/kernels diffusion SGLang Diffusion documentation Improvements or additions to documentation hicache Hierarchical Caching for SGLang lora model-gateway mthreads Multi-modal multi-modal language model npu piecewise-cuda-graph quant LLM Quantization run-ci sgl-kernel speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants