diff options
author | Andrew Svetlov <andrew.svetlov@gmail.com> | 2018-09-13 23:53:49 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-13 23:53:49 (GMT) |
commit | 11194c877c902a6c3b769d85be887c2272e0a541 (patch) | |
tree | 8181f75217256e9035a177cee53916e4fa6eacab /Lib | |
parent | 413118ebf3162418639a5c4af14b02d26571a02c (diff) | |
download | cpython-11194c877c902a6c3b769d85be887c2272e0a541.zip cpython-11194c877c902a6c3b769d85be887c2272e0a541.tar.gz cpython-11194c877c902a6c3b769d85be887c2272e0a541.tar.bz2 |
bpo-34666: Implement stream.awrite() and stream.aclose() (GH-9274)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/streams.py | 10 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 22 |
2 files changed, 31 insertions, 1 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index e7fb22e..0afc66a 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -348,7 +348,7 @@ class StreamWriter: # a reader can be garbage collected # after connection closing self._protocol._untrack_reader() - return self._transport.close() + self._transport.close() def is_closing(self): return self._transport.is_closing() @@ -381,6 +381,14 @@ class StreamWriter: await sleep(0, loop=self._loop) await self._protocol._drain_helper() + async def aclose(self): + self.close() + await self.wait_closed() + + async def awrite(self, data): + self.write(data) + await self.drain() + class StreamReader: diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 67ac9d9..d8e3715 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -964,6 +964,28 @@ os.close(fd) 'call "stream.close()" explicitly.', messages[0]['message']) + def test_async_writer_api(self): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with test_utils.run_test_server() as httpd: + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, + loop=self.loop)) + + f = wr.awrite(b'GET / HTTP/1.0\r\n\r\n') + self.loop.run_until_complete(f) + f = rd.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = rd.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + f = wr.aclose() + self.loop.run_until_complete(f) + + self.assertEqual(messages, []) + if __name__ == '__main__': unittest.main() |