summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/socket.rst3
-rw-r--r--Doc/whatsnew/3.2.rst6
-rw-r--r--Lib/socket.py7
-rw-r--r--Lib/test/test_socket.py44
4 files changed, 60 insertions, 0 deletions
diff --git a/Doc/library/socket.rst b/Doc/library/socket.rst
index 581756f..a7656c1 100644
--- a/Doc/library/socket.rst
+++ b/Doc/library/socket.rst
@@ -213,6 +213,9 @@ The module :mod:`socket` exports the following constants and functions:
.. versionchanged:: 3.2
*source_address* was added.
+ .. versionchanged:: 3.2
+ support for the :keyword:`with` statement was added.
+
.. function:: getaddrinfo(host, port, family=0, type=0, proto=0, flags=0)
diff --git a/Doc/whatsnew/3.2.rst b/Doc/whatsnew/3.2.rst
index 4969623..7d8970b 100644
--- a/Doc/whatsnew/3.2.rst
+++ b/Doc/whatsnew/3.2.rst
@@ -389,6 +389,12 @@ New, Improved, and Deprecated Modules
(Contributed by Giampaolo RodolĂ ; :issue:`8807`.)
+* :func:`socket.create_connection` now supports the context manager protocol
+ to unconditionally consume :exc:`socket.error` exceptions and to close the
+ socket when done.
+
+ (Contributed by Giampaolo RodolĂ ; :issue:`9794`.)
+
Multi-threading
===============
diff --git a/Lib/socket.py b/Lib/socket.py
index 004d6a9..bfc9a726 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -93,6 +93,13 @@ class socket(_socket.socket):
self._io_refs = 0
self._closed = False
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ if not self._closed:
+ self.close()
+
def __repr__(self):
"""Wrap __repr__() to reveal the real class name."""
s = _socket.socket.__repr__(self)
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 19c494b..81f9cdf 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -1595,6 +1595,49 @@ class TIPCThreadableTest (unittest.TestCase, ThreadableTest):
self.cli.close()
+@unittest.skipUnless(thread, 'Threading required for this test.')
+class ContextManagersTest(ThreadedTCPSocketTest):
+
+ def _testSocketClass(self):
+ # base test
+ with socket.socket() as sock:
+ self.assertFalse(sock._closed)
+ self.assertTrue(sock._closed)
+ # close inside with block
+ with socket.socket() as sock:
+ sock.close()
+ self.assertTrue(sock._closed)
+ # exception inside with block
+ with socket.socket() as sock:
+ self.assertRaises(socket.error, sock.sendall, b'foo')
+ self.assertTrue(sock._closed)
+
+ def testCreateConnectionBase(self):
+ conn, addr = self.serv.accept()
+ data = conn.recv(1024)
+ conn.sendall(data)
+
+ def _testCreateConnectionBase(self):
+ address = self.serv.getsockname()
+ with socket.create_connection(address) as sock:
+ self.assertFalse(sock._closed)
+ sock.sendall(b'foo')
+ self.assertEqual(sock.recv(1024), b'foo')
+ self.assertTrue(sock._closed)
+
+ def testCreateConnectionClose(self):
+ conn, addr = self.serv.accept()
+ data = conn.recv(1024)
+ conn.sendall(data)
+
+ def _testCreateConnectionClose(self):
+ address = self.serv.getsockname()
+ with socket.create_connection(address) as sock:
+ sock.close()
+ self.assertTrue(sock._closed)
+ self.assertRaises(socket.error, sock.sendall, b'foo')
+
+
def test_main():
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ]
@@ -1609,6 +1652,7 @@ def test_main():
NetworkConnectionNoServer,
NetworkConnectionAttributesTest,
NetworkConnectionBehaviourTest,
+ ContextManagersTest,
])
if hasattr(socket, "socketpair"):
tests.append(BasicSocketPairTest)