diff options
Diffstat (limited to 'Lib/test/test_multiprocessing.py')
-rw-r--r-- | Lib/test/test_multiprocessing.py | 355 |
1 files changed, 351 insertions, 4 deletions
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 65e7b0b..2704827 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -18,6 +18,7 @@ import array import socket import random import logging +import struct import test.support @@ -1057,6 +1058,340 @@ class _TestEvent(BaseTestCase): self.assertEqual(wait(), True) # +# Tests for Barrier - adapted from tests in test/lock_tests.py +# + +# Many of the tests for threading.Barrier use a list as an atomic +# counter: a value is appended to increment the counter, and the +# length of the list gives the value. We use the class DummyList +# for the same purpose. + +class _DummyList(object): + + def __init__(self): + wrapper = multiprocessing.heap.BufferWrapper(struct.calcsize('i')) + lock = multiprocessing.Lock() + self.__setstate__((wrapper, lock)) + self._lengthbuf[0] = 0 + + def __setstate__(self, state): + (self._wrapper, self._lock) = state + self._lengthbuf = self._wrapper.create_memoryview().cast('i') + + def __getstate__(self): + return (self._wrapper, self._lock) + + def append(self, _): + with self._lock: + self._lengthbuf[0] += 1 + + def __len__(self): + with self._lock: + return self._lengthbuf[0] + +def _wait(): + # A crude wait/yield function not relying on synchronization primitives. + time.sleep(0.01) + + +class Bunch(object): + """ + A bunch of threads. + """ + def __init__(self, namespace, f, args, n, wait_before_exit=False): + """ + Construct a bunch of `n` threads running the same function `f`. + If `wait_before_exit` is True, the threads won't terminate until + do_finish() is called. + """ + self.f = f + self.args = args + self.n = n + self.started = namespace.DummyList() + self.finished = namespace.DummyList() + self._can_exit = namespace.Event() + if not wait_before_exit: + self._can_exit.set() + for i in range(n): + p = namespace.Process(target=self.task) + p.daemon = True + p.start() + + def task(self): + pid = os.getpid() + self.started.append(pid) + try: + self.f(*self.args) + finally: + self.finished.append(pid) + self._can_exit.wait(30) + assert self._can_exit.is_set() + + def wait_for_started(self): + while len(self.started) < self.n: + _wait() + + def wait_for_finished(self): + while len(self.finished) < self.n: + _wait() + + def do_finish(self): + self._can_exit.set() + + +class AppendTrue(object): + def __init__(self, obj): + self.obj = obj + def __call__(self): + self.obj.append(True) + + +class _TestBarrier(BaseTestCase): + """ + Tests for Barrier objects. + """ + N = 5 + defaultTimeout = 10.0 # XXX Slow Windows buildbots need generous timeout + + def setUp(self): + self.barrier = self.Barrier(self.N, timeout=self.defaultTimeout) + + def tearDown(self): + self.barrier.abort() + self.barrier = None + + def DummyList(self): + if self.TYPE == 'threads': + return [] + elif self.TYPE == 'manager': + return self.manager.list() + else: + return _DummyList() + + def run_threads(self, f, args): + b = Bunch(self, f, args, self.N-1) + f(*args) + b.wait_for_finished() + + @classmethod + def multipass(cls, barrier, results, n): + m = barrier.parties + assert m == cls.N + for i in range(n): + results[0].append(True) + assert len(results[1]) == i * m + barrier.wait() + results[1].append(True) + assert len(results[0]) == (i + 1) * m + barrier.wait() + try: + assert barrier.n_waiting == 0 + except NotImplementedError: + pass + assert not barrier.broken + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [self.DummyList(), self.DummyList()] + self.run_threads(self.multipass, (self.barrier, results, passes)) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + @classmethod + def _test_wait_return_f(cls, barrier, queue): + res = barrier.wait() + queue.put(res) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + queue = self.Queue() + self.run_threads(self._test_wait_return_f, (self.barrier, queue)) + results = [queue.get() for i in range(self.N)] + self.assertEqual(results.count(0), 1) + + @classmethod + def _test_action_f(cls, barrier, results): + barrier.wait() + if len(results) != 1: + raise RuntimeError + + def test_action(self): + """ + Test the 'action' callback + """ + results = self.DummyList() + barrier = self.Barrier(self.N, action=AppendTrue(results)) + self.run_threads(self._test_action_f, (barrier, results)) + self.assertEqual(len(results), 1) + + @classmethod + def _test_abort_f(cls, barrier, results1, results2): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = self.DummyList() + results2 = self.DummyList() + self.run_threads(self._test_abort_f, + (self.barrier, results1, results2)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(self.barrier.broken) + + @classmethod + def _test_reset_f(cls, barrier, results1, results2, results3): + i = barrier.wait() + if i == cls.N//2: + # Wait until the other threads are all in the barrier. + while barrier.n_waiting < cls.N-1: + time.sleep(0.001) + barrier.reset() + else: + try: + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + barrier.wait() + results3.append(True) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + self.run_threads(self._test_reset_f, + (self.barrier, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_abort_and_reset_f(cls, barrier, barrier2, + results1, results2, results3): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == cls.N//2: + barrier.reset() + barrier2.wait() + barrier.wait() + results3.append(True) + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + barrier2 = self.Barrier(self.N) + + self.run_threads(self._test_abort_and_reset_f, + (self.barrier, barrier2, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_timeout_f(cls, barrier, results): + i = barrier.wait(20) + if i == cls.N//2: + # One thread is late! + time.sleep(4.0) + try: + barrier.wait(0.5) + except threading.BrokenBarrierError: + results.append(True) + + def test_timeout(self): + """ + Test wait(timeout) + """ + results = self.DummyList() + self.run_threads(self._test_timeout_f, (self.barrier, results)) + self.assertEqual(len(results), self.barrier.parties) + + @classmethod + def _test_default_timeout_f(cls, barrier, results): + i = barrier.wait(20) + if i == cls.N//2: + # One thread is later than the default timeout + time.sleep(4.0) + try: + barrier.wait() + except threading.BrokenBarrierError: + results.append(True) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + barrier = self.Barrier(self.N, timeout=1.0) + results = self.DummyList() + self.run_threads(self._test_default_timeout_f, (barrier, results)) + self.assertEqual(len(results), barrier.parties) + + def test_single_thread(self): + b = self.Barrier(1) + b.wait() + b.wait() + + @classmethod + def _test_thousand_f(cls, barrier, passes, conn, lock): + for i in range(passes): + barrier.wait() + with lock: + conn.send(i) + + def test_thousand(self): + if self.TYPE == 'manager': + return + passes = 1000 + lock = self.Lock() + conn, child_conn = self.Pipe(False) + for j in range(self.N): + p = self.Process(target=self._test_thousand_f, + args=(self.barrier, passes, child_conn, lock)) + p.start() + + for i in range(passes): + for j in range(self.N): + self.assertEqual(conn.recv(), i) + +# # # @@ -1485,6 +1820,11 @@ class _TestZZZNumberOfObjects(BaseTestCase): # run after all the other tests for the manager. It tests that # there have been no "reference leaks" for the manager's shared # objects. Note the comment in _TestPool.test_terminate(). + + # If some other test using ManagerMixin.manager fails, then the + # raised exception may keep alive a frame which holds a reference + # to a managed object. This will cause test_number_of_objects to + # also fail. ALLOWED_TYPES = ('manager',) def test_number_of_objects(self): @@ -1564,6 +1904,11 @@ class _TestMyManager(BaseTestCase): manager.shutdown() + # If the manager process exited cleanly then the exitcode + # will be zero. Otherwise (after a short timeout) + # terminate() is used, resulting in an exitcode of -SIGTERM. + self.assertEqual(manager._process.exitcode, 0) + # # Test of connecting to a remote server and using xmlrpclib for serialization # @@ -1923,7 +2268,7 @@ class _TestConnection(BaseTestCase): class _TestListener(BaseTestCase): - ALLOWED_TYPES = ('processes') + ALLOWED_TYPES = ('processes',) def test_multiple_bind(self): for family in self.connection.families: @@ -2505,10 +2850,12 @@ def create_test_cases(Mixin, type): result = {} glob = globals() Type = type.capitalize() + ALL_TYPES = {'processes', 'threads', 'manager'} for name in list(glob.keys()): if name.startswith('_Test'): base = glob[name] + assert set(base.ALLOWED_TYPES) <= ALL_TYPES, set(base.ALLOWED_TYPES) if type in base.ALLOWED_TYPES: newname = 'With' + Type + name[1:] class Temp(base, unittest.TestCase, Mixin): @@ -2527,7 +2874,7 @@ class ProcessesMixin(object): Process = multiprocessing.Process locals().update(get_attributes(multiprocessing, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'RawValue', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'RawValue', 'RawArray', 'current_process', 'active_children', 'Pipe', 'connection', 'JoinableQueue', 'Pool' ))) @@ -2542,7 +2889,7 @@ class ManagerMixin(object): manager = object.__new__(multiprocessing.managers.SyncManager) locals().update(get_attributes(manager, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'list', 'dict', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'list', 'dict', 'Namespace', 'JoinableQueue', 'Pool' ))) @@ -2555,7 +2902,7 @@ class ThreadsMixin(object): Process = multiprocessing.dummy.Process locals().update(get_attributes(multiprocessing.dummy, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'current_process', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'current_process', 'active_children', 'Pipe', 'connection', 'dict', 'list', 'Namespace', 'JoinableQueue', 'Pool' ))) |