Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 42 additions & 24 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def __init__(
self.digests = None
self._ongoing_background_tasks = AsyncTaskGroup()
self._event_finished = asyncio.Event()
self._event_started = asyncio.Event()

self.listeners = []
self.io_loop = self.loop = IOLoop.current()
Expand Down Expand Up @@ -418,6 +419,7 @@ def set_thread_ident():
self.io_loop.add_callback(set_thread_ident)
self._startup_lock = asyncio.Lock()
self.__startup_exc = None
self.__startup_task = None

self.rpc = ConnectionPool(
limit=connection_limit,
Expand Down Expand Up @@ -489,9 +491,17 @@ async def finished(self):
"""Wait until the server has finished"""
await self._event_finished.wait()

async def started(self):
await self._event_started.wait()

def __await__(self):
return self.start().__await__()

async def cancel_start(self):
if self.__startup_task:
self.__startup_task.cancel()
await self.started()

async def start_unsafe(self):
"""Attempt to start the server. This is not idempotent and not protected against concurrent startup attempts.

Expand All @@ -507,30 +517,35 @@ async def start_unsafe(self):

@final
async def start(self):
async with self._startup_lock:
if self.status == Status.failed:
assert self.__startup_exc is not None
raise self.__startup_exc
elif self.status != Status.init:
return self
timeout = getattr(self, "death_timeout", None)

async def _close_on_failure(exc: Exception) -> None:
await self.close()
self.status = Status.failed
self.__startup_exc = exc
if self.status == Status.failed:
assert self.__startup_exc is not None
raise self.__startup_exc
elif self.status != Status.init:
return self

try:
await asyncio.wait_for(self.start_unsafe(), timeout=timeout)
except asyncio.TimeoutError as exc:
await _close_on_failure(exc)
raise asyncio.TimeoutError(
f"{type(self).__name__} start timed out after {timeout}s."
) from exc
except Exception as exc:
await _close_on_failure(exc)
raise RuntimeError(f"{type(self).__name__} failed to start.") from exc
self.status = Status.running
async def _close_on_failure(exc: Exception) -> None:
self._event_started.set()
await self.close()
self.status = Status.failed
self.__startup_exc = exc

timeout = getattr(self, "death_timeout", None)
try:
async with self._startup_lock:
self.__startup_task = asyncio.create_task(self.start_unsafe())
self.__startup_task.add_done_callback(
lambda _: self._event_started.set()
)
await asyncio.wait_for(self.__startup_task, timeout=timeout)
self.status = Status.running
except asyncio.TimeoutError as exc:
await _close_on_failure(exc)
raise asyncio.TimeoutError(
f"{type(self).__name__} start timed out after {timeout}s."
) from exc
except Exception as exc:
await _close_on_failure(exc)
raise RuntimeError(f"{type(self).__name__} failed to start.") from exc
return self

async def __aenter__(self):
Expand Down Expand Up @@ -741,7 +756,7 @@ async def _handle_comm(self, comm):
logger.debug("Connection from %r to %s", address, type(self).__name__)
self._comms[comm] = op

await self
await self.started()
try:
while not self.__stopped:
try:
Expand Down Expand Up @@ -940,6 +955,9 @@ async def close(self, timeout=None):
await asyncio.gather(*[comm.close() for comm in list(self._comms)])
finally:
self._event_finished.set()
logger.debug(
f"Closed {type(self).__name__} - {self.address_safe} - {self.id}"
)


def pingpong(comm):
Expand Down
149 changes: 75 additions & 74 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from collections.abc import Collection
from inspect import isawaitable
from queue import Empty
from time import sleep as sync_sleep
from typing import TYPE_CHECKING, Callable, ClassVar, Literal

from toolz import merge
Expand Down Expand Up @@ -223,6 +222,8 @@ def __init__( # type: ignore[no-untyped-def]
self.validate = validate
self.resources = resources

self._instantiate_lock = asyncio.Lock()

self.Worker = Worker if worker_class is None else worker_class

self.pre_spawn_env = _get_env_variables("distributed.nanny.pre-spawn-environ")
Expand Down Expand Up @@ -385,66 +386,58 @@ async def kill(self, timeout: float = 2, reason: str = "nanny-kill") -> None:
return

deadline = time() + timeout
await self.process.kill(reason=reason, timeout=0.8 * (deadline - time()))
proc = self.process
await proc.kill(reason=reason, timeout=0.8 * (deadline - time()))
assert proc.status in (Status.stopped, Status.failed), proc.status
await proc.stopped.wait()
assert self.process is not proc

async def instantiate(self) -> Status:
"""Start a local worker process

Blocks until the process is up and the scheduler is properly informed
"""
if self.process is None:
worker_kwargs = dict(
scheduler_ip=self.scheduler_addr,
nthreads=self.nthreads,
local_directory=self._original_local_dir,
services=self.services,
nanny=self.address,
name=self.name,
memory_limit=self.memory_manager.memory_limit,
resources=self.resources,
validate=self.validate,
silence_logs=self.silence_logs,
death_timeout=self.death_timeout,
preload=self.preload,
preload_argv=self.preload_argv,
security=self.security,
contact_address=self.contact_address,
)
worker_kwargs.update(self.worker_kwargs)
self.process = WorkerProcess(
worker_kwargs=worker_kwargs,
silence_logs=self.silence_logs,
on_exit=self._on_worker_exit_sync,
worker=self.Worker,
env=self.env,
pre_spawn_env=self.pre_spawn_env,
config=self.config,
)

if self.death_timeout:
try:
result = await asyncio.wait_for(
self.process.start(), self.death_timeout
# The lock is required since there are many possible race conditions due
# to the worker exit callback
async with self._instantiate_lock:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should kill the exit callbacks eventually. I think this would require something like #6616

if self.status in (
Status.closing,
Status.closed,
Status.closing_gracefully,
Status.failed,
):
raise RuntimeError(
"Tried to start a worker on closed Nanny. This can happen if an error occured during restart. Please check logs for more information."
)
except asyncio.TimeoutError:
logger.error(
"Timed out connecting Nanny '%s' to scheduler '%s'",
self,
self.scheduler_addr,
if self.process is None:
worker_kwargs = dict(
scheduler_ip=self.scheduler_addr,
nthreads=self.nthreads,
local_directory=self._original_local_dir,
services=self.services,
nanny=self.address,
name=self.name,
memory_limit=self.memory_manager.memory_limit,
resources=self.resources,
validate=self.validate,
silence_logs=self.silence_logs,
death_timeout=self.death_timeout,
preload=self.preload,
preload_argv=self.preload_argv,
security=self.security,
contact_address=self.contact_address,
)
await self.close(
timeout=self.death_timeout, reason="nanny-instantiate-timeout"
worker_kwargs.update(self.worker_kwargs)
self.process = WorkerProcess(
worker_kwargs=worker_kwargs,
silence_logs=self.silence_logs,
on_exit=self._on_worker_exit_sync,
worker=self.Worker,
env=self.env,
pre_spawn_env=self.pre_spawn_env,
config=self.config,
)
raise

else:
try:
result = await self.process.start()
except Exception:
logger.error("Failed to start process", exc_info=True)
await self.close(reason="nanny-instantiate-failed")
raise
return result
return await self.process.start()

@log_errors
async def plugin_add(self, plugin=None, name=None):
Expand Down Expand Up @@ -519,6 +512,8 @@ def _on_worker_exit_sync(self, exitcode):

@log_errors
async def _on_worker_exit(self, exitcode):
assert self.process
self.process = None
if self.status not in (
Status.init,
Status.closing,
Expand Down Expand Up @@ -550,6 +545,8 @@ async def _on_worker_exit(self, exitcode):
logger.error(
"Failed to restart worker after its process exited", exc_info=True
)
await self.close(reason="worker-failed-restart")
raise

@property
def pid(self):
Expand Down Expand Up @@ -578,13 +575,17 @@ async def close(
"""
if self.status == Status.closing:
await self.finished()
assert self.status == Status.closed
assert self.status in (Status.closed, Status.failed)

if self.status == Status.closed:
if self.status in (Status.closed, Status.failed):
return "OK"

self.status = Status.closing
# Make sure we're not colliding with the startup coro when setting the
# status to closing
logger.info("Closing Nanny at %r. Reason: %s", self.address_safe, reason)
await self.cancel_start()

self.status = Status.closing

for preload in self.preloads:
await preload.teardown()
Expand Down Expand Up @@ -726,6 +727,7 @@ async def start(self) -> Status:
self.running.set()

init_q.close()
init_q.join_thread()

return self.status

Expand Down Expand Up @@ -760,7 +762,6 @@ def mark_stopped(self):
msg = self._death_message(self.process.pid, r)
logger.info(msg)
self.status = Status.stopped
self.stopped.set()
# Release resources
self.process.close()
self.init_result_q = None
Expand All @@ -773,6 +774,7 @@ def mark_stopped(self):
# User hook
if self.on_exit is not None:
self.on_exit(r)
self.stopped.set()

async def kill(
self,
Expand All @@ -791,13 +793,20 @@ async def kill(
"""
deadline = time() + timeout

# If the process is not properly up it will not watch the closing queue
# and we may end up leaking this process.
# Therefore wait for it to be properly started before killing it.
if self.status == Status.starting:
await self.running.wait()

if self.status == Status.stopped:
return

if self.status == Status.stopping:
await self.stopped.wait()
return

assert self.status in (
Status.starting,
Status.running,
Status.failed, # process failed to start, but hasn't been joined yet
), self.status
Expand All @@ -817,22 +826,20 @@ async def kill(
"reason": reason,
}
)
await asyncio.sleep(0) # otherwise we get broken pipe errors
queue.close()
queue.join_thread()
del queue

try:
try:
await process.join(wait_timeout)
return
except asyncio.TimeoutError:
pass

logger.warning(
f"Worker process still alive after {wait_timeout} seconds, killing"
)
await process.kill()
await process.join(max(0, deadline - time()))
logger.warning(
f"Worker process still alive after {wait_timeout} seconds, killing"
)
await process.kill()
await process.join(max(0, deadline - time()))
await self.stopped.wait()
except ValueError as e:
if "invalid operation on closed AsyncProcess" in str(e):
return
Expand Down Expand Up @@ -934,6 +941,7 @@ async def run() -> None:
}
)
init_result_q.close()
init_result_q.join_thread()
await worker.finished()
logger.info("Worker closed")
except Exception as e:
Expand All @@ -943,14 +951,7 @@ async def run() -> None:
logger.exception(f"Failed to {failure_type} worker")
init_result_q.put({"uid": uid, "exception": e})
init_result_q.close()
# If we hit an exception here we need to wait for a least
# one interval for the outside to pick up this message.
# Otherwise we arrive in a race condition where the process
# cleanup wipes the queue before the exception can be
# properly handled. See also
# WorkerProcess._wait_until_connected (the 3 is for good
# measure)
sync_sleep(cls._init_msg_interval * 3)
init_result_q.join_thread()

with contextlib.ExitStack() as stack:

Expand Down
4 changes: 3 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3958,7 +3958,9 @@ async def log_errors(func):
await asyncio.gather(
*[log_errors(plugin.before_close) for plugin in list(self.plugins.values())]
)

# Make sure we're not colliding with the startup coro when setting the
# status to closing
await self.started()
self.status = Status.closing

logger.info("Scheduler closing...")
Expand Down
Loading