diff options
author | Antoine Pitrou <pitrou@free.fr> | 2017-11-04 10:05:49 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-11-04 10:05:49 (GMT) |
commit | 63ff4131af86e8a48cbedb9fbba95bd65ca90061 (patch) | |
tree | e6b205d0bc509e1be7d03a1d755f328f650f5ea1 /Lib | |
parent | b838cc3ff4e039af949c6a19bd896e98e944dcbe (diff) | |
download | cpython-63ff4131af86e8a48cbedb9fbba95bd65ca90061.zip cpython-63ff4131af86e8a48cbedb9fbba95bd65ca90061.tar.gz cpython-63ff4131af86e8a48cbedb9fbba95bd65ca90061.tar.bz2 |
bpo-21423: Add an initializer argument to {Process,Thread}PoolExecutor (#4241)
* bpo-21423: Add an initializer argument to {Process,Thread}PoolExecutor
* Fix docstring
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/concurrent/futures/__init__.py | 1 | ||||
-rw-r--r-- | Lib/concurrent/futures/_base.py | 6 | ||||
-rw-r--r-- | Lib/concurrent/futures/process.py | 36 | ||||
-rw-r--r-- | Lib/concurrent/futures/thread.py | 51 | ||||
-rw-r--r-- | Lib/test/test_concurrent_futures.py | 184 |
5 files changed, 203 insertions, 75 deletions
diff --git a/Lib/concurrent/futures/__init__.py b/Lib/concurrent/futures/__init__.py index b5231f8..ba8de16 100644 --- a/Lib/concurrent/futures/__init__.py +++ b/Lib/concurrent/futures/__init__.py @@ -10,6 +10,7 @@ from concurrent.futures._base import (FIRST_COMPLETED, ALL_COMPLETED, CancelledError, TimeoutError, + BrokenExecutor, Future, Executor, wait, diff --git a/Lib/concurrent/futures/_base.py b/Lib/concurrent/futures/_base.py index 6bace6c..4f22f7e 100644 --- a/Lib/concurrent/futures/_base.py +++ b/Lib/concurrent/futures/_base.py @@ -610,3 +610,9 @@ class Executor(object): def __exit__(self, exc_type, exc_val, exc_tb): self.shutdown(wait=True) return False + + +class BrokenExecutor(RuntimeError): + """ + Raised when a executor has become non-functional after a severe failure. + """ diff --git a/Lib/concurrent/futures/process.py b/Lib/concurrent/futures/process.py index 67ebbf5..35af65d 100644 --- a/Lib/concurrent/futures/process.py +++ b/Lib/concurrent/futures/process.py @@ -131,6 +131,7 @@ class _CallItem(object): self.args = args self.kwargs = kwargs + def _get_chunks(*iterables, chunksize): """ Iterates over zip()ed iterables in chunks. """ it = zip(*iterables) @@ -151,7 +152,7 @@ def _process_chunk(fn, chunk): """ return [fn(*args) for args in chunk] -def _process_worker(call_queue, result_queue): +def _process_worker(call_queue, result_queue, initializer, initargs): """Evaluates calls from call_queue and places the results in result_queue. This worker is run in a separate process. @@ -161,7 +162,17 @@ def _process_worker(call_queue, result_queue): evaluated by the worker. result_queue: A ctx.Queue of _ResultItems that will written to by the worker. + initializer: A callable initializer, or None + initargs: A tuple of args for the initializer """ + if initializer is not None: + try: + initializer(*initargs) + except BaseException: + _base.LOGGER.critical('Exception in initializer:', exc_info=True) + # The parent will notice that the process stopped and + # mark the pool broken + return while True: call_item = call_queue.get(block=True) if call_item is None: @@ -277,7 +288,9 @@ def _queue_management_worker(executor_reference, # Mark the process pool broken so that submits fail right now. executor = executor_reference() if executor is not None: - executor._broken = True + executor._broken = ('A child process terminated ' + 'abruptly, the process pool is not ' + 'usable anymore') executor._shutdown_thread = True executor = None # All futures in flight must be marked failed @@ -372,7 +385,7 @@ def _chain_from_iterable_of_lists(iterable): yield element.pop() -class BrokenProcessPool(RuntimeError): +class BrokenProcessPool(_base.BrokenExecutor): """ Raised when a process in a ProcessPoolExecutor terminated abruptly while a future was in the running state. @@ -380,7 +393,8 @@ class BrokenProcessPool(RuntimeError): class ProcessPoolExecutor(_base.Executor): - def __init__(self, max_workers=None, mp_context=None): + def __init__(self, max_workers=None, mp_context=None, + initializer=None, initargs=()): """Initializes a new ProcessPoolExecutor instance. Args: @@ -389,6 +403,8 @@ class ProcessPoolExecutor(_base.Executor): worker processes will be created as the machine has processors. mp_context: A multiprocessing context to launch the workers. This object should provide SimpleQueue, Queue and Process. + initializer: An callable used to initialize worker processes. + initargs: A tuple of arguments to pass to the initializer. """ _check_system_limits() @@ -403,6 +419,11 @@ class ProcessPoolExecutor(_base.Executor): mp_context = mp.get_context() self._mp_context = mp_context + if initializer is not None and not callable(initializer): + raise TypeError("initializer must be a callable") + self._initializer = initializer + self._initargs = initargs + # Make the call queue slightly larger than the number of processes to # prevent the worker processes from idling. But don't make it too big # because futures in the call queue cannot be cancelled. @@ -450,15 +471,16 @@ class ProcessPoolExecutor(_base.Executor): p = self._mp_context.Process( target=_process_worker, args=(self._call_queue, - self._result_queue)) + self._result_queue, + self._initializer, + self._initargs)) p.start() self._processes[p.pid] = p def submit(self, fn, *args, **kwargs): with self._shutdown_lock: if self._broken: - raise BrokenProcessPool('A child process terminated ' - 'abruptly, the process pool is not usable anymore') + raise BrokenProcessPool(self._broken) if self._shutdown_thread: raise RuntimeError('cannot schedule new futures after shutdown') diff --git a/Lib/concurrent/futures/thread.py b/Lib/concurrent/futures/thread.py index 0b5d537..2e7100b 100644 --- a/Lib/concurrent/futures/thread.py +++ b/Lib/concurrent/futures/thread.py @@ -41,6 +41,7 @@ def _python_exit(): atexit.register(_python_exit) + class _WorkItem(object): def __init__(self, future, fn, args, kwargs): self.future = future @@ -61,7 +62,17 @@ class _WorkItem(object): else: self.future.set_result(result) -def _worker(executor_reference, work_queue): + +def _worker(executor_reference, work_queue, initializer, initargs): + if initializer is not None: + try: + initializer(*initargs) + except BaseException: + _base.LOGGER.critical('Exception in initializer:', exc_info=True) + executor = executor_reference() + if executor is not None: + executor._initializer_failed() + return try: while True: work_item = work_queue.get(block=True) @@ -83,18 +94,28 @@ def _worker(executor_reference, work_queue): except BaseException: _base.LOGGER.critical('Exception in worker', exc_info=True) + +class BrokenThreadPool(_base.BrokenExecutor): + """ + Raised when a worker thread in a ThreadPoolExecutor failed initializing. + """ + + class ThreadPoolExecutor(_base.Executor): # Used to assign unique thread names when thread_name_prefix is not supplied. _counter = itertools.count().__next__ - def __init__(self, max_workers=None, thread_name_prefix=''): + def __init__(self, max_workers=None, thread_name_prefix='', + initializer=None, initargs=()): """Initializes a new ThreadPoolExecutor instance. Args: max_workers: The maximum number of threads that can be used to execute the given calls. thread_name_prefix: An optional name prefix to give our threads. + initializer: An callable used to initialize worker threads. + initargs: A tuple of arguments to pass to the initializer. """ if max_workers is None: # Use this number because ThreadPoolExecutor is often @@ -103,16 +124,25 @@ class ThreadPoolExecutor(_base.Executor): if max_workers <= 0: raise ValueError("max_workers must be greater than 0") + if initializer is not None and not callable(initializer): + raise TypeError("initializer must be a callable") + self._max_workers = max_workers self._work_queue = queue.Queue() self._threads = set() + self._broken = False self._shutdown = False self._shutdown_lock = threading.Lock() self._thread_name_prefix = (thread_name_prefix or ("ThreadPoolExecutor-%d" % self._counter())) + self._initializer = initializer + self._initargs = initargs def submit(self, fn, *args, **kwargs): with self._shutdown_lock: + if self._broken: + raise BrokenThreadPool(self._broken) + if self._shutdown: raise RuntimeError('cannot schedule new futures after shutdown') @@ -137,12 +167,27 @@ class ThreadPoolExecutor(_base.Executor): num_threads) t = threading.Thread(name=thread_name, target=_worker, args=(weakref.ref(self, weakref_cb), - self._work_queue)) + self._work_queue, + self._initializer, + self._initargs)) t.daemon = True t.start() self._threads.add(t) _threads_queues[t] = self._work_queue + def _initializer_failed(self): + with self._shutdown_lock: + self._broken = ('A thread initializer failed, the thread pool ' + 'is not usable anymore') + # Drain work queue and mark pending futures failed + while True: + try: + work_item = self._work_queue.get_nowait() + except queue.Empty: + break + if work_item is not None: + work_item.future.set_exception(BrokenThreadPool(self._broken)) + def shutdown(self, wait=True): with self._shutdown_lock: self._shutdown = True diff --git a/Lib/test/test_concurrent_futures.py b/Lib/test/test_concurrent_futures.py index ed8ad41..296398f 100644 --- a/Lib/test/test_concurrent_futures.py +++ b/Lib/test/test_concurrent_futures.py @@ -7,6 +7,7 @@ test.support.import_module('multiprocessing.synchronize') from test.support.script_helper import assert_python_ok +import contextlib import itertools import os import sys @@ -17,7 +18,8 @@ import weakref from concurrent import futures from concurrent.futures._base import ( - PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future) + PENDING, RUNNING, CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, Future, + BrokenExecutor) from concurrent.futures.process import BrokenProcessPool from multiprocessing import get_context @@ -37,11 +39,12 @@ CANCELLED_AND_NOTIFIED_FUTURE = create_future(state=CANCELLED_AND_NOTIFIED) EXCEPTION_FUTURE = create_future(state=FINISHED, exception=OSError()) SUCCESSFUL_FUTURE = create_future(state=FINISHED, result=42) +INITIALIZER_STATUS = 'uninitialized' + def mul(x, y): return x * y - def sleep_and_raise(t): time.sleep(t) raise Exception('this is an exception') @@ -51,6 +54,17 @@ def sleep_and_print(t, msg): print(msg) sys.stdout.flush() +def init(x): + global INITIALIZER_STATUS + INITIALIZER_STATUS = x + +def get_init_status(): + return INITIALIZER_STATUS + +def init_fail(): + time.sleep(0.1) # let some futures be scheduled + raise ValueError('error in initializer') + class MyObject(object): def my_method(self): @@ -81,6 +95,7 @@ class BaseTestCase(unittest.TestCase): class ExecutorMixin: worker_count = 5 + executor_kwargs = {} def setUp(self): super().setUp() @@ -90,10 +105,12 @@ class ExecutorMixin: if hasattr(self, "ctx"): self.executor = self.executor_type( max_workers=self.worker_count, - mp_context=get_context(self.ctx)) + mp_context=get_context(self.ctx), + **self.executor_kwargs) else: self.executor = self.executor_type( - max_workers=self.worker_count) + max_workers=self.worker_count, + **self.executor_kwargs) except NotImplementedError as e: self.skipTest(str(e)) self._prime_executor() @@ -114,7 +131,6 @@ class ExecutorMixin: # tests. This should reduce the probability of timeouts in the tests. futures = [self.executor.submit(time.sleep, 0.1) for _ in range(self.worker_count)] - for f in futures: f.result() @@ -148,6 +164,90 @@ class ProcessPoolForkserverMixin(ExecutorMixin): super().setUp() +def create_executor_tests(mixin, bases=(BaseTestCase,), + executor_mixins=(ThreadPoolMixin, + ProcessPoolForkMixin, + ProcessPoolForkserverMixin, + ProcessPoolSpawnMixin)): + def strip_mixin(name): + if name.endswith(('Mixin', 'Tests')): + return name[:-5] + elif name.endswith('Test'): + return name[:-4] + else: + return name + + for exe in executor_mixins: + name = ("%s%sTest" + % (strip_mixin(exe.__name__), strip_mixin(mixin.__name__))) + cls = type(name, (mixin,) + (exe,) + bases, {}) + globals()[name] = cls + + +class InitializerMixin(ExecutorMixin): + worker_count = 2 + + def setUp(self): + global INITIALIZER_STATUS + INITIALIZER_STATUS = 'uninitialized' + self.executor_kwargs = dict(initializer=init, + initargs=('initialized',)) + super().setUp() + + def test_initializer(self): + futures = [self.executor.submit(get_init_status) + for _ in range(self.worker_count)] + + for f in futures: + self.assertEqual(f.result(), 'initialized') + + +class FailingInitializerMixin(ExecutorMixin): + worker_count = 2 + + def setUp(self): + self.executor_kwargs = dict(initializer=init_fail) + super().setUp() + + def test_initializer(self): + with self._assert_logged('ValueError: error in initializer'): + try: + future = self.executor.submit(get_init_status) + except BrokenExecutor: + # Perhaps the executor is already broken + pass + else: + with self.assertRaises(BrokenExecutor): + future.result() + # At some point, the executor should break + t1 = time.time() + while not self.executor._broken: + if time.time() - t1 > 5: + self.fail("executor not broken after 5 s.") + time.sleep(0.01) + # ... and from this point submit() is guaranteed to fail + with self.assertRaises(BrokenExecutor): + self.executor.submit(get_init_status) + + def _prime_executor(self): + pass + + @contextlib.contextmanager + def _assert_logged(self, msg): + if self.executor_type is futures.ProcessPoolExecutor: + # No easy way to catch the child processes' stderr + yield + else: + with self.assertLogs('concurrent.futures', 'CRITICAL') as cm: + yield + self.assertTrue(any(msg in line for line in cm.output), + cm.output) + + +create_executor_tests(InitializerMixin) +create_executor_tests(FailingInitializerMixin) + + class ExecutorShutdownTest: def test_run_after_shutdown(self): self.executor.shutdown() @@ -278,20 +378,11 @@ class ProcessPoolShutdownTest(ExecutorShutdownTest): call_queue.join_thread() -class ProcessPoolForkShutdownTest(ProcessPoolForkMixin, BaseTestCase, - ProcessPoolShutdownTest): - pass - - -class ProcessPoolForkserverShutdownTest(ProcessPoolForkserverMixin, - BaseTestCase, - ProcessPoolShutdownTest): - pass - -class ProcessPoolSpawnShutdownTest(ProcessPoolSpawnMixin, BaseTestCase, - ProcessPoolShutdownTest): - pass +create_executor_tests(ProcessPoolShutdownTest, + executor_mixins=(ProcessPoolForkMixin, + ProcessPoolForkserverMixin, + ProcessPoolSpawnMixin)) class WaitTests: @@ -413,18 +504,10 @@ class ThreadPoolWaitTests(ThreadPoolMixin, WaitTests, BaseTestCase): sys.setswitchinterval(oldswitchinterval) -class ProcessPoolForkWaitTests(ProcessPoolForkMixin, WaitTests, BaseTestCase): - pass - - -class ProcessPoolForkserverWaitTests(ProcessPoolForkserverMixin, WaitTests, - BaseTestCase): - pass - - -class ProcessPoolSpawnWaitTests(ProcessPoolSpawnMixin, BaseTestCase, - WaitTests): - pass +create_executor_tests(WaitTests, + executor_mixins=(ProcessPoolForkMixin, + ProcessPoolForkserverMixin, + ProcessPoolSpawnMixin)) class AsCompletedTests: @@ -507,24 +590,7 @@ class AsCompletedTests: self.assertEqual(str(cm.exception), '2 (of 4) futures unfinished') -class ThreadPoolAsCompletedTests(ThreadPoolMixin, AsCompletedTests, BaseTestCase): - pass - - -class ProcessPoolForkAsCompletedTests(ProcessPoolForkMixin, AsCompletedTests, - BaseTestCase): - pass - - -class ProcessPoolForkserverAsCompletedTests(ProcessPoolForkserverMixin, - AsCompletedTests, - BaseTestCase): - pass - - -class ProcessPoolSpawnAsCompletedTests(ProcessPoolSpawnMixin, AsCompletedTests, - BaseTestCase): - pass +create_executor_tests(AsCompletedTests) class ExecutorTest: @@ -688,23 +754,10 @@ class ProcessPoolExecutorTest(ExecutorTest): self.assertTrue(obj.event.wait(timeout=1)) -class ProcessPoolForkExecutorTest(ProcessPoolForkMixin, - ProcessPoolExecutorTest, - BaseTestCase): - pass - - -class ProcessPoolForkserverExecutorTest(ProcessPoolForkserverMixin, - ProcessPoolExecutorTest, - BaseTestCase): - pass - - -class ProcessPoolSpawnExecutorTest(ProcessPoolSpawnMixin, - ProcessPoolExecutorTest, - BaseTestCase): - pass - +create_executor_tests(ProcessPoolExecutorTest, + executor_mixins=(ProcessPoolForkMixin, + ProcessPoolForkserverMixin, + ProcessPoolSpawnMixin)) class FutureTests(BaseTestCase): @@ -932,6 +985,7 @@ class FutureTests(BaseTestCase): self.assertTrue(isinstance(f1.exception(timeout=5), OSError)) t.join() + @test.support.reap_threads def test_main(): try: |