summaryrefslogtreecommitdiffstats
path: root/Lib/socket.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/socket.py')
-rw-r--r--Lib/socket.py395
1 files changed, 94 insertions, 301 deletions
diff --git a/Lib/socket.py b/Lib/socket.py
index 8dd2383..03cdc65 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -54,7 +54,7 @@ try:
except ImportError:
pass
-import os, sys
+import os, sys, io
try:
from errno import EBADF
@@ -66,14 +66,6 @@ __all__.extend(os._get_exports_list(_socket))
if _have_ssl:
__all__.extend(os._get_exports_list(_ssl))
-_realsocket = socket
-if _have_ssl:
- _realssl = ssl
- def ssl(sock, keyfile=None, certfile=None):
- if hasattr(sock, "_sock"):
- sock = sock._sock
- return _realssl(sock, keyfile, certfile)
-
# WSA error codes
if sys.platform.lower().startswith("win"):
errorTab = {}
@@ -95,6 +87,99 @@ if sys.platform.lower().startswith("win"):
__all__.append("errorTab")
+_os_has_dup = hasattr(os, "dup")
+if _os_has_dup:
+ def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0):
+ nfd = os.dup(fd)
+ return socket(family, type, proto, fileno=nfd)
+
+
+class socket(_socket.socket):
+
+ """A subclass of _socket.socket adding the makefile() method."""
+
+ __slots__ = ["__weakref__"]
+ if not _os_has_dup:
+ __slots__.append("_base")
+
+ def __repr__(self):
+ """Wrap __repr__() to reveal the real class name."""
+ s = _socket.socket.__repr__(self)
+ if s.startswith("<socket object"):
+ s = "<%s.%s%s" % (self.__class__.__module__,
+ self.__class__.__name__,
+ s[7:])
+ return s
+
+ def accept(self):
+ """Wrap accept() to give the connection the right type."""
+ conn, addr = _socket.socket.accept(self)
+ fd = conn.fileno()
+ nfd = fd
+ if _os_has_dup:
+ nfd = os.dup(fd)
+ wrapper = socket(self.family, self.type, self.proto, fileno=nfd)
+ if fd == nfd:
+ wrapper._base = conn # Keep the base alive
+ else:
+ conn.close()
+ return wrapper, addr
+
+ if not _os_has_dup:
+ def close(self):
+ """Wrap close() to close the _base as well."""
+ _socket.socket.close(self)
+ base = getattr(self, "_base", None)
+ if base is not None:
+ base.close()
+
+ def makefile(self, mode="r", buffering=None, *,
+ encoding=None, newline=None):
+ """Return an I/O stream connected to the socket.
+
+ The arguments are as for io.open() after the filename,
+ except the only mode characters supported are 'r', 'w' and 'b'.
+ The semantics are similar too. (XXX refactor to share code?)
+ """
+ for c in mode:
+ if c not in {"r", "w", "b"}:
+ raise ValueError("invalid mode %r (only r, w, b allowed)")
+ writing = "w" in mode
+ reading = "r" in mode or not writing
+ assert reading or writing
+ binary = "b" in mode
+ rawmode = ""
+ if reading:
+ rawmode += "r"
+ if writing:
+ rawmode += "w"
+ raw = io.SocketIO(self, rawmode)
+ if buffering is None:
+ buffering = -1
+ if buffering < 0:
+ buffering = io.DEFAULT_BUFFER_SIZE
+ if buffering == 0:
+ if not binary:
+ raise ValueError("unbuffered streams must be binary")
+ raw.name = self.fileno()
+ raw.mode = mode
+ return raw
+ if reading and writing:
+ buffer = io.BufferedRWPair(raw, raw, buffering)
+ elif reading:
+ buffer = io.BufferedReader(raw, buffering)
+ else:
+ assert writing
+ buffer = io.BufferedWriter(raw, buffering)
+ if binary:
+ buffer.name = self.fileno()
+ buffer.mode = mode
+ return buffer
+ text = io.TextIOWrapper(buffer, encoding, newline)
+ text.name = self.fileno()
+ self.mode = mode
+ return text
+
def getfqdn(name=''):
"""Get fully qualified domain name from name.
@@ -122,298 +207,6 @@ def getfqdn(name=''):
return name
-_socketmethods = (
- 'bind', 'connect', 'connect_ex', 'fileno', 'listen',
- 'getpeername', 'getsockname', 'getsockopt', 'setsockopt',
- 'sendall', 'setblocking',
- 'settimeout', 'gettimeout', 'shutdown')
-
-if sys.platform == "riscos":
- _socketmethods = _socketmethods + ('sleeptaskw',)
-
-# All the method names that must be delegated to either the real socket
-# object or the _closedsocket object.
-_delegate_methods = ("recv", "recvfrom", "recv_into", "recvfrom_into",
- "send", "sendto")
-
-class _closedsocket(object):
- __slots__ = []
- def _dummy(*args):
- raise error(EBADF, 'Bad file descriptor')
- # All _delegate_methods must also be initialized here.
- send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy
- __getattr__ = _dummy
-
-class _socketobject(object):
-
- __doc__ = _realsocket.__doc__
-
- __slots__ = ["_sock", "__weakref__"] + list(_delegate_methods)
-
- def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None):
- if _sock is None:
- _sock = _realsocket(family, type, proto)
- self._sock = _sock
- for method in _delegate_methods:
- setattr(self, method, getattr(_sock, method))
-
- def close(self):
- self._sock = _closedsocket()
- dummy = self._sock._dummy
- for method in _delegate_methods:
- setattr(self, method, dummy)
- close.__doc__ = _realsocket.close.__doc__
-
- def accept(self):
- sock, addr = self._sock.accept()
- return _socketobject(_sock=sock), addr
- accept.__doc__ = _realsocket.accept.__doc__
-
- def dup(self):
- """dup() -> socket object
-
- Return a new socket object connected to the same system resource."""
- return _socketobject(_sock=self._sock)
-
- def makefile(self, mode='r', bufsize=-1):
- """makefile([mode[, bufsize]]) -> file object
-
- Return a regular file object corresponding to the socket. The mode
- and bufsize arguments are as for the built-in open() function."""
- return _fileobject(self._sock, mode, bufsize)
-
- family = property(lambda self: self._sock.family, doc="the socket family")
- type = property(lambda self: self._sock.type, doc="the socket type")
- proto = property(lambda self: self._sock.proto, doc="the socket protocol")
-
- _s = ("def %s(self, *args): return self._sock.%s(*args)\n\n"
- "%s.__doc__ = _realsocket.%s.__doc__\n")
- for _m in _socketmethods:
- exec(_s % (_m, _m, _m, _m))
- del _m, _s
-
-socket = SocketType = _socketobject
-
-class _fileobject(object):
- """Faux file object attached to a socket object."""
-
- default_bufsize = 8192
- name = "<socket>"
-
- __slots__ = ["mode", "bufsize",
- # "closed" is a property, see below
- "_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf",
- "_close"]
-
- def __init__(self, sock, mode='rb', bufsize=-1, close=False):
- self._sock = sock
- self.mode = mode # Not actually used in this version
- if bufsize < 0:
- bufsize = self.default_bufsize
- self.bufsize = bufsize
- if bufsize == 0:
- self._rbufsize = 1
- elif bufsize == 1:
- self._rbufsize = self.default_bufsize
- else:
- self._rbufsize = bufsize
- self._wbufsize = bufsize
- self._rbuf = "" # A string
- self._wbuf = [] # A list of strings
- self._close = close
-
- def _getclosed(self):
- return self._sock is None
- closed = property(_getclosed, doc="True if the file is closed")
-
- def close(self):
- try:
- if self._sock:
- self.flush()
- finally:
- if self._close:
- self._sock.close()
- self._sock = None
-
- def __del__(self):
- try:
- self.close()
- except:
- # close() may fail if __init__ didn't complete
- pass
-
- def flush(self):
- if self._wbuf:
- buffer = "".join(self._wbuf)
- self._wbuf = []
- self._sock.sendall(buffer)
-
- def fileno(self):
- return self._sock.fileno()
-
- def write(self, data):
- data = str(data) # XXX Should really reject non-string non-buffers
- if not data:
- return
- self._wbuf.append(data)
- if (self._wbufsize == 0 or
- self._wbufsize == 1 and '\n' in data or
- self._get_wbuf_len() >= self._wbufsize):
- self.flush()
-
- def writelines(self, list):
- # XXX We could do better here for very long lists
- # XXX Should really reject non-string non-buffers
- self._wbuf.extend(filter(None, map(str, list)))
- if (self._wbufsize <= 1 or
- self._get_wbuf_len() >= self._wbufsize):
- self.flush()
-
- def _get_wbuf_len(self):
- buf_len = 0
- for x in self._wbuf:
- buf_len += len(x)
- return buf_len
-
- def read(self, size=-1):
- data = self._rbuf
- if size < 0:
- # Read until EOF
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- if self._rbufsize <= 1:
- recv_size = self.default_bufsize
- else:
- recv_size = self._rbufsize
- while True:
- data = self._sock.recv(recv_size)
- if not data:
- break
- buffers.append(data)
- return "".join(buffers)
- else:
- # Read until size bytes or EOF seen, whichever comes first
- buf_len = len(data)
- if buf_len >= size:
- self._rbuf = data[size:]
- return data[:size]
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- while True:
- left = size - buf_len
- recv_size = max(self._rbufsize, left)
- data = self._sock.recv(recv_size)
- if not data:
- break
- buffers.append(data)
- n = len(data)
- if n >= left:
- self._rbuf = data[left:]
- buffers[-1] = data[:left]
- break
- buf_len += n
- return "".join(buffers)
-
- def readline(self, size=-1):
- data = self._rbuf
- if size < 0:
- # Read until \n or EOF, whichever comes first
- if self._rbufsize <= 1:
- # Speed up unbuffered case
- assert data == ""
- buffers = []
- recv = self._sock.recv
- while data != "\n":
- data = recv(1)
- if not data:
- break
- buffers.append(data)
- return "".join(buffers)
- nl = data.find('\n')
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- return data[:nl]
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- while True:
- data = self._sock.recv(self._rbufsize)
- if not data:
- break
- buffers.append(data)
- nl = data.find('\n')
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- buffers[-1] = data[:nl]
- break
- return "".join(buffers)
- else:
- # Read until size bytes or \n or EOF seen, whichever comes first
- nl = data.find('\n', 0, size)
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- return data[:nl]
- buf_len = len(data)
- if buf_len >= size:
- self._rbuf = data[size:]
- return data[:size]
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- while True:
- data = self._sock.recv(self._rbufsize)
- if not data:
- break
- buffers.append(data)
- left = size - buf_len
- nl = data.find('\n', 0, left)
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- buffers[-1] = data[:nl]
- break
- n = len(data)
- if n >= left:
- self._rbuf = data[left:]
- buffers[-1] = data[:left]
- break
- buf_len += n
- return "".join(buffers)
-
- def readlines(self, sizehint=0):
- total = 0
- list = []
- while True:
- line = self.readline()
- if not line:
- break
- list.append(line)
- total += len(line)
- if sizehint and total >= sizehint:
- break
- return list
-
- # Iterator protocols
-
- def __iter__(self):
- return self
-
- def __next__(self):
- line = self.readline()
- if not line:
- raise StopIteration
- return line
-
-
def create_connection(address, timeout=None):
"""Connect to address (host, port) with an optional timeout.