diff options
Diffstat (limited to 'Lib/multiprocessing/reduction.py')
-rw-r--r-- | Lib/multiprocessing/reduction.py | 363 |
1 files changed, 162 insertions, 201 deletions
diff --git a/Lib/multiprocessing/reduction.py b/Lib/multiprocessing/reduction.py index 656fa8f..5bbbcf4 100644 --- a/Lib/multiprocessing/reduction.py +++ b/Lib/multiprocessing/reduction.py @@ -1,6 +1,5 @@ # -# Module to allow connection and socket objects to be transferred -# between processes +# Module which deals with pickling of objects. # # multiprocessing/reduction.py # @@ -8,27 +7,57 @@ # Licensed to PSF under a Contributor Agreement. # -__all__ = ['reduce_socket', 'reduce_connection', 'send_handle', 'recv_handle'] - +import copyreg +import functools +import io import os -import sys +import pickle import socket -import threading -import struct -import signal +import sys -from multiprocessing import current_process -from multiprocessing.util import register_after_fork, debug, sub_debug -from multiprocessing.util import is_exiting, sub_warning +from . import popen +from . import util +__all__ = ['send_handle', 'recv_handle', 'ForkingPickler', 'register', 'dump'] + + +HAVE_SEND_HANDLE = (sys.platform == 'win32' or + (hasattr(socket, 'CMSG_LEN') and + hasattr(socket, 'SCM_RIGHTS') and + hasattr(socket.socket, 'sendmsg'))) # +# Pickler subclass # -# -if not(sys.platform == 'win32' or (hasattr(socket, 'CMSG_LEN') and - hasattr(socket, 'SCM_RIGHTS'))): - raise ImportError('pickling of connections not supported') +class ForkingPickler(pickle.Pickler): + '''Pickler subclass used by multiprocessing.''' + _extra_reducers = {} + _copyreg_dispatch_table = copyreg.dispatch_table + + def __init__(self, *args): + super().__init__(*args) + self.dispatch_table = self._copyreg_dispatch_table.copy() + self.dispatch_table.update(self._extra_reducers) + + @classmethod + def register(cls, type, reduce): + '''Register a reduce function for a type.''' + cls._extra_reducers[type] = reduce + + @classmethod + def dumps(cls, obj, protocol=None): + buf = io.BytesIO() + cls(buf, protocol).dump(obj) + return buf.getbuffer() + + loads = pickle.loads + +register = ForkingPickler.register + +def dump(obj, file, protocol=None): + '''Replacement for pickle.dump() using ForkingPickler.''' + ForkingPickler(file, protocol).dump(obj) # # Platform specific definitions @@ -36,20 +65,44 @@ if not(sys.platform == 'win32' or (hasattr(socket, 'CMSG_LEN') and if sys.platform == 'win32': # Windows - __all__ += ['reduce_pipe_connection'] + __all__ += ['DupHandle', 'duplicate', 'steal_handle'] import _winapi + def duplicate(handle, target_process=None, inheritable=False): + '''Duplicate a handle. (target_process is a handle not a pid!)''' + if target_process is None: + target_process = _winapi.GetCurrentProcess() + return _winapi.DuplicateHandle( + _winapi.GetCurrentProcess(), handle, target_process, + 0, inheritable, _winapi.DUPLICATE_SAME_ACCESS) + + def steal_handle(source_pid, handle): + '''Steal a handle from process identified by source_pid.''' + source_process_handle = _winapi.OpenProcess( + _winapi.PROCESS_DUP_HANDLE, False, source_pid) + try: + return _winapi.DuplicateHandle( + source_process_handle, handle, + _winapi.GetCurrentProcess(), 0, False, + _winapi.DUPLICATE_SAME_ACCESS | _winapi.DUPLICATE_CLOSE_SOURCE) + finally: + _winapi.CloseHandle(source_process_handle) + def send_handle(conn, handle, destination_pid): + '''Send a handle over a local connection.''' dh = DupHandle(handle, _winapi.DUPLICATE_SAME_ACCESS, destination_pid) conn.send(dh) def recv_handle(conn): + '''Receive a handle over a local connection.''' return conn.recv().detach() class DupHandle(object): + '''Picklable wrapper for a handle.''' def __init__(self, handle, access, pid=None): - # duplicate handle for process with given pid if pid is None: + # We just duplicate the handle in the current process and + # let the receiving process steal the handle. pid = os.getpid() proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False, pid) try: @@ -62,9 +115,12 @@ if sys.platform == 'win32': self._pid = pid def detach(self): + '''Get the handle. This should only be called once.''' # retrieve handle from process which currently owns it if self._pid == os.getpid(): + # The handle has already been duplicated for this process. return self._handle + # We must steal the handle from the process whose pid is self._pid. proc = _winapi.OpenProcess(_winapi.PROCESS_DUP_HANDLE, False, self._pid) try: @@ -74,207 +130,112 @@ if sys.platform == 'win32': finally: _winapi.CloseHandle(proc) - class DupSocket(object): - def __init__(self, sock): - new_sock = sock.dup() - def send(conn, pid): - share = new_sock.share(pid) - conn.send_bytes(share) - self._id = resource_sharer.register(send, new_sock.close) - - def detach(self): - conn = resource_sharer.get_connection(self._id) - try: - share = conn.recv_bytes() - return socket.fromshare(share) - finally: - conn.close() - - def reduce_socket(s): - return rebuild_socket, (DupSocket(s),) - - def rebuild_socket(ds): - return ds.detach() - - def reduce_connection(conn): - handle = conn.fileno() - with socket.fromfd(handle, socket.AF_INET, socket.SOCK_STREAM) as s: - ds = DupSocket(s) - return rebuild_connection, (ds, conn.readable, conn.writable) - - def rebuild_connection(ds, readable, writable): - from .connection import Connection - sock = ds.detach() - return Connection(sock.detach(), readable, writable) - - def reduce_pipe_connection(conn): - access = ((_winapi.FILE_GENERIC_READ if conn.readable else 0) | - (_winapi.FILE_GENERIC_WRITE if conn.writable else 0)) - dh = DupHandle(conn.fileno(), access) - return rebuild_pipe_connection, (dh, conn.readable, conn.writable) - - def rebuild_pipe_connection(dh, readable, writable): - from .connection import PipeConnection - handle = dh.detach() - return PipeConnection(handle, readable, writable) - else: # Unix + __all__ += ['DupFd', 'sendfds', 'recvfds'] + import array # On MacOSX we should acknowledge receipt of fds -- see Issue14669 ACKNOWLEDGE = sys.platform == 'darwin' + def sendfds(sock, fds): + '''Send an array of fds over an AF_UNIX socket.''' + fds = array.array('i', fds) + msg = bytes([len(fds) % 256]) + sock.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)]) + if ACKNOWLEDGE and sock.recv(1) != b'A': + raise RuntimeError('did not receive acknowledgement of fd') + + def recvfds(sock, size): + '''Receive an array of fds over an AF_UNIX socket.''' + a = array.array('i') + bytes_size = a.itemsize * size + msg, ancdata, flags, addr = sock.recvmsg(1, socket.CMSG_LEN(bytes_size)) + if not msg and not ancdata: + raise EOFError + try: + if ACKNOWLEDGE: + sock.send(b'A') + if len(ancdata) != 1: + raise RuntimeError('received %d items of ancdata' % + len(ancdata)) + cmsg_level, cmsg_type, cmsg_data = ancdata[0] + if (cmsg_level == socket.SOL_SOCKET and + cmsg_type == socket.SCM_RIGHTS): + if len(cmsg_data) % a.itemsize != 0: + raise ValueError + a.frombytes(cmsg_data) + assert len(a) % 256 == msg[0] + return list(a) + except (ValueError, IndexError): + pass + raise RuntimeError('Invalid data received') + def send_handle(conn, handle, destination_pid): + '''Send a handle over a local connection.''' with socket.fromfd(conn.fileno(), socket.AF_UNIX, socket.SOCK_STREAM) as s: - s.sendmsg([b'x'], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, - struct.pack("@i", handle))]) - if ACKNOWLEDGE and conn.recv_bytes() != b'ACK': - raise RuntimeError('did not receive acknowledgement of fd') + sendfds(s, [handle]) def recv_handle(conn): - size = struct.calcsize("@i") + '''Receive a handle over a local connection.''' with socket.fromfd(conn.fileno(), socket.AF_UNIX, socket.SOCK_STREAM) as s: - msg, ancdata, flags, addr = s.recvmsg(1, socket.CMSG_LEN(size)) - try: - if ACKNOWLEDGE: - conn.send_bytes(b'ACK') - cmsg_level, cmsg_type, cmsg_data = ancdata[0] - if (cmsg_level == socket.SOL_SOCKET and - cmsg_type == socket.SCM_RIGHTS): - return struct.unpack("@i", cmsg_data[:size])[0] - except (ValueError, IndexError, struct.error): - pass - raise RuntimeError('Invalid data received') - - class DupFd(object): - def __init__(self, fd): - new_fd = os.dup(fd) - def send(conn, pid): - send_handle(conn, new_fd, pid) - def close(): - os.close(new_fd) - self._id = resource_sharer.register(send, close) + return recvfds(s, 1)[0] + + def DupFd(fd): + '''Return a wrapper for an fd.''' + popen_obj = popen.get_spawning_popen() + if popen_obj is not None: + return popen_obj.DupFd(popen_obj.duplicate_for_child(fd)) + elif HAVE_SEND_HANDLE: + from . import resource_sharer + return resource_sharer.DupFd(fd) + else: + raise ValueError('SCM_RIGHTS appears not to be available') - def detach(self): - conn = resource_sharer.get_connection(self._id) - try: - return recv_handle(conn) - finally: - conn.close() +# +# Try making some callable types picklable +# - def reduce_socket(s): - df = DupFd(s.fileno()) - return rebuild_socket, (df, s.family, s.type, s.proto) +def _reduce_method(m): + if m.__self__ is None: + return getattr, (m.__class__, m.__func__.__name__) + else: + return getattr, (m.__self__, m.__func__.__name__) +class _C: + def f(self): + pass +register(type(_C().f), _reduce_method) - def rebuild_socket(df, family, type, proto): - fd = df.detach() - s = socket.fromfd(fd, family, type, proto) - os.close(fd) - return s - def reduce_connection(conn): - df = DupFd(conn.fileno()) - return rebuild_connection, (df, conn.readable, conn.writable) +def _reduce_method_descriptor(m): + return getattr, (m.__objclass__, m.__name__) +register(type(list.append), _reduce_method_descriptor) +register(type(int.__add__), _reduce_method_descriptor) - def rebuild_connection(df, readable, writable): - from .connection import Connection - fd = df.detach() - return Connection(fd, readable, writable) + +def _reduce_partial(p): + return _rebuild_partial, (p.func, p.args, p.keywords or {}) +def _rebuild_partial(func, args, keywords): + return functools.partial(func, *args, **keywords) +register(functools.partial, _reduce_partial) # -# Server which shares registered resources with clients +# Make sockets picklable # -class ResourceSharer(object): - def __init__(self): - self._key = 0 - self._cache = {} - self._old_locks = [] - self._lock = threading.Lock() - self._listener = None - self._address = None - self._thread = None - register_after_fork(self, ResourceSharer._afterfork) - - def register(self, send, close): - with self._lock: - if self._address is None: - self._start() - self._key += 1 - self._cache[self._key] = (send, close) - return (self._address, self._key) - - @staticmethod - def get_connection(ident): - from .connection import Client - address, key = ident - c = Client(address, authkey=current_process().authkey) - c.send((key, os.getpid())) - return c - - def stop(self, timeout=None): - from .connection import Client - with self._lock: - if self._address is not None: - c = Client(self._address, authkey=current_process().authkey) - c.send(None) - c.close() - self._thread.join(timeout) - if self._thread.is_alive(): - sub_warn('ResourceSharer thread did not stop when asked') - self._listener.close() - self._thread = None - self._address = None - self._listener = None - for key, (send, close) in self._cache.items(): - close() - self._cache.clear() - - def _afterfork(self): - for key, (send, close) in self._cache.items(): - close() - self._cache.clear() - # If self._lock was locked at the time of the fork, it may be broken - # -- see issue 6721. Replace it without letting it be gc'ed. - self._old_locks.append(self._lock) - self._lock = threading.Lock() - if self._listener is not None: - self._listener.close() - self._listener = None - self._address = None - self._thread = None - - def _start(self): - from .connection import Listener - assert self._listener is None - debug('starting listener and thread for sending handles') - self._listener = Listener(authkey=current_process().authkey) - self._address = self._listener.address - t = threading.Thread(target=self._serve) - t.daemon = True - t.start() - self._thread = t - - def _serve(self): - if hasattr(signal, 'pthread_sigmask'): - signal.pthread_sigmask(signal.SIG_BLOCK, range(1, signal.NSIG)) - while 1: - try: - conn = self._listener.accept() - msg = conn.recv() - if msg is None: - break - key, destination_pid = msg - send, close = self._cache.pop(key) - send(conn, destination_pid) - close() - conn.close() - except: - if not is_exiting(): - import traceback - sub_warning( - 'thread for sharing handles raised exception :\n' + - '-'*79 + '\n' + traceback.format_exc() + '-'*79 - ) - -resource_sharer = ResourceSharer() +if sys.platform == 'win32': + def _reduce_socket(s): + from .resource_sharer import DupSocket + return _rebuild_socket, (DupSocket(s),) + def _rebuild_socket(ds): + return ds.detach() + register(socket.socket, _reduce_socket) + +else: + def _reduce_socket(s): + df = DupFd(s.fileno()) + return _rebuild_socket, (df, s.family, s.type, s.proto) + def _rebuild_socket(df, family, type, proto): + fd = df.detach() + return socket.socket(family, type, proto, fileno=fd) + register(socket.socket, _reduce_socket) |