diff options
Diffstat (limited to 'Lib/test/test_multiprocessing.py')
-rw-r--r-- | Lib/test/test_multiprocessing.py | 1053 |
1 files changed, 1001 insertions, 52 deletions
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 05fba7f..3fb07f6 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -8,6 +8,7 @@ import unittest import queue as pyqueue import time import io +import itertools import sys import os import gc @@ -17,6 +18,7 @@ import array import socket import random import logging +import struct import test.support import test.script_helper @@ -84,6 +86,13 @@ HAVE_GETVALUE = not getattr(_multiprocessing, WIN32 = (sys.platform == "win32") +from multiprocessing.connection import wait + +def wait_for_handle(handle, timeout): + if timeout is not None and timeout < 0.0: + timeout = None + return wait([handle], timeout) + try: MAXFD = os.sysconf("SC_OPEN_MAX") except: @@ -197,6 +206,18 @@ class _TestProcess(BaseTestCase): self.assertEqual(current.ident, os.getpid()) self.assertEqual(current.exitcode, None) + def test_daemon_argument(self): + if self.TYPE == "threads": + return + + # By default uses the current process's daemon flag. + proc0 = self.Process(target=self._test) + self.assertEqual(proc0.daemon, self.current_process().daemon) + proc1 = self.Process(target=self._test, daemon=True) + self.assertTrue(proc1.daemon) + proc2 = self.Process(target=self._test, daemon=False) + self.assertFalse(proc2.daemon) + @classmethod def _test(cls, q, *args, **kwds): current = cls.current_process() @@ -262,9 +283,18 @@ class _TestProcess(BaseTestCase): self.assertIn(p, self.active_children()) self.assertEqual(p.exitcode, None) + join = TimingWrapper(p.join) + + self.assertEqual(join(0), None) + self.assertTimingAlmostEqual(join.elapsed, 0.0) + self.assertEqual(p.is_alive(), True) + + self.assertEqual(join(-1), None) + self.assertTimingAlmostEqual(join.elapsed, 0.0) + self.assertEqual(p.is_alive(), True) + p.terminate() - join = TimingWrapper(p.join) self.assertEqual(join(), None) self.assertTimingAlmostEqual(join.elapsed, 0.0) @@ -329,6 +359,26 @@ class _TestProcess(BaseTestCase): ] self.assertEqual(result, expected) + @classmethod + def _test_sentinel(cls, event): + event.wait(10.0) + + def test_sentinel(self): + if self.TYPE == "threads": + return + event = self.Event() + p = self.Process(target=self._test_sentinel, args=(event,)) + with self.assertRaises(ValueError): + p.sentinel + p.start() + self.addCleanup(p.join) + sentinel = p.sentinel + self.assertIsInstance(sentinel, int) + self.assertFalse(wait_for_handle(sentinel, timeout=0.0)) + event.set() + p.join() + self.assertTrue(wait_for_handle(sentinel, timeout=DELTA)) + # # # @@ -868,6 +918,104 @@ class _TestCondition(BaseTestCase): self.assertEqual(res, False) self.assertTimingAlmostEqual(wait.elapsed, TIMEOUT1) + @classmethod + def _test_waitfor_f(cls, cond, state): + with cond: + state.value = 0 + cond.notify() + result = cond.wait_for(lambda : state.value==4) + if not result or state.value != 4: + sys.exit(1) + + @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes') + def test_waitfor(self): + # based on test in test/lock_tests.py + cond = self.Condition() + state = self.Value('i', -1) + + p = self.Process(target=self._test_waitfor_f, args=(cond, state)) + p.daemon = True + p.start() + + with cond: + result = cond.wait_for(lambda : state.value==0) + self.assertTrue(result) + self.assertEqual(state.value, 0) + + for i in range(4): + time.sleep(0.01) + with cond: + state.value += 1 + cond.notify() + + p.join(5) + self.assertFalse(p.is_alive()) + self.assertEqual(p.exitcode, 0) + + @classmethod + def _test_waitfor_timeout_f(cls, cond, state, success, sem): + sem.release() + with cond: + expected = 0.1 + dt = time.time() + result = cond.wait_for(lambda : state.value==4, timeout=expected) + dt = time.time() - dt + # borrow logic in assertTimeout() from test/lock_tests.py + if not result and expected * 0.6 < dt < expected * 10.0: + success.value = True + + @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes') + def test_waitfor_timeout(self): + # based on test in test/lock_tests.py + cond = self.Condition() + state = self.Value('i', 0) + success = self.Value('i', False) + sem = self.Semaphore(0) + + p = self.Process(target=self._test_waitfor_timeout_f, + args=(cond, state, success, sem)) + p.daemon = True + p.start() + self.assertTrue(sem.acquire(timeout=10)) + + # Only increment 3 times, so state == 4 is never reached. + for i in range(3): + time.sleep(0.01) + with cond: + state.value += 1 + cond.notify() + + p.join(5) + self.assertTrue(success.value) + + @classmethod + def _test_wait_result(cls, c, pid): + with c: + c.notify() + time.sleep(1) + if pid is not None: + os.kill(pid, signal.SIGINT) + + def test_wait_result(self): + if isinstance(self, ProcessesMixin) and sys.platform != 'win32': + pid = os.getpid() + else: + pid = None + + c = self.Condition() + with c: + self.assertFalse(c.wait(0)) + self.assertFalse(c.wait(0.1)) + + p = self.Process(target=self._test_wait_result, args=(c, pid)) + p.start() + + self.assertTrue(c.wait(10)) + if pid is not None: + self.assertRaises(KeyboardInterrupt, c.wait, 10) + + p.join() + class _TestEvent(BaseTestCase): @@ -911,6 +1059,340 @@ class _TestEvent(BaseTestCase): self.assertEqual(wait(), True) # +# Tests for Barrier - adapted from tests in test/lock_tests.py +# + +# Many of the tests for threading.Barrier use a list as an atomic +# counter: a value is appended to increment the counter, and the +# length of the list gives the value. We use the class DummyList +# for the same purpose. + +class _DummyList(object): + + def __init__(self): + wrapper = multiprocessing.heap.BufferWrapper(struct.calcsize('i')) + lock = multiprocessing.Lock() + self.__setstate__((wrapper, lock)) + self._lengthbuf[0] = 0 + + def __setstate__(self, state): + (self._wrapper, self._lock) = state + self._lengthbuf = self._wrapper.create_memoryview().cast('i') + + def __getstate__(self): + return (self._wrapper, self._lock) + + def append(self, _): + with self._lock: + self._lengthbuf[0] += 1 + + def __len__(self): + with self._lock: + return self._lengthbuf[0] + +def _wait(): + # A crude wait/yield function not relying on synchronization primitives. + time.sleep(0.01) + + +class Bunch(object): + """ + A bunch of threads. + """ + def __init__(self, namespace, f, args, n, wait_before_exit=False): + """ + Construct a bunch of `n` threads running the same function `f`. + If `wait_before_exit` is True, the threads won't terminate until + do_finish() is called. + """ + self.f = f + self.args = args + self.n = n + self.started = namespace.DummyList() + self.finished = namespace.DummyList() + self._can_exit = namespace.Event() + if not wait_before_exit: + self._can_exit.set() + for i in range(n): + p = namespace.Process(target=self.task) + p.daemon = True + p.start() + + def task(self): + pid = os.getpid() + self.started.append(pid) + try: + self.f(*self.args) + finally: + self.finished.append(pid) + self._can_exit.wait(30) + assert self._can_exit.is_set() + + def wait_for_started(self): + while len(self.started) < self.n: + _wait() + + def wait_for_finished(self): + while len(self.finished) < self.n: + _wait() + + def do_finish(self): + self._can_exit.set() + + +class AppendTrue(object): + def __init__(self, obj): + self.obj = obj + def __call__(self): + self.obj.append(True) + + +class _TestBarrier(BaseTestCase): + """ + Tests for Barrier objects. + """ + N = 5 + defaultTimeout = 30.0 # XXX Slow Windows buildbots need generous timeout + + def setUp(self): + self.barrier = self.Barrier(self.N, timeout=self.defaultTimeout) + + def tearDown(self): + self.barrier.abort() + self.barrier = None + + def DummyList(self): + if self.TYPE == 'threads': + return [] + elif self.TYPE == 'manager': + return self.manager.list() + else: + return _DummyList() + + def run_threads(self, f, args): + b = Bunch(self, f, args, self.N-1) + f(*args) + b.wait_for_finished() + + @classmethod + def multipass(cls, barrier, results, n): + m = barrier.parties + assert m == cls.N + for i in range(n): + results[0].append(True) + assert len(results[1]) == i * m + barrier.wait() + results[1].append(True) + assert len(results[0]) == (i + 1) * m + barrier.wait() + try: + assert barrier.n_waiting == 0 + except NotImplementedError: + pass + assert not barrier.broken + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [self.DummyList(), self.DummyList()] + self.run_threads(self.multipass, (self.barrier, results, passes)) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + @classmethod + def _test_wait_return_f(cls, barrier, queue): + res = barrier.wait() + queue.put(res) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + queue = self.Queue() + self.run_threads(self._test_wait_return_f, (self.barrier, queue)) + results = [queue.get() for i in range(self.N)] + self.assertEqual(results.count(0), 1) + + @classmethod + def _test_action_f(cls, barrier, results): + barrier.wait() + if len(results) != 1: + raise RuntimeError + + def test_action(self): + """ + Test the 'action' callback + """ + results = self.DummyList() + barrier = self.Barrier(self.N, action=AppendTrue(results)) + self.run_threads(self._test_action_f, (barrier, results)) + self.assertEqual(len(results), 1) + + @classmethod + def _test_abort_f(cls, barrier, results1, results2): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = self.DummyList() + results2 = self.DummyList() + self.run_threads(self._test_abort_f, + (self.barrier, results1, results2)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(self.barrier.broken) + + @classmethod + def _test_reset_f(cls, barrier, results1, results2, results3): + i = barrier.wait() + if i == cls.N//2: + # Wait until the other threads are all in the barrier. + while barrier.n_waiting < cls.N-1: + time.sleep(0.001) + barrier.reset() + else: + try: + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + barrier.wait() + results3.append(True) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + self.run_threads(self._test_reset_f, + (self.barrier, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_abort_and_reset_f(cls, barrier, barrier2, + results1, results2, results3): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == cls.N//2: + barrier.reset() + barrier2.wait() + barrier.wait() + results3.append(True) + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + barrier2 = self.Barrier(self.N) + + self.run_threads(self._test_abort_and_reset_f, + (self.barrier, barrier2, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_timeout_f(cls, barrier, results): + i = barrier.wait() + if i == cls.N//2: + # One thread is late! + time.sleep(1.0) + try: + barrier.wait(0.5) + except threading.BrokenBarrierError: + results.append(True) + + def test_timeout(self): + """ + Test wait(timeout) + """ + results = self.DummyList() + self.run_threads(self._test_timeout_f, (self.barrier, results)) + self.assertEqual(len(results), self.barrier.parties) + + @classmethod + def _test_default_timeout_f(cls, barrier, results): + i = barrier.wait(cls.defaultTimeout) + if i == cls.N//2: + # One thread is later than the default timeout + time.sleep(1.0) + try: + barrier.wait() + except threading.BrokenBarrierError: + results.append(True) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + barrier = self.Barrier(self.N, timeout=0.5) + results = self.DummyList() + self.run_threads(self._test_default_timeout_f, (barrier, results)) + self.assertEqual(len(results), barrier.parties) + + def test_single_thread(self): + b = self.Barrier(1) + b.wait() + b.wait() + + @classmethod + def _test_thousand_f(cls, barrier, passes, conn, lock): + for i in range(passes): + barrier.wait() + with lock: + conn.send(i) + + def test_thousand(self): + if self.TYPE == 'manager': + return + passes = 1000 + lock = self.Lock() + conn, child_conn = self.Pipe(False) + for j in range(self.N): + p = self.Process(target=self._test_thousand_f, + args=(self.barrier, passes, child_conn, lock)) + p.start() + + for i in range(passes): + for j in range(self.N): + self.assertEqual(conn.recv(), i) + +# # # @@ -1130,6 +1612,9 @@ def sqr(x, wait=0.0): time.sleep(wait) return x*x +def mul(x, y): + return x*y + class _TestPool(BaseTestCase): def test_apply(self): @@ -1143,6 +1628,37 @@ class _TestPool(BaseTestCase): self.assertEqual(pmap(sqr, list(range(100)), chunksize=20), list(map(sqr, list(range(100))))) + def test_starmap(self): + psmap = self.pool.starmap + tuples = list(zip(range(10), range(9,-1, -1))) + self.assertEqual(psmap(mul, tuples), + list(itertools.starmap(mul, tuples))) + tuples = list(zip(range(100), range(99,-1, -1))) + self.assertEqual(psmap(mul, tuples, chunksize=20), + list(itertools.starmap(mul, tuples))) + + def test_starmap_async(self): + tuples = list(zip(range(100), range(99,-1, -1))) + self.assertEqual(self.pool.starmap_async(mul, tuples).get(), + list(itertools.starmap(mul, tuples))) + + def test_map_async(self): + self.assertEqual(self.pool.map_async(sqr, list(range(10))).get(), + list(map(sqr, list(range(10))))) + + def test_map_async_callbacks(self): + call_args = self.manager.list() if self.TYPE == 'manager' else [] + self.pool.map_async(int, ['1'], + callback=call_args.append, + error_callback=call_args.append).wait() + self.assertEqual(1, len(call_args)) + self.assertEqual([1], call_args[0]) + self.pool.map_async(int, ['a'], + callback=call_args.append, + error_callback=call_args.append).wait() + self.assertEqual(2, len(call_args)) + self.assertIsInstance(call_args[1], ValueError) + def test_map_chunksize(self): try: self.pool.map_async(sqr, [], chunksize=1).get(timeout=TIMEOUT1) @@ -1221,6 +1737,16 @@ class _TestPool(BaseTestCase): p.close() p.join() + def test_context(self): + if self.TYPE == 'processes': + L = list(range(10)) + expected = [sqr(i) for i in L] + with multiprocessing.Pool(2) as p: + r = p.map_async(sqr, L) + self.assertEqual(r.get(), expected) + print(p._state) + self.assertRaises(ValueError, p.map_async, sqr, L) + def raising(): raise KeyError("key") @@ -1322,6 +1848,11 @@ class _TestZZZNumberOfObjects(BaseTestCase): # run after all the other tests for the manager. It tests that # there have been no "reference leaks" for the manager's shared # objects. Note the comment in _TestPool.test_terminate(). + + # If some other test using ManagerMixin.manager fails, then the + # raised exception may keep alive a frame which holds a reference + # to a managed object. This will cause test_number_of_objects to + # also fail. ALLOWED_TYPES = ('manager',) def test_number_of_objects(self): @@ -1376,7 +1907,27 @@ class _TestMyManager(BaseTestCase): def test_mymanager(self): manager = MyManager() manager.start() + self.common(manager) + manager.shutdown() + + # If the manager process exited cleanly then the exitcode + # will be zero. Otherwise (after a short timeout) + # terminate() is used, resulting in an exitcode of -SIGTERM. + self.assertEqual(manager._process.exitcode, 0) + + def test_mymanager_context(self): + with MyManager() as manager: + self.common(manager) + self.assertEqual(manager._process.exitcode, 0) + + def test_mymanager_context_prestarted(self): + manager = MyManager() + manager.start() + with manager: + self.common(manager) + self.assertEqual(manager._process.exitcode, 0) + def common(self, manager): foo = manager.Foo() bar = manager.Bar() baz = manager.baz() @@ -1399,7 +1950,6 @@ class _TestMyManager(BaseTestCase): self.assertEqual(list(baz), [i*i for i in range(10)]) - manager.shutdown() # # Test of connecting to a remote server and using xmlrpclib for serialization @@ -1570,6 +2120,9 @@ class _TestConnection(BaseTestCase): self.assertEqual(poll(), False) self.assertTimingAlmostEqual(poll.elapsed, 0) + self.assertEqual(poll(-1), False) + self.assertTimingAlmostEqual(poll.elapsed, 0) + self.assertEqual(poll(TIMEOUT1), False) self.assertTimingAlmostEqual(poll.elapsed, TIMEOUT1) @@ -1756,6 +2309,43 @@ class _TestConnection(BaseTestCase): self.assertRaises(RuntimeError, reduction.recv_handle, conn) p.join() + def test_context(self): + a, b = self.Pipe() + + with a, b: + a.send(1729) + self.assertEqual(b.recv(), 1729) + if self.TYPE == 'processes': + self.assertFalse(a.closed) + self.assertFalse(b.closed) + + if self.TYPE == 'processes': + self.assertTrue(a.closed) + self.assertTrue(b.closed) + self.assertRaises(IOError, a.recv) + self.assertRaises(IOError, b.recv) + +class _TestListener(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + def test_multiple_bind(self): + for family in self.connection.families: + l = self.connection.Listener(family=family) + self.addCleanup(l.close) + self.assertRaises(OSError, self.connection.Listener, + l.address, family) + + def test_context(self): + with self.connection.Listener() as l: + with self.connection.Client(l.address) as c: + with l.accept() as d: + c.send(1729) + self.assertEqual(d.recv(), 1729) + + if self.TYPE == 'processes': + self.assertRaises(IOError, l.accept) + class _TestListenerClient(BaseTestCase): ALLOWED_TYPES = ('processes', 'threads') @@ -1793,52 +2383,146 @@ class _TestListenerClient(BaseTestCase): p.join() l.close() + def test_issue16955(self): + for fam in self.connection.families: + l = self.connection.Listener(family=fam) + c = self.connection.Client(l.address) + a = l.accept() + a.send_bytes(b"hello") + self.assertTrue(c.poll(1)) + a.close() + c.close() + l.close() + +class _TestPoll(unittest.TestCase): + + ALLOWED_TYPES = ('processes', 'threads') + + def test_empty_string(self): + a, b = self.Pipe() + self.assertEqual(a.poll(), False) + b.send_bytes(b'') + self.assertEqual(a.poll(), True) + self.assertEqual(a.poll(), True) + + @classmethod + def _child_strings(cls, conn, strings): + for s in strings: + time.sleep(0.1) + conn.send_bytes(s) + conn.close() + + def test_strings(self): + strings = (b'hello', b'', b'a', b'b', b'', b'bye', b'', b'lop') + a, b = self.Pipe() + p = self.Process(target=self._child_strings, args=(b, strings)) + p.start() + + for s in strings: + for i in range(200): + if a.poll(0.01): + break + x = a.recv_bytes() + self.assertEqual(s, x) + + p.join() + + @classmethod + def _child_boundaries(cls, r): + # Polling may "pull" a message in to the child process, but we + # don't want it to pull only part of a message, as that would + # corrupt the pipe for any other processes which might later + # read from it. + r.poll(5) + + def test_boundaries(self): + r, w = self.Pipe(False) + p = self.Process(target=self._child_boundaries, args=(r,)) + p.start() + time.sleep(2) + L = [b"first", b"second"] + for obj in L: + w.send_bytes(obj) + w.close() + p.join() + self.assertIn(r.recv_bytes(), L) + + @classmethod + def _child_dont_merge(cls, b): + b.send_bytes(b'a') + b.send_bytes(b'b') + b.send_bytes(b'cd') + + def test_dont_merge(self): + a, b = self.Pipe() + self.assertEqual(a.poll(0.0), False) + self.assertEqual(a.poll(0.1), False) + + p = self.Process(target=self._child_dont_merge, args=(b,)) + p.start() + + self.assertEqual(a.recv_bytes(), b'a') + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.recv_bytes(), b'b') + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(0.0), True) + self.assertEqual(a.recv_bytes(), b'cd') + + p.join() + # # Test of sending connection and socket objects between processes # -""" + +@unittest.skipUnless(HAS_REDUCTION, "test needs multiprocessing.reduction") class _TestPicklingConnections(BaseTestCase): ALLOWED_TYPES = ('processes',) - def _listener(self, conn, families): + @classmethod + def tearDownClass(cls): + from multiprocessing.reduction import resource_sharer + resource_sharer.stop(timeout=5) + + @classmethod + def _listener(cls, conn, families): for fam in families: - l = self.connection.Listener(family=fam) + l = cls.connection.Listener(family=fam) conn.send(l.address) new_conn = l.accept() conn.send(new_conn) + new_conn.close() + l.close() - if self.TYPE == 'processes': - l = socket.socket() - l.bind(('localhost', 0)) - conn.send(l.getsockname()) - l.listen(1) - new_conn, addr = l.accept() - conn.send(new_conn) + l = socket.socket() + l.bind(('localhost', 0)) + l.listen(1) + conn.send(l.getsockname()) + new_conn, addr = l.accept() + conn.send(new_conn) + new_conn.close() + l.close() conn.recv() - def _remote(self, conn): + @classmethod + def _remote(cls, conn): for (address, msg) in iter(conn.recv, None): - client = self.connection.Client(address) + client = cls.connection.Client(address) client.send(msg.upper()) client.close() - if self.TYPE == 'processes': - address, msg = conn.recv() - client = socket.socket() - client.connect(address) - client.sendall(msg.upper()) - client.close() + address, msg = conn.recv() + client = socket.socket() + client.connect(address) + client.sendall(msg.upper()) + client.close() conn.close() def test_pickling(self): - try: - multiprocessing.allow_connection_pickling() - except ImportError: - return - families = self.connection.families lconn, lconn0 = self.Pipe() @@ -1862,16 +2546,19 @@ class _TestPicklingConnections(BaseTestCase): rconn.send(None) - if self.TYPE == 'processes': - msg = latin('This connection uses a normal socket') - address = lconn.recv() - rconn.send((address, msg)) - if hasattr(socket, 'fromfd'): - new_conn = lconn.recv() - self.assertEqual(new_conn.recv(100), msg.upper()) - else: - # XXX On Windows with Py2.6 need to backport fromfd() - discard = lconn.recv_bytes() + msg = latin('This connection uses a normal socket') + address = lconn.recv() + rconn.send((address, msg)) + new_conn = lconn.recv() + buf = [] + while True: + s = new_conn.recv(100) + if not s: + break + buf.append(s) + buf = b''.join(buf) + self.assertEqual(buf, msg.upper()) + new_conn.close() lconn.send(None) @@ -1880,7 +2567,46 @@ class _TestPicklingConnections(BaseTestCase): lp.join() rp.join() -""" + + @classmethod + def child_access(cls, conn): + w = conn.recv() + w.send('all is well') + w.close() + + r = conn.recv() + msg = r.recv() + conn.send(msg*2) + + conn.close() + + def test_access(self): + # On Windows, if we do not specify a destination pid when + # using DupHandle then we need to be careful to use the + # correct access flags for DuplicateHandle(), or else + # DupHandle.detach() will raise PermissionError. For example, + # for a read only pipe handle we should use + # access=FILE_GENERIC_READ. (Unfortunately + # DUPLICATE_SAME_ACCESS does not work.) + conn, child_conn = self.Pipe() + p = self.Process(target=self.child_access, args=(child_conn,)) + p.daemon = True + p.start() + child_conn.close() + + r, w = self.Pipe(duplex=False) + conn.send(w) + w.close() + self.assertEqual(r.recv(), 'all is well') + r.close() + + r, w = self.Pipe(duplex=False) + conn.send(r) + r.close() + w.send('foobar') + w.close() + self.assertEqual(conn.recv(), 'foobar'*2) + # # # @@ -2207,9 +2933,15 @@ class TestInvalidHandle(unittest.TestCase): @unittest.skipIf(WIN32, "skipped on Windows") def test_invalid_handles(self): - conn = _multiprocessing.Connection(44977608) - self.assertRaises(IOError, conn.poll) - self.assertRaises(IOError, _multiprocessing.Connection, -1) + conn = multiprocessing.connection.Connection(44977608) + try: + self.assertRaises((ValueError, IOError), conn.poll) + finally: + # Hack private attribute _handle to avoid printing an error + # in conn.__del__ + conn._handle = None + self.assertRaises((ValueError, IOError), + multiprocessing.connection.Connection, -1) # # Functions used to create test cases from the base ones in this module @@ -2228,10 +2960,12 @@ def create_test_cases(Mixin, type): result = {} glob = globals() Type = type.capitalize() + ALL_TYPES = {'processes', 'threads', 'manager'} for name in list(glob.keys()): if name.startswith('_Test'): base = glob[name] + assert set(base.ALLOWED_TYPES) <= ALL_TYPES, set(base.ALLOWED_TYPES) if type in base.ALLOWED_TYPES: newname = 'With' + Type + name[1:] class Temp(base, unittest.TestCase, Mixin): @@ -2250,7 +2984,7 @@ class ProcessesMixin(object): Process = multiprocessing.Process locals().update(get_attributes(multiprocessing, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'RawValue', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'RawValue', 'RawArray', 'current_process', 'active_children', 'Pipe', 'connection', 'JoinableQueue', 'Pool' ))) @@ -2265,7 +2999,7 @@ class ManagerMixin(object): manager = object.__new__(multiprocessing.managers.SyncManager) locals().update(get_attributes(manager, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'list', 'dict', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'list', 'dict', 'Namespace', 'JoinableQueue', 'Pool' ))) @@ -2278,7 +3012,7 @@ class ThreadsMixin(object): Process = multiprocessing.dummy.Process locals().update(get_attributes(multiprocessing.dummy, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'current_process', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'current_process', 'active_children', 'Pipe', 'connection', 'dict', 'list', 'Namespace', 'JoinableQueue', 'Pool' ))) @@ -2330,6 +3064,7 @@ class TestInitializers(unittest.TestCase): def tearDown(self): self.mgr.shutdown() + self.mgr.join() def test_manager_initializer(self): m = multiprocessing.managers.SyncManager() @@ -2337,6 +3072,7 @@ class TestInitializers(unittest.TestCase): m.start(initializer, (self.ns,)) self.assertEqual(self.ns.test, 1) m.shutdown() + m.join() def test_pool_initializer(self): self.assertRaises(TypeError, multiprocessing.Pool, initializer=1) @@ -2369,6 +3105,8 @@ def _afunc(x): def pool_in_process(): pool = multiprocessing.Pool(processes=4) x = pool.map(_afunc, [1, 2, 3, 4, 5, 6, 7]) + pool.close() + pool.join() class _file_like(object): def __init__(self, delegate): @@ -2413,6 +3151,181 @@ class TestStdinBadfiledescriptor(unittest.TestCase): assert sio.getvalue() == 'foo' +class TestWait(unittest.TestCase): + + @classmethod + def _child_test_wait(cls, w, slow): + for i in range(10): + if slow: + time.sleep(random.random()*0.1) + w.send((i, os.getpid())) + w.close() + + def test_wait(self, slow=False): + from multiprocessing.connection import wait + readers = [] + procs = [] + messages = [] + + for i in range(4): + r, w = multiprocessing.Pipe(duplex=False) + p = multiprocessing.Process(target=self._child_test_wait, args=(w, slow)) + p.daemon = True + p.start() + w.close() + readers.append(r) + procs.append(p) + self.addCleanup(p.join) + + while readers: + for r in wait(readers): + try: + msg = r.recv() + except EOFError: + readers.remove(r) + r.close() + else: + messages.append(msg) + + messages.sort() + expected = sorted((i, p.pid) for i in range(10) for p in procs) + self.assertEqual(messages, expected) + + @classmethod + def _child_test_wait_socket(cls, address, slow): + s = socket.socket() + s.connect(address) + for i in range(10): + if slow: + time.sleep(random.random()*0.1) + s.sendall(('%s\n' % i).encode('ascii')) + s.close() + + def test_wait_socket(self, slow=False): + from multiprocessing.connection import wait + l = socket.socket() + l.bind(('', 0)) + l.listen(4) + addr = ('localhost', l.getsockname()[1]) + readers = [] + procs = [] + dic = {} + + for i in range(4): + p = multiprocessing.Process(target=self._child_test_wait_socket, + args=(addr, slow)) + p.daemon = True + p.start() + procs.append(p) + self.addCleanup(p.join) + + for i in range(4): + r, _ = l.accept() + readers.append(r) + dic[r] = [] + l.close() + + while readers: + for r in wait(readers): + msg = r.recv(32) + if not msg: + readers.remove(r) + r.close() + else: + dic[r].append(msg) + + expected = ''.join('%s\n' % i for i in range(10)).encode('ascii') + for v in dic.values(): + self.assertEqual(b''.join(v), expected) + + def test_wait_slow(self): + self.test_wait(True) + + def test_wait_socket_slow(self): + self.test_wait_socket(True) + + def test_wait_timeout(self): + from multiprocessing.connection import wait + + expected = 5 + a, b = multiprocessing.Pipe() + + start = time.time() + res = wait([a, b], expected) + delta = time.time() - start + + self.assertEqual(res, []) + self.assertLess(delta, expected * 2) + self.assertGreater(delta, expected * 0.5) + + b.send(None) + + start = time.time() + res = wait([a, b], 20) + delta = time.time() - start + + self.assertEqual(res, [a]) + self.assertLess(delta, 0.4) + + @classmethod + def signal_and_sleep(cls, sem, period): + sem.release() + time.sleep(period) + + def test_wait_integer(self): + from multiprocessing.connection import wait + + expected = 3 + sorted_ = lambda l: sorted(l, key=lambda x: id(x)) + sem = multiprocessing.Semaphore(0) + a, b = multiprocessing.Pipe() + p = multiprocessing.Process(target=self.signal_and_sleep, + args=(sem, expected)) + + p.start() + self.assertIsInstance(p.sentinel, int) + self.assertTrue(sem.acquire(timeout=20)) + + start = time.time() + res = wait([a, p.sentinel, b], expected + 20) + delta = time.time() - start + + self.assertEqual(res, [p.sentinel]) + self.assertLess(delta, expected + 2) + self.assertGreater(delta, expected - 2) + + a.send(None) + + start = time.time() + res = wait([a, p.sentinel, b], 20) + delta = time.time() - start + + self.assertEqual(sorted_(res), sorted_([p.sentinel, b])) + self.assertLess(delta, 0.4) + + b.send(None) + + start = time.time() + res = wait([a, p.sentinel, b], 20) + delta = time.time() - start + + self.assertEqual(sorted_(res), sorted_([a, p.sentinel, b])) + self.assertLess(delta, 0.4) + + p.terminate() + p.join() + + def test_neg_timeout(self): + from multiprocessing.connection import wait + a, b = multiprocessing.Pipe() + t = time.time() + res = wait([a], timeout=-1) + t = time.time() - t + self.assertEqual(res, []) + self.assertLess(t, 1) + a.close() + b.close() + # # Issue 14151: Test invalid family on invalid environment # @@ -2430,6 +3343,38 @@ class TestInvalidFamily(unittest.TestCase): multiprocessing.connection.Listener('/var/test.pipe') # +# Issue 12098: check sys.flags of child matches that for parent +# + +class TestFlags(unittest.TestCase): + @classmethod + def run_in_grandchild(cls, conn): + conn.send(tuple(sys.flags)) + + @classmethod + def run_in_child(cls): + import json + r, w = multiprocessing.Pipe(duplex=False) + p = multiprocessing.Process(target=cls.run_in_grandchild, args=(w,)) + p.start() + grandchild_flags = r.recv() + p.join() + r.close() + w.close() + flags = (tuple(sys.flags), grandchild_flags) + print(json.dumps(flags)) + + def test_flags(self): + import json, subprocess + # start child process using unusual flags + prog = ('from test.test_multiprocessing import TestFlags; ' + + 'TestFlags.run_in_child()') + data = subprocess.check_output( + [sys.executable, '-E', '-S', '-O', '-c', prog]) + child_flags, grandchild_flags = json.loads(data.decode('ascii')) + self.assertEqual(child_flags, grandchild_flags) + +# # Test interaction with socket timeouts - see Issue #6056 # @@ -2484,8 +3429,8 @@ class TestNoForkBomb(unittest.TestCase): # testcases_other = [OtherTest, TestInvalidHandle, TestInitializers, - TestStdinBadfiledescriptor, TestInvalidFamily, - TestTimeouts, TestNoForkBomb] + TestStdinBadfiledescriptor, TestWait, TestInvalidFamily, + TestFlags, TestTimeouts, TestNoForkBomb] # # @@ -2522,14 +3467,18 @@ def test_main(run=None): loadTestsFromTestCase = unittest.defaultTestLoader.loadTestsFromTestCase suite = unittest.TestSuite(loadTestsFromTestCase(tc) for tc in testcases) - run(suite) - - ThreadsMixin.pool.terminate() - ProcessesMixin.pool.terminate() - ManagerMixin.pool.terminate() - ManagerMixin.manager.shutdown() - - del ProcessesMixin.pool, ThreadsMixin.pool, ManagerMixin.pool + try: + run(suite) + finally: + ThreadsMixin.pool.terminate() + ProcessesMixin.pool.terminate() + ManagerMixin.pool.terminate() + ManagerMixin.pool.join() + ManagerMixin.manager.shutdown() + ManagerMixin.manager.join() + ThreadsMixin.pool.join() + ProcessesMixin.pool.join() + del ProcessesMixin.pool, ThreadsMixin.pool, ManagerMixin.pool def main(): test_main(unittest.TextTestRunner(verbosity=2).run) |