From d8c654f98a544f7e20ee731d38fce654278803e4 Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 24 Jun 2022 15:45:14 +0200 Subject: [PATCH] Use pytest tmpdir_factory in gen_cluster --- distributed/tests/test_utils_test.py | 9 ++++++ distributed/utils_test.py | 44 ++++++++++++++++++---------- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 438f6faa475..32f8fcf5bbb 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -88,6 +88,15 @@ async def test_gen_cluster_pytest_fixture(c, s, a, b, tmp_path): assert isinstance(w, Worker) +@gen_cluster(client=True) +async def test_gen_cluster_pytest_fixture_tmpdir_factory(c, s, a, b, tmpdir_factory): + assert isinstance(tmpdir_factory, pytest.TempdirFactory) + assert isinstance(c, Client) + assert isinstance(s, Scheduler) + for w in [a, b]: + assert isinstance(w, Worker) + + @pytest.mark.parametrize("foo", [True]) @gen_cluster(client=True) async def test_gen_cluster_parametrized(c, s, a, b, foo): diff --git a/distributed/utils_test.py b/distributed/utils_test.py index f14eef1b041..4750ef3bfa7 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1064,15 +1064,40 @@ async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture def _(func): if not iscoroutinefunction(func): raise RuntimeError("gen_cluster only works for coroutine functions.") + # Patch the signature so pytest can inject fixtures + orig_sig = inspect.signature(func) + args = [None] * (1 + len(nthreads)) # scheduler, *workers + if client: + args.insert(0, None) + + bound = orig_sig.bind_partial(*args) + new_parameters = [ + p for name, p in orig_sig.parameters.items() if name not in bound.arguments + ] + requires_tmpdir_factory_fixture = True + if "tmp_path_factory" not in orig_sig.parameters: + requires_tmpdir_factory_fixture = False + new_parameters.append( + inspect.Parameter("tmp_path_factory", inspect.Parameter.KEYWORD_ONLY) + ) @functools.wraps(func) @clean(**clean_kwargs) def test_func(*outer_args, **kwargs): async def async_fn(): - result = None - with tempfile.TemporaryDirectory() as tmpdir: + with contextlib.ExitStack() as exitstack: + tmp_path_factory = kwargs.get("tmp_path_factory") + if tmp_path_factory is None: + # For usage as plain ctxmanager for non-pytest functions + tmpdir = str(exitstack.enter_context(_SafeTemporaryDirectory())) + else: + if not requires_tmpdir_factory_fixture: + del kwargs["tmp_path_factory"] + tmpdir = tmp_path_factory.mktemp(func.__name__) + result = None config2 = merge({"temporary-directory": tmpdir}, config) with dask.config.set(config2): + workers = [] s = False @@ -1217,20 +1242,7 @@ async def async_fn_outer(): return _run_and_close_tornado(async_fn_outer) - # Patch the signature so pytest can inject fixtures - orig_sig = inspect.signature(func) - args = [None] * (1 + len(nthreads)) # scheduler, *workers - if client: - args.insert(0, None) - - bound = orig_sig.bind_partial(*args) - test_func.__signature__ = orig_sig.replace( - parameters=[ - p - for name, p in orig_sig.parameters.items() - if name not in bound.arguments - ] - ) + test_func.__signature__ = orig_sig.replace(parameters=new_parameters) return test_func