summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorAndrew Svetlov <andrew.svetlov@gmail.com>2018-09-13 23:53:49 (GMT)
committerGitHub <noreply@github.com>2018-09-13 23:53:49 (GMT)
commit11194c877c902a6c3b769d85be887c2272e0a541 (patch)
tree8181f75217256e9035a177cee53916e4fa6eacab /Lib
parent413118ebf3162418639a5c4af14b02d26571a02c (diff)
downloadcpython-11194c877c902a6c3b769d85be887c2272e0a541.zip
cpython-11194c877c902a6c3b769d85be887c2272e0a541.tar.gz
cpython-11194c877c902a6c3b769d85be887c2272e0a541.tar.bz2
bpo-34666: Implement stream.awrite() and stream.aclose() (GH-9274)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/streams.py10
-rw-r--r--Lib/test/test_asyncio/test_streams.py22
2 files changed, 31 insertions, 1 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index e7fb22e..0afc66a 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -348,7 +348,7 @@ class StreamWriter:
# a reader can be garbage collected
# after connection closing
self._protocol._untrack_reader()
- return self._transport.close()
+ self._transport.close()
def is_closing(self):
return self._transport.is_closing()
@@ -381,6 +381,14 @@ class StreamWriter:
await sleep(0, loop=self._loop)
await self._protocol._drain_helper()
+ async def aclose(self):
+ self.close()
+ await self.wait_closed()
+
+ async def awrite(self, data):
+ self.write(data)
+ await self.drain()
+
class StreamReader:
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
index 67ac9d9..d8e3715 100644
--- a/Lib/test/test_asyncio/test_streams.py
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -964,6 +964,28 @@ os.close(fd)
'call "stream.close()" explicitly.',
messages[0]['message'])
+ def test_async_writer_api(self):
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ with test_utils.run_test_server() as httpd:
+ rd, wr = self.loop.run_until_complete(
+ asyncio.open_connection(*httpd.address,
+ loop=self.loop))
+
+ f = wr.awrite(b'GET / HTTP/1.0\r\n\r\n')
+ self.loop.run_until_complete(f)
+ f = rd.readline()
+ data = self.loop.run_until_complete(f)
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ f = rd.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ f = wr.aclose()
+ self.loop.run_until_complete(f)
+
+ self.assertEqual(messages, [])
+
if __name__ == '__main__':
unittest.main()