summaryrefslogtreecommitdiffstats
path: root/Lib/multiprocessing/connection.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/multiprocessing/connection.py')
-rw-r--r--Lib/multiprocessing/connection.py415
1 files changed, 368 insertions, 47 deletions
diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py
index df00f1d..8807618 100644
--- a/Lib/multiprocessing/connection.py
+++ b/Lib/multiprocessing/connection.py
@@ -34,19 +34,31 @@
__all__ = [ 'Client', 'Listener', 'Pipe' ]
+import io
import os
import sys
+import pickle
+import select
import socket
+import struct
import errno
import time
import tempfile
import itertools
import _multiprocessing
-from multiprocessing import current_process, AuthenticationError
-from multiprocessing.util import get_temp_dir, Finalize, sub_debug, debug
-from multiprocessing.forking import duplicate, close
+from multiprocessing import current_process, AuthenticationError, BufferTooShort
+from multiprocessing.util import (
+ get_temp_dir, Finalize, sub_debug, debug, _eintr_retry)
+try:
+ from _multiprocessing import win32
+ from _subprocess import WAIT_OBJECT_0, WAIT_TIMEOUT, INFINITE
+except ImportError:
+ if sys.platform == 'win32':
+ raise
+ win32 = None
+_select = _eintr_retry(select.select)
#
#
@@ -110,6 +122,326 @@ def address_type(address):
else:
raise ValueError('address type of %r unrecognized' % address)
+
+class SentinelReady(Exception):
+ """
+ Raised when a sentinel is ready when polling.
+ """
+ def __init__(self, *args):
+ Exception.__init__(self, *args)
+ self.sentinels = args[0]
+
+#
+# Connection classes
+#
+
+class _ConnectionBase:
+ _handle = None
+
+ def __init__(self, handle, readable=True, writable=True):
+ handle = handle.__index__()
+ if handle < 0:
+ raise ValueError("invalid handle")
+ if not readable and not writable:
+ raise ValueError(
+ "at least one of `readable` and `writable` must be True")
+ self._handle = handle
+ self._readable = readable
+ self._writable = writable
+
+ # XXX should we use util.Finalize instead of a __del__?
+
+ def __del__(self):
+ if self._handle is not None:
+ self._close()
+
+ def _check_closed(self):
+ if self._handle is None:
+ raise IOError("handle is closed")
+
+ def _check_readable(self):
+ if not self._readable:
+ raise IOError("connection is write-only")
+
+ def _check_writable(self):
+ if not self._writable:
+ raise IOError("connection is read-only")
+
+ def _bad_message_length(self):
+ if self._writable:
+ self._readable = False
+ else:
+ self.close()
+ raise IOError("bad message length")
+
+ @property
+ def closed(self):
+ """True if the connection is closed"""
+ return self._handle is None
+
+ @property
+ def readable(self):
+ """True if the connection is readable"""
+ return self._readable
+
+ @property
+ def writable(self):
+ """True if the connection is writable"""
+ return self._writable
+
+ def fileno(self):
+ """File descriptor or handle of the connection"""
+ self._check_closed()
+ return self._handle
+
+ def close(self):
+ """Close the connection"""
+ if self._handle is not None:
+ try:
+ self._close()
+ finally:
+ self._handle = None
+
+ def send_bytes(self, buf, offset=0, size=None):
+ """Send the bytes data from a bytes-like object"""
+ self._check_closed()
+ self._check_writable()
+ m = memoryview(buf)
+ # HACK for byte-indexing of non-bytewise buffers (e.g. array.array)
+ if m.itemsize > 1:
+ m = memoryview(bytes(m))
+ n = len(m)
+ if offset < 0:
+ raise ValueError("offset is negative")
+ if n < offset:
+ raise ValueError("buffer length < offset")
+ if size is None:
+ size = n - offset
+ elif size < 0:
+ raise ValueError("size is negative")
+ elif offset + size > n:
+ raise ValueError("buffer length < offset + size")
+ self._send_bytes(m[offset:offset + size])
+
+ def send(self, obj):
+ """Send a (picklable) object"""
+ self._check_closed()
+ self._check_writable()
+ buf = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
+ self._send_bytes(memoryview(buf))
+
+ def recv_bytes(self, maxlength=None):
+ """
+ Receive bytes data as a bytes object.
+ """
+ self._check_closed()
+ self._check_readable()
+ if maxlength is not None and maxlength < 0:
+ raise ValueError("negative maxlength")
+ buf = self._recv_bytes(maxlength)
+ if buf is None:
+ self._bad_message_length()
+ return buf.getvalue()
+
+ def recv_bytes_into(self, buf, offset=0):
+ """
+ Receive bytes data into a writeable buffer-like object.
+ Return the number of bytes read.
+ """
+ self._check_closed()
+ self._check_readable()
+ with memoryview(buf) as m:
+ # Get bytesize of arbitrary buffer
+ itemsize = m.itemsize
+ bytesize = itemsize * len(m)
+ if offset < 0:
+ raise ValueError("negative offset")
+ elif offset > bytesize:
+ raise ValueError("offset too large")
+ result = self._recv_bytes()
+ size = result.tell()
+ if bytesize < offset + size:
+ raise BufferTooShort(result.getvalue())
+ # Message can fit in dest
+ result.seek(0)
+ result.readinto(m[offset // itemsize :
+ (offset + size) // itemsize])
+ return size
+
+ def recv(self, sentinels=None):
+ """Receive a (picklable) object"""
+ self._check_closed()
+ self._check_readable()
+ buf = self._recv_bytes(sentinels=sentinels)
+ return pickle.loads(buf.getbuffer())
+
+ def poll(self, timeout=0.0):
+ """Whether there is any input available to be read"""
+ self._check_closed()
+ self._check_readable()
+ return self._poll(timeout)
+
+
+if win32:
+
+ class PipeConnection(_ConnectionBase):
+ """
+ Connection class based on a Windows named pipe.
+ Overlapped I/O is used, so the handles must have been created
+ with FILE_FLAG_OVERLAPPED.
+ """
+ _buffered = b''
+
+ def _close(self, _CloseHandle=win32.CloseHandle):
+ _CloseHandle(self._handle)
+
+ def _send_bytes(self, buf):
+ overlapped = win32.WriteFile(self._handle, buf, overlapped=True)
+ nwritten, complete = overlapped.GetOverlappedResult(True)
+ assert complete
+ assert nwritten == len(buf)
+
+ def _recv_bytes(self, maxsize=None, sentinels=()):
+ if sentinels:
+ self._poll(-1.0, sentinels)
+ buf = io.BytesIO()
+ firstchunk = self._buffered
+ if firstchunk:
+ lenfirstchunk = len(firstchunk)
+ buf.write(firstchunk)
+ self._buffered = b''
+ else:
+ # A reasonable size for the first chunk transfer
+ bufsize = 128
+ if maxsize is not None and maxsize < bufsize:
+ bufsize = maxsize
+ try:
+ overlapped = win32.ReadFile(self._handle, bufsize, overlapped=True)
+ lenfirstchunk, complete = overlapped.GetOverlappedResult(True)
+ firstchunk = overlapped.getbuffer()
+ assert lenfirstchunk == len(firstchunk)
+ except IOError as e:
+ if e.winerror == win32.ERROR_BROKEN_PIPE:
+ raise EOFError
+ raise
+ buf.write(firstchunk)
+ if complete:
+ return buf
+ navail, nleft = win32.PeekNamedPipe(self._handle)
+ if maxsize is not None and lenfirstchunk + nleft > maxsize:
+ return None
+ if nleft > 0:
+ overlapped = win32.ReadFile(self._handle, nleft, overlapped=True)
+ res, complete = overlapped.GetOverlappedResult(True)
+ assert res == nleft
+ assert complete
+ buf.write(overlapped.getbuffer())
+ return buf
+
+ def _poll(self, timeout, sentinels=()):
+ # Fast non-blocking path
+ navail, nleft = win32.PeekNamedPipe(self._handle)
+ if navail > 0:
+ return True
+ elif timeout == 0.0:
+ return False
+ # Blocking: use overlapped I/O
+ if timeout < 0.0:
+ timeout = INFINITE
+ else:
+ timeout = int(timeout * 1000 + 0.5)
+ overlapped = win32.ReadFile(self._handle, 1, overlapped=True)
+ try:
+ handles = [overlapped.event]
+ handles += sentinels
+ res = win32.WaitForMultipleObjects(handles, False, timeout)
+ finally:
+ # Always cancel overlapped I/O in the same thread
+ # (because CancelIoEx() appears only in Vista)
+ overlapped.cancel()
+ if res == WAIT_TIMEOUT:
+ return False
+ idx = res - WAIT_OBJECT_0
+ if idx == 0:
+ # I/O was successful, store received data
+ overlapped.GetOverlappedResult(True)
+ self._buffered += overlapped.getbuffer()
+ return True
+ assert 0 < idx < len(handles)
+ raise SentinelReady([handles[idx]])
+
+
+class Connection(_ConnectionBase):
+ """
+ Connection class based on an arbitrary file descriptor (Unix only), or
+ a socket handle (Windows).
+ """
+
+ if win32:
+ def _close(self, _close=win32.closesocket):
+ _close(self._handle)
+ _write = win32.send
+ _read = win32.recv
+ else:
+ def _close(self, _close=os.close):
+ _close(self._handle)
+ _write = os.write
+ _read = os.read
+
+ def _send(self, buf, write=_write):
+ remaining = len(buf)
+ while True:
+ n = write(self._handle, buf)
+ remaining -= n
+ if remaining == 0:
+ break
+ buf = buf[n:]
+
+ def _recv(self, size, sentinels=(), read=_read):
+ buf = io.BytesIO()
+ handle = self._handle
+ if sentinels:
+ handles = [handle] + sentinels
+ remaining = size
+ while remaining > 0:
+ if sentinels:
+ r = _select(handles, [], [])[0]
+ if handle not in r:
+ raise SentinelReady(r)
+ chunk = read(handle, remaining)
+ n = len(chunk)
+ if n == 0:
+ if remaining == size:
+ raise EOFError
+ else:
+ raise IOError("got end of file during message")
+ buf.write(chunk)
+ remaining -= n
+ return buf
+
+ def _send_bytes(self, buf):
+ # For wire compatibility with 3.2 and lower
+ n = len(buf)
+ self._send(struct.pack("!i", n))
+ # The condition is necessary to avoid "broken pipe" errors
+ # when sending a 0-length buffer if the other end closed the pipe.
+ if n > 0:
+ self._send(buf)
+
+ def _recv_bytes(self, maxsize=None, sentinels=()):
+ buf = self._recv(4, sentinels)
+ size, = struct.unpack("!i", buf.getvalue())
+ if maxsize is not None and size > maxsize:
+ return None
+ return self._recv(size, sentinels)
+
+ def _poll(self, timeout):
+ if timeout < 0.0:
+ timeout = None
+ r = _select([self._handle], [], [], timeout)[0]
+ return bool(r)
+
+
#
# Public functions
#
@@ -186,21 +518,17 @@ if sys.platform != 'win32':
'''
if duplex:
s1, s2 = socket.socketpair()
- c1 = _multiprocessing.Connection(os.dup(s1.fileno()))
- c2 = _multiprocessing.Connection(os.dup(s2.fileno()))
- s1.close()
- s2.close()
+ c1 = Connection(s1.detach())
+ c2 = Connection(s2.detach())
else:
fd1, fd2 = os.pipe()
- c1 = _multiprocessing.Connection(fd1, writable=False)
- c2 = _multiprocessing.Connection(fd2, readable=False)
+ c1 = Connection(fd1, writable=False)
+ c2 = Connection(fd2, readable=False)
return c1, c2
else:
- from _multiprocessing import win32
-
def Pipe(duplex=True):
'''
Returns pair of connection objects at either end of a pipe
@@ -216,26 +544,25 @@ else:
obsize, ibsize = 0, BUFSIZE
h1 = win32.CreateNamedPipe(
- address, openmode,
+ address, openmode | win32.FILE_FLAG_OVERLAPPED |
+ win32.FILE_FLAG_FIRST_PIPE_INSTANCE,
win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
win32.PIPE_WAIT,
1, obsize, ibsize, win32.NMPWAIT_WAIT_FOREVER, win32.NULL
)
h2 = win32.CreateFile(
- address, access, 0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
+ address, access, 0, win32.NULL, win32.OPEN_EXISTING,
+ win32.FILE_FLAG_OVERLAPPED, win32.NULL
)
win32.SetNamedPipeHandleState(
h2, win32.PIPE_READMODE_MESSAGE, None, None
)
- try:
- win32.ConnectNamedPipe(h1, win32.NULL)
- except WindowsError as e:
- if e.args[0] != win32.ERROR_PIPE_CONNECTED:
- raise
+ overlapped = win32.ConnectNamedPipe(h1, overlapped=True)
+ overlapped.GetOverlappedResult(True)
- c1 = _multiprocessing.PipeConnection(h1, writable=duplex)
- c2 = _multiprocessing.PipeConnection(h2, readable=duplex)
+ c1 = PipeConnection(h1, writable=duplex)
+ c2 = PipeConnection(h2, readable=duplex)
return c1, c2
@@ -250,11 +577,14 @@ class SocketListener(object):
def __init__(self, address, family, backlog=1):
self._socket = socket.socket(getattr(socket, family))
try:
- self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ # SO_REUSEADDR has different semantics on Windows (issue #2550).
+ if os.name == 'posix':
+ self._socket.setsockopt(socket.SOL_SOCKET,
+ socket.SO_REUSEADDR, 1)
self._socket.bind(address)
self._socket.listen(backlog)
self._address = self._socket.getsockname()
- except socket.error:
+ except OSError:
self._socket.close()
raise
self._family = family
@@ -270,7 +600,7 @@ class SocketListener(object):
def accept(self):
s, self._last_accepted = self._socket.accept()
fd = duplicate(s.fileno())
- conn = _multiprocessing.Connection(fd)
+ conn = Connection(fd)
s.close()
return conn
@@ -286,23 +616,9 @@ def SocketClient(address):
'''
family = address_type(address)
with socket.socket( getattr(socket, family) ) as s:
- t = _init_timeout()
-
- while 1:
- try:
- s.connect(address)
- except socket.error as e:
- if e.args[0] != errno.ECONNREFUSED or _check_timeout(t):
- debug('failed to connect to address %s', address)
- raise
- time.sleep(0.01)
- else:
- break
- else:
- raise
-
+ s.connect(address)
fd = duplicate(s.fileno())
- conn = _multiprocessing.Connection(fd)
+ conn = Connection(fd)
return conn
#
@@ -318,7 +634,8 @@ if sys.platform == 'win32':
def __init__(self, address, backlog=None):
self._address = address
handle = win32.CreateNamedPipe(
- address, win32.PIPE_ACCESS_DUPLEX,
+ address, win32.PIPE_ACCESS_DUPLEX |
+ win32.FILE_FLAG_FIRST_PIPE_INSTANCE,
win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
win32.PIPE_WAIT,
win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
@@ -347,9 +664,9 @@ if sys.platform == 'win32':
try:
win32.ConnectNamedPipe(handle, win32.NULL)
except WindowsError as e:
- if e.args[0] != win32.ERROR_PIPE_CONNECTED:
+ if e.winerror != win32.ERROR_PIPE_CONNECTED:
raise
- return _multiprocessing.PipeConnection(handle)
+ return PipeConnection(handle)
@staticmethod
def _finalize_pipe_listener(queue, address):
@@ -370,8 +687,8 @@ if sys.platform == 'win32':
0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
)
except WindowsError as e:
- if e.args[0] not in (win32.ERROR_SEM_TIMEOUT,
- win32.ERROR_PIPE_BUSY) or _check_timeout(t):
+ if e.winerror not in (win32.ERROR_SEM_TIMEOUT,
+ win32.ERROR_PIPE_BUSY) or _check_timeout(t):
raise
else:
break
@@ -381,7 +698,7 @@ if sys.platform == 'win32':
win32.SetNamedPipeHandleState(
h, win32.PIPE_READMODE_MESSAGE, None, None
)
- return _multiprocessing.PipeConnection(h)
+ return PipeConnection(h)
#
# Authentication stuff
@@ -438,10 +755,10 @@ class ConnectionWrapper(object):
return self._loads(s)
def _xml_dumps(obj):
- return xmlrpclib.dumps((obj,), None, None, None, 1).encode('utf8')
+ return xmlrpclib.dumps((obj,), None, None, None, 1).encode('utf-8')
def _xml_loads(s):
- (obj,), method = xmlrpclib.loads(s.decode('utf8'))
+ (obj,), method = xmlrpclib.loads(s.decode('utf-8'))
return obj
class XmlListener(Listener):
@@ -455,3 +772,7 @@ def XmlClient(*args, **kwds):
global xmlrpclib
import xmlrpc.client as xmlrpclib
return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
+
+
+# Late import because of circular import
+from multiprocessing.forking import duplicate, close