Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions fast_llm_external_models/apriel2/vllm/modeling_apriel2.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,62 @@ def _patched(kv_cache_spec: dict) -> list:
_patch_kv_cache_grouping()


# ---------------------------------------------------------------------------
# Monkey-patch: fix coordinator selection for pure-recurrent configurations
# ---------------------------------------------------------------------------
# When ALL layers use MambaSpec (e.g. all-GDN or all-KDA placement), vLLM
# creates a single KV cache group and selects UnitaryKVCacheCoordinator.
# That coordinator asserts hash_block_size == block_size, but for MambaSpec
# block_size = max_model_len (e.g. 4096) while hash_block_size comes from
# cache_config.block_size (16 on CUDA). Prefix caching is meaningless for
# pure recurrent state anyway (can't share state prefixes across requests),
# so we fall back to KVCacheCoordinatorNoPrefixCache in this case.
def _patch_coordinator_selection() -> None:
import vllm.v1.core.kv_cache_coordinator as _coord

_original = _coord.get_kv_cache_coordinator

def _patched(kv_cache_config, max_model_len, use_eagle, enable_caching,
enable_kv_cache_events, dcp_world_size, pcp_world_size,
hash_block_size, metrics_collector=None):
# Detect pure-recurrent: single group whose spec is MambaSpec with
# block_size != hash_block_size (would trip UnitaryCoordinator assert)
if (enable_caching
and len(kv_cache_config.kv_cache_groups) == 1
and isinstance(
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
MambaSpec)
and kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size
!= hash_block_size):
apriel2_logger.info(
"Pure recurrent config detected (MambaSpec block_size=%d != "
"hash_block_size=%d). Disabling prefix caching for this "
"config (recurrent state is not prefix-shareable).",
kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
hash_block_size,
)
return _coord.KVCacheCoordinatorNoPrefixCache(
kv_cache_config,
max_model_len,
use_eagle,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
return _original(
kv_cache_config, max_model_len, use_eagle, enable_caching,
enable_kv_cache_events, dcp_world_size, pcp_world_size,
hash_block_size, metrics_collector=metrics_collector,
)

_coord.get_kv_cache_coordinator = _patched


_patch_coordinator_selection()


# =============================================================================
# KV Cache Spec Computation
# =============================================================================
Expand Down