diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index c2d75d8206f..1b36807f133 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -17,7 +17,7 @@ from distributed import Lock from distributed.utils import is_valid_xml -from distributed.utils_test import gen_cluster, inc, lock_inc, slowinc +from distributed.utils_test import fetch_metrics, gen_cluster, inc, lock_inc, slowinc DEFAULT_ROUTES = dask.config.get("distributed.scheduler.http.routes") @@ -89,43 +89,32 @@ async def test_prefix(c, s, a, b): @gen_cluster(client=True, clean_kwargs={"threads": False}) async def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") - from prometheus_client.parser import text_string_to_metric_families - http_client = AsyncHTTPClient() + active_metrics = await fetch_metrics(s.http_server.port, "dask_scheduler_") - # request data twice since there once was a case where metrics got registered multiple times resulting in - # prometheus_client errors - for _ in range(2): - response = await http_client.fetch( - "http://localhost:%d/metrics" % s.http_server.port - ) - assert response.code == 200 - assert response.headers["Content-Type"] == "text/plain; version=0.0.4" + expected_metrics = { + "dask_scheduler_clients", + "dask_scheduler_desired_workers", + "dask_scheduler_workers", + "dask_scheduler_tasks", + "dask_scheduler_tasks_suspicious", + "dask_scheduler_tasks_forgotten", + } - txt = response.body.decode("utf8") - families = { - family.name: family for family in text_string_to_metric_families(txt) - } - assert "dask_scheduler_workers" in families + assert active_metrics.keys() == expected_metrics + assert active_metrics["dask_scheduler_clients"].samples[0].value == 1.0 - client = families["dask_scheduler_clients"] - assert client.samples[0].value == 1.0 + # request data twice since there once was a case where metrics got registered multiple times resulting in + # prometheus_client errors + await fetch_metrics(s.http_server.port, "dask_scheduler_") @gen_cluster(client=True, clean_kwargs={"threads": False}) async def test_prometheus_collect_task_states(c, s, a, b): pytest.importorskip("prometheus_client") - from prometheus_client.parser import text_string_to_metric_families - http_client = AsyncHTTPClient() - - async def fetch_metrics(): - port = s.http_server.port - response = await http_client.fetch(f"http://localhost:{port}/metrics") - txt = response.body.decode("utf8") - families = { - family.name: family for family in text_string_to_metric_families(txt) - } + async def fetch_state_metrics(): + families = await fetch_metrics(s.http_server.port, prefix="dask_scheduler_") active_metrics = { sample.labels["state"]: sample.value @@ -142,7 +131,7 @@ async def fetch_metrics(): # Ensure that we get full zero metrics for all states even though the # scheduler did nothing, yet assert not s.tasks - active_metrics, forgotten_tasks = await fetch_metrics() + active_metrics, forgotten_tasks = await fetch_state_metrics() assert active_metrics.keys() == expected assert sum(active_metrics.values()) == 0.0 assert sum(forgotten_tasks) == 0.0 @@ -152,7 +141,7 @@ async def fetch_metrics(): while not any(future.key in w.state.tasks for w in [a, b]): await asyncio.sleep(0.001) - active_metrics, forgotten_tasks = await fetch_metrics() + active_metrics, forgotten_tasks = await fetch_state_metrics() assert active_metrics.keys() == expected assert sum(active_metrics.values()) == 1.0 assert sum(forgotten_tasks) == 0.0 @@ -165,7 +154,7 @@ async def fetch_metrics(): while any(future.key in w.state.tasks for w in [a, b]): await asyncio.sleep(0.001) - active_metrics, forgotten_tasks = await fetch_metrics() + active_metrics, forgotten_tasks = await fetch_state_metrics() assert active_metrics.keys() == expected assert sum(active_metrics.values()) == 0.0 assert sum(forgotten_tasks) == 0.0 diff --git a/distributed/http/scheduler/tests/test_semaphore_http.py b/distributed/http/scheduler/tests/test_semaphore_http.py index 93ac888f237..cae0395487e 100644 --- a/distributed/http/scheduler/tests/test_semaphore_http.py +++ b/distributed/http/scheduler/tests/test_semaphore_http.py @@ -1,31 +1,16 @@ from __future__ import annotations import pytest -from tornado.httpclient import AsyncHTTPClient from distributed import Semaphore -from distributed.utils_test import gen_cluster +from distributed.utils_test import fetch_metrics, gen_cluster @gen_cluster(client=True, clean_kwargs={"threads": False}) -async def test_prometheus_collect_task_states(c, s, a, b): +async def test_prometheus(c, s, a, b): pytest.importorskip("prometheus_client") - from prometheus_client.parser import text_string_to_metric_families - http_client = AsyncHTTPClient() - - async def fetch_metrics(): - port = s.http_server.port - response = await http_client.fetch(f"http://localhost:{port}/metrics") - txt = response.body.decode("utf8") - families = { - family.name: family - for family in text_string_to_metric_families(txt) - if family.name.startswith("dask_semaphore_") - } - return families - - active_metrics = await fetch_metrics() + active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_") expected_metrics = { "dask_semaphore_max_leases", @@ -42,7 +27,7 @@ async def fetch_metrics(): sem = await Semaphore(name="test", max_leases=2) - active_metrics = await fetch_metrics() + active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_") assert active_metrics.keys() == expected_metrics # Assert values are set upon intialization for name, v in active_metrics.items(): @@ -56,7 +41,7 @@ async def fetch_metrics(): assert sample.value == 0 assert await sem.acquire() - active_metrics = await fetch_metrics() + active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_") assert active_metrics["dask_semaphore_max_leases"].samples[0].value == 2 assert active_metrics["dask_semaphore_active_leases"].samples[0].value == 1 assert ( @@ -68,7 +53,7 @@ async def fetch_metrics(): assert active_metrics["dask_semaphore_pending_leases"].samples[0].value == 0 assert await sem.release() is True - active_metrics = await fetch_metrics() + active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_") assert active_metrics["dask_semaphore_max_leases"].samples[0].value == 2 assert active_metrics["dask_semaphore_active_leases"].samples[0].value == 0 assert ( @@ -80,7 +65,7 @@ async def fetch_metrics(): assert active_metrics["dask_semaphore_pending_leases"].samples[0].value == 0 await sem.close() - active_metrics = await fetch_metrics() + active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_") assert active_metrics.keys() == expected_metrics for v in active_metrics.values(): assert v.samples == [] diff --git a/distributed/http/worker/prometheus/core.py b/distributed/http/worker/prometheus/core.py index f95cf415b9f..8e40957bf83 100644 --- a/distributed/http/worker/prometheus/core.py +++ b/distributed/http/worker/prometheus/core.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from typing import ClassVar from distributed.http.prometheus import PrometheusCollector from distributed.http.utils import RequestHandler @@ -77,19 +78,22 @@ def collect(self): class PrometheusHandler(RequestHandler): - _initialized = False + _collector: ClassVar[WorkerMetricCollector | None] = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, dask_server=None, **kwargs): import prometheus_client - super().__init__(*args, **kwargs) + super().__init__(*args, dask_server=dask_server, **kwargs) - if PrometheusHandler._initialized: + if PrometheusHandler._collector: + # Especially during testing, multiple workers are started + # sequentially in the same python process + PrometheusHandler._collector.server = self.server return - prometheus_client.REGISTRY.register(WorkerMetricCollector(self.server)) - - PrometheusHandler._initialized = True + PrometheusHandler._collector = WorkerMetricCollector(self.server) + # Register collector + prometheus_client.REGISTRY.register(PrometheusHandler._collector) def get(self): import prometheus_client diff --git a/distributed/http/worker/tests/test_worker_http.py b/distributed/http/worker/tests/test_worker_http.py index b0fc7c1227c..02d3d4dba59 100644 --- a/distributed/http/worker/tests/test_worker_http.py +++ b/distributed/http/worker/tests/test_worker_http.py @@ -1,32 +1,100 @@ from __future__ import annotations +import asyncio import json import pytest from tornado.httpclient import AsyncHTTPClient -from distributed.utils_test import gen_cluster +from distributed import Event +from distributed.utils_test import fetch_metrics, gen_cluster -@gen_cluster(client=True) -async def test_prometheus(c, s, a, b): +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) +async def test_prometheus(c, s, a): pytest.importorskip("prometheus_client") - from prometheus_client.parser import text_string_to_metric_families - http_client = AsyncHTTPClient() + active_metrics = await fetch_metrics(a.http_server.port, prefix="dask_worker_") + + expected_metrics = { + "dask_worker_tasks", + "dask_worker_concurrent_fetch_requests", + "dask_worker_threads", + "dask_worker_latency_seconds", + } + + try: + import crick # noqa: F401 + except ImportError: + pass + else: + expected_metrics = expected_metrics.union( + { + "dask_worker_tick_duration_median_seconds", + "dask_worker_task_duration_median_seconds", + "dask_worker_transfer_bandwidth_median_bytes", + } + ) + + assert active_metrics.keys() == expected_metrics # request data twice since there once was a case where metrics got registered # multiple times resulting in prometheus_client errors - for _ in range(2): - response = await http_client.fetch( - "http://localhost:%d/metrics" % a.http_server.port - ) - assert response.code == 200 - assert response.headers["Content-Type"] == "text/plain; version=0.0.4" + await fetch_metrics(a.http_server.port, prefix="dask_worker_") + + +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) +async def test_prometheus_collect_task_states(c, s, a): + pytest.importorskip("prometheus_client") + + async def fetch_state_metrics(): + families = await fetch_metrics(a.http_server.port, prefix="dask_worker_") + active_metrics = { + sample.labels["state"]: sample.value + for sample in families["dask_worker_tasks"].samples + } + return active_metrics + + expected_metrics = {"stored", "executing", "ready", "waiting"} + assert not a.state.tasks + active_metrics = await fetch_state_metrics() + assert active_metrics == { + "stored": 0.0, + "executing": 0.0, + "ready": 0.0, + "waiting": 0.0, + } + + ev = Event() + + # submit a task which should show up in the prometheus scraping + future = c.submit(ev.wait) + while not a.state.executing: + await asyncio.sleep(0.001) + + active_metrics = await fetch_state_metrics() + assert active_metrics == { + "stored": 0.0, + "executing": 1.0, + "ready": 0.0, + "waiting": 0.0, + } + + await ev.set() + await c.gather(future) + + future.release() + + while future.key in a.state.tasks: + await asyncio.sleep(0.001) - txt = response.body.decode("utf8") - families = {familiy.name for familiy in text_string_to_metric_families(txt)} - assert "dask_worker_latency_seconds" in families + active_metrics = await fetch_state_metrics() + assert active_metrics == { + "stored": 0.0, + "executing": 0.0, + "ready": 0.0, + "waiting": 0.0, + } @gen_cluster(client=True) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 450290a0b81..da02ee9f9d2 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -33,6 +33,7 @@ import pytest import yaml from tlz import assoc, memoize, merge +from tornado.httpclient import AsyncHTTPClient from tornado.ioloop import IOLoop import dask @@ -2440,3 +2441,18 @@ def _bind_port(port): raise TimeoutError(f"Default ports didn't open up in time for {name_of_test}") yield + + +async def fetch_metrics(port: int, prefix: str | None = None) -> dict[str, Any]: + from prometheus_client.parser import text_string_to_metric_families + + http_client = AsyncHTTPClient() + response = await http_client.fetch(f"http://localhost:{port}/metrics") + assert response.code == 200 + txt = response.body.decode("utf8") + families = { + family.name: family + for family in text_string_to_metric_families(txt) + if prefix is None or family.name.startswith(prefix) + } + return families