summaryrefslogtreecommitdiffstats
path: root/Lib/test/lock_tests.py
diff options
context:
space:
mode:
authorKristján Valur Jónsson <kristjan@ccpgames.com>2010-10-28 09:43:10 (GMT)
committerKristján Valur Jónsson <kristjan@ccpgames.com>2010-10-28 09:43:10 (GMT)
commit3be00037d65178644b20a826f68eb3d0b25ccb5f (patch)
tree31d7ff67d789c0ca1fc0ce927afaa0fb5afe5714 /Lib/test/lock_tests.py
parent65ffae0aa3b31f26503182cbc7cd79943b6b8ff5 (diff)
downloadcpython-3be00037d65178644b20a826f68eb3d0b25ccb5f.zip
cpython-3be00037d65178644b20a826f68eb3d0b25ccb5f.tar.gz
cpython-3be00037d65178644b20a826f68eb3d0b25ccb5f.tar.bz2
issue 8777
Add threading.Barrier
Diffstat (limited to 'Lib/test/lock_tests.py')
-rw-r--r--Lib/test/lock_tests.py190
1 files changed, 190 insertions, 0 deletions
diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py
index 1ff6af0..f256a80 100644
--- a/Lib/test/lock_tests.py
+++ b/Lib/test/lock_tests.py
@@ -597,3 +597,193 @@ class BoundedSemaphoreTests(BaseSemaphoreTests):
sem.acquire()
sem.release()
self.assertRaises(ValueError, sem.release)
+
+
+class BarrierTests(BaseTestCase):
+ """
+ Tests for Barrier objects.
+ """
+ N = 5
+
+ def setUp(self):
+ self.barrier = self.barriertype(self.N, timeout=0.1)
+ def tearDown(self):
+ self.barrier.abort()
+
+ def run_threads(self, f):
+ b = Bunch(f, self.N-1)
+ f()
+ b.wait_for_finished()
+
+ def multipass(self, results, n):
+ m = self.barrier.parties
+ self.assertEqual(m, self.N)
+ for i in range(n):
+ results[0].append(True)
+ self.assertEqual(len(results[1]), i * m)
+ self.barrier.wait()
+ results[1].append(True)
+ self.assertEqual(len(results[0]), (i + 1) * m)
+ self.barrier.wait()
+ self.assertEqual(self.barrier.n_waiting, 0)
+ self.assertFalse(self.barrier.broken)
+
+ def test_barrier(self, passes=1):
+ """
+ Test that a barrier is passed in lockstep
+ """
+ results = [[],[]]
+ def f():
+ self.multipass(results, passes)
+ self.run_threads(f)
+
+ def test_barrier_10(self):
+ """
+ Test that a barrier works for 10 consecutive runs
+ """
+ return self.test_barrier(10)
+
+ def test_wait_return(self):
+ """
+ test the return value from barrier.wait
+ """
+ results = []
+ def f():
+ r = self.barrier.wait()
+ results.append(r)
+
+ self.run_threads(f)
+ self.assertEqual(sum(results), sum(range(self.N)))
+
+ def test_action(self):
+ """
+ Test the 'action' callback
+ """
+ results = []
+ def action():
+ results.append(True)
+ barrier = self.barriertype(self.N, action)
+ def f():
+ barrier.wait()
+ self.assertEqual(len(results), 1)
+
+ self.run_threads(f)
+
+ def test_abort(self):
+ """
+ Test that an abort will put the barrier in a broken state
+ """
+ results1 = []
+ results2 = []
+ def f():
+ try:
+ i = self.barrier.wait()
+ if i == self.N//2:
+ raise RuntimeError
+ self.barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ self.barrier.abort()
+ pass
+
+ self.run_threads(f)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(self.barrier.broken)
+
+ def test_reset(self):
+ """
+ Test that a 'reset' on a barrier frees the waiting threads
+ """
+ results1 = []
+ results2 = []
+ results3 = []
+ def f():
+ i = self.barrier.wait()
+ if i == self.N//2:
+ # Wait until the other threads are all in the barrier.
+ while self.barrier.n_waiting < self.N-1:
+ time.sleep(0.001)
+ self.barrier.reset()
+ else:
+ try:
+ self.barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ # Now, pass the barrier again
+ self.barrier.wait()
+ results3.append(True)
+
+ self.run_threads(f)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertEqual(len(results3), self.N)
+
+
+ def test_abort_and_reset(self):
+ """
+ Test that a barrier can be reset after being broken.
+ """
+ results1 = []
+ results2 = []
+ results3 = []
+ barrier2 = self.barriertype(self.N)
+ def f():
+ try:
+ i = self.barrier.wait()
+ if i == self.N//2:
+ raise RuntimeError
+ self.barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ self.barrier.abort()
+ pass
+ # Synchronize and reset the barrier. Must synchronize first so
+ # that everyone has left it when we reset, and after so that no
+ # one enters it before the reset.
+ if barrier2.wait() == self.N//2:
+ self.barrier.reset()
+ barrier2.wait()
+ self.barrier.wait()
+ results3.append(True)
+
+ self.run_threads(f)
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertEqual(len(results3), self.N)
+
+ def test_timeout(self):
+ """
+ Test wait(timeout)
+ """
+ def f():
+ i = self.barrier.wait()
+ if i == self.N // 2:
+ # One thread is late!
+ time.sleep(0.1)
+ # Default timeout is 0.1, so this is shorter.
+ self.assertRaises(threading.BrokenBarrierError,
+ self.barrier.wait, 0.05)
+ self.run_threads(f)
+
+ def test_default_timeout(self):
+ """
+ Test the barrier's default timeout
+ """
+ def f():
+ i = self.barrier.wait()
+ if i == self.N // 2:
+ # One thread is later than the default timeout of 0.1s.
+ time.sleep(0.15)
+ self.assertRaises(threading.BrokenBarrierError, self.barrier.wait)
+ self.run_threads(f)
+
+ def test_single_thread(self):
+ b = self.barriertype(1)
+ b.wait()
+ b.wait()