Skip to content
Prev Previous commit
Next Next commit
stash
  • Loading branch information
rshaw@neuralmagic.com committed Dec 31, 2024
commit dfc9deed21aeac4f1fe4dfb5714b3bf3bc6edf0a
171 changes: 139 additions & 32 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
import pickle
import queue
import signal
import threading
import time
from abc import ABC, abstractmethod
from multiprocessing.connection import Connection
from typing import List, Tuple, Type
from typing import List, Optional, Tuple, Type

import psutil
import zmq
import zmq.asyncio
from msgspec import msgpack

from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.utils import get_exception_traceback, make_zmq_socket, zmq_socket_ctx
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine import (EngineCoreAbort, EngineCoreOutput,
EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.utils import BackgroundProcHandle
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -127,7 +127,8 @@ def step(self) -> List[EngineCoreOutput]:
return engine_core_outputs

def shutdown(self):
self.model_executor.shutdown()
pass
# self.model_executor.shutdown()

def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
Expand Down Expand Up @@ -164,6 +165,24 @@ def __init__(
# Send Readiness signal to EngineClient.
ready_pipe.send({"status": "READY"})

@staticmethod
def make_process(
vllm_config: VllmConfig,
executor_class: Type[Executor],
input_path: str,
output_path: str,
log_stats: bool,
) -> BackgroundProcHandle:
return BackgroundProcHandle(input_path=input_path,
output_path=output_path,
process_name="EngineCore",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"executor_class": executor_class,
"log_stats": log_stats,
})

@staticmethod
def run_engine_core(*args, **kwargs):
"""Launch EngineCore busy loop in background process."""
Expand Down Expand Up @@ -260,36 +279,18 @@ def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
self.add_request(request)
elif isinstance(request, EngineCoreProfile):
self.model_executor.profile(request.is_start)
elif isinstance(request, EngineCoreAbort):
self.abort_requests(request.request_ids)
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
self.abort_requests(request)
raise ValueError("Unknown request type: {request}")

def process_input_socket(self, input_path: str):
"""Input socket IO thread."""

# Msgpack serialization decoding.
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()

with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
request_type = type_frame.buffer
request_data = data_frame.buffer

# Deserialize the request data.
if request_type == EngineCoreRequestType.ADD.value:
request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
elif request_type == EngineCoreRequestType.PROFILE.value:
request = pickle.loads(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")

# Push to input queue for core busy loop.
request = socket.recv_pyobj()
self.input_queue.put_nowait(request)

def process_output_socket(self, output_path: str):
Expand All @@ -305,4 +306,110 @@ def process_output_socket(self, output_path: str):
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
encoder.encode_into(outputs, buffer)
socket.send_multipart((buffer, ), copy=False)
msg = (EngineCoreRequestType.FROM_ENGINE_CORE.value, buffer)
socket.send_multipart(msg, copy=False)


class EngineCoreClient(ABC):
"""Client used To interact with EngineCore."""

@abstractmethod
def get_output(self) -> List[EngineCoreOutput]:
...

@abstractmethod
def add_request(self, request: EngineCoreRequest) -> None:
...

@abstractmethod
def abort_requests(self, request_ids: List[str]) -> None:
...

@abstractmethod
def profile(self, is_start: bool = True) -> None:
...

@abstractmethod
def shutdown(self):
...


class InprocEngineCoreClient(EngineCoreClient):
"""
InprocClient: client for in-process EngineCore. Intended
for use in LLMEngine for V0-style add_request() and step()
EngineCore setup in this process (no busy loop).
* pushes EngineCoreRequest directly into the EngineCore
* pulls EngineCoreOutputs by stepping the EngineCore
"""

def __init__(self, engine_core: EngineCore):
self.engine_core = engine_core

def get_output(self) -> List[EngineCoreOutput]:
return self.engine_core.step()

def add_request(self, request: EngineCoreRequest) -> None:
self.engine_core.add_request(request)

def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)

def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)

def shutdown(self):
self.engine_core.shutdown()


class MpEngineCoreClient(EngineCoreClient):
"""
MPClient: client for multi-proc EngineCore.
EngineCore runs in a background process busy loop, getting
new EngineCoreRequests and returning EngineCoreOutputs

* pushes EngineCoreRequests via input_socket
* pulls EngineCoreOutputs via output_socket
"""

def __init__(
self,
input_path: str,
output_path: str,
proc_handle: Optional[BackgroundProcHandle] = None,
) -> None:

# Use msgpack for hotpath serialization.
self.decoder = msgpack.Decoder(EngineCoreOutputs)

# Setup ZMQ IO.
self.ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined]
self.input_socket = make_zmq_socket(self.ctx, input_path,
zmq.constants.PUSH)
self.output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL)

# Optionally hold the proc handle for cleanup at shutdown().
self.proc_handle = proc_handle

def get_output(self) -> List[EngineCoreOutput]:
# TODO(rob): use copy=False
(msg_type, msg_bytes) = self.output_socket.recv_multipart()
assert msg_type == EngineCoreRequestType.FROM_ENGINE_CORE.value
return self.decoder.decode(msg_bytes).outputs

def add_request(self, request: EngineCoreRequest) -> None:
self.input_socket.send_pyobj(request)

def abort_requests(self, request_ids: List[str]) -> None:
self.input_socket.send_pyobj(EngineCoreAbort(request_ids))

def profile(self, is_start: bool = True) -> None:
self.input_socket.send_pyobj(EngineCoreProfile(is_start))

def shutdown(self) -> None:
if hasattr(self, "ctx"):
self.ctx.destroy(linger=0)

if hasattr(self, "proc_handle") and self.proc_handle:
self.proc_handle.shutdown()
2 changes: 1 addition & 1 deletion vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)

def shutdown(self):
self.engine_core.shutdown()
pass

def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)
Expand Down