diff options
author | Antoine Pitrou <solipsis@pitrou.net> | 2011-05-09 15:04:27 (GMT) |
---|---|---|
committer | Antoine Pitrou <solipsis@pitrou.net> | 2011-05-09 15:04:27 (GMT) |
commit | 87cf220972c9cb400ddcd577962883dcc5dca51a (patch) | |
tree | 3f1ab5b64ae538a2ced622637cc7e4112b1c6ffd /Lib | |
parent | df77e3d4a07223ebfe049e66d4d8a8c0b4315e04 (diff) | |
download | cpython-87cf220972c9cb400ddcd577962883dcc5dca51a.zip cpython-87cf220972c9cb400ddcd577962883dcc5dca51a.tar.gz cpython-87cf220972c9cb400ddcd577962883dcc5dca51a.tar.bz2 |
Issue #11743: Rewrite multiprocessing connection classes in pure Python.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/multiprocessing/connection.py | 315 | ||||
-rw-r--r-- | Lib/multiprocessing/forking.py | 5 | ||||
-rw-r--r-- | Lib/multiprocessing/reduction.py | 13 | ||||
-rw-r--r-- | Lib/test/test_multiprocessing.py | 12 |
4 files changed, 320 insertions, 25 deletions
diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py index d6627e5..afd580b 100644 --- a/Lib/multiprocessing/connection.py +++ b/Lib/multiprocessing/connection.py @@ -34,19 +34,27 @@ __all__ = [ 'Client', 'Listener', 'Pipe' ] +import io import os import sys +import pickle +import select import socket +import struct import errno import time import tempfile import itertools import _multiprocessing -from multiprocessing import current_process, AuthenticationError +from multiprocessing import current_process, AuthenticationError, BufferTooShort from multiprocessing.util import get_temp_dir, Finalize, sub_debug, debug -from multiprocessing.forking import duplicate, close - +try: + from _multiprocessing import win32 +except ImportError: + if sys.platform == 'win32': + raise + win32 = None # # @@ -111,6 +119,281 @@ def address_type(address): raise ValueError('address type of %r unrecognized' % address) # +# Connection classes +# + +class _ConnectionBase: + _handle = None + + def __init__(self, handle, readable=True, writable=True): + handle = handle.__index__() + if handle < 0: + raise ValueError("invalid handle") + if not readable and not writable: + raise ValueError( + "at least one of `readable` and `writable` must be True") + self._handle = handle + self._readable = readable + self._writable = writable + + def __del__(self): + if self._handle is not None: + self._close() + + def _check_closed(self): + if self._handle is None: + raise IOError("handle is closed") + + def _check_readable(self): + if not self._readable: + raise IOError("connection is write-only") + + def _check_writable(self): + if not self._writable: + raise IOError("connection is read-only") + + def _bad_message_length(self): + if self._writable: + self._readable = False + else: + self.close() + raise IOError("bad message length") + + @property + def closed(self): + """True if the connection is closed""" + return self._handle is None + + @property + def readable(self): + """True if the connection is readable""" + return self._readable + + @property + def writable(self): + """True if the connection is writable""" + return self._writable + + def fileno(self): + """File descriptor or handle of the connection""" + self._check_closed() + return self._handle + + def close(self): + """Close the connection""" + if self._handle is not None: + try: + self._close() + finally: + self._handle = None + + def send_bytes(self, buf, offset=0, size=None): + """Send the bytes data from a bytes-like object""" + self._check_closed() + self._check_writable() + m = memoryview(buf) + # HACK for byte-indexing of non-bytewise buffers (e.g. array.array) + if m.itemsize > 1: + m = memoryview(bytes(m)) + n = len(m) + if offset < 0: + raise ValueError("offset is negative") + if n < offset: + raise ValueError("buffer length < offset") + if size is None: + size = n - offset + elif size < 0: + raise ValueError("size is negative") + elif offset + size > n: + raise ValueError("buffer length < offset + size") + self._send_bytes(m[offset:offset + size]) + + def send(self, obj): + """Send a (picklable) object""" + self._check_closed() + self._check_writable() + buf = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + self._send_bytes(memoryview(buf)) + + def recv_bytes(self, maxlength=None): + """ + Receive bytes data as a bytes object. + """ + self._check_closed() + self._check_readable() + if maxlength is not None and maxlength < 0: + raise ValueError("negative maxlength") + buf = self._recv_bytes(maxlength) + if buf is None: + self._bad_message_length() + return buf.getvalue() + + def recv_bytes_into(self, buf, offset=0): + """ + Receive bytes data into a writeable buffer-like object. + Return the number of bytes read. + """ + self._check_closed() + self._check_readable() + with memoryview(buf) as m: + # Get bytesize of arbitrary buffer + itemsize = m.itemsize + bytesize = itemsize * len(m) + if offset < 0: + raise ValueError("negative offset") + elif offset > bytesize: + raise ValueError("offset too large") + result = self._recv_bytes() + size = result.tell() + if bytesize < offset + size: + raise BufferTooShort(result.getvalue()) + # Message can fit in dest + result.seek(0) + result.readinto(m[offset // itemsize : + (offset + size) // itemsize]) + return size + + def recv(self): + """Receive a (picklable) object""" + self._check_closed() + self._check_readable() + buf = self._recv_bytes() + return pickle.loads(buf.getbuffer()) + + def poll(self, timeout=0.0): + """Whether there is any input available to be read""" + self._check_closed() + self._check_readable() + if timeout < 0.0: + timeout = None + return self._poll(timeout) + + +if win32: + + class PipeConnection(_ConnectionBase): + """ + Connection class based on a Windows named pipe. + """ + + def _close(self): + win32.CloseHandle(self._handle) + + def _send_bytes(self, buf): + nwritten = win32.WriteFile(self._handle, buf) + assert nwritten == len(buf) + + def _recv_bytes(self, maxsize=None): + buf = io.BytesIO() + bufsize = 512 + if maxsize is not None: + bufsize = min(bufsize, maxsize) + try: + firstchunk, complete = win32.ReadFile(self._handle, bufsize) + except IOError as e: + if e.errno == win32.ERROR_BROKEN_PIPE: + raise EOFError + raise + lenfirstchunk = len(firstchunk) + buf.write(firstchunk) + if complete: + return buf + navail, nleft = win32.PeekNamedPipe(self._handle) + if maxsize is not None and lenfirstchunk + nleft > maxsize: + return None + lastchunk, complete = win32.ReadFile(self._handle, nleft) + assert complete + buf.write(lastchunk) + return buf + + def _poll(self, timeout): + navail, nleft = win32.PeekNamedPipe(self._handle) + if navail > 0: + return True + elif timeout == 0.0: + return False + # Setup a polling loop (translated straight from old + # pipe_connection.c) + if timeout < 0.0: + deadline = None + else: + deadline = time.time() + timeout + delay = 0.001 + max_delay = 0.02 + while True: + time.sleep(delay) + navail, nleft = win32.PeekNamedPipe(self._handle) + if navail > 0: + return True + if deadline and time.time() > deadline: + return False + if delay < max_delay: + delay += 0.001 + + +class Connection(_ConnectionBase): + """ + Connection class based on an arbitrary file descriptor (Unix only), or + a socket handle (Windows). + """ + + if win32: + def _close(self): + win32.closesocket(self._handle) + _write = win32.send + _read = win32.recv + else: + def _close(self): + os.close(self._handle) + _write = os.write + _read = os.read + + def _send(self, buf, write=_write): + remaining = len(buf) + while True: + n = write(self._handle, buf) + remaining -= n + if remaining == 0: + break + buf = buf[n:] + + def _recv(self, size, read=_read): + buf = io.BytesIO() + remaining = size + while remaining > 0: + chunk = read(self._handle, remaining) + n = len(chunk) + if n == 0: + if remaining == size: + raise EOFError + else: + raise IOError("got end of file during message") + buf.write(chunk) + remaining -= n + return buf + + def _send_bytes(self, buf): + # For wire compatibility with 3.2 and lower + n = len(buf) + self._send(struct.pack("=i", len(buf))) + # The condition is necessary to avoid "broken pipe" errors + # when sending a 0-length buffer if the other end closed the pipe. + if n > 0: + self._send(buf) + + def _recv_bytes(self, maxsize=None): + buf = self._recv(4) + size, = struct.unpack("=i", buf.getvalue()) + if maxsize is not None and size > maxsize: + return None + return self._recv(size) + + def _poll(self, timeout): + r = select.select([self._handle], [], [], timeout)[0] + return bool(r) + + +# # Public functions # @@ -186,21 +469,19 @@ if sys.platform != 'win32': ''' if duplex: s1, s2 = socket.socketpair() - c1 = _multiprocessing.Connection(os.dup(s1.fileno())) - c2 = _multiprocessing.Connection(os.dup(s2.fileno())) + c1 = Connection(os.dup(s1.fileno())) + c2 = Connection(os.dup(s2.fileno())) s1.close() s2.close() else: fd1, fd2 = os.pipe() - c1 = _multiprocessing.Connection(fd1, writable=False) - c2 = _multiprocessing.Connection(fd2, readable=False) + c1 = Connection(fd1, writable=False) + c2 = Connection(fd2, readable=False) return c1, c2 else: - from _multiprocessing import win32 - def Pipe(duplex=True): ''' Returns pair of connection objects at either end of a pipe @@ -234,8 +515,8 @@ else: if e.args[0] != win32.ERROR_PIPE_CONNECTED: raise - c1 = _multiprocessing.PipeConnection(h1, writable=duplex) - c2 = _multiprocessing.PipeConnection(h2, readable=duplex) + c1 = PipeConnection(h1, writable=duplex) + c2 = PipeConnection(h2, readable=duplex) return c1, c2 @@ -266,7 +547,7 @@ class SocketListener(object): def accept(self): s, self._last_accepted = self._socket.accept() fd = duplicate(s.fileno()) - conn = _multiprocessing.Connection(fd) + conn = Connection(fd) s.close() return conn @@ -298,7 +579,7 @@ def SocketClient(address): raise fd = duplicate(s.fileno()) - conn = _multiprocessing.Connection(fd) + conn = Connection(fd) return conn # @@ -345,7 +626,7 @@ if sys.platform == 'win32': except WindowsError as e: if e.args[0] != win32.ERROR_PIPE_CONNECTED: raise - return _multiprocessing.PipeConnection(handle) + return PipeConnection(handle) @staticmethod def _finalize_pipe_listener(queue, address): @@ -377,7 +658,7 @@ if sys.platform == 'win32': win32.SetNamedPipeHandleState( h, win32.PIPE_READMODE_MESSAGE, None, None ) - return _multiprocessing.PipeConnection(h) + return PipeConnection(h) # # Authentication stuff @@ -451,3 +732,7 @@ def XmlClient(*args, **kwds): global xmlrpclib import xmlrpc.client as xmlrpclib return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads) + + +# Late import because of circular import +from multiprocessing.forking import duplicate, close diff --git a/Lib/multiprocessing/forking.py b/Lib/multiprocessing/forking.py index cc7c326..3d95557 100644 --- a/Lib/multiprocessing/forking.py +++ b/Lib/multiprocessing/forking.py @@ -183,7 +183,7 @@ else: import time from pickle import dump, load, HIGHEST_PROTOCOL - from _multiprocessing import win32, Connection, PipeConnection + from _multiprocessing import win32 from .util import Finalize def dump(obj, file, protocol=None): @@ -411,6 +411,9 @@ else: # Make (Pipe)Connection picklable # + # Late import because of circular import + from .connection import Connection, PipeConnection + def reduce_connection(conn): if not Popen.thread_is_spawning(): raise RuntimeError( diff --git a/Lib/multiprocessing/reduction.py b/Lib/multiprocessing/reduction.py index 6e5e5bc..b32c725 100644 --- a/Lib/multiprocessing/reduction.py +++ b/Lib/multiprocessing/reduction.py @@ -44,7 +44,7 @@ import _multiprocessing from multiprocessing import current_process from multiprocessing.forking import Popen, duplicate, close, ForkingPickler from multiprocessing.util import register_after_fork, debug, sub_debug -from multiprocessing.connection import Client, Listener +from multiprocessing.connection import Client, Listener, Connection # @@ -159,7 +159,7 @@ def rebuild_handle(pickled_data): return new_handle # -# Register `_multiprocessing.Connection` with `ForkingPickler` +# Register `Connection` with `ForkingPickler` # def reduce_connection(conn): @@ -168,11 +168,11 @@ def reduce_connection(conn): def rebuild_connection(reduced_handle, readable, writable): handle = rebuild_handle(reduced_handle) - return _multiprocessing.Connection( + return Connection( handle, readable=readable, writable=writable ) -ForkingPickler.register(_multiprocessing.Connection, reduce_connection) +ForkingPickler.register(Connection, reduce_connection) # # Register `socket.socket` with `ForkingPickler` @@ -201,6 +201,7 @@ ForkingPickler.register(socket.socket, reduce_socket) # if sys.platform == 'win32': + from multiprocessing.connection import PipeConnection def reduce_pipe_connection(conn): rh = reduce_handle(conn.fileno()) @@ -208,8 +209,8 @@ if sys.platform == 'win32': def rebuild_pipe_connection(reduced_handle, readable, writable): handle = rebuild_handle(reduced_handle) - return _multiprocessing.PipeConnection( + return PipeConnection( handle, readable=readable, writable=writable ) - ForkingPickler.register(_multiprocessing.PipeConnection, reduce_pipe_connection) + ForkingPickler.register(PipeConnection, reduce_pipe_connection) diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index a7f0391..0c05ff6 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -1915,9 +1915,15 @@ class TestInvalidHandle(unittest.TestCase): @unittest.skipIf(WIN32, "skipped on Windows") def test_invalid_handles(self): - conn = _multiprocessing.Connection(44977608) - self.assertRaises(IOError, conn.poll) - self.assertRaises(IOError, _multiprocessing.Connection, -1) + conn = multiprocessing.connection.Connection(44977608) + try: + self.assertRaises((ValueError, IOError), conn.poll) + finally: + # Hack private attribute _handle to avoid printing an error + # in conn.__del__ + conn._handle = None + self.assertRaises((ValueError, IOError), + multiprocessing.connection.Connection, -1) # # Functions used to create test cases from the base ones in this module |