summaryrefslogtreecommitdiffstats
path: root/Lib/multiprocessing
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/multiprocessing')
-rw-r--r--Lib/multiprocessing/__init__.py6
-rw-r--r--Lib/multiprocessing/connection.py29
-rw-r--r--Lib/multiprocessing/forking.py10
-rw-r--r--Lib/multiprocessing/managers.py8
-rw-r--r--Lib/multiprocessing/pool.py161
-rw-r--r--Lib/multiprocessing/process.py15
-rw-r--r--Lib/multiprocessing/queues.py6
-rw-r--r--Lib/multiprocessing/synchronize.py3
8 files changed, 182 insertions, 56 deletions
diff --git a/Lib/multiprocessing/__init__.py b/Lib/multiprocessing/__init__.py
index fdd012e..e6e16c8 100644
--- a/Lib/multiprocessing/__init__.py
+++ b/Lib/multiprocessing/__init__.py
@@ -9,7 +9,7 @@
# wrapper for 'threading'.
#
# Try calling `multiprocessing.doc.main()` to read the html
-# documentation in in a webbrowser.
+# documentation in a webbrowser.
#
#
# Copyright (c) 2006-2008, R Oudkerk
@@ -223,12 +223,12 @@ def JoinableQueue(maxsize=0):
from multiprocessing.queues import JoinableQueue
return JoinableQueue(maxsize)
-def Pool(processes=None, initializer=None, initargs=()):
+def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None):
'''
Returns a process pool object
'''
from multiprocessing.pool import Pool
- return Pool(processes, initializer, initargs)
+ return Pool(processes, initializer, initargs, maxtasksperchild)
def RawValue(typecode_or_type, *args):
'''
diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py
index 846d396..d6c23fb 100644
--- a/Lib/multiprocessing/connection.py
+++ b/Lib/multiprocessing/connection.py
@@ -281,25 +281,24 @@ def SocketClient(address):
Return a connection object connected to the socket given by `address`
'''
family = address_type(address)
- s = socket.socket( getattr(socket, family) )
- t = _init_timeout()
+ with socket.socket( getattr(socket, family) ) as s:
+ t = _init_timeout()
- while 1:
- try:
- s.connect(address)
- except socket.error as e:
- if e.args[0] != errno.ECONNREFUSED or _check_timeout(t):
- debug('failed to connect to address %s', address)
- raise
- time.sleep(0.01)
+ while 1:
+ try:
+ s.connect(address)
+ except socket.error as e:
+ if e.args[0] != errno.ECONNREFUSED or _check_timeout(t):
+ debug('failed to connect to address %s', address)
+ raise
+ time.sleep(0.01)
+ else:
+ break
else:
- break
- else:
- raise
+ raise
- fd = duplicate(s.fileno())
+ fd = duplicate(s.fileno())
conn = _multiprocessing.Connection(fd)
- s.close()
return conn
#
diff --git a/Lib/multiprocessing/forking.py b/Lib/multiprocessing/forking.py
index 0300a4b..cc7c326 100644
--- a/Lib/multiprocessing/forking.py
+++ b/Lib/multiprocessing/forking.py
@@ -460,12 +460,20 @@ def prepare(data):
process.ORIGINAL_DIR = data['orig_dir']
if 'main_path' in data:
+ # XXX (ncoghlan): The following code makes several bogus
+ # assumptions regarding the relationship between __file__
+ # and a module's real name. See PEP 302 and issue #10845
main_path = data['main_path']
main_name = os.path.splitext(os.path.basename(main_path))[0]
if main_name == '__init__':
main_name = os.path.basename(os.path.dirname(main_path))
- if main_name != 'ipython':
+ if main_name == '__main__':
+ main_module = sys.modules['__main__']
+ main_module.__file__ = main_path
+ elif main_name != 'ipython':
+ # Main modules not actually called __main__.py may
+ # contain additional code that should still be executed
import imp
if main_path is None:
diff --git a/Lib/multiprocessing/managers.py b/Lib/multiprocessing/managers.py
index b77c693..5588ead 100644
--- a/Lib/multiprocessing/managers.py
+++ b/Lib/multiprocessing/managers.py
@@ -58,7 +58,7 @@ from multiprocessing.util import Finalize, info
#
def reduce_array(a):
- return array.array, (a.typecode, a.tostring())
+ return array.array, (a.typecode, a.tobytes())
ForkingPickler.register(array.array, reduce_array)
view_types = [type(getattr({}, name)()) for name in ('items','keys','values')]
@@ -134,7 +134,7 @@ def all_methods(obj):
temp = []
for name in dir(obj):
func = getattr(obj, name)
- if hasattr(func, '__call__'):
+ if callable(func):
temp.append(name)
return temp
@@ -162,7 +162,7 @@ class Server(object):
Listener, Client = listener_client[serializer]
# do authentication later
- self.listener = Listener(address=address, backlog=5)
+ self.listener = Listener(address=address, backlog=16)
self.address = self.listener.address
self.id_to_obj = {'0': (None, ())}
@@ -510,7 +510,7 @@ class BaseManager(object):
'''
assert self._state.value == State.INITIAL
- if initializer is not None and not hasattr(initializer, '__call__'):
+ if initializer is not None and not callable(initializer):
raise TypeError('initializer must be a callable')
# pipe over which we will retrieve address of server
diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py
index fc03a0a..0c29e64 100644
--- a/Lib/multiprocessing/pool.py
+++ b/Lib/multiprocessing/pool.py
@@ -68,7 +68,25 @@ def mapstar(args):
# Code run by worker processes
#
-def worker(inqueue, outqueue, initializer=None, initargs=()):
+class MaybeEncodingError(Exception):
+ """Wraps possible unpickleable errors, so they can be
+ safely sent through the socket."""
+
+ def __init__(self, exc, value):
+ self.exc = repr(exc)
+ self.value = repr(value)
+ super(MaybeEncodingError, self).__init__(self.exc, self.value)
+
+ def __str__(self):
+ return "Error sending result: '%s'. Reason: '%s'" % (self.value,
+ self.exc)
+
+ def __repr__(self):
+ return "<MaybeEncodingError: %s>" % str(self)
+
+
+def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
+ assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
put = outqueue.put
get = inqueue.get
if hasattr(inqueue, '_writer'):
@@ -78,7 +96,8 @@ def worker(inqueue, outqueue, initializer=None, initargs=()):
if initializer is not None:
initializer(*initargs)
- while 1:
+ completed = 0
+ while maxtasks is None or (maxtasks and completed < maxtasks):
try:
task = get()
except (EOFError, IOError):
@@ -94,7 +113,15 @@ def worker(inqueue, outqueue, initializer=None, initargs=()):
result = (True, func(*args, **kwds))
except Exception as e:
result = (False, e)
- put((job, i, result))
+ try:
+ put((job, i, result))
+ except Exception as e:
+ wrapped = MaybeEncodingError(e, result[1])
+ debug("Possible encoding error while sending result: %s" % (
+ wrapped))
+ put((job, i, (False, wrapped)))
+ completed += 1
+ debug('worker exiting after %d tasks' % completed)
#
# Class representing a process pool
@@ -106,31 +133,39 @@ class Pool(object):
'''
Process = Process
- def __init__(self, processes=None, initializer=None, initargs=()):
+ def __init__(self, processes=None, initializer=None, initargs=(),
+ maxtasksperchild=None):
self._setup_queues()
self._taskqueue = queue.Queue()
self._cache = {}
self._state = RUN
+ self._maxtasksperchild = maxtasksperchild
+ self._initializer = initializer
+ self._initargs = initargs
if processes is None:
try:
processes = cpu_count()
except NotImplementedError:
processes = 1
+ if processes < 1:
+ raise ValueError("Number of processes must be at least 1")
- if initializer is not None and not hasattr(initializer, '__call__'):
+ if initializer is not None and not callable(initializer):
raise TypeError('initializer must be a callable')
+ self._processes = processes
self._pool = []
- for i in range(processes):
- w = self.Process(
- target=worker,
- args=(self._inqueue, self._outqueue, initializer, initargs)
- )
- self._pool.append(w)
- w.name = w.name.replace('Process', 'PoolWorker')
- w.daemon = True
- w.start()
+ self._repopulate_pool()
+
+ self._worker_handler = threading.Thread(
+ target=Pool._handle_workers,
+ args=(self, )
+ )
+ self._worker_handler.daemon = True
+ self._worker_handler._state = RUN
+ self._worker_handler.start()
+
self._task_handler = threading.Thread(
target=Pool._handle_tasks,
@@ -151,10 +186,48 @@ class Pool(object):
self._terminate = Finalize(
self, self._terminate_pool,
args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
- self._task_handler, self._result_handler, self._cache),
+ self._worker_handler, self._task_handler,
+ self._result_handler, self._cache),
exitpriority=15
)
+ def _join_exited_workers(self):
+ """Cleanup after any worker processes which have exited due to reaching
+ their specified lifetime. Returns True if any workers were cleaned up.
+ """
+ cleaned = False
+ for i in reversed(range(len(self._pool))):
+ worker = self._pool[i]
+ if worker.exitcode is not None:
+ # worker exited
+ debug('cleaning up worker %d' % i)
+ worker.join()
+ cleaned = True
+ del self._pool[i]
+ return cleaned
+
+ def _repopulate_pool(self):
+ """Bring the number of pool processes up to the specified number,
+ for use after reaping workers which have exited.
+ """
+ for i in range(self._processes - len(self._pool)):
+ w = self.Process(target=worker,
+ args=(self._inqueue, self._outqueue,
+ self._initializer,
+ self._initargs, self._maxtasksperchild)
+ )
+ self._pool.append(w)
+ w.name = w.name.replace('Process', 'PoolWorker')
+ w.daemon = True
+ w.start()
+ debug('added worker')
+
+ def _maintain_pool(self):
+ """Clean up any exited workers and start replacements for them.
+ """
+ if self._join_exited_workers():
+ self._repopulate_pool()
+
def _setup_queues(self):
from .queues import SimpleQueue
self._inqueue = SimpleQueue()
@@ -213,16 +286,18 @@ class Pool(object):
for i, x in enumerate(task_batches)), result._set_length))
return (item for chunk in result for item in chunk)
- def apply_async(self, func, args=(), kwds={}, callback=None):
+ def apply_async(self, func, args=(), kwds={}, callback=None,
+ error_callback=None):
'''
Asynchronous version of `apply()` method.
'''
assert self._state == RUN
- result = ApplyResult(self._cache, callback)
+ result = ApplyResult(self._cache, callback, error_callback)
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
return result
- def map_async(self, func, iterable, chunksize=None, callback=None):
+ def map_async(self, func, iterable, chunksize=None, callback=None,
+ error_callback=None):
'''
Asynchronous version of `map()` method.
'''
@@ -238,12 +313,26 @@ class Pool(object):
chunksize = 0
task_batches = Pool._get_tasks(func, iterable, chunksize)
- result = MapResult(self._cache, chunksize, len(iterable), callback)
+ result = MapResult(self._cache, chunksize, len(iterable), callback,
+ error_callback=error_callback)
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
for i, x in enumerate(task_batches)), None))
return result
@staticmethod
+ def _handle_workers(pool):
+ thread = threading.current_thread()
+
+ # Keep maintaining workers until the cache gets drained, unless the pool
+ # is terminated.
+ while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
+ pool._maintain_pool()
+ time.sleep(0.1)
+ # send sentinel to stop workers
+ pool._taskqueue.put(None)
+ debug('worker handler exiting')
+
+ @staticmethod
def _handle_tasks(taskqueue, put, outqueue, pool):
thread = threading.current_thread()
@@ -358,16 +447,18 @@ class Pool(object):
debug('closing pool')
if self._state == RUN:
self._state = CLOSE
- self._taskqueue.put(None)
+ self._worker_handler._state = CLOSE
def terminate(self):
debug('terminating pool')
self._state = TERMINATE
+ self._worker_handler._state = TERMINATE
self._terminate()
def join(self):
debug('joining pool')
assert self._state in (CLOSE, TERMINATE)
+ self._worker_handler.join()
self._task_handler.join()
self._result_handler.join()
for p in self._pool:
@@ -384,12 +475,12 @@ class Pool(object):
@classmethod
def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
- task_handler, result_handler, cache):
+ worker_handler, task_handler, result_handler, cache):
# this is guaranteed to only be called once
debug('finalizing pool')
+ worker_handler._state = TERMINATE
task_handler._state = TERMINATE
- taskqueue.put(None) # sentinel
debug('helping task handler/workers to finish')
cls._help_stuff_finish(inqueue, task_handler, len(pool))
@@ -399,16 +490,23 @@ class Pool(object):
result_handler._state = TERMINATE
outqueue.put(None) # sentinel
+ # We must wait for the worker handler to exit before terminating
+ # workers because we don't want workers to be restarted behind our back.
+ debug('joining worker handler')
+ worker_handler.join()
+
+ # Terminate workers which haven't already finished.
if pool and hasattr(pool[0], 'terminate'):
debug('terminating workers')
for p in pool:
- p.terminate()
+ if p.exitcode is None:
+ p.terminate()
debug('joining task handler')
- task_handler.join(1e100)
+ task_handler.join()
debug('joining result handler')
- result_handler.join(1e100)
+ result_handler.join()
if pool and hasattr(pool[0], 'terminate'):
debug('joining pool workers')
@@ -424,12 +522,13 @@ class Pool(object):
class ApplyResult(object):
- def __init__(self, cache, callback):
+ def __init__(self, cache, callback, error_callback):
self._cond = threading.Condition(threading.Lock())
self._job = next(job_counter)
self._cache = cache
self._ready = False
self._callback = callback
+ self._error_callback = error_callback
cache[self._job] = self
def ready(self):
@@ -460,6 +559,8 @@ class ApplyResult(object):
self._success, self._value = obj
if self._callback and self._success:
self._callback(self._value)
+ if self._error_callback and not self._success:
+ self._error_callback(self._value)
self._cond.acquire()
try:
self._ready = True
@@ -474,8 +575,9 @@ class ApplyResult(object):
class MapResult(ApplyResult):
- def __init__(self, cache, chunksize, length, callback):
- ApplyResult.__init__(self, cache, callback)
+ def __init__(self, cache, chunksize, length, callback, error_callback):
+ ApplyResult.__init__(self, cache, callback,
+ error_callback=error_callback)
self._success = True
self._value = [None] * length
self._chunksize = chunksize
@@ -500,10 +602,11 @@ class MapResult(ApplyResult):
self._cond.notify()
finally:
self._cond.release()
-
else:
self._success = False
self._value = result
+ if self._error_callback:
+ self._error_callback(self._value)
del self._cache[self._job]
self._cond.acquire()
try:
diff --git a/Lib/multiprocessing/process.py b/Lib/multiprocessing/process.py
index b56a061..5987af9 100644
--- a/Lib/multiprocessing/process.py
+++ b/Lib/multiprocessing/process.py
@@ -42,6 +42,7 @@ import os
import sys
import signal
import itertools
+from _weakrefset import WeakSet
#
#
@@ -105,6 +106,7 @@ class Process(object):
self._kwargs = dict(kwargs)
self._name = name or type(self).__name__ + '-' + \
':'.join(str(i) for i in self._identity)
+ _dangling.add(self)
def run(self):
'''
@@ -251,9 +253,15 @@ class Process(object):
sys.stdin = open(os.devnull)
except (OSError, ValueError):
pass
+ old_process = _current_process
_current_process = self
- util._finalizer_registry.clear()
- util._run_after_forkers()
+ try:
+ util._finalizer_registry.clear()
+ util._run_after_forkers()
+ finally:
+ # delay finalization of the old process object until after
+ # _run_after_forkers() is executed
+ del old_process
util.info('child process calling self.run()')
try:
self.run()
@@ -322,3 +330,6 @@ _exitcode_to_name = {}
for name, signum in list(signal.__dict__.items()):
if name[:3]=='SIG' and '_' not in name:
_exitcode_to_name[-signum] = name
+
+# For debug and leak testing
+_dangling = WeakSet()
diff --git a/Lib/multiprocessing/queues.py b/Lib/multiprocessing/queues.py
index 3280a25..51d9912 100644
--- a/Lib/multiprocessing/queues.py
+++ b/Lib/multiprocessing/queues.py
@@ -126,7 +126,11 @@ class Queue(object):
if not self._rlock.acquire(block, timeout):
raise Empty
try:
- if not self._poll(block and (deadline-time.time()) or 0.0):
+ if block:
+ timeout = deadline - time.time()
+ if timeout < 0 or not self._poll(timeout):
+ raise Empty
+ elif not self._poll():
raise Empty
res = self._recv()
self._sem.release()
diff --git a/Lib/multiprocessing/synchronize.py b/Lib/multiprocessing/synchronize.py
index 617d0b6..70ae825 100644
--- a/Lib/multiprocessing/synchronize.py
+++ b/Lib/multiprocessing/synchronize.py
@@ -243,7 +243,7 @@ class Condition(object):
try:
# wait for notification or timeout
- self._wait_semaphore.acquire(True, timeout)
+ ret = self._wait_semaphore.acquire(True, timeout)
finally:
# indicate that this thread has woken
self._woken_count.release()
@@ -251,6 +251,7 @@ class Condition(object):
# reacquire lock
for i in range(count):
self._lock.acquire()
+ return ret
def notify(self):
assert self._lock._semlock._is_mine(), 'lock is not owned'