diff --git a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py index 6049f913a..cdc651f9f 100644 --- a/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/vllm/modeling_apriel2.py @@ -157,6 +157,62 @@ def _patched(kv_cache_spec: dict) -> list: _patch_kv_cache_grouping() +# --------------------------------------------------------------------------- +# Monkey-patch: fix coordinator selection for pure-recurrent configurations +# --------------------------------------------------------------------------- +# When ALL layers use MambaSpec (e.g. all-GDN or all-KDA placement), vLLM +# creates a single KV cache group and selects UnitaryKVCacheCoordinator. +# That coordinator asserts hash_block_size == block_size, but for MambaSpec +# block_size = max_model_len (e.g. 4096) while hash_block_size comes from +# cache_config.block_size (16 on CUDA). Prefix caching is meaningless for +# pure recurrent state anyway (can't share state prefixes across requests), +# so we fall back to KVCacheCoordinatorNoPrefixCache in this case. +def _patch_coordinator_selection() -> None: + import vllm.v1.core.kv_cache_coordinator as _coord + + _original = _coord.get_kv_cache_coordinator + + def _patched(kv_cache_config, max_model_len, use_eagle, enable_caching, + enable_kv_cache_events, dcp_world_size, pcp_world_size, + hash_block_size, metrics_collector=None): + # Detect pure-recurrent: single group whose spec is MambaSpec with + # block_size != hash_block_size (would trip UnitaryCoordinator assert) + if (enable_caching + and len(kv_cache_config.kv_cache_groups) == 1 + and isinstance( + kv_cache_config.kv_cache_groups[0].kv_cache_spec, + MambaSpec) + and kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size + != hash_block_size): + apriel2_logger.info( + "Pure recurrent config detected (MambaSpec block_size=%d != " + "hash_block_size=%d). Disabling prefix caching for this " + "config (recurrent state is not prefix-shareable).", + kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size, + hash_block_size, + ) + return _coord.KVCacheCoordinatorNoPrefixCache( + kv_cache_config, + max_model_len, + use_eagle, + enable_kv_cache_events, + dcp_world_size=dcp_world_size, + pcp_world_size=pcp_world_size, + hash_block_size=hash_block_size, + metrics_collector=metrics_collector, + ) + return _original( + kv_cache_config, max_model_len, use_eagle, enable_caching, + enable_kv_cache_events, dcp_world_size, pcp_world_size, + hash_block_size, metrics_collector=metrics_collector, + ) + + _coord.get_kv_cache_coordinator = _patched + + +_patch_coordinator_selection() + + # ============================================================================= # KV Cache Spec Computation # =============================================================================