diff options
Diffstat (limited to 'Lib/test/test_free_threading/test_heapq.py')
-rw-r--r-- | Lib/test/test_free_threading/test_heapq.py | 66 |
1 files changed, 23 insertions, 43 deletions
diff --git a/Lib/test/test_free_threading/test_heapq.py b/Lib/test/test_free_threading/test_heapq.py index ee7adfb..d771333 100644 --- a/Lib/test/test_free_threading/test_heapq.py +++ b/Lib/test/test_free_threading/test_heapq.py @@ -3,10 +3,11 @@ import unittest import heapq from enum import Enum -from threading import Thread, Barrier, Lock +from threading import Barrier, Lock from random import shuffle, randint from test.support import threading_helper +from test.support.threading_helper import run_concurrently from test import test_heapq @@ -28,8 +29,8 @@ class TestHeapq(unittest.TestCase): heap = list(range(OBJECT_COUNT)) shuffle(heap) - self.run_concurrently( - worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS + run_concurrently( + worker_func=heapq.heapify, nthreads=NTHREADS, args=(heap,) ) self.test_heapq.check_invariant(heap) @@ -40,8 +41,8 @@ class TestHeapq(unittest.TestCase): for item in reversed(range(OBJECT_COUNT)): heapq.heappush(heap, item) - self.run_concurrently( - worker_func=heappush_func, args=(heap,), nthreads=NTHREADS + run_concurrently( + worker_func=heappush_func, nthreads=NTHREADS, args=(heap,) ) self.test_heapq.check_invariant(heap) @@ -61,10 +62,10 @@ class TestHeapq(unittest.TestCase): # Each local list should be sorted self.assertTrue(self.is_sorted_ascending(local_list)) - self.run_concurrently( + run_concurrently( worker_func=heappop_func, - args=(heap, per_thread_pop_count), nthreads=NTHREADS, + args=(heap, per_thread_pop_count), ) self.assertEqual(len(heap), 0) @@ -77,10 +78,10 @@ class TestHeapq(unittest.TestCase): popped_item = heapq.heappushpop(heap, item) self.assertTrue(popped_item <= item) - self.run_concurrently( + run_concurrently( worker_func=heappushpop_func, - args=(heap, pushpop_items), nthreads=NTHREADS, + args=(heap, pushpop_items), ) self.assertEqual(len(heap), OBJECT_COUNT) self.test_heapq.check_invariant(heap) @@ -93,10 +94,10 @@ class TestHeapq(unittest.TestCase): for item in replace_items: heapq.heapreplace(heap, item) - self.run_concurrently( + run_concurrently( worker_func=heapreplace_func, - args=(heap, replace_items), nthreads=NTHREADS, + args=(heap, replace_items), ) self.assertEqual(len(heap), OBJECT_COUNT) self.test_heapq.check_invariant(heap) @@ -105,8 +106,8 @@ class TestHeapq(unittest.TestCase): max_heap = list(range(OBJECT_COUNT)) shuffle(max_heap) - self.run_concurrently( - worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS + run_concurrently( + worker_func=heapq.heapify_max, nthreads=NTHREADS, args=(max_heap,) ) self.test_heapq.check_max_invariant(max_heap) @@ -117,8 +118,8 @@ class TestHeapq(unittest.TestCase): for item in range(OBJECT_COUNT): heapq.heappush_max(max_heap, item) - self.run_concurrently( - worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS + run_concurrently( + worker_func=heappush_max_func, nthreads=NTHREADS, args=(max_heap,) ) self.test_heapq.check_max_invariant(max_heap) @@ -138,10 +139,10 @@ class TestHeapq(unittest.TestCase): # Each local list should be sorted self.assertTrue(self.is_sorted_descending(local_list)) - self.run_concurrently( + run_concurrently( worker_func=heappop_max_func, - args=(max_heap, per_thread_pop_count), nthreads=NTHREADS, + args=(max_heap, per_thread_pop_count), ) self.assertEqual(len(max_heap), 0) @@ -154,10 +155,10 @@ class TestHeapq(unittest.TestCase): popped_item = heapq.heappushpop_max(max_heap, item) self.assertTrue(popped_item >= item) - self.run_concurrently( + run_concurrently( worker_func=heappushpop_max_func, - args=(max_heap, pushpop_items), nthreads=NTHREADS, + args=(max_heap, pushpop_items), ) self.assertEqual(len(max_heap), OBJECT_COUNT) self.test_heapq.check_max_invariant(max_heap) @@ -170,10 +171,10 @@ class TestHeapq(unittest.TestCase): for item in replace_items: heapq.heapreplace_max(max_heap, item) - self.run_concurrently( + run_concurrently( worker_func=heapreplace_max_func, - args=(max_heap, replace_items), nthreads=NTHREADS, + args=(max_heap, replace_items), ) self.assertEqual(len(max_heap), OBJECT_COUNT) self.test_heapq.check_max_invariant(max_heap) @@ -203,7 +204,7 @@ class TestHeapq(unittest.TestCase): except IndexError: pass - self.run_concurrently(worker, (), n_threads * 2) + run_concurrently(worker, n_threads * 2) @staticmethod def is_sorted_ascending(lst): @@ -241,27 +242,6 @@ class TestHeapq(unittest.TestCase): """ return [randint(-a, b) for _ in range(size)] - def run_concurrently(self, worker_func, args, nthreads): - """ - Run the worker function concurrently in multiple threads. - """ - barrier = Barrier(nthreads) - - def wrapper_func(*args): - # Wait for all threads to reach this point before proceeding. - barrier.wait() - worker_func(*args) - - with threading_helper.catch_threading_exception() as cm: - workers = ( - Thread(target=wrapper_func, args=args) for _ in range(nthreads) - ) - with threading_helper.start_threads(workers): - pass - - # Worker threads should not raise any exceptions - self.assertIsNone(cm.exc_value) - if __name__ == "__main__": unittest.main() |