From aa292955fffb9a00485618fb295eea573a35748c Mon Sep 17 00:00:00 2001 From: Ben McKerry <110857332+bmckerry@users.noreply.github.com> Date: Wed, 17 Jun 2026 12:57:36 -0400 Subject: [PATCH] fix(TaskProducer): call close() on shutdown --- clients/python/src/taskbroker_client/types.py | 11 +++++++++++ .../src/taskbroker_client/worker/producer.py | 14 ++++++++++---- clients/python/tests/worker/test_producer.py | 5 +++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/clients/python/src/taskbroker_client/types.py b/clients/python/src/taskbroker_client/types.py index fbc1f4c6..1388bc1a 100644 --- a/clients/python/src/taskbroker_client/types.py +++ b/clients/python/src/taskbroker_client/types.py @@ -1,6 +1,7 @@ import contextlib import dataclasses from collections.abc import MutableMapping +from concurrent.futures import Future from typing import Any, Callable, Protocol from arroyo.backends.abstract import ProducerFuture @@ -42,6 +43,16 @@ def produce( ) -> ProducerFuture[BrokerValue[KafkaPayload]]: ... +class CloseableProducerProtocol(Protocol): + """Interface used by TaskProducer. Represents a producer that has a shutdown method.""" + + def produce( + self, dest: Topic | Partition, payload: KafkaPayload + ) -> ProducerFuture[BrokerValue[KafkaPayload]]: ... + + def close(self) -> Future[None]: ... + + ProducerFactory = Callable[[str], ProducerProtocol] """ A factory interface for resolving topics into a KafkaProducer diff --git a/clients/python/src/taskbroker_client/worker/producer.py b/clients/python/src/taskbroker_client/worker/producer.py index f3757598..607263d9 100644 --- a/clients/python/src/taskbroker_client/worker/producer.py +++ b/clients/python/src/taskbroker_client/worker/producer.py @@ -1,3 +1,4 @@ +import atexit from collections import deque from collections.abc import Callable from concurrent.futures import Future @@ -9,7 +10,7 @@ from taskbroker_client.constants import TASK_PRODUCER_MAX_PENDING_FUTURES from taskbroker_client.metrics import MetricsBackend, NoOpMetricsBackend -from taskbroker_client.types import ProducerProtocol +from taskbroker_client.types import CloseableProducerProtocol # This is global as TaskWorker needs to be able to call TaskProducer.collect_futures() # without a reference to a task's specific instance of TaskProducer. @@ -39,17 +40,18 @@ class TaskProducer: def __init__( self, name: str, - producer_factory: Callable[[], ProducerProtocol], + producer_factory: Callable[[], CloseableProducerProtocol], metrics_backend: MetricsBackend | None = None, ) -> None: self.name = name self._producer_factory = producer_factory - self._inner_producer: ProducerProtocol | None = None + self._inner_producer: CloseableProducerProtocol | None = None self.metrics = metrics_backend if metrics_backend is not None else NoOpMetricsBackend() - def _get(self) -> ProducerProtocol: + def _get(self) -> CloseableProducerProtocol: if self._inner_producer is None: self._inner_producer = self._producer_factory() + atexit.register(self._shutdown) return self._inner_producer def track_future(self, future: ProducerFuture[BrokerValue[KafkaPayload]]) -> None: @@ -99,3 +101,7 @@ def produce( "or instantiate your producer with `use_simple_futures=False`." ) ) + + def _shutdown(self) -> None: + if self._inner_producer is not None: + self._inner_producer.close().result() diff --git a/clients/python/tests/worker/test_producer.py b/clients/python/tests/worker/test_producer.py index 2efb42fb..e8eafd4b 100644 --- a/clients/python/tests/worker/test_producer.py +++ b/clients/python/tests/worker/test_producer.py @@ -36,6 +36,11 @@ def produce( future.set_result(make_broker_value()) return future + def close(self) -> Future[None]: + f: Future[None] = Future() + f.set_result(None) + return f + def get_dummy_producer(use_simple_futures: bool) -> DummyProducer: return DummyProducer(use_simple_futures=use_simple_futures)