diff options
Diffstat (limited to 'Lib/test/test_multiprocessing.py')
-rw-r--r-- | Lib/test/test_multiprocessing.py | 128 |
1 files changed, 115 insertions, 13 deletions
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index d1ac4b7..91e2cbb 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -1,11 +1,10 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # # Unit tests for the multiprocessing package # import unittest -import threading import queue as pyqueue import time import io @@ -24,6 +23,10 @@ import test.support _multiprocessing = test.support.import_module('_multiprocessing') # Skip tests if sem_open implementation is broken. test.support.import_module('multiprocessing.synchronize') +# import threading after _multiprocessing to raise a more revelant error +# message: "No module named _multiprocessing". _multiprocessing is not compiled +# without thread support. +import threading import multiprocessing.dummy import multiprocessing.connection @@ -51,7 +54,7 @@ def latin(s): # LOG_LEVEL = util.SUBWARNING -#LOG_LEVEL = logging.WARNING +#LOG_LEVEL = logging.DEBUG DELTA = 0.1 CHECK_TIMINGS = False # making true makes tests take a lot longer @@ -155,7 +158,7 @@ class _TestProcess(BaseTestCase): self.assertTrue(current.is_alive()) self.assertTrue(not current.daemon) - self.assertTrue(isinstance(authkey, bytes)) + self.assertIsInstance(authkey, bytes) self.assertTrue(len(authkey) > 0) self.assertEqual(current.ident, os.getpid()) self.assertEqual(current.exitcode, None) @@ -186,7 +189,7 @@ class _TestProcess(BaseTestCase): self.assertEqual(p.authkey, current.authkey) self.assertEqual(p.is_alive(), False) self.assertEqual(p.daemon, True) - self.assertTrue(p not in self.active_children()) + self.assertNotIn(p, self.active_children()) self.assertTrue(type(self.active_children()) is list) self.assertEqual(p.exitcode, None) @@ -194,7 +197,7 @@ class _TestProcess(BaseTestCase): self.assertEqual(p.exitcode, None) self.assertEqual(p.is_alive(), True) - self.assertTrue(p in self.active_children()) + self.assertIn(p, self.active_children()) self.assertEqual(q.get(), args[1:]) self.assertEqual(q.get(), kwargs) @@ -207,7 +210,7 @@ class _TestProcess(BaseTestCase): self.assertEqual(p.exitcode, 0) self.assertEqual(p.is_alive(), False) - self.assertTrue(p not in self.active_children()) + self.assertNotIn(p, self.active_children()) @classmethod def _test_terminate(cls): @@ -222,7 +225,7 @@ class _TestProcess(BaseTestCase): p.start() self.assertEqual(p.is_alive(), True) - self.assertTrue(p in self.active_children()) + self.assertIn(p, self.active_children()) self.assertEqual(p.exitcode, None) p.terminate() @@ -232,7 +235,7 @@ class _TestProcess(BaseTestCase): self.assertTimingAlmostEqual(join.elapsed, 0.0) self.assertEqual(p.is_alive(), False) - self.assertTrue(p not in self.active_children()) + self.assertNotIn(p, self.active_children()) p.join() @@ -251,13 +254,13 @@ class _TestProcess(BaseTestCase): self.assertEqual(type(self.active_children()), list) p = self.Process(target=time.sleep, args=(DELTA,)) - self.assertTrue(p not in self.active_children()) + self.assertNotIn(p, self.active_children()) p.start() - self.assertTrue(p in self.active_children()) + self.assertIn(p, self.active_children()) p.join() - self.assertTrue(p not in self.active_children()) + self.assertNotIn(p, self.active_children()) @classmethod def _test_recursion(cls, wconn, id): @@ -765,7 +768,7 @@ class _TestCondition(BaseTestCase): cond.acquire() res = wait(TIMEOUT1) cond.release() - self.assertEqual(res, None) + self.assertEqual(res, False) self.assertTimingAlmostEqual(wait.elapsed, TIMEOUT1) @@ -1010,6 +1013,7 @@ class _TestContainers(BaseTestCase): def sqr(x, wait=0.0): time.sleep(wait) return x*x + class _TestPool(BaseTestCase): def test_apply(self): @@ -1085,6 +1089,84 @@ class _TestPool(BaseTestCase): join = TimingWrapper(self.pool.join) join() self.assertTrue(join.elapsed < 0.2) + +def raising(): + raise KeyError("key") + +def unpickleable_result(): + return lambda: 42 + +class _TestPoolWorkerErrors(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + + def test_async_error_callback(self): + p = multiprocessing.Pool(2) + + scratchpad = [None] + def errback(exc): + scratchpad[0] = exc + + res = p.apply_async(raising, error_callback=errback) + self.assertRaises(KeyError, res.get) + self.assertTrue(scratchpad[0]) + self.assertIsInstance(scratchpad[0], KeyError) + + p.close() + p.join() + + def test_unpickleable_result(self): + from multiprocessing.pool import MaybeEncodingError + p = multiprocessing.Pool(2) + + # Make sure we don't lose pool processes because of encoding errors. + for iteration in range(20): + + scratchpad = [None] + def errback(exc): + scratchpad[0] = exc + + res = p.apply_async(unpickleable_result, error_callback=errback) + self.assertRaises(MaybeEncodingError, res.get) + wrapped = scratchpad[0] + self.assertTrue(wrapped) + self.assertIsInstance(scratchpad[0], MaybeEncodingError) + self.assertIsNotNone(wrapped.exc) + self.assertIsNotNone(wrapped.value) + + p.close() + p.join() + +class _TestPoolWorkerLifetime(BaseTestCase): + ALLOWED_TYPES = ('processes', ) + + def test_pool_worker_lifetime(self): + p = multiprocessing.Pool(3, maxtasksperchild=10) + self.assertEqual(3, len(p._pool)) + origworkerpids = [w.pid for w in p._pool] + # Run many tasks so each worker gets replaced (hopefully) + results = [] + for i in range(100): + results.append(p.apply_async(sqr, (i, ))) + # Fetch the results and verify we got the right answers, + # also ensuring all the tasks have completed. + for (j, res) in enumerate(results): + self.assertEqual(res.get(), sqr(j)) + # Refill the pool + p._repopulate_pool() + # Wait until all workers are alive + countdown = 5 + while countdown and not all(w.is_alive() for w in p._pool): + countdown -= 1 + time.sleep(DELTA) + finalworkerpids = [w.pid for w in p._pool] + # All pids should be assigned. See issue #7805. + self.assertNotIn(None, origworkerpids) + self.assertNotIn(None, finalworkerpids) + # Finally, check that the worker pids have changed + self.assertNotEqual(sorted(origworkerpids), sorted(finalworkerpids)) + p.close() + p.join() + # # Test that manager has expected number of shared objects left # @@ -1761,6 +1843,26 @@ class _TestLogging(BaseTestCase): root_logger.setLevel(root_level) logger.setLevel(level=LOG_LEVEL) + +# class _TestLoggingProcessName(BaseTestCase): +# +# def handle(self, record): +# assert record.processName == multiprocessing.current_process().name +# self.__handled = True +# +# def test_logging(self): +# handler = logging.Handler() +# handler.handle = self.handle +# self.__handled = False +# # Bypass getLogger() and side-effects +# logger = logging.getLoggerClass()( +# 'multiprocessing.test.TestLoggingProcessName') +# logger.addHandler(handler) +# logger.propagate = False +# +# logger.warn('foo') +# assert self.__handled + # # Test to verify handle verification, see issue 3321 # |