summaryrefslogtreecommitdiffstats
path: root/Lib/multiprocessing/pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/multiprocessing/pool.py')
-rw-r--r--Lib/multiprocessing/pool.py112
1 files changed, 74 insertions, 38 deletions
diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py
index bfb2769..18a56f8 100644
--- a/Lib/multiprocessing/pool.py
+++ b/Lib/multiprocessing/pool.py
@@ -151,8 +151,9 @@ class Pool(object):
'''
_wrap_exception = True
- def Process(self, *args, **kwds):
- return self._ctx.Process(*args, **kwds)
+ @staticmethod
+ def Process(ctx, *args, **kwds):
+ return ctx.Process(*args, **kwds)
def __init__(self, processes=None, initializer=None, initargs=(),
maxtasksperchild=None, context=None):
@@ -190,7 +191,10 @@ class Pool(object):
self._worker_handler = threading.Thread(
target=Pool._handle_workers,
- args=(self, )
+ args=(self._cache, self._taskqueue, self._ctx, self.Process,
+ self._processes, self._pool, self._inqueue, self._outqueue,
+ self._initializer, self._initargs, self._maxtasksperchild,
+ self._wrap_exception)
)
self._worker_handler.daemon = True
self._worker_handler._state = RUN
@@ -236,43 +240,61 @@ class Pool(object):
f'state={self._state} '
f'pool_size={len(self._pool)}>')
- def _join_exited_workers(self):
+ @staticmethod
+ def _join_exited_workers(pool):
"""Cleanup after any worker processes which have exited due to reaching
their specified lifetime. Returns True if any workers were cleaned up.
"""
cleaned = False
- for i in reversed(range(len(self._pool))):
- worker = self._pool[i]
+ for i in reversed(range(len(pool))):
+ worker = pool[i]
if worker.exitcode is not None:
# worker exited
util.debug('cleaning up worker %d' % i)
worker.join()
cleaned = True
- del self._pool[i]
+ del pool[i]
return cleaned
def _repopulate_pool(self):
+ return self._repopulate_pool_static(self._ctx, self.Process,
+ self._processes,
+ self._pool, self._inqueue,
+ self._outqueue, self._initializer,
+ self._initargs,
+ self._maxtasksperchild,
+ self._wrap_exception)
+
+ @staticmethod
+ def _repopulate_pool_static(ctx, Process, processes, pool, inqueue,
+ outqueue, initializer, initargs,
+ maxtasksperchild, wrap_exception):
"""Bring the number of pool processes up to the specified number,
for use after reaping workers which have exited.
"""
- for i in range(self._processes - len(self._pool)):
- w = self.Process(target=worker,
- args=(self._inqueue, self._outqueue,
- self._initializer,
- self._initargs, self._maxtasksperchild,
- self._wrap_exception)
- )
+ for i in range(processes - len(pool)):
+ w = Process(ctx, target=worker,
+ args=(inqueue, outqueue,
+ initializer,
+ initargs, maxtasksperchild,
+ wrap_exception))
w.name = w.name.replace('Process', 'PoolWorker')
w.daemon = True
w.start()
- self._pool.append(w)
+ pool.append(w)
util.debug('added worker')
- def _maintain_pool(self):
+ @staticmethod
+ def _maintain_pool(ctx, Process, processes, pool, inqueue, outqueue,
+ initializer, initargs, maxtasksperchild,
+ wrap_exception):
"""Clean up any exited workers and start replacements for them.
"""
- if self._join_exited_workers():
- self._repopulate_pool()
+ if Pool._join_exited_workers(pool):
+ Pool._repopulate_pool_static(ctx, Process, processes, pool,
+ inqueue, outqueue, initializer,
+ initargs, maxtasksperchild,
+ wrap_exception)
def _setup_queues(self):
self._inqueue = self._ctx.SimpleQueue()
@@ -331,7 +353,7 @@ class Pool(object):
'''
self._check_running()
if chunksize == 1:
- result = IMapIterator(self._cache)
+ result = IMapIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
@@ -344,7 +366,7 @@ class Pool(object):
"Chunksize must be 1+, not {0:n}".format(
chunksize))
task_batches = Pool._get_tasks(func, iterable, chunksize)
- result = IMapIterator(self._cache)
+ result = IMapIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
@@ -360,7 +382,7 @@ class Pool(object):
'''
self._check_running()
if chunksize == 1:
- result = IMapUnorderedIterator(self._cache)
+ result = IMapUnorderedIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
@@ -372,7 +394,7 @@ class Pool(object):
raise ValueError(
"Chunksize must be 1+, not {0!r}".format(chunksize))
task_batches = Pool._get_tasks(func, iterable, chunksize)
- result = IMapUnorderedIterator(self._cache)
+ result = IMapUnorderedIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
@@ -388,7 +410,7 @@ class Pool(object):
Asynchronous version of `apply()` method.
'''
self._check_running()
- result = ApplyResult(self._cache, callback, error_callback)
+ result = ApplyResult(self, callback, error_callback)
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
return result
@@ -417,7 +439,7 @@ class Pool(object):
chunksize = 0
task_batches = Pool._get_tasks(func, iterable, chunksize)
- result = MapResult(self._cache, chunksize, len(iterable), callback,
+ result = MapResult(self, chunksize, len(iterable), callback,
error_callback=error_callback)
self._taskqueue.put(
(
@@ -430,16 +452,20 @@ class Pool(object):
return result
@staticmethod
- def _handle_workers(pool):
+ def _handle_workers(cache, taskqueue, ctx, Process, processes, pool,
+ inqueue, outqueue, initializer, initargs,
+ maxtasksperchild, wrap_exception):
thread = threading.current_thread()
# Keep maintaining workers until the cache gets drained, unless the pool
# is terminated.
- while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
- pool._maintain_pool()
+ while thread._state == RUN or (cache and thread._state != TERMINATE):
+ Pool._maintain_pool(ctx, Process, processes, pool, inqueue,
+ outqueue, initializer, initargs,
+ maxtasksperchild, wrap_exception)
time.sleep(0.1)
# send sentinel to stop workers
- pool._taskqueue.put(None)
+ taskqueue.put(None)
util.debug('worker handler exiting')
@staticmethod
@@ -656,13 +682,14 @@ class Pool(object):
class ApplyResult(object):
- def __init__(self, cache, callback, error_callback):
+ def __init__(self, pool, callback, error_callback):
+ self._pool = pool
self._event = threading.Event()
self._job = next(job_counter)
- self._cache = cache
+ self._cache = pool._cache
self._callback = callback
self._error_callback = error_callback
- cache[self._job] = self
+ self._cache[self._job] = self
def ready(self):
return self._event.is_set()
@@ -692,6 +719,7 @@ class ApplyResult(object):
self._error_callback(self._value)
self._event.set()
del self._cache[self._job]
+ self._pool = None
AsyncResult = ApplyResult # create alias -- see #17805
@@ -701,8 +729,8 @@ AsyncResult = ApplyResult # create alias -- see #17805
class MapResult(ApplyResult):
- def __init__(self, cache, chunksize, length, callback, error_callback):
- ApplyResult.__init__(self, cache, callback,
+ def __init__(self, pool, chunksize, length, callback, error_callback):
+ ApplyResult.__init__(self, pool, callback,
error_callback=error_callback)
self._success = True
self._value = [None] * length
@@ -710,7 +738,7 @@ class MapResult(ApplyResult):
if chunksize <= 0:
self._number_left = 0
self._event.set()
- del cache[self._job]
+ del self._cache[self._job]
else:
self._number_left = length//chunksize + bool(length % chunksize)
@@ -724,6 +752,7 @@ class MapResult(ApplyResult):
self._callback(self._value)
del self._cache[self._job]
self._event.set()
+ self._pool = None
else:
if not success and self._success:
# only store first exception
@@ -735,6 +764,7 @@ class MapResult(ApplyResult):
self._error_callback(self._value)
del self._cache[self._job]
self._event.set()
+ self._pool = None
#
# Class whose instances are returned by `Pool.imap()`
@@ -742,15 +772,16 @@ class MapResult(ApplyResult):
class IMapIterator(object):
- def __init__(self, cache):
+ def __init__(self, pool):
+ self._pool = pool
self._cond = threading.Condition(threading.Lock())
self._job = next(job_counter)
- self._cache = cache
+ self._cache = pool._cache
self._items = collections.deque()
self._index = 0
self._length = None
self._unsorted = {}
- cache[self._job] = self
+ self._cache[self._job] = self
def __iter__(self):
return self
@@ -761,12 +792,14 @@ class IMapIterator(object):
item = self._items.popleft()
except IndexError:
if self._index == self._length:
+ self._pool = None
raise StopIteration from None
self._cond.wait(timeout)
try:
item = self._items.popleft()
except IndexError:
if self._index == self._length:
+ self._pool = None
raise StopIteration from None
raise TimeoutError from None
@@ -792,6 +825,7 @@ class IMapIterator(object):
if self._index == self._length:
del self._cache[self._job]
+ self._pool = None
def _set_length(self, length):
with self._cond:
@@ -799,6 +833,7 @@ class IMapIterator(object):
if self._index == self._length:
self._cond.notify()
del self._cache[self._job]
+ self._pool = None
#
# Class whose instances are returned by `Pool.imap_unordered()`
@@ -813,6 +848,7 @@ class IMapUnorderedIterator(IMapIterator):
self._cond.notify()
if self._index == self._length:
del self._cache[self._job]
+ self._pool = None
#
#
@@ -822,7 +858,7 @@ class ThreadPool(Pool):
_wrap_exception = False
@staticmethod
- def Process(*args, **kwds):
+ def Process(ctx, *args, **kwds):
from .dummy import Process
return Process(*args, **kwds)