summaryrefslogtreecommitdiffstats
path: root/Lib/socket.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/socket.py')
-rw-r--r--Lib/socket.py120
1 files changed, 109 insertions, 11 deletions
diff --git a/Lib/socket.py b/Lib/socket.py
index 8d3508a..1b3920a 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -89,22 +89,67 @@ if sys.platform.lower().startswith("win"):
# True if os.dup() can duplicate socket descriptors.
# (On Windows at least, os.dup only works on files)
-_can_dup_socket = hasattr(_socket, "dup")
+_can_dup_socket = hasattr(_socket.socket, "dup")
if _can_dup_socket:
def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0):
nfd = os.dup(fd)
return socket(family, type, proto, fileno=nfd)
+class SocketCloser:
+
+ """Helper to manage socket close() logic for makefile().
+
+ The OS socket should not be closed until the socket and all
+ of its makefile-children are closed. If the refcount is zero
+ when socket.close() is called, this is easy: Just close the
+ socket. If the refcount is non-zero when socket.close() is
+ called, then the real close should not occur until the last
+ makefile-child is closed.
+ """
+
+ def __init__(self, sock):
+ self._sock = sock
+ self._makefile_refs = 0
+ # Test whether the socket is open.
+ try:
+ sock.fileno()
+ self._socket_open = True
+ except error:
+ self._socket_open = False
+
+ def socket_close(self):
+ self._socket_open = False
+ self.close()
+
+ def makefile_open(self):
+ self._makefile_refs += 1
+
+ def makefile_close(self):
+ self._makefile_refs -= 1
+ self.close()
+
+ def close(self):
+ if not (self._socket_open or self._makefile_refs):
+ self._sock._real_close()
+
class socket(_socket.socket):
"""A subclass of _socket.socket adding the makefile() method."""
- __slots__ = ["__weakref__"]
+ __slots__ = ["__weakref__", "_closer"]
if not _can_dup_socket:
__slots__.append("_base")
+ def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
+ if fileno is None:
+ _socket.socket.__init__(self, family, type, proto)
+ else:
+ _socket.socket.__init__(self, family, type, proto, fileno)
+ # Defer creating a SocketCloser until makefile() is actually called.
+ self._closer = None
+
def __repr__(self):
"""Wrap __repr__() to reveal the real class name."""
s = _socket.socket.__repr__(self)
@@ -128,14 +173,6 @@ class socket(_socket.socket):
conn.close()
return wrapper, addr
- if not _can_dup_socket:
- def close(self):
- """Wrap close() to close the _base as well."""
- _socket.socket.close(self)
- base = getattr(self, "_base", None)
- if base is not None:
- base.close()
-
def makefile(self, mode="r", buffering=None, *,
encoding=None, newline=None):
"""Return an I/O stream connected to the socket.
@@ -156,7 +193,9 @@ class socket(_socket.socket):
rawmode += "r"
if writing:
rawmode += "w"
- raw = io.SocketIO(self, rawmode)
+ if self._closer is None:
+ self._closer = SocketCloser(self)
+ raw = SocketIO(self, rawmode, self._closer)
if buffering is None:
buffering = -1
if buffering < 0:
@@ -183,6 +222,65 @@ class socket(_socket.socket):
text.mode = mode
return text
+ def close(self):
+ if self._closer is None:
+ self._real_close()
+ else:
+ self._closer.socket_close()
+
+ # _real_close calls close on the _socket.socket base class.
+
+ if not _can_dup_socket:
+ def _real_close(self):
+ _socket.socket.close(self)
+ base = getattr(self, "_base", None)
+ if base is not None:
+ self._base = None
+ base.close()
+ else:
+ def _real_close(self):
+ _socket.socket.close(self)
+
+
+class SocketIO(io.RawIOBase):
+
+ """Raw I/O implementation for stream sockets.
+
+ This class supports the makefile() method on sockets. It provides
+ the raw I/O interface on top of a socket object.
+ """
+
+ # XXX More docs
+
+ def __init__(self, sock, mode, closer):
+ assert mode in ("r", "w", "rw")
+ io.RawIOBase.__init__(self)
+ self._sock = sock
+ self._mode = mode
+ self._closer = closer
+ closer.makefile_open()
+
+ def readinto(self, b):
+ return self._sock.recv_into(b)
+
+ def write(self, b):
+ return self._sock.send(b)
+
+ def readable(self):
+ return "r" in self._mode
+
+ def writable(self):
+ return "w" in self._mode
+
+ def fileno(self):
+ return self._sock.fileno()
+
+ def close(self):
+ if self.closed:
+ return
+ self._closer.makefile_close()
+ io.RawIOBase.close(self)
+
def getfqdn(name=''):
"""Get fully qualified domain name from name.