From 3cececcbbd53e42e28a6f0e6fcdfdabf2f6d5975 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 11 May 2022 11:33:56 +0100 Subject: [PATCH 1/3] Prevent infinite transition loops --- distributed/distributed-schema.yaml | 7 ++++ distributed/distributed.yaml | 4 ++ distributed/scheduler.py | 29 ++++++++++---- distributed/tests/test_scheduler.py | 62 ++++++++++++++++++++++++++++- distributed/tests/test_stress.py | 14 ++++++- distributed/tests/test_worker.py | 1 + distributed/utils_test.py | 23 +++++------ distributed/worker.py | 39 ++++++++++++++++-- distributed/worker_state_machine.py | 6 ++- 9 files changed, 157 insertions(+), 28 deletions(-) diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 942eab8ff9a..86e9acf6b6b 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -1027,6 +1027,13 @@ properties: type: boolean description: Enter Python Debugger on scheduling error + transition-counter-max: + oneOf: + - enum: [false] + - type: integer + description: Cause the scheduler or workers to break if they reach this + number of transitions + system-monitor: type: object description: | diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 74d59addb35..649ddfe2b33 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -277,6 +277,10 @@ distributed: log-length: 10000 # default length of logs to keep in memory log-format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' pdb-on-err: False # enter debug mode on scheduling error + # Cause scheduler and workers to break if they reach this many transitions. + # Used to debug infinite transition loops. + # Note: setting this will cause healthy long-running services to eventually break. + transition-counter-max: False system-monitor: interval: 500ms event-loop: tornado diff --git a/distributed/scheduler.py b/distributed/scheduler.py index ad37548766b..9641c2f5357 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1380,6 +1380,7 @@ class SchedulerState: "validate", "workers", "transition_counter", + "transition_counter_max", "plugins", "UNKNOWN_TASK_DURATION", "MEMORY_RECENT_TO_OLD_TIME", @@ -1472,6 +1473,9 @@ def __init__( / 2.0 ) self.transition_counter = 0 + self.transition_counter_max = dask.config.get( + "distributed.admin.transition-counter-max" + ) @property def memory(self) -> MemoryState: @@ -1548,16 +1552,24 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): Scheduler.transitions : transitive version of this function """ try: - recommendations = {} # type: ignore - worker_msgs = {} # type: ignore - client_msgs = {} # type: ignore - ts: TaskState = self.tasks.get(key) # type: ignore if ts is None: - return recommendations, client_msgs, worker_msgs + return {}, {}, {} start = ts._state if start == finish: - return recommendations, client_msgs, worker_msgs + return {}, {}, {} + + # Notes: + # - in case of transition through released, this counter is incremented by 2 + # - this increase happens before the actual transitions, so that it can + # catch potential infinite recursions + self.transition_counter += 1 + if self.validate and self.transition_counter_max: + assert self.transition_counter < self.transition_counter_max + + recommendations = {} # type: ignore + worker_msgs = {} # type: ignore + client_msgs = {} # type: ignore if self.plugins: dependents = set(ts.dependents) @@ -1569,7 +1581,7 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): recommendations, client_msgs, worker_msgs = func( key, stimulus_id, *args, **kwargs ) # type: ignore - self.transition_counter += 1 + elif "released" not in start_finish: assert not args and not kwargs, (args, kwargs, start_finish) a_recs: dict @@ -3294,6 +3306,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: info = super()._to_dict(exclude=exclude) extra = { "transition_log": self.transition_log, + "transition_counter": self.transition_counter, "log": self.log, "tasks": self.tasks, "task_groups": self.task_groups, @@ -4617,6 +4630,8 @@ def validate_state(self, allow_overlap: bool = False) -> None: actual_total_occupancy, self.total_occupancy, ) + if self.transition_counter_max: + assert self.transition_counter < self.transition_counter_max ################### # Manage Messages # diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 784448c588c..dbbf5c55ded 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -21,6 +21,7 @@ from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename from distributed import ( + CancelledError, Client, Event, Lock, @@ -3215,11 +3216,67 @@ async def test_computations_futures(c, s, a, b): assert "inc" in str(computation.groups) -@gen_cluster(client=True) -async def test_transition_counter(c, s, a, b): +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_transition_counter(c, s, a): assert s.transition_counter == 0 + assert a.transition_counter == 0 await c.submit(inc, 1) assert s.transition_counter > 1 + assert a.transition_counter > 1 + + +@pytest.mark.slow +@gen_cluster(client=True) +async def test_transition_counter_max_scheduler(c, s, a, b): + # This is set by @gen_cluster; it's False in production + assert s.transition_counter_max > 0 + s.transition_counter_max = 1 + with captured_logger("distributed.scheduler") as logger: + with pytest.raises(CancelledError): + await c.submit(inc, 2) + assert s.transition_counter > 1 + with pytest.raises(AssertionError): + s.validate_state() + assert "transition_counter_max" in logger.getvalue() + # Scheduler state is corrupted. Avoid test failure on gen_cluster teardown. + s.validate = False + + +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_transition_counter_max_worker(c, s, a): + # This is set by @gen_cluster; it's False in production + assert s.transition_counter_max > 0 + a.transition_counter_max = 1 + with captured_logger("distributed.core") as logger: + fut = c.submit(inc, 2) + while True: + try: + a.validate_state() + except AssertionError: + break + await asyncio.sleep(0.01) + + assert "TransitionCounterMaxExceeded" in logger.getvalue() + # Worker state is corrupted. Avoid test failure on gen_cluster teardown. + a.validate = False + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + config={"distributed.admin.transition-counter-max": False}, +) +async def test_disable_transition_counter_max(c, s, a, b): + """Test that the cluster can run indefinitely if transition_counter_max is disabled. + This is the default outside of @gen_cluster. + """ + assert s.transition_counter_max is False + assert a.transition_counter_max is False + assert await c.submit(inc, 1) == 2 + assert s.transition_counter > 1 + assert a.transition_counter > 1 + s.validate_state() + a.validate_state() @gen_cluster( @@ -3339,6 +3396,7 @@ async def test_Scheduler__to_dict(c, s, a): "status", "thread_id", "transition_log", + "transition_counter", "log", "memory", "tasks", diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 0502f8ee9b3..f8bd600ff6e 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -228,7 +228,12 @@ async def test_stress_steal(c, s, *workers): @pytest.mark.slow -@gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=180) +@gen_cluster( + nthreads=[("", 1)] * 10, + client=True, + timeout=180, + config={"distributed.admin.transition-counter-max": 500_000}, +) async def test_close_connections(c, s, *workers): da = pytest.importorskip("dask.array") x = da.random.random(size=(1000, 1000), chunks=(1000, 1)) @@ -291,7 +296,12 @@ async def test_no_delay_during_large_transfer(c, s, w): @pytest.mark.slow -@gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)] * 6) +@gen_cluster( + client=True, + Worker=Nanny, + nthreads=[("", 2)] * 6, + config={"distributed.admin.transition-counter-max": 500_000}, +) async def test_chaos_rechunk(c, s, *workers): s.allowed_failures = 10000 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index adc0c16bd08..47bf0c646f4 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3435,6 +3435,7 @@ async def test_Worker__to_dict(c, s, a): "busy_workers", "log", "stimulus_log", + "transition_counter", "tasks", "logs", "config", diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 044bd662ac0..9ca66fae3d3 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -27,14 +27,6 @@ from time import sleep from typing import Any, Generator, Literal -from distributed.compatibility import MACOS -from distributed.scheduler import Scheduler - -try: - import ssl -except ImportError: - ssl = None # type: ignore - import pytest import yaml from tlz import assoc, memoize, merge @@ -43,12 +35,12 @@ import dask -from distributed import system +from distributed import Scheduler, system from distributed import versions as version_module from distributed.client import Client, _global_clients, default_client from distributed.comm import Comm from distributed.comm.tcp import TCP -from distributed.compatibility import WINDOWS +from distributed.compatibility import MACOS, WINDOWS from distributed.config import initialize_logging from distributed.core import ( CommClosedError, @@ -79,6 +71,11 @@ ) from distributed.worker import WORKER_ANY_RUNNING, InvalidTransition, Worker +try: + import ssl +except ImportError: + ssl = None # type: ignore + try: import dask.array # register config except ImportError: @@ -447,8 +444,6 @@ async def background_read(): def run_scheduler(q, nputs, config, port=0, **kwargs): with dask.config.set(config): - from distributed import Scheduler - # On Python 2.7 and Unix, fork() is used to spawn child processes, # so avoid inheriting the parent's IO loop. with pristine_loop() as loop: @@ -999,6 +994,7 @@ async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture worker_kwargs = merge( {"memory_limit": system.MEMORY_LIMIT, "death_timeout": 15}, worker_kwargs ) + config = merge({"distributed.admin.transition-counter-max": 50_000}, config) def _(func): if not iscoroutinefunction(func): @@ -1054,6 +1050,9 @@ async def coro(): result = await coro2 if s.validate: s.validate_state() + for w in workers: + if w.validate and hasattr(w, "validate_state"): + w.validate_state() except asyncio.TimeoutError: assert task diff --git a/distributed/worker.py b/distributed/worker.py index 601b053cadb..26585b813d4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -129,6 +129,7 @@ TaskFinishedMsg, TaskState, TaskStateState, + TransitionCounterMaxExceeded, UniqueTaskHeap, UnpauseEvent, merge_recs_instructions, @@ -449,7 +450,8 @@ class Worker(ServerNode): target_message_size: int validate: bool _transitions_table: dict[tuple[str, str], Callable] - _transition_counter: int + transition_counter: int + transition_counter_max: int | Literal[False] incoming_count: int outgoing_count: int outgoing_current_count: int @@ -656,7 +658,10 @@ def __init__( ("waiting", "released"): self.transition_generic_released, } - self._transition_counter = 0 + self.transition_counter = 0 + self.transition_counter_max = dask.config.get( + "distributed.admin.transition-counter-max" + ) self.incoming_transfer_log = deque(maxlen=100000) self.incoming_count = 0 self.outgoing_transfer_log = deque(maxlen=100000) @@ -1113,6 +1118,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: "busy_workers": self.busy_workers, "log": self.log, "stimulus_log": self.stimulus_log, + "transition_counter": self.transition_counter, "tasks": self.tasks, "logs": self.get_logs(), "config": dask.config.config, @@ -2650,8 +2656,29 @@ def _transition( start = ts.state func = self._transitions_table.get((start, cast(str, finish))) + # Notes: + # - in case of transition through released, this counter is incremented by 2 + # - this increase happens before the actual transitions, so that it can + # catch potential infinite recursions + self.transition_counter += 1 + if ( + self.validate + and self.transition_counter_max + and self.transition_counter >= self.transition_counter_max + ): + self.log_event( + "transition-counter-max-exceeded", + { + "key": ts.key, + "start": start, + "finish": finish, + "story": self.story(ts), + "worker": self.address, + }, + ) + raise TransitionCounterMaxExceeded(ts.key, start, finish, self.story(ts)) + if func is not None: - self._transition_counter += 1 recs, instructions = func(ts, *args, stimulus_id=stimulus_id, **kwargs) self._notify_plugins("transition", ts.key, start, finish, **kwargs) @@ -4209,7 +4236,8 @@ def validate_state(self): or ts_wait in self._missing_dep_flight or ts_wait.who_has.issubset(self.in_flight_workers) ), (ts, ts_wait, self.story(ts), self.story(ts_wait)) - assert self.waiting_for_data_count == waiting_for_data_count + # FIXME https://github.com/dask/distributed/issues/6319 + # assert self.waiting_for_data_count == waiting_for_data_count for worker, keys in self.has_what.items(): for k in keys: assert worker in self.tasks[k].who_has @@ -4217,6 +4245,9 @@ def validate_state(self): for ts in self.tasks.values(): self.validate_task(ts) + if self.transition_counter_max: + assert self.transition_counter < self.transition_counter_max + except Exception as e: logger.error("Validate state failed. Closing.", exc_info=e) logger.exception(e) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index abdbc121080..14986c7f257 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -75,7 +75,7 @@ def __init__(self, key, start, finish, story): def __repr__(self): return ( - f"InvalidTransition: {self.key} :: {self.start}->{self.finish}" + f"{self.__class__.__name__}: {self.key} :: {self.start}->{self.finish}" + "\n" + " Story:\n " + "\n ".join(map(str, self.story)) @@ -84,6 +84,10 @@ def __repr__(self): __str__ = __repr__ +class TransitionCounterMaxExceeded(InvalidTransition): + pass + + @lru_cache def _default_data_size() -> int: return parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) From e4f8d9cbcd8965253d91a764e901dd1e7cfaaa86 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 11 May 2022 12:10:06 +0100 Subject: [PATCH 2/3] validate on timeout --- distributed/utils_test.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 9ca66fae3d3..88f3c4e82b2 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1048,11 +1048,7 @@ async def coro(): task = asyncio.create_task(coro) coro2 = asyncio.wait_for(asyncio.shield(task), timeout) result = await coro2 - if s.validate: - s.validate_state() - for w in workers: - if w.validate and hasattr(w, "validate_state"): - w.validate_state() + validate_state(s, *workers) except asyncio.TimeoutError: assert task @@ -1072,6 +1068,10 @@ async def coro(): while not task.cancelled(): await asyncio.sleep(0.01) + # Hopefully, the hang has been caused by inconsistent state, + # which should be much more meaningful than the timeout + validate_state(s, *workers) + # Remove as much of the traceback as possible; it's # uninteresting boilerplate from utils_test and asyncio and # not from the code being tested. @@ -1204,6 +1204,15 @@ async def dump_cluster_state( print(f"Dumped cluster state to {fname}") +def validate_state(*servers: Scheduler | Worker | Nanny) -> None: + """Run validate_state() on the Scheduler and all the Workers of the cluster. + Excludes workers wrapped by Nannies and workers manually started by the test. + """ + for s in servers: + if s.validate and hasattr(s, "validate_state"): + s.validate_state() # type: ignore + + def raises(func, exc=Exception): try: func() From 678fd021ceb9cf64d913cd74d4d73c24e83f78c6 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 12 May 2022 11:52:43 +0100 Subject: [PATCH 3/3] Don't require validate flag --- distributed/scheduler.py | 2 +- distributed/worker.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 4bcf7c48020..f236e7e0d1c 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1446,7 +1446,7 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs): # - this increase happens before the actual transitions, so that it can # catch potential infinite recursions self.transition_counter += 1 - if self.validate and self.transition_counter_max: + if self.transition_counter_max: assert self.transition_counter < self.transition_counter_max recommendations = {} # type: ignore diff --git a/distributed/worker.py b/distributed/worker.py index 9243af6e81a..653b11a2e9c 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2666,8 +2666,7 @@ def _transition( # catch potential infinite recursions self.transition_counter += 1 if ( - self.validate - and self.transition_counter_max + self.transition_counter_max and self.transition_counter >= self.transition_counter_max ): self.log_event(