From abc9968910e30f73ea27375dba97fc7557485500 Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 15 Apr 2026 14:20:17 +0000 Subject: [PATCH] Fix vLLM coordinator crash for pure-recurrent placements MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When all layers use MambaSpec (e.g. all-GDN, all-KDA, or GDN+KDA only), vLLM creates a single KV cache group and selects UnitaryKVCacheCoordinator, which asserts hash_block_size == block_size. This fails because hash_block_size = cache_config.block_size (16 on CUDA) while MambaSpec.block_size = max_model_len (e.g. 4096). Prefix caching is meaningless for pure recurrent state (can't share recurrent state prefixes across requests), so fall back to KVCacheCoordinatorNoPrefixCache for this case. Mixed configs with any attention-type layer are unaffected and continue using the normal coordinator with prefix caching. Tested on SuperApriel-0.5b-Base with all 8 placement combinations: all-attention, all-swa, all-gdn, all-kda, mixed-attn-gdn, mixed-swa-kda, mixed-gdn-kda, mixed-all-4 — all pass. Co-Authored-By: Claude Opus 4.6 --- .../apriel2/vllm/modeling_apriel2.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 6049f913a..cdc651f9f 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -157,6 +157,62 @@ def _patched(kv_cache_spec: dict) -> list: _patch_kv_cache_grouping() +# --------------------------------------------------------------------------- +# Monkey-patch: fix coordinator selection for pure-recurrent configurations +# --------------------------------------------------------------------------- +# When ALL layers use MambaSpec (e.g. all-GDN or all-KDA placement), vLLM +# creates a single KV cache group and selects UnitaryKVCacheCoordinator. +# That coordinator asserts hash_block_size == block_size, but for MambaSpec +# block_size = max_model_len (e.g. 4096) while hash_block_size comes from +# cache_config.block_size (16 on CUDA). Prefix caching is meaningless for +# pure recurrent state anyway (can't share state prefixes across requests), +# so we fall back to KVCacheCoordinatorNoPrefixCache in this case. +def _patch_coordinator_selection() -> None: + import vllm.v1.core.kv_cache_coordinator as _coord + + _original = _coord.get_kv_cache_coordinator + + def _patched(kv_cache_config, max_model_len, use_eagle, enable_caching, + enable_kv_cache_events, dcp_world_size, pcp_world_size, + hash_block_size, metrics_collector=None): + # Detect pure-recurrent: single group whose spec is MambaSpec with + # block_size != hash_block_size (would trip UnitaryCoordinator assert) + if (enable_caching + and len(kv_cache_config.kv_cache_groups) == 1 + and isinstance( + kv_cache_config.kv_cache_groups[0].kv_cache_spec, + MambaSpec) + and kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + != hash_block_size): + apriel2_logger.info( + "Pure recurrent config detected (MambaSpec block_size=%d != " + "hash_block_size=%d). Disabling prefix caching for this " + "config (recurrent state is not prefix-shareable).", + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size, + hash_block_size, + ) + return _coord.KVCacheCoordinatorNoPrefixCache( + kv_cache_config, + max_model_len, + use_eagle, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, + hash_block_size=hash_block_size, + metrics_collector=metrics_collector, + ) + return _original( + kv_cache_config, max_model_len, use_eagle, enable_caching, + enable_kv_cache_events, dcp_world_size, pcp_world_size, + hash_block_size, metrics_collector=metrics_collector, + ) + + _coord.get_kv_cache_coordinator = _patched + + +_patch_coordinator_selection() + + # ============================================================================= # KV Cache Spec Computation # =============================================================================