Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,13 @@ async def shuffle_inputs_done(self, shuffle_id: ShuffleId, run_id: int) -> None:
await shuffle.inputs_done()

async def _close_shuffle_run(self, shuffle: ShuffleRun) -> None:
await shuffle.close()
async with self._runs_cleanup_condition:
self._runs.remove(shuffle)
self._runs_cleanup_condition.notify_all()
with log_errors():
try:
await shuffle.close()
finally:
async with self._runs_cleanup_condition:
self._runs.remove(shuffle)
self._runs_cleanup_condition.notify_all()

def shuffle_fail(self, shuffle_id: ShuffleId, run_id: int, message: str) -> None:
"""Fails the shuffle run with the message as exception and triggers cleanup.
Expand Down Expand Up @@ -277,15 +280,9 @@ async def _refresh_shuffle(
RuntimeError("{existing!r} stale, expected run_id=={run_id}")
)

async def _(
extension: ShuffleWorkerPlugin, shuffle: ShuffleRun
) -> None:
await shuffle.close()
async with extension._runs_cleanup_condition:
extension._runs.remove(shuffle)
extension._runs_cleanup_condition.notify_all()

self.worker._ongoing_background_tasks.call_soon(_, self, existing)
self.worker._ongoing_background_tasks.call_soon(
ShuffleWorkerPlugin._close_shuffle_run, self, existing
)
shuffle: ShuffleRun = result.spec.create_run_on_worker(
result.run_id, result.worker_for, self
)
Expand Down
29 changes: 29 additions & 0 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import io
import itertools
import logging
import os
import random
import shutil
Expand Down Expand Up @@ -610,6 +611,34 @@ async def test_closed_bystanding_worker_during_shuffle(c, s, w1, w2, w3):
await check_scheduler_cleanup(s)


class RaiseOnCloseShuffleRun(DataFrameShuffleRun):
async def close(self, *args, **kwargs):
raise RuntimeError("test-exception-on-close")


@mock.patch(
"distributed.shuffle._shuffle.DataFrameShuffleRun",
RaiseOnCloseShuffleRun,
)
@gen_cluster(client=True, nthreads=[])
async def test_exception_on_close_cleans_up(c, s, caplog):
# Ensure that everything is cleaned up and does not lock up if an exception
# is raised during shuffle close.
with caplog.at_level(logging.ERROR):
async with Worker(s.address) as w:
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="10 s",
)
shuffled = dd.shuffle.shuffle(df, "x", shuffle="p2p")
await c.compute([shuffled, df], sync=True)

assert any("test-exception-on-close" in record.message for record in caplog.records)
await check_worker_cleanup(w, closed=True)


class BlockedInputsDoneShuffle(DataFrameShuffleRun):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down