Skip to content

[bugfix] fix mamba slot leak when scheduling fails with radix cache (#15840)#16067

Merged
ispobock merged 8 commits intosgl-project:mainfrom
kuafou:fix/mamba-radix-cache-memory-leak
Feb 13, 2026
Merged

[bugfix] fix mamba slot leak when scheduling fails with radix cache (#15840)#16067
ispobock merged 8 commits intosgl-project:mainfrom
kuafou:fix/mamba-radix-cache-memory-leak

Conversation

@kuafou
Copy link
Contributor

@kuafou kuafou commented Dec 29, 2025

Motivation

This PR fixes a server crash caused by a false positive "memory leak detected" error when serving Mamba models with radix-cache enabled.

The issue was discovered during the investigation of #15840 but is a distinct issue regarding the scheduler's failure path cleanup. It occurs when the scheduler is temporarily resource-blocked (e.g., during request retraction), causing check_memory() to flag a temporarily untracked Mamba slot as leaked.

Modifications

  • Explicitly release the COW-allocated Mamba slot if add_one_req fails (returns NO_TOKEN or IDLE) during the scheduling phase.
  • Add a regression test test_mamba_slot_release_after_match_prefix_cow to verify the cleanup logic.

Accuracy Tests

N/A — resource cleanup only, no model output changes.

Benchmarking and Profiling

N/A — failure-path fix, no performance impact.

Checklist

Details

The Problem:

  1. In init_next_round_input(), a Mamba slot is allocated via Copy-On-Write (COW) / match_prefix.
  2. Subsequently, add_one_req() is called but returns NO_TOKEN (e.g., because the KV cache pool is full).
  3. The scheduler enters an "Idle" state (which technically means "resource-blocked" in this context) and triggers self_check_during_idle() -> check_memory().
  4. check_memory() fails because the slot allocated in step 1 is:
    • Not in free_slots (it was allocated).
    • Not yet tracked by tree_cache (the request didn't finish scheduling).
    • Result: The server crashes with ValueError: token_to_kv_pool_allocator memory leak detected!.

The Fix:
When scheduling fails, we now explicitly free the temporary Mamba slot. The request remains in the waiting_queue and will safely re-allocate (and copy) the Mamba state in the next scheduling round.

Error Log

The server crashes with the following error during high load/retraction:

ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=6016, full_evictable_size=390685, self.token_to_kv_pool_allocator.size=396701, self.tree_cache.full_protected_size()=0
mamba_available_size=315, mamba_evictable_size=138, self.req_to_token_pool.mamba_pool.size=454, self.tree_cache.mamba_protected_size()=0, leaked_full_pages=None, leaked_mamba_pages={191}

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @kuafou, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical GPU memory leak identified when serving Mamba models with radix-cache, particularly affecting Qwen3-Next. The leak occurred because Mamba slots allocated through Copy-On-Write were not being released if a subsequent scheduling attempt failed. The changes introduce a mechanism to ensure these slots are properly deallocated in such failure paths, accompanied by a new regression test to validate the fix and prevent future regressions.

Highlights

  • Memory Leak Fix: Addresses a GPU memory leak that occurred when scheduling requests failed after a Mamba slot was allocated via Copy-On-Write (COW) with radix-cache enabled.
  • Mamba Slot Release: Implemented logic in the scheduler to explicitly free COW-allocated Mamba slots if the add_one_req operation does not succeed, ensuring proper resource deallocation.
  • Regression Test: Added a new unit test (test_mamba_slot_release_after_match_prefix_cow) to verify the correct cleanup and release of Mamba slots in scenarios where scheduling fails after a COW allocation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a GPU memory leak in Mamba-based models that occurred when scheduling failed after a Copy-On-Write (COW) operation in the radix cache. The fix is implemented by ensuring that any Mamba slot allocated during prefix matching is properly released if the request subsequently fails to be scheduled. The logic is sound and correctly placed within the scheduler's prefill logic. Additionally, a new regression test has been added to specifically cover this cleanup path, verifying that the allocated slot is indeed returned to the pool. The changes are clear, targeted, and improve the robustness of the memory management for hybrid SSM models.

@kuafou kuafou force-pushed the fix/mamba-radix-cache-memory-leak branch from 43c6362 to 91e2022 Compare December 29, 2025 10:27
@kuafou kuafou changed the title Fix mamba slot leak when scheduling fails with radix cache (#15840) [bugfix] fix mamba slot leak when scheduling fails with radix cache (#15840) Dec 29, 2025
@kuafou kuafou force-pushed the fix/mamba-radix-cache-memory-leak branch 2 times, most recently from 2564440 to 8028fe1 Compare December 31, 2025 01:33
@yizhang2077
Copy link
Collaborator

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Jan 6, 2026
@kuafou kuafou force-pushed the fix/mamba-radix-cache-memory-leak branch from a451b1c to 2f8819d Compare January 6, 2026 08:48
@kuafou
Copy link
Contributor Author

kuafou commented Jan 7, 2026

Looks like one of the tests failed — I’m looking into it now and will push a fix soon.

@yizhang2077
Copy link
Collaborator

let me think, without this pr it will not cause actual mamba slot leak, but it will cause check memory failed, do I understand correctly?

@kuafou
Copy link
Contributor Author

kuafou commented Jan 7, 2026

let me think, without this pr it will not cause actual mamba slot leak, but it will cause check memory failed, do I understand correctly?

Yes.

Without this PR, it’s not necessarily a real long-term “slot leak” in the sense of “lost forever,” but it does create an untracked allocated mamba slot in the failure path (e.g., add_one_req() returns NO_TOKEN and scheduling bails out). That slot is:

  • not in mamba_pool.free_slots (because it’s already allocated), and
  • not in tree_cache’s cached mamba values (because the request didn’t finish and the state wasn’t inserted),

so check_memory() computes it as expected - free - cached and flags it as leaked_mamba_pages, which can trigger the “memory leak detected” crash.

That’s why the fix explicitly frees the mamba slot when scheduling fails. The request stays in waiting_queue, and on the next scheduling round it will go through match_prefix() again and safely re-allocate (and copy) the mamba state. Keeping the old mamba_pool_idx would require teaching check_memory to account for “waiting requests holding slots,” which adds complexity for very little gain (since the copy_from still happens either way).

@yizhang2077
Copy link
Collaborator

yizhang2077 commented Jan 7, 2026

let me think, without this pr it will not cause actual mamba slot leak, but it will cause check memory failed, do I understand correctly?

Yes.

Without this PR, it’s not necessarily a real long-term “slot leak” in the sense of “lost forever,” but it does create an untracked allocated mamba slot in the failure path (e.g., add_one_req() returns NO_TOKEN and scheduling bails out). That slot is:

  • not in mamba_pool.free_slots (because it’s already allocated), and
  • not in tree_cache’s cached mamba values (because the request didn’t finish and the state wasn’t inserted),

so check_memory() computes it as expected - free - cached and flags it as leaked_mamba_pages, which can trigger the “memory leak detected” crash.

That’s why the fix explicitly frees the mamba slot when scheduling fails. The request stays in waiting_queue, and on the next scheduling round it will go through match_prefix() again and safely re-allocate (and copy) the mamba state. Keeping the old mamba_pool_idx would require teaching check_memory to account for “waiting requests holding slots,” which adds complexity for very little gain (since the copy_from still happens either way).

I see, I think it will not lost this slot forwever since request mamba_pool_idx is not none, and when doing extend slot will be tracked again. But it will cause check_memory failed actually.

@kuafou
Copy link
Contributor Author

kuafou commented Jan 7, 2026

let me think, without this pr it will not cause actual mamba slot leak, but it will cause check memory failed, do I understand correctly?

Yes.
Without this PR, it’s not necessarily a real long-term “slot leak” in the sense of “lost forever,” but it does create an untracked allocated mamba slot in the failure path (e.g., add_one_req() returns NO_TOKEN and scheduling bails out). That slot is:

  • not in mamba_pool.free_slots (because it’s already allocated), and
  • not in tree_cache’s cached mamba values (because the request didn’t finish and the state wasn’t inserted),

so check_memory() computes it as expected - free - cached and flags it as leaked_mamba_pages, which can trigger the “memory leak detected” crash.
That’s why the fix explicitly frees the mamba slot when scheduling fails. The request stays in waiting_queue, and on the next scheduling round it will go through match_prefix() again and safely re-allocate (and copy) the mamba state. Keeping the old mamba_pool_idx would require teaching check_memory to account for “waiting requests holding slots,” which adds complexity for very little gain (since the copy_from still happens either way).

I see, I think it will not lost this slot forwever since request mamba_pool_idx is not none, and when doing extend slot will be tracked again. But it will cause check_memory failed actually.

Yes.

@kuafou kuafou force-pushed the fix/mamba-radix-cache-memory-leak branch from 8dca3fd to 76c6a90 Compare January 7, 2026 04:58
Copy link
Collaborator

@hanming-lu hanming-lu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I don't see the relationship between #15840 and this PR. #15840 is leaking runtime memory, this PR doesn't touch runtime memory.

@kuafou
Copy link
Contributor Author

kuafou commented Jan 7, 2026

Also I don't see the relationship between #15840 and this PR. #15840 is leaking runtime memory, this PR doesn't touch runtime memory.

Thanks for the clarification.

You’re right — #15840 refers to a runtime GPU memory leak, and this PR does not directly address that issue. The only relationship is that this bug was found while investigating #15840 locally.

This PR fixes a separate scheduler failure-path problem: when scheduling fails, a COW-allocated mamba slot is left temporarily untracked, which can cause check_memory() to report a leak and crash the server.

If it would be clearer, I’m happy to update the PR title and/or Motivation to remove the direct reference to #15840 and instead describe it as “discovered during investigation of #15840.” Please let me know what you prefer.

@MegaGriffin
Copy link

Also I don't see the relationship between #15840 and this PR. #15840 is leaking runtime memory, this PR doesn't touch runtime memory.

Thanks for the clarification.

You’re right — #15840 refers to a runtime GPU memory leak, and this PR does not directly address that issue. The only relationship is that this bug was found while investigating #15840 locally.

This PR fixes a separate scheduler failure-path problem: when scheduling fails, a COW-allocated mamba slot is left temporarily untracked, which can cause check_memory() to report a leak and crash the server.

If it would be clearer, I’m happy to update the PR title and/or Motivation to remove the direct reference to #15840 and instead describe it as “discovered during investigation of #15840.” Please let me know what you prefer.

Hi @kuafou , thanks for your insight. Did your local test crash on check_memory() function with the report of "token_to_kv_pool_allocator memory leak", or it crash on "req_to_token_pool memory leak detected!" ?

@kuafou
Copy link
Contributor Author

kuafou commented Jan 12, 2026

Also I don't see the relationship between #15840 and this PR. #15840 is leaking runtime memory, this PR doesn't touch runtime memory.

Thanks for the clarification.
You’re right — #15840 refers to a runtime GPU memory leak, and this PR does not directly address that issue. The only relationship is that this bug was found while investigating #15840 locally.
This PR fixes a separate scheduler failure-path problem: when scheduling fails, a COW-allocated mamba slot is left temporarily untracked, which can cause check_memory() to report a leak and crash the server.
If it would be clearer, I’m happy to update the PR title and/or Motivation to remove the direct reference to #15840 and instead describe it as “discovered during investigation of #15840.” Please let me know what you prefer.

Hi @kuafou , thanks for your insight. Did your local test crash on check_memory() function with the report of "token_to_kv_pool_allocator memory leak", or it crash on "req_to_token_pool memory leak detected!" ?

Hi, I have the logs from the test and will post them shortly. In the meantime, could you share exactly what the original error message/report was on your side?

@kuafou
Copy link
Contributor Author

kuafou commented Jan 12, 2026

@MegaGriffin

Hi, that my test scripts and error log. Hope this helps!

Server Side Scripts

export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
export MODEL_PATH=/mnt/nvme0n1/home/jiaqi/.cache/modelscope/hub/models/Qwen/Qwen3-Next-80B-A3B-Instruct
export MODEL_LENGTH=262144
export TP_SIZE=8
export SERVED_NAME=test_serving

export COMMON_LAUNCH_ARGS="--trust-remote \
    --model-path $MODEL_PATH \
    --context-length ${MODEL_LENGTH} \
    --tp-size ${TP_SIZE} \
    --served-model-name ${SERVED_NAME} \
    --host 0.0.0.0 \
    --grammar-backend outlines \
    --enable-memory-saver \
    --allow-auto-truncate \
    --sampling-backend flashinfer \
    --attention-backend triton \
    --disable-overlap-schedule \
    --mem-fraction-static 0.9 \
    --enable-metrics \
    --log-level info "

    # --attention-backend flashinfer \
python -m sglang.launch_server ${COMMON_LAUNCH_ARGS}

Batchmark Scripts

python3 -m sglang.bench_serving \
    --model /mnt/nvme0n1/home/jiaqi/.cache/modelscope/hub/models/Qwen/Qwen3-Next-80B-A3B-Instruct \
    --backend sglang \
    --host 127.0.0.1 \
    --port 30000 \
    --warmup-requests 0 \
    --dataset-name generated-shared-prefix \
    --gsp-num-groups 1 \
    --gsp-prompts-per-group 50 \
    --gsp-system-prompt-len 150000 \
    --gsp-question-len 400 \
    --gsp-output-len 8192

Error Log

[2026-01-12 10:57:27 TP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 7819, #new_token_ratio: 0.6485 -> 0.9570
[2026-01-12 10:57:27 TP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 7819, #new_token_ratio: 0.6485 -> 0.9570
[2026-01-12 10:57:27 TP6] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 7819, #new_token_ratio: 0.6485 -> 0.9570
[2026-01-12 10:57:27 TP3] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 7819, #new_token_ratio: 0.6485 -> 0.9570
[2026-01-12 10:57:27 TP4] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 7819, #new_token_ratio: 0.6485 -> 0.9570
[2026-01-12 10:57:27 TP2] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 7819, #new_token_ratio: 0.6485 -> 0.9570
[2026-01-12 10:57:27 TP5] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 7819, #new_token_ratio: 0.6485 -> 0.9570
[2026-01-12 10:57:27 TP7] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 7819, #new_token_ratio: 0.6485 -> 0.9570
[2026-01-12 10:57:27 TP0] Decode batch, #running-req: 28, #full token: 388852, full token usage: 0.98, mamba num: 57, mamba usage: 0.13, cuda graph: True, gen throughput (token/s): 673.12, #queue-req: 22, 
[2026-01-12 10:57:29 TP0] Decode batch, #running-req: 28, #full token: 389972, full token usage: 0.98, mamba num: 57, mamba usage: 0.13, cuda graph: True, gen throughput (token/s): 669.53, #queue-req: 22, 
[2026-01-12 10:57:31 TP0] Decode batch, #running-req: 28, #full token: 391092, full token usage: 0.99, mamba num: 57, mamba usage: 0.13, cuda graph: True, gen throughput (token/s): 673.28, #queue-req: 22, 
[2026-01-12 10:57:32 TP0] Decode batch, #running-req: 28, #full token: 392212, full token usage: 0.99, mamba num: 57, mamba usage: 0.13, cuda graph: True, gen throughput (token/s): 673.00, #queue-req: 22, 
[2026-01-12 10:57:34 TP0] Decode batch, #running-req: 28, #full token: 393332, full token usage: 0.99, mamba num: 57, mamba usage: 0.13, cuda graph: True, gen throughput (token/s): 671.93, #queue-req: 22, 
[2026-01-12 10:57:36 TP0] Decode batch, #running-req: 28, #full token: 394452, full token usage: 0.99, mamba num: 57, mamba usage: 0.13, cuda graph: True, gen throughput (token/s): 670.30, #queue-req: 22, 
[2026-01-12 10:57:37 TP0] Decode batch, #running-req: 28, #full token: 395572, full token usage: 1.00, mamba num: 57, mamba usage: 0.13, cuda graph: True, gen throughput (token/s): 668.98, #queue-req: 22, 
[2026-01-12 10:57:39 TP0] Decode batch, #running-req: 28, #full token: 396692, full token usage: 1.00, mamba num: 57, mamba usage: 0.13, cuda graph: True, gen throughput (token/s): 668.86, #queue-req: 22, 
[2026-01-12 10:57:39 TP0] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 8113, #new_token_ratio: 0.6631 -> 0.9929
[2026-01-12 10:57:39 TP2] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 8113, #new_token_ratio: 0.6631 -> 0.9929
[2026-01-12 10:57:39 TP4] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 8113, #new_token_ratio: 0.6631 -> 0.9929
[2026-01-12 10:57:39 TP5] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 8113, #new_token_ratio: 0.6631 -> 0.9929
[2026-01-12 10:57:39 TP7] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 8113, #new_token_ratio: 0.6631 -> 0.9929
[2026-01-12 10:57:39 TP1] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 8113, #new_token_ratio: 0.6631 -> 0.9929
[2026-01-12 10:57:39 TP6] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 8113, #new_token_ratio: 0.6631 -> 0.9929
[2026-01-12 10:57:39 TP3] KV cache pool is full. Retract requests. #retracted_reqs: 1, #new_tokens_gained: 8113, #new_token_ratio: 0.6631 -> 0.9929
[2026-01-12 10:57:41 TP0] Decode batch, #running-req: 27, #full token: 389237, full token usage: 0.98, mamba num: 55, mamba usage: 0.12, cuda graph: True, gen throughput (token/s): 631.80, #queue-req: 23, 
[2026-01-12 10:57:43 TP2] Scheduler hit an exception: Traceback (most recent call last):
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 2996, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 1105, in event_loop_normal
    self.self_check_during_idle()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 330, in self_check_during_idle
    self.check_memory()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 244, in check_memory
    raise_error_or_warn(
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/utils/common.py", line 3774, in raise_error_or_warn
    raise ValueError(message)
ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=6016, full_evictable_size=390685, self.token_to_kv_pool_allocator.size=396701, self.tree_cache.full_protected_size()=0
mamba_available_size=315, mamba_evictable_size=138, self.req_to_token_pool.mamba_pool.size=454, self.tree_cache.mamba_protected_size()=0, leaked_full_pages=None, leaked_mamba_pages={191}


[2026-01-12 10:57:43] SIGQUIT received. signum=None, frame=None. It usually means one child failed.
[2026-01-12 10:57:43 TP1] Scheduler hit an exception: Traceback (most recent call last):
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 2996, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 1105, in event_loop_normal
    self.self_check_during_idle()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 330, in self_check_during_idle
    self.check_memory()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 244, in check_memory
    raise_error_or_warn(
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/utils/common.py", line 3774, in raise_error_or_warn
    raise ValueError(message)
ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=6016, full_evictable_size=390685, self.token_to_kv_pool_allocator.size=396701, self.tree_cache.full_protected_size()=0
mamba_available_size=315, mamba_evictable_size=138, self.req_to_token_pool.mamba_pool.size=454, self.tree_cache.mamba_protected_size()=0, leaked_full_pages=None, leaked_mamba_pages={191}


[2026-01-12 10:57:43 TP5] Scheduler hit an exception: Traceback (most recent call last):
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 2996, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 1105, in event_loop_normal
    self.self_check_during_idle()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 330, in self_check_during_idle
    self.check_memory()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 244, in check_memory
    raise_error_or_warn(
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/utils/common.py", line 3774, in raise_error_or_warn
    raise ValueError(message)
ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=6016, full_evictable_size=390685, self.token_to_kv_pool_allocator.size=396701, self.tree_cache.full_protected_size()=0
mamba_available_size=315, mamba_evictable_size=138, self.req_to_token_pool.mamba_pool.size=454, self.tree_cache.mamba_protected_size()=0, leaked_full_pages=None, leaked_mamba_pages={191}


[2026-01-12 10:57:43 TP6] Scheduler hit an exception: Traceback (most recent call last):
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 2996, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 1105, in event_loop_normal
    self.self_check_during_idle()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 330, in self_check_during_idle
    self.check_memory()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 244, in check_memory
    raise_error_or_warn(
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/utils/common.py", line 3774, in raise_error_or_warn
    raise ValueError(message)
ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=6016, full_evictable_size=390685, self.token_to_kv_pool_allocator.size=396701, self.tree_cache.full_protected_size()=0
mamba_available_size=315, mamba_evictable_size=138, self.req_to_token_pool.mamba_pool.size=454, self.tree_cache.mamba_protected_size()=0, leaked_full_pages=None, leaked_mamba_pages={191}


[2026-01-12 10:57:43 TP3] Scheduler hit an exception: Traceback (most recent call last):
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 2996, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 1105, in event_loop_normal
    self.self_check_during_idle()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 330, in self_check_during_idle
    self.check_memory()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 244, in check_memory
    raise_error_or_warn(
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/utils/common.py", line 3774, in raise_error_or_warn
    raise ValueError(message)
ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=6016, full_evictable_size=390685, self.token_to_kv_pool_allocator.size=396701, self.tree_cache.full_protected_size()=0
mamba_available_size=315, mamba_evictable_size=138, self.req_to_token_pool.mamba_pool.size=454, self.tree_cache.mamba_protected_size()=0, leaked_full_pages=None, leaked_mamba_pages={191}


[2026-01-12 10:57:43 TP4] Scheduler hit an exception: Traceback (most recent call last):
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 2996, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 1105, in event_loop_normal
    self.self_check_during_idle()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 330, in self_check_during_idle
    self.check_memory()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 244, in check_memory
    raise_error_or_warn(
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/utils/common.py", line 3774, in raise_error_or_warn
    raise ValueError(message)
ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=6016, full_evictable_size=390685, self.token_to_kv_pool_allocator.size=396701, self.tree_cache.full_protected_size()=0
mamba_available_size=315, mamba_evictable_size=138, self.req_to_token_pool.mamba_pool.size=454, self.tree_cache.mamba_protected_size()=0, leaked_full_pages=None, leaked_mamba_pages={191}


[2026-01-12 10:57:43 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 2996, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 1105, in event_loop_normal
    self.self_check_during_idle()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 330, in self_check_during_idle
    self.check_memory()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 244, in check_memory
    raise_error_or_warn(
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/utils/common.py", line 3774, in raise_error_or_warn
    raise ValueError(message)
ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=6016, full_evictable_size=390685, self.token_to_kv_pool_allocator.size=396701, self.tree_cache.full_protected_size()=0
mamba_available_size=315, mamba_evictable_size=138, self.req_to_token_pool.mamba_pool.size=454, self.tree_cache.mamba_protected_size()=0, leaked_full_pages=None, leaked_mamba_pages={191}


[2026-01-12 10:57:43 TP7] Scheduler hit an exception: Traceback (most recent call last):
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 2996, in run_scheduler_process
    scheduler.event_loop_normal()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler.py", line 1105, in event_loop_normal
    self.self_check_during_idle()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 330, in self_check_during_idle
    self.check_memory()
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py", line 244, in check_memory
    raise_error_or_warn(
  File "/mnt/nvme0n1/home/jiaqi/workspaces/opensources/sglang/python/sglang/srt/utils/common.py", line 3774, in raise_error_or_warn
    raise ValueError(message)
ValueError: token_to_kv_pool_allocator memory leak detected! full_available_size=6016, full_evictable_size=390685, self.token_to_kv_pool_allocator.size=396701, self.tree_cache.full_protected_size()=0
mamba_available_size=315, mamba_evictable_size=138, self.req_to_token_pool.mamba_pool.size=454, self.tree_cache.mamba_protected_size()=0, leaked_full_pages=None, leaked_mamba_pages={191}


./outputs/run-mml.sh: line 24: 3557732 Killed                  python -m sglang.launch_server ${COMMON_LAUNCH_ARGS}

Copy link
Collaborator

@hanming-lu hanming-lu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a mitigation where server incorrectly thinks the server is idle, triggering memory check. From the server logs, it looks to happen during retractions.

Please update the description, thanks.
#16067 (comment)

@kuafou
Copy link
Contributor Author

kuafou commented Jan 12, 2026

This is a mitigation where server incorrectly thinks the server is idle, triggering memory check. From the server logs, it looks to happen during retractions.

Please update the description, thanks. #16067 (comment)

Done! I've updated the description and title to reflect the correct cause. Thanks for the review!

@kuafou
Copy link
Contributor Author

kuafou commented Jan 13, 2026

Hi @yizhang2077

The CI failures appear unrelated to my changes:

  1. backend-4-gpu0: Flaky Qwen3-Next-80B test (KL div 0.0097 > threshold 0.008).
  2. backend-2-gpu1: "Memory capacity unbalanced" error in mooncake test, likely an env issue or backend-specific behavior.

Could you please re-run? Thanks!

@MegaGriffin
Copy link

MegaGriffin commented Jan 14, 2026

Also I don't see the relationship between #15840 and this PR. #15840 is leaking runtime memory, this PR doesn't touch runtime memory.

Thanks for the clarification.
You’re right — #15840 refers to a runtime GPU memory leak, and this PR does not directly address that issue. The only relationship is that this bug was found while investigating #15840 locally.
This PR fixes a separate scheduler failure-path problem: when scheduling fails, a COW-allocated mamba slot is left temporarily untracked, which can cause check_memory() to report a leak and crash the server.
If it would be clearer, I’m happy to update the PR title and/or Motivation to remove the direct reference to #15840 and instead describe it as “discovered during investigation of #15840.” Please let me know what you prefer.

Hi @kuafou , thanks for your insight. Did your local test crash on check_memory() function with the report of "token_to_kv_pool_allocator memory leak", or it crash on "req_to_token_pool memory leak detected!" ?

Hi, I have the logs from the test and will post them shortly. In the meantime, could you share exactly what the original error message/report was on your side?

Hi @kuafou, for the issue you mentioned in this PR, I also noticed it when I do my local test. Thx for fixing it. The exact issue I meet is when the qwen3next model receives requests in some coding dataset or structured format, I have updated original issue #15840 with a dummy inputs. You could observe the increment of GPU memory when the inputs.

@yizhang2077
Copy link
Collaborator

/rerun-failed-ci

kuafou and others added 8 commits February 12, 2026 05:01
…ct#15840)

When add_one_req fails after init_next_round_input allocates a mamba slot
via COW (copy-on-write) during match_prefix, the slot was not released,
causing memory leak.

This fix releases the mamba slot when scheduling fails.
COW (copy-on-write) mamba slot allocation only happens in MambaRadixCache.
Add isinstance check to ensure slot cleanup only runs when using MambaRadixCache,
following the pattern in scheduler_runtime_checker_mixin.py.
@yizhang2077 yizhang2077 force-pushed the fix/mamba-radix-cache-memory-leak branch from 76c6a90 to 3483615 Compare February 12, 2026 05:13
@yizhang2077
Copy link
Collaborator

/rerun-failed-ci

@ispobock ispobock merged commit 4c6afbe into sgl-project:main Feb 13, 2026
231 of 255 checks passed
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
@kuafou kuafou deleted the fix/mamba-radix-cache-memory-leak branch February 19, 2026 19:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants