From 0b0d368ce8c3746244df2f46b2b16d538cea29b5 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 22 Dec 2022 13:05:18 +0000 Subject: [PATCH] Do not count spilled memory when comparing vs. process memory --- distributed/tests/test_worker_memory.py | 30 ++++++++++++++++++------- distributed/worker.py | 16 ++++++++----- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 6ebddb924e4..35b32e8855c 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -1158,18 +1158,32 @@ async def test_deprecated_params(s, name): @gen_cluster( client=True, - nthreads=[("", 1)], - config={ - "distributed.worker.memory.target": False, - "distributed.worker.memory.monitor-interval": "100ms", - }, + config={"distributed.worker.memory.target": False}, + worker_kwargs={"heartbeat_interval": "10ms"}, ) -async def test_warn_on_sizeof_overestimate(c, s, a): +async def test_warn_on_sizeof_overestimate(c, s, a, b): + class C: + def __sizeof__(self): + return 2**40 + + with captured_logger("distributed.worker") as log: + x = c.submit(C) + # Wait for heartbeat + while "exceeds process memory" not in log.getvalue(): + await asyncio.sleep(0.01) + + +@gen_cluster(client=True, worker_kwargs={"heartbeat_interval": "10ms"}) +async def test_warn_on_sizeof_overestimate_spill(c, s, a, b): class C: def __sizeof__(self): return 2**40 - with captured_logger("distributed.worker", level=logging.WARNING) as log: + with captured_logger("distributed.worker") as log: x = c.submit(C) - while "Managed memory exceeds process memory" not in log.getvalue(): + # Wait for heartbeat + while not s.memory.spilled: await asyncio.sleep(0.01) + + # Measure managed, not managed+spilled + assert "exceeds process memory" not in log.getvalue() diff --git a/distributed/worker.py b/distributed/worker.py index 7f94ec0d490..b159e42c4f6 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1050,12 +1050,18 @@ async def get_metrics(self) -> dict: except Exception: # TODO: log error once pass - if out["managed_bytes"] > out["memory"]: + managed = out["managed_bytes"] - out["spilled_bytes"]["memory"] + if managed > out["memory"]: + # Maybe the new managed memory was added after the latest monitor run + out["memory"] = self.monitor.get_process_memory() + if managed > out["memory"]: logger.warning( - "Managed memory exceeds process memory; this will cause premature " - "spilling as well as malfunctions in several heuristics. Please ensure " - "that sizeof() returns accurate outputs for your data. Read more: " - "https://distributed.dask.org/en/stable/worker-memory.html" + "Managed memory (%s) exceeds process memory (%s); this will cause " + "premature spilling as well as malfunctions in several heuristics. " + "Please ensure that sizeof() returns accurate outputs for your data. " + "Read more: https://distributed.dask.org/en/stable/worker-memory.html", + format_bytes(managed), + format_bytes(out["memory"]), ) return out