Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update cpu engine
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
  • Loading branch information
heheda12345 committed Jan 2, 2025
commit fa9b0bb5c4e2ad1f79bba7632ac43401ff471315
22 changes: 14 additions & 8 deletions vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import vllm.envs as envs
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
ModelConfig, ParallelConfig, VllmConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, register_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
Expand All @@ -33,8 +33,8 @@ class CPUCacheEngine:
"""

def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
parallel_config: ParallelConfig,
device_config: DeviceConfig) -> None:
parallel_config: ParallelConfig, device_config: DeviceConfig,
compilation_config: CompilationConfig) -> None:
assert device_config.device_type == "cpu"
self.cache_config = cache_config
self.model_config = model_config
Expand Down Expand Up @@ -66,6 +66,8 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,

# Initialize the cache.
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
register_kv_cache(compilation_config.static_forward_context,
self.cpu_cache)

def _allocate_kv_cache(
self,
Expand Down Expand Up @@ -285,9 +287,13 @@ def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:

def _init_cache_engine(self) -> None:
self.cache_engine = [
CPUCacheEngine(self.cache_config, self.model_config,
self.parallel_config, self.device_config)
for _ in range(self.parallel_config.pipeline_parallel_size)
CPUCacheEngine(
self.cache_config,
self.model_config,
self.parallel_config,
self.device_config,
self.compilation_config,
) for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.cpu_cache = [
self.cache_engine[ve].cpu_cache
Expand Down