summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_free_threading/test_heapq.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_free_threading/test_heapq.py')
-rw-r--r--Lib/test/test_free_threading/test_heapq.py66
1 files changed, 23 insertions, 43 deletions
diff --git a/Lib/test/test_free_threading/test_heapq.py b/Lib/test/test_free_threading/test_heapq.py
index ee7adfb..d771333 100644
--- a/Lib/test/test_free_threading/test_heapq.py
+++ b/Lib/test/test_free_threading/test_heapq.py
@@ -3,10 +3,11 @@ import unittest
import heapq
from enum import Enum
-from threading import Thread, Barrier, Lock
+from threading import Barrier, Lock
from random import shuffle, randint
from test.support import threading_helper
+from test.support.threading_helper import run_concurrently
from test import test_heapq
@@ -28,8 +29,8 @@ class TestHeapq(unittest.TestCase):
heap = list(range(OBJECT_COUNT))
shuffle(heap)
- self.run_concurrently(
- worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
+ run_concurrently(
+ worker_func=heapq.heapify, nthreads=NTHREADS, args=(heap,)
)
self.test_heapq.check_invariant(heap)
@@ -40,8 +41,8 @@ class TestHeapq(unittest.TestCase):
for item in reversed(range(OBJECT_COUNT)):
heapq.heappush(heap, item)
- self.run_concurrently(
- worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
+ run_concurrently(
+ worker_func=heappush_func, nthreads=NTHREADS, args=(heap,)
)
self.test_heapq.check_invariant(heap)
@@ -61,10 +62,10 @@ class TestHeapq(unittest.TestCase):
# Each local list should be sorted
self.assertTrue(self.is_sorted_ascending(local_list))
- self.run_concurrently(
+ run_concurrently(
worker_func=heappop_func,
- args=(heap, per_thread_pop_count),
nthreads=NTHREADS,
+ args=(heap, per_thread_pop_count),
)
self.assertEqual(len(heap), 0)
@@ -77,10 +78,10 @@ class TestHeapq(unittest.TestCase):
popped_item = heapq.heappushpop(heap, item)
self.assertTrue(popped_item <= item)
- self.run_concurrently(
+ run_concurrently(
worker_func=heappushpop_func,
- args=(heap, pushpop_items),
nthreads=NTHREADS,
+ args=(heap, pushpop_items),
)
self.assertEqual(len(heap), OBJECT_COUNT)
self.test_heapq.check_invariant(heap)
@@ -93,10 +94,10 @@ class TestHeapq(unittest.TestCase):
for item in replace_items:
heapq.heapreplace(heap, item)
- self.run_concurrently(
+ run_concurrently(
worker_func=heapreplace_func,
- args=(heap, replace_items),
nthreads=NTHREADS,
+ args=(heap, replace_items),
)
self.assertEqual(len(heap), OBJECT_COUNT)
self.test_heapq.check_invariant(heap)
@@ -105,8 +106,8 @@ class TestHeapq(unittest.TestCase):
max_heap = list(range(OBJECT_COUNT))
shuffle(max_heap)
- self.run_concurrently(
- worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
+ run_concurrently(
+ worker_func=heapq.heapify_max, nthreads=NTHREADS, args=(max_heap,)
)
self.test_heapq.check_max_invariant(max_heap)
@@ -117,8 +118,8 @@ class TestHeapq(unittest.TestCase):
for item in range(OBJECT_COUNT):
heapq.heappush_max(max_heap, item)
- self.run_concurrently(
- worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
+ run_concurrently(
+ worker_func=heappush_max_func, nthreads=NTHREADS, args=(max_heap,)
)
self.test_heapq.check_max_invariant(max_heap)
@@ -138,10 +139,10 @@ class TestHeapq(unittest.TestCase):
# Each local list should be sorted
self.assertTrue(self.is_sorted_descending(local_list))
- self.run_concurrently(
+ run_concurrently(
worker_func=heappop_max_func,
- args=(max_heap, per_thread_pop_count),
nthreads=NTHREADS,
+ args=(max_heap, per_thread_pop_count),
)
self.assertEqual(len(max_heap), 0)
@@ -154,10 +155,10 @@ class TestHeapq(unittest.TestCase):
popped_item = heapq.heappushpop_max(max_heap, item)
self.assertTrue(popped_item >= item)
- self.run_concurrently(
+ run_concurrently(
worker_func=heappushpop_max_func,
- args=(max_heap, pushpop_items),
nthreads=NTHREADS,
+ args=(max_heap, pushpop_items),
)
self.assertEqual(len(max_heap), OBJECT_COUNT)
self.test_heapq.check_max_invariant(max_heap)
@@ -170,10 +171,10 @@ class TestHeapq(unittest.TestCase):
for item in replace_items:
heapq.heapreplace_max(max_heap, item)
- self.run_concurrently(
+ run_concurrently(
worker_func=heapreplace_max_func,
- args=(max_heap, replace_items),
nthreads=NTHREADS,
+ args=(max_heap, replace_items),
)
self.assertEqual(len(max_heap), OBJECT_COUNT)
self.test_heapq.check_max_invariant(max_heap)
@@ -203,7 +204,7 @@ class TestHeapq(unittest.TestCase):
except IndexError:
pass
- self.run_concurrently(worker, (), n_threads * 2)
+ run_concurrently(worker, n_threads * 2)
@staticmethod
def is_sorted_ascending(lst):
@@ -241,27 +242,6 @@ class TestHeapq(unittest.TestCase):
"""
return [randint(-a, b) for _ in range(size)]
- def run_concurrently(self, worker_func, args, nthreads):
- """
- Run the worker function concurrently in multiple threads.
- """
- barrier = Barrier(nthreads)
-
- def wrapper_func(*args):
- # Wait for all threads to reach this point before proceeding.
- barrier.wait()
- worker_func(*args)
-
- with threading_helper.catch_threading_exception() as cm:
- workers = (
- Thread(target=wrapper_func, args=args) for _ in range(nthreads)
- )
- with threading_helper.start_threads(workers):
- pass
-
- # Worker threads should not raise any exceptions
- self.assertIsNone(cm.exc_value)
-
if __name__ == "__main__":
unittest.main()