diff options
Diffstat (limited to 'Lib/test/test_asyncio/test_streams.py')
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 361 |
1 files changed, 361 insertions, 0 deletions
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py new file mode 100644 index 0000000..011a09d --- /dev/null +++ b/Lib/test/test_asyncio/test_streams.py @@ -0,0 +1,361 @@ +"""Tests for streams.py.""" + +import gc +import ssl +import unittest +import unittest.mock + +from asyncio import events +from asyncio import streams +from asyncio import tasks +from asyncio import test_utils + + +class StreamReaderTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + + @unittest.mock.patch('asyncio.streams.events') + def test_ctor_global_loop(self, m_events): + stream = streams.StreamReader() + self.assertIs(stream.loop, m_events.get_event_loop.return_value) + + def test_open_connection(self): + with test_utils.run_test_server() as httpd: + f = streams.open_connection(*httpd.address, loop=self.loop) + reader, writer = self.loop.run_until_complete(f) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_open_connection_no_loop_ssl(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + try: + events.set_event_loop(self.loop) + f = streams.open_connection(*httpd.address, + ssl=test_utils.dummy_ssl_context()) + reader, writer = self.loop.run_until_complete(f) + finally: + events.set_event_loop(None) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + + def test_open_connection_error(self): + with test_utils.run_test_server() as httpd: + f = streams.open_connection(*httpd.address, loop=self.loop) + reader, writer = self.loop.run_until_complete(f) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + + writer.close() + test_utils.run_briefly(self.loop) + + def test_feed_empty_data(self): + stream = streams.StreamReader(loop=self.loop) + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(30), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(1024), loop=self.loop) + + def cb(): + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(-1), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline(), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.loop.call_soon(cb) + + line = self.loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + stream = streams.StreamReader(3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + stream = streams.StreamReader(7, loop=self.loop) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.loop.call_soon(cb) + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + self.loop.run_until_complete(stream.readline()) + + data = self.loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader(loop=self.loop) + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader(loop=self.loop) + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader(loop=self.loop) + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader(loop=self.loop) + + @tasks.coroutine + def set_err(): + stream.set_exception(ValueError()) + + @tasks.coroutine + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) + + self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) + + self.assertRaises(ValueError, t1.result) + + def test_exception_cancel(self): + stream = streams.StreamReader(loop=self.loop) + + @tasks.coroutine + def read_a_line(): + yield from stream.readline() + + t = tasks.Task(read_a_line(), loop=self.loop) + test_utils.run_briefly(self.loop) + t.cancel() + test_utils.run_briefly(self.loop) + # The following line fails if set_exception() isn't careful. + stream.set_exception(RuntimeError('message')) + test_utils.run_briefly(self.loop) + self.assertIs(stream.waiter, None) + + +if __name__ == '__main__': + unittest.main() |