1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
|
import asyncio
import unittest
from test.test_asyncio import functional as func_tests
class ReceiveStuffProto(asyncio.BufferedProtocol):
def __init__(self, cb, con_lost_fut):
self.cb = cb
self.con_lost_fut = con_lost_fut
def get_buffer(self):
self.buffer = bytearray(100)
return self.buffer
def buffer_updated(self, nbytes):
self.cb(self.buffer[:nbytes])
def connection_lost(self, exc):
if exc is None:
self.con_lost_fut.set_result(None)
else:
self.con_lost_fut.set_exception(exc)
class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin):
def new_loop(self):
raise NotImplementedError
def test_buffered_proto_create_connection(self):
NOISE = b'12345678+' * 1024
async def client(addr):
data = b''
def on_buf(buf):
nonlocal data
data += buf
if data == NOISE:
tr.write(b'1')
conn_lost_fut = self.loop.create_future()
tr, pr = await self.loop.create_connection(
lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr)
await conn_lost_fut
async def on_server_client(reader, writer):
writer.write(NOISE)
await reader.readexactly(1)
writer.close()
await writer.wait_closed()
srv = self.loop.run_until_complete(
asyncio.start_server(
on_server_client, '127.0.0.1', 0))
addr = srv.sockets[0].getsockname()
self.loop.run_until_complete(
asyncio.wait_for(client(addr), 5, loop=self.loop))
srv.close()
self.loop.run_until_complete(srv.wait_closed())
class BufferedProtocolSelectorTests(BaseTestBufferedProtocol,
unittest.TestCase):
def new_loop(self):
return asyncio.SelectorEventLoop()
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
class BufferedProtocolProactorTests(BaseTestBufferedProtocol,
unittest.TestCase):
def new_loop(self):
return asyncio.ProactorEventLoop()
if __name__ == '__main__':
unittest.main()
|