Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,11 @@ async def run():

async def wait_for_nannies_to_finish():
"""Wait for all nannies to initialize and finish"""
await asyncio.gather(*nannies)
try:
await asyncio.gather(*nannies)
except Exception:
if not signal_fired:
raise
await asyncio.gather(*(n.finished() for n in nannies))

async def wait_for_signals_and_close():
Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,9 @@ async def _handle_stream(self, stream, address):
try:
await self.on_connection(comm)
except CommClosedError:
logger.info("Connection from %s closed before handshake completed", address)
logger.debug(
"Connection from %s closed before handshake completed", address
)
return

await self.comm_handler(comm)
Expand Down
49 changes: 43 additions & 6 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections.abc import Awaitable, Generator
from contextlib import suppress
from inspect import isawaitable
from time import time
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

from tornado import gen
Expand Down Expand Up @@ -389,28 +390,64 @@ async def _correct_state_internal(self) -> None:
# proper teardown.
await asyncio.gather(*worker_futs)

def _update_worker_status(self, op, msg):
def _update_worker_status(self, op, worker_addr):
if op == "remove":
name = self.scheduler_info["workers"][msg]["name"]
worker_info = self.scheduler_info["workers"][worker_addr].copy()
name = worker_info["name"]

from distributed import Nanny, Worker

def f():
# FIXME: SpecCluster is tracking workers by `name`` which are
# not necessarily unique.
# Clusters with Nannies (default) are susceptible to falsely
# removing the Nannies on restart due to this logic since the
# restart emits a op==remove signal on the worker address but
# the SpecCluster only tracks the names, i.e. after
# `lost-worker-timeout` the Nanny is still around and this logic
# could trigger a false close. The below code should handle this
# but it would be cleaner if the cluster tracked by address
# instead of name just like the scheduler does
if (
name in self.workers
and msg not in self.scheduler_info["workers"]
and worker_addr not in self.scheduler_info["workers"]
and not any(
d["name"] == name
for d in self.scheduler_info["workers"].values()
)
):
self._futures.add(asyncio.ensure_future(self.workers[name].close()))
del self.workers[name]
w = self.workers[name]

async def remove_worker():
await w.close(reason=f"lost-worker-timeout-{time()}")
self.workers.pop(name, None)

if (
worker_info["type"] == "Worker"
and (isinstance(w, Nanny) and w.worker_address == worker_addr)
or (isinstance(w, Worker) and w.address == worker_addr)
):
self._futures.add(
asyncio.create_task(
remove_worker(),
name="remove-worker-lost-worker-timeout",
)
)
elif worker_info["type"] == "Nanny":
# This should never happen
logger.critical(
"Unespected signal encountered. WorkerStatusPlugin "
"emitted a op==remove signal for a Nanny which "
"should not happen. This might cause a lingering "
"Nanny process."
)

delay = parse_timedelta(
dask.config.get("distributed.deploy.lost-worker-timeout")
)

asyncio.get_running_loop().call_later(delay, f)
super()._update_worker_status(op, msg)
super()._update_worker_status(op, worker_addr)

def __await__(self: Self) -> Generator[Any, Any, Self]:
async def _() -> Self:
Expand Down
15 changes: 15 additions & 0 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
from tornado.httpclient import AsyncHTTPClient

import dask
from dask.system import CPU_COUNT

from distributed import Client, LocalCluster, Nanny, Worker, get_client
Expand Down Expand Up @@ -1285,3 +1286,17 @@ def test_localcluster_get_client(loop):
with Client(cluster) as client2:
assert client1 != client2
assert client2 == cluster.get_client()


@pytest.mark.slow()
def test_localcluster_restart(loop):
with (
dask.config.set({"distributed.deploy.lost-worker-timeout": "0.5s"}),
LocalCluster(asynchronous=False, dashboard_address=":0", loop=loop) as cluster,
cluster.get_client() as client,
):
nworkers = len(client.run(lambda: None))
for _ in range(10):
assert len(client.run(lambda: None)) == nworkers
client.restart()
assert len(client.run(lambda: None)) == nworkers
Loading