From 15899209770970c0320b47d1ae5749b998152c9f Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 19 Feb 2014 11:10:52 -0500 Subject: asyncio: WriteTransport.set_write_buffer_size to call _maybe_pause_protocol --- Lib/asyncio/transports.py | 8 ++++++-- Lib/test/test_asyncio/test_transports.py | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/Lib/asyncio/transports.py b/Lib/asyncio/transports.py index 5b975aa..5f674f9 100644 --- a/Lib/asyncio/transports.py +++ b/Lib/asyncio/transports.py @@ -241,7 +241,7 @@ class _FlowControlMixin(Transport): def __init__(self, extra=None): super().__init__(extra) self._protocol_paused = False - self.set_write_buffer_limits() + self._set_write_buffer_limits() def _maybe_pause_protocol(self): size = self.get_write_buffer_size() @@ -273,7 +273,7 @@ class _FlowControlMixin(Transport): 'protocol': self._protocol, }) - def set_write_buffer_limits(self, high=None, low=None): + def _set_write_buffer_limits(self, high=None, low=None): if high is None: if low is None: high = 64*1024 @@ -287,5 +287,9 @@ class _FlowControlMixin(Transport): self._high_water = high self._low_water = low + def set_write_buffer_limits(self, high=None, low=None): + self._set_write_buffer_limits(high=high, low=low) + self._maybe_pause_protocol() + def get_write_buffer_size(self): raise NotImplementedError diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py index d16db80..4c64526 100644 --- a/Lib/test/test_asyncio/test_transports.py +++ b/Lib/test/test_asyncio/test_transports.py @@ -4,6 +4,7 @@ import unittest import unittest.mock import asyncio +from asyncio import transports class TransportTests(unittest.TestCase): @@ -60,6 +61,28 @@ class TransportTests(unittest.TestCase): self.assertRaises(NotImplementedError, transport.terminate) self.assertRaises(NotImplementedError, transport.kill) + def test_flowcontrol_mixin_set_write_limits(self): + + class MyTransport(transports._FlowControlMixin, + transports.Transport): + + def get_write_buffer_size(self): + return 512 + + transport = MyTransport() + transport._protocol = unittest.mock.Mock() + + self.assertFalse(transport._protocol_paused) + + with self.assertRaisesRegex(ValueError, 'high.*must be >= low'): + transport.set_write_buffer_limits(high=0, low=1) + + transport.set_write_buffer_limits(high=1024, low=128) + self.assertFalse(transport._protocol_paused) + + transport.set_write_buffer_limits(high=256, low=128) + self.assertTrue(transport._protocol_paused) + if __name__ == '__main__': unittest.main() -- cgit v0.12