diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 78f4db5e400..f9d98dd7882 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -34,6 +34,7 @@ assert_story, captured_logger, check_process_leak, + check_thread_leak, cluster, dump_cluster_state, freeze_batched_send, @@ -755,6 +756,49 @@ def test_raises_with_cause(): raise RuntimeError("exception") from ValueError("cause") +@pytest.mark.slow +def test_check_thread_leak(): + event = threading.Event() + + t1 = threading.Thread(target=lambda: (event.wait(), "one")) + t1.start() + + t2 = t3 = None + try: + with pytest.raises( + pytest.fail.Exception, match=r"2 thread\(s\) were leaked" + ) as exc: + with check_thread_leak(): + t2 = threading.Thread(target=lambda: (event.wait(), "two")) + t2.start() + t3 = threading.Thread(target=lambda: (event.wait(), "three")) + t3.start() + + msg = exc.value.msg + assert msg + print(msg) # For reference, if test fails + + # First, outer thread is ignored + assert msg.count("Call stack of leaked thread") == 2 + assert "one" not in msg + + # Make sure we can see the full traceback, not just the last line + assert msg.count(__file__) == 2 + assert 'target=lambda: (event.wait(), "two")' in msg + assert 'target=lambda: (event.wait(), "three")' in msg + + # Ensure there aren't too many or too few newlines + exc.match(r'event.wait\(\), "three"\)\)\n +File') + finally: + # Clean up + event.set() + t1.join(5) + if t2: + t2.join(5) + if t3: + t3.join(5) + + @pytest.mark.parametrize("sync", [True, False]) def test_fail_hard(sync): """@fail_hard is a last resort when error handling for everything that we foresaw diff --git a/distributed/utils_test.py b/distributed/utils_test.py index dbf5a36352d..23f8817ab10 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1737,9 +1737,23 @@ def check_thread_leak(): # Raise an error with information about leaked threads from distributed import profile - bad_thread = bad_threads[0] - call_stacks = profile.call_stack(sys._current_frames()[bad_thread.ident]) - assert False, (bad_thread, call_stacks) # noqa: B011 + frames = sys._current_frames() + try: + lines: list[str] = [ + f"{len(bad_threads)} thread(s) were leaked from test\n" + ] + for i, thread in enumerate(bad_threads, 1): + lines.append( + f"------ Call stack of leaked thread {i}/{len(bad_threads)}: {thread} ------" + ) + lines.append( + "".join(profile.call_stack(frames[thread.ident])) + # NOTE: `call_stack` already adds newlines + ) + finally: + del frames + + pytest.fail("\n".join(lines), pytrace=False) def wait_active_children(timeout: float) -> list[multiprocessing.Process]: