From b172ed819ae390af566a2b72e9abfc6e49364d45 Mon Sep 17 00:00:00 2001 From: Viraat Chandra Date: Wed, 18 Feb 2026 15:47:39 -0800 Subject: [PATCH 1/2] updates --- .../async_utils/transport/zmq/transport.py | 45 ++++++++++++++----- .../async_utils/transport/test_zmq.py | 31 ++++--------- 2 files changed, 44 insertions(+), 32 deletions(-) diff --git a/src/inference_endpoint/async_utils/transport/zmq/transport.py b/src/inference_endpoint/async_utils/transport/zmq/transport.py index 669ecf9d..caf56876 100644 --- a/src/inference_endpoint/async_utils/transport/zmq/transport.py +++ b/src/inference_endpoint/async_utils/transport/zmq/transport.py @@ -95,6 +95,19 @@ class _ZMQSocketConfig: recv_buffer_size: int = 4 * 1024 * 1024 # 4MB send_buffer_size: int = 4 * 1024 * 1024 # 4MB + def apply_recv(self, sock: zmq.Socket) -> None: + """Apply receiver socket options.""" + sock.setsockopt(zmq.LINGER, self.linger) + sock.setsockopt(zmq.RCVHWM, self.high_water_mark) + sock.setsockopt(zmq.RCVBUF, self.recv_buffer_size) + + def apply_send(self, sock: zmq.Socket) -> None: + """Apply sender socket options.""" + sock.setsockopt(zmq.LINGER, self.linger) + sock.setsockopt(zmq.SNDHWM, self.high_water_mark) + sock.setsockopt(zmq.SNDBUF, self.send_buffer_size) + sock.setsockopt(zmq.IMMEDIATE, self.immediate) + class _ZmqReceiverTransport(ReceiverTransport): """ @@ -122,6 +135,8 @@ class _ZmqReceiverTransport(ReceiverTransport): "_waiter", "_closing", "_soon_call", + "_recv_buf", + "_recv_view", ) def __init__( @@ -140,6 +155,13 @@ def __init__( self._closing = False self._soon_call: asyncio.Handle | None = None + # NOTE(vir): + # zmq recv_into with Pre-allocated buffer. + # msgspec can decode in-place, avoiding per-message bytes allocation. + recv_buf_size = sock.getsockopt(zmq.RCVBUF) + self._recv_buf = bytearray(recv_buf_size) + self._recv_view = memoryview(self._recv_buf) + self._loop.add_reader(self._fd, self._on_readable) def _on_readable(self) -> None: @@ -170,10 +192,18 @@ def _on_readable(self) -> None: return count = 0 + recv_buf = self._recv_buf + recv_view = self._recv_view + buf_len = len(recv_buf) try: while True: - data = self._sock.recv(zmq.NOBLOCK, copy=False, track=False) - self._deque.append(self._decoder.decode(data)) + nbytes = self._sock.recv_into(recv_buf, flags=zmq.NOBLOCK) + if nbytes > buf_len: + raise RuntimeError( + f"ZMQ message truncated ({nbytes} > {buf_len} bytes). " + f"Increase recv_buffer_size in _ZMQSocketConfig." + ) + self._deque.append(self._decoder.decode(recv_view[:nbytes])) count += 1 except zmq.Again: # Normal: no more messages @@ -181,7 +211,7 @@ def _on_readable(self) -> None: except zmq.ZMQError as e: if e.errno not in (errno.EAGAIN, errno.EINTR, errno.ENOTSOCK): logger.error(f"ZMQ recv error: {e}") - except Exception as e: + except msgspec.DecodeError as e: logger.error(f"Decode error: {e}") # Wake waiter once after draining (not per message) @@ -402,9 +432,7 @@ def _create_receiver( Configured receiver transport. """ sock = zmq_context.socket(zmq.PULL) - sock.setsockopt(zmq.LINGER, config.linger) - sock.setsockopt(zmq.RCVHWM, config.high_water_mark) - sock.setsockopt(zmq.RCVBUF, config.recv_buffer_size) + config.apply_recv(sock) if bind: sock.bind(address) @@ -429,10 +457,7 @@ def _create_sender( ) -> _ZmqSenderTransport: """Create a ZMQ sender transport.""" sock = zmq_context.socket(zmq.PUSH) - sock.setsockopt(zmq.LINGER, config.linger) - sock.setsockopt(zmq.SNDHWM, config.high_water_mark) - sock.setsockopt(zmq.SNDBUF, config.send_buffer_size) - sock.setsockopt(zmq.IMMEDIATE, config.immediate) + config.apply_send(sock) if bind: sock.bind(address) diff --git a/tests/performance/async_utils/transport/test_zmq.py b/tests/performance/async_utils/transport/test_zmq.py index 72f8792f..49d16d53 100644 --- a/tests/performance/async_utils/transport/test_zmq.py +++ b/tests/performance/async_utils/transport/test_zmq.py @@ -28,11 +28,11 @@ import msgspec import pytest import uvloop -import zmq from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.async_utils.transport.zmq.transport import ( - _ZmqReceiverTransport, - _ZmqSenderTransport, + _ZMQSocketConfig, + _create_receiver, + _create_sender, ) from inference_endpoint.core.types import Query, QueryResult, StreamChunk @@ -44,7 +44,6 @@ TEST_DURATION_SECONDS = 5.0 WARMUP_MESSAGES = 100 -BUFFER_SIZE = 10 * 1024 * 1024 # Payload sizes in chars PAYLOAD_SIZES_CHARS = [32, 128, 512, 1024, 4096, 16384, 32768] @@ -140,27 +139,15 @@ async def benchmark( loop = asyncio.get_running_loop() - with ManagedZMQContext.scoped(io_threads=4) as zmq_ctx: + config = _ZMQSocketConfig() + + with ManagedZMQContext.scoped(io_threads=config.io_threads) as zmq_ctx: with tempfile.TemporaryDirectory(prefix="zmq_") as tmp: addr = f"ipc://{tmp}/bench" - # Sender (main proc perspective) - push = zmq_ctx.socket(zmq.PUSH) - push.setsockopt(zmq.LINGER, -1) - push.setsockopt(zmq.SNDHWM, 0) - push.setsockopt(zmq.SNDBUF, BUFFER_SIZE) - push.setsockopt(zmq.IMMEDIATE, 1) - push.bind(addr) - sender = _ZmqSenderTransport(loop, push, msgspec.msgpack.Encoder()) - - # Receiver (worker proc perspective) - pull = zmq_ctx.socket(zmq.PULL) - pull.setsockopt(zmq.LINGER, -1) - pull.setsockopt(zmq.RCVHWM, 0) - pull.setsockopt(zmq.RCVBUF, BUFFER_SIZE) - pull.connect(addr) - receiver = _ZmqReceiverTransport( - loop, pull, msgspec.msgpack.Decoder(type=msg_type) + sender = _create_sender(loop, addr, zmq_ctx, config, bind=True) + receiver = _create_receiver( + loop, addr, zmq_ctx, config, msg_type, bind=False ) await asyncio.sleep(0.01) From 7d0926e4d90e9b65440acee88454e1845ddbc790 Mon Sep 17 00:00:00 2001 From: Viraat Chandra Date: Mon, 9 Mar 2026 15:50:08 -0700 Subject: [PATCH 2/2] ci fix --- tests/performance/async_utils/transport/test_zmq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/performance/async_utils/transport/test_zmq.py b/tests/performance/async_utils/transport/test_zmq.py index 49d16d53..ecf9015e 100644 --- a/tests/performance/async_utils/transport/test_zmq.py +++ b/tests/performance/async_utils/transport/test_zmq.py @@ -30,9 +30,9 @@ import uvloop from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.async_utils.transport.zmq.transport import ( - _ZMQSocketConfig, _create_receiver, _create_sender, + _ZMQSocketConfig, ) from inference_endpoint.core.types import Query, QueryResult, StreamChunk