summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAsk Solem <askh@opera.com>2010-11-09 20:55:52 (GMT)
committerAsk Solem <askh@opera.com>2010-11-09 20:55:52 (GMT)
commit2afcbf2249a04092b1e7cb8ff29e8505a6b20da4 (patch)
tree2bd1e666b3730a2f88dedaaf9cfb850233bad896
parentfb0469112f2e027833a1dc7ff4c678417de0111a (diff)
downloadcpython-2afcbf2249a04092b1e7cb8ff29e8505a6b20da4.zip
cpython-2afcbf2249a04092b1e7cb8ff29e8505a6b20da4.tar.gz
cpython-2afcbf2249a04092b1e7cb8ff29e8505a6b20da4.tar.bz2
Issue #9244: multiprocessing.pool: Worker crashes if result can't be encoded
-rw-r--r--Lib/multiprocessing/pool.py49
-rw-r--r--Lib/test/test_multiprocessing.py49
2 files changed, 88 insertions, 10 deletions
diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py
index 7154d3c..c170cca 100644
--- a/Lib/multiprocessing/pool.py
+++ b/Lib/multiprocessing/pool.py
@@ -42,6 +42,23 @@ def mapstar(args):
# Code run by worker processes
#
+class MaybeEncodingError(Exception):
+ """Wraps possible unpickleable errors, so they can be
+ safely sent through the socket."""
+
+ def __init__(self, exc, value):
+ self.exc = repr(exc)
+ self.value = repr(value)
+ super(MaybeEncodingError, self).__init__(self.exc, self.value)
+
+ def __str__(self):
+ return "Error sending result: '%s'. Reason: '%s'" % (self.value,
+ self.exc)
+
+ def __repr__(self):
+ return "<MaybeEncodingError: %s>" % str(self)
+
+
def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
put = outqueue.put
@@ -70,7 +87,13 @@ def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
result = (True, func(*args, **kwds))
except Exception as e:
result = (False, e)
- put((job, i, result))
+ try:
+ put((job, i, result))
+ except Exception as e:
+ wrapped = MaybeEncodingError(e, result[1])
+ debug("Possible encoding error while sending result: %s" % (
+ wrapped))
+ put((job, i, (False, wrapped)))
completed += 1
debug('worker exiting after %d tasks' % completed)
@@ -235,16 +258,18 @@ class Pool(object):
for i, x in enumerate(task_batches)), result._set_length))
return (item for chunk in result for item in chunk)
- def apply_async(self, func, args=(), kwds={}, callback=None):
+ def apply_async(self, func, args=(), kwds={}, callback=None,
+ error_callback=None):
'''
Asynchronous version of `apply()` method.
'''
assert self._state == RUN
- result = ApplyResult(self._cache, callback)
+ result = ApplyResult(self._cache, callback, error_callback)
self._taskqueue.put(([(result._job, None, func, args, kwds)], None))
return result
- def map_async(self, func, iterable, chunksize=None, callback=None):
+ def map_async(self, func, iterable, chunksize=None, callback=None,
+ error_callback=None):
'''
Asynchronous version of `map()` method.
'''
@@ -260,7 +285,8 @@ class Pool(object):
chunksize = 0
task_batches = Pool._get_tasks(func, iterable, chunksize)
- result = MapResult(self._cache, chunksize, len(iterable), callback)
+ result = MapResult(self._cache, chunksize, len(iterable), callback,
+ error_callback=error_callback)
self._taskqueue.put((((result._job, i, mapstar, (x,), {})
for i, x in enumerate(task_batches)), None))
return result
@@ -459,12 +485,13 @@ class Pool(object):
class ApplyResult(object):
- def __init__(self, cache, callback):
+ def __init__(self, cache, callback, error_callback):
self._cond = threading.Condition(threading.Lock())
self._job = next(job_counter)
self._cache = cache
self._ready = False
self._callback = callback
+ self._error_callback = error_callback
cache[self._job] = self
def ready(self):
@@ -495,6 +522,8 @@ class ApplyResult(object):
self._success, self._value = obj
if self._callback and self._success:
self._callback(self._value)
+ if self._error_callback and not self._success:
+ self._error_callback(self._value)
self._cond.acquire()
try:
self._ready = True
@@ -509,8 +538,9 @@ class ApplyResult(object):
class MapResult(ApplyResult):
- def __init__(self, cache, chunksize, length, callback):
- ApplyResult.__init__(self, cache, callback)
+ def __init__(self, cache, chunksize, length, callback, error_callback):
+ ApplyResult.__init__(self, cache, callback,
+ error_callback=error_callback)
self._success = True
self._value = [None] * length
self._chunksize = chunksize
@@ -535,10 +565,11 @@ class MapResult(ApplyResult):
self._cond.notify()
finally:
self._cond.release()
-
else:
self._success = False
self._value = result
+ if self._error_callback:
+ self._error_callback(self._value)
del self._cache[self._job]
self._cond.acquire()
try:
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py
index 0b3f937..bb0f06a 100644
--- a/Lib/test/test_multiprocessing.py
+++ b/Lib/test/test_multiprocessing.py
@@ -1011,6 +1011,7 @@ class _TestContainers(BaseTestCase):
def sqr(x, wait=0.0):
time.sleep(wait)
return x*x
+
class _TestPool(BaseTestCase):
def test_apply(self):
@@ -1087,9 +1088,55 @@ class _TestPool(BaseTestCase):
join()
self.assertTrue(join.elapsed < 0.2)
-class _TestPoolWorkerLifetime(BaseTestCase):
+def raising():
+ raise KeyError("key")
+
+def unpickleable_result():
+ return lambda: 42
+
+class _TestPoolWorkerErrors(BaseTestCase):
+ ALLOWED_TYPES = ('processes', )
+
+ def test_async_error_callback(self):
+ p = multiprocessing.Pool(2)
+
+ scratchpad = [None]
+ def errback(exc):
+ scratchpad[0] = exc
+
+ res = p.apply_async(raising, error_callback=errback)
+ self.assertRaises(KeyError, res.get)
+ self.assertTrue(scratchpad[0])
+ self.assertIsInstance(scratchpad[0], KeyError)
+
+ p.close()
+ p.join()
+
+ def test_unpickleable_result(self):
+ from multiprocessing.pool import MaybeEncodingError
+ p = multiprocessing.Pool(2)
+
+ # Make sure we don't lose pool processes because of encoding errors.
+ for iteration in range(20):
+
+ scratchpad = [None]
+ def errback(exc):
+ scratchpad[0] = exc
+
+ res = p.apply_async(unpickleable_result, error_callback=errback)
+ self.assertRaises(MaybeEncodingError, res.get)
+ wrapped = scratchpad[0]
+ self.assertTrue(wrapped)
+ self.assertIsInstance(scratchpad[0], MaybeEncodingError)
+ self.assertIsNotNone(wrapped.exc)
+ self.assertIsNotNone(wrapped.value)
+ p.close()
+ p.join()
+
+class _TestPoolWorkerLifetime(BaseTestCase):
ALLOWED_TYPES = ('processes', )
+
def test_pool_worker_lifetime(self):
p = multiprocessing.Pool(3, maxtasksperchild=10)
self.assertEqual(3, len(p._pool))