diff options
author | Kumar Aditya <59607654+kumaraditya303@users.noreply.github.com> | 2022-08-29 18:31:11 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-08-29 18:31:11 (GMT) |
commit | e5b2453e61ba5376831093236d598ef5f9f1de61 (patch) | |
tree | a4ad8da010853f287dbb68a31eac23384be25269 /Lib | |
parent | 3d180e3ab21c5d41d1c46e3ef349b30ba409f300 (diff) | |
download | cpython-e5b2453e61ba5376831093236d598ef5f9f1de61.zip cpython-e5b2453e61ba5376831093236d598ef5f9f1de61.tar.gz cpython-e5b2453e61ba5376831093236d598ef5f9f1de61.tar.bz2 |
GH-74116: Allow multiple drain waiters for asyncio.StreamWriter (GH-94705)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/streams.py | 35 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 19 |
2 files changed, 35 insertions, 19 deletions
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 614b2cd..c4d837a 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -2,6 +2,7 @@ __all__ = ( 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 'open_connection', 'start_server') +import collections import socket import sys import weakref @@ -128,7 +129,7 @@ class FlowControlMixin(protocols.Protocol): else: self._loop = loop self._paused = False - self._drain_waiter = None + self._drain_waiters = collections.deque() self._connection_lost = False def pause_writing(self): @@ -143,38 +144,34 @@ class FlowControlMixin(protocols.Protocol): if self._loop.get_debug(): logger.debug("%r resumes writing", self) - waiter = self._drain_waiter - if waiter is not None: - self._drain_waiter = None + for waiter in self._drain_waiters: if not waiter.done(): waiter.set_result(None) def connection_lost(self, exc): self._connection_lost = True - # Wake up the writer if currently paused. + # Wake up the writer(s) if currently paused. if not self._paused: return - waiter = self._drain_waiter - if waiter is None: - return - self._drain_waiter = None - if waiter.done(): - return - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) + + for waiter in self._drain_waiters: + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) async def _drain_helper(self): if self._connection_lost: raise ConnectionResetError('Connection lost') if not self._paused: return - waiter = self._drain_waiter - assert waiter is None or waiter.cancelled() waiter = self._loop.create_future() - self._drain_waiter = waiter - await waiter + self._drain_waiters.append(waiter) + try: + await waiter + finally: + self._drain_waiters.remove(waiter) def _get_close_waiter(self, stream): raise NotImplementedError diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 098a0da..0c49099 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -864,6 +864,25 @@ os.close(fd) self.assertEqual(cm.filename, __file__) self.assertIs(protocol._loop, self.loop) + def test_multiple_drain(self): + # See https://github.com/python/cpython/issues/74116 + drained = 0 + + async def drainer(stream): + nonlocal drained + await stream._drain_helper() + drained += 1 + + async def main(): + loop = asyncio.get_running_loop() + stream = asyncio.streams.FlowControlMixin(loop) + stream.pause_writing() + loop.call_later(0.1, stream.resume_writing) + await asyncio.gather(*[drainer(stream) for _ in range(10)]) + self.assertEqual(drained, 10) + + self.loop.run_until_complete(main()) + def test_drain_raises(self): # See http://bugs.python.org/issue25441 |