Skip to content
Prev Previous commit
Next Next commit
gh-116738: Improve run_concurrently() arguments
  • Loading branch information
yoney committed Jul 3, 2025
commit 851caf4c754e63ca0ca93adb456eb0f439b80af2
7 changes: 7 additions & 0 deletions Doc/library/test.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,13 @@ The :mod:`test.support.threading_helper` module provides support for threading t
.. versionadded:: 3.8


.. function:: run_concurrently(worker_func, nthreads, args=(), kwargs={})

Run the worker function concurrently in multiple threads.
Re-raises an exception if any thread raises one, after all threads have
finished.


:mod:`test.support.os_helper` --- Utilities for os tests
========================================================================

Expand Down
8 changes: 4 additions & 4 deletions Lib/test/support/threading_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,20 +250,20 @@ def requires_working_threading(*, module=False):
return unittest.skipUnless(can_start_thread, msg)


def run_concurrently(worker_func, args, nthreads):
def run_concurrently(worker_func, nthreads, args=(), kwargs={}):
"""
Run the worker function concurrently in multiple threads.
"""
barrier = threading.Barrier(nthreads)

def wrapper_func(*args):
def wrapper_func(*args, **kwargs):
# Wait for all threads to reach this point before proceeding.
barrier.wait()
worker_func(*args)
worker_func(*args, **kwargs)

with catch_threading_exception() as cm:
workers = (
threading.Thread(target=wrapper_func, args=args)
threading.Thread(target=wrapper_func, args=args, kwargs=kwargs)
for _ in range(nthreads)
)
with start_threads(workers):
Expand Down
3 changes: 1 addition & 2 deletions Lib/test/test_free_threading/test_grp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ def setUp(self):
def test_racing_test_values(self):
# test_grp.test_values() calls grp.getgrall() and checks the entries
run_concurrently(
worker_func=self.test_grp.test_values, args=(), nthreads=NTHREADS
worker_func=self.test_grp.test_values, nthreads=NTHREADS
)

def test_racing_test_values_extended(self):
# test_grp.test_values_extended() calls grp.getgrall(), grp.getgrgid(),
# grp.getgrnam() and checks the entries
run_concurrently(
worker_func=self.test_grp.test_values_extended,
args=(),
nthreads=NTHREADS,
)

Expand Down
22 changes: 11 additions & 11 deletions Lib/test/test_free_threading/test_heapq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_racing_heapify(self):
shuffle(heap)

run_concurrently(
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
worker_func=heapq.heapify, nthreads=NTHREADS, args=(heap,)
)
self.test_heapq.check_invariant(heap)

Expand All @@ -42,7 +42,7 @@ def heappush_func(heap):
heapq.heappush(heap, item)

run_concurrently(
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
worker_func=heappush_func, nthreads=NTHREADS, args=(heap,)
)
self.test_heapq.check_invariant(heap)

Expand All @@ -64,8 +64,8 @@ def heappop_func(heap, pop_count):

run_concurrently(
worker_func=heappop_func,
args=(heap, per_thread_pop_count),
nthreads=NTHREADS,
args=(heap, per_thread_pop_count),
)
self.assertEqual(len(heap), 0)

Expand All @@ -80,8 +80,8 @@ def heappushpop_func(heap, pushpop_items):

run_concurrently(
worker_func=heappushpop_func,
args=(heap, pushpop_items),
nthreads=NTHREADS,
args=(heap, pushpop_items),
)
self.assertEqual(len(heap), OBJECT_COUNT)
self.test_heapq.check_invariant(heap)
Expand All @@ -96,8 +96,8 @@ def heapreplace_func(heap, replace_items):

run_concurrently(
worker_func=heapreplace_func,
args=(heap, replace_items),
nthreads=NTHREADS,
args=(heap, replace_items),
)
self.assertEqual(len(heap), OBJECT_COUNT)
self.test_heapq.check_invariant(heap)
Expand All @@ -107,7 +107,7 @@ def test_racing_heapify_max(self):
shuffle(max_heap)

run_concurrently(
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
worker_func=heapq.heapify_max, nthreads=NTHREADS, args=(max_heap,)
)
self.test_heapq.check_max_invariant(max_heap)

Expand All @@ -119,7 +119,7 @@ def heappush_max_func(max_heap):
heapq.heappush_max(max_heap, item)

run_concurrently(
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
worker_func=heappush_max_func, nthreads=NTHREADS, args=(max_heap,)
)
self.test_heapq.check_max_invariant(max_heap)

Expand All @@ -141,8 +141,8 @@ def heappop_max_func(max_heap, pop_count):

run_concurrently(
worker_func=heappop_max_func,
args=(max_heap, per_thread_pop_count),
nthreads=NTHREADS,
args=(max_heap, per_thread_pop_count),
)
self.assertEqual(len(max_heap), 0)

Expand All @@ -157,8 +157,8 @@ def heappushpop_max_func(max_heap, pushpop_items):

run_concurrently(
worker_func=heappushpop_max_func,
args=(max_heap, pushpop_items),
nthreads=NTHREADS,
args=(max_heap, pushpop_items),
)
self.assertEqual(len(max_heap), OBJECT_COUNT)
self.test_heapq.check_max_invariant(max_heap)
Expand All @@ -173,8 +173,8 @@ def heapreplace_max_func(max_heap, replace_items):

run_concurrently(
worker_func=heapreplace_max_func,
args=(max_heap, replace_items),
nthreads=NTHREADS,
args=(max_heap, replace_items),
)
self.assertEqual(len(max_heap), OBJECT_COUNT)
self.test_heapq.check_max_invariant(max_heap)
Expand Down Expand Up @@ -204,7 +204,7 @@ def worker():
except IndexError:
pass

run_concurrently(worker, (), n_threads * 2)
run_concurrently(worker, n_threads * 2)

@staticmethod
def is_sorted_ascending(lst):
Expand Down
Loading