diff options
-rw-r--r-- | Doc/library/asyncio-stream.rst | 15 | ||||
-rw-r--r-- | Lib/asyncio/streams.py | 36 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 31 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2018-01-24-15-20-12.bpo-32391.0f8MY9.rst | 1 |
4 files changed, 73 insertions, 10 deletions
diff --git a/Doc/library/asyncio-stream.rst b/Doc/library/asyncio-stream.rst index 6d5cbbc..099b59e 100644 --- a/Doc/library/asyncio-stream.rst +++ b/Doc/library/asyncio-stream.rst @@ -201,6 +201,21 @@ StreamWriter Close the transport: see :meth:`BaseTransport.close`. + .. method:: is_closing() + + Return ``True`` if the writer is closing or is closed. + + .. versionadded:: 3.7 + + .. coroutinemethod:: wait_closed() + + Wait until the writer is closed. + + Should be called after :meth:`close` to wait until the underlying + connection (and the associated transport/protocol pair) is closed. + + .. versionadded:: 3.7 + .. coroutinemethod:: drain() Let the write buffer of the underlying transport a chance to be flushed. diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index eef2b89..9a53ee4 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -224,6 +224,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): self._stream_writer = None self._client_connected_cb = client_connected_cb self._over_ssl = False + self._closed = self._loop.create_future() def connection_made(self, transport): self._stream_reader.set_transport(transport) @@ -243,6 +244,11 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): self._stream_reader.feed_eof() else: self._stream_reader.set_exception(exc) + if not self._closed.done(): + if exc is None: + self._closed.set_result(None) + else: + self._closed.set_exception(exc) super().connection_lost(exc) self._stream_reader = None self._stream_writer = None @@ -259,6 +265,13 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): return False return True + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + closed = self._closed + if closed.done() and not closed.cancelled(): + closed.exception() + class StreamWriter: """Wraps a Transport. @@ -303,6 +316,12 @@ class StreamWriter: def close(self): return self._transport.close() + def is_closing(self): + return self._transport.is_closing() + + async def wait_closed(self): + await self._protocol._closed + def get_extra_info(self, name, default=None): return self._transport.get_extra_info(name, default) @@ -318,15 +337,14 @@ class StreamWriter: exc = self._reader.exception() if exc is not None: raise exc - if self._transport is not None: - if self._transport.is_closing(): - # Yield to the event loop so connection_lost() may be - # called. Without this, _drain_helper() would return - # immediately, and code that calls - # write(...); await drain() - # in a loop would never call connection_lost(), so it - # would not see an error when the socket is closed. - await sleep(0, loop=self._loop) + if self._transport.is_closing(): + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); await drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await sleep(0, loop=self._loop) await self._protocol._drain_helper() diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 7a0762c..63fa13f 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -19,7 +19,7 @@ import asyncio from test.test_asyncio import utils as test_utils -class StreamReaderTests(test_utils.TestCase): +class StreamTests(test_utils.TestCase): DATA = b'line1\nline2\nline3\n' @@ -860,6 +860,35 @@ os.close(fd) self.assertEqual(str(e), str(e2)) self.assertEqual(e.consumed, e2.consumed) + def test_wait_closed_on_close(self): + with test_utils.run_test_server() as httpd: + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) + + wr.write(b'GET / HTTP/1.0\r\n\r\n') + f = rd.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = rd.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertFalse(wr.is_closing()) + wr.close() + self.assertTrue(wr.is_closing()) + self.loop.run_until_complete(wr.wait_closed()) + + def test_wait_closed_on_close_with_unread_data(self): + with test_utils.run_test_server() as httpd: + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) + + wr.write(b'GET / HTTP/1.0\r\n\r\n') + f = rd.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + wr.close() + self.loop.run_until_complete(wr.wait_closed()) + if __name__ == '__main__': unittest.main() diff --git a/Misc/NEWS.d/next/Library/2018-01-24-15-20-12.bpo-32391.0f8MY9.rst b/Misc/NEWS.d/next/Library/2018-01-24-15-20-12.bpo-32391.0f8MY9.rst new file mode 100644 index 0000000..6e09227 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-01-24-15-20-12.bpo-32391.0f8MY9.rst @@ -0,0 +1 @@ +Implement :meth:`asyncio.StreamWriter.wait_closed` and :meth:`asyncio.StreamWriter.is_closing` methods |