Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 20 additions & 31 deletions distributed/http/scheduler/tests/test_scheduler_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from distributed import Lock
from distributed.utils import is_valid_xml
from distributed.utils_test import gen_cluster, inc, lock_inc, slowinc
from distributed.utils_test import fetch_metrics, gen_cluster, inc, lock_inc, slowinc

DEFAULT_ROUTES = dask.config.get("distributed.scheduler.http.routes")

Expand Down Expand Up @@ -89,43 +89,32 @@ async def test_prefix(c, s, a, b):
@gen_cluster(client=True, clean_kwargs={"threads": False})
async def test_prometheus(c, s, a, b):
pytest.importorskip("prometheus_client")
from prometheus_client.parser import text_string_to_metric_families

http_client = AsyncHTTPClient()
active_metrics = await fetch_metrics(s.http_server.port, "dask_scheduler_")

# request data twice since there once was a case where metrics got registered multiple times resulting in
# prometheus_client errors
for _ in range(2):
response = await http_client.fetch(
"http://localhost:%d/metrics" % s.http_server.port
)
assert response.code == 200
assert response.headers["Content-Type"] == "text/plain; version=0.0.4"
expected_metrics = {
"dask_scheduler_clients",
"dask_scheduler_desired_workers",
"dask_scheduler_workers",
"dask_scheduler_tasks",
"dask_scheduler_tasks_suspicious",
"dask_scheduler_tasks_forgotten",
}

txt = response.body.decode("utf8")
families = {
family.name: family for family in text_string_to_metric_families(txt)
}
assert "dask_scheduler_workers" in families
assert active_metrics.keys() == expected_metrics
assert active_metrics["dask_scheduler_clients"].samples[0].value == 1.0

client = families["dask_scheduler_clients"]
assert client.samples[0].value == 1.0
# request data twice since there once was a case where metrics got registered multiple times resulting in
# prometheus_client errors
await fetch_metrics(s.http_server.port, "dask_scheduler_")


@gen_cluster(client=True, clean_kwargs={"threads": False})
async def test_prometheus_collect_task_states(c, s, a, b):
pytest.importorskip("prometheus_client")
from prometheus_client.parser import text_string_to_metric_families

http_client = AsyncHTTPClient()

async def fetch_metrics():
port = s.http_server.port
response = await http_client.fetch(f"http://localhost:{port}/metrics")
txt = response.body.decode("utf8")
families = {
family.name: family for family in text_string_to_metric_families(txt)
}
async def fetch_state_metrics():
families = await fetch_metrics(s.http_server.port, prefix="dask_scheduler_")

active_metrics = {
sample.labels["state"]: sample.value
Expand All @@ -142,7 +131,7 @@ async def fetch_metrics():
# Ensure that we get full zero metrics for all states even though the
# scheduler did nothing, yet
assert not s.tasks
active_metrics, forgotten_tasks = await fetch_metrics()
active_metrics, forgotten_tasks = await fetch_state_metrics()
assert active_metrics.keys() == expected
assert sum(active_metrics.values()) == 0.0
assert sum(forgotten_tasks) == 0.0
Expand All @@ -152,7 +141,7 @@ async def fetch_metrics():
while not any(future.key in w.state.tasks for w in [a, b]):
await asyncio.sleep(0.001)

active_metrics, forgotten_tasks = await fetch_metrics()
active_metrics, forgotten_tasks = await fetch_state_metrics()
assert active_metrics.keys() == expected
assert sum(active_metrics.values()) == 1.0
assert sum(forgotten_tasks) == 0.0
Expand All @@ -165,7 +154,7 @@ async def fetch_metrics():
while any(future.key in w.state.tasks for w in [a, b]):
await asyncio.sleep(0.001)

active_metrics, forgotten_tasks = await fetch_metrics()
active_metrics, forgotten_tasks = await fetch_state_metrics()
assert active_metrics.keys() == expected
assert sum(active_metrics.values()) == 0.0
assert sum(forgotten_tasks) == 0.0
Expand Down
29 changes: 7 additions & 22 deletions distributed/http/scheduler/tests/test_semaphore_http.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,16 @@
from __future__ import annotations

import pytest
from tornado.httpclient import AsyncHTTPClient

from distributed import Semaphore
from distributed.utils_test import gen_cluster
from distributed.utils_test import fetch_metrics, gen_cluster


@gen_cluster(client=True, clean_kwargs={"threads": False})
async def test_prometheus_collect_task_states(c, s, a, b):
async def test_prometheus(c, s, a, b):
pytest.importorskip("prometheus_client")
from prometheus_client.parser import text_string_to_metric_families

http_client = AsyncHTTPClient()

async def fetch_metrics():
port = s.http_server.port
response = await http_client.fetch(f"http://localhost:{port}/metrics")
txt = response.body.decode("utf8")
families = {
family.name: family
for family in text_string_to_metric_families(txt)
if family.name.startswith("dask_semaphore_")
}
return families

active_metrics = await fetch_metrics()
active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_")

expected_metrics = {
"dask_semaphore_max_leases",
Expand All @@ -42,7 +27,7 @@ async def fetch_metrics():

sem = await Semaphore(name="test", max_leases=2)

active_metrics = await fetch_metrics()
active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_")
assert active_metrics.keys() == expected_metrics
# Assert values are set upon intialization
for name, v in active_metrics.items():
Expand All @@ -56,7 +41,7 @@ async def fetch_metrics():
assert sample.value == 0

assert await sem.acquire()
active_metrics = await fetch_metrics()
active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_")
assert active_metrics["dask_semaphore_max_leases"].samples[0].value == 2
assert active_metrics["dask_semaphore_active_leases"].samples[0].value == 1
assert (
Expand All @@ -68,7 +53,7 @@ async def fetch_metrics():
assert active_metrics["dask_semaphore_pending_leases"].samples[0].value == 0

assert await sem.release() is True
active_metrics = await fetch_metrics()
active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_")
assert active_metrics["dask_semaphore_max_leases"].samples[0].value == 2
assert active_metrics["dask_semaphore_active_leases"].samples[0].value == 0
assert (
Expand All @@ -80,7 +65,7 @@ async def fetch_metrics():
assert active_metrics["dask_semaphore_pending_leases"].samples[0].value == 0

await sem.close()
active_metrics = await fetch_metrics()
active_metrics = await fetch_metrics(s.http_server.port, "dask_semaphore_")
assert active_metrics.keys() == expected_metrics
for v in active_metrics.values():
assert v.samples == []
18 changes: 11 additions & 7 deletions distributed/http/worker/prometheus/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from typing import ClassVar

from distributed.http.prometheus import PrometheusCollector
from distributed.http.utils import RequestHandler
Expand Down Expand Up @@ -77,19 +78,22 @@ def collect(self):


class PrometheusHandler(RequestHandler):
_initialized = False
_collector: ClassVar[WorkerMetricCollector | None] = None

def __init__(self, *args, **kwargs):
def __init__(self, *args, dask_server=None, **kwargs):
import prometheus_client

super().__init__(*args, **kwargs)
super().__init__(*args, dask_server=dask_server, **kwargs)

if PrometheusHandler._initialized:
if PrometheusHandler._collector:
# Especially during testing, multiple workers are started
# sequentially in the same python process
PrometheusHandler._collector.server = self.server
return

prometheus_client.REGISTRY.register(WorkerMetricCollector(self.server))

PrometheusHandler._initialized = True
PrometheusHandler._collector = WorkerMetricCollector(self.server)
# Register collector
prometheus_client.REGISTRY.register(PrometheusHandler._collector)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are multiple concurrent Worker instances in the same process (very common with async workers in all gen_cluster tests), this means PrometheusHandler._collector will only be set to the WorkerMetricCollector of the last worker to start. Is that okay? Does the prometheus_client.REGISTRY.register(PrometheusHandler._collector) mean that all the collectors will be registered somewhere anywhere?

I don't know what either _collector or prometheus_client.REGISTRY are for, I'm just seeing some global variables getting overwritten in a place where I know there might not just be multiple workers sequentially, but also multiple workers in parallel.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way this had been implemented before was that we'd simply set up a PrometheusHandler for the first worker spinning up in the process, set _initialized to True and call it a day. This led to test_prometheus_collect_task_states because that PrometheusHandler did not belong to the newly spun-up worker when running multiple tests. This new pattern surely isn't perfect as we only ever have a single instance of a PrometheusHandler per process talking to the last worker we spun up. However, IMO this is better than what we had before and it would require some more serious thinking to understand the Prometheus client to handle parallel workers gracefully. I'm happy to file a follow-up issue for that.

Regarding the singleton implementation: I've stolen that from https://github.com/hendrikmakait/distributed/blob/13e315c11c7277ba0cadd9f5c2a16364cdeaf14b/distributed/http/scheduler/prometheus/core.py#L79

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was pretty painful when we started testing this stuff. I'm OK with either solution and don't think we need anything more sophisticated for now.

As was already pointed out, this is merely an artifact for testing and no real world application would ever run multiple workers in the same process AND use prometheus. At the very least, this is something we then choose to not properly support

TLDR As long as the testing works out, I'm happy


def get(self):
import prometheus_client
Expand Down
96 changes: 82 additions & 14 deletions distributed/http/worker/tests/test_worker_http.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,100 @@
from __future__ import annotations

import asyncio
import json

import pytest
from tornado.httpclient import AsyncHTTPClient

from distributed.utils_test import gen_cluster
from distributed import Event
from distributed.utils_test import fetch_metrics, gen_cluster


@gen_cluster(client=True)
async def test_prometheus(c, s, a, b):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_prometheus(c, s, a):
pytest.importorskip("prometheus_client")
from prometheus_client.parser import text_string_to_metric_families

http_client = AsyncHTTPClient()
active_metrics = await fetch_metrics(a.http_server.port, prefix="dask_worker_")

expected_metrics = {
"dask_worker_tasks",
"dask_worker_concurrent_fetch_requests",
"dask_worker_threads",
"dask_worker_latency_seconds",
}

try:
import crick # noqa: F401
except ImportError:
pass
else:
expected_metrics = expected_metrics.union(
{
"dask_worker_tick_duration_median_seconds",
"dask_worker_task_duration_median_seconds",
"dask_worker_transfer_bandwidth_median_bytes",
}
)

assert active_metrics.keys() == expected_metrics

# request data twice since there once was a case where metrics got registered
# multiple times resulting in prometheus_client errors
for _ in range(2):
response = await http_client.fetch(
"http://localhost:%d/metrics" % a.http_server.port
)
assert response.code == 200
assert response.headers["Content-Type"] == "text/plain; version=0.0.4"
await fetch_metrics(a.http_server.port, prefix="dask_worker_")


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_prometheus_collect_task_states(c, s, a):
pytest.importorskip("prometheus_client")

async def fetch_state_metrics():
families = await fetch_metrics(a.http_server.port, prefix="dask_worker_")
active_metrics = {
sample.labels["state"]: sample.value
for sample in families["dask_worker_tasks"].samples
}
return active_metrics

expected_metrics = {"stored", "executing", "ready", "waiting"}
assert not a.state.tasks
active_metrics = await fetch_state_metrics()
assert active_metrics == {
"stored": 0.0,
"executing": 0.0,
"ready": 0.0,
"waiting": 0.0,
}

ev = Event()

# submit a task which should show up in the prometheus scraping
future = c.submit(ev.wait)
while not a.state.executing:
await asyncio.sleep(0.001)

active_metrics = await fetch_state_metrics()
assert active_metrics == {
"stored": 0.0,
"executing": 1.0,
"ready": 0.0,
"waiting": 0.0,
}

await ev.set()
await c.gather(future)

future.release()

while future.key in a.state.tasks:
await asyncio.sleep(0.001)

txt = response.body.decode("utf8")
families = {familiy.name for familiy in text_string_to_metric_families(txt)}
assert "dask_worker_latency_seconds" in families
active_metrics = await fetch_state_metrics()
assert active_metrics == {
"stored": 0.0,
"executing": 0.0,
"ready": 0.0,
"waiting": 0.0,
}


@gen_cluster(client=True)
Expand Down
16 changes: 16 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import pytest
import yaml
from tlz import assoc, memoize, merge
from tornado.httpclient import AsyncHTTPClient
from tornado.ioloop import IOLoop

import dask
Expand Down Expand Up @@ -2440,3 +2441,18 @@ def _bind_port(port):
raise TimeoutError(f"Default ports didn't open up in time for {name_of_test}")

yield


async def fetch_metrics(port: int, prefix: str | None = None) -> dict[str, Any]:
from prometheus_client.parser import text_string_to_metric_families

http_client = AsyncHTTPClient()
response = await http_client.fetch(f"http://localhost:{port}/metrics")
assert response.code == 200
txt = response.body.decode("utf8")
families = {
family.name: family
for family in text_string_to_metric_families(txt)
if prefix is None or family.name.startswith(prefix)
}
return families