summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/io.py28
-rw-r--r--Lib/socket.py120
-rw-r--r--Lib/test/test_socket.py39
3 files changed, 148 insertions, 39 deletions
diff --git a/Lib/io.py b/Lib/io.py
index 9a4f956..43695be 100644
--- a/Lib/io.py
+++ b/Lib/io.py
@@ -442,34 +442,6 @@ class FileIO(_fileio._FileIO, RawIOBase):
return self._mode
-class SocketIO(RawIOBase):
-
- """Raw I/O implementation for stream sockets."""
-
- # XXX More docs
-
- def __init__(self, sock, mode):
- assert mode in ("r", "w", "rw")
- RawIOBase.__init__(self)
- self._sock = sock
- self._mode = mode
-
- def readinto(self, b):
- return self._sock.recv_into(b)
-
- def write(self, b):
- return self._sock.send(b)
-
- def readable(self):
- return "r" in self._mode
-
- def writable(self):
- return "w" in self._mode
-
- def fileno(self):
- return self._sock.fileno()
-
-
class BufferedIOBase(IOBase):
"""Base class for buffered IO objects.
diff --git a/Lib/socket.py b/Lib/socket.py
index 8d3508a..1b3920a 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -89,22 +89,67 @@ if sys.platform.lower().startswith("win"):
# True if os.dup() can duplicate socket descriptors.
# (On Windows at least, os.dup only works on files)
-_can_dup_socket = hasattr(_socket, "dup")
+_can_dup_socket = hasattr(_socket.socket, "dup")
if _can_dup_socket:
def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0):
nfd = os.dup(fd)
return socket(family, type, proto, fileno=nfd)
+class SocketCloser:
+
+ """Helper to manage socket close() logic for makefile().
+
+ The OS socket should not be closed until the socket and all
+ of its makefile-children are closed. If the refcount is zero
+ when socket.close() is called, this is easy: Just close the
+ socket. If the refcount is non-zero when socket.close() is
+ called, then the real close should not occur until the last
+ makefile-child is closed.
+ """
+
+ def __init__(self, sock):
+ self._sock = sock
+ self._makefile_refs = 0
+ # Test whether the socket is open.
+ try:
+ sock.fileno()
+ self._socket_open = True
+ except error:
+ self._socket_open = False
+
+ def socket_close(self):
+ self._socket_open = False
+ self.close()
+
+ def makefile_open(self):
+ self._makefile_refs += 1
+
+ def makefile_close(self):
+ self._makefile_refs -= 1
+ self.close()
+
+ def close(self):
+ if not (self._socket_open or self._makefile_refs):
+ self._sock._real_close()
+
class socket(_socket.socket):
"""A subclass of _socket.socket adding the makefile() method."""
- __slots__ = ["__weakref__"]
+ __slots__ = ["__weakref__", "_closer"]
if not _can_dup_socket:
__slots__.append("_base")
+ def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None):
+ if fileno is None:
+ _socket.socket.__init__(self, family, type, proto)
+ else:
+ _socket.socket.__init__(self, family, type, proto, fileno)
+ # Defer creating a SocketCloser until makefile() is actually called.
+ self._closer = None
+
def __repr__(self):
"""Wrap __repr__() to reveal the real class name."""
s = _socket.socket.__repr__(self)
@@ -128,14 +173,6 @@ class socket(_socket.socket):
conn.close()
return wrapper, addr
- if not _can_dup_socket:
- def close(self):
- """Wrap close() to close the _base as well."""
- _socket.socket.close(self)
- base = getattr(self, "_base", None)
- if base is not None:
- base.close()
-
def makefile(self, mode="r", buffering=None, *,
encoding=None, newline=None):
"""Return an I/O stream connected to the socket.
@@ -156,7 +193,9 @@ class socket(_socket.socket):
rawmode += "r"
if writing:
rawmode += "w"
- raw = io.SocketIO(self, rawmode)
+ if self._closer is None:
+ self._closer = SocketCloser(self)
+ raw = SocketIO(self, rawmode, self._closer)
if buffering is None:
buffering = -1
if buffering < 0:
@@ -183,6 +222,65 @@ class socket(_socket.socket):
text.mode = mode
return text
+ def close(self):
+ if self._closer is None:
+ self._real_close()
+ else:
+ self._closer.socket_close()
+
+ # _real_close calls close on the _socket.socket base class.
+
+ if not _can_dup_socket:
+ def _real_close(self):
+ _socket.socket.close(self)
+ base = getattr(self, "_base", None)
+ if base is not None:
+ self._base = None
+ base.close()
+ else:
+ def _real_close(self):
+ _socket.socket.close(self)
+
+
+class SocketIO(io.RawIOBase):
+
+ """Raw I/O implementation for stream sockets.
+
+ This class supports the makefile() method on sockets. It provides
+ the raw I/O interface on top of a socket object.
+ """
+
+ # XXX More docs
+
+ def __init__(self, sock, mode, closer):
+ assert mode in ("r", "w", "rw")
+ io.RawIOBase.__init__(self)
+ self._sock = sock
+ self._mode = mode
+ self._closer = closer
+ closer.makefile_open()
+
+ def readinto(self, b):
+ return self._sock.recv_into(b)
+
+ def write(self, b):
+ return self._sock.send(b)
+
+ def readable(self):
+ return "r" in self._mode
+
+ def writable(self):
+ return "w" in self._mode
+
+ def fileno(self):
+ return self._sock.fileno()
+
+ def close(self):
+ if self.closed:
+ return
+ self._closer.makefile_close()
+ io.RawIOBase.close(self)
+
def getfqdn(name=''):
"""Get fully qualified domain name from name.
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index f2b74ee..a8b65c4 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -163,6 +163,11 @@ class ThreadedUDPSocketTest(SocketUDPTest, ThreadableTest):
self.cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
class SocketConnectedTest(ThreadedTCPSocketTest):
+ """Socket tests for client-server connection.
+
+ self.cli_conn is a client socket connected to the server. The
+ setUp() method guarantees that it is connected to the server.
+ """
def __init__(self, methodName='runTest'):
ThreadedTCPSocketTest.__init__(self, methodName=methodName)
@@ -618,6 +623,10 @@ class TCPCloserTest(ThreadedTCPSocketTest):
self.assertEqual(read, [sd])
self.assertEqual(sd.recv(1), b'')
+ # Calling close() many times should be safe.
+ conn.close()
+ conn.close()
+
def _testClose(self):
self.cli.connect((HOST, PORT))
time.sleep(1.0)
@@ -710,6 +719,16 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
self.cli.send(MSG)
class FileObjectClassTestCase(SocketConnectedTest):
+ """Unit tests for the object returned by socket.makefile()
+
+ self.serv_file is the io object returned by makefile() on
+ the client connection. You can read from this file to
+ get output from the server.
+
+ self.cli_file is the io object returned by makefile() on the
+ server connection. You can write to this file to send output
+ to the client.
+ """
bufsize = -1 # Use default buffer size
@@ -779,6 +798,26 @@ class FileObjectClassTestCase(SocketConnectedTest):
self.cli_file.write(MSG)
self.cli_file.flush()
+ def testCloseAfterMakefile(self):
+ # The file returned by makefile should keep the socket open.
+ self.cli_conn.close()
+ # read until EOF
+ msg = self.serv_file.read()
+ self.assertEqual(msg, MSG)
+
+ def _testCloseAfterMakefile(self):
+ self.cli_file.write(MSG)
+ self.cli_file.flush()
+
+ def testMakefileAfterMakefileClose(self):
+ self.serv_file.close()
+ msg = self.cli_conn.recv(len(MSG))
+ self.assertEqual(msg, MSG)
+
+ def _testMakefileAfterMakefileClose(self):
+ self.cli_file.write(MSG)
+ self.cli_file.flush()
+
def testClosedAttr(self):
self.assert_(not self.serv_file.closed)