summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/streams.py35
-rw-r--r--Lib/test/test_asyncio/test_streams.py42
2 files changed, 57 insertions, 20 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
index 79adf02..d9a9f5e 100644
--- a/Lib/asyncio/streams.py
+++ b/Lib/asyncio/streams.py
@@ -352,6 +352,8 @@ class StreamWriter:
assert reader is None or isinstance(reader, StreamReader)
self._reader = reader
self._loop = loop
+ self._complete_fut = self._loop.create_future()
+ self._complete_fut.set_result(None)
def __repr__(self):
info = [self.__class__.__name__, f'transport={self._transport!r}']
@@ -365,9 +367,33 @@ class StreamWriter:
def write(self, data):
self._transport.write(data)
+ return self._fast_drain()
def writelines(self, data):
self._transport.writelines(data)
+ return self._fast_drain()
+
+ def _fast_drain(self):
+ # The helper tries to use fast-path to return already existing complete future
+ # object if underlying transport is not paused and actual waiting for writing
+ # resume is not needed
+ if self._reader is not None:
+ # this branch will be simplified after merging reader with writer
+ exc = self._reader.exception()
+ if exc is not None:
+ fut = self._loop.create_future()
+ fut.set_exception(exc)
+ return fut
+ if not self._transport.is_closing():
+ if self._protocol._connection_lost:
+ fut = self._loop.create_future()
+ fut.set_exception(ConnectionResetError('Connection lost'))
+ return fut
+ if not self._protocol._paused:
+ # fast path, the stream is not paused
+ # no need to wait for resume signal
+ return self._complete_fut
+ return self._loop.create_task(self.drain())
def write_eof(self):
return self._transport.write_eof()
@@ -377,6 +403,7 @@ class StreamWriter:
def close(self):
self._transport.close()
+ return self._protocol._get_close_waiter(self)
def is_closing(self):
return self._transport.is_closing()
@@ -408,14 +435,6 @@ class StreamWriter:
raise ConnectionResetError('Connection lost')
await self._protocol._drain_helper()
- async def aclose(self):
- self.close()
- await self.wait_closed()
-
- async def awrite(self, data):
- self.write(data)
- await self.drain()
-
class StreamReader:
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
index 905141c..bf93f30 100644
--- a/Lib/test/test_asyncio/test_streams.py
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -1035,24 +1035,42 @@ os.close(fd)
messages[0]['message'])
def test_async_writer_api(self):
+ async def inner(httpd):
+ rd, wr = await asyncio.open_connection(*httpd.address)
+
+ await wr.write(b'GET / HTTP/1.0\r\n\r\n')
+ data = await rd.readline()
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ data = await rd.read()
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+ await wr.close()
+
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
with test_utils.run_test_server() as httpd:
- rd, wr = self.loop.run_until_complete(
- asyncio.open_connection(*httpd.address,
- loop=self.loop))
+ self.loop.run_until_complete(inner(httpd))
- f = wr.awrite(b'GET / HTTP/1.0\r\n\r\n')
- self.loop.run_until_complete(f)
- f = rd.readline()
- data = self.loop.run_until_complete(f)
+ self.assertEqual(messages, [])
+
+ def test_async_writer_api(self):
+ async def inner(httpd):
+ rd, wr = await asyncio.open_connection(*httpd.address)
+
+ await wr.write(b'GET / HTTP/1.0\r\n\r\n')
+ data = await rd.readline()
self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
- f = rd.read()
- data = self.loop.run_until_complete(f)
+ data = await rd.read()
self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
- f = wr.aclose()
- self.loop.run_until_complete(f)
+ wr.close()
+ with self.assertRaises(ConnectionResetError):
+ await wr.write(b'data')
+
+ messages = []
+ self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
+
+ with test_utils.run_test_server() as httpd:
+ self.loop.run_until_complete(inner(httpd))
self.assertEqual(messages, [])
@@ -1066,7 +1084,7 @@ os.close(fd)
asyncio.open_connection(*httpd.address,
loop=self.loop))
- f = wr.aclose()
+ f = wr.close()
self.loop.run_until_complete(f)
assert rd.at_eof()
f = rd.read()