diff options
author | Kristján Valur Jónsson <kristjan@ccpgames.com> | 2010-10-28 09:43:10 (GMT) |
---|---|---|
committer | Kristján Valur Jónsson <kristjan@ccpgames.com> | 2010-10-28 09:43:10 (GMT) |
commit | 3be00037d65178644b20a826f68eb3d0b25ccb5f (patch) | |
tree | 31d7ff67d789c0ca1fc0ce927afaa0fb5afe5714 /Lib/test/lock_tests.py | |
parent | 65ffae0aa3b31f26503182cbc7cd79943b6b8ff5 (diff) | |
download | cpython-3be00037d65178644b20a826f68eb3d0b25ccb5f.zip cpython-3be00037d65178644b20a826f68eb3d0b25ccb5f.tar.gz cpython-3be00037d65178644b20a826f68eb3d0b25ccb5f.tar.bz2 |
issue 8777
Add threading.Barrier
Diffstat (limited to 'Lib/test/lock_tests.py')
-rw-r--r-- | Lib/test/lock_tests.py | 190 |
1 files changed, 190 insertions, 0 deletions
diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py index 1ff6af0..f256a80 100644 --- a/Lib/test/lock_tests.py +++ b/Lib/test/lock_tests.py @@ -597,3 +597,193 @@ class BoundedSemaphoreTests(BaseSemaphoreTests): sem.acquire() sem.release() self.assertRaises(ValueError, sem.release) + + +class BarrierTests(BaseTestCase): + """ + Tests for Barrier objects. + """ + N = 5 + + def setUp(self): + self.barrier = self.barriertype(self.N, timeout=0.1) + def tearDown(self): + self.barrier.abort() + + def run_threads(self, f): + b = Bunch(f, self.N-1) + f() + b.wait_for_finished() + + def multipass(self, results, n): + m = self.barrier.parties + self.assertEqual(m, self.N) + for i in range(n): + results[0].append(True) + self.assertEqual(len(results[1]), i * m) + self.barrier.wait() + results[1].append(True) + self.assertEqual(len(results[0]), (i + 1) * m) + self.barrier.wait() + self.assertEqual(self.barrier.n_waiting, 0) + self.assertFalse(self.barrier.broken) + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [[],[]] + def f(): + self.multipass(results, passes) + self.run_threads(f) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + results = [] + def f(): + r = self.barrier.wait() + results.append(r) + + self.run_threads(f) + self.assertEqual(sum(results), sum(range(self.N))) + + def test_action(self): + """ + Test the 'action' callback + """ + results = [] + def action(): + results.append(True) + barrier = self.barriertype(self.N, action) + def f(): + barrier.wait() + self.assertEqual(len(results), 1) + + self.run_threads(f) + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = [] + results2 = [] + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(self.barrier.broken) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = [] + results2 = [] + results3 = [] + def f(): + i = self.barrier.wait() + if i == self.N//2: + # Wait until the other threads are all in the barrier. + while self.barrier.n_waiting < self.N-1: + time.sleep(0.001) + self.barrier.reset() + else: + try: + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = [] + results2 = [] + results3 = [] + barrier2 = self.barriertype(self.N) + def f(): + try: + i = self.barrier.wait() + if i == self.N//2: + raise RuntimeError + self.barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + self.barrier.abort() + pass + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == self.N//2: + self.barrier.reset() + barrier2.wait() + self.barrier.wait() + results3.append(True) + + self.run_threads(f) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + def test_timeout(self): + """ + Test wait(timeout) + """ + def f(): + i = self.barrier.wait() + if i == self.N // 2: + # One thread is late! + time.sleep(0.1) + # Default timeout is 0.1, so this is shorter. + self.assertRaises(threading.BrokenBarrierError, + self.barrier.wait, 0.05) + self.run_threads(f) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + def f(): + i = self.barrier.wait() + if i == self.N // 2: + # One thread is later than the default timeout of 0.1s. + time.sleep(0.15) + self.assertRaises(threading.BrokenBarrierError, self.barrier.wait) + self.run_threads(f) + + def test_single_thread(self): + b = self.barriertype(1) + b.wait() + b.wait() |