Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clients/python/src/taskbroker_client/worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ def _connect_to_host(self, host: str) -> ConsumerServiceStub:
channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets))
return ConsumerServiceStub(channel)

def emit_health_check(self) -> None:
self._emit_health_check()

def update_task(
self,
processing_result: ProcessingResult,
Expand Down
44 changes: 44 additions & 0 deletions clients/python/src/taskbroker_client/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ def __init__(
self._metrics = app.metrics
self._concurrency = concurrency
self._grpc_sync_event = self._mp_context.Event()
self._health_check_sec_per_touch = (
None if health_check_file_path is None else health_check_sec_per_touch
)
self._health_check_stop_event = threading.Event()
self._health_check_thread: threading.Thread | None = None

self._setstatus_backoff_seconds = 0

Expand Down Expand Up @@ -236,6 +241,42 @@ def _send_result(

return None

def _start_health_check_thread(self) -> None:
if self._health_check_sec_per_touch is None:
return
if self._health_check_thread is not None and self._health_check_thread.is_alive():
return

health_check_sec_per_touch = self._health_check_sec_per_touch
self._health_check_stop_event.clear()

def health_check_thread() -> None:
logger.debug("taskworker.worker.health_check_thread.started")
while not self._health_check_stop_event.is_set():
try:
self.client.emit_health_check()
except Exception as e:
logger.warning(
"taskworker.worker.health_check.failed",
extra={
"error": e,
"processing_pool": self._processing_pool_name,
},
)

self._health_check_stop_event.wait(health_check_sec_per_touch)

self._health_check_thread = threading.Thread(
name="push-health-check", target=health_check_thread, daemon=True
)
self._health_check_thread.start()

def _stop_health_check_thread(self) -> None:
self._health_check_stop_event.set()
if self._health_check_thread is not None:
self._health_check_thread.join(timeout=5)
self._health_check_thread = None

def start(self) -> int:
"""
This starts the worker gRPC server.
Expand Down Expand Up @@ -294,6 +335,7 @@ def signal_handler(*args: Any) -> None:
health_servicer.set(WORKER_SERVICE_NAME, health_pb2.HealthCheckResponse.SERVING)

logger.info("taskworker.grpc_server.started", extra={"port": self._grpc_port})
self._start_health_check_thread()

try:
server.wait_for_termination()
Expand All @@ -309,6 +351,7 @@ def signal_handler(*args: Any) -> None:
if server is not None:
server.stop(grace=5)

self._stop_health_check_thread()
self.worker_pool.shutdown()

return 0
Expand All @@ -317,6 +360,7 @@ def shutdown(self) -> None:
"""
Shutdown the worker.
"""
self._stop_health_check_thread()
self._grpc_sync_event.set()
self.worker_pool.shutdown()

Expand Down
24 changes: 24 additions & 0 deletions clients/python/tests/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from concurrent.futures import Future
from datetime import datetime
from multiprocessing import Event, get_context
from pathlib import Path
from typing import Any
from unittest import TestCase, mock

Expand Down Expand Up @@ -535,6 +536,29 @@ def test_constructor_push_mode(self) -> None:
self.assertEqual(taskworker._grpc_port, 50099)


def test_push_worker_health_check_touches_while_idle(tmp_path: Path) -> None:
taskworker = PushTaskWorker(
app_module="examples.app:app",
broker_service="127.0.0.1:50051",
max_child_task_count=100,
process_type="fork",
health_check_file_path=str(tmp_path / "health"),
health_check_sec_per_touch=0.01,
)

with mock.patch.object(taskworker.client, "emit_health_check") as mock_emit:
taskworker._start_health_check_thread()
try:
start = time.time()
while mock_emit.call_count < 2 and time.time() - start < 1:
time.sleep(0.01)
finally:
taskworker._stop_health_check_thread()

assert mock_emit.call_count >= 2
assert taskworker._health_check_thread is None


class TestWorkerServicer(TestCase):
def test_push_task_success(self) -> None:
taskworker = PushTaskWorker(
Expand Down
Loading