Skip to content

[Fix] data race in req_to_token pool#17850

Merged
merrymercy merged 6 commits intomainfrom
csy/fix_req_to_pool
Feb 2, 2026
Merged

[Fix] data race in req_to_token pool#17850
merrymercy merged 6 commits intomainfrom
csy/fix_req_to_pool

Conversation

@cctry
Copy link
Collaborator

@cctry cctry commented Jan 28, 2026

Motivation

The chunked prefill requests will free its slot in req_to_token_pool and get allocated again when preparing for its next prefill batch.

As a result, if a prefill batch contains multiple requests and req_to_token_pool is at capacity. The write for matched kv indices for another request will overwrite the slot of the chunked requests which is being read in forward stream

Example

Prepare & Launch prefill batch N:     
    req A (first half) --> idx 1  
    
model runner reads idx 1
  
Prepare batch N+1: 
    req A (second half) --> idx 2
    req B --> idx 1

scheduler writes req B's matched indices to idx 1

Modifications

  1. alloc(reqs: list[Req]) - Now takes request list, sets req.req_pool_idx directly, reuses slot if already set. cc @hnyls2002
  2. Separate free() with free_mamba_cache(req, ...) in HybridReqToTokenPool - Only frees mamba state, not req slot cc @hanming-lu @yizhang2077
  3. release_kv_cache() - Now calls free(req) at end; handles early mamba-only free case
  4. Removed free() in process_prefill_chunk and cache_finished_req

Accuracy Tests

Benchmarking and Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@cctry
Copy link
Collaborator Author

cctry commented Jan 28, 2026

/tag-run-ci-label

@Henrry-CHEN
Copy link

so if a prefill batch contain 2 or more request or request chunk, the accuracy of the mamba state for these req is not right?

@cctry
Copy link
Collaborator Author

cctry commented Jan 28, 2026 via email

@cctry cctry force-pushed the csy/fix_req_to_pool branch from 60ab814 to 30b9b41 Compare January 28, 2026 21:31
@merrymercy merrymercy merged commit 027f314 into main Feb 2, 2026
194 of 214 checks passed
@merrymercy merrymercy deleted the csy/fix_req_to_pool branch February 2, 2026 22:38
charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Feb 5, 2026
sfiisf pushed a commit to sfiisf/sglang that referenced this pull request Feb 5, 2026
@ClawSeven
Copy link
Collaborator

Hi @cctry,
I have a question regarding the described race conditions. If we disable overlap scheduling, could there still be other write race conditions?

@ClawSeven
Copy link
Collaborator

ClawSeven commented Feb 10, 2026

@cctry, Oh, I see. Since the kernel is launched asynchronously, even with overlap scheduling disabled, there could still be a race condition between write_cache_indices and the backend reading the KV caches.
Thank you for the fix—the PR not only resolves the chunked prefill issue but also addresses a potential problem in dLLM.

nvcastet added a commit to nvcastet/sglang that referenced this pull request Feb 13, 2026
…vent cross-stream data race

In overlap scheduling (MTPv2), `process_batch_result(N-1)` runs on the
default stream concurrently with `forward(N)` on the forward stream.
When a request finishes, `release_kv_cache` immediately returns its
`req_pool_idx` to the free list.  A new request can then recycle that
pool index and `prepare_for_decode` overwrites the `req_to_token` row on
the default stream while `forward(N)` still reads it — causing an
"index out of bounds" assertion in IndexKernel.cu.

Fix: defer the pool-index free by one overlap iteration.

- `ReqToTokenPool.deferred_free(req)`: withholds the pool index from
  the free list (the slot cannot be reallocated).
- `ReqToTokenPool.flush_deferred_frees()`: moves deferred slots back to
  the free list once the forward that read them has completed.
- `release_kv_cache(..., defer_pool_free=True)`: used in the decode
  result-processing path when overlap is enabled.
- `process_batch_result_decode`: flushes deferred frees right after
  `copy_done.synchronize()`, which guarantees the previous forward
  has finished reading `req_to_token`.

This is the overlap-scheduling counterpart of PR sgl-project#17850, which fixed the
same class of race for chunked prefill.
nvcastet added a commit to nvcastet/sglang that referenced this pull request Feb 13, 2026
…vent cross-stream data race

In overlap scheduling (MTPv2), `process_batch_result(N-1)` runs on the
default stream concurrently with `forward(N)` on the forward stream.
When a request finishes, `release_kv_cache` immediately returns its
`req_pool_idx` to the free list.  A new request can then recycle that
pool index and `prepare_for_decode` overwrites the `req_to_token` row on
the default stream while `forward(N)` still reads it — causing an
"index out of bounds" assertion in IndexKernel.cu.

Fix: defer the pool-index free by one overlap iteration.

- `ReqToTokenPool.deferred_free(req)`: withholds the pool index from
  the free list (the slot cannot be reallocated).
- `ReqToTokenPool.flush_deferred_frees()`: moves deferred slots back to
  the free list once the forward that read them has completed.
- `release_kv_cache(..., defer_pool_free=True)`: used in the decode
  result-processing path when overlap is enabled.
- `process_batch_result_decode`: flushes deferred frees right after
  `copy_done.synchronize()`, which guarantees the previous forward
  has finished reading `req_to_token`.

This is the overlap-scheduling counterpart of PR sgl-project#17850, which fixed the
same class of race for chunked prefill.
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants