diff --git a/distributed/stealing.py b/distributed/stealing.py index b352040bff7..2a68a232a0e 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -452,4 +452,4 @@ def _can_steal(thief, ts, victim): return True -fast_tasks = {"shuffle-split"} +fast_tasks = {"split-shuffle"} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index ee2695cea87..fbabd2a6086 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -827,3 +827,38 @@ async def test_balance_with_longer_task(c, s, a, b): ) # a task after y, suggesting a, but open to b await z assert z.key in b.data + + +@gen_cluster(client=True) +async def test_blacklist_shuffle_split(c, s, a, b): + + pd = pytest.importorskip("pandas") + dd = pytest.importorskip("dask.dataframe") + npart = 10 + df = dd.from_pandas(pd.DataFrame({"A": range(100), "B": 1}), npartitions=npart) + graph = df.shuffle( + "A", + shuffle="tasks", + # If we don't have enough partitions, we'll fall back to a simple shuffle + max_branch=npart - 1, + ).sum() + res = c.compute(graph) + + while not s.tasks: + await asyncio.sleep(0.005) + prefixes = set(s.task_prefixes.keys()) + from distributed.stealing import fast_tasks + + blacklisted = fast_tasks & prefixes + assert blacklisted + assert any(["split" in prefix for prefix in blacklisted]) + + stealable = s.extensions["stealing"].stealable + while not res.done(): + for tasks_per_level in stealable.values(): + for tasks in tasks_per_level: + for ts in tasks: + assert ts.prefix.name not in fast_tasks + assert "split" not in ts.prefix.name + await asyncio.sleep(0.001) + await res