From 176dc6db44fdd30e8f2409b49bed1eafbcf58993 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 31 Dec 2024 18:40:50 -0800 Subject: [PATCH 01/18] hide kv cache behind torch.compile Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 36 +++++++++++++++--------------- vllm/config.py | 12 +++++++--- vllm/forward_context.py | 30 +++++++++++++------------ vllm/v1/worker/gpu_model_runner.py | 11 +++++++++ vllm/worker/cache_engine.py | 16 +++++++++++-- vllm/worker/hpu_worker.py | 3 ++- vllm/worker/worker.py | 3 ++- vllm/worker/worker_base.py | 1 + 8 files changed, 73 insertions(+), 39 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 69b6d1e4648d..d1e6f4241f14 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,7 +7,8 @@ from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend -from vllm.config import CacheConfig, get_current_vllm_config +from vllm.config import (CacheConfig, LayerForwardContext, + get_current_vllm_config) from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -117,7 +118,10 @@ def __init__( compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self + # use a placeholder kv cache tensor during init, which will be replaced + # after kv cache initialization + compilation_config.static_forward_context[ + prefix] = LayerForwardContext(self, torch.tensor([])) self.layer_name = prefix def forward( @@ -152,13 +156,11 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) torch.ops.vllm.unified_attention_with_output( - query, key, value, output, kv_cache, attn_type, - self.layer_name) + query, key, value, output, attn_type, self.layer_name) return output.view(-1, hidden_size) else: return torch.ops.vllm.unified_attention(query, key, value, - kv_cache, attn_type, - self.layer_name) + attn_type, self.layer_name) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore @@ -236,17 +238,17 @@ def unified_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, attn_type: str, layer_name: str, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.dynamic_forward_context - self = forward_context.static_forward_context[layer_name] + attn_metadata = forward_context.attn_metadata + ctx = forward_context.layers[layer_name] + self = ctx.attn_module return self.impl.forward(query, key, value, - kv_cache, + ctx.kv_cache, attn_metadata, self._k_scale, self._v_scale, @@ -257,7 +259,6 @@ def unified_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: torch.Tensor, attn_type: str, layer_name: str, ) -> torch.Tensor: @@ -267,7 +268,7 @@ def unified_attention_fake( direct_register_custom_op( op_name="unified_attention", op_func=unified_attention, - mutates_args=["kv_cache"], + mutates_args=[], fake_impl=unified_attention_fake, dispatch_key=current_platform.dispatch_key, ) @@ -278,17 +279,17 @@ def unified_attention_with_output( key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, - kv_cache: torch.Tensor, attn_type: str, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.dynamic_forward_context - self = forward_context.static_forward_context[layer_name] + attn_metadata = forward_context.attn_metadata + ctx = forward_context.layers[layer_name] + self = ctx.attn_module self.impl.forward(query, key, value, - kv_cache, + ctx.kv_cache, attn_metadata, self._k_scale, self._v_scale, @@ -301,7 +302,6 @@ def unified_attention_with_output_fake( key: torch.Tensor, value: torch.Tensor, output: torch.Tensor, - kv_cache: torch.Tensor, attn_type: str, layer_name: str, ) -> None: @@ -311,7 +311,7 @@ def unified_attention_with_output_fake( direct_register_custom_op( op_name="unified_attention_with_output", op_func=unified_attention_with_output, - mutates_args=["kv_cache", "output"], + mutates_args=["output"], fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/config.py b/vllm/config.py index e72c53b6130d..1f88b4fca57d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2616,6 +2616,12 @@ class CompilationLevel: PIECEWISE = 3 +@dataclass +class LayerForwardContext: + attn_module: Any # vllm.attention.layer.Attention + kv_cache: Any # torch.Tensor + + class CompilationConfig(BaseModel): """ Configuration for compilation. @@ -2769,9 +2775,9 @@ def model_post_init(self, __context: Any) -> None: inductor_hash_cache: Any = PrivateAttr # Per-model forward context - # Mainly used to store attention cls - # Map from layer name to the attention cls - static_forward_context: Dict[str, Any] = PrivateAttr + # Map from layer name to the layer's forward context, which stores + # attention cls and kv_cache + static_forward_context: Dict[str, LayerForwardContext] = PrivateAttr def compute_hash(self) -> str: """ diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 7f56575279e9..fb163e842529 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -2,14 +2,17 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional import torch import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.config import LayerForwardContext, VllmConfig from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + logger = init_logger(__name__) track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 @@ -21,9 +24,10 @@ @dataclass class ForwardContext: - static_forward_context: Dict[str, Any] + # copy from vllm_config.compilation_config.static_forward_context + layers: Dict[str, LayerForwardContext] # TODO: extend to support per-layer dynamic forward context - dynamic_forward_context: Any + attn_metadata: "AttentionMetadata" # set dynamically for each forward pass _forward_context: Optional[ForwardContext] = None @@ -38,34 +42,32 @@ def get_forward_context() -> ForwardContext: @contextmanager -def set_forward_context(context: Any, vllm_config: VllmConfig): +def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. """ global forward_start_time - need_to_track_batchsize = track_batchsize and context is not None + need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() global _forward_context prev_context = _forward_context _forward_context = ForwardContext( - static_forward_context=vllm_config.compilation_config. - static_forward_context, - dynamic_forward_context=context) + layers=vllm_config.compilation_config.static_forward_context, + attn_metadata=attn_metadata) try: yield finally: - global batchsize_counter global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: - if hasattr(context, "num_prefill_tokens"): + if hasattr(attn_metadata, "num_prefill_tokens"): # for v0 attention backends - batchsize = context.num_prefill_tokens + \ - context.num_decode_tokens + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens else: # for v1 attention backends - batchsize = context.num_input_tokens + batchsize = attn_metadata.num_input_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 509771b7e2e5..b29847975394 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -13,6 +13,7 @@ from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, @@ -752,3 +753,13 @@ def initialize_kv_cache(self, num_blocks: int) -> None: torch.zeros(kv_cache_shape, dtype=self.kv_cache_dtype, device=self.device)) + # register kv_cache for forward_context + if self.vllm_config.parallel_config.pipeline_parallel_size > 1: + # TODO(Chen): In pipeline parallelism, layer_name 'layers.i.xxx' + # is mapped to kv_caches[i - start_layer_idx]. Need to implement + # and verify after supporting PP in v1 + raise NotImplementedError("Pipeline parallelism is not supported.") + ctx = self.vllm_config.compilation_config.static_forward_context + for layer_name, forward_ctx in ctx.items(): + layer_id = extract_layer_index(layer_name) + forward_ctx.kv_cache = self.kv_caches[layer_id] diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 7ccd4571b19d..499bd7a1cda6 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,11 +1,13 @@ """CacheEngine class for managing the KV cache.""" -from typing import List +from typing import Dict, List import torch from vllm.attention import get_attn_backend -from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig +from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, + LayerForwardContext, ModelConfig, ParallelConfig) from vllm.logger import init_logger +from vllm.model_executor.models.utils import extract_layer_index from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size, is_pin_memory_available) @@ -26,6 +28,7 @@ def __init__( model_config: ModelConfig, parallel_config: ParallelConfig, device_config: DeviceConfig, + compilation_config: CompilationConfig, ) -> None: self.cache_config = cache_config self.model_config = model_config @@ -62,6 +65,7 @@ def __init__( self.gpu_cache = self._allocate_kv_cache( self.num_gpu_blocks, self.device_config.device_type) self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") + self._register_gpu_kv_cache(compilation_config.static_forward_context) def _allocate_kv_cache( self, @@ -84,6 +88,14 @@ def _allocate_kv_cache( device=device)) return kv_cache + def _register_gpu_kv_cache(self, ctx: Dict[str, + LayerForwardContext]) -> None: + if self.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError + for layer_name, forward_ctx in ctx.items(): + layer_id = extract_layer_index(layer_name) + forward_ctx.kv_cache = self.gpu_cache[layer_id] + def swap_in(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_attention_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index cca7cd50bfc7..6dfa805717bf 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -208,7 +208,8 @@ def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ HPUCacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) + self.parallel_config, self.device_config, + self.compilation_config) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.hpu_cache = [ diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f51b51d433d3..0af5b0cc515c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -278,7 +278,8 @@ def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) + self.parallel_config, self.device_config, + self.compilation_config) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.gpu_cache = [ diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 249b3ed2dfd3..a835718e1db1 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -43,6 +43,7 @@ def __init__( self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.kv_transfer_config = vllm_config.kv_transfer_config + self.compilation_config = vllm_config.compilation_config from vllm.platforms import current_platform self.current_platform = current_platform From c5a5155454e8b2aa7bedecb9b49377e995b99829 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 1 Jan 2025 22:37:56 -0800 Subject: [PATCH 02/18] support pp & non-attn layers Signed-off-by: Chen Zhang --- tests/test_utils.py | 30 ++++++++++++++++++++++++++++-- vllm/utils.py | 14 +++++++++++++- vllm/v1/worker/gpu_model_runner.py | 16 +++++----------- vllm/worker/cache_engine.py | 19 ++++++------------- 4 files changed, 52 insertions(+), 27 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 32a6b0aed66a..40d9bc11c393 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,14 +1,14 @@ import asyncio import os import socket -from typing import AsyncIterator, Tuple +from typing import TYPE_CHECKING, AsyncIterator, Tuple import pytest import torch from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, get_open_port, memory_profiling, merge_async_iterators, - supports_kw) + supports_kw, register_kv_cache) from .utils import error_on_warning, fork_new_process_for_each_test @@ -306,3 +306,29 @@ def test_memory_profiling(): del weights lib.cudaFree(handle1) lib.cudaFree(handle2) + + +def test_register_gpu_kv_cache(): + from vllm.config import LayerForwardContext + from vllm.attention import Attention + + # example from Jamba PP=2 + ctx = { + 'model.layers.20.attn': + LayerForwardContext( + attn_module=Attention(32, 128, 0.1), + kv_cache=None, + ), + 'model.layers.28.attn': + LayerForwardContext( + attn_module=Attention(32, 128, 0.1), + kv_cache=None, + ) + } + kv_cache = [ + torch.zeros((1, )), + torch.zeros((1, )), + ] + register_kv_cache(ctx, kv_cache) + assert ctx['model.layers.20.attn'].kv_cache is kv_cache[0] + assert ctx['model.layers.28.attn'].kv_cache is kv_cache[1] diff --git a/vllm/utils.py b/vllm/utils.py index 8ef07d2c326a..6ddf17dd0364 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -52,7 +52,7 @@ from vllm.logger import enable_trace_function_call, init_logger if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import LayerForwardContext, VllmConfig logger = init_logger(__name__) @@ -1947,3 +1947,15 @@ def get_mp_context(): _check_multiproc_method() mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD return multiprocessing.get_context(mp_method) + + +def register_kv_cache(ctx: Dict[str, "LayerForwardContext"], + kv_cache: List[torch.Tensor]) -> None: + # Two things needed to be handled here: + # 1. Some models have non-attention layers, e.g., Jamba + # 2. Pipeline parallelism, each rank only has a subset of layers + from vllm.model_executor.models.utils import extract_layer_index + layer_name_sorted = sorted(ctx.keys(), key=extract_layer_index) + for i, layer_name in enumerate(layer_name_sorted): + forward_ctx = ctx[layer_name] + forward_ctx.kv_cache = kv_cache[i] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b29847975394..d19f0b229777 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -13,11 +13,11 @@ from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.utils import extract_layer_index from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, cdiv, is_pin_memory_available) + LayerBlockType, cdiv, is_pin_memory_available, + register_kv_cache) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient @@ -754,12 +754,6 @@ def initialize_kv_cache(self, num_blocks: int) -> None: dtype=self.kv_cache_dtype, device=self.device)) # register kv_cache for forward_context - if self.vllm_config.parallel_config.pipeline_parallel_size > 1: - # TODO(Chen): In pipeline parallelism, layer_name 'layers.i.xxx' - # is mapped to kv_caches[i - start_layer_idx]. Need to implement - # and verify after supporting PP in v1 - raise NotImplementedError("Pipeline parallelism is not supported.") - ctx = self.vllm_config.compilation_config.static_forward_context - for layer_name, forward_ctx in ctx.items(): - layer_id = extract_layer_index(layer_name) - forward_ctx.kv_cache = self.kv_caches[layer_id] + register_kv_cache( + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 499bd7a1cda6..951d894b536e 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,15 +1,15 @@ """CacheEngine class for managing the KV cache.""" -from typing import Dict, List +from typing import List import torch from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, - LayerForwardContext, ModelConfig, ParallelConfig) + ModelConfig, ParallelConfig) from vllm.logger import init_logger -from vllm.model_executor.models.utils import extract_layer_index from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, - get_dtype_size, is_pin_memory_available) + get_dtype_size, is_pin_memory_available, + register_kv_cache) logger = init_logger(__name__) @@ -65,7 +65,8 @@ def __init__( self.gpu_cache = self._allocate_kv_cache( self.num_gpu_blocks, self.device_config.device_type) self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") - self._register_gpu_kv_cache(compilation_config.static_forward_context) + register_kv_cache(compilation_config.static_forward_context, + self.gpu_cache) def _allocate_kv_cache( self, @@ -88,14 +89,6 @@ def _allocate_kv_cache( device=device)) return kv_cache - def _register_gpu_kv_cache(self, ctx: Dict[str, - LayerForwardContext]) -> None: - if self.parallel_config.pipeline_parallel_size > 1: - raise NotImplementedError - for layer_name, forward_ctx in ctx.items(): - layer_id = extract_layer_index(layer_name) - forward_ctx.kv_cache = self.gpu_cache[layer_id] - def swap_in(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_attention_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], From de8324bd1169234615d3c6fec031fd93f9b86c91 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 1 Jan 2025 22:55:58 -0800 Subject: [PATCH 03/18] format Signed-off-by: Chen Zhang --- tests/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 40d9bc11c393..c862ae568e5d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,14 +1,14 @@ import asyncio import os import socket -from typing import TYPE_CHECKING, AsyncIterator, Tuple +from typing import AsyncIterator, Tuple import pytest import torch from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, get_open_port, memory_profiling, merge_async_iterators, - supports_kw, register_kv_cache) + register_kv_cache, supports_kw) from .utils import error_on_warning, fork_new_process_for_each_test @@ -309,8 +309,8 @@ def test_memory_profiling(): def test_register_gpu_kv_cache(): - from vllm.config import LayerForwardContext from vllm.attention import Attention + from vllm.config import LayerForwardContext # example from Jamba PP=2 ctx = { From fa9b0bb5c4e2ad1f79bba7632ac43401ff471315 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 1 Jan 2025 23:29:36 -0800 Subject: [PATCH 04/18] update cpu engine Signed-off-by: Chen Zhang --- vllm/worker/cpu_worker.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index b5dfebfce6f7..b33a059ac473 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -6,14 +6,14 @@ import vllm.envs as envs from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, VllmConfig) +from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, + ModelConfig, ParallelConfig, VllmConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, register_kv_cache from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner @@ -33,8 +33,8 @@ class CPUCacheEngine: """ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, - parallel_config: ParallelConfig, - device_config: DeviceConfig) -> None: + parallel_config: ParallelConfig, device_config: DeviceConfig, + compilation_config: CompilationConfig) -> None: assert device_config.device_type == "cpu" self.cache_config = cache_config self.model_config = model_config @@ -66,6 +66,8 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) + register_kv_cache(compilation_config.static_forward_context, + self.cpu_cache) def _allocate_kv_cache( self, @@ -285,9 +287,13 @@ def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: def _init_cache_engine(self) -> None: self.cache_engine = [ - CPUCacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) - for _ in range(self.parallel_config.pipeline_parallel_size) + CPUCacheEngine( + self.cache_config, + self.model_config, + self.parallel_config, + self.device_config, + self.compilation_config, + ) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.cpu_cache = [ self.cache_engine[ve].cpu_cache From 3bb7d6d6c2e5da9f58672466274307724569f41a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Jan 2025 19:21:06 -0800 Subject: [PATCH 05/18] support encoder-decoder and move kv_cache to Attention Signed-off-by: Chen Zhang --- tests/test_utils.py | 68 +++++++++++++++++++++++------- vllm/attention/layer.py | 21 ++++----- vllm/config.py | 5 +-- vllm/utils.py | 31 +++++++++++--- vllm/v1/worker/gpu_model_runner.py | 7 ++- vllm/worker/cache_engine.py | 7 ++- vllm/worker/cpu_worker.py | 6 +-- 7 files changed, 96 insertions(+), 49 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index c862ae568e5d..e049bfd59726 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,9 +6,9 @@ import pytest import torch -from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, - get_open_port, memory_profiling, merge_async_iterators, - register_kv_cache, supports_kw) +from vllm.utils import (FlexibleArgumentParser, StoreBoolean, bind_kv_cache, + deprecate_kwargs, get_open_port, memory_profiling, + merge_async_iterators, supports_kw) from .utils import error_on_warning, fork_new_process_for_each_test @@ -308,27 +308,63 @@ def test_memory_profiling(): lib.cudaFree(handle2) -def test_register_gpu_kv_cache(): +def test_bind_kv_cache(): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + 'layers.1.self_attn': Attention(32, 128, 0.1), + 'layers.2.self_attn': Attention(32, 128, 0.1), + 'layers.3.self_attn': Attention(32, 128, 0.1), + } + kv_cache = [ + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + ] + bind_kv_cache(ctx, kv_cache) + assert ctx['layers.0.self_attn'].kv_cache is kv_cache[0] + assert ctx['layers.1.self_attn'].kv_cache is kv_cache[1] + assert ctx['layers.2.self_attn'].kv_cache is kv_cache[2] + assert ctx['layers.3.self_attn'].kv_cache is kv_cache[3] + +def test_bind_kv_cache_non_attention(): from vllm.attention import Attention - from vllm.config import LayerForwardContext # example from Jamba PP=2 ctx = { - 'model.layers.20.attn': - LayerForwardContext( - attn_module=Attention(32, 128, 0.1), - kv_cache=None, - ), - 'model.layers.28.attn': - LayerForwardContext( - attn_module=Attention(32, 128, 0.1), - kv_cache=None, - ) + 'model.layers.20.attn': Attention(32, 128, 0.1), + 'model.layers.28.attn': Attention(32, 128, 0.1), } kv_cache = [ torch.zeros((1, )), torch.zeros((1, )), ] - register_kv_cache(ctx, kv_cache) + bind_kv_cache(ctx, kv_cache) assert ctx['model.layers.20.attn'].kv_cache is kv_cache[0] assert ctx['model.layers.28.attn'].kv_cache is kv_cache[1] + + +def test_bind_kv_cache_encoder_decoder(): + from vllm.attention import Attention, AttentionType + + # example from bart + ctx = { + 'encoder.layers.0.self_attn.attn': + Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), + 'decoder.layers.0.encoder_attn.attn': + Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), + 'decoder.layers.0.self_attn.attn': + Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), + } + + kv_cache = [ + torch.zeros((1, )), + ] + encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache + + bind_kv_cache(ctx, kv_cache) + assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache + assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache is kv_cache[0] + assert ctx['decoder.layers.0.self_attn.attn'].kv_cache is kv_cache[0] diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3cad3080fc1b..9196309278b6 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -7,8 +7,7 @@ from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend -from vllm.config import (CacheConfig, LayerForwardContext, - get_current_vllm_config) +from vllm.config import CacheConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -119,12 +118,12 @@ def __init__( compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - # use a placeholder kv cache tensor during init, which will be replaced - # after kv cache initialization - compilation_config.static_forward_context[ - prefix] = LayerForwardContext(self, torch.tensor([])) + compilation_config.static_forward_context[prefix] = self self.layer_name = prefix self.attn_type = attn_type + # use a placeholder kv cache tensor during init, which will be replaced + # by bind_kv_cache + self.kv_cache = torch.tensor([]) def forward( self, @@ -238,9 +237,8 @@ def unified_attention( ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - ctx = forward_context.layers[layer_name] - self = ctx.attn_module - return self.impl.forward(query, key, value, ctx.kv_cache, attn_metadata, + self = forward_context.layers[layer_name] + return self.impl.forward(query, key, value, self.kv_cache, attn_metadata, self._k_scale, self._v_scale) @@ -271,12 +269,11 @@ def unified_attention_with_output( ) -> None: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - ctx = forward_context.layers[layer_name] - self = ctx.attn_module + self = forward_context.layers[layer_name] self.impl.forward(query, key, value, - ctx.kv_cache, + self.kv_cache, attn_metadata, self._k_scale, self._v_scale, diff --git a/vllm/config.py b/vllm/config.py index 56772b8a71f2..d63bd2463db8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2772,9 +2772,8 @@ def model_post_init(self, __context: Any) -> None: inductor_hash_cache: Any = PrivateAttr # Per-model forward context - # Map from layer name to the layer's forward context, which stores - # attention cls and kv_cache - static_forward_context: Dict[str, LayerForwardContext] = PrivateAttr + # Map from layer name to the attention cls + static_forward_context: Dict[str, Any] = PrivateAttr def compute_hash(self) -> str: """ diff --git a/vllm/utils.py b/vllm/utils.py index 8c7f1592718c..cc6cf37c7029 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -52,7 +52,7 @@ from vllm.logger import enable_trace_function_call, init_logger if TYPE_CHECKING: - from vllm.config import LayerForwardContext, VllmConfig + from vllm.config import VllmConfig logger = init_logger(__name__) @@ -1949,13 +1949,30 @@ def get_mp_context(): return multiprocessing.get_context(mp_method) -def register_kv_cache(ctx: Dict[str, "LayerForwardContext"], - kv_cache: List[torch.Tensor]) -> None: - # Two things needed to be handled here: +def bind_kv_cache(ctx: Dict[str, Any], kv_cache: List[torch.Tensor]) -> None: + # Bind the kv_cache tensor to Attention modules, similar to + # ctx[layer_name].kv_cache = kv_cache[extract_layer_index(layer_name)] + # Special things handled here: # 1. Some models have non-attention layers, e.g., Jamba # 2. Pipeline parallelism, each rank only has a subset of layers + # 3. Encoder attention has no kv cache + # 3. Encoder-decoder models, e.g., Bart, encoder-decoder attention and + # decoder-only attention of the same layer (e.g., bart's + # decoder.layers.1.self_attn and decoder.layers.1.encoder_attn is mapped + # to the same kv cache tensor + from vllm.attention import AttentionType from vllm.model_executor.models.utils import extract_layer_index - layer_name_sorted = sorted(ctx.keys(), key=extract_layer_index) - for i, layer_name in enumerate(layer_name_sorted): + layer_need_kv_cache = [ + layer_name for layer_name in ctx + if ctx[layer_name].attn_type in (AttentionType.DECODER, + AttentionType.ENCODER_DECODER) + ] + layer_index_sorted = sorted( + set( + extract_layer_index(layer_name) + for layer_name in layer_need_kv_cache)) + for layer_name in layer_need_kv_cache: + kv_cache_idx = layer_index_sorted.index( + extract_layer_index(layer_name)) forward_ctx = ctx[layer_name] - forward_ctx.kv_cache = kv_cache[i] + forward_ctx.kv_cache = kv_cache[kv_cache_idx] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 69558055ffcb..93045a6b5a09 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -16,8 +16,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.sampling_params import SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, cdiv, is_pin_memory_available, - register_kv_cache) + LayerBlockType, bind_kv_cache, cdiv, + is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient @@ -859,7 +859,6 @@ def initialize_kv_cache(self, num_blocks: int) -> None: torch.zeros(kv_cache_shape, dtype=self.kv_cache_dtype, device=self.device)) - # register kv_cache for forward_context - register_kv_cache( + bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, self.kv_caches) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 951d894b536e..b10584b0d5a6 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -8,8 +8,7 @@ ModelConfig, ParallelConfig) from vllm.logger import init_logger from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, - get_dtype_size, is_pin_memory_available, - register_kv_cache) + bind_kv_cache, get_dtype_size, is_pin_memory_available) logger = init_logger(__name__) @@ -65,8 +64,8 @@ def __init__( self.gpu_cache = self._allocate_kv_cache( self.num_gpu_blocks, self.device_config.device_type) self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") - register_kv_cache(compilation_config.static_forward_context, - self.gpu_cache) + bind_kv_cache(compilation_config.static_forward_context, + self.gpu_cache) def _allocate_kv_cache( self, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index b33a059ac473..92f519bdff91 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,7 +13,7 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, register_kv_cache +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner @@ -66,8 +66,8 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) - register_kv_cache(compilation_config.static_forward_context, - self.cpu_cache) + bind_kv_cache(compilation_config.static_forward_context, + self.cpu_cache) def _allocate_kv_cache( self, From 7a3c15439dbd671ca7d5d91d9b1b67ee71a36073 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Jan 2025 21:43:10 -0800 Subject: [PATCH 06/18] fix bug Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9196309278b6..74df82355151 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -155,7 +155,7 @@ def forward( return output.view(-1, hidden_size) else: return torch.ops.vllm.unified_attention(query, key, value, - kv_cache, self.layer_name) + self.layer_name) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore From 9f4794252b8190110d3245a656b9e6ebede0dd28 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Jan 2025 21:50:24 -0800 Subject: [PATCH 07/18] update comment Signed-off-by: Chen Zhang --- vllm/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index cc6cf37c7029..ed67abffbc8a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1956,10 +1956,9 @@ def bind_kv_cache(ctx: Dict[str, Any], kv_cache: List[torch.Tensor]) -> None: # 1. Some models have non-attention layers, e.g., Jamba # 2. Pipeline parallelism, each rank only has a subset of layers # 3. Encoder attention has no kv cache - # 3. Encoder-decoder models, e.g., Bart, encoder-decoder attention and - # decoder-only attention of the same layer (e.g., bart's - # decoder.layers.1.self_attn and decoder.layers.1.encoder_attn is mapped - # to the same kv cache tensor + # 4. Encoder-decoder models, encoder-decoder attention and decoder-only + # attention of the same layer (e.g., bart's decoder.layers.1.self_attn + # and decoder.layers.1.encoder_attn is mapped to the same kv cache tensor from vllm.attention import AttentionType from vllm.model_executor.models.utils import extract_layer_index layer_need_kv_cache = [ From 4418608c8b060249fab96364d9c1debfe27525be Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Jan 2025 21:54:34 -0800 Subject: [PATCH 08/18] update format Signed-off-by: Chen Zhang --- tests/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index e049bfd59726..ba96f005e69f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -312,10 +312,10 @@ def test_bind_kv_cache(): from vllm.attention import Attention ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), - 'layers.1.self_attn': Attention(32, 128, 0.1), - 'layers.2.self_attn': Attention(32, 128, 0.1), - 'layers.3.self_attn': Attention(32, 128, 0.1), + 'layers.0.self_attn': Attention(32, 128, 0.1), + 'layers.1.self_attn': Attention(32, 128, 0.1), + 'layers.2.self_attn': Attention(32, 128, 0.1), + 'layers.3.self_attn': Attention(32, 128, 0.1), } kv_cache = [ torch.zeros((1, )), From 3590e550e931043a3217dc4cb909a638b8fc613c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 6 Jan 2025 21:57:23 -0800 Subject: [PATCH 09/18] remove unused code Signed-off-by: Chen Zhang --- vllm/config.py | 6 ------ vllm/forward_context.py | 4 ++-- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d63bd2463db8..8d7c199d75f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2613,12 +2613,6 @@ class CompilationLevel: PIECEWISE = 3 -@dataclass -class LayerForwardContext: - attn_module: Any # vllm.attention.layer.Attention - kv_cache: Any # torch.Tensor - - class CompilationConfig(BaseModel): """ Configuration for compilation. diff --git a/vllm/forward_context.py b/vllm/forward_context.py index fb163e842529..7c997b106cd6 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -7,7 +7,7 @@ import torch import vllm.envs as envs -from vllm.config import LayerForwardContext, VllmConfig +from vllm.config import VllmConfig from vllm.logger import init_logger if TYPE_CHECKING: @@ -25,7 +25,7 @@ @dataclass class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context - layers: Dict[str, LayerForwardContext] + layers: Dict[str, Any] # TODO: extend to support per-layer dynamic forward context attn_metadata: "AttentionMetadata" # set dynamically for each forward pass From beb0dee65a58811b905f595aa500fe3890f4a896 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 7 Jan 2025 06:17:01 -0800 Subject: [PATCH 10/18] layers->attn_layers Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 4 ++-- vllm/forward_context.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 74df82355151..7bf28b9b4bd9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -237,7 +237,7 @@ def unified_attention( ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - self = forward_context.layers[layer_name] + self = forward_context.attn_layers[layer_name] return self.impl.forward(query, key, value, self.kv_cache, attn_metadata, self._k_scale, self._v_scale) @@ -269,7 +269,7 @@ def unified_attention_with_output( ) -> None: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - self = forward_context.layers[layer_name] + self = forward_context.attn_layers[layer_name] self.impl.forward(query, key, value, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 7c997b106cd6..5e9fe27fbdd9 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -25,7 +25,7 @@ @dataclass class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context - layers: Dict[str, Any] + attn_layers: Dict[str, Any] # TODO: extend to support per-layer dynamic forward context attn_metadata: "AttentionMetadata" # set dynamically for each forward pass @@ -54,7 +54,7 @@ def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig): global _forward_context prev_context = _forward_context _forward_context = ForwardContext( - layers=vllm_config.compilation_config.static_forward_context, + attn_layers=vllm_config.compilation_config.static_forward_context, attn_metadata=attn_metadata) try: yield From ffe8cdd07d96327ef3bf2e80bf63777e35a59932 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 7 Jan 2025 18:46:19 -0800 Subject: [PATCH 11/18] update test Signed-off-by: Chen Zhang --- tests/kernels/test_encoder_decoder_attn.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 614674375786..f70120ae202d 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -142,12 +142,18 @@ class that Attention will automatically select when it is constructed. torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) # Construct KV cache - kv_cache = make_kv_cache(test_pt.num_blocks, - test_pt.num_heads, - test_pt.head_size, - test_pt.block_size, - device=CUDA_DEVICE, - backend=test_pt.backend_name) + if test_pt.attn_type in (AttentionType.DECODER, + AttentionType.ENCODER_DECODER): + kv_cache = make_kv_cache(test_pt.num_blocks, + test_pt.num_heads, + test_pt.head_size, + test_pt.block_size, + device=CUDA_DEVICE, + backend=test_pt.backend_name) + else: + kv_cache = torch.tensor([]) + + attn.kv_cache = kv_cache return TestResources(scale, attn, kv_cache) From 2cb84f2fc770bfd69bb63bc033f5226ec5a734a8 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 8 Jan 2025 05:58:06 -0800 Subject: [PATCH 12/18] support pp virtual engine Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 12 +++++++++--- vllm/forward_context.py | 7 ++++++- vllm/utils.py | 15 +++++++++++---- vllm/v1/worker/gpu_model_runner.py | 2 +- vllm/worker/cache_engine.py | 18 +++++------------- vllm/worker/cpu_enc_dec_model_runner.py | 3 ++- vllm/worker/cpu_model_runner.py | 3 ++- vllm/worker/cpu_pooling_model_runner.py | 3 ++- vllm/worker/cpu_worker.py | 14 +++++++------- vllm/worker/enc_dec_model_runner.py | 3 ++- vllm/worker/model_runner.py | 5 +++-- vllm/worker/pooling_model_runner.py | 3 ++- vllm/worker/worker.py | 7 ++++--- 13 files changed, 56 insertions(+), 39 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 7bf28b9b4bd9..55e4e14027f7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -123,7 +123,11 @@ def __init__( self.attn_type = attn_type # use a placeholder kv cache tensor during init, which will be replaced # by bind_kv_cache - self.kv_cache = torch.tensor([]) + # this variable will not be accessed if use_direct_call is True + self.kv_cache = [ + torch.tensor([]) for _ in range(get_current_vllm_config( + ).parallel_config.pipeline_parallel_size) + ] def forward( self, @@ -238,7 +242,8 @@ def unified_attention( forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] - return self.impl.forward(query, key, value, self.kv_cache, attn_metadata, + kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward(query, key, value, kv_cache, attn_metadata, self._k_scale, self._v_scale) @@ -270,10 +275,11 @@ def unified_attention_with_output( forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] + kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(query, key, value, - self.kv_cache, + kv_cache, attn_metadata, self._k_scale, self._v_scale, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 5e9fe27fbdd9..828b394ec5d2 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -28,6 +28,8 @@ class ForwardContext: attn_layers: Dict[str, Any] # TODO: extend to support per-layer dynamic forward context attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + # TODO: remove after making all virtual_engines share the same kv cache + virtual_engine: int # set dynamically for each forward pass _forward_context: Optional[ForwardContext] = None @@ -42,7 +44,9 @@ def get_forward_context() -> ForwardContext: @contextmanager -def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig): +def set_forward_context(attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -55,6 +59,7 @@ def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig): prev_context = _forward_context _forward_context = ForwardContext( attn_layers=vllm_config.compilation_config.static_forward_context, + virtual_engine=virtual_engine, attn_metadata=attn_metadata) try: yield diff --git a/vllm/utils.py b/vllm/utils.py index ed67abffbc8a..7643fc7223a6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1949,16 +1949,20 @@ def get_mp_context(): return multiprocessing.get_context(mp_method) -def bind_kv_cache(ctx: Dict[str, Any], kv_cache: List[torch.Tensor]) -> None: +def bind_kv_cache( + ctx: Dict[str, Any], + kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index] +) -> None: # Bind the kv_cache tensor to Attention modules, similar to - # ctx[layer_name].kv_cache = kv_cache[extract_layer_index(layer_name)] + # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] # Special things handled here: # 1. Some models have non-attention layers, e.g., Jamba # 2. Pipeline parallelism, each rank only has a subset of layers # 3. Encoder attention has no kv cache # 4. Encoder-decoder models, encoder-decoder attention and decoder-only # attention of the same layer (e.g., bart's decoder.layers.1.self_attn - # and decoder.layers.1.encoder_attn is mapped to the same kv cache tensor + # and decoder.layers.1.encoder_attn) is mapped to the same kv cache + # tensor from vllm.attention import AttentionType from vllm.model_executor.models.utils import extract_layer_index layer_need_kv_cache = [ @@ -1974,4 +1978,7 @@ def bind_kv_cache(ctx: Dict[str, Any], kv_cache: List[torch.Tensor]) -> None: kv_cache_idx = layer_index_sorted.index( extract_layer_index(layer_name)) forward_ctx = ctx[layer_name] - forward_ctx.kv_cache = kv_cache[kv_cache_idx] + assert len(forward_ctx.kv_cache) == len(kv_cache) + for ve, ve_kv_cache in enumerate(kv_cache): + assert forward_ctx.kv_cache[ve].numel() == 0 + forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a6649b982fc..fb87dc5a8222 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -863,4 +863,4 @@ def initialize_kv_cache(self, num_blocks: int) -> None: device=self.device)) bind_kv_cache( self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + [self.kv_caches]) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index b10584b0d5a6..3a297f3c41a1 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -4,11 +4,10 @@ import torch from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, - ModelConfig, ParallelConfig) +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, - bind_kv_cache, get_dtype_size, is_pin_memory_available) + get_dtype_size, is_pin_memory_available) logger = init_logger(__name__) @@ -21,14 +20,9 @@ class CacheEngine: as swapping and copying. """ - def __init__( - self, - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, - device_config: DeviceConfig, - compilation_config: CompilationConfig, - ) -> None: + def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig) -> None: self.cache_config = cache_config self.model_config = model_config self.parallel_config = parallel_config @@ -64,8 +58,6 @@ def __init__( self.gpu_cache = self._allocate_kv_cache( self.num_gpu_blocks, self.device_config.device_type) self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") - bind_kv_cache(compilation_config.static_forward_context, - self.gpu_cache) def _allocate_kv_cache( self, diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index cc24cfe04d2b..fa6775cbd6c6 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -305,7 +305,8 @@ def execute_model( intermediate_tensors, } - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f1531e0fc067..d99db4e0c6c4 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -526,7 +526,8 @@ def execute_model( execute_model_kwargs.update( {"previous_hidden_states": previous_hidden_states}) - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py index 17b2fd2564a0..d31ba89e1237 100644 --- a/vllm/worker/cpu_pooling_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -69,7 +69,8 @@ def execute_model( intermediate_tensors, } - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 92f519bdff91..90b08d6197ad 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -6,8 +6,8 @@ import vllm.envs as envs from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, - ModelConfig, ParallelConfig, VllmConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, VllmConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -33,8 +33,8 @@ class CPUCacheEngine: """ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, - parallel_config: ParallelConfig, device_config: DeviceConfig, - compilation_config: CompilationConfig) -> None: + parallel_config: ParallelConfig, + device_config: DeviceConfig) -> None: assert device_config.device_type == "cpu" self.cache_config = cache_config self.model_config = model_config @@ -66,8 +66,6 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, # Initialize the cache. self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) - bind_kv_cache(compilation_config.static_forward_context, - self.cpu_cache) def _allocate_kv_cache( self, @@ -292,13 +290,15 @@ def _init_cache_engine(self) -> None: self.model_config, self.parallel_config, self.device_config, - self.compilation_config, ) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.cpu_cache = [ self.cache_engine[ve].cpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] + for ve in range(self.parallel_config.pipeline_parallel_size): + bind_kv_cache(self.compilation_config.static_forward_context, + self.cpu_cache[ve], ve) self.model_runner.block_size = self.cache_engine[0].block_size assert all( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4d5d918087be..8a161b740042 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -175,7 +175,8 @@ def execute_model( } if self.has_inner_state else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + model_input.virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1c6d1bbee78e..2b918483d367 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1527,7 +1527,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context(attn_metadata, self.vllm_config, + virtual_engine): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1695,7 +1696,7 @@ def execute_model( if not bypass_model_exec: with set_forward_context(model_input.attn_metadata, - self.vllm_config): + self.vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index f79b3773bcbd..6de227f3cb2b 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -105,7 +105,8 @@ def execute_model( if model_input.token_types is not None: cross_enc_kwargs["token_type_ids"] = model_input.token_types - with set_forward_context(model_input.attn_metadata, self.vllm_config): + with set_forward_context(model_input.attn_metadata, self.vllm_config, + virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0af5b0cc515c..0f12549e3f3f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -21,7 +21,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.utils import GiB_bytes, memory_profiling +from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -278,14 +278,15 @@ def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ CacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config, - self.compilation_config) + self.parallel_config, self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.gpu_cache = [ self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] + bind_kv_cache(self.compilation_config.static_forward_context, + self.gpu_cache) def _warm_up_model(self) -> None: if not self.model_config.enforce_eager: From 76712f83f5928e4cf478ac2c046e879ebddb17bf Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 8 Jan 2025 06:42:10 -0800 Subject: [PATCH 13/18] fix Signed-off-by: Chen Zhang --- tests/test_utils.py | 40 ++++++++++++++++++++++++++++----------- vllm/worker/hpu_worker.py | 3 +-- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index ba96f005e69f..209a09f9578b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,7 @@ import pytest import torch +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.utils import (FlexibleArgumentParser, StoreBoolean, bind_kv_cache, deprecate_kwargs, get_open_port, memory_profiling, merge_async_iterators, supports_kw) @@ -323,11 +324,11 @@ def test_bind_kv_cache(): torch.zeros((1, )), torch.zeros((1, )), ] - bind_kv_cache(ctx, kv_cache) - assert ctx['layers.0.self_attn'].kv_cache is kv_cache[0] - assert ctx['layers.1.self_attn'].kv_cache is kv_cache[1] - assert ctx['layers.2.self_attn'].kv_cache is kv_cache[2] - assert ctx['layers.3.self_attn'].kv_cache is kv_cache[3] + bind_kv_cache(ctx, [kv_cache]) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] + assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] + assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2] + assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3] def test_bind_kv_cache_non_attention(): from vllm.attention import Attention @@ -341,9 +342,9 @@ def test_bind_kv_cache_non_attention(): torch.zeros((1, )), torch.zeros((1, )), ] - bind_kv_cache(ctx, kv_cache) - assert ctx['model.layers.20.attn'].kv_cache is kv_cache[0] - assert ctx['model.layers.28.attn'].kv_cache is kv_cache[1] + bind_kv_cache(ctx, [kv_cache]) + assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0] + assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1] def test_bind_kv_cache_encoder_decoder(): @@ -364,7 +365,24 @@ def test_bind_kv_cache_encoder_decoder(): ] encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache - bind_kv_cache(ctx, kv_cache) + bind_kv_cache(ctx, [kv_cache]) assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache - assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache is kv_cache[0] - assert ctx['decoder.layers.0.self_attn.attn'].kv_cache is kv_cache[0] + assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0] + assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0] + + +def test_bind_kv_cache_pp(): + cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2)) + with set_current_vllm_config(cfg): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + } + kv_cache = [ + [torch.zeros((1, ))], + [torch.zeros((1, ))] + ] + bind_kv_cache(ctx, kv_cache) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0] + assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0] diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 6dfa805717bf..cca7cd50bfc7 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -208,8 +208,7 @@ def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = [ HPUCacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config, - self.compilation_config) + self.parallel_config, self.device_config) for _ in range(self.parallel_config.pipeline_parallel_size) ] self.hpu_cache = [ From bab3bea6017d3a3904db2c40fff5e1f98c454e2d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 8 Jan 2025 07:10:09 -0800 Subject: [PATCH 14/18] revert unrealted change Signed-off-by: Chen Zhang --- vllm/worker/cache_engine.py | 10 +++++++--- vllm/worker/cpu_worker.py | 9 +++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 3a297f3c41a1..7ccd4571b19d 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -20,9 +20,13 @@ class CacheEngine: as swapping and copying. """ - def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, - parallel_config: ParallelConfig, - device_config: DeviceConfig) -> None: + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, + ) -> None: self.cache_config = cache_config self.model_config = model_config self.parallel_config = parallel_config diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 90b08d6197ad..fb2b22bcfbc4 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -285,12 +285,9 @@ def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: def _init_cache_engine(self) -> None: self.cache_engine = [ - CPUCacheEngine( - self.cache_config, - self.model_config, - self.parallel_config, - self.device_config, - ) for _ in range(self.parallel_config.pipeline_parallel_size) + CPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) ] self.cpu_cache = [ self.cache_engine[ve].cpu_cache From 10f5353ba26928a08d577c82d25d6e0eb282676a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 8 Jan 2025 07:10:52 -0800 Subject: [PATCH 15/18] fix test Signed-off-by: Chen Zhang --- tests/kernels/test_encoder_decoder_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index f70120ae202d..e008a56de620 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -153,7 +153,7 @@ class that Attention will automatically select when it is constructed. else: kv_cache = torch.tensor([]) - attn.kv_cache = kv_cache + attn.kv_cache = [kv_cache] return TestResources(scale, attn, kv_cache) From dcded5b78f525c5c48ab0801fe9482fa35efe6fa Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 9 Jan 2025 00:50:23 -0800 Subject: [PATCH 16/18] fork new process for v1 teste Signed-off-by: Chen Zhang --- tests/v1/engine/test_engine_core.py | 3 +++ tests/v1/engine/test_engine_core_client.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 8dd9b23fbdd5..5b1732036e80 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -4,6 +4,7 @@ import pytest from transformers import AutoTokenizer +from tests.utils import fork_new_process_for_each_test from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform @@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest: ) +@fork_new_process_for_each_test def test_engine_core(monkeypatch): with monkeypatch.context() as m: @@ -138,6 +140,7 @@ def test_engine_core(monkeypatch): assert len(engine_core.scheduler.running) == 0 +@fork_new_process_for_each_test def test_engine_core_advanced_sampling(monkeypatch): """ A basic end-to-end test to verify that the engine functions correctly diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 5a21806e57a1..7eac16f2cf54 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -6,6 +6,7 @@ import pytest from transformers import AutoTokenizer +from tests.utils import fork_new_process_for_each_test from vllm import SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform @@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict): break +@fork_new_process_for_each_test @pytest.mark.parametrize("multiprocessing_mode", [True, False]) def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): @@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): client.abort_requests([request.request_id]) +@fork_new_process_for_each_test @pytest.mark.asyncio async def test_engine_core_client_asyncio(monkeypatch): From 616b36ccbcdb3d272b40dcb9fa43b0fe3599453d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 9 Jan 2025 18:12:19 -0800 Subject: [PATCH 17/18] fix bug in cpu test Signed-off-by: Chen Zhang --- vllm/worker/cpu_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index fb2b22bcfbc4..c45fb24bc682 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -295,7 +295,7 @@ def _init_cache_engine(self) -> None: ] for ve in range(self.parallel_config.pipeline_parallel_size): bind_kv_cache(self.compilation_config.static_forward_context, - self.cpu_cache[ve], ve) + self.cpu_cache[ve]) self.model_runner.block_size = self.cache_engine[0].block_size assert all( From ff088e18775d817c1324eda064f235ced2386b39 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 9 Jan 2025 18:38:43 -0800 Subject: [PATCH 18/18] fix bug in cpu test Signed-off-by: Chen Zhang --- vllm/worker/cpu_worker.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index c45fb24bc682..494c6506f3c0 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -293,9 +293,8 @@ def _init_cache_engine(self) -> None: self.cache_engine[ve].cpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] - for ve in range(self.parallel_config.pipeline_parallel_size): - bind_kv_cache(self.compilation_config.static_forward_context, - self.cpu_cache[ve]) + bind_kv_cache(self.compilation_config.static_forward_context, + self.cpu_cache) self.model_runner.block_size = self.cache_engine[0].block_size assert all(