diff options
Diffstat (limited to 'Lib/socket.py')
-rw-r--r-- | Lib/socket.py | 120 |
1 files changed, 109 insertions, 11 deletions
diff --git a/Lib/socket.py b/Lib/socket.py index 8d3508a..1b3920a 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -89,22 +89,67 @@ if sys.platform.lower().startswith("win"): # True if os.dup() can duplicate socket descriptors. # (On Windows at least, os.dup only works on files) -_can_dup_socket = hasattr(_socket, "dup") +_can_dup_socket = hasattr(_socket.socket, "dup") if _can_dup_socket: def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0): nfd = os.dup(fd) return socket(family, type, proto, fileno=nfd) +class SocketCloser: + + """Helper to manage socket close() logic for makefile(). + + The OS socket should not be closed until the socket and all + of its makefile-children are closed. If the refcount is zero + when socket.close() is called, this is easy: Just close the + socket. If the refcount is non-zero when socket.close() is + called, then the real close should not occur until the last + makefile-child is closed. + """ + + def __init__(self, sock): + self._sock = sock + self._makefile_refs = 0 + # Test whether the socket is open. + try: + sock.fileno() + self._socket_open = True + except error: + self._socket_open = False + + def socket_close(self): + self._socket_open = False + self.close() + + def makefile_open(self): + self._makefile_refs += 1 + + def makefile_close(self): + self._makefile_refs -= 1 + self.close() + + def close(self): + if not (self._socket_open or self._makefile_refs): + self._sock._real_close() + class socket(_socket.socket): """A subclass of _socket.socket adding the makefile() method.""" - __slots__ = ["__weakref__"] + __slots__ = ["__weakref__", "_closer"] if not _can_dup_socket: __slots__.append("_base") + def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None): + if fileno is None: + _socket.socket.__init__(self, family, type, proto) + else: + _socket.socket.__init__(self, family, type, proto, fileno) + # Defer creating a SocketCloser until makefile() is actually called. + self._closer = None + def __repr__(self): """Wrap __repr__() to reveal the real class name.""" s = _socket.socket.__repr__(self) @@ -128,14 +173,6 @@ class socket(_socket.socket): conn.close() return wrapper, addr - if not _can_dup_socket: - def close(self): - """Wrap close() to close the _base as well.""" - _socket.socket.close(self) - base = getattr(self, "_base", None) - if base is not None: - base.close() - def makefile(self, mode="r", buffering=None, *, encoding=None, newline=None): """Return an I/O stream connected to the socket. @@ -156,7 +193,9 @@ class socket(_socket.socket): rawmode += "r" if writing: rawmode += "w" - raw = io.SocketIO(self, rawmode) + if self._closer is None: + self._closer = SocketCloser(self) + raw = SocketIO(self, rawmode, self._closer) if buffering is None: buffering = -1 if buffering < 0: @@ -183,6 +222,65 @@ class socket(_socket.socket): text.mode = mode return text + def close(self): + if self._closer is None: + self._real_close() + else: + self._closer.socket_close() + + # _real_close calls close on the _socket.socket base class. + + if not _can_dup_socket: + def _real_close(self): + _socket.socket.close(self) + base = getattr(self, "_base", None) + if base is not None: + self._base = None + base.close() + else: + def _real_close(self): + _socket.socket.close(self) + + +class SocketIO(io.RawIOBase): + + """Raw I/O implementation for stream sockets. + + This class supports the makefile() method on sockets. It provides + the raw I/O interface on top of a socket object. + """ + + # XXX More docs + + def __init__(self, sock, mode, closer): + assert mode in ("r", "w", "rw") + io.RawIOBase.__init__(self) + self._sock = sock + self._mode = mode + self._closer = closer + closer.makefile_open() + + def readinto(self, b): + return self._sock.recv_into(b) + + def write(self, b): + return self._sock.send(b) + + def readable(self): + return "r" in self._mode + + def writable(self): + return "w" in self._mode + + def fileno(self): + return self._sock.fileno() + + def close(self): + if self.closed: + return + self._closer.makefile_close() + io.RawIOBase.close(self) + def getfqdn(name=''): """Get fully qualified domain name from name. |