Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions distributed/spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

logger = logging.getLogger(__name__)
has_zict_210 = parse_version(zict.__version__) >= parse_version("2.1.0")
# At the moment of writing, zict 2.2.0 has not been released yet. Support git tip.
has_zict_220 = parse_version(zict.__version__) >= parse_version("2.2.0.dev2")
has_zict_220 = parse_version(zict.__version__) >= parse_version("2.2.0")


class SpilledSize(NamedTuple):
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ async def test_worker_dir(c, s, a, b):
test_worker_dir()


@gen_cluster(nthreads=[])
@gen_cluster(nthreads=[], config={"temporary-directory": None})
async def test_false_worker_dir(s):
async with Worker(s.address, local_directory="") as w:
local_directory = w.local_directory
Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import distributed.system
from distributed import Client, Event, Nanny, Worker, wait
from distributed.compatibility import MACOS
from distributed.core import Status
from distributed.metrics import monotonic
from distributed.spill import has_zict_210
Expand Down Expand Up @@ -763,6 +764,7 @@ def __reduce__(self):


@pytest.mark.slow
@pytest.mark.skipif(MACOS, reason="https://github.com/dask/distributed/issues/6233")
@gen_cluster(
nthreads=[("", 1)],
client=True,
Expand Down
245 changes: 125 additions & 120 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,9 +966,8 @@ def gen_cluster(
allow_unclosed: bool = False,
cluster_dump_directory: str | Literal[False] = "test_cluster_dump",
) -> Callable[[Callable], Callable]:
from distributed import Client

""" Coroutine test with small cluster
"""Coroutine test with small cluster

@gen_cluster()
async def test_foo(scheduler, worker1, worker2):
Expand Down Expand Up @@ -1012,129 +1011,135 @@ def test_func(*outer_args, **kwargs):
with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop:

async def coro():
with dask.config.set(config):
s = False
for _ in range(60):
try:
s, ws = await start_cluster(
nthreads,
scheduler,
loop,
with tempfile.TemporaryDirectory() as tmpdir:
config2 = merge({"temporary-directory": tmpdir}, config)
with dask.config.set(config2):

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything beyond this line is just an extra indentation

s = False
for _ in range(60):
try:
s, ws = await start_cluster(
nthreads,
scheduler,
loop,
security=security,
Worker=Worker,
scheduler_kwargs=scheduler_kwargs,
worker_kwargs=worker_kwargs,
)
except Exception as e:
logger.error(
"Failed to start gen_cluster: "
f"{e.__class__.__name__}: {e}; retrying",
exc_info=True,
)
await asyncio.sleep(1)
else:
workers[:] = ws
args = [s] + workers
break
if s is False:
raise Exception("Could not start cluster")
if client:
c = await Client(
s.address,
loop=loop,
security=security,
Worker=Worker,
scheduler_kwargs=scheduler_kwargs,
worker_kwargs=worker_kwargs,
)
except Exception as e:
logger.error(
"Failed to start gen_cluster: "
f"{e.__class__.__name__}: {e}; retrying",
exc_info=True,
)
await asyncio.sleep(1)
else:
workers[:] = ws
args = [s] + workers
break
if s is False:
raise Exception("Could not start cluster")
if client:
c = await Client(
s.address,
loop=loop,
security=security,
asynchronous=True,
**client_kwargs,
)
args = [c] + args

try:
coro = func(*args, *outer_args, **kwargs)
task = asyncio.create_task(coro)
coro2 = asyncio.wait_for(asyncio.shield(task), timeout)
result = await coro2
if s.validate:
s.validate_state()

except asyncio.TimeoutError:
assert task
buffer = io.StringIO()
# This stack indicates where the coro/test is suspended
task.print_stack(file=buffer)

if cluster_dump_directory:
await dump_cluster_state(
s,
ws,
output_dir=cluster_dump_directory,
func_name=func.__name__,
asynchronous=True,
**client_kwargs,
)
args = [c] + args

task.cancel()
while not task.cancelled():
await asyncio.sleep(0.01)

# Remove as much of the traceback as possible; it's
# uninteresting boilerplate from utils_test and asyncio and
# not from the code being tested.
raise TimeoutError(
f"Test timeout after {timeout}s.\n"
"========== Test stack trace starts here ==========\n"
f"{buffer.getvalue()}"
) from None

except pytest.xfail.Exception:
raise

except Exception:
if cluster_dump_directory and not has_pytestmark(
test_func, "xfail"
):
await dump_cluster_state(
s,
ws,
output_dir=cluster_dump_directory,
func_name=func.__name__,
)
raise

finally:
if client and c.status not in ("closing", "closed"):
await c._close(fast=s.status == Status.closed)
await end_cluster(s, workers)
await asyncio.wait_for(cleanup_global_workers(), 1)

try:
c = await default_client()
except ValueError:
pass
else:
await c._close(fast=True)

def get_unclosed():
return [c for c in Comm._instances if not c.closed()] + [
c
for c in _global_clients.values()
if c.status != "closed"
]

try:
start = time()
while time() < start + 60:
gc.collect()
if not get_unclosed():
break
await asyncio.sleep(0.05)
try:
coro = func(*args, *outer_args, **kwargs)
task = asyncio.create_task(coro)
coro2 = asyncio.wait_for(asyncio.shield(task), timeout)
result = await coro2
if s.validate:
s.validate_state()

except asyncio.TimeoutError:
assert task
buffer = io.StringIO()
# This stack indicates where the coro/test is suspended
task.print_stack(file=buffer)

if cluster_dump_directory:
await dump_cluster_state(
s,
ws,
output_dir=cluster_dump_directory,
func_name=func.__name__,
)

task.cancel()
while not task.cancelled():
await asyncio.sleep(0.01)

# Remove as much of the traceback as possible; it's
# uninteresting boilerplate from utils_test and asyncio
# and not from the code being tested.
raise TimeoutError(
f"Test timeout after {timeout}s.\n"
"========== Test stack trace starts here ==========\n"
f"{buffer.getvalue()}"
) from None

except pytest.xfail.Exception:
raise

except Exception:
if cluster_dump_directory and not has_pytestmark(
test_func, "xfail"
):
await dump_cluster_state(
s,
ws,
output_dir=cluster_dump_directory,
func_name=func.__name__,
)
raise

finally:
if client and c.status not in ("closing", "closed"):
await c._close(fast=s.status == Status.closed)
await end_cluster(s, workers)
await asyncio.wait_for(cleanup_global_workers(), 1)

try:
c = await default_client()
except ValueError:
pass
else:
if allow_unclosed:
print(f"Unclosed Comms: {get_unclosed()}")
else:
raise RuntimeError("Unclosed Comms", get_unclosed())
finally:
Comm._instances.clear()
_global_clients.clear()
await c._close(fast=True)

def get_unclosed():
return [
c for c in Comm._instances if not c.closed()
] + [
c
for c in _global_clients.values()
if c.status != "closed"
]

return result
try:
start = time()
while time() < start + 60:
gc.collect()
if not get_unclosed():
break
await asyncio.sleep(0.05)
else:
if allow_unclosed:
print(f"Unclosed Comms: {get_unclosed()}")
else:
raise RuntimeError(
"Unclosed Comms", get_unclosed()
)
finally:
Comm._instances.clear()
_global_clients.clear()

return result

result = loop.run_sync(
coro, timeout=timeout * 2 if timeout else timeout
Expand Down