diff options
Diffstat (limited to 'Lib/test/test_asyncio/test_streams.py')
-rw-r--r-- | Lib/test/test_asyncio/test_streams.py | 1008 |
1 files changed, 828 insertions, 180 deletions
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index fed6098..df3d7e7 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -1,6 +1,8 @@ """Tests for streams.py.""" +import contextlib import gc +import io import os import queue import pickle @@ -16,6 +18,7 @@ except ImportError: ssl = None import asyncio +from asyncio.streams import _StreamProtocol, _ensure_can_read, _ensure_can_write from test.test_asyncio import utils as test_utils @@ -23,6 +26,24 @@ def tearDownModule(): asyncio.set_event_loop_policy(None) +class StreamModeTests(unittest.TestCase): + def test__ensure_can_read_ok(self): + self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READ)) + self.assertIsNone(_ensure_can_read(asyncio.StreamMode.READWRITE)) + + def test__ensure_can_read_fail(self): + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + _ensure_can_read(asyncio.StreamMode.WRITE) + + def test__ensure_can_write_ok(self): + self.assertIsNone(_ensure_can_write(asyncio.StreamMode.WRITE)) + self.assertIsNone(_ensure_can_write(asyncio.StreamMode.READWRITE)) + + def test__ensure_can_write_fail(self): + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + _ensure_can_write(asyncio.StreamMode.READ) + + class StreamTests(test_utils.TestCase): DATA = b'line1\nline2\nline3\n' @@ -42,13 +63,15 @@ class StreamTests(test_utils.TestCase): @mock.patch('asyncio.streams.events') def test_ctor_global_loop(self, m_events): - stream = asyncio.StreamReader(_asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + _asyncio_internal=True) self.assertIs(stream._loop, m_events.get_event_loop.return_value) def _basetest_open_connection(self, open_connection_fut): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - reader, writer = self.loop.run_until_complete(open_connection_fut) + with self.assertWarns(DeprecationWarning): + reader, writer = self.loop.run_until_complete(open_connection_fut) writer.write(b'GET / HTTP/1.0\r\n\r\n') f = reader.readline() data = self.loop.run_until_complete(f) @@ -76,7 +99,9 @@ class StreamTests(test_utils.TestCase): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) try: - reader, writer = self.loop.run_until_complete(open_connection_fut) + with self.assertWarns(DeprecationWarning): + reader, writer = self.loop.run_until_complete( + open_connection_fut) finally: asyncio.set_event_loop(None) writer.write(b'GET / HTTP/1.0\r\n\r\n') @@ -112,7 +137,8 @@ class StreamTests(test_utils.TestCase): def _basetest_open_connection_error(self, open_connection_fut): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - reader, writer = self.loop.run_until_complete(open_connection_fut) + with self.assertWarns(DeprecationWarning): + reader, writer = self.loop.run_until_complete(open_connection_fut) writer._protocol.connection_lost(ZeroDivisionError()) f = reader.read() with self.assertRaises(ZeroDivisionError): @@ -135,23 +161,26 @@ class StreamTests(test_utils.TestCase): self._basetest_open_connection_error(conn_fut) def test_feed_empty_data(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'') self.assertEqual(b'', stream._buffer) def test_feed_nonempty_data(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) self.assertEqual(self.DATA, stream._buffer) def test_read_zero(self): # Read zero bytes. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.read(0)) @@ -160,8 +189,9 @@ class StreamTests(test_utils.TestCase): def test_read(self): # Read bytes. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) read_task = asyncio.Task(stream.read(30), loop=self.loop) def cb(): @@ -174,8 +204,9 @@ class StreamTests(test_utils.TestCase): def test_read_line_breaks(self): # Read bytes without line breaks. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line1') stream.feed_data(b'line2') @@ -186,8 +217,9 @@ class StreamTests(test_utils.TestCase): def test_read_eof(self): # Read bytes, stop at eof. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) read_task = asyncio.Task(stream.read(1024), loop=self.loop) def cb(): @@ -200,8 +232,9 @@ class StreamTests(test_utils.TestCase): def test_read_until_eof(self): # Read all bytes until eof. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) read_task = asyncio.Task(stream.read(-1), loop=self.loop) def cb(): @@ -216,8 +249,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_read_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.read(2)) @@ -229,16 +263,19 @@ class StreamTests(test_utils.TestCase): def test_invalid_limit(self): with self.assertRaisesRegex(ValueError, 'imit'): - asyncio.StreamReader(limit=0, loop=self.loop, - _asyncio_internal=True) + asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=0, loop=self.loop, + _asyncio_internal=True) with self.assertRaisesRegex(ValueError, 'imit'): - asyncio.StreamReader(limit=-1, loop=self.loop, - _asyncio_internal=True) + asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=-1, loop=self.loop, + _asyncio_internal=True) def test_read_limit(self): - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'chunk') data = self.loop.run_until_complete(stream.read(5)) self.assertEqual(b'chunk', data) @@ -247,8 +284,9 @@ class StreamTests(test_utils.TestCase): def test_readline(self): # Read one line. 'readline' will need to wait for the data # to come from 'cb' - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'chunk1 ') read_task = asyncio.Task(stream.readline(), loop=self.loop) @@ -263,11 +301,12 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b' chunk4', stream._buffer) def test_readline_limit_with_existing_data(self): - # Read one line. The data is in StreamReader's buffer + # Read one line. The data is in Stream's buffer # before the event loop is run. - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') @@ -276,8 +315,9 @@ class StreamTests(test_utils.TestCase): # The buffer should contain the remaining data after exception self.assertEqual(b'line2\n', stream._buffer) - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'li') stream.feed_data(b'ne1') stream.feed_data(b'li') @@ -292,8 +332,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_at_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) self.assertFalse(stream.at_eof()) stream.feed_data(b'some data\n') @@ -308,11 +349,12 @@ class StreamTests(test_utils.TestCase): self.assertTrue(stream.at_eof()) def test_readline_limit(self): - # Read one line. StreamReaders are fed with data after + # Read one line. Streams are fed with data after # their 'readline' methods are called. - stream = asyncio.StreamReader(limit=7, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=7, loop=self.loop, + _asyncio_internal=True) def cb(): stream.feed_data(b'chunk1') stream.feed_data(b'chunk2') @@ -326,8 +368,9 @@ class StreamTests(test_utils.TestCase): # a ValueError it should be empty. self.assertEqual(b'', stream._buffer) - stream = asyncio.StreamReader(limit=7, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=7, loop=self.loop, + _asyncio_internal=True) def cb(): stream.feed_data(b'chunk1') stream.feed_data(b'chunk2\n') @@ -340,8 +383,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'chunk3\n', stream._buffer) # check strictness of the limit - stream = asyncio.StreamReader(limit=7, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=7, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'1234567\n') line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'1234567\n', line) @@ -360,8 +404,9 @@ class StreamTests(test_utils.TestCase): def test_readline_nolimit_nowait(self): # All needed data for the first 'readline' call will be # in the buffer. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) @@ -371,8 +416,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'line2\nline3\n', stream._buffer) def test_readline_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'some data') stream.feed_eof() @@ -380,16 +426,18 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'some data', line) def test_readline_empty_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_eof() line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'', line) def test_readline_read_byte_count(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) self.loop.run_until_complete(stream.readline()) @@ -400,8 +448,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'ine3\n', stream._buffer) def test_readline_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readline()) @@ -413,14 +462,16 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_readuntil_separator(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) with self.assertRaisesRegex(ValueError, 'Separator should be'): self.loop.run_until_complete(stream.readuntil(separator=b'')) def test_readuntil_multi_chunks(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'lineAAA') data = self.loop.run_until_complete(stream.readuntil(separator=b'AAA')) @@ -438,8 +489,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'xxx', stream._buffer) def test_readuntil_multi_chunks_1(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'QWEaa') stream.feed_data(b'XYaa') @@ -474,8 +526,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_readuntil_eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'some dataAA') stream.feed_eof() @@ -486,8 +539,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_readuntil_limit_found_sep(self): - stream = asyncio.StreamReader(loop=self.loop, limit=3, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=3, + _asyncio_internal=True) stream.feed_data(b'some dataAA') with self.assertRaisesRegex(asyncio.LimitOverrunError, @@ -505,8 +559,9 @@ class StreamTests(test_utils.TestCase): def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.readexactly(0)) @@ -519,8 +574,9 @@ class StreamTests(test_utils.TestCase): def test_readexactly(self): # Read exact number of bytes. - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) n = 2 * len(self.DATA) read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) @@ -536,8 +592,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(self.DATA, stream._buffer) def test_readexactly_limit(self): - stream = asyncio.StreamReader(limit=3, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + limit=3, loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'chunk') data = self.loop.run_until_complete(stream.readexactly(5)) self.assertEqual(b'chunk', data) @@ -545,8 +602,9 @@ class StreamTests(test_utils.TestCase): def test_readexactly_eof(self): # Read exact number of bytes (eof). - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) n = 2 * len(self.DATA) read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) @@ -564,8 +622,9 @@ class StreamTests(test_utils.TestCase): self.assertEqual(b'', stream._buffer) def test_readexactly_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readexactly(2)) @@ -576,8 +635,9 @@ class StreamTests(test_utils.TestCase): ValueError, self.loop.run_until_complete, stream.readexactly(2)) def test_exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) self.assertIsNone(stream.exception()) exc = ValueError() @@ -585,8 +645,9 @@ class StreamTests(test_utils.TestCase): self.assertIs(stream.exception(), exc) def test_exception_waiter(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) async def set_err(): stream.set_exception(ValueError()) @@ -599,8 +660,9 @@ class StreamTests(test_utils.TestCase): self.assertRaises(ValueError, t1.result) def test_exception_cancel(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) t = asyncio.Task(stream.readline(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -655,8 +717,9 @@ class StreamTests(test_utils.TestCase): self.server = None async def client(addr): - reader, writer = await asyncio.open_connection( - *addr, loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + *addr, loop=self.loop) # send a line writer.write(b"hello world!\n") # read it back @@ -670,7 +733,8 @@ class StreamTests(test_utils.TestCase): # test the server variant with a coroutine as client handler server = MyServer(self.loop) - addr = server.start() + with self.assertWarns(DeprecationWarning): + addr = server.start() msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() @@ -678,7 +742,8 @@ class StreamTests(test_utils.TestCase): # test the server variant with a callback as client handler server = MyServer(self.loop) - addr = server.start_callback() + with self.assertWarns(DeprecationWarning): + addr = server.start_callback() msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() @@ -726,8 +791,9 @@ class StreamTests(test_utils.TestCase): self.server = None async def client(path): - reader, writer = await asyncio.open_unix_connection( - path, loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_unix_connection( + path, loop=self.loop) # send a line writer.write(b"hello world!\n") # read it back @@ -742,7 +808,8 @@ class StreamTests(test_utils.TestCase): # test the server variant with a coroutine as client handler with test_utils.unix_socket_path() as path: server = MyServer(self.loop, path) - server.start() + with self.assertWarns(DeprecationWarning): + server.start() msg = self.loop.run_until_complete(asyncio.Task(client(path), loop=self.loop)) server.stop() @@ -751,7 +818,8 @@ class StreamTests(test_utils.TestCase): # test the server variant with a callback as client handler with test_utils.unix_socket_path() as path: server = MyServer(self.loop, path) - server.start_callback() + with self.assertWarns(DeprecationWarning): + server.start_callback() msg = self.loop.run_until_complete(asyncio.Task(client(path), loop=self.loop)) server.stop() @@ -763,7 +831,7 @@ class StreamTests(test_utils.TestCase): def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example # subprocess_attach_read_pipe.py, but we configure the - # StreamReader's limit so that twice it is less than the size + # Stream's limit so that twice it is less than the size # of the data writter. Also we must explicitly attach a child # watcher to the event loop. @@ -777,10 +845,11 @@ os.close(fd) args = [sys.executable, '-c', code, str(wfd)] pipe = open(rfd, 'rb', 0) - reader = asyncio.StreamReader(loop=self.loop, limit=1, - _asyncio_internal=True) - protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=1, + _asyncio_internal=True) + protocol = _StreamProtocol(stream, loop=self.loop, + _asyncio_internal=True) transport, _ = self.loop.run_until_complete( self.loop.connect_read_pipe(lambda: protocol, pipe)) @@ -797,29 +866,30 @@ os.close(fd) asyncio.set_child_watcher(None) os.close(wfd) - data = self.loop.run_until_complete(reader.read(-1)) + data = self.loop.run_until_complete(stream.read(-1)) self.assertEqual(data, b'data') def test_streamreader_constructor(self): self.addCleanup(asyncio.set_event_loop, None) asyncio.set_event_loop(self.loop) - # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # asyncio issue #184: Ensure that _StreamProtocol constructor # retrieves the current loop if the loop parameter is not set - reader = asyncio.StreamReader(_asyncio_internal=True) + reader = asyncio.Stream(mode=asyncio.StreamMode.READ, + _asyncio_internal=True) self.assertIs(reader._loop, self.loop) def test_streamreaderprotocol_constructor(self): self.addCleanup(asyncio.set_event_loop, None) asyncio.set_event_loop(self.loop) - # asyncio issue #184: Ensure that StreamReaderProtocol constructor + # asyncio issue #184: Ensure that _StreamProtocol constructor # retrieves the current loop if the loop parameter is not set - reader = mock.Mock() - protocol = asyncio.StreamReaderProtocol(reader, _asyncio_internal=True) + stream = mock.Mock() + protocol = _StreamProtocol(stream, _asyncio_internal=True) self.assertIs(protocol._loop, self.loop) - def test_drain_raises(self): + def test_drain_raises_deprecated(self): # See http://bugs.python.org/issue25441 # This test should not use asyncio for the mock server; the @@ -833,15 +903,16 @@ os.close(fd) def server(): # Runs in a separate thread. - with socket.create_server(('localhost', 0)) as sock: + with socket.create_server(('127.0.0.1', 0)) as sock: addr = sock.getsockname() q.put(addr) clt, _ = sock.accept() clt.close() async def client(host, port): - reader, writer = await asyncio.open_connection( - host, port, loop=self.loop) + with self.assertWarns(DeprecationWarning): + reader, writer = await asyncio.open_connection( + host, port, loop=self.loop) while True: writer.write(b"foo\n") @@ -863,55 +934,106 @@ os.close(fd) thread.join() self.assertEqual([], messages) + def test_drain_raises(self): + # See http://bugs.python.org/issue25441 + + # This test should not use asyncio for the mock server; the + # whole point of the test is to test for a bug in drain() + # where it never gives up the event loop but the socket is + # closed on the server side. + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + q = queue.Queue() + + def server(): + # Runs in a separate thread. + with socket.create_server(('localhost', 0)) as sock: + addr = sock.getsockname() + q.put(addr) + clt, _ = sock.accept() + clt.close() + + async def client(host, port): + stream = await asyncio.connect(host, port) + + while True: + stream.write(b"foo\n") + await stream.drain() + + # Start the server thread and wait for it to be listening. + thread = threading.Thread(target=server) + thread.setDaemon(True) + thread.start() + addr = q.get() + + # Should not be stuck in an infinite loop. + with self.assertRaises((ConnectionResetError, ConnectionAbortedError, + BrokenPipeError)): + self.loop.run_until_complete(client(*addr)) + + # Clean up the thread. (Only on success; on failure, it may + # be stuck in accept().) + thread.join() + self.assertEqual([], messages) + def test___repr__(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - self.assertEqual("<StreamReader>", repr(stream)) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) + self.assertEqual("<Stream mode=StreamMode.READ>", repr(stream)) def test___repr__nondefault_limit(self): - stream = asyncio.StreamReader(loop=self.loop, limit=123, - _asyncio_internal=True) - self.assertEqual("<StreamReader limit=123>", repr(stream)) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, limit=123, + _asyncio_internal=True) + self.assertEqual("<Stream mode=StreamMode.READ limit=123>", repr(stream)) def test___repr__eof(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_eof() - self.assertEqual("<StreamReader eof>", repr(stream)) + self.assertEqual("<Stream mode=StreamMode.READ eof>", repr(stream)) def test___repr__data(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream.feed_data(b'data') - self.assertEqual("<StreamReader 4 bytes>", repr(stream)) + self.assertEqual("<Stream mode=StreamMode.READ 4 bytes>", repr(stream)) def test___repr__exception(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) exc = RuntimeError() stream.set_exception(exc) - self.assertEqual("<StreamReader exception=RuntimeError()>", + self.assertEqual("<Stream mode=StreamMode.READ exception=RuntimeError()>", repr(stream)) def test___repr__waiter(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream._waiter = asyncio.Future(loop=self.loop) self.assertRegex( repr(stream), - r"<StreamReader waiter=<Future pending[\S ]*>>") + r"<Stream .+ waiter=<Future pending[\S ]*>>") stream._waiter.set_result(None) self.loop.run_until_complete(stream._waiter) stream._waiter = None - self.assertEqual("<StreamReader>", repr(stream)) + self.assertEqual("<Stream mode=StreamMode.READ>", repr(stream)) def test___repr__transport(self): - stream = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) stream._transport = mock.Mock() stream._transport.__repr__ = mock.Mock() stream._transport.__repr__.return_value = "<Transport>" - self.assertEqual("<StreamReader transport=<Transport>>", repr(stream)) + self.assertEqual("<Stream mode=StreamMode.READ transport=<Transport>>", + repr(stream)) def test_IncompleteReadError_pickleable(self): e = asyncio.IncompleteReadError(b'abc', 10) @@ -930,10 +1052,11 @@ os.close(fd) self.assertEqual(str(e), str(e2)) self.assertEqual(e.consumed, e2.consumed) - def test_wait_closed_on_close(self): + def test_wait_closed_on_close_deprecated(self): with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, loop=self.loop)) + with self.assertWarns(DeprecationWarning): + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -947,10 +1070,28 @@ os.close(fd) self.assertTrue(wr.is_closing()) self.loop.run_until_complete(wr.wait_closed()) - def test_wait_closed_on_close_with_unread_data(self): + def test_wait_closed_on_close(self): with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, loop=self.loop)) + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + + stream.write(b'GET / HTTP/1.0\r\n\r\n') + f = stream.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = stream.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertFalse(stream.is_closing()) + stream.close() + self.assertTrue(stream.is_closing()) + self.loop.run_until_complete(stream.wait_closed()) + + def test_wait_closed_on_close_with_unread_data_deprecated(self): + with test_utils.run_test_server() as httpd: + with self.assertWarns(DeprecationWarning): + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, loop=self.loop)) wr.write(b'GET / HTTP/1.0\r\n\r\n') f = rd.readline() @@ -959,32 +1100,44 @@ os.close(fd) wr.close() self.loop.run_until_complete(wr.wait_closed()) + def test_wait_closed_on_close_with_unread_data(self): + with test_utils.run_test_server() as httpd: + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + + stream.write(b'GET / HTTP/1.0\r\n\r\n') + f = stream.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + stream.close() + self.loop.run_until_complete(stream.wait_closed()) + def test_del_stream_before_sock_closing(self): messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) - with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, loop=self.loop)) - sock = wr.get_extra_info('socket') - self.assertNotEqual(sock.fileno(), -1) + async def test(): - wr.write(b'GET / HTTP/1.0\r\n\r\n') - f = rd.readline() - data = self.loop.run_until_complete(f) - self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + with test_utils.run_test_server() as httpd: + stream = await asyncio.connect(*httpd.address) + sock = stream.get_extra_info('socket') + self.assertNotEqual(sock.fileno(), -1) - # drop refs to reader/writer - del rd - del wr - gc.collect() - # make a chance to close the socket - test_utils.run_briefly(self.loop) + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - self.assertEqual(1, len(messages)) - self.assertEqual(sock.fileno(), -1) + # drop refs to reader/writer + del stream + gc.collect() + # make a chance to close the socket + await asyncio.sleep(0) - self.assertEqual(1, len(messages)) + self.assertEqual(1, len(messages), messages) + self.assertEqual(sock.fileno(), -1) + + self.loop.run_until_complete(test()) + self.assertEqual(1, len(messages), messages) self.assertEqual('An open stream object is being garbage ' 'collected; call "stream.close()" explicitly.', messages[0]['message']) @@ -994,11 +1147,12 @@ os.close(fd) self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) with test_utils.run_test_server() as httpd: - rd = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - pr = asyncio.StreamReaderProtocol(rd, loop=self.loop, - _asyncio_internal=True) - del rd + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + loop=self.loop, + _asyncio_internal=True) + pr = _StreamProtocol(stream, loop=self.loop, + _asyncio_internal=True) + del stream gc.collect() tr, _ = self.loop.run_until_complete( self.loop.create_connection( @@ -1015,14 +1169,14 @@ os.close(fd) def test_async_writer_api(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + stream = await asyncio.connect(*httpd.address) - await wr.write(b'GET / HTTP/1.0\r\n\r\n') - data = await rd.readline() + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - data = await rd.read() + data = await stream.read() self.assertTrue(data.endswith(b'\r\n\r\nTest message')) - await wr.close() + await stream.close() messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1032,18 +1186,18 @@ os.close(fd) self.assertEqual(messages, []) - def test_async_writer_api(self): + def test_async_writer_api_exception_after_close(self): async def inner(httpd): - rd, wr = await asyncio.open_connection(*httpd.address) + stream = await asyncio.connect(*httpd.address) - await wr.write(b'GET / HTTP/1.0\r\n\r\n') - data = await rd.readline() + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - data = await rd.read() + data = await stream.read() self.assertTrue(data.endswith(b'\r\n\r\nTest message')) - wr.close() + stream.close() with self.assertRaises(ConnectionResetError): - await wr.write(b'data') + await stream.write(b'data') messages = [] self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) @@ -1059,11 +1213,13 @@ os.close(fd) self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) with test_utils.run_test_server() as httpd: - rd, wr = self.loop.run_until_complete( - asyncio.open_connection(*httpd.address, - loop=self.loop)) + with self.assertWarns(DeprecationWarning): + rd, wr = self.loop.run_until_complete( + asyncio.open_connection(*httpd.address, + loop=self.loop)) - f = wr.close() + wr.close() + f = wr.wait_closed() self.loop.run_until_complete(f) assert rd.at_eof() f = rd.read() @@ -1074,22 +1230,514 @@ os.close(fd) def test_stream_reader_create_warning(self): with self.assertWarns(DeprecationWarning): - asyncio.StreamReader(loop=self.loop) - - def test_stream_reader_protocol_create_warning(self): - reader = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - with self.assertWarns(DeprecationWarning): - asyncio.StreamReaderProtocol(reader, loop=self.loop) + asyncio.StreamReader def test_stream_writer_create_warning(self): - reader = asyncio.StreamReader(loop=self.loop, - _asyncio_internal=True) - proto = asyncio.StreamReaderProtocol(reader, loop=self.loop, - _asyncio_internal=True) with self.assertWarns(DeprecationWarning): - asyncio.StreamWriter('transport', proto, reader, self.loop) + asyncio.StreamWriter + + def test_stream_reader_forbidden_ops(self): + async def inner(): + stream = asyncio.Stream(mode=asyncio.StreamMode.READ, + _asyncio_internal=True) + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + await stream.write(b'data') + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + await stream.writelines([b'data', b'other']) + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + stream.write_eof() + with self.assertRaisesRegex(RuntimeError, "The stream is read-only"): + await stream.drain() + + self.loop.run_until_complete(inner()) + + def test_stream_writer_forbidden_ops(self): + async def inner(): + stream = asyncio.Stream(mode=asyncio.StreamMode.WRITE, + _asyncio_internal=True) + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + stream.feed_data(b'data') + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.readline() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.readuntil() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.read() + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + await stream.readexactly(10) + with self.assertRaisesRegex(RuntimeError, "The stream is write-only"): + async for chunk in stream: + pass + + self.loop.run_until_complete(inner()) + + def _basetest_connect(self, stream): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + stream.write(b'GET / HTTP/1.0\r\n\r\n') + f = stream.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = stream.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + stream.close() + self.loop.run_until_complete(stream.wait_closed()) + + self.assertEqual([], messages) + + def test_connect(self): + with test_utils.run_test_server() as httpd: + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + self.assertFalse(stream.is_server_side()) + self._basetest_connect(stream) + + @support.skip_unless_bind_unix_socket + def test_connect_unix(self): + with test_utils.run_test_unix_server() as httpd: + stream = self.loop.run_until_complete( + asyncio.connect_unix(httpd.address)) + self._basetest_connect(stream) + + def test_stream_async_context_manager(self): + async def test(httpd): + stream = await asyncio.connect(*httpd.address) + async with stream: + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await stream.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertTrue(stream.is_closing()) + + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete(test(httpd)) + + def test_connect_async_context_manager(self): + async def test(httpd): + async with asyncio.connect(*httpd.address) as stream: + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await stream.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertTrue(stream.is_closing()) + + with test_utils.run_test_server() as httpd: + self.loop.run_until_complete(test(httpd)) + + @support.skip_unless_bind_unix_socket + def test_connect_unix_async_context_manager(self): + async def test(httpd): + async with asyncio.connect_unix(httpd.address) as stream: + await stream.write(b'GET / HTTP/1.0\r\n\r\n') + data = await stream.readline() + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + data = await stream.read() + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + self.assertTrue(stream.is_closing()) + + with test_utils.run_test_unix_server() as httpd: + self.loop.run_until_complete(test(httpd)) + + def test_stream_server(self): + + async def handle_client(stream): + self.assertTrue(stream.is_server_side()) + data = await stream.readline() + await stream.write(data) + await stream.close() + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + # send a line + await stream.write(b"hello world!\n") + # read it back + msgback = await stream.readline() + await stream.close() + self.assertEqual(msgback, b"hello world!\n") + await srv.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + with contextlib.suppress(asyncio.CancelledError): + await server.serve_forever() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + @support.skip_unless_bind_unix_socket + def test_unix_stream_server(self): + + async def handle_client(stream): + data = await stream.readline() + await stream.write(data) + await stream.close() + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect_unix(addr) + # send a line + await stream.write(b"hello world!\n") + # read it back + msgback = await stream.readline() + await stream.close() + self.assertEqual(msgback, b"hello world!\n") + await srv.close() + + async def test(): + with test_utils.unix_socket_path() as path: + async with asyncio.UnixStreamServer(handle_client, path) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + with contextlib.suppress(asyncio.CancelledError): + await server.serve_forever() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + def test_stream_server_inheritance_forbidden(self): + with self.assertRaises(TypeError): + class MyServer(asyncio.StreamServer): + pass + + @support.skip_unless_bind_unix_socket + def test_unix_stream_server_inheritance_forbidden(self): + with self.assertRaises(TypeError): + class MyServer(asyncio.UnixStreamServer): + pass + + def test_stream_server_bind(self): + async def handle_client(stream): + await stream.close() + + async def test(): + srv = asyncio.StreamServer(handle_client, '127.0.0.1', 0) + self.assertFalse(srv.is_bound()) + self.assertEqual(0, len(srv.sockets)) + await srv.bind() + self.assertTrue(srv.is_bound()) + self.assertEqual(1, len(srv.sockets)) + await srv.close() + self.assertFalse(srv.is_bound()) + self.assertEqual(0, len(srv.sockets)) + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + def test_stream_server_bind_async_with(self): + async def handle_client(stream): + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv: + self.assertTrue(srv.is_bound()) + self.assertEqual(1, len(srv.sockets)) + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + def test_stream_server_start_serving(self): + async def handle_client(stream): + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as srv: + self.assertFalse(srv.is_serving()) + await srv.start_serving() + self.assertTrue(srv.is_serving()) + await srv.close() + self.assertFalse(srv.is_serving()) + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + + def test_stream_server_close(self): + server_stream_aborted = False + fut = self.loop.create_future() + + async def handle_client(stream): + await fut + self.assertEqual(b'', await stream.readline()) + nonlocal server_stream_aborted + server_stream_aborted = True + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + fut.set_result(None) + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut + await server.close() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut.done()) + self.assertTrue(server_stream_aborted) + + def test_stream_server_abort(self): + server_stream_aborted = False + fut = self.loop.create_future() + + async def handle_client(stream): + await fut + self.assertEqual(b'', await stream.readline()) + nonlocal server_stream_aborted + server_stream_aborted = True + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + fut.set_result(None) + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, '127.0.0.1', 0) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut + await server.abort() + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut.done()) + self.assertTrue(server_stream_aborted) + + def test_stream_shutdown_hung_task(self): + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() + + async def handle_client(stream): + while True: + await asyncio.sleep(0.01) + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + fut1.set_result(None) + await fut2 + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, + '127.0.0.1', + 0, + shutdown_timeout=0.3) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut1 + await server.close() + fut2.set_result(None) + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(messages, []) + self.assertTrue(fut1.done()) + self.assertTrue(fut2.done()) + + def test_stream_shutdown_hung_task_prevents_cancellation(self): + fut1 = self.loop.create_future() + fut2 = self.loop.create_future() + do_handle_client = True + + async def handle_client(stream): + while do_handle_client: + with contextlib.suppress(asyncio.CancelledError): + await asyncio.sleep(0.01) + + async def client(srv): + addr = srv.sockets[0].getsockname() + stream = await asyncio.connect(*addr) + fut1.set_result(None) + await fut2 + self.assertEqual(b'', await stream.readline()) + await stream.close() + + async def test(): + async with asyncio.StreamServer(handle_client, + '127.0.0.1', + 0, + shutdown_timeout=0.3) as server: + await server.start_serving() + task = asyncio.create_task(client(server)) + await fut1 + await server.close() + nonlocal do_handle_client + do_handle_client = False + fut2.set_result(None) + await task + + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + self.loop.run_until_complete(test()) + self.assertEqual(1, len(messages)) + self.assertRegex(messages[0]['message'], + "<Task pending .+ ignored cancellation request") + self.assertTrue(fut1.done()) + self.assertTrue(fut2.done()) + + def test_sendfile(self): + messages = [] + self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx)) + + with open(support.TESTFN, 'wb') as fp: + fp.write(b'data\n') + self.addCleanup(support.unlink, support.TESTFN) + + async def serve_callback(stream): + data = await stream.readline() + self.assertEqual(data, b'begin\n') + data = await stream.readline() + self.assertEqual(data, b'data\n') + data = await stream.readline() + self.assertEqual(data, b'end\n') + await stream.write(b'done\n') + await stream.close() + + async def do_connect(host, port): + stream = await asyncio.connect(host, port) + await stream.write(b'begin\n') + with open(support.TESTFN, 'rb') as fp: + await stream.sendfile(fp) + await stream.write(b'end\n') + data = await stream.readline() + self.assertEqual(data, b'done\n') + await stream.close() + + async def test(): + async with asyncio.StreamServer(serve_callback, '127.0.0.1', 0) as srv: + await srv.start_serving() + await do_connect(*srv.sockets[0].getsockname()) + + self.loop.run_until_complete(test()) + + self.assertEqual([], messages) + + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_connect_start_tls(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + # connect without SSL but upgrade to TLS just after + # connection is established + stream = self.loop.run_until_complete( + asyncio.connect(*httpd.address)) + + self.loop.run_until_complete( + stream.start_tls( + sslcontext=test_utils.dummy_ssl_context())) + self._basetest_connect(stream) + + def test_repr_unbound(self): + async def serve(stream): + pass + + async def test(): + srv = asyncio.StreamServer(serve) + self.assertEqual('<StreamServer>', repr(srv)) + await srv.close() + + self.loop.run_until_complete(test()) + + def test_repr_bound(self): + async def serve(stream): + pass + + async def test(): + srv = asyncio.StreamServer(serve, '127.0.0.1', 0) + await srv.bind() + self.assertRegex(repr(srv), r'<StreamServer sockets=\(.+\)>') + await srv.close() + + self.loop.run_until_complete(test()) + + def test_repr_serving(self): + async def serve(stream): + pass + + async def test(): + srv = asyncio.StreamServer(serve, '127.0.0.1', 0) + await srv.start_serving() + self.assertRegex(repr(srv), r'<StreamServer serving sockets=\(.+\)>') + await srv.close() + + self.loop.run_until_complete(test()) + + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + async def test(): + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + async with asyncio.connect_read_pipe(pipeobj) as stream: + self.assertEqual(stream.mode, asyncio.StreamMode.READ) + + os.write(wpipe, b'1') + data = await stream.readexactly(1) + self.assertEqual(data, b'1') + + os.write(wpipe, b'2345') + data = await stream.readexactly(4) + self.assertEqual(data, b'2345') + os.close(wpipe) + + self.loop.run_until_complete(test()) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + async def test(): + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + async with asyncio.connect_write_pipe(pipeobj) as stream: + self.assertEqual(stream.mode, asyncio.StreamMode.WRITE) + + await stream.write(b'1') + data = os.read(rpipe, 1024) + self.assertEqual(data, b'1') + + await stream.write(b'2345') + data = os.read(rpipe, 1024) + self.assertEqual(data, b'2345') + + os.close(rpipe) + self.loop.run_until_complete(test()) if __name__ == '__main__': |