summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/asyncio/__init__.py33
-rw-r--r--Lib/asyncio/base_events.py606
-rw-r--r--Lib/asyncio/constants.py4
-rw-r--r--Lib/asyncio/events.py395
-rw-r--r--Lib/asyncio/futures.py338
-rw-r--r--Lib/asyncio/locks.py401
-rw-r--r--Lib/asyncio/log.py6
-rw-r--r--Lib/asyncio/proactor_events.py352
-rw-r--r--Lib/asyncio/protocols.py98
-rw-r--r--Lib/asyncio/queues.py284
-rw-r--r--Lib/asyncio/selector_events.py769
-rw-r--r--Lib/asyncio/streams.py257
-rw-r--r--Lib/asyncio/tasks.py636
-rw-r--r--Lib/asyncio/test_utils.py246
-rw-r--r--Lib/asyncio/transports.py186
-rw-r--r--Lib/asyncio/unix_events.py541
-rw-r--r--Lib/asyncio/windows_events.py375
-rw-r--r--Lib/asyncio/windows_utils.py181
-rw-r--r--Lib/test/test_asyncio/__init__.py26
-rw-r--r--Lib/test/test_asyncio/__main__.py5
-rw-r--r--Lib/test/test_asyncio/echo.py6
-rw-r--r--Lib/test/test_asyncio/echo2.py6
-rw-r--r--Lib/test/test_asyncio/echo3.py9
-rw-r--r--Lib/test/test_asyncio/sample.crt14
-rw-r--r--Lib/test/test_asyncio/sample.key15
-rw-r--r--Lib/test/test_asyncio/test_base_events.py590
-rw-r--r--Lib/test/test_asyncio/test_events.py1573
-rw-r--r--Lib/test/test_asyncio/test_futures.py329
-rw-r--r--Lib/test/test_asyncio/test_locks.py765
-rw-r--r--Lib/test/test_asyncio/test_proactor_events.py480
-rw-r--r--Lib/test/test_asyncio/test_queues.py470
-rw-r--r--Lib/test/test_asyncio/test_selector_events.py1485
-rw-r--r--Lib/test/test_asyncio/test_selectors.py145
-rw-r--r--Lib/test/test_asyncio/test_streams.py361
-rw-r--r--Lib/test/test_asyncio/test_tasks.py1518
-rw-r--r--Lib/test/test_asyncio/test_transports.py55
-rw-r--r--Lib/test/test_asyncio/test_unix_events.py767
-rw-r--r--Lib/test/test_asyncio/test_windows_events.py95
-rw-r--r--Lib/test/test_asyncio/test_windows_utils.py136
-rw-r--r--Lib/test/test_asyncio/tests.txt14
40 files changed, 14572 insertions, 0 deletions
diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py
new file mode 100644
index 0000000..afc444d
--- /dev/null
+++ b/Lib/asyncio/__init__.py
@@ -0,0 +1,33 @@
+"""The asyncio package, tracking PEP 3156."""
+
+import sys
+
+# The selectors module is in the stdlib in Python 3.4 but not in 3.3.
+# Do this first, so the other submodules can use "from . import selectors".
+try:
+ import selectors # Will also be exported.
+except ImportError:
+ from . import selectors
+
+# This relies on each of the submodules having an __all__ variable.
+from .futures import *
+from .events import *
+from .locks import *
+from .transports import *
+from .protocols import *
+from .streams import *
+from .tasks import *
+
+if sys.platform == 'win32': # pragma: no cover
+ from .windows_events import *
+else:
+ from .unix_events import * # pragma: no cover
+
+
+__all__ = (futures.__all__ +
+ events.__all__ +
+ locks.__all__ +
+ transports.__all__ +
+ protocols.__all__ +
+ streams.__all__ +
+ tasks.__all__)
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
new file mode 100644
index 0000000..32457eb
--- /dev/null
+++ b/Lib/asyncio/base_events.py
@@ -0,0 +1,606 @@
+"""Base implementation of event loop.
+
+The event loop can be broken up into a multiplexer (the part
+responsible for notifying us of IO events) and the event loop proper,
+which wraps a multiplexer with functionality for scheduling callbacks,
+immediately or at a given time in the future.
+
+Whenever a public API takes a callback, subsequent positional
+arguments will be passed to the callback if/when it is called. This
+avoids the proliferation of trivial lambdas implementing closures.
+Keyword arguments for the callback are not supported; this is a
+conscious design decision, leaving the door open for keyword arguments
+to modify the meaning of the API call itself.
+"""
+
+
+import collections
+import concurrent.futures
+import heapq
+import logging
+import socket
+import subprocess
+import time
+import os
+import sys
+
+from . import events
+from . import futures
+from . import tasks
+from .log import asyncio_log
+
+
+__all__ = ['BaseEventLoop', 'Server']
+
+
+# Argument for default thread pool executor creation.
+_MAX_WORKERS = 5
+
+
+class _StopError(BaseException):
+ """Raised to stop the event loop."""
+
+
+def _raise_stop_error(*args):
+ raise _StopError
+
+
+class Server(events.AbstractServer):
+
+ def __init__(self, loop, sockets):
+ self.loop = loop
+ self.sockets = sockets
+ self.active_count = 0
+ self.waiters = []
+
+ def attach(self, transport):
+ assert self.sockets is not None
+ self.active_count += 1
+
+ def detach(self, transport):
+ assert self.active_count > 0
+ self.active_count -= 1
+ if self.active_count == 0 and self.sockets is None:
+ self._wakeup()
+
+ def close(self):
+ sockets = self.sockets
+ if sockets is not None:
+ self.sockets = None
+ for sock in sockets:
+ self.loop._stop_serving(sock)
+ if self.active_count == 0:
+ self._wakeup()
+
+ def _wakeup(self):
+ waiters = self.waiters
+ self.waiters = None
+ for waiter in waiters:
+ if not waiter.done():
+ waiter.set_result(waiter)
+
+ @tasks.coroutine
+ def wait_closed(self):
+ if self.sockets is None or self.waiters is None:
+ return
+ waiter = futures.Future(loop=self.loop)
+ self.waiters.append(waiter)
+ yield from waiter
+
+
+class BaseEventLoop(events.AbstractEventLoop):
+
+ def __init__(self):
+ self._ready = collections.deque()
+ self._scheduled = []
+ self._default_executor = None
+ self._internal_fds = 0
+ self._running = False
+
+ def _make_socket_transport(self, sock, protocol, waiter=None, *,
+ extra=None, server=None):
+ """Create socket transport."""
+ raise NotImplementedError
+
+ def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *,
+ server_side=False, server_hostname=None,
+ extra=None, server=None):
+ """Create SSL transport."""
+ raise NotImplementedError
+
+ def _make_datagram_transport(self, sock, protocol,
+ address=None, extra=None):
+ """Create datagram transport."""
+ raise NotImplementedError
+
+ def _make_read_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ """Create read pipe transport."""
+ raise NotImplementedError
+
+ def _make_write_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ """Create write pipe transport."""
+ raise NotImplementedError
+
+ @tasks.coroutine
+ def _make_subprocess_transport(self, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ extra=None, **kwargs):
+ """Create subprocess transport."""
+ raise NotImplementedError
+
+ def _read_from_self(self):
+ """XXX"""
+ raise NotImplementedError
+
+ def _write_to_self(self):
+ """XXX"""
+ raise NotImplementedError
+
+ def _process_events(self, event_list):
+ """Process selector events."""
+ raise NotImplementedError
+
+ def run_forever(self):
+ """Run until stop() is called."""
+ if self._running:
+ raise RuntimeError('Event loop is running.')
+ self._running = True
+ try:
+ while True:
+ try:
+ self._run_once()
+ except _StopError:
+ break
+ finally:
+ self._running = False
+
+ def run_until_complete(self, future):
+ """Run until the Future is done.
+
+ If the argument is a coroutine, it is wrapped in a Task.
+
+ XXX TBD: It would be disastrous to call run_until_complete()
+ with the same coroutine twice -- it would wrap it in two
+ different Tasks and that can't be good.
+
+ Return the Future's result, or raise its exception.
+ """
+ future = tasks.async(future, loop=self)
+ future.add_done_callback(_raise_stop_error)
+ self.run_forever()
+ future.remove_done_callback(_raise_stop_error)
+ if not future.done():
+ raise RuntimeError('Event loop stopped before Future completed.')
+
+ return future.result()
+
+ def stop(self):
+ """Stop running the event loop.
+
+ Every callback scheduled before stop() is called will run.
+ Callback scheduled after stop() is called won't. However,
+ those callbacks will run if run() is called again later.
+ """
+ self.call_soon(_raise_stop_error)
+
+ def is_running(self):
+ """Returns running status of event loop."""
+ return self._running
+
+ def time(self):
+ """Return the time according to the event loop's clock."""
+ return time.monotonic()
+
+ def call_later(self, delay, callback, *args):
+ """Arrange for a callback to be called at a given time.
+
+ Return a Handle: an opaque object with a cancel() method that
+ can be used to cancel the call.
+
+ The delay can be an int or float, expressed in seconds. It is
+ always a relative time.
+
+ Each callback will be called exactly once. If two callbacks
+ are scheduled for exactly the same time, it undefined which
+ will be called first.
+
+ Any positional arguments after the callback will be passed to
+ the callback when it is called.
+ """
+ return self.call_at(self.time() + delay, callback, *args)
+
+ def call_at(self, when, callback, *args):
+ """Like call_later(), but uses an absolute time."""
+ timer = events.TimerHandle(when, callback, args)
+ heapq.heappush(self._scheduled, timer)
+ return timer
+
+ def call_soon(self, callback, *args):
+ """Arrange for a callback to be called as soon as possible.
+
+ This operates as a FIFO queue, callbacks are called in the
+ order in which they are registered. Each callback will be
+ called exactly once.
+
+ Any positional arguments after the callback will be passed to
+ the callback when it is called.
+ """
+ handle = events.make_handle(callback, args)
+ self._ready.append(handle)
+ return handle
+
+ def call_soon_threadsafe(self, callback, *args):
+ """XXX"""
+ handle = self.call_soon(callback, *args)
+ self._write_to_self()
+ return handle
+
+ def run_in_executor(self, executor, callback, *args):
+ if isinstance(callback, events.Handle):
+ assert not args
+ assert not isinstance(callback, events.TimerHandle)
+ if callback._cancelled:
+ f = futures.Future(loop=self)
+ f.set_result(None)
+ return f
+ callback, args = callback._callback, callback._args
+ if executor is None:
+ executor = self._default_executor
+ if executor is None:
+ executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS)
+ self._default_executor = executor
+ return futures.wrap_future(executor.submit(callback, *args), loop=self)
+
+ def set_default_executor(self, executor):
+ self._default_executor = executor
+
+ def getaddrinfo(self, host, port, *,
+ family=0, type=0, proto=0, flags=0):
+ return self.run_in_executor(None, socket.getaddrinfo,
+ host, port, family, type, proto, flags)
+
+ def getnameinfo(self, sockaddr, flags=0):
+ return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags)
+
+ @tasks.coroutine
+ def create_connection(self, protocol_factory, host=None, port=None, *,
+ ssl=None, family=0, proto=0, flags=0, sock=None,
+ local_addr=None):
+ """XXX"""
+ if host is not None or port is not None:
+ if sock is not None:
+ raise ValueError(
+ 'host/port and sock can not be specified at the same time')
+
+ f1 = self.getaddrinfo(
+ host, port, family=family,
+ type=socket.SOCK_STREAM, proto=proto, flags=flags)
+ fs = [f1]
+ if local_addr is not None:
+ f2 = self.getaddrinfo(
+ *local_addr, family=family,
+ type=socket.SOCK_STREAM, proto=proto, flags=flags)
+ fs.append(f2)
+ else:
+ f2 = None
+
+ yield from tasks.wait(fs, loop=self)
+
+ infos = f1.result()
+ if not infos:
+ raise OSError('getaddrinfo() returned empty list')
+ if f2 is not None:
+ laddr_infos = f2.result()
+ if not laddr_infos:
+ raise OSError('getaddrinfo() returned empty list')
+
+ exceptions = []
+ for family, type, proto, cname, address in infos:
+ try:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ sock.setblocking(False)
+ if f2 is not None:
+ for _, _, _, _, laddr in laddr_infos:
+ try:
+ sock.bind(laddr)
+ break
+ except OSError as exc:
+ exc = OSError(
+ exc.errno, 'error while '
+ 'attempting to bind on address '
+ '{!r}: {}'.format(
+ laddr, exc.strerror.lower()))
+ exceptions.append(exc)
+ else:
+ sock.close()
+ sock = None
+ continue
+ yield from self.sock_connect(sock, address)
+ except OSError as exc:
+ if sock is not None:
+ sock.close()
+ exceptions.append(exc)
+ else:
+ break
+ else:
+ if len(exceptions) == 1:
+ raise exceptions[0]
+ else:
+ # If they all have the same str(), raise one.
+ model = str(exceptions[0])
+ if all(str(exc) == model for exc in exceptions):
+ raise exceptions[0]
+ # Raise a combined exception so the user can see all
+ # the various error messages.
+ raise OSError('Multiple exceptions: {}'.format(
+ ', '.join(str(exc) for exc in exceptions)))
+
+ elif sock is None:
+ raise ValueError(
+ 'host and port was not specified and no sock specified')
+
+ sock.setblocking(False)
+
+ protocol = protocol_factory()
+ waiter = futures.Future(loop=self)
+ if ssl:
+ sslcontext = None if isinstance(ssl, bool) else ssl
+ transport = self._make_ssl_transport(
+ sock, protocol, sslcontext, waiter,
+ server_side=False, server_hostname=host)
+ else:
+ transport = self._make_socket_transport(sock, protocol, waiter)
+
+ yield from waiter
+ return transport, protocol
+
+ @tasks.coroutine
+ def create_datagram_endpoint(self, protocol_factory,
+ local_addr=None, remote_addr=None, *,
+ family=0, proto=0, flags=0):
+ """Create datagram connection."""
+ if not (local_addr or remote_addr):
+ if family == 0:
+ raise ValueError('unexpected address family')
+ addr_pairs_info = (((family, proto), (None, None)),)
+ else:
+ # join addresss by (family, protocol)
+ addr_infos = collections.OrderedDict()
+ for idx, addr in ((0, local_addr), (1, remote_addr)):
+ if addr is not None:
+ assert isinstance(addr, tuple) and len(addr) == 2, (
+ '2-tuple is expected')
+
+ infos = yield from self.getaddrinfo(
+ *addr, family=family, type=socket.SOCK_DGRAM,
+ proto=proto, flags=flags)
+ if not infos:
+ raise OSError('getaddrinfo() returned empty list')
+
+ for fam, _, pro, _, address in infos:
+ key = (fam, pro)
+ if key not in addr_infos:
+ addr_infos[key] = [None, None]
+ addr_infos[key][idx] = address
+
+ # each addr has to have info for each (family, proto) pair
+ addr_pairs_info = [
+ (key, addr_pair) for key, addr_pair in addr_infos.items()
+ if not ((local_addr and addr_pair[0] is None) or
+ (remote_addr and addr_pair[1] is None))]
+
+ if not addr_pairs_info:
+ raise ValueError('can not get address information')
+
+ exceptions = []
+
+ for ((family, proto),
+ (local_address, remote_address)) in addr_pairs_info:
+ sock = None
+ r_addr = None
+ try:
+ sock = socket.socket(
+ family=family, type=socket.SOCK_DGRAM, proto=proto)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.setblocking(False)
+
+ if local_addr:
+ sock.bind(local_address)
+ if remote_addr:
+ yield from self.sock_connect(sock, remote_address)
+ r_addr = remote_address
+ except OSError as exc:
+ if sock is not None:
+ sock.close()
+ exceptions.append(exc)
+ else:
+ break
+ else:
+ raise exceptions[0]
+
+ protocol = protocol_factory()
+ transport = self._make_datagram_transport(sock, protocol, r_addr)
+ return transport, protocol
+
+ @tasks.coroutine
+ def create_server(self, protocol_factory, host=None, port=None,
+ *,
+ family=socket.AF_UNSPEC,
+ flags=socket.AI_PASSIVE,
+ sock=None,
+ backlog=100,
+ ssl=None,
+ reuse_address=None):
+ """XXX"""
+ if host is not None or port is not None:
+ if sock is not None:
+ raise ValueError(
+ 'host/port and sock can not be specified at the same time')
+
+ AF_INET6 = getattr(socket, 'AF_INET6', 0)
+ if reuse_address is None:
+ reuse_address = os.name == 'posix' and sys.platform != 'cygwin'
+ sockets = []
+ if host == '':
+ host = None
+
+ infos = yield from self.getaddrinfo(
+ host, port, family=family,
+ type=socket.SOCK_STREAM, proto=0, flags=flags)
+ if not infos:
+ raise OSError('getaddrinfo() returned empty list')
+
+ completed = False
+ try:
+ for res in infos:
+ af, socktype, proto, canonname, sa = res
+ sock = socket.socket(af, socktype, proto)
+ sockets.append(sock)
+ if reuse_address:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
+ True)
+ # Disable IPv4/IPv6 dual stack support (enabled by
+ # default on Linux) which makes a single socket
+ # listen on both address families.
+ if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'):
+ sock.setsockopt(socket.IPPROTO_IPV6,
+ socket.IPV6_V6ONLY,
+ True)
+ try:
+ sock.bind(sa)
+ except OSError as err:
+ raise OSError(err.errno, 'error while attempting '
+ 'to bind on address %r: %s'
+ % (sa, err.strerror.lower()))
+ completed = True
+ finally:
+ if not completed:
+ for sock in sockets:
+ sock.close()
+ else:
+ if sock is None:
+ raise ValueError(
+ 'host and port was not specified and no sock specified')
+ sockets = [sock]
+
+ server = Server(self, sockets)
+ for sock in sockets:
+ sock.listen(backlog)
+ sock.setblocking(False)
+ self._start_serving(protocol_factory, sock, ssl, server)
+ return server
+
+ @tasks.coroutine
+ def connect_read_pipe(self, protocol_factory, pipe):
+ protocol = protocol_factory()
+ waiter = futures.Future(loop=self)
+ transport = self._make_read_pipe_transport(pipe, protocol, waiter)
+ yield from waiter
+ return transport, protocol
+
+ @tasks.coroutine
+ def connect_write_pipe(self, protocol_factory, pipe):
+ protocol = protocol_factory()
+ waiter = futures.Future(loop=self)
+ transport = self._make_write_pipe_transport(pipe, protocol, waiter)
+ yield from waiter
+ return transport, protocol
+
+ @tasks.coroutine
+ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ universal_newlines=False, shell=True, bufsize=0,
+ **kwargs):
+ assert not universal_newlines, "universal_newlines must be False"
+ assert shell, "shell must be True"
+ assert isinstance(cmd, str), cmd
+ protocol = protocol_factory()
+ transport = yield from self._make_subprocess_transport(
+ protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs)
+ return transport, protocol
+
+ @tasks.coroutine
+ def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ universal_newlines=False, shell=False, bufsize=0,
+ **kwargs):
+ assert not universal_newlines, "universal_newlines must be False"
+ assert not shell, "shell must be False"
+ protocol = protocol_factory()
+ transport = yield from self._make_subprocess_transport(
+ protocol, args, False, stdin, stdout, stderr, bufsize, **kwargs)
+ return transport, protocol
+
+ def _add_callback(self, handle):
+ """Add a Handle to ready or scheduled."""
+ assert isinstance(handle, events.Handle), 'A Handle is required here'
+ if handle._cancelled:
+ return
+ if isinstance(handle, events.TimerHandle):
+ heapq.heappush(self._scheduled, handle)
+ else:
+ self._ready.append(handle)
+
+ def _add_callback_signalsafe(self, handle):
+ """Like _add_callback() but called from a signal handler."""
+ self._add_callback(handle)
+ self._write_to_self()
+
+ def _run_once(self):
+ """Run one full iteration of the event loop.
+
+ This calls all currently ready callbacks, polls for I/O,
+ schedules the resulting callbacks, and finally schedules
+ 'call_later' callbacks.
+ """
+ # Remove delayed calls that were cancelled from head of queue.
+ while self._scheduled and self._scheduled[0]._cancelled:
+ heapq.heappop(self._scheduled)
+
+ timeout = None
+ if self._ready:
+ timeout = 0
+ elif self._scheduled:
+ # Compute the desired timeout.
+ when = self._scheduled[0]._when
+ deadline = max(0, when - self.time())
+ if timeout is None:
+ timeout = deadline
+ else:
+ timeout = min(timeout, deadline)
+
+ # TODO: Instrumentation only in debug mode?
+ t0 = self.time()
+ event_list = self._selector.select(timeout)
+ t1 = self.time()
+ argstr = '' if timeout is None else '{:.3f}'.format(timeout)
+ if t1-t0 >= 1:
+ level = logging.INFO
+ else:
+ level = logging.DEBUG
+ asyncio_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0)
+ self._process_events(event_list)
+
+ # Handle 'later' callbacks that are ready.
+ now = self.time()
+ while self._scheduled:
+ handle = self._scheduled[0]
+ if handle._when > now:
+ break
+ handle = heapq.heappop(self._scheduled)
+ self._ready.append(handle)
+
+ # This is the only place where callbacks are actually *called*.
+ # All other places just add them to ready.
+ # Note: We run all currently scheduled callbacks, but not any
+ # callbacks scheduled by callbacks run this time around --
+ # they will be run the next time (after another I/O poll).
+ # Use an idiom that is threadsafe without using locks.
+ ntodo = len(self._ready)
+ for i in range(ntodo):
+ handle = self._ready.popleft()
+ if not handle._cancelled:
+ handle._run()
+ handle = None # Needed to break cycles when an exception occurs.
diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py
new file mode 100644
index 0000000..79c3b93
--- /dev/null
+++ b/Lib/asyncio/constants.py
@@ -0,0 +1,4 @@
+"""Constants."""
+
+
+LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
new file mode 100644
index 0000000..9724615
--- /dev/null
+++ b/Lib/asyncio/events.py
@@ -0,0 +1,395 @@
+"""Event loop and event loop policy."""
+
+__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy',
+ 'AbstractEventLoop', 'AbstractServer',
+ 'Handle', 'TimerHandle',
+ 'get_event_loop_policy', 'set_event_loop_policy',
+ 'get_event_loop', 'set_event_loop', 'new_event_loop',
+ ]
+
+import subprocess
+import sys
+import threading
+import socket
+
+from .log import asyncio_log
+
+
+class Handle:
+ """Object returned by callback registration methods."""
+
+ def __init__(self, callback, args):
+ self._callback = callback
+ self._args = args
+ self._cancelled = False
+
+ def __repr__(self):
+ res = 'Handle({}, {})'.format(self._callback, self._args)
+ if self._cancelled:
+ res += '<cancelled>'
+ return res
+
+ def cancel(self):
+ self._cancelled = True
+
+ def _run(self):
+ try:
+ self._callback(*self._args)
+ except Exception:
+ asyncio_log.exception('Exception in callback %s %r',
+ self._callback, self._args)
+ self = None # Needed to break cycles when an exception occurs.
+
+
+def make_handle(callback, args):
+ # TODO: Inline this?
+ assert not isinstance(callback, Handle), 'A Handle is not a callback'
+ return Handle(callback, args)
+
+
+class TimerHandle(Handle):
+ """Object returned by timed callback registration methods."""
+
+ def __init__(self, when, callback, args):
+ assert when is not None
+ super().__init__(callback, args)
+
+ self._when = when
+
+ def __repr__(self):
+ res = 'TimerHandle({}, {}, {})'.format(self._when,
+ self._callback,
+ self._args)
+ if self._cancelled:
+ res += '<cancelled>'
+
+ return res
+
+ def __hash__(self):
+ return hash(self._when)
+
+ def __lt__(self, other):
+ return self._when < other._when
+
+ def __le__(self, other):
+ if self._when < other._when:
+ return True
+ return self.__eq__(other)
+
+ def __gt__(self, other):
+ return self._when > other._when
+
+ def __ge__(self, other):
+ if self._when > other._when:
+ return True
+ return self.__eq__(other)
+
+ def __eq__(self, other):
+ if isinstance(other, TimerHandle):
+ return (self._when == other._when and
+ self._callback == other._callback and
+ self._args == other._args and
+ self._cancelled == other._cancelled)
+ return NotImplemented
+
+ def __ne__(self, other):
+ equal = self.__eq__(other)
+ return NotImplemented if equal is NotImplemented else not equal
+
+
+class AbstractServer:
+ """Abstract server returned by create_service()."""
+
+ def close(self):
+ """Stop serving. This leaves existing connections open."""
+ return NotImplemented
+
+ def wait_closed(self):
+ """Coroutine to wait until service is closed."""
+ return NotImplemented
+
+
+class AbstractEventLoop:
+ """Abstract event loop."""
+
+ # Running and stopping the event loop.
+
+ def run_forever(self):
+ """Run the event loop until stop() is called."""
+ raise NotImplementedError
+
+ def run_until_complete(self, future):
+ """Run the event loop until a Future is done.
+
+ Return the Future's result, or raise its exception.
+ """
+ raise NotImplementedError
+
+ def stop(self):
+ """Stop the event loop as soon as reasonable.
+
+ Exactly how soon that is may depend on the implementation, but
+ no more I/O callbacks should be scheduled.
+ """
+ raise NotImplementedError
+
+ def is_running(self):
+ """Return whether the event loop is currently running."""
+ raise NotImplementedError
+
+ # Methods scheduling callbacks. All these return Handles.
+
+ def call_soon(self, callback, *args):
+ return self.call_later(0, callback, *args)
+
+ def call_later(self, delay, callback, *args):
+ raise NotImplementedError
+
+ def call_at(self, when, callback, *args):
+ raise NotImplementedError
+
+ def time(self):
+ raise NotImplementedError
+
+ # Methods for interacting with threads.
+
+ def call_soon_threadsafe(self, callback, *args):
+ raise NotImplementedError
+
+ def run_in_executor(self, executor, callback, *args):
+ raise NotImplementedError
+
+ def set_default_executor(self, executor):
+ raise NotImplementedError
+
+ # Network I/O methods returning Futures.
+
+ def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0):
+ raise NotImplementedError
+
+ def getnameinfo(self, sockaddr, flags=0):
+ raise NotImplementedError
+
+ def create_connection(self, protocol_factory, host=None, port=None, *,
+ ssl=None, family=0, proto=0, flags=0, sock=None,
+ local_addr=None):
+ raise NotImplementedError
+
+ def create_server(self, protocol_factory, host=None, port=None, *,
+ family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE,
+ sock=None, backlog=100, ssl=None, reuse_address=None):
+ """A coroutine which creates a TCP server bound to host and port.
+
+ The return value is a Server object which can be used to stop
+ the service.
+
+ If host is an empty string or None all interfaces are assumed
+ and a list of multiple sockets will be returned (most likely
+ one for IPv4 and another one for IPv6).
+
+ family can be set to either AF_INET or AF_INET6 to force the
+ socket to use IPv4 or IPv6. If not set it will be determined
+ from host (defaults to AF_UNSPEC).
+
+ flags is a bitmask for getaddrinfo().
+
+ sock can optionally be specified in order to use a preexisting
+ socket object.
+
+ backlog is the maximum number of queued connections passed to
+ listen() (defaults to 100).
+
+ ssl can be set to an SSLContext to enable SSL over the
+ accepted connections.
+
+ reuse_address tells the kernel to reuse a local socket in
+ TIME_WAIT state, without waiting for its natural timeout to
+ expire. If not specified will automatically be set to True on
+ UNIX.
+ """
+ raise NotImplementedError
+
+ def create_datagram_endpoint(self, protocol_factory,
+ local_addr=None, remote_addr=None, *,
+ family=0, proto=0, flags=0):
+ raise NotImplementedError
+
+ def connect_read_pipe(self, protocol_factory, pipe):
+ """Register read pipe in eventloop.
+
+ protocol_factory should instantiate object with Protocol interface.
+ pipe is file-like object already switched to nonblocking.
+ Return pair (transport, protocol), where transport support
+ ReadTransport ABC"""
+ # The reason to accept file-like object instead of just file descriptor
+ # is: we need to own pipe and close it at transport finishing
+ # Can got complicated errors if pass f.fileno(),
+ # close fd in pipe transport then close f and vise versa.
+ raise NotImplementedError
+
+ def connect_write_pipe(self, protocol_factory, pipe):
+ """Register write pipe in eventloop.
+
+ protocol_factory should instantiate object with BaseProtocol interface.
+ Pipe is file-like object already switched to nonblocking.
+ Return pair (transport, protocol), where transport support
+ WriteTransport ABC"""
+ # The reason to accept file-like object instead of just file descriptor
+ # is: we need to own pipe and close it at transport finishing
+ # Can got complicated errors if pass f.fileno(),
+ # close fd in pipe transport then close f and vise versa.
+ raise NotImplementedError
+
+ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ **kwargs):
+ raise NotImplementedError
+
+ def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE,
+ **kwargs):
+ raise NotImplementedError
+
+ # Ready-based callback registration methods.
+ # The add_*() methods return None.
+ # The remove_*() methods return True if something was removed,
+ # False if there was nothing to delete.
+
+ def add_reader(self, fd, callback, *args):
+ raise NotImplementedError
+
+ def remove_reader(self, fd):
+ raise NotImplementedError
+
+ def add_writer(self, fd, callback, *args):
+ raise NotImplementedError
+
+ def remove_writer(self, fd):
+ raise NotImplementedError
+
+ # Completion based I/O methods returning Futures.
+
+ def sock_recv(self, sock, nbytes):
+ raise NotImplementedError
+
+ def sock_sendall(self, sock, data):
+ raise NotImplementedError
+
+ def sock_connect(self, sock, address):
+ raise NotImplementedError
+
+ def sock_accept(self, sock):
+ raise NotImplementedError
+
+ # Signal handling.
+
+ def add_signal_handler(self, sig, callback, *args):
+ raise NotImplementedError
+
+ def remove_signal_handler(self, sig):
+ raise NotImplementedError
+
+
+class AbstractEventLoopPolicy:
+ """Abstract policy for accessing the event loop."""
+
+ def get_event_loop(self):
+ """XXX"""
+ raise NotImplementedError
+
+ def set_event_loop(self, loop):
+ """XXX"""
+ raise NotImplementedError
+
+ def new_event_loop(self):
+ """XXX"""
+ raise NotImplementedError
+
+
+class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy):
+ """Default policy implementation for accessing the event loop.
+
+ In this policy, each thread has its own event loop. However, we
+ only automatically create an event loop by default for the main
+ thread; other threads by default have no event loop.
+
+ Other policies may have different rules (e.g. a single global
+ event loop, or automatically creating an event loop per thread, or
+ using some other notion of context to which an event loop is
+ associated).
+ """
+
+ _loop = None
+ _set_called = False
+
+ def get_event_loop(self):
+ """Get the event loop.
+
+ This may be None or an instance of EventLoop.
+ """
+ if (self._loop is None and
+ not self._set_called and
+ isinstance(threading.current_thread(), threading._MainThread)):
+ self._loop = self.new_event_loop()
+ assert self._loop is not None, \
+ ('There is no current event loop in thread %r.' %
+ threading.current_thread().name)
+ return self._loop
+
+ def set_event_loop(self, loop):
+ """Set the event loop."""
+ # TODO: The isinstance() test violates the PEP.
+ self._set_called = True
+ assert loop is None or isinstance(loop, AbstractEventLoop)
+ self._loop = loop
+
+ def new_event_loop(self):
+ """Create a new event loop.
+
+ You must call set_event_loop() to make this the current event
+ loop.
+ """
+ if sys.platform == 'win32': # pragma: no cover
+ from . import windows_events
+ return windows_events.SelectorEventLoop()
+ else: # pragma: no cover
+ from . import unix_events
+ return unix_events.SelectorEventLoop()
+
+
+# Event loop policy. The policy itself is always global, even if the
+# policy's rules say that there is an event loop per thread (or other
+# notion of context). The default policy is installed by the first
+# call to get_event_loop_policy().
+_event_loop_policy = None
+
+
+def get_event_loop_policy():
+ """XXX"""
+ global _event_loop_policy
+ if _event_loop_policy is None:
+ _event_loop_policy = DefaultEventLoopPolicy()
+ return _event_loop_policy
+
+
+def set_event_loop_policy(policy):
+ """XXX"""
+ global _event_loop_policy
+ # TODO: The isinstance() test violates the PEP.
+ assert policy is None or isinstance(policy, AbstractEventLoopPolicy)
+ _event_loop_policy = policy
+
+
+def get_event_loop():
+ """XXX"""
+ return get_event_loop_policy().get_event_loop()
+
+
+def set_event_loop(loop):
+ """XXX"""
+ get_event_loop_policy().set_event_loop(loop)
+
+
+def new_event_loop():
+ """XXX"""
+ return get_event_loop_policy().new_event_loop()
diff --git a/Lib/asyncio/futures.py b/Lib/asyncio/futures.py
new file mode 100644
index 0000000..99a043b
--- /dev/null
+++ b/Lib/asyncio/futures.py
@@ -0,0 +1,338 @@
+"""A Future class similar to the one in PEP 3148."""
+
+__all__ = ['CancelledError', 'TimeoutError',
+ 'InvalidStateError',
+ 'Future', 'wrap_future',
+ ]
+
+import concurrent.futures._base
+import logging
+import traceback
+
+from . import events
+from .log import asyncio_log
+
+# States for Future.
+_PENDING = 'PENDING'
+_CANCELLED = 'CANCELLED'
+_FINISHED = 'FINISHED'
+
+# TODO: Do we really want to depend on concurrent.futures internals?
+Error = concurrent.futures._base.Error
+CancelledError = concurrent.futures.CancelledError
+TimeoutError = concurrent.futures.TimeoutError
+
+STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging
+
+
+class InvalidStateError(Error):
+ """The operation is not allowed in this state."""
+ # TODO: Show the future, its state, the method, and the required state.
+
+
+class _TracebackLogger:
+ """Helper to log a traceback upon destruction if not cleared.
+
+ This solves a nasty problem with Futures and Tasks that have an
+ exception set: if nobody asks for the exception, the exception is
+ never logged. This violates the Zen of Python: 'Errors should
+ never pass silently. Unless explicitly silenced.'
+
+ However, we don't want to log the exception as soon as
+ set_exception() is called: if the calling code is written
+ properly, it will get the exception and handle it properly. But
+ we *do* want to log it if result() or exception() was never called
+ -- otherwise developers waste a lot of time wondering why their
+ buggy code fails silently.
+
+ An earlier attempt added a __del__() method to the Future class
+ itself, but this backfired because the presence of __del__()
+ prevents garbage collection from breaking cycles. A way out of
+ this catch-22 is to avoid having a __del__() method on the Future
+ class itself, but instead to have a reference to a helper object
+ with a __del__() method that logs the traceback, where we ensure
+ that the helper object doesn't participate in cycles, and only the
+ Future has a reference to it.
+
+ The helper object is added when set_exception() is called. When
+ the Future is collected, and the helper is present, the helper
+ object is also collected, and its __del__() method will log the
+ traceback. When the Future's result() or exception() method is
+ called (and a helper object is present), it removes the the helper
+ object, after calling its clear() method to prevent it from
+ logging.
+
+ One downside is that we do a fair amount of work to extract the
+ traceback from the exception, even when it is never logged. It
+ would seem cheaper to just store the exception object, but that
+ references the traceback, which references stack frames, which may
+ reference the Future, which references the _TracebackLogger, and
+ then the _TracebackLogger would be included in a cycle, which is
+ what we're trying to avoid! As an optimization, we don't
+ immediately format the exception; we only do the work when
+ activate() is called, which call is delayed until after all the
+ Future's callbacks have run. Since usually a Future has at least
+ one callback (typically set by 'yield from') and usually that
+ callback extracts the callback, thereby removing the need to
+ format the exception.
+
+ PS. I don't claim credit for this solution. I first heard of it
+ in a discussion about closing files when they are collected.
+ """
+
+ __slots__ = ['exc', 'tb']
+
+ def __init__(self, exc):
+ self.exc = exc
+ self.tb = None
+
+ def activate(self):
+ exc = self.exc
+ if exc is not None:
+ self.exc = None
+ self.tb = traceback.format_exception(exc.__class__, exc,
+ exc.__traceback__)
+
+ def clear(self):
+ self.exc = None
+ self.tb = None
+
+ def __del__(self):
+ if self.tb:
+ asyncio_log.error('Future/Task exception was never retrieved:\n%s',
+ ''.join(self.tb))
+
+
+class Future:
+ """This class is *almost* compatible with concurrent.futures.Future.
+
+ Differences:
+
+ - result() and exception() do not take a timeout argument and
+ raise an exception when the future isn't done yet.
+
+ - Callbacks registered with add_done_callback() are always called
+ via the event loop's call_soon_threadsafe().
+
+ - This class is not compatible with the wait() and as_completed()
+ methods in the concurrent.futures package.
+
+ (In Python 3.4 or later we may be able to unify the implementations.)
+ """
+
+ # Class variables serving as defaults for instance variables.
+ _state = _PENDING
+ _result = None
+ _exception = None
+ _loop = None
+
+ _blocking = False # proper use of future (yield vs yield from)
+
+ _tb_logger = None
+
+ def __init__(self, *, loop=None):
+ """Initialize the future.
+
+ The optional event_loop argument allows to explicitly set the event
+ loop object used by the future. If it's not provided, the future uses
+ the default event loop.
+ """
+ if loop is None:
+ self._loop = events.get_event_loop()
+ else:
+ self._loop = loop
+ self._callbacks = []
+
+ def __repr__(self):
+ res = self.__class__.__name__
+ if self._state == _FINISHED:
+ if self._exception is not None:
+ res += '<exception={!r}>'.format(self._exception)
+ else:
+ res += '<result={!r}>'.format(self._result)
+ elif self._callbacks:
+ size = len(self._callbacks)
+ if size > 2:
+ res += '<{}, [{}, <{} more>, {}]>'.format(
+ self._state, self._callbacks[0],
+ size-2, self._callbacks[-1])
+ else:
+ res += '<{}, {}>'.format(self._state, self._callbacks)
+ else:
+ res += '<{}>'.format(self._state)
+ return res
+
+ def cancel(self):
+ """Cancel the future and schedule callbacks.
+
+ If the future is already done or cancelled, return False. Otherwise,
+ change the future's state to cancelled, schedule the callbacks and
+ return True.
+ """
+ if self._state != _PENDING:
+ return False
+ self._state = _CANCELLED
+ self._schedule_callbacks()
+ return True
+
+ def _schedule_callbacks(self):
+ """Internal: Ask the event loop to call all callbacks.
+
+ The callbacks are scheduled to be called as soon as possible. Also
+ clears the callback list.
+ """
+ callbacks = self._callbacks[:]
+ if not callbacks:
+ return
+
+ self._callbacks[:] = []
+ for callback in callbacks:
+ self._loop.call_soon(callback, self)
+
+ def cancelled(self):
+ """Return True if the future was cancelled."""
+ return self._state == _CANCELLED
+
+ # Don't implement running(); see http://bugs.python.org/issue18699
+
+ def done(self):
+ """Return True if the future is done.
+
+ Done means either that a result / exception are available, or that the
+ future was cancelled.
+ """
+ return self._state != _PENDING
+
+ def result(self):
+ """Return the result this future represents.
+
+ If the future has been cancelled, raises CancelledError. If the
+ future's result isn't yet available, raises InvalidStateError. If
+ the future is done and has an exception set, this exception is raised.
+ """
+ if self._state == _CANCELLED:
+ raise CancelledError
+ if self._state != _FINISHED:
+ raise InvalidStateError('Result is not ready.')
+ if self._tb_logger is not None:
+ self._tb_logger.clear()
+ self._tb_logger = None
+ if self._exception is not None:
+ raise self._exception
+ return self._result
+
+ def exception(self):
+ """Return the exception that was set on this future.
+
+ The exception (or None if no exception was set) is returned only if
+ the future is done. If the future has been cancelled, raises
+ CancelledError. If the future isn't done yet, raises
+ InvalidStateError.
+ """
+ if self._state == _CANCELLED:
+ raise CancelledError
+ if self._state != _FINISHED:
+ raise InvalidStateError('Exception is not set.')
+ if self._tb_logger is not None:
+ self._tb_logger.clear()
+ self._tb_logger = None
+ return self._exception
+
+ def add_done_callback(self, fn):
+ """Add a callback to be run when the future becomes done.
+
+ The callback is called with a single argument - the future object. If
+ the future is already done when this is called, the callback is
+ scheduled with call_soon.
+ """
+ if self._state != _PENDING:
+ self._loop.call_soon(fn, self)
+ else:
+ self._callbacks.append(fn)
+
+ # New method not in PEP 3148.
+
+ def remove_done_callback(self, fn):
+ """Remove all instances of a callback from the "call when done" list.
+
+ Returns the number of callbacks removed.
+ """
+ filtered_callbacks = [f for f in self._callbacks if f != fn]
+ removed_count = len(self._callbacks) - len(filtered_callbacks)
+ if removed_count:
+ self._callbacks[:] = filtered_callbacks
+ return removed_count
+
+ # So-called internal methods (note: no set_running_or_notify_cancel()).
+
+ def set_result(self, result):
+ """Mark the future done and set its result.
+
+ If the future is already done when this method is called, raises
+ InvalidStateError.
+ """
+ if self._state != _PENDING:
+ raise InvalidStateError('{}: {!r}'.format(self._state, self))
+ self._result = result
+ self._state = _FINISHED
+ self._schedule_callbacks()
+
+ def set_exception(self, exception):
+ """Mark the future done and set an exception.
+
+ If the future is already done when this method is called, raises
+ InvalidStateError.
+ """
+ if self._state != _PENDING:
+ raise InvalidStateError('{}: {!r}'.format(self._state, self))
+ self._exception = exception
+ self._tb_logger = _TracebackLogger(exception)
+ self._state = _FINISHED
+ self._schedule_callbacks()
+ # Arrange for the logger to be activated after all callbacks
+ # have had a chance to call result() or exception().
+ self._loop.call_soon(self._tb_logger.activate)
+
+ # Truly internal methods.
+
+ def _copy_state(self, other):
+ """Internal helper to copy state from another Future.
+
+ The other Future may be a concurrent.futures.Future.
+ """
+ assert other.done()
+ assert not self.done()
+ if other.cancelled():
+ self.cancel()
+ else:
+ exception = other.exception()
+ if exception is not None:
+ self.set_exception(exception)
+ else:
+ result = other.result()
+ self.set_result(result)
+
+ def __iter__(self):
+ if not self.done():
+ self._blocking = True
+ yield self # This tells Task to wait for completion.
+ assert self.done(), "yield from wasn't used with future"
+ return self.result() # May raise too.
+
+
+def wrap_future(fut, *, loop=None):
+ """Wrap concurrent.futures.Future object."""
+ if isinstance(fut, Future):
+ return fut
+
+ assert isinstance(fut, concurrent.futures.Future), \
+ 'concurrent.futures.Future is expected, got {!r}'.format(fut)
+
+ if loop is None:
+ loop = events.get_event_loop()
+
+ new_future = Future(loop=loop)
+ fut.add_done_callback(
+ lambda future: loop.call_soon_threadsafe(
+ new_future._copy_state, fut))
+ return new_future
diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py
new file mode 100644
index 0000000..06edbbc
--- /dev/null
+++ b/Lib/asyncio/locks.py
@@ -0,0 +1,401 @@
+"""Synchronization primitives."""
+
+__all__ = ['Lock', 'Event', 'Condition', 'Semaphore']
+
+import collections
+
+from . import events
+from . import futures
+from . import tasks
+
+
+class Lock:
+ """Primitive lock objects.
+
+ A primitive lock is a synchronization primitive that is not owned
+ by a particular coroutine when locked. A primitive lock is in one
+ of two states, 'locked' or 'unlocked'.
+
+ It is created in the unlocked state. It has two basic methods,
+ acquire() and release(). When the state is unlocked, acquire()
+ changes the state to locked and returns immediately. When the
+ state is locked, acquire() blocks until a call to release() in
+ another coroutine changes it to unlocked, then the acquire() call
+ resets it to locked and returns. The release() method should only
+ be called in the locked state; it changes the state to unlocked
+ and returns immediately. If an attempt is made to release an
+ unlocked lock, a RuntimeError will be raised.
+
+ When more than one coroutine is blocked in acquire() waiting for
+ the state to turn to unlocked, only one coroutine proceeds when a
+ release() call resets the state to unlocked; first coroutine which
+ is blocked in acquire() is being processed.
+
+ acquire() is a coroutine and should be called with 'yield from'.
+
+ Locks also support the context manager protocol. '(yield from lock)'
+ should be used as context manager expression.
+
+ Usage:
+
+ lock = Lock()
+ ...
+ yield from lock
+ try:
+ ...
+ finally:
+ lock.release()
+
+ Context manager usage:
+
+ lock = Lock()
+ ...
+ with (yield from lock):
+ ...
+
+ Lock objects can be tested for locking state:
+
+ if not lock.locked():
+ yield from lock
+ else:
+ # lock is acquired
+ ...
+
+ """
+
+ def __init__(self, *, loop=None):
+ self._waiters = collections.deque()
+ self._locked = False
+ if loop is not None:
+ self._loop = loop
+ else:
+ self._loop = events.get_event_loop()
+
+ def __repr__(self):
+ res = super().__repr__()
+ extra = 'locked' if self._locked else 'unlocked'
+ if self._waiters:
+ extra = '{},waiters:{}'.format(extra, len(self._waiters))
+ return '<{} [{}]>'.format(res[1:-1], extra)
+
+ def locked(self):
+ """Return true if lock is acquired."""
+ return self._locked
+
+ @tasks.coroutine
+ def acquire(self):
+ """Acquire a lock.
+
+ This method blocks until the lock is unlocked, then sets it to
+ locked and returns True.
+ """
+ if not self._waiters and not self._locked:
+ self._locked = True
+ return True
+
+ fut = futures.Future(loop=self._loop)
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ self._locked = True
+ return True
+ finally:
+ self._waiters.remove(fut)
+
+ def release(self):
+ """Release a lock.
+
+ When the lock is locked, reset it to unlocked, and return.
+ If any other coroutines are blocked waiting for the lock to become
+ unlocked, allow exactly one of them to proceed.
+
+ When invoked on an unlocked lock, a RuntimeError is raised.
+
+ There is no return value.
+ """
+ if self._locked:
+ self._locked = False
+ # Wake up the first waiter who isn't cancelled.
+ for fut in self._waiters:
+ if not fut.done():
+ fut.set_result(True)
+ break
+ else:
+ raise RuntimeError('Lock is not acquired.')
+
+ def __enter__(self):
+ if not self._locked:
+ raise RuntimeError(
+ '"yield from" should be used as context manager expression')
+ return True
+
+ def __exit__(self, *args):
+ self.release()
+
+ def __iter__(self):
+ yield from self.acquire()
+ return self
+
+
+class Event:
+ """An Event implementation, our equivalent to threading.Event.
+
+ Class implementing event objects. An event manages a flag that can be set
+ to true with the set() method and reset to false with the clear() method.
+ The wait() method blocks until the flag is true. The flag is initially
+ false.
+ """
+
+ def __init__(self, *, loop=None):
+ self._waiters = collections.deque()
+ self._value = False
+ if loop is not None:
+ self._loop = loop
+ else:
+ self._loop = events.get_event_loop()
+
+ def __repr__(self):
+ # TODO: add waiters:N if > 0.
+ res = super().__repr__()
+ return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset')
+
+ def is_set(self):
+ """Return true if and only if the internal flag is true."""
+ return self._value
+
+ def set(self):
+ """Set the internal flag to true. All coroutines waiting for it to
+ become true are awakened. Coroutine that call wait() once the flag is
+ true will not block at all.
+ """
+ if not self._value:
+ self._value = True
+
+ for fut in self._waiters:
+ if not fut.done():
+ fut.set_result(True)
+
+ def clear(self):
+ """Reset the internal flag to false. Subsequently, coroutines calling
+ wait() will block until set() is called to set the internal flag
+ to true again."""
+ self._value = False
+
+ @tasks.coroutine
+ def wait(self):
+ """Block until the internal flag is true.
+
+ If the internal flag is true on entry, return True
+ immediately. Otherwise, block until another coroutine calls
+ set() to set the flag to true, then return True.
+ """
+ if self._value:
+ return True
+
+ fut = futures.Future(loop=self._loop)
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ return True
+ finally:
+ self._waiters.remove(fut)
+
+
+# TODO: Why is this a Lock subclass? threading.Condition *has* a lock.
+class Condition(Lock):
+ """A Condition implementation.
+
+ This class implements condition variable objects. A condition variable
+ allows one or more coroutines to wait until they are notified by another
+ coroutine.
+ """
+
+ def __init__(self, *, loop=None):
+ super().__init__(loop=loop)
+ self._condition_waiters = collections.deque()
+
+ # TODO: Add __repr__() with len(_condition_waiters).
+
+ @tasks.coroutine
+ def wait(self):
+ """Wait until notified.
+
+ If the calling coroutine has not acquired the lock when this
+ method is called, a RuntimeError is raised.
+
+ This method releases the underlying lock, and then blocks
+ until it is awakened by a notify() or notify_all() call for
+ the same condition variable in another coroutine. Once
+ awakened, it re-acquires the lock and returns True.
+ """
+ if not self._locked:
+ raise RuntimeError('cannot wait on un-acquired lock')
+
+ keep_lock = True
+ self.release()
+ try:
+ fut = futures.Future(loop=self._loop)
+ self._condition_waiters.append(fut)
+ try:
+ yield from fut
+ return True
+ finally:
+ self._condition_waiters.remove(fut)
+
+ except GeneratorExit:
+ keep_lock = False # Prevent yield in finally clause.
+ raise
+ finally:
+ if keep_lock:
+ yield from self.acquire()
+
+ @tasks.coroutine
+ def wait_for(self, predicate):
+ """Wait until a predicate becomes true.
+
+ The predicate should be a callable which result will be
+ interpreted as a boolean value. The final predicate value is
+ the return value.
+ """
+ result = predicate()
+ while not result:
+ yield from self.wait()
+ result = predicate()
+ return result
+
+ def notify(self, n=1):
+ """By default, wake up one coroutine waiting on this condition, if any.
+ If the calling coroutine has not acquired the lock when this method
+ is called, a RuntimeError is raised.
+
+ This method wakes up at most n of the coroutines waiting for the
+ condition variable; it is a no-op if no coroutines are waiting.
+
+ Note: an awakened coroutine does not actually return from its
+ wait() call until it can reacquire the lock. Since notify() does
+ not release the lock, its caller should.
+ """
+ if not self._locked:
+ raise RuntimeError('cannot notify on un-acquired lock')
+
+ idx = 0
+ for fut in self._condition_waiters:
+ if idx >= n:
+ break
+
+ if not fut.done():
+ idx += 1
+ fut.set_result(False)
+
+ def notify_all(self):
+ """Wake up all threads waiting on this condition. This method acts
+ like notify(), but wakes up all waiting threads instead of one. If the
+ calling thread has not acquired the lock when this method is called,
+ a RuntimeError is raised.
+ """
+ self.notify(len(self._condition_waiters))
+
+
+class Semaphore:
+ """A Semaphore implementation.
+
+ A semaphore manages an internal counter which is decremented by each
+ acquire() call and incremented by each release() call. The counter
+ can never go below zero; when acquire() finds that it is zero, it blocks,
+ waiting until some other thread calls release().
+
+ Semaphores also support the context manager protocol.
+
+ The first optional argument gives the initial value for the internal
+ counter; it defaults to 1. If the value given is less than 0,
+ ValueError is raised.
+
+ The second optional argument determins can semophore be released more than
+ initial internal counter value; it defaults to False. If the value given
+ is True and number of release() is more than number of successfull
+ acquire() calls ValueError is raised.
+ """
+
+ def __init__(self, value=1, bound=False, *, loop=None):
+ if value < 0:
+ raise ValueError("Semaphore initial value must be > 0")
+ self._value = value
+ self._bound = bound
+ self._bound_value = value
+ self._waiters = collections.deque()
+ self._locked = False
+ if loop is not None:
+ self._loop = loop
+ else:
+ self._loop = events.get_event_loop()
+
+ def __repr__(self):
+ # TODO: add waiters:N if > 0.
+ res = super().__repr__()
+ return '<{} [{}]>'.format(
+ res[1:-1],
+ 'locked' if self._locked else 'unlocked,value:{}'.format(
+ self._value))
+
+ def locked(self):
+ """Returns True if semaphore can not be acquired immediately."""
+ return self._locked
+
+ @tasks.coroutine
+ def acquire(self):
+ """Acquire a semaphore.
+
+ If the internal counter is larger than zero on entry,
+ decrement it by one and return True immediately. If it is
+ zero on entry, block, waiting until some other coroutine has
+ called release() to make it larger than 0, and then return
+ True.
+ """
+ if not self._waiters and self._value > 0:
+ self._value -= 1
+ if self._value == 0:
+ self._locked = True
+ return True
+
+ fut = futures.Future(loop=self._loop)
+ self._waiters.append(fut)
+ try:
+ yield from fut
+ self._value -= 1
+ if self._value == 0:
+ self._locked = True
+ return True
+ finally:
+ self._waiters.remove(fut)
+
+ def release(self):
+ """Release a semaphore, incrementing the internal counter by one.
+ When it was zero on entry and another coroutine is waiting for it to
+ become larger than zero again, wake up that coroutine.
+
+ If Semaphore is create with "bound" paramter equals true, then
+ release() method checks to make sure its current value doesn't exceed
+ its initial value. If it does, ValueError is raised.
+ """
+ if self._bound and self._value >= self._bound_value:
+ raise ValueError('Semaphore released too many times')
+
+ self._value += 1
+ self._locked = False
+
+ for waiter in self._waiters:
+ if not waiter.done():
+ waiter.set_result(True)
+ break
+
+ def __enter__(self):
+ # TODO: This is questionable. How do we know the user actually
+ # wrote "with (yield from sema)" instead of "with sema"?
+ return True
+
+ def __exit__(self, *args):
+ self.release()
+
+ def __iter__(self):
+ yield from self.acquire()
+ return self
diff --git a/Lib/asyncio/log.py b/Lib/asyncio/log.py
new file mode 100644
index 0000000..54dc784
--- /dev/null
+++ b/Lib/asyncio/log.py
@@ -0,0 +1,6 @@
+"""Logging configuration."""
+
+import logging
+
+
+asyncio_log = logging.getLogger("asyncio")
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
new file mode 100644
index 0000000..348de03
--- /dev/null
+++ b/Lib/asyncio/proactor_events.py
@@ -0,0 +1,352 @@
+"""Event loop using a proactor and related classes.
+
+A proactor is a "notify-on-completion" multiplexer. Currently a
+proactor is only implemented on Windows with IOCP.
+"""
+
+import socket
+
+from . import base_events
+from . import constants
+from . import futures
+from . import transports
+from .log import asyncio_log
+
+
+class _ProactorBasePipeTransport(transports.BaseTransport):
+ """Base class for pipe and socket transports."""
+
+ def __init__(self, loop, sock, protocol, waiter=None,
+ extra=None, server=None):
+ super().__init__(extra)
+ self._set_extra(sock)
+ self._loop = loop
+ self._sock = sock
+ self._protocol = protocol
+ self._server = server
+ self._buffer = []
+ self._read_fut = None
+ self._write_fut = None
+ self._conn_lost = 0
+ self._closing = False # Set when close() called.
+ self._eof_written = False
+ if self._server is not None:
+ self._server.attach(self)
+ self._loop.call_soon(self._protocol.connection_made, self)
+ if waiter is not None:
+ self._loop.call_soon(waiter.set_result, None)
+
+ def _set_extra(self, sock):
+ self._extra['pipe'] = sock
+
+ def close(self):
+ if self._closing:
+ return
+ self._closing = True
+ self._conn_lost += 1
+ if not self._buffer and self._write_fut is None:
+ self._loop.call_soon(self._call_connection_lost, None)
+ if self._read_fut is not None:
+ self._read_fut.cancel()
+
+ def _fatal_error(self, exc):
+ asyncio_log.exception('Fatal error for %s', self)
+ self._force_close(exc)
+
+ def _force_close(self, exc):
+ if self._closing:
+ return
+ self._closing = True
+ self._conn_lost += 1
+ if self._write_fut:
+ self._write_fut.cancel()
+ if self._read_fut:
+ self._read_fut.cancel()
+ self._write_fut = self._read_fut = None
+ self._buffer = []
+ self._loop.call_soon(self._call_connection_lost, exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ # XXX If there is a pending overlapped read on the other
+ # end then it may fail with ERROR_NETNAME_DELETED if we
+ # just close our end. First calling shutdown() seems to
+ # cure it, but maybe using DisconnectEx() would be better.
+ if hasattr(self._sock, 'shutdown'):
+ self._sock.shutdown(socket.SHUT_RDWR)
+ self._sock.close()
+ server = self._server
+ if server is not None:
+ server.detach(self)
+ self._server = None
+
+
+class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
+ transports.ReadTransport):
+ """Transport for read pipes."""
+
+ def __init__(self, loop, sock, protocol, waiter=None,
+ extra=None, server=None):
+ super().__init__(loop, sock, protocol, waiter, extra, server)
+ self._read_fut = None
+ self._paused = False
+ self._loop.call_soon(self._loop_reading)
+
+ def pause(self):
+ assert not self._closing, 'Cannot pause() when closing'
+ assert not self._paused, 'Already paused'
+ self._paused = True
+
+ def resume(self):
+ assert self._paused, 'Not paused'
+ self._paused = False
+ if self._closing:
+ return
+ self._loop.call_soon(self._loop_reading, self._read_fut)
+
+ def _loop_reading(self, fut=None):
+ if self._paused:
+ return
+ data = None
+
+ try:
+ if fut is not None:
+ assert self._read_fut is fut or (self._read_fut is None and
+ self._closing)
+ self._read_fut = None
+ data = fut.result() # deliver data later in "finally" clause
+
+ if self._closing:
+ # since close() has been called we ignore any read data
+ data = None
+ return
+
+ if data == b'':
+ # we got end-of-file so no need to reschedule a new read
+ return
+
+ # reschedule a new read
+ self._read_fut = self._loop._proactor.recv(self._sock, 4096)
+ except ConnectionAbortedError as exc:
+ if not self._closing:
+ self._fatal_error(exc)
+ except ConnectionResetError as exc:
+ self._force_close(exc)
+ except OSError as exc:
+ self._fatal_error(exc)
+ except futures.CancelledError:
+ if not self._closing:
+ raise
+ else:
+ self._read_fut.add_done_callback(self._loop_reading)
+ finally:
+ if data:
+ self._protocol.data_received(data)
+ elif data is not None:
+ keep_open = self._protocol.eof_received()
+ if not keep_open:
+ self.close()
+
+
+class _ProactorWritePipeTransport(_ProactorBasePipeTransport,
+ transports.WriteTransport):
+ """Transport for write pipes."""
+
+ def write(self, data):
+ assert isinstance(data, bytes), repr(data)
+ if self._eof_written:
+ raise IOError('write_eof() already called')
+
+ if not data:
+ return
+
+ if self._conn_lost:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ asyncio_log.warning('socket.send() raised exception.')
+ self._conn_lost += 1
+ return
+ self._buffer.append(data)
+ if self._write_fut is None:
+ self._loop_writing()
+
+ def _loop_writing(self, f=None):
+ try:
+ assert f is self._write_fut
+ self._write_fut = None
+ if f:
+ f.result()
+ data = b''.join(self._buffer)
+ self._buffer = []
+ if not data:
+ if self._closing:
+ self._loop.call_soon(self._call_connection_lost, None)
+ if self._eof_written:
+ self._sock.shutdown(socket.SHUT_WR)
+ return
+ self._write_fut = self._loop._proactor.send(self._sock, data)
+ self._write_fut.add_done_callback(self._loop_writing)
+ except ConnectionResetError as exc:
+ self._force_close(exc)
+ except OSError as exc:
+ self._fatal_error(exc)
+
+ def can_write_eof(self):
+ return True
+
+ def write_eof(self):
+ self.close()
+
+ def abort(self):
+ self._force_close(None)
+
+
+class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport,
+ _ProactorWritePipeTransport,
+ transports.Transport):
+ """Transport for duplex pipes."""
+
+ def can_write_eof(self):
+ return False
+
+ def write_eof(self):
+ raise NotImplementedError
+
+
+class _ProactorSocketTransport(_ProactorReadPipeTransport,
+ _ProactorWritePipeTransport,
+ transports.Transport):
+ """Transport for connected sockets."""
+
+ def _set_extra(self, sock):
+ self._extra['socket'] = sock
+ try:
+ self._extra['sockname'] = sock.getsockname()
+ except (socket.error, AttributeError):
+ pass
+ if 'peername' not in self._extra:
+ try:
+ self._extra['peername'] = sock.getpeername()
+ except (socket.error, AttributeError):
+ pass
+
+ def can_write_eof(self):
+ return True
+
+ def write_eof(self):
+ if self._closing or self._eof_written:
+ return
+ self._eof_written = True
+ if self._write_fut is None:
+ self._sock.shutdown(socket.SHUT_WR)
+
+
+class BaseProactorEventLoop(base_events.BaseEventLoop):
+
+ def __init__(self, proactor):
+ super().__init__()
+ asyncio_log.debug('Using proactor: %s', proactor.__class__.__name__)
+ self._proactor = proactor
+ self._selector = proactor # convenient alias
+ proactor.set_loop(self)
+ self._make_self_pipe()
+
+ def _make_socket_transport(self, sock, protocol, waiter=None,
+ extra=None, server=None):
+ return _ProactorSocketTransport(self, sock, protocol, waiter,
+ extra, server)
+
+ def _make_duplex_pipe_transport(self, sock, protocol, waiter=None,
+ extra=None):
+ return _ProactorDuplexPipeTransport(self,
+ sock, protocol, waiter, extra)
+
+ def _make_read_pipe_transport(self, sock, protocol, waiter=None,
+ extra=None):
+ return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra)
+
+ def _make_write_pipe_transport(self, sock, protocol, waiter=None,
+ extra=None):
+ return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra)
+
+ def close(self):
+ if self._proactor is not None:
+ self._close_self_pipe()
+ self._proactor.close()
+ self._proactor = None
+ self._selector = None
+
+ def sock_recv(self, sock, n):
+ return self._proactor.recv(sock, n)
+
+ def sock_sendall(self, sock, data):
+ return self._proactor.send(sock, data)
+
+ def sock_connect(self, sock, address):
+ return self._proactor.connect(sock, address)
+
+ def sock_accept(self, sock):
+ return self._proactor.accept(sock)
+
+ def _socketpair(self):
+ raise NotImplementedError
+
+ def _close_self_pipe(self):
+ self._ssock.close()
+ self._ssock = None
+ self._csock.close()
+ self._csock = None
+ self._internal_fds -= 1
+
+ def _make_self_pipe(self):
+ # A self-socket, really. :-)
+ self._ssock, self._csock = self._socketpair()
+ self._ssock.setblocking(False)
+ self._csock.setblocking(False)
+ self._internal_fds += 1
+ self.call_soon(self._loop_self_reading)
+
+ def _loop_self_reading(self, f=None):
+ try:
+ if f is not None:
+ f.result() # may raise
+ f = self._proactor.recv(self._ssock, 4096)
+ except:
+ self.close()
+ raise
+ else:
+ f.add_done_callback(self._loop_self_reading)
+
+ def _write_to_self(self):
+ self._csock.send(b'x')
+
+ def _start_serving(self, protocol_factory, sock, ssl=None, server=None):
+ assert not ssl, 'IocpEventLoop is incompatible with SSL.'
+
+ def loop(f=None):
+ try:
+ if f is not None:
+ conn, addr = f.result()
+ protocol = protocol_factory()
+ self._make_socket_transport(
+ conn, protocol,
+ extra={'peername': addr}, server=server)
+ f = self._proactor.accept(sock)
+ except OSError:
+ if sock.fileno() != -1:
+ asyncio_log.exception('Accept failed')
+ sock.close()
+ except futures.CancelledError:
+ sock.close()
+ else:
+ f.add_done_callback(loop)
+
+ self.call_soon(loop)
+
+ def _process_events(self, event_list):
+ pass # XXX hard work currently done in poll
+
+ def _stop_serving(self, sock):
+ self._proactor._stop_serving(sock)
+ sock.close()
diff --git a/Lib/asyncio/protocols.py b/Lib/asyncio/protocols.py
new file mode 100644
index 0000000..a94abbe
--- /dev/null
+++ b/Lib/asyncio/protocols.py
@@ -0,0 +1,98 @@
+"""Abstract Protocol class."""
+
+__all__ = ['Protocol', 'DatagramProtocol']
+
+
+class BaseProtocol:
+ """ABC for base protocol class.
+
+ Usually user implements protocols that derived from BaseProtocol
+ like Protocol or ProcessProtocol.
+
+ The only case when BaseProtocol should be implemented directly is
+ write-only transport like write pipe
+ """
+
+ def connection_made(self, transport):
+ """Called when a connection is made.
+
+ The argument is the transport representing the pipe connection.
+ To receive data, wait for data_received() calls.
+ When the connection is closed, connection_lost() is called.
+ """
+
+ def connection_lost(self, exc):
+ """Called when the connection is lost or closed.
+
+ The argument is an exception object or None (the latter
+ meaning a regular EOF is received or the connection was
+ aborted or closed).
+ """
+
+
+class Protocol(BaseProtocol):
+ """ABC representing a protocol.
+
+ The user should implement this interface. They can inherit from
+ this class but don't need to. The implementations here do
+ nothing (they don't raise exceptions).
+
+ When the user wants to requests a transport, they pass a protocol
+ factory to a utility function (e.g., EventLoop.create_connection()).
+
+ When the connection is made successfully, connection_made() is
+ called with a suitable transport object. Then data_received()
+ will be called 0 or more times with data (bytes) received from the
+ transport; finally, connection_lost() will be called exactly once
+ with either an exception object or None as an argument.
+
+ State machine of calls:
+
+ start -> CM [-> DR*] [-> ER?] -> CL -> end
+ """
+
+ def data_received(self, data):
+ """Called when some data is received.
+
+ The argument is a bytes object.
+ """
+
+ def eof_received(self):
+ """Called when the other end calls write_eof() or equivalent.
+
+ If this returns a false value (including None), the transport
+ will close itself. If it returns a true value, closing the
+ transport is up to the protocol.
+ """
+
+
+class DatagramProtocol(BaseProtocol):
+ """ABC representing a datagram protocol."""
+
+ def datagram_received(self, data, addr):
+ """Called when some datagram is received."""
+
+ def connection_refused(self, exc):
+ """Connection is refused."""
+
+
+class SubprocessProtocol(BaseProtocol):
+ """ABC representing a protocol for subprocess calls."""
+
+ def pipe_data_received(self, fd, data):
+ """Called when subprocess write a data into stdout/stderr pipes.
+
+ fd is int file dascriptor.
+ data is bytes object.
+ """
+
+ def pipe_connection_lost(self, fd, exc):
+ """Called when a file descriptor associated with the child process is
+ closed.
+
+ fd is the int file descriptor that was closed.
+ """
+
+ def process_exited(self):
+ """Called when subprocess has exited.
+ """
diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py
new file mode 100644
index 0000000..536de1c
--- /dev/null
+++ b/Lib/asyncio/queues.py
@@ -0,0 +1,284 @@
+"""Queues"""
+
+__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue',
+ 'Full', 'Empty']
+
+import collections
+import heapq
+import queue
+
+from . import events
+from . import futures
+from . import locks
+from .tasks import coroutine
+
+
+# Re-export queue.Full and .Empty exceptions.
+Full = queue.Full
+Empty = queue.Empty
+
+
+class Queue:
+ """A queue, useful for coordinating producer and consumer coroutines.
+
+ If maxsize is less than or equal to zero, the queue size is infinite. If it
+ is an integer greater than 0, then "yield from put()" will block when the
+ queue reaches maxsize, until an item is removed by get().
+
+ Unlike the standard library Queue, you can reliably know this Queue's size
+ with qsize(), since your single-threaded Tulip application won't be
+ interrupted between calling qsize() and doing an operation on the Queue.
+ """
+
+ def __init__(self, maxsize=0, *, loop=None):
+ if loop is None:
+ self._loop = events.get_event_loop()
+ else:
+ self._loop = loop
+ self._maxsize = maxsize
+
+ # Futures.
+ self._getters = collections.deque()
+ # Pairs of (item, Future).
+ self._putters = collections.deque()
+ self._init(maxsize)
+
+ def _init(self, maxsize):
+ self._queue = collections.deque()
+
+ def _get(self):
+ return self._queue.popleft()
+
+ def _put(self, item):
+ self._queue.append(item)
+
+ def __repr__(self):
+ return '<{} at {:#x} {}>'.format(
+ type(self).__name__, id(self), self._format())
+
+ def __str__(self):
+ return '<{} {}>'.format(type(self).__name__, self._format())
+
+ def _format(self):
+ result = 'maxsize={!r}'.format(self._maxsize)
+ if getattr(self, '_queue', None):
+ result += ' _queue={!r}'.format(list(self._queue))
+ if self._getters:
+ result += ' _getters[{}]'.format(len(self._getters))
+ if self._putters:
+ result += ' _putters[{}]'.format(len(self._putters))
+ return result
+
+ def _consume_done_getters(self):
+ # Delete waiters at the head of the get() queue who've timed out.
+ while self._getters and self._getters[0].done():
+ self._getters.popleft()
+
+ def _consume_done_putters(self):
+ # Delete waiters at the head of the put() queue who've timed out.
+ while self._putters and self._putters[0][1].done():
+ self._putters.popleft()
+
+ def qsize(self):
+ """Number of items in the queue."""
+ return len(self._queue)
+
+ @property
+ def maxsize(self):
+ """Number of items allowed in the queue."""
+ return self._maxsize
+
+ def empty(self):
+ """Return True if the queue is empty, False otherwise."""
+ return not self._queue
+
+ def full(self):
+ """Return True if there are maxsize items in the queue.
+
+ Note: if the Queue was initialized with maxsize=0 (the default),
+ then full() is never True.
+ """
+ if self._maxsize <= 0:
+ return False
+ else:
+ return self.qsize() == self._maxsize
+
+ @coroutine
+ def put(self, item):
+ """Put an item into the queue.
+
+ If you yield from put(), wait until a free slot is available
+ before adding item.
+ """
+ self._consume_done_getters()
+ if self._getters:
+ assert not self._queue, (
+ 'queue non-empty, why are getters waiting?')
+
+ getter = self._getters.popleft()
+
+ # Use _put and _get instead of passing item straight to getter, in
+ # case a subclass has logic that must run (e.g. JoinableQueue).
+ self._put(item)
+ getter.set_result(self._get())
+
+ elif self._maxsize > 0 and self._maxsize == self.qsize():
+ waiter = futures.Future(loop=self._loop)
+
+ self._putters.append((item, waiter))
+ yield from waiter
+
+ else:
+ self._put(item)
+
+ def put_nowait(self, item):
+ """Put an item into the queue without blocking.
+
+ If no free slot is immediately available, raise Full.
+ """
+ self._consume_done_getters()
+ if self._getters:
+ assert not self._queue, (
+ 'queue non-empty, why are getters waiting?')
+
+ getter = self._getters.popleft()
+
+ # Use _put and _get instead of passing item straight to getter, in
+ # case a subclass has logic that must run (e.g. JoinableQueue).
+ self._put(item)
+ getter.set_result(self._get())
+
+ elif self._maxsize > 0 and self._maxsize == self.qsize():
+ raise Full
+ else:
+ self._put(item)
+
+ @coroutine
+ def get(self):
+ """Remove and return an item from the queue.
+
+ If you yield from get(), wait until a item is available.
+ """
+ self._consume_done_putters()
+ if self._putters:
+ assert self.full(), 'queue not full, why are putters waiting?'
+ item, putter = self._putters.popleft()
+ self._put(item)
+
+ # When a getter runs and frees up a slot so this putter can
+ # run, we need to defer the put for a tick to ensure that
+ # getters and putters alternate perfectly. See
+ # ChannelTest.test_wait.
+ self._loop.call_soon(putter.set_result, None)
+
+ return self._get()
+
+ elif self.qsize():
+ return self._get()
+ else:
+ waiter = futures.Future(loop=self._loop)
+
+ self._getters.append(waiter)
+ return (yield from waiter)
+
+ def get_nowait(self):
+ """Remove and return an item from the queue.
+
+ Return an item if one is immediately available, else raise Full.
+ """
+ self._consume_done_putters()
+ if self._putters:
+ assert self.full(), 'queue not full, why are putters waiting?'
+ item, putter = self._putters.popleft()
+ self._put(item)
+ # Wake putter on next tick.
+ putter.set_result(None)
+
+ return self._get()
+
+ elif self.qsize():
+ return self._get()
+ else:
+ raise Empty
+
+
+class PriorityQueue(Queue):
+ """A subclass of Queue; retrieves entries in priority order (lowest first).
+
+ Entries are typically tuples of the form: (priority number, data).
+ """
+
+ def _init(self, maxsize):
+ self._queue = []
+
+ def _put(self, item, heappush=heapq.heappush):
+ heappush(self._queue, item)
+
+ def _get(self, heappop=heapq.heappop):
+ return heappop(self._queue)
+
+
+class LifoQueue(Queue):
+ """A subclass of Queue that retrieves most recently added entries first."""
+
+ def _init(self, maxsize):
+ self._queue = []
+
+ def _put(self, item):
+ self._queue.append(item)
+
+ def _get(self):
+ return self._queue.pop()
+
+
+class JoinableQueue(Queue):
+ """A subclass of Queue with task_done() and join() methods."""
+
+ def __init__(self, maxsize=0, *, loop=None):
+ super().__init__(maxsize=maxsize, loop=loop)
+ self._unfinished_tasks = 0
+ self._finished = locks.Event(loop=self._loop)
+ self._finished.set()
+
+ def _format(self):
+ result = Queue._format(self)
+ if self._unfinished_tasks:
+ result += ' tasks={}'.format(self._unfinished_tasks)
+ return result
+
+ def _put(self, item):
+ super()._put(item)
+ self._unfinished_tasks += 1
+ self._finished.clear()
+
+ def task_done(self):
+ """Indicate that a formerly enqueued task is complete.
+
+ Used by queue consumers. For each get() used to fetch a task,
+ a subsequent call to task_done() tells the queue that the processing
+ on the task is complete.
+
+ If a join() is currently blocking, it will resume when all items have
+ been processed (meaning that a task_done() call was received for every
+ item that had been put() into the queue).
+
+ Raises ValueError if called more times than there were items placed in
+ the queue.
+ """
+ if self._unfinished_tasks <= 0:
+ raise ValueError('task_done() called too many times')
+ self._unfinished_tasks -= 1
+ if self._unfinished_tasks == 0:
+ self._finished.set()
+
+ @coroutine
+ def join(self):
+ """Block until all items in the queue have been gotten and processed.
+
+ The count of unfinished tasks goes up whenever an item is added to the
+ queue. The count goes down whenever a consumer thread calls task_done()
+ to indicate that the item was retrieved and all work on it is complete.
+ When the count of unfinished tasks drops to zero, join() unblocks.
+ """
+ if self._unfinished_tasks > 0:
+ yield from self._finished.wait()
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
new file mode 100644
index 0000000..bae9a49
--- /dev/null
+++ b/Lib/asyncio/selector_events.py
@@ -0,0 +1,769 @@
+"""Event loop using a selector and related classes.
+
+A selector is a "notify-when-ready" multiplexer. For a subclass which
+also includes support for signal handling, see the unix_events sub-module.
+"""
+
+import collections
+import socket
+try:
+ import ssl
+except ImportError: # pragma: no cover
+ ssl = None
+
+from . import base_events
+from . import constants
+from . import events
+from . import futures
+from . import selectors
+from . import transports
+from .log import asyncio_log
+
+
+class BaseSelectorEventLoop(base_events.BaseEventLoop):
+ """Selector event loop.
+
+ See events.EventLoop for API specification.
+ """
+
+ def __init__(self, selector=None):
+ super().__init__()
+
+ if selector is None:
+ selector = selectors.DefaultSelector()
+ asyncio_log.debug('Using selector: %s', selector.__class__.__name__)
+ self._selector = selector
+ self._make_self_pipe()
+
+ def _make_socket_transport(self, sock, protocol, waiter=None, *,
+ extra=None, server=None):
+ return _SelectorSocketTransport(self, sock, protocol, waiter,
+ extra, server)
+
+ def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *,
+ server_side=False, server_hostname=None,
+ extra=None, server=None):
+ return _SelectorSslTransport(
+ self, rawsock, protocol, sslcontext, waiter,
+ server_side, server_hostname, extra, server)
+
+ def _make_datagram_transport(self, sock, protocol,
+ address=None, extra=None):
+ return _SelectorDatagramTransport(self, sock, protocol, address, extra)
+
+ def close(self):
+ if self._selector is not None:
+ self._close_self_pipe()
+ self._selector.close()
+ self._selector = None
+
+ def _socketpair(self):
+ raise NotImplementedError
+
+ def _close_self_pipe(self):
+ self.remove_reader(self._ssock.fileno())
+ self._ssock.close()
+ self._ssock = None
+ self._csock.close()
+ self._csock = None
+ self._internal_fds -= 1
+
+ def _make_self_pipe(self):
+ # A self-socket, really. :-)
+ self._ssock, self._csock = self._socketpair()
+ self._ssock.setblocking(False)
+ self._csock.setblocking(False)
+ self._internal_fds += 1
+ self.add_reader(self._ssock.fileno(), self._read_from_self)
+
+ def _read_from_self(self):
+ try:
+ self._ssock.recv(1)
+ except (BlockingIOError, InterruptedError):
+ pass
+
+ def _write_to_self(self):
+ try:
+ self._csock.send(b'x')
+ except (BlockingIOError, InterruptedError):
+ pass
+
+ def _start_serving(self, protocol_factory, sock, ssl=None, server=None):
+ self.add_reader(sock.fileno(), self._accept_connection,
+ protocol_factory, sock, ssl, server)
+
+ def _accept_connection(self, protocol_factory, sock, ssl=None,
+ server=None):
+ try:
+ conn, addr = sock.accept()
+ conn.setblocking(False)
+ except (BlockingIOError, InterruptedError):
+ pass # False alarm.
+ except Exception:
+ # Bad error. Stop serving.
+ self.remove_reader(sock.fileno())
+ sock.close()
+ # There's nowhere to send the error, so just log it.
+ # TODO: Someone will want an error handler for this.
+ asyncio_log.exception('Accept failed')
+ else:
+ if ssl:
+ self._make_ssl_transport(
+ conn, protocol_factory(), ssl, None,
+ server_side=True, extra={'peername': addr}, server=server)
+ else:
+ self._make_socket_transport(
+ conn, protocol_factory(), extra={'peername': addr},
+ server=server)
+ # It's now up to the protocol to handle the connection.
+
+ def add_reader(self, fd, callback, *args):
+ """Add a reader callback."""
+ handle = events.make_handle(callback, args)
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ self._selector.register(fd, selectors.EVENT_READ,
+ (handle, None))
+ else:
+ mask, (reader, writer) = key.events, key.data
+ self._selector.modify(fd, mask | selectors.EVENT_READ,
+ (handle, writer))
+ if reader is not None:
+ reader.cancel()
+
+ def remove_reader(self, fd):
+ """Remove a reader callback."""
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ return False
+ else:
+ mask, (reader, writer) = key.events, key.data
+ mask &= ~selectors.EVENT_READ
+ if not mask:
+ self._selector.unregister(fd)
+ else:
+ self._selector.modify(fd, mask, (None, writer))
+
+ if reader is not None:
+ reader.cancel()
+ return True
+ else:
+ return False
+
+ def add_writer(self, fd, callback, *args):
+ """Add a writer callback.."""
+ handle = events.make_handle(callback, args)
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ self._selector.register(fd, selectors.EVENT_WRITE,
+ (None, handle))
+ else:
+ mask, (reader, writer) = key.events, key.data
+ self._selector.modify(fd, mask | selectors.EVENT_WRITE,
+ (reader, handle))
+ if writer is not None:
+ writer.cancel()
+
+ def remove_writer(self, fd):
+ """Remove a writer callback."""
+ try:
+ key = self._selector.get_key(fd)
+ except KeyError:
+ return False
+ else:
+ mask, (reader, writer) = key.events, key.data
+ # Remove both writer and connector.
+ mask &= ~selectors.EVENT_WRITE
+ if not mask:
+ self._selector.unregister(fd)
+ else:
+ self._selector.modify(fd, mask, (reader, None))
+
+ if writer is not None:
+ writer.cancel()
+ return True
+ else:
+ return False
+
+ def sock_recv(self, sock, n):
+ """XXX"""
+ fut = futures.Future(loop=self)
+ self._sock_recv(fut, False, sock, n)
+ return fut
+
+ def _sock_recv(self, fut, registered, sock, n):
+ fd = sock.fileno()
+ if registered:
+ # Remove the callback early. It should be rare that the
+ # selector says the fd is ready but the call still returns
+ # EAGAIN, and I am willing to take a hit in that case in
+ # order to simplify the common case.
+ self.remove_reader(fd)
+ if fut.cancelled():
+ return
+ try:
+ data = sock.recv(n)
+ except (BlockingIOError, InterruptedError):
+ self.add_reader(fd, self._sock_recv, fut, True, sock, n)
+ except Exception as exc:
+ fut.set_exception(exc)
+ else:
+ fut.set_result(data)
+
+ def sock_sendall(self, sock, data):
+ """XXX"""
+ fut = futures.Future(loop=self)
+ if data:
+ self._sock_sendall(fut, False, sock, data)
+ else:
+ fut.set_result(None)
+ return fut
+
+ def _sock_sendall(self, fut, registered, sock, data):
+ fd = sock.fileno()
+
+ if registered:
+ self.remove_writer(fd)
+ if fut.cancelled():
+ return
+
+ try:
+ n = sock.send(data)
+ except (BlockingIOError, InterruptedError):
+ n = 0
+ except Exception as exc:
+ fut.set_exception(exc)
+ return
+
+ if n == len(data):
+ fut.set_result(None)
+ else:
+ if n:
+ data = data[n:]
+ self.add_writer(fd, self._sock_sendall, fut, True, sock, data)
+
+ def sock_connect(self, sock, address):
+ """XXX"""
+ # That address better not require a lookup! We're not calling
+ # self.getaddrinfo() for you here. But verifying this is
+ # complicated; the socket module doesn't have a pattern for
+ # IPv6 addresses (there are too many forms, apparently).
+ fut = futures.Future(loop=self)
+ self._sock_connect(fut, False, sock, address)
+ return fut
+
+ def _sock_connect(self, fut, registered, sock, address):
+ # TODO: Use getaddrinfo() to look up the address, to avoid the
+ # trap of hanging the entire event loop when the address
+ # requires doing a DNS lookup. (OTOH, the caller should
+ # already have done this, so it would be nice if we could
+ # easily tell whether the address needs looking up or not. I
+ # know how to do this for IPv4, but IPv6 addresses have many
+ # syntaxes.)
+ fd = sock.fileno()
+ if registered:
+ self.remove_writer(fd)
+ if fut.cancelled():
+ return
+ try:
+ if not registered:
+ # First time around.
+ sock.connect(address)
+ else:
+ err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
+ if err != 0:
+ # Jump to the except clause below.
+ raise OSError(err, 'Connect call failed')
+ except (BlockingIOError, InterruptedError):
+ self.add_writer(fd, self._sock_connect, fut, True, sock, address)
+ except Exception as exc:
+ fut.set_exception(exc)
+ else:
+ fut.set_result(None)
+
+ def sock_accept(self, sock):
+ """XXX"""
+ fut = futures.Future(loop=self)
+ self._sock_accept(fut, False, sock)
+ return fut
+
+ def _sock_accept(self, fut, registered, sock):
+ fd = sock.fileno()
+ if registered:
+ self.remove_reader(fd)
+ if fut.cancelled():
+ return
+ try:
+ conn, address = sock.accept()
+ conn.setblocking(False)
+ except (BlockingIOError, InterruptedError):
+ self.add_reader(fd, self._sock_accept, fut, True, sock)
+ except Exception as exc:
+ fut.set_exception(exc)
+ else:
+ fut.set_result((conn, address))
+
+ def _process_events(self, event_list):
+ for key, mask in event_list:
+ fileobj, (reader, writer) = key.fileobj, key.data
+ if mask & selectors.EVENT_READ and reader is not None:
+ if reader._cancelled:
+ self.remove_reader(fileobj)
+ else:
+ self._add_callback(reader)
+ if mask & selectors.EVENT_WRITE and writer is not None:
+ if writer._cancelled:
+ self.remove_writer(fileobj)
+ else:
+ self._add_callback(writer)
+
+ def _stop_serving(self, sock):
+ self.remove_reader(sock.fileno())
+ sock.close()
+
+
+class _SelectorTransport(transports.Transport):
+
+ max_size = 256 * 1024 # Buffer size passed to recv().
+
+ def __init__(self, loop, sock, protocol, extra, server=None):
+ super().__init__(extra)
+ self._extra['socket'] = sock
+ self._extra['sockname'] = sock.getsockname()
+ if 'peername' not in self._extra:
+ try:
+ self._extra['peername'] = sock.getpeername()
+ except socket.error:
+ self._extra['peername'] = None
+ self._loop = loop
+ self._sock = sock
+ self._sock_fd = sock.fileno()
+ self._protocol = protocol
+ self._server = server
+ self._buffer = collections.deque()
+ self._conn_lost = 0
+ self._closing = False # Set when close() called.
+ if server is not None:
+ server.attach(self)
+
+ def abort(self):
+ self._force_close(None)
+
+ def close(self):
+ if self._closing:
+ return
+ self._closing = True
+ self._conn_lost += 1
+ self._loop.remove_reader(self._sock_fd)
+ if not self._buffer:
+ self._loop.call_soon(self._call_connection_lost, None)
+
+ def _fatal_error(self, exc):
+ # should be called from exception handler only
+ asyncio_log.exception('Fatal error for %s', self)
+ self._force_close(exc)
+
+ def _force_close(self, exc):
+ if self._buffer:
+ self._buffer.clear()
+ self._loop.remove_writer(self._sock_fd)
+
+ if self._closing:
+ return
+
+ self._closing = True
+ self._conn_lost += 1
+ self._loop.remove_reader(self._sock_fd)
+ self._loop.call_soon(self._call_connection_lost, exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._sock.close()
+ self._sock = None
+ self._protocol = None
+ self._loop = None
+ server = self._server
+ if server is not None:
+ server.detach(self)
+ self._server = None
+
+
+class _SelectorSocketTransport(_SelectorTransport):
+
+ def __init__(self, loop, sock, protocol, waiter=None,
+ extra=None, server=None):
+ super().__init__(loop, sock, protocol, extra, server)
+ self._eof = False
+ self._paused = False
+
+ self._loop.add_reader(self._sock_fd, self._read_ready)
+ self._loop.call_soon(self._protocol.connection_made, self)
+ if waiter is not None:
+ self._loop.call_soon(waiter.set_result, None)
+
+ def pause(self):
+ assert not self._closing, 'Cannot pause() when closing'
+ assert not self._paused, 'Already paused'
+ self._paused = True
+ self._loop.remove_reader(self._sock_fd)
+
+ def resume(self):
+ assert self._paused, 'Not paused'
+ self._paused = False
+ if self._closing:
+ return
+ self._loop.add_reader(self._sock_fd, self._read_ready)
+
+ def _read_ready(self):
+ try:
+ data = self._sock.recv(self.max_size)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except ConnectionResetError as exc:
+ self._force_close(exc)
+ except Exception as exc:
+ self._fatal_error(exc)
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ keep_open = self._protocol.eof_received()
+ if not keep_open:
+ self.close()
+
+ def write(self, data):
+ assert isinstance(data, bytes), repr(type(data))
+ assert not self._eof, 'Cannot call write() after write_eof()'
+ if not data:
+ return
+
+ if self._conn_lost:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ asyncio_log.warning('socket.send() raised exception.')
+ self._conn_lost += 1
+ return
+
+ if not self._buffer:
+ # Attempt to send it right away first.
+ try:
+ n = self._sock.send(data)
+ except (BlockingIOError, InterruptedError):
+ n = 0
+ except (BrokenPipeError, ConnectionResetError) as exc:
+ self._force_close(exc)
+ return
+ except OSError as exc:
+ self._fatal_error(exc)
+ return
+ else:
+ data = data[n:]
+ if not data:
+ return
+ # Start async I/O.
+ self._loop.add_writer(self._sock_fd, self._write_ready)
+
+ self._buffer.append(data)
+
+ def _write_ready(self):
+ data = b''.join(self._buffer)
+ assert data, 'Data should not be empty'
+
+ self._buffer.clear()
+ try:
+ n = self._sock.send(data)
+ except (BlockingIOError, InterruptedError):
+ self._buffer.append(data)
+ except (BrokenPipeError, ConnectionResetError) as exc:
+ self._loop.remove_writer(self._sock_fd)
+ self._force_close(exc)
+ except Exception as exc:
+ self._loop.remove_writer(self._sock_fd)
+ self._fatal_error(exc)
+ else:
+ data = data[n:]
+ if not data:
+ self._loop.remove_writer(self._sock_fd)
+ if self._closing:
+ self._call_connection_lost(None)
+ elif self._eof:
+ self._sock.shutdown(socket.SHUT_WR)
+ return
+
+ self._buffer.append(data) # Try again later.
+
+ def write_eof(self):
+ if self._eof:
+ return
+ self._eof = True
+ if not self._buffer:
+ self._sock.shutdown(socket.SHUT_WR)
+
+ def can_write_eof(self):
+ return True
+
+
+class _SelectorSslTransport(_SelectorTransport):
+
+ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None,
+ server_side=False, server_hostname=None,
+ extra=None, server=None):
+ if server_side:
+ assert isinstance(
+ sslcontext, ssl.SSLContext), 'Must pass an SSLContext'
+ else:
+ # Client-side may pass ssl=True to use a default context.
+ # The default is the same as used by urllib.
+ if sslcontext is None:
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext.options |= ssl.OP_NO_SSLv2
+ sslcontext.set_default_verify_paths()
+ sslcontext.verify_mode = ssl.CERT_REQUIRED
+ wrap_kwargs = {
+ 'server_side': server_side,
+ 'do_handshake_on_connect': False,
+ }
+ if server_hostname is not None and not server_side and ssl.HAS_SNI:
+ wrap_kwargs['server_hostname'] = server_hostname
+ sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs)
+
+ super().__init__(loop, sslsock, protocol, extra, server)
+
+ self._server_hostname = server_hostname
+ self._waiter = waiter
+ self._rawsock = rawsock
+ self._sslcontext = sslcontext
+ self._paused = False
+
+ # SSL-specific extra info. (peercert is set later)
+ self._extra.update(sslcontext=sslcontext)
+
+ self._on_handshake()
+
+ def _on_handshake(self):
+ try:
+ self._sock.do_handshake()
+ except ssl.SSLWantReadError:
+ self._loop.add_reader(self._sock_fd, self._on_handshake)
+ return
+ except ssl.SSLWantWriteError:
+ self._loop.add_writer(self._sock_fd, self._on_handshake)
+ return
+ except Exception as exc:
+ self._sock.close()
+ if self._waiter is not None:
+ self._waiter.set_exception(exc)
+ return
+ except BaseException as exc:
+ self._sock.close()
+ if self._waiter is not None:
+ self._waiter.set_exception(exc)
+ raise
+
+ # Verify hostname if requested.
+ peercert = self._sock.getpeercert()
+ if (self._server_hostname is not None and
+ self._sslcontext.verify_mode == ssl.CERT_REQUIRED):
+ try:
+ ssl.match_hostname(peercert, self._server_hostname)
+ except Exception as exc:
+ self._sock.close()
+ if self._waiter is not None:
+ self._waiter.set_exception(exc)
+ return
+
+ # Add extra info that becomes available after handshake.
+ self._extra.update(peercert=peercert,
+ cipher=self._sock.cipher(),
+ compression=self._sock.compression(),
+ )
+
+ self._loop.remove_reader(self._sock_fd)
+ self._loop.remove_writer(self._sock_fd)
+ self._loop.add_reader(self._sock_fd, self._on_ready)
+ self._loop.add_writer(self._sock_fd, self._on_ready)
+ self._loop.call_soon(self._protocol.connection_made, self)
+ if self._waiter is not None:
+ self._loop.call_soon(self._waiter.set_result, None)
+
+ def pause(self):
+ # XXX This is a bit icky, given the comment at the top of
+ # _on_ready(). Is it possible to evoke a deadlock? I don't
+ # know, although it doesn't look like it; write() will still
+ # accept more data for the buffer and eventually the app will
+ # call resume() again, and things will flow again.
+
+ assert not self._closing, 'Cannot pause() when closing'
+ assert not self._paused, 'Already paused'
+ self._paused = True
+ self._loop.remove_reader(self._sock_fd)
+
+ def resume(self):
+ assert self._paused, 'Not paused'
+ self._paused = False
+ if self._closing:
+ return
+ self._loop.add_reader(self._sock_fd, self._on_ready)
+
+ def _on_ready(self):
+ # Because of renegotiations (?), there's no difference between
+ # readable and writable. We just try both. XXX This may be
+ # incorrect; we probably need to keep state about what we
+ # should do next.
+
+ # First try reading.
+ if not self._closing and not self._paused:
+ try:
+ data = self._sock.recv(self.max_size)
+ except (BlockingIOError, InterruptedError,
+ ssl.SSLWantReadError, ssl.SSLWantWriteError):
+ pass
+ except ConnectionResetError as exc:
+ self._force_close(exc)
+ except Exception as exc:
+ self._fatal_error(exc)
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ try:
+ self._protocol.eof_received()
+ finally:
+ self.close()
+
+ # Now try writing, if there's anything to write.
+ if self._buffer:
+ data = b''.join(self._buffer)
+ self._buffer.clear()
+ try:
+ n = self._sock.send(data)
+ except (BlockingIOError, InterruptedError,
+ ssl.SSLWantReadError, ssl.SSLWantWriteError):
+ n = 0
+ except (BrokenPipeError, ConnectionResetError) as exc:
+ self._loop.remove_writer(self._sock_fd)
+ self._force_close(exc)
+ return
+ except Exception as exc:
+ self._loop.remove_writer(self._sock_fd)
+ self._fatal_error(exc)
+ return
+
+ if n < len(data):
+ self._buffer.append(data[n:])
+
+ if self._closing and not self._buffer:
+ self._loop.remove_writer(self._sock_fd)
+ self._call_connection_lost(None)
+
+ def write(self, data):
+ assert isinstance(data, bytes), repr(type(data))
+ if not data:
+ return
+
+ if self._conn_lost:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ asyncio_log.warning('socket.send() raised exception.')
+ self._conn_lost += 1
+ return
+
+ self._buffer.append(data)
+ # We could optimize, but the callback can do this for now.
+
+ def can_write_eof(self):
+ return False
+
+ def close(self):
+ if self._closing:
+ return
+ self._closing = True
+ self._conn_lost += 1
+ self._loop.remove_reader(self._sock_fd)
+
+
+class _SelectorDatagramTransport(_SelectorTransport):
+
+ def __init__(self, loop, sock, protocol, address=None, extra=None):
+ super().__init__(loop, sock, protocol, extra)
+
+ self._address = address
+ self._loop.add_reader(self._sock_fd, self._read_ready)
+ self._loop.call_soon(self._protocol.connection_made, self)
+
+ def _read_ready(self):
+ try:
+ data, addr = self._sock.recvfrom(self.max_size)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except Exception as exc:
+ self._fatal_error(exc)
+ else:
+ self._protocol.datagram_received(data, addr)
+
+ def sendto(self, data, addr=None):
+ assert isinstance(data, bytes), repr(type(data))
+ if not data:
+ return
+
+ if self._address:
+ assert addr in (None, self._address)
+
+ if self._conn_lost and self._address:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ asyncio_log.warning('socket.send() raised exception.')
+ self._conn_lost += 1
+ return
+
+ if not self._buffer:
+ # Attempt to send it right away first.
+ try:
+ if self._address:
+ self._sock.send(data)
+ else:
+ self._sock.sendto(data, addr)
+ return
+ except ConnectionRefusedError as exc:
+ if self._address:
+ self._fatal_error(exc)
+ return
+ except (BlockingIOError, InterruptedError):
+ self._loop.add_writer(self._sock_fd, self._sendto_ready)
+ except Exception as exc:
+ self._fatal_error(exc)
+ return
+
+ self._buffer.append((data, addr))
+
+ def _sendto_ready(self):
+ while self._buffer:
+ data, addr = self._buffer.popleft()
+ try:
+ if self._address:
+ self._sock.send(data)
+ else:
+ self._sock.sendto(data, addr)
+ except ConnectionRefusedError as exc:
+ if self._address:
+ self._fatal_error(exc)
+ return
+ except (BlockingIOError, InterruptedError):
+ self._buffer.appendleft((data, addr)) # Try again later.
+ break
+ except Exception as exc:
+ self._fatal_error(exc)
+ return
+
+ if not self._buffer:
+ self._loop.remove_writer(self._sock_fd)
+ if self._closing:
+ self._call_connection_lost(None)
+
+ def _force_close(self, exc):
+ if self._address and isinstance(exc, ConnectionRefusedError):
+ self._protocol.connection_refused(exc)
+
+ super()._force_close(exc)
diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py
new file mode 100644
index 0000000..d0f12e8
--- /dev/null
+++ b/Lib/asyncio/streams.py
@@ -0,0 +1,257 @@
+"""Stream-related things."""
+
+__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection']
+
+import collections
+
+from . import events
+from . import futures
+from . import protocols
+from . import tasks
+
+
+_DEFAULT_LIMIT = 2**16
+
+
+@tasks.coroutine
+def open_connection(host=None, port=None, *,
+ loop=None, limit=_DEFAULT_LIMIT, **kwds):
+ """A wrapper for create_connection() returning a (reader, writer) pair.
+
+ The reader returned is a StreamReader instance; the writer is a
+ Transport.
+
+ The arguments are all the usual arguments to create_connection()
+ except protocol_factory; most common are positional host and port,
+ with various optional keyword arguments following.
+
+ Additional optional keyword arguments are loop (to set the event loop
+ instance to use) and limit (to set the buffer limit passed to the
+ StreamReader).
+
+ (If you want to customize the StreamReader and/or
+ StreamReaderProtocol classes, just copy the code -- there's
+ really nothing special here except some convenience.)
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+ reader = StreamReader(limit=limit, loop=loop)
+ protocol = StreamReaderProtocol(reader)
+ transport, _ = yield from loop.create_connection(
+ lambda: protocol, host, port, **kwds)
+ return reader, transport # (reader, writer)
+
+
+class StreamReaderProtocol(protocols.Protocol):
+ """Trivial helper class to adapt between Protocol and StreamReader.
+
+ (This is a helper class instead of making StreamReader itself a
+ Protocol subclass, because the StreamReader has other potential
+ uses, and to prevent the user of the StreamReader to accidentally
+ call inappropriate methods of the protocol.)
+ """
+
+ def __init__(self, stream_reader):
+ self.stream_reader = stream_reader
+
+ def connection_made(self, transport):
+ self.stream_reader.set_transport(transport)
+
+ def connection_lost(self, exc):
+ if exc is None:
+ self.stream_reader.feed_eof()
+ else:
+ self.stream_reader.set_exception(exc)
+
+ def data_received(self, data):
+ self.stream_reader.feed_data(data)
+
+ def eof_received(self):
+ self.stream_reader.feed_eof()
+
+
+class StreamReader:
+
+ def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
+ # The line length limit is a security feature;
+ # it also doubles as half the buffer limit.
+ self.limit = limit
+ if loop is None:
+ loop = events.get_event_loop()
+ self.loop = loop
+ self.buffer = collections.deque() # Deque of bytes objects.
+ self.byte_count = 0 # Bytes in buffer.
+ self.eof = False # Whether we're done.
+ self.waiter = None # A future.
+ self._exception = None
+ self._transport = None
+ self._paused = False
+
+ def exception(self):
+ return self._exception
+
+ def set_exception(self, exc):
+ self._exception = exc
+
+ waiter = self.waiter
+ if waiter is not None:
+ self.waiter = None
+ if not waiter.cancelled():
+ waiter.set_exception(exc)
+
+ def set_transport(self, transport):
+ assert self._transport is None, 'Transport already set'
+ self._transport = transport
+
+ def _maybe_resume_transport(self):
+ if self._paused and self.byte_count <= self.limit:
+ self._paused = False
+ self._transport.resume()
+
+ def feed_eof(self):
+ self.eof = True
+ waiter = self.waiter
+ if waiter is not None:
+ self.waiter = None
+ if not waiter.cancelled():
+ waiter.set_result(True)
+
+ def feed_data(self, data):
+ if not data:
+ return
+
+ self.buffer.append(data)
+ self.byte_count += len(data)
+
+ waiter = self.waiter
+ if waiter is not None:
+ self.waiter = None
+ if not waiter.cancelled():
+ waiter.set_result(False)
+
+ if (self._transport is not None and
+ not self._paused and
+ self.byte_count > 2*self.limit):
+ try:
+ self._transport.pause()
+ except NotImplementedError:
+ # The transport can't be paused.
+ # We'll just have to buffer all data.
+ # Forget the transport so we don't keep trying.
+ self._transport = None
+ else:
+ self._paused = True
+
+ @tasks.coroutine
+ def readline(self):
+ if self._exception is not None:
+ raise self._exception
+
+ parts = []
+ parts_size = 0
+ not_enough = True
+
+ while not_enough:
+ while self.buffer and not_enough:
+ data = self.buffer.popleft()
+ ichar = data.find(b'\n')
+ if ichar < 0:
+ parts.append(data)
+ parts_size += len(data)
+ else:
+ ichar += 1
+ head, tail = data[:ichar], data[ichar:]
+ if tail:
+ self.buffer.appendleft(tail)
+ not_enough = False
+ parts.append(head)
+ parts_size += len(head)
+
+ if parts_size > self.limit:
+ self.byte_count -= parts_size
+ self._maybe_resume_transport()
+ raise ValueError('Line is too long')
+
+ if self.eof:
+ break
+
+ if not_enough:
+ assert self.waiter is None
+ self.waiter = futures.Future(loop=self.loop)
+ try:
+ yield from self.waiter
+ finally:
+ self.waiter = None
+
+ line = b''.join(parts)
+ self.byte_count -= parts_size
+ self._maybe_resume_transport()
+
+ return line
+
+ @tasks.coroutine
+ def read(self, n=-1):
+ if self._exception is not None:
+ raise self._exception
+
+ if not n:
+ return b''
+
+ if n < 0:
+ while not self.eof:
+ assert not self.waiter
+ self.waiter = futures.Future(loop=self.loop)
+ try:
+ yield from self.waiter
+ finally:
+ self.waiter = None
+ else:
+ if not self.byte_count and not self.eof:
+ assert not self.waiter
+ self.waiter = futures.Future(loop=self.loop)
+ try:
+ yield from self.waiter
+ finally:
+ self.waiter = None
+
+ if n < 0 or self.byte_count <= n:
+ data = b''.join(self.buffer)
+ self.buffer.clear()
+ self.byte_count = 0
+ self._maybe_resume_transport()
+ return data
+
+ parts = []
+ parts_bytes = 0
+ while self.buffer and parts_bytes < n:
+ data = self.buffer.popleft()
+ data_bytes = len(data)
+ if n < parts_bytes + data_bytes:
+ data_bytes = n - parts_bytes
+ data, rest = data[:data_bytes], data[data_bytes:]
+ self.buffer.appendleft(rest)
+
+ parts.append(data)
+ parts_bytes += data_bytes
+ self.byte_count -= data_bytes
+ self._maybe_resume_transport()
+
+ return b''.join(parts)
+
+ @tasks.coroutine
+ def readexactly(self, n):
+ if self._exception is not None:
+ raise self._exception
+
+ if n <= 0:
+ return b''
+
+ while self.byte_count < n and not self.eof:
+ assert not self.waiter
+ self.waiter = futures.Future(loop=self.loop)
+ try:
+ yield from self.waiter
+ finally:
+ self.waiter = None
+
+ return (yield from self.read(n))
diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py
new file mode 100644
index 0000000..2c8579f
--- /dev/null
+++ b/Lib/asyncio/tasks.py
@@ -0,0 +1,636 @@
+"""Support for tasks, coroutines and the scheduler."""
+
+__all__ = ['coroutine', 'Task',
+ 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED',
+ 'wait', 'wait_for', 'as_completed', 'sleep', 'async',
+ 'gather',
+ ]
+
+import collections
+import concurrent.futures
+import functools
+import inspect
+import linecache
+import traceback
+import weakref
+
+from . import events
+from . import futures
+from .log import asyncio_log
+
+# If you set _DEBUG to true, @coroutine will wrap the resulting
+# generator objects in a CoroWrapper instance (defined below). That
+# instance will log a message when the generator is never iterated
+# over, which may happen when you forget to use "yield from" with a
+# coroutine call. Note that the value of the _DEBUG flag is taken
+# when the decorator is used, so to be of any use it must be set
+# before you define your coroutines. A downside of using this feature
+# is that tracebacks show entries for the CoroWrapper.__next__ method
+# when _DEBUG is true.
+_DEBUG = False
+
+
+class CoroWrapper:
+ """Wrapper for coroutine in _DEBUG mode."""
+
+ __slot__ = ['gen', 'func']
+
+ def __init__(self, gen, func):
+ assert inspect.isgenerator(gen), gen
+ self.gen = gen
+ self.func = func
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ return next(self.gen)
+
+ def send(self, value):
+ return self.gen.send(value)
+
+ def throw(self, exc):
+ return self.gen.throw(exc)
+
+ def close(self):
+ return self.gen.close()
+
+ def __del__(self):
+ frame = self.gen.gi_frame
+ if frame is not None and frame.f_lasti == -1:
+ func = self.func
+ code = func.__code__
+ filename = code.co_filename
+ lineno = code.co_firstlineno
+ asyncio_log.error('Coroutine %r defined at %s:%s was never yielded from',
+ func.__name__, filename, lineno)
+
+
+def coroutine(func):
+ """Decorator to mark coroutines.
+
+ If the coroutine is not yielded from before it is destroyed,
+ an error message is logged.
+ """
+ if inspect.isgeneratorfunction(func):
+ coro = func
+ else:
+ @functools.wraps(func)
+ def coro(*args, **kw):
+ res = func(*args, **kw)
+ if isinstance(res, futures.Future) or inspect.isgenerator(res):
+ res = yield from res
+ return res
+
+ if not _DEBUG:
+ wrapper = coro
+ else:
+ @functools.wraps(func)
+ def wrapper(*args, **kwds):
+ w = CoroWrapper(coro(*args, **kwds), func)
+ w.__name__ = coro.__name__
+ w.__doc__ = coro.__doc__
+ return w
+
+ wrapper._is_coroutine = True # For iscoroutinefunction().
+ return wrapper
+
+
+def iscoroutinefunction(func):
+ """Return True if func is a decorated coroutine function."""
+ return getattr(func, '_is_coroutine', False)
+
+
+def iscoroutine(obj):
+ """Return True if obj is a coroutine object."""
+ return isinstance(obj, CoroWrapper) or inspect.isgenerator(obj)
+
+
+class Task(futures.Future):
+ """A coroutine wrapped in a Future."""
+
+ # An important invariant maintained while a Task not done:
+ #
+ # - Either _fut_waiter is None, and _step() is scheduled;
+ # - or _fut_waiter is some Future, and _step() is *not* scheduled.
+ #
+ # The only transition from the latter to the former is through
+ # _wakeup(). When _fut_waiter is not None, one of its callbacks
+ # must be _wakeup().
+
+ # Weak set containing all tasks alive.
+ _all_tasks = weakref.WeakSet()
+
+ @classmethod
+ def all_tasks(cls, loop=None):
+ """Return a set of all tasks for an event loop.
+
+ By default all tasks for the current event loop are returned.
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+ return {t for t in cls._all_tasks if t._loop is loop}
+
+ def __init__(self, coro, *, loop=None):
+ assert iscoroutine(coro), repr(coro) # Not a coroutine function!
+ super().__init__(loop=loop)
+ self._coro = iter(coro) # Use the iterator just in case.
+ self._fut_waiter = None
+ self._must_cancel = False
+ self._loop.call_soon(self._step)
+ self.__class__._all_tasks.add(self)
+
+ def __repr__(self):
+ res = super().__repr__()
+ if (self._must_cancel and
+ self._state == futures._PENDING and
+ '<PENDING' in res):
+ res = res.replace('<PENDING', '<CANCELLING', 1)
+ i = res.find('<')
+ if i < 0:
+ i = len(res)
+ res = res[:i] + '(<{}>)'.format(self._coro.__name__) + res[i:]
+ return res
+
+ def get_stack(self, *, limit=None):
+ """Return the list of stack frames for this task's coroutine.
+
+ If the coroutine is active, this returns the stack where it is
+ suspended. If the coroutine has completed successfully or was
+ cancelled, this returns an empty list. If the coroutine was
+ terminated by an exception, this returns the list of traceback
+ frames.
+
+ The frames are always ordered from oldest to newest.
+
+ The optional limit gives the maximum nummber of frames to
+ return; by default all available frames are returned. Its
+ meaning differs depending on whether a stack or a traceback is
+ returned: the newest frames of a stack are returned, but the
+ oldest frames of a traceback are returned. (This matches the
+ behavior of the traceback module.)
+
+ For reasons beyond our control, only one stack frame is
+ returned for a suspended coroutine.
+ """
+ frames = []
+ f = self._coro.gi_frame
+ if f is not None:
+ while f is not None:
+ if limit is not None:
+ if limit <= 0:
+ break
+ limit -= 1
+ frames.append(f)
+ f = f.f_back
+ frames.reverse()
+ elif self._exception is not None:
+ tb = self._exception.__traceback__
+ while tb is not None:
+ if limit is not None:
+ if limit <= 0:
+ break
+ limit -= 1
+ frames.append(tb.tb_frame)
+ tb = tb.tb_next
+ return frames
+
+ def print_stack(self, *, limit=None, file=None):
+ """Print the stack or traceback for this task's coroutine.
+
+ This produces output similar to that of the traceback module,
+ for the frames retrieved by get_stack(). The limit argument
+ is passed to get_stack(). The file argument is an I/O stream
+ to which the output goes; by default it goes to sys.stderr.
+ """
+ extracted_list = []
+ checked = set()
+ for f in self.get_stack(limit=limit):
+ lineno = f.f_lineno
+ co = f.f_code
+ filename = co.co_filename
+ name = co.co_name
+ if filename not in checked:
+ checked.add(filename)
+ linecache.checkcache(filename)
+ line = linecache.getline(filename, lineno, f.f_globals)
+ extracted_list.append((filename, lineno, name, line))
+ exc = self._exception
+ if not extracted_list:
+ print('No stack for %r' % self, file=file)
+ elif exc is not None:
+ print('Traceback for %r (most recent call last):' % self,
+ file=file)
+ else:
+ print('Stack for %r (most recent call last):' % self,
+ file=file)
+ traceback.print_list(extracted_list, file=file)
+ if exc is not None:
+ for line in traceback.format_exception_only(exc.__class__, exc):
+ print(line, file=file, end='')
+
+ def cancel(self):
+ if self.done():
+ return False
+ if self._fut_waiter is not None:
+ if self._fut_waiter.cancel():
+ # Leave self._fut_waiter; it may be a Task that
+ # catches and ignores the cancellation so we may have
+ # to cancel it again later.
+ return True
+ # It must be the case that self._step is already scheduled.
+ self._must_cancel = True
+ return True
+
+ def _step(self, value=None, exc=None):
+ assert not self.done(), \
+ '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc)
+ if self._must_cancel:
+ if not isinstance(exc, futures.CancelledError):
+ exc = futures.CancelledError()
+ self._must_cancel = False
+ coro = self._coro
+ self._fut_waiter = None
+ # Call either coro.throw(exc) or coro.send(value).
+ try:
+ if exc is not None:
+ result = coro.throw(exc)
+ elif value is not None:
+ result = coro.send(value)
+ else:
+ result = next(coro)
+ except StopIteration as exc:
+ self.set_result(exc.value)
+ except futures.CancelledError as exc:
+ super().cancel() # I.e., Future.cancel(self).
+ except Exception as exc:
+ self.set_exception(exc)
+ except BaseException as exc:
+ self.set_exception(exc)
+ raise
+ else:
+ if isinstance(result, futures.Future):
+ # Yielded Future must come from Future.__iter__().
+ if result._blocking:
+ result._blocking = False
+ result.add_done_callback(self._wakeup)
+ self._fut_waiter = result
+ if self._must_cancel:
+ if self._fut_waiter.cancel():
+ self._must_cancel = False
+ else:
+ self._loop.call_soon(
+ self._step, None,
+ RuntimeError(
+ 'yield was used instead of yield from '
+ 'in task {!r} with {!r}'.format(self, result)))
+ elif result is None:
+ # Bare yield relinquishes control for one event loop iteration.
+ self._loop.call_soon(self._step)
+ elif inspect.isgenerator(result):
+ # Yielding a generator is just wrong.
+ self._loop.call_soon(
+ self._step, None,
+ RuntimeError(
+ 'yield was used instead of yield from for '
+ 'generator in task {!r} with {}'.format(
+ self, result)))
+ else:
+ # Yielding something else is an error.
+ self._loop.call_soon(
+ self._step, None,
+ RuntimeError(
+ 'Task got bad yield: {!r}'.format(result)))
+ self = None
+
+ def _wakeup(self, future):
+ try:
+ value = future.result()
+ except Exception as exc:
+ # This may also be a cancellation.
+ self._step(None, exc)
+ else:
+ self._step(value, None)
+ self = None # Needed to break cycles when an exception occurs.
+
+
+# wait() and as_completed() similar to those in PEP 3148.
+
+FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED
+FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION
+ALL_COMPLETED = concurrent.futures.ALL_COMPLETED
+
+
+@coroutine
+def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED):
+ """Wait for the Futures and coroutines given by fs to complete.
+
+ Coroutines will be wrapped in Tasks.
+
+ Returns two sets of Future: (done, pending).
+
+ Usage:
+
+ done, pending = yield from asyncio.wait(fs)
+
+ Note: This does not raise TimeoutError! Futures that aren't done
+ when the timeout occurs are returned in the second set.
+ """
+ if not fs:
+ raise ValueError('Set of coroutines/Futures is empty.')
+
+ if loop is None:
+ loop = events.get_event_loop()
+
+ fs = set(async(f, loop=loop) for f in fs)
+
+ if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED):
+ raise ValueError('Invalid return_when value: {}'.format(return_when))
+ return (yield from _wait(fs, timeout, return_when, loop))
+
+
+def _release_waiter(waiter, value=True, *args):
+ if not waiter.done():
+ waiter.set_result(value)
+
+
+@coroutine
+def wait_for(fut, timeout, *, loop=None):
+ """Wait for the single Future or coroutine to complete, with timeout.
+
+ Coroutine will be wrapped in Task.
+
+ Returns result of the Future or coroutine. Raises TimeoutError when
+ timeout occurs.
+
+ Usage:
+
+ result = yield from asyncio.wait_for(fut, 10.0)
+
+ """
+ if loop is None:
+ loop = events.get_event_loop()
+
+ waiter = futures.Future(loop=loop)
+ timeout_handle = loop.call_later(timeout, _release_waiter, waiter, False)
+ cb = functools.partial(_release_waiter, waiter, True)
+
+ fut = async(fut, loop=loop)
+ fut.add_done_callback(cb)
+
+ try:
+ if (yield from waiter):
+ return fut.result()
+ else:
+ fut.remove_done_callback(cb)
+ raise futures.TimeoutError()
+ finally:
+ timeout_handle.cancel()
+
+
+@coroutine
+def _wait(fs, timeout, return_when, loop):
+ """Internal helper for wait() and _wait_for().
+
+ The fs argument must be a collection of Futures.
+ """
+ assert fs, 'Set of Futures is empty.'
+ waiter = futures.Future(loop=loop)
+ timeout_handle = None
+ if timeout is not None:
+ timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
+ counter = len(fs)
+
+ def _on_completion(f):
+ nonlocal counter
+ counter -= 1
+ if (counter <= 0 or
+ return_when == FIRST_COMPLETED or
+ return_when == FIRST_EXCEPTION and (not f.cancelled() and
+ f.exception() is not None)):
+ if timeout_handle is not None:
+ timeout_handle.cancel()
+ if not waiter.done():
+ waiter.set_result(False)
+
+ for f in fs:
+ f.add_done_callback(_on_completion)
+
+ try:
+ yield from waiter
+ finally:
+ if timeout_handle is not None:
+ timeout_handle.cancel()
+
+ done, pending = set(), set()
+ for f in fs:
+ f.remove_done_callback(_on_completion)
+ if f.done():
+ done.add(f)
+ else:
+ pending.add(f)
+ return done, pending
+
+
+# This is *not* a @coroutine! It is just an iterator (yielding Futures).
+def as_completed(fs, *, loop=None, timeout=None):
+ """Return an iterator whose values, when waited for, are Futures.
+
+ This differs from PEP 3148; the proper way to use this is:
+
+ for f in as_completed(fs):
+ result = yield from f # The 'yield from' may raise.
+ # Use result.
+
+ Raises TimeoutError if the timeout occurs before all Futures are
+ done.
+
+ Note: The futures 'f' are not necessarily members of fs.
+ """
+ loop = loop if loop is not None else events.get_event_loop()
+ deadline = None if timeout is None else loop.time() + timeout
+ todo = set(async(f, loop=loop) for f in fs)
+ completed = collections.deque()
+
+ @coroutine
+ def _wait_for_one():
+ while not completed:
+ timeout = None
+ if deadline is not None:
+ timeout = deadline - loop.time()
+ if timeout < 0:
+ raise futures.TimeoutError()
+ done, pending = yield from _wait(
+ todo, timeout, FIRST_COMPLETED, loop)
+ # Multiple callers might be waiting for the same events
+ # and getting the same outcome. Dedupe by updating todo.
+ for f in done:
+ if f in todo:
+ todo.remove(f)
+ completed.append(f)
+ f = completed.popleft()
+ return f.result() # May raise.
+
+ for _ in range(len(todo)):
+ yield _wait_for_one()
+
+
+@coroutine
+def sleep(delay, result=None, *, loop=None):
+ """Coroutine that completes after a given time (in seconds)."""
+ future = futures.Future(loop=loop)
+ h = future._loop.call_later(delay, future.set_result, result)
+ try:
+ return (yield from future)
+ finally:
+ h.cancel()
+
+
+def async(coro_or_future, *, loop=None):
+ """Wrap a coroutine in a future.
+
+ If the argument is a Future, it is returned directly.
+ """
+ if isinstance(coro_or_future, futures.Future):
+ if loop is not None and loop is not coro_or_future._loop:
+ raise ValueError('loop argument must agree with Future')
+ return coro_or_future
+ elif iscoroutine(coro_or_future):
+ return Task(coro_or_future, loop=loop)
+ else:
+ raise TypeError('A Future or coroutine is required')
+
+
+class _GatheringFuture(futures.Future):
+ """Helper for gather().
+
+ This overrides cancel() to cancel all the children and act more
+ like Task.cancel(), which doesn't immediately mark itself as
+ cancelled.
+ """
+
+ def __init__(self, children, *, loop=None):
+ super().__init__(loop=loop)
+ self._children = children
+
+ def cancel(self):
+ if self.done():
+ return False
+ for child in self._children:
+ child.cancel()
+ return True
+
+
+def gather(*coros_or_futures, loop=None, return_exceptions=False):
+ """Return a future aggregating results from the given coroutines
+ or futures.
+
+ All futures must share the same event loop. If all the tasks are
+ done successfully, the returned future's result is the list of
+ results (in the order of the original sequence, not necessarily
+ the order of results arrival). If *result_exception* is True,
+ exceptions in the tasks are treated the same as successful
+ results, and gathered in the result list; otherwise, the first
+ raised exception will be immediately propagated to the returned
+ future.
+
+ Cancellation: if the outer Future is cancelled, all children (that
+ have not completed yet) are also cancelled. If any child is
+ cancelled, this is treated as if it raised CancelledError --
+ the outer Future is *not* cancelled in this case. (This is to
+ prevent the cancellation of one child to cause other children to
+ be cancelled.)
+ """
+ children = [async(fut, loop=loop) for fut in coros_or_futures]
+ n = len(children)
+ if n == 0:
+ outer = futures.Future(loop=loop)
+ outer.set_result([])
+ return outer
+ if loop is None:
+ loop = children[0]._loop
+ for fut in children:
+ if fut._loop is not loop:
+ raise ValueError("futures are tied to different event loops")
+ outer = _GatheringFuture(children, loop=loop)
+ nfinished = 0
+ results = [None] * n
+
+ def _done_callback(i, fut):
+ nonlocal nfinished
+ if outer._state != futures._PENDING:
+ if fut._exception is not None:
+ # Mark exception retrieved.
+ fut.exception()
+ return
+ if fut._state == futures._CANCELLED:
+ res = futures.CancelledError()
+ if not return_exceptions:
+ outer.set_exception(res)
+ return
+ elif fut._exception is not None:
+ res = fut.exception() # Mark exception retrieved.
+ if not return_exceptions:
+ outer.set_exception(res)
+ return
+ else:
+ res = fut._result
+ results[i] = res
+ nfinished += 1
+ if nfinished == n:
+ outer.set_result(results)
+
+ for i, fut in enumerate(children):
+ fut.add_done_callback(functools.partial(_done_callback, i))
+ return outer
+
+
+def shield(arg, *, loop=None):
+ """Wait for a future, shielding it from cancellation.
+
+ The statement
+
+ res = yield from shield(something())
+
+ is exactly equivalent to the statement
+
+ res = yield from something()
+
+ *except* that if the coroutine containing it is cancelled, the
+ task running in something() is not cancelled. From the POV of
+ something(), the cancellation did not happen. But its caller is
+ still cancelled, so the yield-from expression still raises
+ CancelledError. Note: If something() is cancelled by other means
+ this will still cancel shield().
+
+ If you want to completely ignore cancellation (not recommended)
+ you can combine shield() with a try/except clause, as follows:
+
+ try:
+ res = yield from shield(something())
+ except CancelledError:
+ res = None
+ """
+ inner = async(arg, loop=loop)
+ if inner.done():
+ # Shortcut.
+ return inner
+ loop = inner._loop
+ outer = futures.Future(loop=loop)
+
+ def _done_callback(inner):
+ if outer.cancelled():
+ # Mark inner's result as retrieved.
+ inner.cancelled() or inner.exception()
+ return
+ if inner.cancelled():
+ outer.cancel()
+ else:
+ exc = inner.exception()
+ if exc is not None:
+ outer.set_exception(exc)
+ else:
+ outer.set_result(inner.result())
+
+ inner.add_done_callback(_done_callback)
+ return outer
diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py
new file mode 100644
index 0000000..91bbedb
--- /dev/null
+++ b/Lib/asyncio/test_utils.py
@@ -0,0 +1,246 @@
+"""Utilities shared by tests."""
+
+import collections
+import contextlib
+import io
+import unittest.mock
+import os
+import sys
+import threading
+import unittest
+import unittest.mock
+from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer
+try:
+ import ssl
+except ImportError: # pragma: no cover
+ ssl = None
+
+from . import tasks
+from . import base_events
+from . import events
+from . import selectors
+
+
+if sys.platform == 'win32': # pragma: no cover
+ from .windows_utils import socketpair
+else:
+ from socket import socketpair # pragma: no cover
+
+
+def dummy_ssl_context():
+ if ssl is None:
+ return None
+ else:
+ return ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+
+
+def run_briefly(loop):
+ @tasks.coroutine
+ def once():
+ pass
+ gen = once()
+ t = tasks.Task(gen, loop=loop)
+ try:
+ loop.run_until_complete(t)
+ finally:
+ gen.close()
+
+
+def run_once(loop):
+ """loop.stop() schedules _raise_stop_error()
+ and run_forever() runs until _raise_stop_error() callback.
+ this wont work if test waits for some IO events, because
+ _raise_stop_error() runs before any of io events callbacks.
+ """
+ loop.stop()
+ loop.run_forever()
+
+
+@contextlib.contextmanager
+def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
+
+ class SilentWSGIRequestHandler(WSGIRequestHandler):
+ def get_stderr(self):
+ return io.StringIO()
+
+ def log_message(self, format, *args):
+ pass
+
+ class SilentWSGIServer(WSGIServer):
+ def handle_error(self, request, client_address):
+ pass
+
+ class SSLWSGIServer(SilentWSGIServer):
+ def finish_request(self, request, client_address):
+ # The relative location of our test directory (which
+ # contains the sample key and certificate files) differs
+ # between the stdlib and stand-alone Tulip/asyncio.
+ # Prefer our own if we can find it.
+ here = os.path.join(os.path.dirname(__file__), '..', 'tests')
+ if not os.path.isdir(here):
+ here = os.path.join(os.path.dirname(os.__file__),
+ 'test', 'test_asyncio')
+ keyfile = os.path.join(here, 'sample.key')
+ certfile = os.path.join(here, 'sample.crt')
+ ssock = ssl.wrap_socket(request,
+ keyfile=keyfile,
+ certfile=certfile,
+ server_side=True)
+ try:
+ self.RequestHandlerClass(ssock, client_address, self)
+ ssock.close()
+ except OSError:
+ # maybe socket has been closed by peer
+ pass
+
+ def app(environ, start_response):
+ status = '200 OK'
+ headers = [('Content-type', 'text/plain')]
+ start_response(status, headers)
+ return [b'Test message']
+
+ # Run the test WSGI server in a separate thread in order not to
+ # interfere with event handling in the main thread
+ server_class = SSLWSGIServer if use_ssl else SilentWSGIServer
+ httpd = make_server(host, port, app,
+ server_class, SilentWSGIRequestHandler)
+ httpd.address = httpd.server_address
+ server_thread = threading.Thread(target=httpd.serve_forever)
+ server_thread.start()
+ try:
+ yield httpd
+ finally:
+ httpd.shutdown()
+ server_thread.join()
+
+
+def make_test_protocol(base):
+ dct = {}
+ for name in dir(base):
+ if name.startswith('__') and name.endswith('__'):
+ # skip magic names
+ continue
+ dct[name] = unittest.mock.Mock(return_value=None)
+ return type('TestProtocol', (base,) + base.__bases__, dct)()
+
+
+class TestSelector(selectors.BaseSelector):
+
+ def select(self, timeout):
+ return []
+
+
+class TestLoop(base_events.BaseEventLoop):
+ """Loop for unittests.
+
+ It manages self time directly.
+ If something scheduled to be executed later then
+ on next loop iteration after all ready handlers done
+ generator passed to __init__ is calling.
+
+ Generator should be like this:
+
+ def gen():
+ ...
+ when = yield ...
+ ... = yield time_advance
+
+ Value retuned by yield is absolute time of next scheduled handler.
+ Value passed to yield is time advance to move loop's time forward.
+ """
+
+ def __init__(self, gen=None):
+ super().__init__()
+
+ if gen is None:
+ def gen():
+ yield
+ self._check_on_close = False
+ else:
+ self._check_on_close = True
+
+ self._gen = gen()
+ next(self._gen)
+ self._time = 0
+ self._timers = []
+ self._selector = TestSelector()
+
+ self.readers = {}
+ self.writers = {}
+ self.reset_counters()
+
+ def time(self):
+ return self._time
+
+ def advance_time(self, advance):
+ """Move test time forward."""
+ if advance:
+ self._time += advance
+
+ def close(self):
+ if self._check_on_close:
+ try:
+ self._gen.send(0)
+ except StopIteration:
+ pass
+ else: # pragma: no cover
+ raise AssertionError("Time generator is not finished")
+
+ def add_reader(self, fd, callback, *args):
+ self.readers[fd] = events.make_handle(callback, args)
+
+ def remove_reader(self, fd):
+ self.remove_reader_count[fd] += 1
+ if fd in self.readers:
+ del self.readers[fd]
+ return True
+ else:
+ return False
+
+ def assert_reader(self, fd, callback, *args):
+ assert fd in self.readers, 'fd {} is not registered'.format(fd)
+ handle = self.readers[fd]
+ assert handle._callback == callback, '{!r} != {!r}'.format(
+ handle._callback, callback)
+ assert handle._args == args, '{!r} != {!r}'.format(
+ handle._args, args)
+
+ def add_writer(self, fd, callback, *args):
+ self.writers[fd] = events.make_handle(callback, args)
+
+ def remove_writer(self, fd):
+ self.remove_writer_count[fd] += 1
+ if fd in self.writers:
+ del self.writers[fd]
+ return True
+ else:
+ return False
+
+ def assert_writer(self, fd, callback, *args):
+ assert fd in self.writers, 'fd {} is not registered'.format(fd)
+ handle = self.writers[fd]
+ assert handle._callback == callback, '{!r} != {!r}'.format(
+ handle._callback, callback)
+ assert handle._args == args, '{!r} != {!r}'.format(
+ handle._args, args)
+
+ def reset_counters(self):
+ self.remove_reader_count = collections.defaultdict(int)
+ self.remove_writer_count = collections.defaultdict(int)
+
+ def _run_once(self):
+ super()._run_once()
+ for when in self._timers:
+ advance = self._gen.send(when)
+ self.advance_time(advance)
+ self._timers = []
+
+ def call_at(self, when, callback, *args):
+ self._timers.append(when)
+ return super().call_at(when, callback, *args)
+
+ def _process_events(self, event_list):
+ return
+
+ def _write_to_self(self):
+ pass
diff --git a/Lib/asyncio/transports.py b/Lib/asyncio/transports.py
new file mode 100644
index 0000000..bf3adee
--- /dev/null
+++ b/Lib/asyncio/transports.py
@@ -0,0 +1,186 @@
+"""Abstract Transport class."""
+
+__all__ = ['ReadTransport', 'WriteTransport', 'Transport']
+
+
+class BaseTransport:
+ """Base ABC for transports."""
+
+ def __init__(self, extra=None):
+ if extra is None:
+ extra = {}
+ self._extra = extra
+
+ def get_extra_info(self, name, default=None):
+ """Get optional transport information."""
+ return self._extra.get(name, default)
+
+ def close(self):
+ """Closes the transport.
+
+ Buffered data will be flushed asynchronously. No more data
+ will be received. After all buffered data is flushed, the
+ protocol's connection_lost() method will (eventually) called
+ with None as its argument.
+ """
+ raise NotImplementedError
+
+
+class ReadTransport(BaseTransport):
+ """ABC for read-only transports."""
+
+ def pause(self):
+ """Pause the receiving end.
+
+ No data will be passed to the protocol's data_received()
+ method until resume() is called.
+ """
+ raise NotImplementedError
+
+ def resume(self):
+ """Resume the receiving end.
+
+ Data received will once again be passed to the protocol's
+ data_received() method.
+ """
+ raise NotImplementedError
+
+
+class WriteTransport(BaseTransport):
+ """ABC for write-only transports."""
+
+ def write(self, data):
+ """Write some data bytes to the transport.
+
+ This does not block; it buffers the data and arranges for it
+ to be sent out asynchronously.
+ """
+ raise NotImplementedError
+
+ def writelines(self, list_of_data):
+ """Write a list (or any iterable) of data bytes to the transport.
+
+ The default implementation just calls write() for each item in
+ the list/iterable.
+ """
+ for data in list_of_data:
+ self.write(data)
+
+ def write_eof(self):
+ """Closes the write end after flushing buffered data.
+
+ (This is like typing ^D into a UNIX program reading from stdin.)
+
+ Data may still be received.
+ """
+ raise NotImplementedError
+
+ def can_write_eof(self):
+ """Return True if this protocol supports write_eof(), False if not."""
+ raise NotImplementedError
+
+ def abort(self):
+ """Closes the transport immediately.
+
+ Buffered data will be lost. No more data will be received.
+ The protocol's connection_lost() method will (eventually) be
+ called with None as its argument.
+ """
+ raise NotImplementedError
+
+
+class Transport(ReadTransport, WriteTransport):
+ """ABC representing a bidirectional transport.
+
+ There may be several implementations, but typically, the user does
+ not implement new transports; rather, the platform provides some
+ useful transports that are implemented using the platform's best
+ practices.
+
+ The user never instantiates a transport directly; they call a
+ utility function, passing it a protocol factory and other
+ information necessary to create the transport and protocol. (E.g.
+ EventLoop.create_connection() or EventLoop.create_server().)
+
+ The utility function will asynchronously create a transport and a
+ protocol and hook them up by calling the protocol's
+ connection_made() method, passing it the transport.
+
+ The implementation here raises NotImplemented for every method
+ except writelines(), which calls write() in a loop.
+ """
+
+
+class DatagramTransport(BaseTransport):
+ """ABC for datagram (UDP) transports."""
+
+ def sendto(self, data, addr=None):
+ """Send data to the transport.
+
+ This does not block; it buffers the data and arranges for it
+ to be sent out asynchronously.
+ addr is target socket address.
+ If addr is None use target address pointed on transport creation.
+ """
+ raise NotImplementedError
+
+ def abort(self):
+ """Closes the transport immediately.
+
+ Buffered data will be lost. No more data will be received.
+ The protocol's connection_lost() method will (eventually) be
+ called with None as its argument.
+ """
+ raise NotImplementedError
+
+
+class SubprocessTransport(BaseTransport):
+
+ def get_pid(self):
+ """Get subprocess id."""
+ raise NotImplementedError
+
+ def get_returncode(self):
+ """Get subprocess returncode.
+
+ See also
+ http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode
+ """
+ raise NotImplementedError
+
+ def get_pipe_transport(self, fd):
+ """Get transport for pipe with number fd."""
+ raise NotImplementedError
+
+ def send_signal(self, signal):
+ """Send signal to subprocess.
+
+ See also:
+ docs.python.org/3/library/subprocess#subprocess.Popen.send_signal
+ """
+ raise NotImplementedError
+
+ def terminate(self):
+ """Stop the subprocess.
+
+ Alias for close() method.
+
+ On Posix OSs the method sends SIGTERM to the subprocess.
+ On Windows the Win32 API function TerminateProcess()
+ is called to stop the subprocess.
+
+ See also:
+ http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate
+ """
+ raise NotImplementedError
+
+ def kill(self):
+ """Kill the subprocess.
+
+ On Posix OSs the function sends SIGKILL to the subprocess.
+ On Windows kill() is an alias for terminate().
+
+ See also:
+ http://docs.python.org/3/library/subprocess#subprocess.Popen.kill
+ """
+ raise NotImplementedError
diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py
new file mode 100644
index 0000000..a3a8e11
--- /dev/null
+++ b/Lib/asyncio/unix_events.py
@@ -0,0 +1,541 @@
+"""Selector eventloop for Unix with signal handling."""
+
+import collections
+import errno
+import fcntl
+import functools
+import os
+import signal
+import socket
+import stat
+import subprocess
+import sys
+
+
+from . import constants
+from . import events
+from . import protocols
+from . import selector_events
+from . import tasks
+from . import transports
+from .log import asyncio_log
+
+
+__all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR']
+
+STDIN = 0
+STDOUT = 1
+STDERR = 2
+
+
+if sys.platform == 'win32': # pragma: no cover
+ raise ImportError('Signals are not really supported on Windows')
+
+
+class SelectorEventLoop(selector_events.BaseSelectorEventLoop):
+ """Unix event loop
+
+ Adds signal handling to SelectorEventLoop
+ """
+
+ def __init__(self, selector=None):
+ super().__init__(selector)
+ self._signal_handlers = {}
+ self._subprocesses = {}
+
+ def _socketpair(self):
+ return socket.socketpair()
+
+ def close(self):
+ handler = self._signal_handlers.get(signal.SIGCHLD)
+ if handler is not None:
+ self.remove_signal_handler(signal.SIGCHLD)
+ super().close()
+
+ def add_signal_handler(self, sig, callback, *args):
+ """Add a handler for a signal. UNIX only.
+
+ Raise ValueError if the signal number is invalid or uncatchable.
+ Raise RuntimeError if there is a problem setting up the handler.
+ """
+ self._check_signal(sig)
+ try:
+ # set_wakeup_fd() raises ValueError if this is not the
+ # main thread. By calling it early we ensure that an
+ # event loop running in another thread cannot add a signal
+ # handler.
+ signal.set_wakeup_fd(self._csock.fileno())
+ except ValueError as exc:
+ raise RuntimeError(str(exc))
+
+ handle = events.make_handle(callback, args)
+ self._signal_handlers[sig] = handle
+
+ try:
+ signal.signal(sig, self._handle_signal)
+ except OSError as exc:
+ del self._signal_handlers[sig]
+ if not self._signal_handlers:
+ try:
+ signal.set_wakeup_fd(-1)
+ except ValueError as nexc:
+ asyncio_log.info('set_wakeup_fd(-1) failed: %s', nexc)
+
+ if exc.errno == errno.EINVAL:
+ raise RuntimeError('sig {} cannot be caught'.format(sig))
+ else:
+ raise
+
+ def _handle_signal(self, sig, arg):
+ """Internal helper that is the actual signal handler."""
+ handle = self._signal_handlers.get(sig)
+ if handle is None:
+ return # Assume it's some race condition.
+ if handle._cancelled:
+ self.remove_signal_handler(sig) # Remove it properly.
+ else:
+ self._add_callback_signalsafe(handle)
+
+ def remove_signal_handler(self, sig):
+ """Remove a handler for a signal. UNIX only.
+
+ Return True if a signal handler was removed, False if not.
+ """
+ self._check_signal(sig)
+ try:
+ del self._signal_handlers[sig]
+ except KeyError:
+ return False
+
+ if sig == signal.SIGINT:
+ handler = signal.default_int_handler
+ else:
+ handler = signal.SIG_DFL
+
+ try:
+ signal.signal(sig, handler)
+ except OSError as exc:
+ if exc.errno == errno.EINVAL:
+ raise RuntimeError('sig {} cannot be caught'.format(sig))
+ else:
+ raise
+
+ if not self._signal_handlers:
+ try:
+ signal.set_wakeup_fd(-1)
+ except ValueError as exc:
+ asyncio_log.info('set_wakeup_fd(-1) failed: %s', exc)
+
+ return True
+
+ def _check_signal(self, sig):
+ """Internal helper to validate a signal.
+
+ Raise ValueError if the signal number is invalid or uncatchable.
+ Raise RuntimeError if there is a problem setting up the handler.
+ """
+ if not isinstance(sig, int):
+ raise TypeError('sig must be an int, not {!r}'.format(sig))
+
+ if not (1 <= sig < signal.NSIG):
+ raise ValueError(
+ 'sig {} out of range(1, {})'.format(sig, signal.NSIG))
+
+ def _make_read_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra)
+
+ def _make_write_pipe_transport(self, pipe, protocol, waiter=None,
+ extra=None):
+ return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra)
+
+ @tasks.coroutine
+ def _make_subprocess_transport(self, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ extra=None, **kwargs):
+ self._reg_sigchld()
+ transp = _UnixSubprocessTransport(self, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ extra=None, **kwargs)
+ self._subprocesses[transp.get_pid()] = transp
+ yield from transp._post_init()
+ return transp
+
+ def _reg_sigchld(self):
+ if signal.SIGCHLD not in self._signal_handlers:
+ self.add_signal_handler(signal.SIGCHLD, self._sig_chld)
+
+ def _sig_chld(self):
+ try:
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG)
+ except ChildProcessError:
+ return
+ if pid == 0:
+ self.call_soon(self._sig_chld)
+ return
+ elif os.WIFSIGNALED(status):
+ returncode = -os.WTERMSIG(status)
+ elif os.WIFEXITED(status):
+ returncode = os.WEXITSTATUS(status)
+ else:
+ self.call_soon(self._sig_chld)
+ return
+ transp = self._subprocesses.get(pid)
+ if transp is not None:
+ transp._process_exited(returncode)
+ except Exception:
+ asyncio_log.exception('Unknown exception in SIGCHLD handler')
+
+ def _subprocess_closed(self, transport):
+ pid = transport.get_pid()
+ self._subprocesses.pop(pid, None)
+
+
+def _set_nonblocking(fd):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL)
+ flags = flags | os.O_NONBLOCK
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags)
+
+
+class _UnixReadPipeTransport(transports.ReadTransport):
+
+ max_size = 256 * 1024 # max bytes we read in one eventloop iteration
+
+ def __init__(self, loop, pipe, protocol, waiter=None, extra=None):
+ super().__init__(extra)
+ self._extra['pipe'] = pipe
+ self._loop = loop
+ self._pipe = pipe
+ self._fileno = pipe.fileno()
+ _set_nonblocking(self._fileno)
+ self._protocol = protocol
+ self._closing = False
+ self._loop.add_reader(self._fileno, self._read_ready)
+ self._loop.call_soon(self._protocol.connection_made, self)
+ if waiter is not None:
+ self._loop.call_soon(waiter.set_result, None)
+
+ def _read_ready(self):
+ try:
+ data = os.read(self._fileno, self.max_size)
+ except (BlockingIOError, InterruptedError):
+ pass
+ except OSError as exc:
+ self._fatal_error(exc)
+ else:
+ if data:
+ self._protocol.data_received(data)
+ else:
+ self._closing = True
+ self._loop.remove_reader(self._fileno)
+ self._loop.call_soon(self._protocol.eof_received)
+ self._loop.call_soon(self._call_connection_lost, None)
+
+ def pause(self):
+ self._loop.remove_reader(self._fileno)
+
+ def resume(self):
+ self._loop.add_reader(self._fileno, self._read_ready)
+
+ def close(self):
+ if not self._closing:
+ self._close(None)
+
+ def _fatal_error(self, exc):
+ # should be called by exception handler only
+ asyncio_log.exception('Fatal error for %s', self)
+ self._close(exc)
+
+ def _close(self, exc):
+ self._closing = True
+ self._loop.remove_reader(self._fileno)
+ self._loop.call_soon(self._call_connection_lost, exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._pipe.close()
+ self._pipe = None
+ self._protocol = None
+ self._loop = None
+
+
+class _UnixWritePipeTransport(transports.WriteTransport):
+
+ def __init__(self, loop, pipe, protocol, waiter=None, extra=None):
+ super().__init__(extra)
+ self._extra['pipe'] = pipe
+ self._loop = loop
+ self._pipe = pipe
+ self._fileno = pipe.fileno()
+ if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode):
+ raise ValueError("Pipe transport is for pipes only.")
+ _set_nonblocking(self._fileno)
+ self._protocol = protocol
+ self._buffer = []
+ self._conn_lost = 0
+ self._closing = False # Set when close() or write_eof() called.
+ self._loop.add_reader(self._fileno, self._read_ready)
+
+ self._loop.call_soon(self._protocol.connection_made, self)
+ if waiter is not None:
+ self._loop.call_soon(waiter.set_result, None)
+
+ def _read_ready(self):
+ # pipe was closed by peer
+ self._close()
+
+ def write(self, data):
+ assert isinstance(data, bytes), repr(data)
+ if not data:
+ return
+
+ if self._conn_lost or self._closing:
+ if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+ asyncio_log.warning('pipe closed by peer or '
+ 'os.write(pipe, data) raised exception.')
+ self._conn_lost += 1
+ return
+
+ if not self._buffer:
+ # Attempt to send it right away first.
+ try:
+ n = os.write(self._fileno, data)
+ except (BlockingIOError, InterruptedError):
+ n = 0
+ except Exception as exc:
+ self._conn_lost += 1
+ self._fatal_error(exc)
+ return
+ if n == len(data):
+ return
+ elif n > 0:
+ data = data[n:]
+ self._loop.add_writer(self._fileno, self._write_ready)
+
+ self._buffer.append(data)
+
+ def _write_ready(self):
+ data = b''.join(self._buffer)
+ assert data, 'Data should not be empty'
+
+ self._buffer.clear()
+ try:
+ n = os.write(self._fileno, data)
+ except (BlockingIOError, InterruptedError):
+ self._buffer.append(data)
+ except Exception as exc:
+ self._conn_lost += 1
+ # Remove writer here, _fatal_error() doesn't it
+ # because _buffer is empty.
+ self._loop.remove_writer(self._fileno)
+ self._fatal_error(exc)
+ else:
+ if n == len(data):
+ self._loop.remove_writer(self._fileno)
+ if self._closing:
+ self._loop.remove_reader(self._fileno)
+ self._call_connection_lost(None)
+ return
+ elif n > 0:
+ data = data[n:]
+
+ self._buffer.append(data) # Try again later.
+
+ def can_write_eof(self):
+ return True
+
+ # TODO: Make the relationships between write_eof(), close(),
+ # abort(), _fatal_error() and _close() more straightforward.
+
+ def write_eof(self):
+ if self._closing:
+ return
+ assert self._pipe
+ self._closing = True
+ if not self._buffer:
+ self._loop.remove_reader(self._fileno)
+ self._loop.call_soon(self._call_connection_lost, None)
+
+ def close(self):
+ if not self._closing:
+ # write_eof is all what we needed to close the write pipe
+ self.write_eof()
+
+ def abort(self):
+ self._close(None)
+
+ def _fatal_error(self, exc):
+ # should be called by exception handler only
+ asyncio_log.exception('Fatal error for %s', self)
+ self._close(exc)
+
+ def _close(self, exc=None):
+ self._closing = True
+ if self._buffer:
+ self._loop.remove_writer(self._fileno)
+ self._buffer.clear()
+ self._loop.remove_reader(self._fileno)
+ self._loop.call_soon(self._call_connection_lost, exc)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._pipe.close()
+ self._pipe = None
+ self._protocol = None
+ self._loop = None
+
+
+class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol):
+ pipe = None
+
+ def __init__(self, proc, fd):
+ self.proc = proc
+ self.fd = fd
+ self.connected = False
+ self.disconnected = False
+ proc._pipes[fd] = self
+
+ def connection_made(self, transport):
+ self.connected = True
+ self.pipe = transport
+ self.proc._try_connected()
+
+ def connection_lost(self, exc):
+ self.disconnected = True
+ self.proc._pipe_connection_lost(self.fd, exc)
+
+
+class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto,
+ protocols.Protocol):
+
+ def data_received(self, data):
+ self.proc._pipe_data_received(self.fd, data)
+
+ def eof_received(self):
+ pass
+
+
+class _UnixSubprocessTransport(transports.SubprocessTransport):
+
+ def __init__(self, loop, protocol, args, shell,
+ stdin, stdout, stderr, bufsize,
+ extra=None, **kwargs):
+ super().__init__(extra)
+ self._protocol = protocol
+ self._loop = loop
+
+ self._pipes = {}
+ if stdin == subprocess.PIPE:
+ self._pipes[STDIN] = None
+ if stdout == subprocess.PIPE:
+ self._pipes[STDOUT] = None
+ if stderr == subprocess.PIPE:
+ self._pipes[STDERR] = None
+ self._pending_calls = collections.deque()
+ self._finished = False
+ self._returncode = None
+
+ self._proc = subprocess.Popen(
+ args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr,
+ universal_newlines=False, bufsize=bufsize, **kwargs)
+ self._extra['subprocess'] = self._proc
+
+ def close(self):
+ for proto in self._pipes.values():
+ proto.pipe.close()
+ if self._returncode is None:
+ self.terminate()
+
+ def get_pid(self):
+ return self._proc.pid
+
+ def get_returncode(self):
+ return self._returncode
+
+ def get_pipe_transport(self, fd):
+ if fd in self._pipes:
+ return self._pipes[fd].pipe
+ else:
+ return None
+
+ def send_signal(self, signal):
+ self._proc.send_signal(signal)
+
+ def terminate(self):
+ self._proc.terminate()
+
+ def kill(self):
+ self._proc.kill()
+
+ @tasks.coroutine
+ def _post_init(self):
+ proc = self._proc
+ loop = self._loop
+ if proc.stdin is not None:
+ transp, proto = yield from loop.connect_write_pipe(
+ functools.partial(
+ _UnixWriteSubprocessPipeProto, self, STDIN),
+ proc.stdin)
+ if proc.stdout is not None:
+ transp, proto = yield from loop.connect_read_pipe(
+ functools.partial(
+ _UnixReadSubprocessPipeProto, self, STDOUT),
+ proc.stdout)
+ if proc.stderr is not None:
+ transp, proto = yield from loop.connect_read_pipe(
+ functools.partial(
+ _UnixReadSubprocessPipeProto, self, STDERR),
+ proc.stderr)
+ if not self._pipes:
+ self._try_connected()
+
+ def _call(self, cb, *data):
+ if self._pending_calls is not None:
+ self._pending_calls.append((cb, data))
+ else:
+ self._loop.call_soon(cb, *data)
+
+ def _try_connected(self):
+ assert self._pending_calls is not None
+ if all(p is not None and p.connected for p in self._pipes.values()):
+ self._loop.call_soon(self._protocol.connection_made, self)
+ for callback, data in self._pending_calls:
+ self._loop.call_soon(callback, *data)
+ self._pending_calls = None
+
+ def _pipe_connection_lost(self, fd, exc):
+ self._call(self._protocol.pipe_connection_lost, fd, exc)
+ self._try_finish()
+
+ def _pipe_data_received(self, fd, data):
+ self._call(self._protocol.pipe_data_received, fd, data)
+
+ def _process_exited(self, returncode):
+ assert returncode is not None, returncode
+ assert self._returncode is None, self._returncode
+ self._returncode = returncode
+ self._loop._subprocess_closed(self)
+ self._call(self._protocol.process_exited)
+ self._try_finish()
+
+ def _try_finish(self):
+ assert not self._finished
+ if self._returncode is None:
+ return
+ if all(p is not None and p.disconnected
+ for p in self._pipes.values()):
+ self._finished = True
+ self._loop.call_soon(self._call_connection_lost, None)
+
+ def _call_connection_lost(self, exc):
+ try:
+ self._protocol.connection_lost(exc)
+ finally:
+ self._proc = None
+ self._protocol = None
+ self._loop = None
diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py
new file mode 100644
index 0000000..1d0ad26
--- /dev/null
+++ b/Lib/asyncio/windows_events.py
@@ -0,0 +1,375 @@
+"""Selector and proactor eventloops for Windows."""
+
+import errno
+import socket
+import weakref
+import struct
+import _winapi
+
+from . import futures
+from . import proactor_events
+from . import selector_events
+from . import tasks
+from . import windows_utils
+from .log import asyncio_log
+
+try:
+ import _overlapped
+except ImportError:
+ from . import _overlapped
+
+
+__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor']
+
+
+NULL = 0
+INFINITE = 0xffffffff
+ERROR_CONNECTION_REFUSED = 1225
+ERROR_CONNECTION_ABORTED = 1236
+
+
+class _OverlappedFuture(futures.Future):
+ """Subclass of Future which represents an overlapped operation.
+
+ Cancelling it will immediately cancel the overlapped operation.
+ """
+
+ def __init__(self, ov, *, loop=None):
+ super().__init__(loop=loop)
+ self.ov = ov
+
+ def cancel(self):
+ try:
+ self.ov.cancel()
+ except OSError:
+ pass
+ return super().cancel()
+
+
+class PipeServer(object):
+ """Class representing a pipe server.
+
+ This is much like a bound, listening socket.
+ """
+ def __init__(self, address):
+ self._address = address
+ self._free_instances = weakref.WeakSet()
+ self._pipe = self._server_pipe_handle(True)
+
+ def _get_unconnected_pipe(self):
+ # Create new instance and return previous one. This ensures
+ # that (until the server is closed) there is always at least
+ # one pipe handle for address. Therefore if a client attempt
+ # to connect it will not fail with FileNotFoundError.
+ tmp, self._pipe = self._pipe, self._server_pipe_handle(False)
+ return tmp
+
+ def _server_pipe_handle(self, first):
+ # Return a wrapper for a new pipe handle.
+ if self._address is None:
+ return None
+ flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED
+ if first:
+ flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
+ h = _winapi.CreateNamedPipe(
+ self._address, flags,
+ _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE |
+ _winapi.PIPE_WAIT,
+ _winapi.PIPE_UNLIMITED_INSTANCES,
+ windows_utils.BUFSIZE, windows_utils.BUFSIZE,
+ _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL)
+ pipe = windows_utils.PipeHandle(h)
+ self._free_instances.add(pipe)
+ return pipe
+
+ def close(self):
+ # Close all instances which have not been connected to by a client.
+ if self._address is not None:
+ for pipe in self._free_instances:
+ pipe.close()
+ self._pipe = None
+ self._address = None
+ self._free_instances.clear()
+
+ __del__ = close
+
+
+class SelectorEventLoop(selector_events.BaseSelectorEventLoop):
+ """Windows version of selector event loop."""
+
+ def _socketpair(self):
+ return windows_utils.socketpair()
+
+
+class ProactorEventLoop(proactor_events.BaseProactorEventLoop):
+ """Windows version of proactor event loop using IOCP."""
+
+ def __init__(self, proactor=None):
+ if proactor is None:
+ proactor = IocpProactor()
+ super().__init__(proactor)
+
+ def _socketpair(self):
+ return windows_utils.socketpair()
+
+ @tasks.coroutine
+ def create_pipe_connection(self, protocol_factory, address):
+ f = self._proactor.connect_pipe(address)
+ pipe = yield from f
+ protocol = protocol_factory()
+ trans = self._make_duplex_pipe_transport(pipe, protocol,
+ extra={'addr': address})
+ return trans, protocol
+
+ @tasks.coroutine
+ def start_serving_pipe(self, protocol_factory, address):
+ server = PipeServer(address)
+ def loop(f=None):
+ pipe = None
+ try:
+ if f:
+ pipe = f.result()
+ server._free_instances.discard(pipe)
+ protocol = protocol_factory()
+ self._make_duplex_pipe_transport(
+ pipe, protocol, extra={'addr': address})
+ pipe = server._get_unconnected_pipe()
+ if pipe is None:
+ return
+ f = self._proactor.accept_pipe(pipe)
+ except OSError:
+ if pipe and pipe.fileno() != -1:
+ asyncio_log.exception('Pipe accept failed')
+ pipe.close()
+ except futures.CancelledError:
+ if pipe:
+ pipe.close()
+ else:
+ f.add_done_callback(loop)
+ self.call_soon(loop)
+ return [server]
+
+ def _stop_serving(self, server):
+ server.close()
+
+
+class IocpProactor:
+ """Proactor implementation using IOCP."""
+
+ def __init__(self, concurrency=0xffffffff):
+ self._loop = None
+ self._results = []
+ self._iocp = _overlapped.CreateIoCompletionPort(
+ _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency)
+ self._cache = {}
+ self._registered = weakref.WeakSet()
+ self._stopped_serving = weakref.WeakSet()
+
+ def set_loop(self, loop):
+ self._loop = loop
+
+ def select(self, timeout=None):
+ if not self._results:
+ self._poll(timeout)
+ tmp = self._results
+ self._results = []
+ return tmp
+
+ def recv(self, conn, nbytes, flags=0):
+ self._register_with_iocp(conn)
+ ov = _overlapped.Overlapped(NULL)
+ if isinstance(conn, socket.socket):
+ ov.WSARecv(conn.fileno(), nbytes, flags)
+ else:
+ ov.ReadFile(conn.fileno(), nbytes)
+ def finish(trans, key, ov):
+ try:
+ return ov.getresult()
+ except OSError as exc:
+ if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+ raise ConnectionResetError(*exc.args)
+ else:
+ raise
+ return self._register(ov, conn, finish)
+
+ def send(self, conn, buf, flags=0):
+ self._register_with_iocp(conn)
+ ov = _overlapped.Overlapped(NULL)
+ if isinstance(conn, socket.socket):
+ ov.WSASend(conn.fileno(), buf, flags)
+ else:
+ ov.WriteFile(conn.fileno(), buf)
+ def finish(trans, key, ov):
+ try:
+ return ov.getresult()
+ except OSError as exc:
+ if exc.winerror == _overlapped.ERROR_NETNAME_DELETED:
+ raise ConnectionResetError(*exc.args)
+ else:
+ raise
+ return self._register(ov, conn, finish)
+
+ def accept(self, listener):
+ self._register_with_iocp(listener)
+ conn = self._get_accept_socket(listener.family)
+ ov = _overlapped.Overlapped(NULL)
+ ov.AcceptEx(listener.fileno(), conn.fileno())
+ def finish_accept(trans, key, ov):
+ ov.getresult()
+ # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work.
+ buf = struct.pack('@P', listener.fileno())
+ conn.setsockopt(socket.SOL_SOCKET,
+ _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf)
+ conn.settimeout(listener.gettimeout())
+ return conn, conn.getpeername()
+ return self._register(ov, listener, finish_accept)
+
+ def connect(self, conn, address):
+ self._register_with_iocp(conn)
+ # The socket needs to be locally bound before we call ConnectEx().
+ try:
+ _overlapped.BindLocal(conn.fileno(), conn.family)
+ except OSError as e:
+ if e.winerror != errno.WSAEINVAL:
+ raise
+ # Probably already locally bound; check using getsockname().
+ if conn.getsockname()[1] == 0:
+ raise
+ ov = _overlapped.Overlapped(NULL)
+ ov.ConnectEx(conn.fileno(), address)
+ def finish_connect(trans, key, ov):
+ ov.getresult()
+ # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work.
+ conn.setsockopt(socket.SOL_SOCKET,
+ _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0)
+ return conn
+ return self._register(ov, conn, finish_connect)
+
+ def accept_pipe(self, pipe):
+ self._register_with_iocp(pipe)
+ ov = _overlapped.Overlapped(NULL)
+ ov.ConnectNamedPipe(pipe.fileno())
+ def finish(trans, key, ov):
+ ov.getresult()
+ return pipe
+ return self._register(ov, pipe, finish)
+
+ def connect_pipe(self, address):
+ ov = _overlapped.Overlapped(NULL)
+ ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address)
+ def finish(err, handle, ov):
+ # err, handle were arguments passed to PostQueuedCompletionStatus()
+ # in a function run in a thread pool.
+ if err == _overlapped.ERROR_SEM_TIMEOUT:
+ # Connection did not succeed within time limit.
+ msg = _overlapped.FormatMessage(err)
+ raise ConnectionRefusedError(0, msg, None, err)
+ elif err != 0:
+ msg = _overlapped.FormatMessage(err)
+ raise OSError(0, msg, None, err)
+ else:
+ return windows_utils.PipeHandle(handle)
+ return self._register(ov, None, finish, wait_for_post=True)
+
+ def _register_with_iocp(self, obj):
+ # To get notifications of finished ops on this objects sent to the
+ # completion port, were must register the handle.
+ if obj not in self._registered:
+ self._registered.add(obj)
+ _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0)
+ # XXX We could also use SetFileCompletionNotificationModes()
+ # to avoid sending notifications to completion port of ops
+ # that succeed immediately.
+
+ def _register(self, ov, obj, callback, wait_for_post=False):
+ # Return a future which will be set with the result of the
+ # operation when it completes. The future's value is actually
+ # the value returned by callback().
+ f = _OverlappedFuture(ov, loop=self._loop)
+ if ov.pending or wait_for_post:
+ # Register the overlapped operation for later. Note that
+ # we only store obj to prevent it from being garbage
+ # collected too early.
+ self._cache[ov.address] = (f, ov, obj, callback)
+ else:
+ # The operation has completed, so no need to postpone the
+ # work. We cannot take this short cut if we need the
+ # NumberOfBytes, CompletionKey values returned by
+ # PostQueuedCompletionStatus().
+ try:
+ value = callback(None, None, ov)
+ except OSError as e:
+ f.set_exception(e)
+ else:
+ f.set_result(value)
+ return f
+
+ def _get_accept_socket(self, family):
+ s = socket.socket(family)
+ s.settimeout(0)
+ return s
+
+ def _poll(self, timeout=None):
+ if timeout is None:
+ ms = INFINITE
+ elif timeout < 0:
+ raise ValueError("negative timeout")
+ else:
+ ms = int(timeout * 1000 + 0.5)
+ if ms >= INFINITE:
+ raise ValueError("timeout too big")
+ while True:
+ status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms)
+ if status is None:
+ return
+ err, transferred, key, address = status
+ try:
+ f, ov, obj, callback = self._cache.pop(address)
+ except KeyError:
+ # key is either zero, or it is used to return a pipe
+ # handle which should be closed to avoid a leak.
+ if key not in (0, _overlapped.INVALID_HANDLE_VALUE):
+ _winapi.CloseHandle(key)
+ ms = 0
+ continue
+ if obj in self._stopped_serving:
+ f.cancel()
+ elif not f.cancelled():
+ try:
+ value = callback(transferred, key, ov)
+ except OSError as e:
+ f.set_exception(e)
+ self._results.append(f)
+ else:
+ f.set_result(value)
+ self._results.append(f)
+ ms = 0
+
+ def _stop_serving(self, obj):
+ # obj is a socket or pipe handle. It will be closed in
+ # BaseProactorEventLoop._stop_serving() which will make any
+ # pending operations fail quickly.
+ self._stopped_serving.add(obj)
+
+ def close(self):
+ # Cancel remaining registered operations.
+ for address, (f, ov, obj, callback) in list(self._cache.items()):
+ if obj is None:
+ # The operation was started with connect_pipe() which
+ # queues a task to Windows' thread pool. This cannot
+ # be cancelled, so just forget it.
+ del self._cache[address]
+ else:
+ try:
+ ov.cancel()
+ except OSError:
+ pass
+
+ while self._cache:
+ if not self._poll(1):
+ asyncio_log.debug('taking long time to close proactor')
+
+ self._results = []
+ if self._iocp is not None:
+ _winapi.CloseHandle(self._iocp)
+ self._iocp = None
diff --git a/Lib/asyncio/windows_utils.py b/Lib/asyncio/windows_utils.py
new file mode 100644
index 0000000..04b43e9
--- /dev/null
+++ b/Lib/asyncio/windows_utils.py
@@ -0,0 +1,181 @@
+"""
+Various Windows specific bits and pieces
+"""
+
+import sys
+
+if sys.platform != 'win32': # pragma: no cover
+ raise ImportError('win32 only')
+
+import socket
+import itertools
+import msvcrt
+import os
+import subprocess
+import tempfile
+import _winapi
+
+
+__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle']
+
+#
+# Constants/globals
+#
+
+BUFSIZE = 8192
+PIPE = subprocess.PIPE
+_mmap_counter = itertools.count()
+
+#
+# Replacement for socket.socketpair()
+#
+
+def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0):
+ """A socket pair usable as a self-pipe, for Windows.
+
+ Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain.
+ """
+ # We create a connected TCP socket. Note the trick with setblocking(0)
+ # that prevents us from having to create a thread.
+ lsock = socket.socket(family, type, proto)
+ lsock.bind(('localhost', 0))
+ lsock.listen(1)
+ addr, port = lsock.getsockname()
+ csock = socket.socket(family, type, proto)
+ csock.setblocking(False)
+ try:
+ csock.connect((addr, port))
+ except (BlockingIOError, InterruptedError):
+ pass
+ except Exception:
+ lsock.close()
+ csock.close()
+ raise
+ ssock, _ = lsock.accept()
+ csock.setblocking(True)
+ lsock.close()
+ return (ssock, csock)
+
+#
+# Replacement for os.pipe() using handles instead of fds
+#
+
+def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
+ """Like os.pipe() but with overlapped support and using handles not fds."""
+ address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' %
+ (os.getpid(), next(_mmap_counter)))
+
+ if duplex:
+ openmode = _winapi.PIPE_ACCESS_DUPLEX
+ access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE
+ obsize, ibsize = bufsize, bufsize
+ else:
+ openmode = _winapi.PIPE_ACCESS_INBOUND
+ access = _winapi.GENERIC_WRITE
+ obsize, ibsize = 0, bufsize
+
+ openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
+
+ if overlapped[0]:
+ openmode |= _winapi.FILE_FLAG_OVERLAPPED
+
+ if overlapped[1]:
+ flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED
+ else:
+ flags_and_attribs = 0
+
+ h1 = h2 = None
+ try:
+ h1 = _winapi.CreateNamedPipe(
+ address, openmode, _winapi.PIPE_WAIT,
+ 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL)
+
+ h2 = _winapi.CreateFile(
+ address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING,
+ flags_and_attribs, _winapi.NULL)
+
+ ov = _winapi.ConnectNamedPipe(h1, overlapped=True)
+ ov.GetOverlappedResult(True)
+ return h1, h2
+ except:
+ if h1 is not None:
+ _winapi.CloseHandle(h1)
+ if h2 is not None:
+ _winapi.CloseHandle(h2)
+ raise
+
+#
+# Wrapper for a pipe handle
+#
+
+class PipeHandle:
+ """Wrapper for an overlapped pipe handle which is vaguely file-object like.
+
+ The IOCP event loop can use these instead of socket objects.
+ """
+ def __init__(self, handle):
+ self._handle = handle
+
+ @property
+ def handle(self):
+ return self._handle
+
+ def fileno(self):
+ return self._handle
+
+ def close(self, *, CloseHandle=_winapi.CloseHandle):
+ if self._handle != -1:
+ CloseHandle(self._handle)
+ self._handle = -1
+
+ __del__ = close
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, t, v, tb):
+ self.close()
+
+#
+# Replacement for subprocess.Popen using overlapped pipe handles
+#
+
+class Popen(subprocess.Popen):
+ """Replacement for subprocess.Popen using overlapped pipe handles.
+
+ The stdin, stdout, stderr are None or instances of PipeHandle.
+ """
+ def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds):
+ stdin_rfd = stdout_wfd = stderr_wfd = None
+ stdin_wh = stdout_rh = stderr_rh = None
+ if stdin == PIPE:
+ stdin_rh, stdin_wh = pipe(overlapped=(False, True))
+ stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY)
+ if stdout == PIPE:
+ stdout_rh, stdout_wh = pipe(overlapped=(True, False))
+ stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0)
+ if stderr == PIPE:
+ stderr_rh, stderr_wh = pipe(overlapped=(True, False))
+ stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0)
+ try:
+ super().__init__(args, bufsize=0, universal_newlines=False,
+ stdin=stdin_rfd, stdout=stdout_wfd,
+ stderr=stderr_wfd, **kwds)
+ except:
+ for h in (stdin_wh, stdout_rh, stderr_rh):
+ _winapi.CloseHandle(h)
+ raise
+ else:
+ if stdin_wh is not None:
+ self.stdin = PipeHandle(stdin_wh)
+ if stdout_rh is not None:
+ self.stdout = PipeHandle(stdout_rh)
+ if stderr_rh is not None:
+ self.stderr = PipeHandle(stderr_rh)
+ finally:
+ if stdin == PIPE:
+ os.close(stdin_rfd)
+ if stdout == PIPE:
+ os.close(stdout_wfd)
+ if stderr == PIPE:
+ os.close(stderr_wfd)
diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py
new file mode 100644
index 0000000..ec483ea
--- /dev/null
+++ b/Lib/test/test_asyncio/__init__.py
@@ -0,0 +1,26 @@
+import os
+import sys
+import unittest
+from test.support import run_unittest
+
+
+def suite():
+ tests_file = os.path.join(os.path.dirname(__file__), 'tests.txt')
+ with open(tests_file) as fp:
+ test_names = fp.read().splitlines()
+ tests = unittest.TestSuite()
+ loader = unittest.TestLoader()
+ for test_name in test_names:
+ mod_name = 'test.' + test_name
+ try:
+ __import__(mod_name)
+ except unittest.SkipTest:
+ pass
+ else:
+ mod = sys.modules[mod_name]
+ tests.addTests(loader.loadTestsFromModule(mod))
+ return tests
+
+
+def test_main():
+ run_unittest(suite())
diff --git a/Lib/test/test_asyncio/__main__.py b/Lib/test/test_asyncio/__main__.py
new file mode 100644
index 0000000..b549492
--- /dev/null
+++ b/Lib/test/test_asyncio/__main__.py
@@ -0,0 +1,5 @@
+from . import test_main
+
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_asyncio/echo.py b/Lib/test/test_asyncio/echo.py
new file mode 100644
index 0000000..f6ac0a3
--- /dev/null
+++ b/Lib/test/test_asyncio/echo.py
@@ -0,0 +1,6 @@
+import os
+
+if __name__ == '__main__':
+ while True:
+ buf = os.read(0, 1024)
+ os.write(1, buf)
diff --git a/Lib/test/test_asyncio/echo2.py b/Lib/test/test_asyncio/echo2.py
new file mode 100644
index 0000000..e83ca09
--- /dev/null
+++ b/Lib/test/test_asyncio/echo2.py
@@ -0,0 +1,6 @@
+import os
+
+if __name__ == '__main__':
+ buf = os.read(0, 1024)
+ os.write(1, b'OUT:'+buf)
+ os.write(2, b'ERR:'+buf)
diff --git a/Lib/test/test_asyncio/echo3.py b/Lib/test/test_asyncio/echo3.py
new file mode 100644
index 0000000..f1f7ea7
--- /dev/null
+++ b/Lib/test/test_asyncio/echo3.py
@@ -0,0 +1,9 @@
+import os
+
+if __name__ == '__main__':
+ while True:
+ buf = os.read(0, 1024)
+ try:
+ os.write(1, b'OUT:'+buf)
+ except OSError as ex:
+ os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii'))
diff --git a/Lib/test/test_asyncio/sample.crt b/Lib/test/test_asyncio/sample.crt
new file mode 100644
index 0000000..6a1e3f3
--- /dev/null
+++ b/Lib/test/test_asyncio/sample.crt
@@ -0,0 +1,14 @@
+-----BEGIN CERTIFICATE-----
+MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV
+UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi
+MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3
+MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp
+Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g
+U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn
+t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7
+gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg
+Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB
+BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB
+MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I
+I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB
+-----END CERTIFICATE-----
diff --git a/Lib/test/test_asyncio/sample.key b/Lib/test/test_asyncio/sample.key
new file mode 100644
index 0000000..edfea8d
--- /dev/null
+++ b/Lib/test/test_asyncio/sample.key
@@ -0,0 +1,15 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx
+/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi
+qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB
+AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y
+bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1
+iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb
+DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP
+lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL
+21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF
+ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0
+zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u
+GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq
+V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof
+-----END RSA PRIVATE KEY-----
diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py
new file mode 100644
index 0000000..d48d12c
--- /dev/null
+++ b/Lib/test/test_asyncio/test_base_events.py
@@ -0,0 +1,590 @@
+"""Tests for base_events.py"""
+
+import logging
+import socket
+import time
+import unittest
+import unittest.mock
+
+from asyncio import base_events
+from asyncio import events
+from asyncio import futures
+from asyncio import protocols
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class BaseEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = base_events.BaseEventLoop()
+ self.loop._selector = unittest.mock.Mock()
+ events.set_event_loop(None)
+
+ def test_not_implemented(self):
+ m = unittest.mock.Mock()
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_socket_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_ssl_transport, m, m, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_datagram_transport, m, m)
+ self.assertRaises(
+ NotImplementedError, self.loop._process_events, [])
+ self.assertRaises(
+ NotImplementedError, self.loop._write_to_self)
+ self.assertRaises(
+ NotImplementedError, self.loop._read_from_self)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_read_pipe_transport, m, m)
+ self.assertRaises(
+ NotImplementedError,
+ self.loop._make_write_pipe_transport, m, m)
+ gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m)
+ self.assertRaises(NotImplementedError, next, iter(gen))
+
+ def test__add_callback_handle(self):
+ h = events.Handle(lambda: False, ())
+
+ self.loop._add_callback(h)
+ self.assertFalse(self.loop._scheduled)
+ self.assertIn(h, self.loop._ready)
+
+ def test__add_callback_timer(self):
+ h = events.TimerHandle(time.monotonic()+10, lambda: False, ())
+
+ self.loop._add_callback(h)
+ self.assertIn(h, self.loop._scheduled)
+
+ def test__add_callback_cancelled_handle(self):
+ h = events.Handle(lambda: False, ())
+ h.cancel()
+
+ self.loop._add_callback(h)
+ self.assertFalse(self.loop._scheduled)
+ self.assertFalse(self.loop._ready)
+
+ def test_set_default_executor(self):
+ executor = unittest.mock.Mock()
+ self.loop.set_default_executor(executor)
+ self.assertIs(executor, self.loop._default_executor)
+
+ def test_getnameinfo(self):
+ sockaddr = unittest.mock.Mock()
+ self.loop.run_in_executor = unittest.mock.Mock()
+ self.loop.getnameinfo(sockaddr)
+ self.assertEqual(
+ (None, socket.getnameinfo, sockaddr, 0),
+ self.loop.run_in_executor.call_args[0])
+
+ def test_call_soon(self):
+ def cb():
+ pass
+
+ h = self.loop.call_soon(cb)
+ self.assertEqual(h._callback, cb)
+ self.assertIsInstance(h, events.Handle)
+ self.assertIn(h, self.loop._ready)
+
+ def test_call_later(self):
+ def cb():
+ pass
+
+ h = self.loop.call_later(10.0, cb)
+ self.assertIsInstance(h, events.TimerHandle)
+ self.assertIn(h, self.loop._scheduled)
+ self.assertNotIn(h, self.loop._ready)
+
+ def test_call_later_negative_delays(self):
+ calls = []
+
+ def cb(arg):
+ calls.append(arg)
+
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop.call_later(-1, cb, 'a')
+ self.loop.call_later(-2, cb, 'b')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(calls, ['b', 'a'])
+
+ def test_time_and_call_at(self):
+ def cb():
+ self.loop.stop()
+
+ self.loop._process_events = unittest.mock.Mock()
+ when = self.loop.time() + 0.1
+ self.loop.call_at(when, cb)
+ t0 = self.loop.time()
+ self.loop.run_forever()
+ t1 = self.loop.time()
+ self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0)
+
+ def test_run_once_in_executor_handle(self):
+ def cb():
+ pass
+
+ self.assertRaises(
+ AssertionError, self.loop.run_in_executor,
+ None, events.Handle(cb, ()), ('',))
+ self.assertRaises(
+ AssertionError, self.loop.run_in_executor,
+ None, events.TimerHandle(10, cb, ()))
+
+ def test_run_once_in_executor_cancelled(self):
+ def cb():
+ pass
+ h = events.Handle(cb, ())
+ h.cancel()
+
+ f = self.loop.run_in_executor(None, h)
+ self.assertIsInstance(f, futures.Future)
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+
+ def test_run_once_in_executor_plain(self):
+ def cb():
+ pass
+ h = events.Handle(cb, ())
+ f = futures.Future(loop=self.loop)
+ executor = unittest.mock.Mock()
+ executor.submit.return_value = f
+
+ self.loop.set_default_executor(executor)
+
+ res = self.loop.run_in_executor(None, h)
+ self.assertIs(f, res)
+
+ executor = unittest.mock.Mock()
+ executor.submit.return_value = f
+ res = self.loop.run_in_executor(executor, h)
+ self.assertIs(f, res)
+ self.assertTrue(executor.submit.called)
+
+ f.cancel() # Don't complain about abandoned Future.
+
+ def test__run_once(self):
+ h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ())
+ h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ())
+
+ h1.cancel()
+
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop._scheduled.append(h1)
+ self.loop._scheduled.append(h2)
+ self.loop._run_once()
+
+ t = self.loop._selector.select.call_args[0][0]
+ self.assertTrue(9.99 < t < 10.1, t)
+ self.assertEqual([h2], self.loop._scheduled)
+ self.assertTrue(self.loop._process_events.called)
+
+ @unittest.mock.patch('asyncio.base_events.time')
+ @unittest.mock.patch('asyncio.base_events.asyncio_log')
+ def test__run_once_logging(self, m_logging, m_time):
+ # Log to INFO level if timeout > 1.0 sec.
+ idx = -1
+ data = [10.0, 10.0, 12.0, 13.0]
+
+ def monotonic():
+ nonlocal data, idx
+ idx += 1
+ return data[idx]
+
+ m_time.monotonic = monotonic
+ m_logging.INFO = logging.INFO
+ m_logging.DEBUG = logging.DEBUG
+
+ self.loop._scheduled.append(
+ events.TimerHandle(11.0, lambda: True, ()))
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop._run_once()
+ self.assertEqual(logging.INFO, m_logging.log.call_args[0][0])
+
+ idx = -1
+ data = [10.0, 10.0, 10.3, 13.0]
+ self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())]
+ self.loop._run_once()
+ self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0])
+
+ def test__run_once_schedule_handle(self):
+ handle = None
+ processed = False
+
+ def cb(loop):
+ nonlocal processed, handle
+ processed = True
+ handle = loop.call_soon(lambda: True)
+
+ h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,))
+
+ self.loop._process_events = unittest.mock.Mock()
+ self.loop._scheduled.append(h)
+ self.loop._run_once()
+
+ self.assertTrue(processed)
+ self.assertEqual([handle], list(self.loop._ready))
+
+ def test_run_until_complete_type_error(self):
+ self.assertRaises(
+ TypeError, self.loop.run_until_complete, 'blah')
+
+
+class MyProto(protocols.Protocol):
+ done = None
+
+ def __init__(self, create_future=False):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if create_future:
+ self.done = futures.Future()
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyDatagramProto(protocols.DatagramProtocol):
+ done = None
+
+ def __init__(self, create_future=False):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if create_future:
+ self.done = futures.Future()
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'INITIALIZED'
+
+ def datagram_received(self, data, addr):
+ assert self.state == 'INITIALIZED', self.state
+ self.nbytes += len(data)
+
+ def connection_refused(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+
+ def connection_lost(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class BaseEventLoopWithSelectorTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = events.new_event_loop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_connection_multiple_errors(self, m_socket):
+
+ class MyProto(protocols.Protocol):
+ pass
+
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ idx = -1
+ errors = ['err1', 'err2']
+
+ def _socket(*args, **kw):
+ nonlocal idx, errors
+ idx += 1
+ raise OSError(errors[idx])
+
+ m_socket.socket = _socket
+
+ self.loop.getaddrinfo = getaddrinfo_task
+
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(coro)
+
+ self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2')
+
+ def test_create_connection_host_port_sock(self):
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_no_host_port_sock(self):
+ coro = self.loop.create_connection(MyProto)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_no_getaddrinfo(self):
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_connect_err(self):
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ yield from []
+ return [(2, 1, 6, '', ('107.6.106.82', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_connection(MyProto, 'example.com', 80)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_connection_multiple(self):
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ return [(2, 1, 6, '', ('0.0.0.1', 80)),
+ (2, 1, 6, '', ('0.0.0.2', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET)
+ with self.assertRaises(OSError):
+ self.loop.run_until_complete(coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_connection_multiple_errors_local_addr(self, m_socket):
+
+ def bind(addr):
+ if addr[0] == '0.0.0.1':
+ err = OSError('Err')
+ err.strerror = 'Err'
+ raise err
+
+ m_socket.socket.return_value.bind = bind
+
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ return [(2, 1, 6, '', ('0.0.0.1', 80)),
+ (2, 1, 6, '', ('0.0.0.2', 80))]
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError('Err2')
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET,
+ local_addr=(None, 8080))
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(coro)
+
+ self.assertTrue(str(cm.exception).startswith('Multiple exceptions: '))
+ self.assertTrue(m_socket.socket.return_value.close.called)
+
+ def test_create_connection_no_local_addr(self):
+ @tasks.coroutine
+ def getaddrinfo(host, *args, **kw):
+ if host == 'example.com':
+ return [(2, 1, 6, '', ('107.6.106.82', 80)),
+ (2, 1, 6, '', ('107.6.106.82', 80))]
+ else:
+ return []
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+ self.loop.getaddrinfo = getaddrinfo_task
+
+ coro = self.loop.create_connection(
+ MyProto, 'example.com', 80, family=socket.AF_INET,
+ local_addr=(None, 8080))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_server_empty_host(self):
+ # if host is empty string use None instead
+ host = object()
+
+ @tasks.coroutine
+ def getaddrinfo(*args, **kw):
+ nonlocal host
+ host = args[0]
+ yield from []
+
+ def getaddrinfo_task(*args, **kwds):
+ return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop)
+
+ self.loop.getaddrinfo = getaddrinfo_task
+ fut = self.loop.create_server(MyProto, '', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, fut)
+ self.assertIsNone(host)
+
+ def test_create_server_host_port_sock(self):
+ fut = self.loop.create_server(
+ MyProto, '0.0.0.0', 0, sock=object())
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_server_no_host_port_sock(self):
+ fut = self.loop.create_server(MyProto)
+ self.assertRaises(ValueError, self.loop.run_until_complete, fut)
+
+ def test_create_server_no_getaddrinfo(self):
+ getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock()
+ getaddrinfo.return_value = []
+
+ f = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, f)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_server_cant_bind(self, m_socket):
+
+ class Err(OSError):
+ strerror = 'error'
+
+ m_socket.getaddrinfo.return_value = [
+ (2, 1, 6, '', ('127.0.0.1', 10100))]
+ m_sock = m_socket.socket.return_value = unittest.mock.Mock()
+ m_sock.bind.side_effect = Err
+
+ fut = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ self.assertRaises(OSError, self.loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_no_addrinfo(self, m_socket):
+ m_socket.getaddrinfo.return_value = []
+
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_addr_error(self):
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr='localhost')
+ self.assertRaises(
+ AssertionError, self.loop.run_until_complete, coro)
+ coro = self.loop.create_datagram_endpoint(
+ MyDatagramProto, local_addr=('localhost', 1, 2, 3))
+ self.assertRaises(
+ AssertionError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_connect_err(self):
+ self.loop.sock_connect = unittest.mock.Mock()
+ self.loop.sock_connect.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_socket_err(self, m_socket):
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_socket.socket.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, local_addr=('127.0.0.1', 0))
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+
+ def test_create_datagram_endpoint_no_matching_family(self):
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol,
+ remote_addr=('127.0.0.1', 0), local_addr=('::1', 0))
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_setblk_err(self, m_socket):
+ m_socket.socket.return_value.setblocking.side_effect = OSError
+
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol, family=socket.AF_INET)
+ self.assertRaises(
+ OSError, self.loop.run_until_complete, coro)
+ self.assertTrue(
+ m_socket.socket.return_value.close.called)
+
+ def test_create_datagram_endpoint_noaddr_nofamily(self):
+ coro = self.loop.create_datagram_endpoint(
+ protocols.DatagramProtocol)
+ self.assertRaises(ValueError, self.loop.run_until_complete, coro)
+
+ @unittest.mock.patch('asyncio.base_events.socket')
+ def test_create_datagram_endpoint_cant_bind(self, m_socket):
+ class Err(OSError):
+ pass
+
+ m_socket.AF_INET6 = socket.AF_INET6
+ m_socket.getaddrinfo = socket.getaddrinfo
+ m_sock = m_socket.socket.return_value = unittest.mock.Mock()
+ m_sock.bind.side_effect = Err
+
+ fut = self.loop.create_datagram_endpoint(
+ MyDatagramProto,
+ local_addr=('127.0.0.1', 0), family=socket.AF_INET)
+ self.assertRaises(Err, self.loop.run_until_complete, fut)
+ self.assertTrue(m_sock.close.called)
+
+ def test_accept_connection_retry(self):
+ sock = unittest.mock.Mock()
+ sock.accept.side_effect = BlockingIOError()
+
+ self.loop._accept_connection(MyProto, sock)
+ self.assertFalse(sock.close.called)
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_accept_connection_exception(self, m_log):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.side_effect = OSError()
+
+ self.loop._accept_connection(MyProto, sock)
+ self.assertTrue(sock.close.called)
+ self.assertTrue(m_log.exception.called)
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
new file mode 100644
index 0000000..243f400
--- /dev/null
+++ b/Lib/test/test_asyncio/test_events.py
@@ -0,0 +1,1573 @@
+"""Tests for events.py."""
+
+import functools
+import gc
+import io
+import os
+import signal
+import socket
+try:
+ import ssl
+except ImportError:
+ ssl = None
+import subprocess
+import sys
+import threading
+import time
+import errno
+import unittest
+import unittest.mock
+from test.support import find_unused_port
+
+
+from asyncio import futures
+from asyncio import events
+from asyncio import transports
+from asyncio import protocols
+from asyncio import selector_events
+from asyncio import tasks
+from asyncio import test_utils
+from asyncio import locks
+
+
+class MyProto(protocols.Protocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
+
+ def data_received(self, data):
+ assert self.state == 'CONNECTED', self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'EOF'
+
+ def connection_lost(self, exc):
+ assert self.state in ('CONNECTED', 'EOF'), self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyDatagramProto(protocols.DatagramProtocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.nbytes = 0
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'INITIALIZED'
+
+ def datagram_received(self, data, addr):
+ assert self.state == 'INITIALIZED', self.state
+ self.nbytes += len(data)
+
+ def connection_refused(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+
+ def connection_lost(self, exc):
+ assert self.state == 'INITIALIZED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyReadPipeProto(protocols.Protocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = ['INITIAL']
+ self.nbytes = 0
+ self.transport = None
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == ['INITIAL'], self.state
+ self.state.append('CONNECTED')
+
+ def data_received(self, data):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.nbytes += len(data)
+
+ def eof_received(self):
+ assert self.state == ['INITIAL', 'CONNECTED'], self.state
+ self.state.append('EOF')
+
+ def connection_lost(self, exc):
+ assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state
+ self.state.append('CLOSED')
+ if self.done:
+ self.done.set_result(None)
+
+
+class MyWritePipeProto(protocols.BaseProtocol):
+ done = None
+
+ def __init__(self, loop=None):
+ self.state = 'INITIAL'
+ self.transport = None
+ if loop is not None:
+ self.done = futures.Future(loop=loop)
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+
+ def connection_lost(self, exc):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'CLOSED'
+ if self.done:
+ self.done.set_result(None)
+
+
+class MySubprocessProtocol(protocols.SubprocessProtocol):
+
+ def __init__(self, loop):
+ self.state = 'INITIAL'
+ self.transport = None
+ self.connected = futures.Future(loop=loop)
+ self.completed = futures.Future(loop=loop)
+ self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)}
+ self.data = {1: b'', 2: b''}
+ self.returncode = None
+ self.got_data = {1: locks.Event(loop=loop),
+ 2: locks.Event(loop=loop)}
+
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+ self.connected.set_result(None)
+
+ def connection_lost(self, exc):
+ assert self.state == 'CONNECTED', self.state
+ self.state = 'CLOSED'
+ self.completed.set_result(None)
+
+ def pipe_data_received(self, fd, data):
+ assert self.state == 'CONNECTED', self.state
+ self.data[fd] += data
+ self.got_data[fd].set()
+
+ def pipe_connection_lost(self, fd, exc):
+ assert self.state == 'CONNECTED', self.state
+ if exc:
+ self.disconnects[fd].set_exception(exc)
+ else:
+ self.disconnects[fd].set_result(exc)
+
+ def process_exited(self):
+ assert self.state == 'CONNECTED', self.state
+ self.returncode = self.transport.get_returncode()
+
+
+class EventLoopTestsMixin:
+
+ def setUp(self):
+ super().setUp()
+ self.loop = self.create_event_loop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ test_utils.run_briefly(self.loop)
+
+ self.loop.close()
+ gc.collect()
+ super().tearDown()
+
+ def test_run_until_complete_nesting(self):
+ @tasks.coroutine
+ def coro1():
+ yield
+
+ @tasks.coroutine
+ def coro2():
+ self.assertTrue(self.loop.is_running())
+ self.loop.run_until_complete(coro1())
+
+ self.assertRaises(
+ RuntimeError, self.loop.run_until_complete, coro2())
+
+ # Note: because of the default Windows timing granularity of
+ # 15.6 msec, we use fairly long sleep times here (~100 msec).
+
+ def test_run_until_complete(self):
+ t0 = self.loop.time()
+ self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop))
+ t1 = self.loop.time()
+ self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0)
+
+ def test_run_until_complete_stopped(self):
+ @tasks.coroutine
+ def cb():
+ self.loop.stop()
+ yield from tasks.sleep(0.1, loop=self.loop)
+ task = cb()
+ self.assertRaises(RuntimeError,
+ self.loop.run_until_complete, task)
+
+ def test_call_later(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+ self.loop.stop()
+
+ self.loop.call_later(0.1, callback, 'hello world')
+ t0 = time.monotonic()
+ self.loop.run_forever()
+ t1 = time.monotonic()
+ self.assertEqual(results, ['hello world'])
+ self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0)
+
+ def test_call_soon(self):
+ results = []
+
+ def callback(arg1, arg2):
+ results.append((arg1, arg2))
+ self.loop.stop()
+
+ self.loop.call_soon(callback, 'hello', 'world')
+ self.loop.run_forever()
+ self.assertEqual(results, [('hello', 'world')])
+
+ def test_call_soon_threadsafe(self):
+ results = []
+ lock = threading.Lock()
+
+ def callback(arg):
+ results.append(arg)
+ if len(results) >= 2:
+ self.loop.stop()
+
+ def run_in_thread():
+ self.loop.call_soon_threadsafe(callback, 'hello')
+ lock.release()
+
+ lock.acquire()
+ t = threading.Thread(target=run_in_thread)
+ t.start()
+
+ with lock:
+ self.loop.call_soon(callback, 'world')
+ self.loop.run_forever()
+ t.join()
+ self.assertEqual(results, ['hello', 'world'])
+
+ def test_call_soon_threadsafe_same_thread(self):
+ results = []
+
+ def callback(arg):
+ results.append(arg)
+ if len(results) >= 2:
+ self.loop.stop()
+
+ self.loop.call_soon_threadsafe(callback, 'hello')
+ self.loop.call_soon(callback, 'world')
+ self.loop.run_forever()
+ self.assertEqual(results, ['hello', 'world'])
+
+ def test_run_in_executor(self):
+ def run(arg):
+ return (arg, threading.get_ident())
+ f2 = self.loop.run_in_executor(None, run, 'yo')
+ res, thread_id = self.loop.run_until_complete(f2)
+ self.assertEqual(res, 'yo')
+ self.assertNotEqual(thread_id, threading.get_ident())
+
+ def test_reader_callback(self):
+ r, w = test_utils.socketpair()
+ bytes_read = []
+
+ def reader():
+ try:
+ data = r.recv(1024)
+ except BlockingIOError:
+ # Spurious readiness notifications are possible
+ # at least on Linux -- see man select.
+ return
+ if data:
+ bytes_read.append(data)
+ else:
+ self.assertTrue(self.loop.remove_reader(r.fileno()))
+ r.close()
+
+ self.loop.add_reader(r.fileno(), reader)
+ self.loop.call_soon(w.send, b'abc')
+ test_utils.run_briefly(self.loop)
+ self.loop.call_soon(w.send, b'def')
+ test_utils.run_briefly(self.loop)
+ self.loop.call_soon(w.close)
+ self.loop.call_soon(self.loop.stop)
+ self.loop.run_forever()
+ self.assertEqual(b''.join(bytes_read), b'abcdef')
+
+ def test_writer_callback(self):
+ r, w = test_utils.socketpair()
+ w.setblocking(False)
+ self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024))
+ test_utils.run_briefly(self.loop)
+
+ def remove_writer():
+ self.assertTrue(self.loop.remove_writer(w.fileno()))
+
+ self.loop.call_soon(remove_writer)
+ self.loop.call_soon(self.loop.stop)
+ self.loop.run_forever()
+ w.close()
+ data = r.recv(256*1024)
+ r.close()
+ self.assertGreaterEqual(len(data), 200)
+
+ def test_sock_client_ops(self):
+ with test_utils.run_test_server() as httpd:
+ sock = socket.socket()
+ sock.setblocking(False)
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, httpd.address))
+ self.loop.run_until_complete(
+ self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
+ data = self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ # consume data
+ self.loop.run_until_complete(
+ self.loop.sock_recv(sock, 1024))
+ sock.close()
+
+ self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
+
+ def test_sock_client_fail(self):
+ # Make sure that we will get an unused port
+ address = None
+ try:
+ s = socket.socket()
+ s.bind(('127.0.0.1', 0))
+ address = s.getsockname()
+ finally:
+ s.close()
+
+ sock = socket.socket()
+ sock.setblocking(False)
+ with self.assertRaises(ConnectionRefusedError):
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, address))
+ sock.close()
+
+ def test_sock_accept(self):
+ listener = socket.socket()
+ listener.setblocking(False)
+ listener.bind(('127.0.0.1', 0))
+ listener.listen(1)
+ client = socket.socket()
+ client.connect(listener.getsockname())
+
+ f = self.loop.sock_accept(listener)
+ conn, addr = self.loop.run_until_complete(f)
+ self.assertEqual(conn.gettimeout(), 0)
+ self.assertEqual(addr, client.getsockname())
+ self.assertEqual(client.getpeername(), listener.getsockname())
+ client.close()
+ conn.close()
+ listener.close()
+
+ @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL')
+ def test_add_signal_handler(self):
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+
+ # Check error behavior first.
+ self.assertRaises(
+ TypeError, self.loop.add_signal_handler, 'boom', my_handler)
+ self.assertRaises(
+ TypeError, self.loop.remove_signal_handler, 'boom')
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, signal.NSIG+1,
+ my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, signal.NSIG+1)
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, 0, my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, 0)
+ self.assertRaises(
+ ValueError, self.loop.add_signal_handler, -1, my_handler)
+ self.assertRaises(
+ ValueError, self.loop.remove_signal_handler, -1)
+ self.assertRaises(
+ RuntimeError, self.loop.add_signal_handler, signal.SIGKILL,
+ my_handler)
+ # Removing SIGKILL doesn't raise, since we don't call signal().
+ self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL))
+ # Now set a handler and handle it.
+ self.loop.add_signal_handler(signal.SIGINT, my_handler)
+ test_utils.run_briefly(self.loop)
+ os.kill(os.getpid(), signal.SIGINT)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(caught, 1)
+ # Removing it should restore the default handler.
+ self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT))
+ self.assertEqual(signal.getsignal(signal.SIGINT),
+ signal.default_int_handler)
+ # Removing again returns False.
+ self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT))
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_while_selecting(self):
+ # Test with a signal actually arriving during a select() call.
+ caught = 0
+
+ def my_handler():
+ nonlocal caught
+ caught += 1
+ self.loop.stop()
+
+ self.loop.add_signal_handler(signal.SIGALRM, my_handler)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once.
+ self.loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM')
+ def test_signal_handling_args(self):
+ some_args = (42,)
+ caught = 0
+
+ def my_handler(*args):
+ nonlocal caught
+ caught += 1
+ self.assertEqual(args, some_args)
+
+ self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args)
+
+ signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once.
+ self.loop.call_later(0.015, self.loop.stop)
+ self.loop.run_forever()
+ self.assertEqual(caught, 1)
+
+ def test_create_connection(self):
+ with test_utils.run_test_server() as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), *httpd.address)
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertTrue(isinstance(tr, transports.Transport))
+ self.assertTrue(isinstance(pr, protocols.Protocol))
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ def test_create_connection_sock(self):
+ with test_utils.run_test_server() as httpd:
+ sock = None
+ infos = self.loop.run_until_complete(
+ self.loop.getaddrinfo(
+ *httpd.address, type=socket.SOCK_STREAM))
+ for family, type, proto, cname, address in infos:
+ try:
+ sock = socket.socket(family=family, type=type, proto=proto)
+ sock.setblocking(False)
+ self.loop.run_until_complete(
+ self.loop.sock_connect(sock, address))
+ except:
+ pass
+ else:
+ break
+ else:
+ assert False, 'Can not create socket.'
+
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), sock=sock)
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertTrue(isinstance(tr, transports.Transport))
+ self.assertTrue(isinstance(pr, protocols.Protocol))
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_ssl_connection(self):
+ with test_utils.run_test_server(use_ssl=True) as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop), *httpd.address,
+ ssl=test_utils.dummy_ssl_context())
+ tr, pr = self.loop.run_until_complete(f)
+ self.assertTrue(isinstance(tr, transports.Transport))
+ self.assertTrue(isinstance(pr, protocols.Protocol))
+ self.assertTrue('ssl' in tr.__class__.__name__.lower())
+ self.assertIsNotNone(tr.get_extra_info('sockname'))
+ self.loop.run_until_complete(pr.done)
+ self.assertGreater(pr.nbytes, 0)
+ tr.close()
+
+ def test_create_connection_local_addr(self):
+ with test_utils.run_test_server() as httpd:
+ port = find_unused_port()
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop),
+ *httpd.address, local_addr=(httpd.address[0], port))
+ tr, pr = self.loop.run_until_complete(f)
+ expected = pr.transport.get_extra_info('sockname')[1]
+ self.assertEqual(port, expected)
+ tr.close()
+
+ def test_create_connection_local_addr_in_use(self):
+ with test_utils.run_test_server() as httpd:
+ f = self.loop.create_connection(
+ lambda: MyProto(loop=self.loop),
+ *httpd.address, local_addr=httpd.address)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(f)
+ self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
+ self.assertIn(str(httpd.address), cm.exception.strerror)
+
+ def test_create_server(self):
+ proto = None
+
+ def factory():
+ nonlocal proto
+ proto = MyProto()
+ return proto
+
+ f = self.loop.create_server(factory, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ self.assertEqual(len(server.sockets), 1)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(proto, MyProto)
+ self.assertEqual('INITIAL', proto.state)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual('CONNECTED', proto.state)
+ test_utils.run_briefly(self.loop) # windows iocp
+ self.assertEqual(3, proto.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('sockname'))
+ self.assertEqual('127.0.0.1',
+ proto.transport.get_extra_info('peername')[0])
+
+ # close connection
+ proto.transport.close()
+ test_utils.run_briefly(self.loop) # windows iocp
+
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # close server
+ server.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_create_server_ssl(self):
+ proto = None
+
+ class ClientMyProto(MyProto):
+ def connection_made(self, transport):
+ self.transport = transport
+ assert self.state == 'INITIAL', self.state
+ self.state = 'CONNECTED'
+
+ def factory():
+ nonlocal proto
+ proto = MyProto(loop=self.loop)
+ return proto
+
+ here = os.path.dirname(__file__)
+ sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ sslcontext.load_cert_chain(
+ certfile=os.path.join(here, 'sample.crt'),
+ keyfile=os.path.join(here, 'sample.key'))
+
+ f = self.loop.create_server(
+ factory, '127.0.0.1', 0, ssl=sslcontext)
+
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+ self.assertEqual(host, '127.0.0.1')
+
+ f_c = self.loop.create_connection(ClientMyProto, host, port,
+ ssl=test_utils.dummy_ssl_context())
+ client, pr = self.loop.run_until_complete(f_c)
+
+ client.write(b'xxx')
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(proto, MyProto)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual('CONNECTED', proto.state)
+ self.assertEqual(3, proto.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('sockname'))
+ self.assertEqual('127.0.0.1',
+ proto.transport.get_extra_info('peername')[0])
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ # the client socket must be closed after to avoid ECONNRESET upon
+ # recv()/send() on the serving socket
+ client.close()
+
+ # stop serving
+ server.close()
+
+ def test_create_server_sock(self):
+ proto = futures.Future(loop=self.loop)
+
+ class TestMyProto(MyProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ proto.set_result(self)
+
+ sock_ob = socket.socket(type=socket.SOCK_STREAM)
+ sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock_ob.bind(('0.0.0.0', 0))
+
+ f = self.loop.create_server(TestMyProto, sock=sock_ob)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ self.assertIs(sock, sock_ob)
+
+ host, port = sock.getsockname()
+ self.assertEqual(host, '0.0.0.0')
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ client.close()
+ server.close()
+
+ def test_create_server_addr_in_use(self):
+ sock_ob = socket.socket(type=socket.SOCK_STREAM)
+ sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock_ob.bind(('0.0.0.0', 0))
+
+ f = self.loop.create_server(MyProto, sock=sock_ob)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+
+ f = self.loop.create_server(MyProto, host=host, port=port)
+ with self.assertRaises(OSError) as cm:
+ self.loop.run_until_complete(f)
+ self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
+
+ server.close()
+
+ @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported')
+ def test_create_server_dual_stack(self):
+ f_proto = futures.Future(loop=self.loop)
+
+ class TestMyProto(MyProto):
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ f_proto.set_result(self)
+
+ try_count = 0
+ while True:
+ try:
+ port = find_unused_port()
+ f = self.loop.create_server(TestMyProto, host=None, port=port)
+ server = self.loop.run_until_complete(f)
+ except OSError as ex:
+ if ex.errno == errno.EADDRINUSE:
+ try_count += 1
+ self.assertGreaterEqual(5, try_count)
+ continue
+ else:
+ raise
+ else:
+ break
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ proto = self.loop.run_until_complete(f_proto)
+ proto.transport.close()
+ client.close()
+
+ f_proto = futures.Future(loop=self.loop)
+ client = socket.socket(socket.AF_INET6)
+ client.connect(('::1', port))
+ client.send(b'xxx')
+ proto = self.loop.run_until_complete(f_proto)
+ proto.transport.close()
+ client.close()
+
+ server.close()
+
+ def test_server_close(self):
+ f = self.loop.create_server(MyProto, '0.0.0.0', 0)
+ server = self.loop.run_until_complete(f)
+ sock = server.sockets[0]
+ host, port = sock.getsockname()
+
+ client = socket.socket()
+ client.connect(('127.0.0.1', port))
+ client.send(b'xxx')
+ client.close()
+
+ server.close()
+
+ client = socket.socket()
+ self.assertRaises(
+ ConnectionRefusedError, client.connect, ('127.0.0.1', port))
+ client.close()
+
+ def test_create_datagram_endpoint(self):
+ class TestMyDatagramProto(MyDatagramProto):
+ def __init__(inner_self):
+ super().__init__(loop=self.loop)
+
+ def datagram_received(self, data, addr):
+ super().datagram_received(data, addr)
+ self.transport.sendto(b'resp:'+data, addr)
+
+ coro = self.loop.create_datagram_endpoint(
+ TestMyDatagramProto, local_addr=('127.0.0.1', 0))
+ s_transport, server = self.loop.run_until_complete(coro)
+ host, port = s_transport.get_extra_info('sockname')
+
+ coro = self.loop.create_datagram_endpoint(
+ lambda: MyDatagramProto(loop=self.loop),
+ remote_addr=(host, port))
+ transport, client = self.loop.run_until_complete(coro)
+
+ self.assertEqual('INITIALIZED', client.state)
+ transport.sendto(b'xxx')
+ for _ in range(1000):
+ if server.nbytes:
+ break
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(3, server.nbytes)
+ for _ in range(1000):
+ if client.nbytes:
+ break
+ test_utils.run_briefly(self.loop)
+
+ # received
+ self.assertEqual(8, client.nbytes)
+
+ # extra info is available
+ self.assertIsNotNone(transport.get_extra_info('sockname'))
+
+ # close connection
+ transport.close()
+ self.loop.run_until_complete(client.done)
+ self.assertEqual('CLOSED', client.state)
+ server.transport.close()
+
+ def test_internal_fds(self):
+ loop = self.create_event_loop()
+ if not isinstance(loop, selector_events.BaseSelectorEventLoop):
+ return
+
+ self.assertEqual(1, loop._internal_fds)
+ loop.close()
+ self.assertEqual(0, loop._internal_fds)
+ self.assertIsNone(loop._csock)
+ self.assertIsNone(loop._ssock)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_read_pipe(self):
+ proto = None
+
+ def factory():
+ nonlocal proto
+ proto = MyReadPipeProto(loop=self.loop)
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(rpipe, 'rb', 1024)
+
+ @tasks.coroutine
+ def connect():
+ t, p = yield from self.loop.connect_read_pipe(factory, pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(0, proto.nbytes)
+
+ self.loop.run_until_complete(connect())
+
+ os.write(wpipe, b'1')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(1, proto.nbytes)
+
+ os.write(wpipe, b'2345')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(['INITIAL', 'CONNECTED'], proto.state)
+ self.assertEqual(5, proto.nbytes)
+
+ os.close(wpipe)
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual(
+ ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state)
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe(self):
+ proto = None
+ transport = None
+
+ def factory():
+ nonlocal proto
+ proto = MyWritePipeProto(loop=self.loop)
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(wpipe, 'wb', 1024)
+
+ @tasks.coroutine
+ def connect():
+ nonlocal transport
+ t, p = yield from self.loop.connect_write_pipe(factory, pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+ transport = t
+
+ self.loop.run_until_complete(connect())
+
+ transport.write(b'1')
+ test_utils.run_briefly(self.loop)
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'1', data)
+
+ transport.write(b'2345')
+ test_utils.run_briefly(self.loop)
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'2345', data)
+ self.assertEqual('CONNECTED', proto.state)
+
+ os.close(rpipe)
+
+ # extra info is available
+ self.assertIsNotNone(proto.transport.get_extra_info('pipe'))
+
+ # close connection
+ proto.transport.close()
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ @unittest.skipUnless(sys.platform != 'win32',
+ "Don't support pipes for Windows")
+ def test_write_pipe_disconnect_on_close(self):
+ proto = None
+ transport = None
+
+ def factory():
+ nonlocal proto
+ proto = MyWritePipeProto(loop=self.loop)
+ return proto
+
+ rpipe, wpipe = os.pipe()
+ pipeobj = io.open(wpipe, 'wb', 1024)
+
+ @tasks.coroutine
+ def connect():
+ nonlocal transport
+ t, p = yield from self.loop.connect_write_pipe(factory,
+ pipeobj)
+ self.assertIs(p, proto)
+ self.assertIs(t, proto.transport)
+ self.assertEqual('CONNECTED', proto.state)
+ transport = t
+
+ self.loop.run_until_complete(connect())
+ self.assertEqual('CONNECTED', proto.state)
+
+ transport.write(b'1')
+ test_utils.run_briefly(self.loop)
+ data = os.read(rpipe, 1024)
+ self.assertEqual(b'1', data)
+
+ os.close(rpipe)
+
+ self.loop.run_until_complete(proto.done)
+ self.assertEqual('CLOSED', proto.state)
+
+ def test_prompt_cancellation(self):
+ r, w = test_utils.socketpair()
+ r.setblocking(False)
+ f = self.loop.sock_recv(r, 1)
+ ov = getattr(f, 'ov', None)
+ self.assertTrue(ov is None or ov.pending)
+
+ @tasks.coroutine
+ def main():
+ try:
+ self.loop.call_soon(f.cancel)
+ yield from f
+ except futures.CancelledError:
+ res = 'cancelled'
+ else:
+ res = None
+ finally:
+ self.loop.stop()
+ return res
+
+ start = time.monotonic()
+ t = tasks.Task(main(), loop=self.loop)
+ self.loop.run_forever()
+ elapsed = time.monotonic() - start
+
+ self.assertLess(elapsed, 0.1)
+ self.assertEqual(t.result(), 'cancelled')
+ self.assertRaises(futures.CancelledError, f.result)
+ self.assertTrue(ov is None or not ov.pending)
+ self.loop._stop_serving(r)
+
+ r.close()
+ w.close()
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_exec(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'Python The Winner')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ transp.close()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGTERM, proto.returncode)
+ self.assertEqual(b'Python The Winner', proto.data[1])
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_interactive(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+ self.assertEqual('CONNECTED', proto.state)
+
+ try:
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'Python ')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ proto.got_data[1].clear()
+ self.assertEqual(b'Python ', proto.data[1])
+
+ stdin.write(b'The Winner')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ self.assertEqual(b'Python The Winner', proto.data[1])
+ finally:
+ transp.close()
+
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGTERM, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_shell(self):
+ proto = None
+ transp = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'echo "Python"')
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.get_pipe_transport(0).close()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(0, proto.returncode)
+ self.assertTrue(all(f.done() for f in proto.disconnects.values()))
+ self.assertEqual({1: b'Python\n', 2: b''}, proto.data)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_exitcode(self):
+ proto = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_close_after_finish(self):
+ proto = None
+ transp = None
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_shell(
+ functools.partial(MySubprocessProtocol, self.loop),
+ 'exit 7', stdin=None, stdout=None, stderr=None)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.assertIsNone(transp.get_pipe_transport(0))
+ self.assertIsNone(transp.get_pipe_transport(1))
+ self.assertIsNone(transp.get_pipe_transport(2))
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(7, proto.returncode)
+ self.assertIsNone(transp.close())
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_kill(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.kill()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGKILL, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_send_signal(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ transp.send_signal(signal.SIGHUP)
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGHUP, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_stderr(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ stdin.write(b'test')
+
+ self.loop.run_until_complete(proto.completed)
+
+ transp.close()
+ self.assertEqual(b'OUT:test', proto.data[1])
+ self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2])
+ self.assertEqual(0, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_stderr_redirect_to_stdout(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo2.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog, stderr=subprocess.STDOUT)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ self.assertIsNotNone(transp.get_pipe_transport(1))
+ self.assertIsNone(transp.get_pipe_transport(2))
+
+ stdin.write(b'test')
+ self.loop.run_until_complete(proto.completed)
+ self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'),
+ proto.data[1])
+ self.assertEqual(b'', proto.data[2])
+
+ transp.close()
+ self.assertEqual(0, proto.returncode)
+
+ @unittest.skipIf(sys.platform == 'win32',
+ "Don't support subprocess for Windows yet")
+ def test_subprocess_close_client_stream(self):
+ proto = None
+ transp = None
+
+ prog = os.path.join(os.path.dirname(__file__), 'echo3.py')
+
+ @tasks.coroutine
+ def connect():
+ nonlocal proto, transp
+ transp, proto = yield from self.loop.subprocess_exec(
+ functools.partial(MySubprocessProtocol, self.loop),
+ sys.executable, prog)
+ self.assertIsInstance(proto, MySubprocessProtocol)
+
+ self.loop.run_until_complete(connect())
+ self.loop.run_until_complete(proto.connected)
+
+ stdin = transp.get_pipe_transport(0)
+ stdout = transp.get_pipe_transport(1)
+ stdin.write(b'test')
+ self.loop.run_until_complete(proto.got_data[1].wait())
+ self.assertEqual(b'OUT:test', proto.data[1])
+
+ stdout.close()
+ self.loop.run_until_complete(proto.disconnects[1])
+ stdin.write(b'xxx')
+ self.loop.run_until_complete(proto.got_data[2].wait())
+ self.assertEqual(b'ERR:BrokenPipeError', proto.data[2])
+
+ transp.close()
+ self.loop.run_until_complete(proto.completed)
+ self.assertEqual(-signal.SIGTERM, proto.returncode)
+
+
+if sys.platform == 'win32':
+ from asyncio import windows_events
+
+ class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return windows_events.SelectorEventLoop()
+
+ class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return windows_events.ProactorEventLoop()
+
+ def test_create_ssl_connection(self):
+ raise unittest.SkipTest("IocpEventLoop imcompatible with SSL")
+
+ def test_create_server_ssl(self):
+ raise unittest.SkipTest("IocpEventLoop imcompatible with SSL")
+
+ def test_reader_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+
+ def test_reader_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
+
+ def test_writer_callback(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+
+ def test_writer_callback_cancel(self):
+ raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
+
+ def test_create_datagram_endpoint(self):
+ raise unittest.SkipTest(
+ "IocpEventLoop does not have create_datagram_endpoint()")
+else:
+ from asyncio import selectors
+ from asyncio import unix_events
+
+ if hasattr(selectors, 'KqueueSelector'):
+ class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(
+ selectors.KqueueSelector())
+
+ if hasattr(selectors, 'EpollSelector'):
+ class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.EpollSelector())
+
+ if hasattr(selectors, 'PollSelector'):
+ class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.PollSelector())
+
+ # Should always exist.
+ class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase):
+
+ def create_event_loop(self):
+ return unix_events.SelectorEventLoop(selectors.SelectSelector())
+
+
+class HandleTests(unittest.TestCase):
+
+ def test_handle(self):
+ def callback(*args):
+ return args
+
+ args = ()
+ h = events.Handle(callback, args)
+ self.assertIs(h._callback, callback)
+ self.assertIs(h._args, args)
+ self.assertFalse(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.startswith(
+ 'Handle('
+ '<function HandleTests.test_handle.<locals>.callback'))
+ self.assertTrue(r.endswith('())'))
+
+ h.cancel()
+ self.assertTrue(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.startswith(
+ 'Handle('
+ '<function HandleTests.test_handle.<locals>.callback'))
+ self.assertTrue(r.endswith('())<cancelled>'), r)
+
+ def test_make_handle(self):
+ def callback(*args):
+ return args
+ h1 = events.Handle(callback, ())
+ self.assertRaises(
+ AssertionError, events.make_handle, h1, ())
+
+ @unittest.mock.patch('asyncio.events.asyncio_log')
+ def test_callback_with_exception(self, log):
+ def callback():
+ raise ValueError()
+
+ h = events.Handle(callback, ())
+ h._run()
+ self.assertTrue(log.exception.called)
+
+
+class TimerTests(unittest.TestCase):
+
+ def test_hash(self):
+ when = time.monotonic()
+ h = events.TimerHandle(when, lambda: False, ())
+ self.assertEqual(hash(h), hash(when))
+
+ def test_timer(self):
+ def callback(*args):
+ return args
+
+ args = ()
+ when = time.monotonic()
+ h = events.TimerHandle(when, callback, args)
+ self.assertIs(h._callback, callback)
+ self.assertIs(h._args, args)
+ self.assertFalse(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.endswith('())'))
+
+ h.cancel()
+ self.assertTrue(h._cancelled)
+
+ r = repr(h)
+ self.assertTrue(r.endswith('())<cancelled>'), r)
+
+ self.assertRaises(AssertionError,
+ events.TimerHandle, None, callback, args)
+
+ def test_timer_comparison(self):
+ def callback(*args):
+ return args
+
+ when = time.monotonic()
+
+ h1 = events.TimerHandle(when, callback, ())
+ h2 = events.TimerHandle(when, callback, ())
+ # TODO: Use assertLess etc.
+ self.assertFalse(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertTrue(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertFalse(h2 > h1)
+ self.assertTrue(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertTrue(h1 == h2)
+ self.assertFalse(h1 != h2)
+
+ h2.cancel()
+ self.assertFalse(h1 == h2)
+
+ h1 = events.TimerHandle(when, callback, ())
+ h2 = events.TimerHandle(when + 10.0, callback, ())
+ self.assertTrue(h1 < h2)
+ self.assertFalse(h2 < h1)
+ self.assertTrue(h1 <= h2)
+ self.assertFalse(h2 <= h1)
+ self.assertFalse(h1 > h2)
+ self.assertTrue(h2 > h1)
+ self.assertFalse(h1 >= h2)
+ self.assertTrue(h2 >= h1)
+ self.assertFalse(h1 == h2)
+ self.assertTrue(h1 != h2)
+
+ h3 = events.Handle(callback, ())
+ self.assertIs(NotImplemented, h1.__eq__(h3))
+ self.assertIs(NotImplemented, h1.__ne__(h3))
+
+
+class AbstractEventLoopTests(unittest.TestCase):
+
+ def test_not_implemented(self):
+ f = unittest.mock.Mock()
+ loop = events.AbstractEventLoop()
+ self.assertRaises(
+ NotImplementedError, loop.run_forever)
+ self.assertRaises(
+ NotImplementedError, loop.run_until_complete, None)
+ self.assertRaises(
+ NotImplementedError, loop.stop)
+ self.assertRaises(
+ NotImplementedError, loop.is_running)
+ self.assertRaises(
+ NotImplementedError, loop.call_later, None, None)
+ self.assertRaises(
+ NotImplementedError, loop.call_at, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.call_soon, None)
+ self.assertRaises(
+ NotImplementedError, loop.time)
+ self.assertRaises(
+ NotImplementedError, loop.call_soon_threadsafe, None)
+ self.assertRaises(
+ NotImplementedError, loop.run_in_executor, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.set_default_executor, f)
+ self.assertRaises(
+ NotImplementedError, loop.getaddrinfo, 'localhost', 8080)
+ self.assertRaises(
+ NotImplementedError, loop.getnameinfo, ('localhost', 8080))
+ self.assertRaises(
+ NotImplementedError, loop.create_connection, f)
+ self.assertRaises(
+ NotImplementedError, loop.create_server, f)
+ self.assertRaises(
+ NotImplementedError, loop.create_datagram_endpoint, f)
+ self.assertRaises(
+ NotImplementedError, loop.add_reader, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_reader, 1)
+ self.assertRaises(
+ NotImplementedError, loop.add_writer, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_writer, 1)
+ self.assertRaises(
+ NotImplementedError, loop.sock_recv, f, 10)
+ self.assertRaises(
+ NotImplementedError, loop.sock_sendall, f, 10)
+ self.assertRaises(
+ NotImplementedError, loop.sock_connect, f, f)
+ self.assertRaises(
+ NotImplementedError, loop.sock_accept, f)
+ self.assertRaises(
+ NotImplementedError, loop.add_signal_handler, 1, f)
+ self.assertRaises(
+ NotImplementedError, loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, loop.remove_signal_handler, 1)
+ self.assertRaises(
+ NotImplementedError, loop.connect_read_pipe, f,
+ unittest.mock.sentinel.pipe)
+ self.assertRaises(
+ NotImplementedError, loop.connect_write_pipe, f,
+ unittest.mock.sentinel.pipe)
+ self.assertRaises(
+ NotImplementedError, loop.subprocess_shell, f,
+ unittest.mock.sentinel)
+ self.assertRaises(
+ NotImplementedError, loop.subprocess_exec, f)
+
+
+class ProtocolsAbsTests(unittest.TestCase):
+
+ def test_empty(self):
+ f = unittest.mock.Mock()
+ p = protocols.Protocol()
+ self.assertIsNone(p.connection_made(f))
+ self.assertIsNone(p.connection_lost(f))
+ self.assertIsNone(p.data_received(f))
+ self.assertIsNone(p.eof_received())
+
+ dp = protocols.DatagramProtocol()
+ self.assertIsNone(dp.connection_made(f))
+ self.assertIsNone(dp.connection_lost(f))
+ self.assertIsNone(dp.connection_refused(f))
+ self.assertIsNone(dp.datagram_received(f, f))
+
+ sp = protocols.SubprocessProtocol()
+ self.assertIsNone(sp.connection_made(f))
+ self.assertIsNone(sp.connection_lost(f))
+ self.assertIsNone(sp.pipe_data_received(1, f))
+ self.assertIsNone(sp.pipe_connection_lost(1, f))
+ self.assertIsNone(sp.process_exited())
+
+
+class PolicyTests(unittest.TestCase):
+
+ def test_event_loop_policy(self):
+ policy = events.AbstractEventLoopPolicy()
+ self.assertRaises(NotImplementedError, policy.get_event_loop)
+ self.assertRaises(NotImplementedError, policy.set_event_loop, object())
+ self.assertRaises(NotImplementedError, policy.new_event_loop)
+
+ def test_get_event_loop(self):
+ policy = events.DefaultEventLoopPolicy()
+ self.assertIsNone(policy._loop)
+
+ loop = policy.get_event_loop()
+ self.assertIsInstance(loop, events.AbstractEventLoop)
+
+ self.assertIs(policy._loop, loop)
+ self.assertIs(loop, policy.get_event_loop())
+ loop.close()
+
+ def test_get_event_loop_after_set_none(self):
+ policy = events.DefaultEventLoopPolicy()
+ policy.set_event_loop(None)
+ self.assertRaises(AssertionError, policy.get_event_loop)
+
+ @unittest.mock.patch('asyncio.events.threading.current_thread')
+ def test_get_event_loop_thread(self, m_current_thread):
+
+ def f():
+ policy = events.DefaultEventLoopPolicy()
+ self.assertRaises(AssertionError, policy.get_event_loop)
+
+ th = threading.Thread(target=f)
+ th.start()
+ th.join()
+
+ def test_new_event_loop(self):
+ policy = events.DefaultEventLoopPolicy()
+
+ loop = policy.new_event_loop()
+ self.assertIsInstance(loop, events.AbstractEventLoop)
+ loop.close()
+
+ def test_set_event_loop(self):
+ policy = events.DefaultEventLoopPolicy()
+ old_loop = policy.get_event_loop()
+
+ self.assertRaises(AssertionError, policy.set_event_loop, object())
+
+ loop = policy.new_event_loop()
+ policy.set_event_loop(loop)
+ self.assertIs(loop, policy.get_event_loop())
+ self.assertIsNot(old_loop, policy.get_event_loop())
+ loop.close()
+ old_loop.close()
+
+ def test_get_event_loop_policy(self):
+ policy = events.get_event_loop_policy()
+ self.assertIsInstance(policy, events.AbstractEventLoopPolicy)
+ self.assertIs(policy, events.get_event_loop_policy())
+
+ def test_set_event_loop_policy(self):
+ self.assertRaises(
+ AssertionError, events.set_event_loop_policy, object())
+
+ old_policy = events.get_event_loop_policy()
+
+ policy = events.DefaultEventLoopPolicy()
+ events.set_event_loop_policy(policy)
+ self.assertIs(policy, events.get_event_loop_policy())
+ self.assertIsNot(policy, old_policy)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py
new file mode 100644
index 0000000..9b5108c
--- /dev/null
+++ b/Lib/test/test_asyncio/test_futures.py
@@ -0,0 +1,329 @@
+"""Tests for futures.py."""
+
+import concurrent.futures
+import threading
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import test_utils
+
+
+def _fakefunc(f):
+ return f
+
+
+class FutureTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_initial_state(self):
+ f = futures.Future(loop=self.loop)
+ self.assertFalse(f.cancelled())
+ self.assertFalse(f.done())
+ f.cancel()
+ self.assertTrue(f.cancelled())
+
+ def test_init_constructor_default_loop(self):
+ try:
+ events.set_event_loop(self.loop)
+ f = futures.Future()
+ self.assertIs(f._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_constructor_positional(self):
+ # Make sure Future does't accept a positional argument
+ self.assertRaises(TypeError, futures.Future, 42)
+
+ def test_cancel(self):
+ f = futures.Future(loop=self.loop)
+ self.assertTrue(f.cancel())
+ self.assertTrue(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertRaises(futures.CancelledError, f.result)
+ self.assertRaises(futures.CancelledError, f.exception)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_result(self):
+ f = futures.Future(loop=self.loop)
+ self.assertRaises(futures.InvalidStateError, f.result)
+
+ f.set_result(42)
+ self.assertFalse(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertEqual(f.result(), 42)
+ self.assertEqual(f.exception(), None)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_exception(self):
+ exc = RuntimeError()
+ f = futures.Future(loop=self.loop)
+ self.assertRaises(futures.InvalidStateError, f.exception)
+
+ f.set_exception(exc)
+ self.assertFalse(f.cancelled())
+ self.assertTrue(f.done())
+ self.assertRaises(RuntimeError, f.result)
+ self.assertEqual(f.exception(), exc)
+ self.assertRaises(futures.InvalidStateError, f.set_result, None)
+ self.assertRaises(futures.InvalidStateError, f.set_exception, None)
+ self.assertFalse(f.cancel())
+
+ def test_yield_from_twice(self):
+ f = futures.Future(loop=self.loop)
+
+ def fixture():
+ yield 'A'
+ x = yield from f
+ yield 'B', x
+ y = yield from f
+ yield 'C', y
+
+ g = fixture()
+ self.assertEqual(next(g), 'A') # yield 'A'.
+ self.assertEqual(next(g), f) # First yield from f.
+ f.set_result(42)
+ self.assertEqual(next(g), ('B', 42)) # yield 'B', x.
+ # The second "yield from f" does not yield f.
+ self.assertEqual(next(g), ('C', 42)) # yield 'C', y.
+
+ def test_repr(self):
+ f_pending = futures.Future(loop=self.loop)
+ self.assertEqual(repr(f_pending), 'Future<PENDING>')
+ f_pending.cancel()
+
+ f_cancelled = futures.Future(loop=self.loop)
+ f_cancelled.cancel()
+ self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>')
+
+ f_result = futures.Future(loop=self.loop)
+ f_result.set_result(4)
+ self.assertEqual(repr(f_result), 'Future<result=4>')
+ self.assertEqual(f_result.result(), 4)
+
+ exc = RuntimeError()
+ f_exception = futures.Future(loop=self.loop)
+ f_exception.set_exception(exc)
+ self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>')
+ self.assertIs(f_exception.exception(), exc)
+
+ f_few_callbacks = futures.Future(loop=self.loop)
+ f_few_callbacks.add_done_callback(_fakefunc)
+ self.assertIn('Future<PENDING, [<function _fakefunc',
+ repr(f_few_callbacks))
+ f_few_callbacks.cancel()
+
+ f_many_callbacks = futures.Future(loop=self.loop)
+ for i in range(20):
+ f_many_callbacks.add_done_callback(_fakefunc)
+ r = repr(f_many_callbacks)
+ self.assertIn('Future<PENDING, [<function _fakefunc', r)
+ self.assertIn('<18 more>', r)
+ f_many_callbacks.cancel()
+
+ def test_copy_state(self):
+ # Test the internal _copy_state method since it's being directly
+ # invoked in other modules.
+ f = futures.Future(loop=self.loop)
+ f.set_result(10)
+
+ newf = futures.Future(loop=self.loop)
+ newf._copy_state(f)
+ self.assertTrue(newf.done())
+ self.assertEqual(newf.result(), 10)
+
+ f_exception = futures.Future(loop=self.loop)
+ f_exception.set_exception(RuntimeError())
+
+ newf_exception = futures.Future(loop=self.loop)
+ newf_exception._copy_state(f_exception)
+ self.assertTrue(newf_exception.done())
+ self.assertRaises(RuntimeError, newf_exception.result)
+
+ f_cancelled = futures.Future(loop=self.loop)
+ f_cancelled.cancel()
+
+ newf_cancelled = futures.Future(loop=self.loop)
+ newf_cancelled._copy_state(f_cancelled)
+ self.assertTrue(newf_cancelled.cancelled())
+
+ def test_iter(self):
+ fut = futures.Future(loop=self.loop)
+
+ def coro():
+ yield from fut
+
+ def test():
+ arg1, arg2 = coro()
+
+ self.assertRaises(AssertionError, test)
+ fut.cancel()
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_abandoned(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_result_unretrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_result(42)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_result_retrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_result(42)
+ fut.result()
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_exception_unretrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ del fut
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_exception_retrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ fut.exception()
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ @unittest.mock.patch('asyncio.futures.asyncio_log')
+ def test_tb_logger_exception_result_retrieved(self, m_log):
+ fut = futures.Future(loop=self.loop)
+ fut.set_exception(RuntimeError('boom'))
+ self.assertRaises(RuntimeError, fut.result)
+ del fut
+ self.assertFalse(m_log.error.called)
+
+ def test_wrap_future(self):
+
+ def run(arg):
+ return (arg, threading.get_ident())
+ ex = concurrent.futures.ThreadPoolExecutor(1)
+ f1 = ex.submit(run, 'oi')
+ f2 = futures.wrap_future(f1, loop=self.loop)
+ res, ident = self.loop.run_until_complete(f2)
+ self.assertIsInstance(f2, futures.Future)
+ self.assertEqual(res, 'oi')
+ self.assertNotEqual(ident, threading.get_ident())
+
+ def test_wrap_future_future(self):
+ f1 = futures.Future(loop=self.loop)
+ f2 = futures.wrap_future(f1)
+ self.assertIs(f1, f2)
+
+ @unittest.mock.patch('asyncio.futures.events')
+ def test_wrap_future_use_global_loop(self, m_events):
+ def run(arg):
+ return (arg, threading.get_ident())
+ ex = concurrent.futures.ThreadPoolExecutor(1)
+ f1 = ex.submit(run, 'oi')
+ f2 = futures.wrap_future(f1)
+ self.assertIs(m_events.get_event_loop.return_value, f2._loop)
+
+
+class FutureDoneCallbackTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def run_briefly(self):
+ test_utils.run_briefly(self.loop)
+
+ def _make_callback(self, bag, thing):
+ # Create a callback function that appends thing to bag.
+ def bag_appender(future):
+ bag.append(thing)
+ return bag_appender
+
+ def _new_future(self):
+ return futures.Future(loop=self.loop)
+
+ def test_callbacks_invoked_on_set_result(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 42))
+ f.add_done_callback(self._make_callback(bag, 17))
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [42, 17])
+ self.assertEqual(f.result(), 'foo')
+
+ def test_callbacks_invoked_on_set_exception(self):
+ bag = []
+ f = self._new_future()
+ f.add_done_callback(self._make_callback(bag, 100))
+
+ self.assertEqual(bag, [])
+ exc = RuntimeError()
+ f.set_exception(exc)
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [100])
+ self.assertEqual(f.exception(), exc)
+
+ def test_remove_done_callback(self):
+ bag = []
+ f = self._new_future()
+ cb1 = self._make_callback(bag, 1)
+ cb2 = self._make_callback(bag, 2)
+ cb3 = self._make_callback(bag, 3)
+
+ # Add one cb1 and one cb2.
+ f.add_done_callback(cb1)
+ f.add_done_callback(cb2)
+
+ # One instance of cb2 removed. Now there's only one cb1.
+ self.assertEqual(f.remove_done_callback(cb2), 1)
+
+ # Never had any cb3 in there.
+ self.assertEqual(f.remove_done_callback(cb3), 0)
+
+ # After this there will be 6 instances of cb1 and one of cb2.
+ f.add_done_callback(cb2)
+ for i in range(5):
+ f.add_done_callback(cb1)
+
+ # Remove all instances of cb1. One cb2 remains.
+ self.assertEqual(f.remove_done_callback(cb1), 6)
+
+ self.assertEqual(bag, [])
+ f.set_result('foo')
+
+ self.run_briefly()
+
+ self.assertEqual(bag, [2])
+ self.assertEqual(f.result(), 'foo')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py
new file mode 100644
index 0000000..31b4d64
--- /dev/null
+++ b/Lib/test/test_asyncio/test_locks.py
@@ -0,0 +1,765 @@
+"""Tests for lock.py"""
+
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import locks
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class LockTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ lock = locks.Lock(loop=loop)
+ self.assertIs(lock._loop, loop)
+
+ lock = locks.Lock(loop=self.loop)
+ self.assertIs(lock._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ lock = locks.Lock()
+ self.assertIs(lock._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ lock = locks.Lock(loop=self.loop)
+ self.assertTrue(repr(lock).endswith('[unlocked]>'))
+
+ @tasks.coroutine
+ def acquire_lock():
+ yield from lock
+
+ self.loop.run_until_complete(acquire_lock())
+ self.assertTrue(repr(lock).endswith('[locked]>'))
+
+ def test_lock(self):
+ lock = locks.Lock(loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ res = self.loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_acquire(self):
+ lock = locks.Lock(loop=self.loop)
+ result = []
+
+ self.assertTrue(self.loop.run_until_complete(lock.acquire()))
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from lock.acquire()):
+ result.append(1)
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ if (yield from lock.acquire()):
+ result.append(2)
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ if (yield from lock.acquire()):
+ result.append(3)
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+
+ lock.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_acquire_cancel(self):
+ lock = locks.Lock(loop=self.loop)
+ self.assertTrue(self.loop.run_until_complete(lock.acquire()))
+
+ task = tasks.Task(lock.acquire(), loop=self.loop)
+ self.loop.call_soon(task.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, task)
+ self.assertFalse(lock._waiters)
+
+ def test_cancel_race(self):
+ # Several tasks:
+ # - A acquires the lock
+ # - B is blocked in aqcuire()
+ # - C is blocked in aqcuire()
+ #
+ # Now, concurrently:
+ # - B is cancelled
+ # - A releases the lock
+ #
+ # If B's waiter is marked cancelled but not yet removed from
+ # _waiters, A's release() call will crash when trying to set
+ # B's waiter; instead, it should move on to C's waiter.
+
+ # Setup: A has the lock, b and c are waiting.
+ lock = locks.Lock(loop=self.loop)
+
+ @tasks.coroutine
+ def lockit(name, blocker):
+ yield from lock.acquire()
+ try:
+ if blocker is not None:
+ yield from blocker
+ finally:
+ lock.release()
+
+ fa = futures.Future(loop=self.loop)
+ ta = tasks.Task(lockit('A', fa), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(lock.locked())
+ tb = tasks.Task(lockit('B', None), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(len(lock._waiters), 1)
+ tc = tasks.Task(lockit('C', None), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(len(lock._waiters), 2)
+
+ # Create the race and check.
+ # Without the fix this failed at the last assert.
+ fa.set_result(None)
+ tb.cancel()
+ self.assertTrue(lock._waiters[0].cancelled())
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(lock.locked())
+ self.assertTrue(ta.done())
+ self.assertTrue(tb.cancelled())
+ self.assertTrue(tc.done())
+
+ def test_release_not_acquired(self):
+ lock = locks.Lock(loop=self.loop)
+
+ self.assertRaises(RuntimeError, lock.release)
+
+ def test_release_no_waiters(self):
+ lock = locks.Lock(loop=self.loop)
+ self.loop.run_until_complete(lock.acquire())
+ self.assertTrue(lock.locked())
+
+ lock.release()
+ self.assertFalse(lock.locked())
+
+ def test_context_manager(self):
+ lock = locks.Lock(loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from lock)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertTrue(lock.locked())
+
+ self.assertFalse(lock.locked())
+
+ def test_context_manager_no_yield(self):
+ lock = locks.Lock(loop=self.loop)
+
+ try:
+ with lock:
+ self.fail('RuntimeError is not raised in with expression')
+ except RuntimeError as err:
+ self.assertEqual(
+ str(err),
+ '"yield from" should be used as context manager expression')
+
+
+class EventTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ ev = locks.Event(loop=loop)
+ self.assertIs(ev._loop, loop)
+
+ ev = locks.Event(loop=self.loop)
+ self.assertIs(ev._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ ev = locks.Event()
+ self.assertIs(ev._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ ev = locks.Event(loop=self.loop)
+ self.assertTrue(repr(ev).endswith('[unset]>'))
+
+ ev.set()
+ self.assertTrue(repr(ev).endswith('[set]>'))
+
+ def test_wait(self):
+ ev = locks.Event(loop=self.loop)
+ self.assertFalse(ev.is_set())
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+
+ @tasks.coroutine
+ def c2(result):
+ if (yield from ev.wait()):
+ result.append(2)
+
+ @tasks.coroutine
+ def c3(result):
+ if (yield from ev.wait()):
+ result.append(3)
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ ev.set()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([3, 1, 2], result)
+
+ self.assertTrue(t1.done())
+ self.assertIsNone(t1.result())
+ self.assertTrue(t2.done())
+ self.assertIsNone(t2.result())
+ self.assertTrue(t3.done())
+ self.assertIsNone(t3.result())
+
+ def test_wait_on_set(self):
+ ev = locks.Event(loop=self.loop)
+ ev.set()
+
+ res = self.loop.run_until_complete(ev.wait())
+ self.assertTrue(res)
+
+ def test_wait_cancel(self):
+ ev = locks.Event(loop=self.loop)
+
+ wait = tasks.Task(ev.wait(), loop=self.loop)
+ self.loop.call_soon(wait.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, wait)
+ self.assertFalse(ev._waiters)
+
+ def test_clear(self):
+ ev = locks.Event(loop=self.loop)
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ self.assertTrue(ev.is_set())
+
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ def test_clear_with_waiters(self):
+ ev = locks.Event(loop=self.loop)
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ if (yield from ev.wait()):
+ result.append(1)
+ return True
+
+ t = tasks.Task(c1(result), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ ev.set()
+ ev.clear()
+ self.assertFalse(ev.is_set())
+
+ ev.set()
+ ev.set()
+ self.assertEqual(1, len(ev._waiters))
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertEqual(0, len(ev._waiters))
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+
+class ConditionTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ cond = locks.Condition(loop=loop)
+ self.assertIs(cond._loop, loop)
+
+ cond = locks.Condition(loop=self.loop)
+ self.assertIs(cond._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ cond = locks.Condition()
+ self.assertIs(cond._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_wait(self):
+ cond = locks.Condition(loop=self.loop)
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+ self.assertFalse(cond.locked())
+
+ self.assertTrue(self.loop.run_until_complete(cond.acquire()))
+ cond.notify()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.notify(2)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+ self.assertTrue(cond.locked())
+
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+ self.assertTrue(cond.locked())
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_wait_cancel(self):
+ cond = locks.Condition(loop=self.loop)
+ self.loop.run_until_complete(cond.acquire())
+
+ wait = tasks.Task(cond.wait(), loop=self.loop)
+ self.loop.call_soon(wait.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, wait)
+ self.assertFalse(cond._condition_waiters)
+ self.assertTrue(cond.locked())
+
+ def test_wait_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete, cond.wait())
+
+ def test_wait_for(self):
+ cond = locks.Condition(loop=self.loop)
+ presult = False
+
+ def predicate():
+ return presult
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait_for(predicate)):
+ result.append(1)
+ cond.release()
+ return True
+
+ t = tasks.Task(c1(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ presult = True
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ def test_wait_for_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+
+ # predicate can return true immediately
+ res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3]))
+ self.assertEqual([1, 2, 3], res)
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete,
+ cond.wait_for(lambda: False))
+
+ def test_notify(self):
+ cond = locks.Condition(loop=self.loop)
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(3)
+ cond.release()
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify(1)
+ cond.notify(2048)
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2, 3], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+
+ def test_notify_all(self):
+ cond = locks.Condition(loop=self.loop)
+
+ result = []
+
+ @tasks.coroutine
+ def c1(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(1)
+ cond.release()
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from cond.acquire()
+ if (yield from cond.wait()):
+ result.append(2)
+ cond.release()
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([], result)
+
+ self.loop.run_until_complete(cond.acquire())
+ cond.notify_all()
+ cond.release()
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1, 2], result)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+
+ def test_notify_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertRaises(RuntimeError, cond.notify)
+
+ def test_notify_all_unacquired(self):
+ cond = locks.Condition(loop=self.loop)
+ self.assertRaises(RuntimeError, cond.notify_all)
+
+
+class SemaphoreTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ sem = locks.Semaphore(loop=loop)
+ self.assertIs(sem._loop, loop)
+
+ sem = locks.Semaphore(loop=self.loop)
+ self.assertIs(sem._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ sem = locks.Semaphore()
+ self.assertIs(sem._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.assertTrue(repr(sem).endswith('[unlocked,value:1]>'))
+
+ self.loop.run_until_complete(sem.acquire())
+ self.assertTrue(repr(sem).endswith('[locked]>'))
+
+ def test_semaphore(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.assertEqual(1, sem._value)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from sem)
+
+ res = self.loop.run_until_complete(acquire_lock())
+
+ self.assertTrue(res)
+ self.assertTrue(sem.locked())
+ self.assertEqual(0, sem._value)
+
+ sem.release()
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ def test_semaphore_value(self):
+ self.assertRaises(ValueError, locks.Semaphore, -1)
+
+ def test_acquire(self):
+ sem = locks.Semaphore(3, loop=self.loop)
+ result = []
+
+ self.assertTrue(self.loop.run_until_complete(sem.acquire()))
+ self.assertTrue(self.loop.run_until_complete(sem.acquire()))
+ self.assertFalse(sem.locked())
+
+ @tasks.coroutine
+ def c1(result):
+ yield from sem.acquire()
+ result.append(1)
+ return True
+
+ @tasks.coroutine
+ def c2(result):
+ yield from sem.acquire()
+ result.append(2)
+ return True
+
+ @tasks.coroutine
+ def c3(result):
+ yield from sem.acquire()
+ result.append(3)
+ return True
+
+ @tasks.coroutine
+ def c4(result):
+ yield from sem.acquire()
+ result.append(4)
+ return True
+
+ t1 = tasks.Task(c1(result), loop=self.loop)
+ t2 = tasks.Task(c2(result), loop=self.loop)
+ t3 = tasks.Task(c3(result), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual([1], result)
+ self.assertTrue(sem.locked())
+ self.assertEqual(2, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ t4 = tasks.Task(c4(result), loop=self.loop)
+
+ sem.release()
+ sem.release()
+ self.assertEqual(2, sem._value)
+
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(0, sem._value)
+ self.assertEqual([1, 2, 3], result)
+ self.assertTrue(sem.locked())
+ self.assertEqual(1, len(sem._waiters))
+ self.assertEqual(0, sem._value)
+
+ self.assertTrue(t1.done())
+ self.assertTrue(t1.result())
+ self.assertTrue(t2.done())
+ self.assertTrue(t2.result())
+ self.assertTrue(t3.done())
+ self.assertTrue(t3.result())
+ self.assertFalse(t4.done())
+
+ # cleanup locked semaphore
+ sem.release()
+
+ def test_acquire_cancel(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.loop.run_until_complete(sem.acquire())
+
+ acquire = tasks.Task(sem.acquire(), loop=self.loop)
+ self.loop.call_soon(acquire.cancel)
+ self.assertRaises(
+ futures.CancelledError,
+ self.loop.run_until_complete, acquire)
+ self.assertFalse(sem._waiters)
+
+ def test_release_not_acquired(self):
+ sem = locks.Semaphore(bound=True, loop=self.loop)
+
+ self.assertRaises(ValueError, sem.release)
+
+ def test_release_no_waiters(self):
+ sem = locks.Semaphore(loop=self.loop)
+ self.loop.run_until_complete(sem.acquire())
+ self.assertTrue(sem.locked())
+
+ sem.release()
+ self.assertFalse(sem.locked())
+
+ def test_context_manager(self):
+ sem = locks.Semaphore(2, loop=self.loop)
+
+ @tasks.coroutine
+ def acquire_lock():
+ return (yield from sem)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertFalse(sem.locked())
+ self.assertEqual(1, sem._value)
+
+ with self.loop.run_until_complete(acquire_lock()):
+ self.assertTrue(sem.locked())
+
+ self.assertEqual(2, sem._value)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
new file mode 100644
index 0000000..c52ade0
--- /dev/null
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -0,0 +1,480 @@
+"""Tests for proactor_events.py"""
+
+import socket
+import unittest
+import unittest.mock
+
+import asyncio
+from asyncio.proactor_events import BaseProactorEventLoop
+from asyncio.proactor_events import _ProactorSocketTransport
+from asyncio.proactor_events import _ProactorWritePipeTransport
+from asyncio.proactor_events import _ProactorDuplexPipeTransport
+from asyncio import test_utils
+
+
+class ProactorSocketTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.proactor = unittest.mock.Mock()
+ self.loop._proactor = self.proactor
+ self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+
+ def test_ctor(self):
+ fut = asyncio.Future(loop=self.loop)
+ tr = _ProactorSocketTransport(
+ self.loop, self.sock, self.protocol, fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+ self.protocol.connection_made(tr)
+ self.proactor.recv.assert_called_with(self.sock, 4096)
+
+ def test_loop_reading(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._loop_reading()
+ self.loop._proactor.recv.assert_called_with(self.sock, 4096)
+ self.assertFalse(self.protocol.data_received.called)
+ self.assertFalse(self.protocol.eof_received.called)
+
+ def test_loop_reading_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result(b'data')
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.loop._proactor.recv.assert_called_with(self.sock, 4096)
+ self.protocol.data_received.assert_called_with(b'data')
+
+ def test_loop_reading_no_data(self):
+ res = asyncio.Future(loop=self.loop)
+ res.set_result(b'')
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+
+ self.assertRaises(AssertionError, tr._loop_reading, res)
+
+ tr.close = unittest.mock.Mock()
+ tr._read_fut = res
+ tr._loop_reading(res)
+ self.assertFalse(self.loop._proactor.recv.called)
+ self.assertTrue(self.protocol.eof_received.called)
+ self.assertTrue(tr.close.called)
+
+ def test_loop_reading_aborted(self):
+ err = self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ tr._fatal_error.assert_called_with(err)
+
+ def test_loop_reading_aborted_closing(self):
+ self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = True
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ self.assertFalse(tr._fatal_error.called)
+
+ def test_loop_reading_aborted_is_fatal(self):
+ self.loop._proactor.recv.side_effect = ConnectionAbortedError()
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = False
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ self.assertTrue(tr._fatal_error.called)
+
+ def test_loop_reading_conn_reset_lost(self):
+ err = self.loop._proactor.recv.side_effect = ConnectionResetError()
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = False
+ tr._fatal_error = unittest.mock.Mock()
+ tr._force_close = unittest.mock.Mock()
+ tr._loop_reading()
+ self.assertFalse(tr._fatal_error.called)
+ tr._force_close.assert_called_with(err)
+
+ def test_loop_reading_exception(self):
+ err = self.loop._proactor.recv.side_effect = (OSError())
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._fatal_error = unittest.mock.Mock()
+ tr._loop_reading()
+ tr._fatal_error.assert_called_with(err)
+
+ def test_write(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._loop_writing = unittest.mock.Mock()
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, [b'data'])
+ self.assertTrue(tr._loop_writing.called)
+
+ def test_write_no_data(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr.write(b'')
+ self.assertFalse(tr._buffer)
+
+ def test_write_more(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = unittest.mock.Mock()
+ tr._loop_writing = unittest.mock.Mock()
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, [b'data'])
+ self.assertFalse(tr._loop_writing.called)
+
+ def test_loop_writing(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'da', b'ta']
+ tr._loop_writing()
+ self.loop._proactor.send.assert_called_with(self.sock, b'data')
+ self.loop._proactor.send.return_value.add_done_callback.\
+ assert_called_with(tr._loop_writing)
+
+ @unittest.mock.patch('asyncio.proactor_events.asyncio_log')
+ def test_loop_writing_err(self, m_log):
+ err = self.loop._proactor.send.side_effect = OSError()
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._fatal_error = unittest.mock.Mock()
+ tr._buffer = [b'da', b'ta']
+ tr._loop_writing()
+ tr._fatal_error.assert_called_with(err)
+ tr._conn_lost = 1
+
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ self.assertEqual(tr._buffer, [])
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_loop_writing_stop(self):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(b'data')
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = fut
+ tr._loop_writing(fut)
+ self.assertIsNone(tr._write_fut)
+
+ def test_loop_writing_closing(self):
+ fut = asyncio.Future(loop=self.loop)
+ fut.set_result(1)
+
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = fut
+ tr.close()
+ tr._loop_writing(fut)
+ self.assertIsNone(tr._write_fut)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_abort(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._force_close = unittest.mock.Mock()
+ tr.abort()
+ tr._force_close.assert_called_with(None)
+
+ def test_close(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertTrue(tr._closing)
+ self.assertEqual(tr._conn_lost, 1)
+
+ self.protocol.connection_lost.reset_mock()
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_close_write_fut(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._write_fut = unittest.mock.Mock()
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_close_buffer(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'data']
+ tr.close()
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ @unittest.mock.patch('asyncio.proactor_events.asyncio_log')
+ def test_fatal_error(self, m_logging):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._force_close = unittest.mock.Mock()
+ tr._fatal_error(None)
+ self.assertTrue(tr._force_close.called)
+ self.assertTrue(m_logging.exception.called)
+
+ def test_force_close(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'data']
+ read_fut = tr._read_fut = unittest.mock.Mock()
+ write_fut = tr._write_fut = unittest.mock.Mock()
+ tr._force_close(None)
+
+ read_fut.cancel.assert_called_with()
+ write_fut.cancel.assert_called_with()
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertEqual([], tr._buffer)
+ self.assertEqual(tr._conn_lost, 1)
+
+ def test_force_close_idempotent(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._closing = True
+ tr._force_close(None)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_fatal_error_2(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._buffer = [b'data']
+ tr._force_close(None)
+
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.assertEqual([], tr._buffer)
+
+ def test_call_connection_lost(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ tr._call_connection_lost(None)
+ self.assertTrue(self.protocol.connection_lost.called)
+ self.assertTrue(self.sock.close.called)
+
+ def test_write_eof(self):
+ tr = _ProactorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.write_eof()
+ self.assertEqual(self.sock.shutdown.call_count, 1)
+ tr.close()
+
+ def test_write_eof_buffer(self):
+ tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol)
+ f = asyncio.Future(loop=self.loop)
+ tr._loop._proactor.send.return_value = f
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertTrue(tr._eof_written)
+ self.assertFalse(self.sock.shutdown.called)
+ tr._loop._proactor.send.assert_called_with(self.sock, b'data')
+ f.set_result(4)
+ self.loop._run_once()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.close()
+
+ def test_write_eof_write_pipe(self):
+ tr = _ProactorWritePipeTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.loop._run_once()
+ self.assertTrue(self.sock.close.called)
+ tr.close()
+
+ def test_write_eof_buffer_write_pipe(self):
+ tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol)
+ f = asyncio.Future(loop=self.loop)
+ tr._loop._proactor.send.return_value = f
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.sock.shutdown.called)
+ tr._loop._proactor.send.assert_called_with(self.sock, b'data')
+ f.set_result(4)
+ self.loop._run_once()
+ self.loop._run_once()
+ self.assertTrue(self.sock.close.called)
+ tr.close()
+
+ def test_write_eof_duplex_pipe(self):
+ tr = _ProactorDuplexPipeTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertFalse(tr.can_write_eof())
+ with self.assertRaises(NotImplementedError):
+ tr.write_eof()
+ tr.close()
+
+ def test_pause_resume(self):
+ tr = _ProactorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ futures = []
+ for msg in [b'data1', b'data2', b'data3', b'data4', b'']:
+ f = asyncio.Future(loop=self.loop)
+ f.set_result(msg)
+ futures.append(f)
+ self.loop._proactor.recv.side_effect = futures
+ self.loop._run_once()
+ self.assertFalse(tr._paused)
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data1')
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data2')
+ tr.pause()
+ self.assertTrue(tr._paused)
+ for i in range(10):
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data2')
+ tr.resume()
+ self.assertFalse(tr._paused)
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data3')
+ self.loop._run_once()
+ self.protocol.data_received.assert_called_with(b'data4')
+ tr.close()
+
+
+class BaseProactorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.proactor = unittest.mock.Mock()
+
+ self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock()
+
+ class EventLoop(BaseProactorEventLoop):
+ def _socketpair(s):
+ return (self.ssock, self.csock)
+
+ self.loop = EventLoop(self.proactor)
+
+ @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon')
+ @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair')
+ def test_ctor(self, socketpair, call_soon):
+ ssock, csock = socketpair.return_value = (
+ unittest.mock.Mock(), unittest.mock.Mock())
+ loop = BaseProactorEventLoop(self.proactor)
+ self.assertIs(loop._ssock, ssock)
+ self.assertIs(loop._csock, csock)
+ self.assertEqual(loop._internal_fds, 1)
+ call_soon.assert_called_with(loop._loop_self_reading)
+
+ def test_close_self_pipe(self):
+ self.loop._close_self_pipe()
+ self.assertEqual(self.loop._internal_fds, 0)
+ self.assertTrue(self.ssock.close.called)
+ self.assertTrue(self.csock.close.called)
+ self.assertIsNone(self.loop._ssock)
+ self.assertIsNone(self.loop._csock)
+
+ def test_close(self):
+ self.loop._close_self_pipe = unittest.mock.Mock()
+ self.loop.close()
+ self.assertTrue(self.loop._close_self_pipe.called)
+ self.assertTrue(self.proactor.close.called)
+ self.assertIsNone(self.loop._proactor)
+
+ self.loop._close_self_pipe.reset_mock()
+ self.loop.close()
+ self.assertFalse(self.loop._close_self_pipe.called)
+
+ def test_sock_recv(self):
+ self.loop.sock_recv(self.sock, 1024)
+ self.proactor.recv.assert_called_with(self.sock, 1024)
+
+ def test_sock_sendall(self):
+ self.loop.sock_sendall(self.sock, b'data')
+ self.proactor.send.assert_called_with(self.sock, b'data')
+
+ def test_sock_connect(self):
+ self.loop.sock_connect(self.sock, 123)
+ self.proactor.connect.assert_called_with(self.sock, 123)
+
+ def test_sock_accept(self):
+ self.loop.sock_accept(self.sock)
+ self.proactor.accept.assert_called_with(self.sock)
+
+ def test_socketpair(self):
+ self.assertRaises(
+ NotImplementedError, BaseProactorEventLoop, self.proactor)
+
+ def test_make_socket_transport(self):
+ tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock())
+ self.assertIsInstance(tr, _ProactorSocketTransport)
+
+ def test_loop_self_reading(self):
+ self.loop._loop_self_reading()
+ self.proactor.recv.assert_called_with(self.ssock, 4096)
+ self.proactor.recv.return_value.add_done_callback.assert_called_with(
+ self.loop._loop_self_reading)
+
+ def test_loop_self_reading_fut(self):
+ fut = unittest.mock.Mock()
+ self.loop._loop_self_reading(fut)
+ self.assertTrue(fut.result.called)
+ self.proactor.recv.assert_called_with(self.ssock, 4096)
+ self.proactor.recv.return_value.add_done_callback.assert_called_with(
+ self.loop._loop_self_reading)
+
+ def test_loop_self_reading_exception(self):
+ self.loop.close = unittest.mock.Mock()
+ self.proactor.recv.side_effect = OSError()
+ self.assertRaises(OSError, self.loop._loop_self_reading)
+ self.assertTrue(self.loop.close.called)
+
+ def test_write_to_self(self):
+ self.loop._write_to_self()
+ self.csock.send.assert_called_with(b'x')
+
+ def test_process_events(self):
+ self.loop._process_events([])
+
+ @unittest.mock.patch('asyncio.proactor_events.asyncio_log')
+ def test_create_server(self, m_log):
+ pf = unittest.mock.Mock()
+ call_soon = self.loop.call_soon = unittest.mock.Mock()
+
+ self.loop._start_serving(pf, self.sock)
+ self.assertTrue(call_soon.called)
+
+ # callback
+ loop = call_soon.call_args[0][0]
+ loop()
+ self.proactor.accept.assert_called_with(self.sock)
+
+ # conn
+ fut = unittest.mock.Mock()
+ fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock())
+
+ make_tr = self.loop._make_socket_transport = unittest.mock.Mock()
+ loop(fut)
+ self.assertTrue(fut.result.called)
+ self.assertTrue(make_tr.called)
+
+ # exception
+ fut.result.side_effect = OSError()
+ loop(fut)
+ self.assertTrue(self.sock.close.called)
+ self.assertTrue(m_log.exception.called)
+
+ def test_create_server_cancel(self):
+ pf = unittest.mock.Mock()
+ call_soon = self.loop.call_soon = unittest.mock.Mock()
+
+ self.loop._start_serving(pf, self.sock)
+ loop = call_soon.call_args[0][0]
+
+ # cancelled
+ fut = asyncio.Future(loop=self.loop)
+ fut.cancel()
+ loop(fut)
+ self.assertTrue(self.sock.close.called)
+
+ def test_stop_serving(self):
+ sock = unittest.mock.Mock()
+ self.loop._stop_serving(sock)
+ self.assertTrue(sock.close.called)
+ self.proactor._stop_serving.assert_called_with(sock)
diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py
new file mode 100644
index 0000000..8af4ee7
--- /dev/null
+++ b/Lib/test/test_asyncio/test_queues.py
@@ -0,0 +1,470 @@
+"""Tests for queues.py"""
+
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import locks
+from asyncio import queues
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class _QueueTestBase(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+
+class QueueBasicTests(_QueueTestBase):
+
+ def _test_repr_or_str(self, fn, expect_id):
+ """Test Queue's repr or str.
+
+ fn is repr or str. expect_id is True if we expect the Queue's id to
+ appear in fn(Queue()).
+ """
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.2, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(loop=loop)
+ self.assertTrue(fn(q).startswith('<Queue'), fn(q))
+ id_is_present = hex(id(q)) in fn(q)
+ self.assertEqual(expect_id, id_is_present)
+
+ @tasks.coroutine
+ def add_getter():
+ q = queues.Queue(loop=loop)
+ # Start a task that waits to get.
+ tasks.Task(q.get(), loop=loop)
+ # Let it start waiting.
+ yield from tasks.sleep(0.1, loop=loop)
+ self.assertTrue('_getters[1]' in fn(q))
+ # resume q.get coroutine to finish generator
+ q.put_nowait(0)
+
+ loop.run_until_complete(add_getter())
+
+ @tasks.coroutine
+ def add_putter():
+ q = queues.Queue(maxsize=1, loop=loop)
+ q.put_nowait(1)
+ # Start a task that waits to put.
+ tasks.Task(q.put(2), loop=loop)
+ # Let it start waiting.
+ yield from tasks.sleep(0.1, loop=loop)
+ self.assertTrue('_putters[1]' in fn(q))
+ # resume q.put coroutine to finish generator
+ q.get_nowait()
+
+ loop.run_until_complete(add_putter())
+
+ q = queues.Queue(loop=loop)
+ q.put_nowait(1)
+ self.assertTrue('_queue=[1]' in fn(q))
+
+ def test_ctor_loop(self):
+ loop = unittest.mock.Mock()
+ q = queues.Queue(loop=loop)
+ self.assertIs(q._loop, loop)
+
+ q = queues.Queue(loop=self.loop)
+ self.assertIs(q._loop, self.loop)
+
+ def test_ctor_noloop(self):
+ try:
+ events.set_event_loop(self.loop)
+ q = queues.Queue()
+ self.assertIs(q._loop, self.loop)
+ finally:
+ events.set_event_loop(None)
+
+ def test_repr(self):
+ self._test_repr_or_str(repr, True)
+
+ def test_str(self):
+ self._test_repr_or_str(str, False)
+
+ def test_empty(self):
+ q = queues.Queue(loop=self.loop)
+ self.assertTrue(q.empty())
+ q.put_nowait(1)
+ self.assertFalse(q.empty())
+ self.assertEqual(1, q.get_nowait())
+ self.assertTrue(q.empty())
+
+ def test_full(self):
+ q = queues.Queue(loop=self.loop)
+ self.assertFalse(q.full())
+
+ q = queues.Queue(maxsize=1, loop=self.loop)
+ q.put_nowait(1)
+ self.assertTrue(q.full())
+
+ def test_order(self):
+ q = queues.Queue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 3, 2], items)
+
+ def test_maxsize(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0.01
+ self.assertAlmostEqual(0.02, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(maxsize=2, loop=loop)
+ self.assertEqual(2, q.maxsize)
+ have_been_put = []
+
+ @tasks.coroutine
+ def putter():
+ for i in range(3):
+ yield from q.put(i)
+ have_been_put.append(i)
+ return True
+
+ @tasks.coroutine
+ def test():
+ t = tasks.Task(putter(), loop=loop)
+ yield from tasks.sleep(0.01, loop=loop)
+
+ # The putter is blocked after putting two items.
+ self.assertEqual([0, 1], have_been_put)
+ self.assertEqual(0, q.get_nowait())
+
+ # Let the putter resume and put last item.
+ yield from tasks.sleep(0.01, loop=loop)
+ self.assertEqual([0, 1, 2], have_been_put)
+ self.assertEqual(1, q.get_nowait())
+ self.assertEqual(2, q.get_nowait())
+
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ loop.run_until_complete(test())
+ self.assertAlmostEqual(0.02, loop.time())
+
+
+class QueueGetTests(_QueueTestBase):
+
+ def test_blocking_get(self):
+ q = queues.Queue(loop=self.loop)
+ q.put_nowait(1)
+
+ @tasks.coroutine
+ def queue_get():
+ return (yield from q.get())
+
+ res = self.loop.run_until_complete(queue_get())
+ self.assertEqual(1, res)
+
+ def test_get_with_putters(self):
+ q = queues.Queue(1, loop=self.loop)
+ q.put_nowait(1)
+
+ waiter = futures.Future(loop=self.loop)
+ q._putters.append((2, waiter))
+
+ res = self.loop.run_until_complete(q.get())
+ self.assertEqual(1, res)
+ self.assertTrue(waiter.done())
+ self.assertIsNone(waiter.result())
+
+ def test_blocking_get_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(loop=loop)
+ started = locks.Event(loop=loop)
+ finished = False
+
+ @tasks.coroutine
+ def queue_get():
+ nonlocal finished
+ started.set()
+ res = yield from q.get()
+ finished = True
+ return res
+
+ @tasks.coroutine
+ def queue_put():
+ loop.call_later(0.01, q.put_nowait, 1)
+ queue_get_task = tasks.Task(queue_get(), loop=loop)
+ yield from started.wait()
+ self.assertFalse(finished)
+ res = yield from queue_get_task
+ self.assertTrue(finished)
+ return res
+
+ res = loop.run_until_complete(queue_put())
+ self.assertEqual(1, res)
+ self.assertAlmostEqual(0.01, loop.time())
+
+ def test_nonblocking_get(self):
+ q = queues.Queue(loop=self.loop)
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_nonblocking_get_exception(self):
+ q = queues.Queue(loop=self.loop)
+ self.assertRaises(queues.Empty, q.get_nowait)
+
+ def test_get_cancelled(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0.01
+ self.assertAlmostEqual(0.061, when)
+ yield 0.05
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(loop=loop)
+
+ @tasks.coroutine
+ def queue_get():
+ return (yield from tasks.wait_for(q.get(), 0.051, loop=loop))
+
+ @tasks.coroutine
+ def test():
+ get_task = tasks.Task(queue_get(), loop=loop)
+ yield from tasks.sleep(0.01, loop=loop) # let the task start
+ q.put_nowait(1)
+ return (yield from get_task)
+
+ self.assertEqual(1, loop.run_until_complete(test()))
+ self.assertAlmostEqual(0.06, loop.time())
+
+ def test_get_cancelled_race(self):
+ q = queues.Queue(loop=self.loop)
+
+ t1 = tasks.Task(q.get(), loop=self.loop)
+ t2 = tasks.Task(q.get(), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ t1.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(t1.done())
+ q.put_nowait('a')
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(t2.result(), 'a')
+
+ def test_get_with_waiting_putters(self):
+ q = queues.Queue(loop=self.loop, maxsize=1)
+ tasks.Task(q.put('a'), loop=self.loop)
+ tasks.Task(q.put('b'), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(self.loop.run_until_complete(q.get()), 'a')
+ self.assertEqual(self.loop.run_until_complete(q.get()), 'b')
+
+
+class QueuePutTests(_QueueTestBase):
+
+ def test_blocking_put(self):
+ q = queues.Queue(loop=self.loop)
+
+ @tasks.coroutine
+ def queue_put():
+ # No maxsize, won't block.
+ yield from q.put(1)
+
+ self.loop.run_until_complete(queue_put())
+
+ def test_blocking_put_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ q = queues.Queue(maxsize=1, loop=loop)
+ started = locks.Event(loop=loop)
+ finished = False
+
+ @tasks.coroutine
+ def queue_put():
+ nonlocal finished
+ started.set()
+ yield from q.put(1)
+ yield from q.put(2)
+ finished = True
+
+ @tasks.coroutine
+ def queue_get():
+ loop.call_later(0.01, q.get_nowait)
+ queue_put_task = tasks.Task(queue_put(), loop=loop)
+ yield from started.wait()
+ self.assertFalse(finished)
+ yield from queue_put_task
+ self.assertTrue(finished)
+
+ loop.run_until_complete(queue_get())
+ self.assertAlmostEqual(0.01, loop.time())
+
+ def test_nonblocking_put(self):
+ q = queues.Queue(loop=self.loop)
+ q.put_nowait(1)
+ self.assertEqual(1, q.get_nowait())
+
+ def test_nonblocking_put_exception(self):
+ q = queues.Queue(maxsize=1, loop=self.loop)
+ q.put_nowait(1)
+ self.assertRaises(queues.Full, q.put_nowait, 2)
+
+ def test_put_cancelled(self):
+ q = queues.Queue(loop=self.loop)
+
+ @tasks.coroutine
+ def queue_put():
+ yield from q.put(1)
+ return True
+
+ @tasks.coroutine
+ def test():
+ return (yield from q.get())
+
+ t = tasks.Task(queue_put(), loop=self.loop)
+ self.assertEqual(1, self.loop.run_until_complete(test()))
+ self.assertTrue(t.done())
+ self.assertTrue(t.result())
+
+ def test_put_cancelled_race(self):
+ q = queues.Queue(loop=self.loop, maxsize=1)
+
+ tasks.Task(q.put('a'), loop=self.loop)
+ tasks.Task(q.put('c'), loop=self.loop)
+ t = tasks.Task(q.put('b'), loop=self.loop)
+
+ test_utils.run_briefly(self.loop)
+ t.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(t.done())
+ self.assertEqual(q.get_nowait(), 'a')
+ self.assertEqual(q.get_nowait(), 'c')
+
+ def test_put_with_waiting_getters(self):
+ q = queues.Queue(loop=self.loop)
+ t = tasks.Task(q.get(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.loop.run_until_complete(q.put('a'))
+ self.assertEqual(self.loop.run_until_complete(t), 'a')
+
+
+class LifoQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = queues.LifoQueue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([2, 3, 1], items)
+
+
+class PriorityQueueTests(_QueueTestBase):
+
+ def test_order(self):
+ q = queues.PriorityQueue(loop=self.loop)
+ for i in [1, 3, 2]:
+ q.put_nowait(i)
+
+ items = [q.get_nowait() for _ in range(3)]
+ self.assertEqual([1, 2, 3], items)
+
+
+class JoinableQueueTests(_QueueTestBase):
+
+ def test_task_done_underflow(self):
+ q = queues.JoinableQueue(loop=self.loop)
+ self.assertRaises(ValueError, q.task_done)
+
+ def test_task_done(self):
+ q = queues.JoinableQueue(loop=self.loop)
+ for i in range(100):
+ q.put_nowait(i)
+
+ accumulator = 0
+
+ # Two workers get items from the queue and call task_done after each.
+ # Join the queue and assert all items have been processed.
+ running = True
+
+ @tasks.coroutine
+ def worker():
+ nonlocal accumulator
+
+ while running:
+ item = yield from q.get()
+ accumulator += item
+ q.task_done()
+
+ @tasks.coroutine
+ def test():
+ for _ in range(2):
+ tasks.Task(worker(), loop=self.loop)
+
+ yield from q.join()
+
+ self.loop.run_until_complete(test())
+ self.assertEqual(sum(range(100)), accumulator)
+
+ # close running generators
+ running = False
+ for i in range(2):
+ q.put_nowait(0)
+
+ def test_join_empty_queue(self):
+ q = queues.JoinableQueue(loop=self.loop)
+
+ # Test that a queue join()s successfully, and before anything else
+ # (done twice for insurance).
+
+ @tasks.coroutine
+ def join():
+ yield from q.join()
+ yield from q.join()
+
+ self.loop.run_until_complete(join())
+
+ def test_format(self):
+ q = queues.JoinableQueue(loop=self.loop)
+ self.assertEqual(q._format(), 'maxsize=0')
+
+ q._unfinished_tasks = 2
+ self.assertEqual(q._format(), 'maxsize=0 tasks=2')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py
new file mode 100644
index 0000000..0225e13
--- /dev/null
+++ b/Lib/test/test_asyncio/test_selector_events.py
@@ -0,0 +1,1485 @@
+"""Tests for selector_events.py"""
+
+import collections
+import errno
+import gc
+import pprint
+import socket
+import sys
+import unittest
+import unittest.mock
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+from asyncio import futures
+from asyncio import selectors
+from asyncio import test_utils
+from asyncio.protocols import DatagramProtocol, Protocol
+from asyncio.selector_events import BaseSelectorEventLoop
+from asyncio.selector_events import _SelectorTransport
+from asyncio.selector_events import _SelectorSslTransport
+from asyncio.selector_events import _SelectorSocketTransport
+from asyncio.selector_events import _SelectorDatagramTransport
+
+
+class TestBaseSelectorEventLoop(BaseSelectorEventLoop):
+
+ def _make_self_pipe(self):
+ self._ssock = unittest.mock.Mock()
+ self._csock = unittest.mock.Mock()
+ self._internal_fds += 1
+
+
+class BaseSelectorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = TestBaseSelectorEventLoop(unittest.mock.Mock())
+
+ def test_make_socket_transport(self):
+ m = unittest.mock.Mock()
+ self.loop.add_reader = unittest.mock.Mock()
+ self.assertIsInstance(
+ self.loop._make_socket_transport(m, m), _SelectorSocketTransport)
+
+ def test_make_ssl_transport(self):
+ m = unittest.mock.Mock()
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.assertIsInstance(
+ self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport)
+
+ def test_close(self):
+ ssock = self.loop._ssock
+ ssock.fileno.return_value = 7
+ csock = self.loop._csock
+ csock.fileno.return_value = 1
+ remove_reader = self.loop.remove_reader = unittest.mock.Mock()
+
+ self.loop._selector.close()
+ self.loop._selector = selector = unittest.mock.Mock()
+ self.loop.close()
+ self.assertIsNone(self.loop._selector)
+ self.assertIsNone(self.loop._csock)
+ self.assertIsNone(self.loop._ssock)
+ selector.close.assert_called_with()
+ ssock.close.assert_called_with()
+ csock.close.assert_called_with()
+ remove_reader.assert_called_with(7)
+
+ self.loop.close()
+ self.loop.close()
+
+ def test_close_no_selector(self):
+ ssock = self.loop._ssock
+ csock = self.loop._csock
+ remove_reader = self.loop.remove_reader = unittest.mock.Mock()
+
+ self.loop._selector.close()
+ self.loop._selector = None
+ self.loop.close()
+ self.assertIsNone(self.loop._selector)
+ self.assertFalse(ssock.close.called)
+ self.assertFalse(csock.close.called)
+ self.assertFalse(remove_reader.called)
+
+ def test_socketpair(self):
+ self.assertRaises(NotImplementedError, self.loop._socketpair)
+
+ def test_read_from_self_tryagain(self):
+ self.loop._ssock.recv.side_effect = BlockingIOError
+ self.assertIsNone(self.loop._read_from_self())
+
+ def test_read_from_self_exception(self):
+ self.loop._ssock.recv.side_effect = OSError
+ self.assertRaises(OSError, self.loop._read_from_self)
+
+ def test_write_to_self_tryagain(self):
+ self.loop._csock.send.side_effect = BlockingIOError
+ self.assertIsNone(self.loop._write_to_self())
+
+ def test_write_to_self_exception(self):
+ self.loop._csock.send.side_effect = OSError()
+ self.assertRaises(OSError, self.loop._write_to_self)
+
+ def test_sock_recv(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_recv = unittest.mock.Mock()
+
+ f = self.loop.sock_recv(sock, 1024)
+ self.assertIsInstance(f, futures.Future)
+ self.loop._sock_recv.assert_called_with(f, False, sock, 1024)
+
+ def test__sock_recv_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertFalse(sock.recv.called)
+
+ def test__sock_recv_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop._sock_recv(f, True, sock, 1024)
+ self.assertEqual((10,), self.loop.remove_reader.call_args[0])
+
+ def test__sock_recv_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.recv.side_effect = BlockingIOError
+
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024),
+ self.loop.add_reader.call_args[0])
+
+ def test__sock_recv_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.recv.side_effect = OSError()
+
+ self.loop._sock_recv(f, False, sock, 1024)
+ self.assertIs(err, f.exception())
+
+ def test_sock_sendall(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_sendall = unittest.mock.Mock()
+
+ f = self.loop.sock_sendall(sock, b'data')
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock, b'data'),
+ self.loop._sock_sendall.call_args[0])
+
+ def test_sock_sendall_nodata(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_sendall = unittest.mock.Mock()
+
+ f = self.loop.sock_sendall(sock, b'')
+ self.assertIsInstance(f, futures.Future)
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+ self.assertFalse(self.loop._sock_sendall.called)
+
+ def test__sock_sendall_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(sock.send.called)
+
+ def test__sock_sendall_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, True, sock, b'data')
+ self.assertEqual((10,), self.loop.remove_writer.call_args[0])
+
+ def test__sock_sendall_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.send.side_effect = BlockingIOError
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_interrupted(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.send.side_effect = InterruptedError
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.send.side_effect = OSError()
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertIs(f.exception(), err)
+
+ def test__sock_sendall(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 4
+
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+
+ def test__sock_sendall_partial(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 2
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'ta'),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_sendall_none(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ sock.fileno.return_value = 10
+ sock.send.return_value = 0
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop._sock_sendall(f, False, sock, b'data')
+ self.assertFalse(f.done())
+ self.assertEqual(
+ (10, self.loop._sock_sendall, f, True, sock, b'data'),
+ self.loop.add_writer.call_args[0])
+
+ def test_sock_connect(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_connect = unittest.mock.Mock()
+
+ f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock, ('127.0.0.1', 8080)),
+ self.loop._sock_connect.call_args[0])
+
+ def test__sock_connect(self):
+ f = futures.Future(loop=self.loop)
+
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
+ self.assertTrue(f.done())
+ self.assertIsNone(f.result())
+ self.assertTrue(sock.connect.called)
+
+ def test__sock_connect_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080))
+ self.assertFalse(sock.connect.called)
+
+ def test__sock_connect_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertEqual((10,), self.loop.remove_writer.call_args[0])
+
+ def test__sock_connect_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.EAGAIN
+
+ self.loop.add_writer = unittest.mock.Mock()
+ self.loop.remove_writer = unittest.mock.Mock()
+
+ self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertEqual(
+ (10, self.loop._sock_connect, f,
+ True, sock, ('127.0.0.1', 8080)),
+ self.loop.add_writer.call_args[0])
+
+ def test__sock_connect_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.getsockopt.return_value = errno.ENOTCONN
+
+ self.loop.remove_writer = unittest.mock.Mock()
+ self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080))
+ self.assertIsInstance(f.exception(), OSError)
+
+ def test_sock_accept(self):
+ sock = unittest.mock.Mock()
+ self.loop._sock_accept = unittest.mock.Mock()
+
+ f = self.loop.sock_accept(sock)
+ self.assertIsInstance(f, futures.Future)
+ self.assertEqual(
+ (f, False, sock), self.loop._sock_accept.call_args[0])
+
+ def test__sock_accept(self):
+ f = futures.Future(loop=self.loop)
+
+ conn = unittest.mock.Mock()
+
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.return_value = conn, ('127.0.0.1', 1000)
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertTrue(f.done())
+ self.assertEqual((conn, ('127.0.0.1', 1000)), f.result())
+ self.assertEqual((False,), conn.setblocking.call_args[0])
+
+ def test__sock_accept_canceled_fut(self):
+ sock = unittest.mock.Mock()
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertFalse(sock.accept.called)
+
+ def test__sock_accept_unregister(self):
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+
+ f = futures.Future(loop=self.loop)
+ f.cancel()
+
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop._sock_accept(f, True, sock)
+ self.assertEqual((10,), self.loop.remove_reader.call_args[0])
+
+ def test__sock_accept_tryagain(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ sock.accept.side_effect = BlockingIOError
+
+ self.loop.add_reader = unittest.mock.Mock()
+ self.loop._sock_accept(f, False, sock)
+ self.assertEqual(
+ (10, self.loop._sock_accept, f, True, sock),
+ self.loop.add_reader.call_args[0])
+
+ def test__sock_accept_exception(self):
+ f = futures.Future(loop=self.loop)
+ sock = unittest.mock.Mock()
+ sock.fileno.return_value = 10
+ err = sock.accept.side_effect = OSError()
+
+ self.loop._sock_accept(f, False, sock)
+ self.assertIs(err, f.exception())
+
+ def test_add_reader(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertTrue(self.loop._selector.register.called)
+ fd, mask, (r, w) = self.loop._selector.register.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertIsNone(w)
+
+ def test_add_reader_existing(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (reader, writer))
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertTrue(reader.cancel.called)
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertEqual(writer, w)
+
+ def test_add_reader_existing_writer(self):
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (None, writer))
+ cb = lambda: True
+ self.loop.add_reader(1, cb)
+
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(cb, r._callback)
+ self.assertEqual(writer, w)
+
+ def test_remove_reader(self):
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (None, None))
+ self.assertFalse(self.loop.remove_reader(1))
+
+ self.assertTrue(self.loop._selector.unregister.called)
+
+ def test_remove_reader_read_write(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
+ (reader, writer))
+ self.assertTrue(
+ self.loop.remove_reader(1))
+
+ self.assertFalse(self.loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_WRITE, (None, writer)),
+ self.loop._selector.modify.call_args[0])
+
+ def test_remove_reader_unknown(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ self.assertFalse(
+ self.loop.remove_reader(1))
+
+ def test_add_writer(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ cb = lambda: True
+ self.loop.add_writer(1, cb)
+
+ self.assertTrue(self.loop._selector.register.called)
+ fd, mask, (r, w) = self.loop._selector.register.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE, mask)
+ self.assertIsNone(r)
+ self.assertEqual(cb, w._callback)
+
+ def test_add_writer_existing(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, writer))
+ cb = lambda: True
+ self.loop.add_writer(1, cb)
+
+ self.assertTrue(writer.cancel.called)
+ self.assertFalse(self.loop._selector.register.called)
+ self.assertTrue(self.loop._selector.modify.called)
+ fd, mask, (r, w) = self.loop._selector.modify.call_args[0]
+ self.assertEqual(1, fd)
+ self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask)
+ self.assertEqual(reader, r)
+ self.assertEqual(cb, w._callback)
+
+ def test_remove_writer(self):
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_WRITE, (None, None))
+ self.assertFalse(self.loop.remove_writer(1))
+
+ self.assertTrue(self.loop._selector.unregister.called)
+
+ def test_remove_writer_read_write(self):
+ reader = unittest.mock.Mock()
+ writer = unittest.mock.Mock()
+ self.loop._selector.get_key.return_value = selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE,
+ (reader, writer))
+ self.assertTrue(
+ self.loop.remove_writer(1))
+
+ self.assertFalse(self.loop._selector.unregister.called)
+ self.assertEqual(
+ (1, selectors.EVENT_READ, (reader, None)),
+ self.loop._selector.modify.call_args[0])
+
+ def test_remove_writer_unknown(self):
+ self.loop._selector.get_key.side_effect = KeyError
+ self.assertFalse(
+ self.loop.remove_writer(1))
+
+ def test_process_events_read(self):
+ reader = unittest.mock.Mock()
+ reader._cancelled = False
+
+ self.loop._add_callback = unittest.mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, None)),
+ selectors.EVENT_READ)])
+ self.assertTrue(self.loop._add_callback.called)
+ self.loop._add_callback.assert_called_with(reader)
+
+ def test_process_events_read_cancelled(self):
+ reader = unittest.mock.Mock()
+ reader.cancelled = True
+
+ self.loop.remove_reader = unittest.mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(
+ 1, 1, selectors.EVENT_READ, (reader, None)),
+ selectors.EVENT_READ)])
+ self.loop.remove_reader.assert_called_with(1)
+
+ def test_process_events_write(self):
+ writer = unittest.mock.Mock()
+ writer._cancelled = False
+
+ self.loop._add_callback = unittest.mock.Mock()
+ self.loop._process_events(
+ [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
+ (None, writer)),
+ selectors.EVENT_WRITE)])
+ self.loop._add_callback.assert_called_with(writer)
+
+ def test_process_events_write_cancelled(self):
+ writer = unittest.mock.Mock()
+ writer.cancelled = True
+ self.loop.remove_writer = unittest.mock.Mock()
+
+ self.loop._process_events(
+ [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE,
+ (None, writer)),
+ selectors.EVENT_WRITE)])
+ self.loop.remove_writer.assert_called_with(1)
+
+
+class SelectorTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+
+ def test_ctor(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ self.assertIs(tr._loop, self.loop)
+ self.assertIs(tr._sock, self.sock)
+ self.assertIs(tr._sock_fd, 7)
+
+ def test_abort(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._force_close = unittest.mock.Mock()
+
+ tr.abort()
+ tr._force_close.assert_called_with(None)
+
+ def test_close(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr.close()
+
+ self.assertTrue(tr._closing)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+ self.protocol.connection_lost(None)
+ self.assertEqual(tr._conn_lost, 1)
+
+ tr.close()
+ self.assertEqual(tr._conn_lost, 1)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+
+ def test_close_write_buffer(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._buffer.append(b'data')
+ tr.close()
+
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.connection_lost.called)
+
+ def test_force_close(self):
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._buffer.append(b'1')
+ self.loop.add_reader(7, unittest.mock.sentinel)
+ self.loop.add_writer(7, unittest.mock.sentinel)
+ tr._force_close(None)
+
+ self.assertTrue(tr._closing)
+ self.assertEqual(tr._buffer, collections.deque())
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+
+ # second close should not remove reader
+ tr._force_close(None)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual(1, self.loop.remove_reader_count[7])
+
+ @unittest.mock.patch('asyncio.log.asyncio_log.exception')
+ def test_fatal_error(self, m_exc):
+ exc = OSError()
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._force_close = unittest.mock.Mock()
+ tr._fatal_error(exc)
+
+ m_exc.assert_called_with('Fatal error for %s', tr)
+ tr._force_close.assert_called_with(exc)
+
+ def test_connection_lost(self):
+ exc = OSError()
+ tr = _SelectorTransport(self.loop, self.sock, self.protocol, None)
+ tr._call_connection_lost(exc)
+
+ self.protocol.connection_lost.assert_called_with(exc)
+ self.sock.close.assert_called_with()
+ self.assertIsNone(tr._sock)
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+
+class SelectorSocketTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock_fd = self.sock.fileno.return_value = 7
+
+ def test_ctor(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.loop.assert_reader(7, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future(loop=self.loop)
+
+ _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol, fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+
+ def test_pause_resume(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(7, tr._read_ready)
+ tr.pause()
+ self.assertTrue(tr._paused)
+ self.assertFalse(7 in self.loop.readers)
+ tr.resume()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(7, tr._read_ready)
+
+ def test_read_ready(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+
+ self.sock.recv.return_value = b'data'
+ transport._read_ready()
+
+ self.protocol.data_received.assert_called_with(b'data')
+
+ def test_read_ready_eof(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close = unittest.mock.Mock()
+
+ self.sock.recv.return_value = b''
+ transport._read_ready()
+
+ self.protocol.eof_received.assert_called_with()
+ transport.close.assert_called_with()
+
+ def test_read_ready_eof_keep_open(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close = unittest.mock.Mock()
+
+ self.sock.recv.return_value = b''
+ self.protocol.eof_received.return_value = True
+ transport._read_ready()
+
+ self.protocol.eof_received.assert_called_with()
+ self.assertFalse(transport.close.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_tryagain(self, m_exc):
+ self.sock.recv.side_effect = BlockingIOError
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_tryagain_interrupted(self, m_exc):
+ self.sock.recv.side_effect = InterruptedError
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_conn_reset(self, m_exc):
+ err = self.sock.recv.side_effect = ConnectionResetError()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._force_close = unittest.mock.Mock()
+ transport._read_ready()
+ transport._force_close.assert_called_with(err)
+
+ @unittest.mock.patch('logging.exception')
+ def test_read_ready_err(self, m_exc):
+ err = self.sock.recv.side_effect = OSError()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_write(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+ self.sock.send.assert_called_with(data)
+
+ def test_write_no_data(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(b'data')
+ transport.write(b'')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_buffer(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(b'data1')
+ transport.write(b'data2')
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(collections.deque([b'data1', b'data2']),
+ transport._buffer)
+
+ def test_write_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'ta']), transport._buffer)
+
+ def test_write_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+ self.sock.fileno.return_value = 7
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ data = b'data'
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.write(data)
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_write_exception(self, m_log):
+ err = self.sock.send.side_effect = OSError()
+
+ data = b'data'
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.write(data)
+ transport._fatal_error.assert_called_with(err)
+ transport._conn_lost = 1
+
+ self.sock.reset_mock()
+ transport.write(data)
+ self.assertFalse(self.sock.send.called)
+ self.assertEqual(transport._conn_lost, 2)
+ transport.write(data)
+ transport.write(data)
+ transport.write(data)
+ transport.write(data)
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_write_str(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport.write, 'str')
+
+ def test_write_closing(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.write(b'data')
+ self.assertEqual(transport._conn_lost, 2)
+
+ def test_write_ready(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.assertTrue(self.sock.send.called)
+ self.assertEqual(self.sock.send.call_args[0], (data,))
+ self.assertFalse(self.loop.writers)
+
+ def test_write_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._closing = True
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.sock.send.assert_called_with(data)
+ self.assertFalse(self.loop.writers)
+ self.sock.close.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_ready_no_data(self):
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport._write_ready)
+
+ def test_write_ready_partial(self):
+ data = b'data'
+ self.sock.send.return_value = 2
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'ta']), transport._buffer)
+
+ def test_write_ready_partial_none(self):
+ data = b'data'
+ self.sock.send.return_value = 0
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append(data)
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_ready_tryagain(self):
+ self.sock.send.side_effect = BlockingIOError
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ self.loop.add_writer(7, transport._write_ready)
+ transport._write_ready()
+
+ self.loop.assert_writer(7, transport._write_ready)
+ self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
+
+ def test_write_ready_exception(self):
+ err = self.sock.send.side_effect = OSError()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append(b'data')
+ transport._write_ready()
+ transport._fatal_error.assert_called_with(err)
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_write_ready_exception_and_close(self, m_log):
+ self.sock.send.side_effect = OSError()
+ remove_writer = self.loop.remove_writer = unittest.mock.Mock()
+
+ transport = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ transport.close()
+ transport._buffer.append(b'data')
+ transport._write_ready()
+ remove_writer.assert_called_with(self.sock_fd)
+
+ def test_write_eof(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+ tr.write_eof()
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.write_eof()
+ self.assertEqual(self.sock.shutdown.call_count, 1)
+ tr.close()
+
+ def test_write_eof_buffer(self):
+ tr = _SelectorSocketTransport(
+ self.loop, self.sock, self.protocol)
+ self.sock.send.side_effect = BlockingIOError
+ tr.write(b'data')
+ tr.write_eof()
+ self.assertEqual(tr._buffer, collections.deque([b'data']))
+ self.assertTrue(tr._eof)
+ self.assertFalse(self.sock.shutdown.called)
+ self.sock.send.side_effect = lambda _: 4
+ tr._write_ready()
+ self.sock.send.assert_called_with(b'data')
+ self.sock.shutdown.assert_called_with(socket.SHUT_WR)
+ tr.close()
+
+
+@unittest.skipIf(ssl is None, 'No ssl module')
+class SelectorSslTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(Protocol)
+ self.sock = unittest.mock.Mock(socket.socket)
+ self.sock.fileno.return_value = 7
+ self.sslsock = unittest.mock.Mock()
+ self.sslsock.fileno.return_value = 1
+ self.sslcontext = unittest.mock.Mock()
+ self.sslcontext.wrap_socket.return_value = self.sslsock
+
+ def _make_one(self, create_waiter=None):
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ self.sock.reset_mock()
+ self.sslsock.reset_mock()
+ self.sslcontext.reset_mock()
+ self.loop.reset_counters()
+ return transport
+
+ def test_on_handshake(self):
+ waiter = futures.Future(loop=self.loop)
+ tr = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext,
+ waiter=waiter)
+ self.assertTrue(self.sslsock.do_handshake.called)
+ self.loop.assert_reader(1, tr._on_ready)
+ self.loop.assert_writer(1, tr._on_ready)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(waiter.result())
+
+ def test_on_handshake_reader_retry(self):
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._on_handshake()
+ self.loop.assert_reader(1, transport._on_handshake)
+
+ def test_on_handshake_writer_retry(self):
+ self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._on_handshake()
+ self.loop.assert_writer(1, transport._on_handshake)
+
+ def test_on_handshake_exc(self):
+ exc = ValueError()
+ self.sslsock.do_handshake.side_effect = exc
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._waiter = futures.Future(loop=self.loop)
+ transport._on_handshake()
+ self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(transport._waiter.done())
+ self.assertIs(exc, transport._waiter.exception())
+
+ def test_on_handshake_base_exc(self):
+ transport = _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext)
+ transport._waiter = futures.Future(loop=self.loop)
+ exc = BaseException()
+ self.sslsock.do_handshake.side_effect = exc
+ self.assertRaises(BaseException, transport._on_handshake)
+ self.assertTrue(self.sslsock.close.called)
+ self.assertTrue(transport._waiter.done())
+ self.assertIs(exc, transport._waiter.exception())
+
+ def test_pause_resume(self):
+ tr = self._make_one()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(1, tr._on_ready)
+ tr.pause()
+ self.assertTrue(tr._paused)
+ self.assertFalse(1 in self.loop.readers)
+ tr.resume()
+ self.assertFalse(tr._paused)
+ self.loop.assert_reader(1, tr._on_ready)
+
+ def test_write_no_data(self):
+ transport = self._make_one()
+ transport._buffer.append(b'data')
+ transport.write(b'')
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_write_str(self):
+ transport = self._make_one()
+ self.assertRaises(AssertionError, transport.write, 'str')
+
+ def test_write_closing(self):
+ transport = self._make_one()
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.write(b'data')
+ self.assertEqual(transport._conn_lost, 2)
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_write_exception(self, m_log):
+ transport = self._make_one()
+ transport._conn_lost = 1
+ transport.write(b'data')
+ self.assertEqual(transport._buffer, collections.deque())
+ transport.write(b'data')
+ transport.write(b'data')
+ transport.write(b'data')
+ transport.write(b'data')
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_on_ready_recv(self):
+ self.sslsock.recv.return_value = b'data'
+ transport = self._make_one()
+ transport._on_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertEqual((b'data',), self.protocol.data_received.call_args[0])
+
+ def test_on_ready_recv_eof(self):
+ self.sslsock.recv.return_value = b''
+ transport = self._make_one()
+ transport.close = unittest.mock.Mock()
+ transport._on_ready()
+ transport.close.assert_called_with()
+ self.protocol.eof_received.assert_called_with()
+
+ def test_on_ready_recv_conn_reset(self):
+ err = self.sslsock.recv.side_effect = ConnectionResetError()
+ transport = self._make_one()
+ transport._force_close = unittest.mock.Mock()
+ transport._on_ready()
+ transport._force_close.assert_called_with(err)
+
+ def test_on_ready_recv_retry(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ transport = self._make_one()
+ transport._on_ready()
+ self.assertTrue(self.sslsock.recv.called)
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = ssl.SSLWantWriteError
+ transport._on_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = BlockingIOError
+ transport._on_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ self.sslsock.recv.side_effect = InterruptedError
+ transport._on_ready()
+ self.assertFalse(self.protocol.data_received.called)
+
+ def test_on_ready_recv_exc(self):
+ err = self.sslsock.recv.side_effect = OSError()
+ transport = self._make_one()
+ transport._fatal_error = unittest.mock.Mock()
+ transport._on_ready()
+ transport._fatal_error.assert_called_with(err)
+
+ def test_on_ready_send(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+ transport._on_ready()
+ self.assertEqual(collections.deque(), transport._buffer)
+ self.assertTrue(self.sslsock.send.called)
+
+ def test_on_ready_send_none(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 0
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(collections.deque([b'data1data2']), transport._buffer)
+
+ def test_on_ready_send_partial(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 2
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(collections.deque([b'ta1data2']), transport._buffer)
+
+ def test_on_ready_send_closing_partial(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 2
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data1', b'data2'])
+ transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertFalse(self.sslsock.close.called)
+
+ def test_on_ready_send_closing(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport.close()
+ transport._buffer = collections.deque([b'data'])
+ transport._on_ready()
+ self.assertFalse(self.loop.writers)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_on_ready_send_closing_empty_buffer(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ self.sslsock.send.return_value = 4
+ transport = self._make_one()
+ transport.close()
+ transport._buffer = collections.deque()
+ transport._on_ready()
+ self.assertFalse(self.loop.writers)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_on_ready_send_retry(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+
+ self.sslsock.send.side_effect = ssl.SSLWantReadError
+ transport._on_ready()
+ self.assertTrue(self.sslsock.send.called)
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ self.sslsock.send.side_effect = ssl.SSLWantWriteError
+ transport._on_ready()
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ self.sslsock.send.side_effect = BlockingIOError()
+ transport._on_ready()
+ self.assertEqual(collections.deque([b'data']), transport._buffer)
+
+ def test_on_ready_send_exc(self):
+ self.sslsock.recv.side_effect = ssl.SSLWantReadError
+ err = self.sslsock.send.side_effect = OSError()
+
+ transport = self._make_one()
+ transport._buffer = collections.deque([b'data'])
+ transport._fatal_error = unittest.mock.Mock()
+ transport._on_ready()
+ transport._fatal_error.assert_called_with(err)
+ self.assertEqual(collections.deque(), transport._buffer)
+
+ def test_write_eof(self):
+ tr = self._make_one()
+ self.assertFalse(tr.can_write_eof())
+ self.assertRaises(NotImplementedError, tr.write_eof)
+
+ def test_close(self):
+ tr = self._make_one()
+ tr.close()
+
+ self.assertTrue(tr._closing)
+ self.assertEqual(1, self.loop.remove_reader_count[1])
+ self.assertEqual(tr._conn_lost, 1)
+
+ tr.close()
+ self.assertEqual(tr._conn_lost, 1)
+ self.assertEqual(1, self.loop.remove_reader_count[1])
+
+ @unittest.skipIf(ssl is None or not ssl.HAS_SNI, 'No SNI support')
+ def test_server_hostname(self):
+ _SelectorSslTransport(
+ self.loop, self.sock, self.protocol, self.sslcontext,
+ server_hostname='localhost')
+ self.sslcontext.wrap_socket.assert_called_with(
+ self.sock, do_handshake_on_connect=False, server_side=False,
+ server_hostname='localhost')
+
+
+class SelectorDatagramTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(DatagramProtocol)
+ self.sock = unittest.mock.Mock(spec_set=socket.socket)
+ self.sock.fileno.return_value = 7
+
+ def test_read_ready(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234))
+ transport._read_ready()
+
+ self.protocol.datagram_received.assert_called_with(
+ b'data', ('0.0.0.0', 1234))
+
+ def test_read_ready_tryagain(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ self.sock.recvfrom.side_effect = BlockingIOError
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_read_ready_err(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+
+ err = self.sock.recvfrom.side_effect = OSError()
+ transport._fatal_error = unittest.mock.Mock()
+ transport._read_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_sendto(self):
+ data = b'data'
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport.sendto(data, ('0.0.0.0', 1234))
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234)))
+
+ def test_sendto_no_data(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append((b'data', ('0.0.0.0', 12345)))
+ transport.sendto(b'', ())
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ def test_sendto_buffer(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+ transport.sendto(b'data2', ('0.0.0.0', 12345))
+ self.assertFalse(self.sock.sendto.called)
+ self.assertEqual(
+ [(b'data1', ('0.0.0.0', 12345)),
+ (b'data2', ('0.0.0.0', 12345))],
+ list(transport._buffer))
+
+ def test_sendto_tryagain(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport.sendto(data, ('0.0.0.0', 12345))
+
+ self.loop.assert_writer(7, transport._sendto_ready)
+ self.assertEqual(
+ [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+ @unittest.mock.patch('asyncio.selector_events.asyncio_log')
+ def test_sendto_exception(self, m_log):
+ data = b'data'
+ err = self.sock.sendto.side_effect = OSError()
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertTrue(transport._fatal_error.called)
+ transport._fatal_error.assert_called_with(err)
+ transport._conn_lost = 1
+
+ transport._address = ('123',)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ transport.sendto(data)
+ m_log.warning.assert_called_with('socket.send() raised exception.')
+
+ def test_sendto_connection_refused(self):
+ data = b'data'
+
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data, ())
+
+ self.assertEqual(transport._conn_lost, 0)
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_connection_refused_connected(self):
+ data = b'data'
+
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ transport._fatal_error = unittest.mock.Mock()
+ transport.sendto(data)
+
+ self.assertTrue(transport._fatal_error.called)
+
+ def test_sendto_str(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ self.assertRaises(AssertionError, transport.sendto, 'str', ())
+
+ def test_sendto_connected_addr(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ self.assertRaises(
+ AssertionError, transport.sendto, b'str', ('0.0.0.0', 2))
+
+ def test_sendto_closing(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, address=(1,))
+ transport.close()
+ self.assertEqual(transport._conn_lost, 1)
+ transport.sendto(b'data', (1,))
+ self.assertEqual(transport._conn_lost, 2)
+
+ def test_sendto_ready(self):
+ data = b'data'
+ self.sock.sendto.return_value = len(data)
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.append((data, ('0.0.0.0', 12345)))
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.assertTrue(self.sock.sendto.called)
+ self.assertEqual(
+ self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345)))
+ self.assertFalse(self.loop.writers)
+
+ def test_sendto_ready_closing(self):
+ data = b'data'
+ self.sock.send.return_value = len(data)
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._closing = True
+ transport._buffer.append((data, ()))
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.sock.sendto.assert_called_with(data, ())
+ self.assertFalse(self.loop.writers)
+ self.sock.close.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_sendto_ready_no_data(self):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+ self.assertFalse(self.sock.sendto.called)
+ self.assertFalse(self.loop.writers)
+
+ def test_sendto_ready_tryagain(self):
+ self.sock.sendto.side_effect = BlockingIOError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._buffer.extend([(b'data1', ()), (b'data2', ())])
+ self.loop.add_writer(7, transport._sendto_ready)
+ transport._sendto_ready()
+
+ self.loop.assert_writer(7, transport._sendto_ready)
+ self.assertEqual(
+ [(b'data1', ()), (b'data2', ())],
+ list(transport._buffer))
+
+ def test_sendto_ready_exception(self):
+ err = self.sock.sendto.side_effect = OSError()
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ transport._fatal_error.assert_called_with(err)
+
+ def test_sendto_ready_connection_refused(self):
+ self.sock.sendto.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol)
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertFalse(transport._fatal_error.called)
+
+ def test_sendto_ready_connection_refused_connection(self):
+ self.sock.send.side_effect = ConnectionRefusedError
+
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ transport._fatal_error = unittest.mock.Mock()
+ transport._buffer.append((b'data', ()))
+ transport._sendto_ready()
+
+ self.assertTrue(transport._fatal_error.called)
+
+ @unittest.mock.patch('asyncio.log.asyncio_log.exception')
+ def test_fatal_error_connected(self, m_exc):
+ transport = _SelectorDatagramTransport(
+ self.loop, self.sock, self.protocol, ('0.0.0.0', 1))
+ err = ConnectionRefusedError()
+ transport._fatal_error(err)
+ self.protocol.connection_refused.assert_called_with(err)
+ m_exc.assert_called_with('Fatal error for %s', transport)
diff --git a/Lib/test/test_asyncio/test_selectors.py b/Lib/test/test_asyncio/test_selectors.py
new file mode 100644
index 0000000..2f7dc69
--- /dev/null
+++ b/Lib/test/test_asyncio/test_selectors.py
@@ -0,0 +1,145 @@
+"""Tests for selectors.py."""
+
+import unittest
+import unittest.mock
+
+from asyncio import selectors
+
+
+class FakeSelector(selectors.BaseSelector):
+ """Trivial non-abstract subclass of BaseSelector."""
+
+ def select(self, timeout=None):
+ raise NotImplementedError
+
+
+class BaseSelectorTests(unittest.TestCase):
+
+ def test_fileobj_to_fd(self):
+ self.assertEqual(10, selectors._fileobj_to_fd(10))
+
+ f = unittest.mock.Mock()
+ f.fileno.return_value = 10
+ self.assertEqual(10, selectors._fileobj_to_fd(f))
+
+ f.fileno.side_effect = AttributeError
+ self.assertRaises(ValueError, selectors._fileobj_to_fd, f)
+
+ def test_selector_key_repr(self):
+ key = selectors.SelectorKey(10, 10, selectors.EVENT_READ, None)
+ self.assertEqual(
+ "SelectorKey(fileobj=10, fd=10, events=1, data=None)", repr(key))
+
+ def test_register(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ key = s.register(fobj, selectors.EVENT_READ)
+ self.assertIsInstance(key, selectors.SelectorKey)
+ self.assertEqual(key.fd, 10)
+ self.assertIs(key, s._fd_to_key[10])
+
+ def test_register_unknown_event(self):
+ s = FakeSelector()
+ self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999)
+
+ def test_register_already_registered(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ s.register(fobj, selectors.EVENT_READ)
+ self.assertRaises(KeyError, s.register, fobj, selectors.EVENT_READ)
+
+ def test_unregister(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ s.register(fobj, selectors.EVENT_READ)
+ s.unregister(fobj)
+ self.assertFalse(s._fd_to_key)
+
+ def test_unregister_unknown(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ self.assertRaises(KeyError, s.unregister, fobj)
+
+ def test_modify_unknown(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ self.assertRaises(KeyError, s.modify, fobj, 1)
+
+ def test_modify(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ s = FakeSelector()
+ key = s.register(fobj, selectors.EVENT_READ)
+ key2 = s.modify(fobj, selectors.EVENT_WRITE)
+ self.assertNotEqual(key.events, key2.events)
+ self.assertEqual(
+ selectors.SelectorKey(fobj, 10, selectors.EVENT_WRITE, None),
+ s.get_key(fobj))
+
+ def test_modify_data(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ d1 = object()
+ d2 = object()
+
+ s = FakeSelector()
+ key = s.register(fobj, selectors.EVENT_READ, d1)
+ key2 = s.modify(fobj, selectors.EVENT_READ, d2)
+ self.assertEqual(key.events, key2.events)
+ self.assertNotEqual(key.data, key2.data)
+ self.assertEqual(
+ selectors.SelectorKey(fobj, 10, selectors.EVENT_READ, d2),
+ s.get_key(fobj))
+
+ def test_modify_same(self):
+ fobj = unittest.mock.Mock()
+ fobj.fileno.return_value = 10
+
+ data = object()
+
+ s = FakeSelector()
+ key = s.register(fobj, selectors.EVENT_READ, data)
+ key2 = s.modify(fobj, selectors.EVENT_READ, data)
+ self.assertIs(key, key2)
+
+ def test_select(self):
+ s = FakeSelector()
+ self.assertRaises(NotImplementedError, s.select)
+
+ def test_close(self):
+ s = FakeSelector()
+ s.register(1, selectors.EVENT_READ)
+
+ s.close()
+ self.assertFalse(s._fd_to_key)
+
+ def test_context_manager(self):
+ s = FakeSelector()
+
+ with s as sel:
+ sel.register(1, selectors.EVENT_READ)
+
+ self.assertFalse(s._fd_to_key)
+
+ def test_key_from_fd(self):
+ s = FakeSelector()
+ key = s.register(1, selectors.EVENT_READ)
+
+ self.assertIs(key, s._key_from_fd(1))
+ self.assertIsNone(s._key_from_fd(10))
+
+ if hasattr(selectors.DefaultSelector, 'fileno'):
+ def test_fileno(self):
+ self.assertIsInstance(selectors.DefaultSelector().fileno(), int)
diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py
new file mode 100644
index 0000000..011a09d
--- /dev/null
+++ b/Lib/test/test_asyncio/test_streams.py
@@ -0,0 +1,361 @@
+"""Tests for streams.py."""
+
+import gc
+import ssl
+import unittest
+import unittest.mock
+
+from asyncio import events
+from asyncio import streams
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class StreamReaderTests(unittest.TestCase):
+
+ DATA = b'line1\nline2\nline3\n'
+
+ def setUp(self):
+ self.loop = events.new_event_loop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ # just in case if we have transport close callbacks
+ test_utils.run_briefly(self.loop)
+
+ self.loop.close()
+ gc.collect()
+
+ @unittest.mock.patch('asyncio.streams.events')
+ def test_ctor_global_loop(self, m_events):
+ stream = streams.StreamReader()
+ self.assertIs(stream.loop, m_events.get_event_loop.return_value)
+
+ def test_open_connection(self):
+ with test_utils.run_test_server() as httpd:
+ f = streams.open_connection(*httpd.address, loop=self.loop)
+ reader, writer = self.loop.run_until_complete(f)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.readline()
+ data = self.loop.run_until_complete(f)
+ self.assertEqual(data, b'HTTP/1.0 200 OK\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+
+ writer.close()
+
+ @unittest.skipIf(ssl is None, 'No ssl module')
+ def test_open_connection_no_loop_ssl(self):
+ with test_utils.run_test_server(use_ssl=True) as httpd:
+ try:
+ events.set_event_loop(self.loop)
+ f = streams.open_connection(*httpd.address,
+ ssl=test_utils.dummy_ssl_context())
+ reader, writer = self.loop.run_until_complete(f)
+ finally:
+ events.set_event_loop(None)
+ writer.write(b'GET / HTTP/1.0\r\n\r\n')
+ f = reader.read()
+ data = self.loop.run_until_complete(f)
+ self.assertTrue(data.endswith(b'\r\n\r\nTest message'))
+
+ writer.close()
+
+ def test_open_connection_error(self):
+ with test_utils.run_test_server() as httpd:
+ f = streams.open_connection(*httpd.address, loop=self.loop)
+ reader, writer = self.loop.run_until_complete(f)
+ writer._protocol.connection_lost(ZeroDivisionError())
+ f = reader.read()
+ with self.assertRaises(ZeroDivisionError):
+ self.loop.run_until_complete(f)
+
+ writer.close()
+ test_utils.run_briefly(self.loop)
+
+ def test_feed_empty_data(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ stream.feed_data(b'')
+ self.assertEqual(0, stream.byte_count)
+
+ def test_feed_data_byte_count(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ stream.feed_data(self.DATA)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_read_zero(self):
+ # Read zero bytes.
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ data = self.loop.run_until_complete(stream.read(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_read(self):
+ # Read bytes.
+ stream = streams.StreamReader(loop=self.loop)
+ read_task = tasks.Task(stream.read(30), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA, data)
+ self.assertFalse(stream.byte_count)
+
+ def test_read_line_breaks(self):
+ # Read bytes without line breaks.
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line1')
+ stream.feed_data(b'line2')
+
+ data = self.loop.run_until_complete(stream.read(5))
+
+ self.assertEqual(b'line1', data)
+ self.assertEqual(5, stream.byte_count)
+
+ def test_read_eof(self):
+ # Read bytes, stop at eof.
+ stream = streams.StreamReader(loop=self.loop)
+ read_task = tasks.Task(stream.read(1024), loop=self.loop)
+
+ def cb():
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(b'', data)
+ self.assertFalse(stream.byte_count)
+
+ def test_read_until_eof(self):
+ # Read all bytes until eof.
+ stream = streams.StreamReader(loop=self.loop)
+ read_task = tasks.Task(stream.read(-1), loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk1\n')
+ stream.feed_data(b'chunk2')
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+
+ self.assertEqual(b'chunk1\nchunk2', data)
+ self.assertFalse(stream.byte_count)
+
+ def test_read_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.read(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.read(2))
+
+ def test_readline(self):
+ # Read one line.
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'chunk1 ')
+ read_task = tasks.Task(stream.readline(), loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk2 ')
+ stream.feed_data(b'chunk3 ')
+ stream.feed_data(b'\n chunk4')
+ self.loop.call_soon(cb)
+
+ line = self.loop.run_until_complete(read_task)
+ self.assertEqual(b'chunk1 chunk2 chunk3 \n', line)
+ self.assertEqual(len(b'\n chunk4')-1, stream.byte_count)
+
+ def test_readline_limit_with_existing_data(self):
+ stream = streams.StreamReader(3, loop=self.loop)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1\nline2\n')
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual([b'line2\n'], list(stream.buffer))
+
+ stream = streams.StreamReader(3, loop=self.loop)
+ stream.feed_data(b'li')
+ stream.feed_data(b'ne1')
+ stream.feed_data(b'li')
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual([b'li'], list(stream.buffer))
+ self.assertEqual(2, stream.byte_count)
+
+ def test_readline_limit(self):
+ stream = streams.StreamReader(7, loop=self.loop)
+
+ def cb():
+ stream.feed_data(b'chunk1')
+ stream.feed_data(b'chunk2')
+ stream.feed_data(b'chunk3\n')
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+ self.assertEqual([b'chunk3\n'], list(stream.buffer))
+ self.assertEqual(7, stream.byte_count)
+
+ def test_readline_line_byte_count(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA[:6])
+ stream.feed_data(self.DATA[6:])
+
+ line = self.loop.run_until_complete(stream.readline())
+
+ self.assertEqual(b'line1\n', line)
+ self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count)
+
+ def test_readline_eof(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'some data')
+ stream.feed_eof()
+
+ line = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'some data', line)
+
+ def test_readline_empty_eof(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_eof()
+
+ line = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'', line)
+
+ def test_readline_read_byte_count(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ self.loop.run_until_complete(stream.readline())
+
+ data = self.loop.run_until_complete(stream.read(7))
+
+ self.assertEqual(b'line2\nl', data)
+ self.assertEqual(
+ len(self.DATA) - len(b'line1\n') - len(b'line2\nl'),
+ stream.byte_count)
+
+ def test_readline_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.readline())
+ self.assertEqual(b'line\n', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readline())
+
+ def test_readexactly_zero_or_less(self):
+ # Read exact number of bytes (zero or less).
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(self.DATA)
+
+ data = self.loop.run_until_complete(stream.readexactly(0))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ data = self.loop.run_until_complete(stream.readexactly(-1))
+ self.assertEqual(b'', data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_readexactly(self):
+ # Read exact number of bytes.
+ stream = streams.StreamReader(loop=self.loop)
+
+ n = 2 * len(self.DATA)
+ read_task = tasks.Task(stream.readexactly(n), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ stream.feed_data(self.DATA)
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA + self.DATA, data)
+ self.assertEqual(len(self.DATA), stream.byte_count)
+
+ def test_readexactly_eof(self):
+ # Read exact number of bytes (eof).
+ stream = streams.StreamReader(loop=self.loop)
+ n = 2 * len(self.DATA)
+ read_task = tasks.Task(stream.readexactly(n), loop=self.loop)
+
+ def cb():
+ stream.feed_data(self.DATA)
+ stream.feed_eof()
+ self.loop.call_soon(cb)
+
+ data = self.loop.run_until_complete(read_task)
+ self.assertEqual(self.DATA, data)
+ self.assertFalse(stream.byte_count)
+
+ def test_readexactly_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ stream.feed_data(b'line\n')
+
+ data = self.loop.run_until_complete(stream.readexactly(2))
+ self.assertEqual(b'li', data)
+
+ stream.set_exception(ValueError())
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete, stream.readexactly(2))
+
+ def test_exception(self):
+ stream = streams.StreamReader(loop=self.loop)
+ self.assertIsNone(stream.exception())
+
+ exc = ValueError()
+ stream.set_exception(exc)
+ self.assertIs(stream.exception(), exc)
+
+ def test_exception_waiter(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ @tasks.coroutine
+ def set_err():
+ stream.set_exception(ValueError())
+
+ @tasks.coroutine
+ def readline():
+ yield from stream.readline()
+
+ t1 = tasks.Task(stream.readline(), loop=self.loop)
+ t2 = tasks.Task(set_err(), loop=self.loop)
+
+ self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop))
+
+ self.assertRaises(ValueError, t1.result)
+
+ def test_exception_cancel(self):
+ stream = streams.StreamReader(loop=self.loop)
+
+ @tasks.coroutine
+ def read_a_line():
+ yield from stream.readline()
+
+ t = tasks.Task(read_a_line(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ t.cancel()
+ test_utils.run_briefly(self.loop)
+ # The following line fails if set_exception() isn't careful.
+ stream.set_exception(RuntimeError('message'))
+ test_utils.run_briefly(self.loop)
+ self.assertIs(stream.waiter, None)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py
new file mode 100644
index 0000000..57fb053
--- /dev/null
+++ b/Lib/test/test_asyncio/test_tasks.py
@@ -0,0 +1,1518 @@
+"""Tests for tasks.py."""
+
+import gc
+import unittest
+import unittest.mock
+from unittest.mock import Mock
+
+from asyncio import events
+from asyncio import futures
+from asyncio import tasks
+from asyncio import test_utils
+
+
+class Dummy:
+
+ def __repr__(self):
+ return 'Dummy()'
+
+ def __call__(self, *args):
+ pass
+
+
+class TaskTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+ gc.collect()
+
+ def test_task_class(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ok'
+ t = tasks.Task(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t._loop, self.loop)
+
+ loop = events.new_event_loop()
+ t = tasks.Task(notmuch(), loop=loop)
+ self.assertIs(t._loop, loop)
+ loop.close()
+
+ def test_async_coroutine(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ok'
+ t = tasks.async(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t._loop, self.loop)
+
+ loop = events.new_event_loop()
+ t = tasks.async(notmuch(), loop=loop)
+ self.assertIs(t._loop, loop)
+ loop.close()
+
+ def test_async_future(self):
+ f_orig = futures.Future(loop=self.loop)
+ f_orig.set_result('ko')
+
+ f = tasks.async(f_orig)
+ self.loop.run_until_complete(f)
+ self.assertTrue(f.done())
+ self.assertEqual(f.result(), 'ko')
+ self.assertIs(f, f_orig)
+
+ loop = events.new_event_loop()
+
+ with self.assertRaises(ValueError):
+ f = tasks.async(f_orig, loop=loop)
+
+ loop.close()
+
+ f = tasks.async(f_orig, loop=self.loop)
+ self.assertIs(f, f_orig)
+
+ def test_async_task(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ok'
+ t_orig = tasks.Task(notmuch(), loop=self.loop)
+ t = tasks.async(t_orig)
+ self.loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'ok')
+ self.assertIs(t, t_orig)
+
+ loop = events.new_event_loop()
+
+ with self.assertRaises(ValueError):
+ t = tasks.async(t_orig, loop=loop)
+
+ loop.close()
+
+ t = tasks.async(t_orig, loop=self.loop)
+ self.assertIs(t, t_orig)
+
+ def test_async_neither(self):
+ with self.assertRaises(TypeError):
+ tasks.async('ok')
+
+ def test_task_repr(self):
+ @tasks.coroutine
+ def notmuch():
+ yield from []
+ return 'abc'
+
+ t = tasks.Task(notmuch(), loop=self.loop)
+ t.add_done_callback(Dummy())
+ self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>')
+ t.cancel() # Does not take immediate effect!
+ self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>')
+ self.assertRaises(futures.CancelledError,
+ self.loop.run_until_complete, t)
+ self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>')
+ t = tasks.Task(notmuch(), loop=self.loop)
+ self.loop.run_until_complete(t)
+ self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>")
+
+ def test_task_repr_custom(self):
+ @tasks.coroutine
+ def coro():
+ pass
+
+ class T(futures.Future):
+ def __repr__(self):
+ return 'T[]'
+
+ class MyTask(tasks.Task, T):
+ def __repr__(self):
+ return super().__repr__()
+
+ gen = coro()
+ t = MyTask(gen, loop=self.loop)
+ self.assertEqual(repr(t), 'T[](<coro>)')
+ gen.close()
+
+ def test_task_basics(self):
+ @tasks.coroutine
+ def outer():
+ a = yield from inner1()
+ b = yield from inner2()
+ return a+b
+
+ @tasks.coroutine
+ def inner1():
+ return 42
+
+ @tasks.coroutine
+ def inner2():
+ return 1000
+
+ t = outer()
+ self.assertEqual(self.loop.run_until_complete(t), 1042)
+
+ def test_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def task():
+ yield from tasks.sleep(10.0, loop=loop)
+ return 12
+
+ t = tasks.Task(task(), loop=loop)
+ loop.call_soon(t.cancel)
+ with self.assertRaises(futures.CancelledError):
+ loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertTrue(t.cancelled())
+ self.assertFalse(t.cancel())
+
+ def test_cancel_yield(self):
+ @tasks.coroutine
+ def task():
+ yield
+ yield
+ return 12
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop) # start coro
+ t.cancel()
+ self.assertRaises(
+ futures.CancelledError, self.loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertTrue(t.cancelled())
+ self.assertFalse(t.cancel())
+
+ def test_cancel_inner_future(self):
+ f = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from f
+ return 12
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop) # start task
+ f.cancel()
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(t)
+ self.assertTrue(f.cancelled())
+ self.assertTrue(t.cancelled())
+
+ def test_cancel_both_task_and_inner_future(self):
+ f = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from f
+ return 12
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+
+ f.cancel()
+ t.cancel()
+
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(t)
+
+ self.assertTrue(t.done())
+ self.assertTrue(f.cancelled())
+ self.assertTrue(t.cancelled())
+
+ def test_cancel_task_catching(self):
+ fut1 = futures.Future(loop=self.loop)
+ fut2 = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from fut1
+ try:
+ yield from fut2
+ except futures.CancelledError:
+ return 42
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut1) # White-box test.
+ fut1.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut2) # White-box test.
+ t.cancel()
+ self.assertTrue(fut2.cancelled())
+ res = self.loop.run_until_complete(t)
+ self.assertEqual(res, 42)
+ self.assertFalse(t.cancelled())
+
+ def test_cancel_task_ignoring(self):
+ fut1 = futures.Future(loop=self.loop)
+ fut2 = futures.Future(loop=self.loop)
+ fut3 = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def task():
+ yield from fut1
+ try:
+ yield from fut2
+ except futures.CancelledError:
+ pass
+ res = yield from fut3
+ return res
+
+ t = tasks.Task(task(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut1) # White-box test.
+ fut1.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut2) # White-box test.
+ t.cancel()
+ self.assertTrue(fut2.cancelled())
+ test_utils.run_briefly(self.loop)
+ self.assertIs(t._fut_waiter, fut3) # White-box test.
+ fut3.set_result(42)
+ res = self.loop.run_until_complete(t)
+ self.assertEqual(res, 42)
+ self.assertFalse(fut3.cancelled())
+ self.assertFalse(t.cancelled())
+
+ def test_cancel_current_task(self):
+ loop = events.new_event_loop()
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def task():
+ t.cancel()
+ self.assertTrue(t._must_cancel) # White-box test.
+ # The sleep should be cancelled immediately.
+ yield from tasks.sleep(100, loop=loop)
+ return 12
+
+ t = tasks.Task(task(), loop=loop)
+ self.assertRaises(
+ futures.CancelledError, loop.run_until_complete, t)
+ self.assertTrue(t.done())
+ self.assertFalse(t._must_cancel) # White-box test.
+ self.assertFalse(t.cancel())
+
+ def test_stop_while_run_in_complete(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.3, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ x = 0
+ waiters = []
+
+ @tasks.coroutine
+ def task():
+ nonlocal x
+ while x < 10:
+ waiters.append(tasks.sleep(0.1, loop=loop))
+ yield from waiters[-1]
+ x += 1
+ if x == 2:
+ loop.stop()
+
+ t = tasks.Task(task(), loop=loop)
+ self.assertRaises(
+ RuntimeError, loop.run_until_complete, t)
+ self.assertFalse(t.done())
+ self.assertEqual(x, 2)
+ self.assertAlmostEqual(0.3, loop.time())
+
+ # close generators
+ for w in waiters:
+ w.close()
+
+ def test_wait_for(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.4, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def foo():
+ yield from tasks.sleep(0.2, loop=loop)
+ return 'done'
+
+ fut = tasks.Task(foo(), loop=loop)
+
+ with self.assertRaises(futures.TimeoutError):
+ loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop))
+
+ self.assertFalse(fut.done())
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # wait for result
+ res = loop.run_until_complete(
+ tasks.wait_for(fut, 0.3, loop=loop))
+ self.assertEqual(res, 'done')
+ self.assertAlmostEqual(0.2, loop.time())
+
+ def test_wait_for_with_global_loop(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.2, when)
+ when = yield 0
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def foo():
+ yield from tasks.sleep(0.2, loop=loop)
+ return 'done'
+
+ events.set_event_loop(loop)
+ try:
+ fut = tasks.Task(foo(), loop=loop)
+ with self.assertRaises(futures.TimeoutError):
+ loop.run_until_complete(tasks.wait_for(fut, 0.01))
+ finally:
+ events.set_event_loop(None)
+
+ self.assertAlmostEqual(0.01, loop.time())
+ self.assertFalse(fut.done())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(fut)
+
+ def test_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ yield 0.15
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], loop=loop)
+ self.assertEqual(done, set([a, b]))
+ self.assertEqual(pending, set())
+ return 42
+
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertEqual(res, 42)
+ self.assertAlmostEqual(0.15, loop.time())
+
+ # Doing it again should take no time and exercise a different path.
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+ self.assertEqual(res, 42)
+
+ def test_wait_with_global_loop(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.01, when)
+ when = yield 0
+ self.assertAlmostEqual(0.015, when)
+ yield 0.015
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a])
+ self.assertEqual(done, set([a, b]))
+ self.assertEqual(pending, set())
+ return 42
+
+ events.set_event_loop(loop)
+ try:
+ res = loop.run_until_complete(
+ tasks.Task(foo(), loop=loop))
+ finally:
+ events.set_event_loop(None)
+
+ self.assertEqual(res, 42)
+
+ def test_wait_errors(self):
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete,
+ tasks.wait(set(), loop=self.loop))
+
+ self.assertRaises(
+ ValueError, self.loop.run_until_complete,
+ tasks.wait([tasks.sleep(10.0, loop=self.loop)],
+ return_when=-1, loop=self.loop))
+
+ def test_wait_first_completed(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED,
+ loop=loop),
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertFalse(a.done())
+ self.assertTrue(b.done())
+ self.assertIsNone(b.result())
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_really_done(self):
+ # there is possibility that some tasks in the pending list
+ # became done but their callbacks haven't all been called yet
+
+ @tasks.coroutine
+ def coro1():
+ yield
+
+ @tasks.coroutine
+ def coro2():
+ yield
+ yield
+
+ a = tasks.Task(coro1(), loop=self.loop)
+ b = tasks.Task(coro2(), loop=self.loop)
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED,
+ loop=self.loop),
+ loop=self.loop)
+
+ done, pending = self.loop.run_until_complete(task)
+ self.assertEqual({a, b}, done)
+ self.assertTrue(a.done())
+ self.assertIsNone(a.result())
+ self.assertTrue(b.done())
+ self.assertIsNone(b.result())
+
+ def test_wait_first_exception(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ # first_exception, task already has exception
+ a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def exc():
+ raise ZeroDivisionError('err')
+
+ b = tasks.Task(exc(), loop=loop)
+ task = tasks.Task(
+ tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION,
+ loop=loop),
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertAlmostEqual(0, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_first_exception_in_wait(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ when = yield 0
+ self.assertAlmostEqual(0.01, when)
+ yield 0.01
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ # first_exception, exception during waiting
+ a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def exc():
+ yield from tasks.sleep(0.01, loop=loop)
+ raise ZeroDivisionError('err')
+
+ b = tasks.Task(exc(), loop=loop)
+ task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION,
+ loop=loop)
+
+ done, pending = loop.run_until_complete(task)
+ self.assertEqual({b}, done)
+ self.assertEqual({a}, pending)
+ self.assertAlmostEqual(0.01, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_with_exception(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ yield 0.15
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def sleeper():
+ yield from tasks.sleep(0.15, loop=loop)
+ raise ZeroDivisionError('really')
+
+ b = tasks.Task(sleeper(), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], loop=loop)
+ self.assertEqual(len(done), 2)
+ self.assertEqual(pending, set())
+ errors = set(f for f in done if f.exception() is not None)
+ self.assertEqual(len(errors), 1)
+
+ loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ def test_wait_with_timeout(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0
+ self.assertAlmostEqual(0.11, when)
+ yield 0.11
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ done, pending = yield from tasks.wait([b, a], timeout=0.11,
+ loop=loop)
+ self.assertEqual(done, set([a]))
+ self.assertEqual(pending, set([b]))
+
+ loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.11, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_wait_concurrent_complete(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop)
+ b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop)
+
+ done, pending = loop.run_until_complete(
+ tasks.wait([b, a], timeout=0.1, loop=loop))
+
+ self.assertEqual(done, set([a]))
+ self.assertEqual(pending, set([b]))
+ self.assertAlmostEqual(0.1, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_as_completed(self):
+
+ def gen():
+ yield 0
+ yield 0
+ yield 0.01
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+ completed = set()
+ time_shifted = False
+
+ @tasks.coroutine
+ def sleeper(dt, x):
+ nonlocal time_shifted
+ yield from tasks.sleep(dt, loop=loop)
+ completed.add(x)
+ if not time_shifted and 'a' in completed and 'b' in completed:
+ time_shifted = True
+ loop.advance_time(0.14)
+ return x
+
+ a = sleeper(0.01, 'a')
+ b = sleeper(0.01, 'b')
+ c = sleeper(0.15, 'c')
+
+ @tasks.coroutine
+ def foo():
+ values = []
+ for f in tasks.as_completed([b, c, a], loop=loop):
+ values.append((yield from f))
+ return values
+
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+ self.assertTrue('a' in res[:2])
+ self.assertTrue('b' in res[:2])
+ self.assertEqual(res[2], 'c')
+
+ # Doing it again should take no time and exercise a different path.
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertAlmostEqual(0.15, loop.time())
+
+ def test_as_completed_with_timeout(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.12, when)
+ when = yield 0
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(0.15, when)
+ when = yield 0.1
+ self.assertAlmostEqual(0.12, when)
+ yield 0.02
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.sleep(0.1, 'a', loop=loop)
+ b = tasks.sleep(0.15, 'b', loop=loop)
+
+ @tasks.coroutine
+ def foo():
+ values = []
+ for f in tasks.as_completed([a, b], timeout=0.12, loop=loop):
+ try:
+ v = yield from f
+ values.append((1, v))
+ except futures.TimeoutError as exc:
+ values.append((2, exc))
+ return values
+
+ res = loop.run_until_complete(tasks.Task(foo(), loop=loop))
+ self.assertEqual(len(res), 2, res)
+ self.assertEqual(res[0], (1, 'a'))
+ self.assertEqual(res[1][0], 2)
+ self.assertTrue(isinstance(res[1][1], futures.TimeoutError))
+ self.assertAlmostEqual(0.12, loop.time())
+
+ # move forward to close generator
+ loop.advance_time(10)
+ loop.run_until_complete(tasks.wait([a, b], loop=loop))
+
+ def test_as_completed_reverse_wait(self):
+
+ def gen():
+ yield 0
+ yield 0.05
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.sleep(0.05, 'a', loop=loop)
+ b = tasks.sleep(0.10, 'b', loop=loop)
+ fs = {a, b}
+ futs = list(tasks.as_completed(fs, loop=loop))
+ self.assertEqual(len(futs), 2)
+
+ x = loop.run_until_complete(futs[1])
+ self.assertEqual(x, 'a')
+ self.assertAlmostEqual(0.05, loop.time())
+ loop.advance_time(0.05)
+ y = loop.run_until_complete(futs[0])
+ self.assertEqual(y, 'b')
+ self.assertAlmostEqual(0.10, loop.time())
+
+ def test_as_completed_concurrent(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.05, when)
+ when = yield 0
+ self.assertAlmostEqual(0.05, when)
+ yield 0.05
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ a = tasks.sleep(0.05, 'a', loop=loop)
+ b = tasks.sleep(0.05, 'b', loop=loop)
+ fs = {a, b}
+ futs = list(tasks.as_completed(fs, loop=loop))
+ self.assertEqual(len(futs), 2)
+ waiter = tasks.wait(futs, loop=loop)
+ done, pending = loop.run_until_complete(waiter)
+ self.assertEqual(set(f.result() for f in done), {'a', 'b'})
+
+ def test_sleep(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.05, when)
+ when = yield 0.05
+ self.assertAlmostEqual(0.1, when)
+ yield 0.05
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def sleeper(dt, arg):
+ yield from tasks.sleep(dt/2, loop=loop)
+ res = yield from tasks.sleep(dt/2, arg, loop=loop)
+ return res
+
+ t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop)
+ loop.run_until_complete(t)
+ self.assertTrue(t.done())
+ self.assertEqual(t.result(), 'yeah')
+ self.assertAlmostEqual(0.1, loop.time())
+
+ def test_sleep_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop),
+ loop=loop)
+
+ handle = None
+ orig_call_later = loop.call_later
+
+ def call_later(self, delay, callback, *args):
+ nonlocal handle
+ handle = orig_call_later(self, delay, callback, *args)
+ return handle
+
+ loop.call_later = call_later
+ test_utils.run_briefly(loop)
+
+ self.assertFalse(handle._cancelled)
+
+ t.cancel()
+ test_utils.run_briefly(loop)
+ self.assertTrue(handle._cancelled)
+
+ def test_task_cancel_sleeping_task(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(0.1, when)
+ when = yield 0
+ self.assertAlmostEqual(5000, when)
+ yield 0.1
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ sleepfut = None
+
+ @tasks.coroutine
+ def sleep(dt):
+ nonlocal sleepfut
+ sleepfut = tasks.sleep(dt, loop=loop)
+ yield from sleepfut
+
+ @tasks.coroutine
+ def doit():
+ sleeper = tasks.Task(sleep(5000), loop=loop)
+ loop.call_later(0.1, sleeper.cancel)
+ try:
+ yield from sleeper
+ except futures.CancelledError:
+ return 'cancelled'
+ else:
+ return 'slept in'
+
+ doer = doit()
+ self.assertEqual(loop.run_until_complete(doer), 'cancelled')
+ self.assertAlmostEqual(0.1, loop.time())
+
+ def test_task_cancel_waiter_future(self):
+ fut = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def coro():
+ yield from fut
+
+ task = tasks.Task(coro(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(task._fut_waiter, fut)
+
+ task.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertRaises(
+ futures.CancelledError, self.loop.run_until_complete, task)
+ self.assertIsNone(task._fut_waiter)
+ self.assertTrue(fut.cancelled())
+
+ def test_step_in_completed_task(self):
+ @tasks.coroutine
+ def notmuch():
+ return 'ko'
+
+ gen = notmuch()
+ task = tasks.Task(gen, loop=self.loop)
+ task.set_result('ok')
+
+ self.assertRaises(AssertionError, task._step)
+ gen.close()
+
+ def test_step_result(self):
+ @tasks.coroutine
+ def notmuch():
+ yield None
+ yield 1
+ return 'ko'
+
+ self.assertRaises(
+ RuntimeError, self.loop.run_until_complete, notmuch())
+
+ def test_step_result_future(self):
+ # If coroutine returns future, task waits on this future.
+
+ class Fut(futures.Future):
+ def __init__(self, *args, **kwds):
+ self.cb_added = False
+ super().__init__(*args, **kwds)
+
+ def add_done_callback(self, fn):
+ self.cb_added = True
+ super().add_done_callback(fn)
+
+ fut = Fut(loop=self.loop)
+ result = None
+
+ @tasks.coroutine
+ def wait_for_future():
+ nonlocal result
+ result = yield from fut
+
+ t = tasks.Task(wait_for_future(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(fut.cb_added)
+
+ res = object()
+ fut.set_result(res)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(res, result)
+ self.assertTrue(t.done())
+ self.assertIsNone(t.result())
+
+ def test_step_with_baseexception(self):
+ @tasks.coroutine
+ def notmutch():
+ raise BaseException()
+
+ task = tasks.Task(notmutch(), loop=self.loop)
+ self.assertRaises(BaseException, task._step)
+
+ self.assertTrue(task.done())
+ self.assertIsInstance(task.exception(), BaseException)
+
+ def test_baseexception_during_cancel(self):
+
+ def gen():
+ when = yield
+ self.assertAlmostEqual(10.0, when)
+ yield 0
+
+ loop = test_utils.TestLoop(gen)
+ self.addCleanup(loop.close)
+
+ @tasks.coroutine
+ def sleeper():
+ yield from tasks.sleep(10, loop=loop)
+
+ base_exc = BaseException()
+
+ @tasks.coroutine
+ def notmutch():
+ try:
+ yield from sleeper()
+ except futures.CancelledError:
+ raise base_exc
+
+ task = tasks.Task(notmutch(), loop=loop)
+ test_utils.run_briefly(loop)
+
+ task.cancel()
+ self.assertFalse(task.done())
+
+ self.assertRaises(BaseException, test_utils.run_briefly, loop)
+
+ self.assertTrue(task.done())
+ self.assertFalse(task.cancelled())
+ self.assertIs(task.exception(), base_exc)
+
+ def test_iscoroutinefunction(self):
+ def fn():
+ pass
+
+ self.assertFalse(tasks.iscoroutinefunction(fn))
+
+ def fn1():
+ yield
+ self.assertFalse(tasks.iscoroutinefunction(fn1))
+
+ @tasks.coroutine
+ def fn2():
+ yield
+ self.assertTrue(tasks.iscoroutinefunction(fn2))
+
+ def test_yield_vs_yield_from(self):
+ fut = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def wait_for_future():
+ yield fut
+
+ task = wait_for_future()
+ with self.assertRaises(RuntimeError):
+ self.loop.run_until_complete(task)
+
+ self.assertFalse(fut.done())
+
+ def test_yield_vs_yield_from_generator(self):
+ @tasks.coroutine
+ def coro():
+ yield
+
+ @tasks.coroutine
+ def wait_for_future():
+ gen = coro()
+ try:
+ yield gen
+ finally:
+ gen.close()
+
+ task = wait_for_future()
+ self.assertRaises(
+ RuntimeError,
+ self.loop.run_until_complete, task)
+
+ def test_coroutine_non_gen_function(self):
+ @tasks.coroutine
+ def func():
+ return 'test'
+
+ self.assertTrue(tasks.iscoroutinefunction(func))
+
+ coro = func()
+ self.assertTrue(tasks.iscoroutine(coro))
+
+ res = self.loop.run_until_complete(coro)
+ self.assertEqual(res, 'test')
+
+ def test_coroutine_non_gen_function_return_future(self):
+ fut = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def func():
+ return fut
+
+ @tasks.coroutine
+ def coro():
+ fut.set_result('test')
+
+ t1 = tasks.Task(func(), loop=self.loop)
+ t2 = tasks.Task(coro(), loop=self.loop)
+ res = self.loop.run_until_complete(t1)
+ self.assertEqual(res, 'test')
+ self.assertIsNone(t2.result())
+
+ # Some thorough tests for cancellation propagation through
+ # coroutines, tasks and wait().
+
+ def test_yield_future_passes_cancel(self):
+ # Cancelling outer() cancels inner() cancels waiter.
+ proof = 0
+ waiter = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ try:
+ yield from waiter
+ except futures.CancelledError:
+ proof += 1
+ raise
+ else:
+ self.fail('got past sleep() in inner()')
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof
+ try:
+ yield from inner()
+ except futures.CancelledError:
+ proof += 100 # Expect this path.
+ else:
+ proof += 10
+
+ f = tasks.async(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ self.loop.run_until_complete(f)
+ self.assertEqual(proof, 101)
+ self.assertTrue(waiter.cancelled())
+
+ def test_yield_wait_does_not_shield_cancel(self):
+ # Cancelling outer() makes wait() return early, leaves inner()
+ # running.
+ proof = 0
+ waiter = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof
+ d, p = yield from tasks.wait([inner()], loop=self.loop)
+ proof += 100
+
+ f = tasks.async(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ self.assertRaises(
+ futures.CancelledError, self.loop.run_until_complete, f)
+ waiter.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(proof, 1)
+
+ def test_shield_result(self):
+ inner = futures.Future(loop=self.loop)
+ outer = tasks.shield(inner)
+ inner.set_result(42)
+ res = self.loop.run_until_complete(outer)
+ self.assertEqual(res, 42)
+
+ def test_shield_exception(self):
+ inner = futures.Future(loop=self.loop)
+ outer = tasks.shield(inner)
+ test_utils.run_briefly(self.loop)
+ exc = RuntimeError('expected')
+ inner.set_exception(exc)
+ test_utils.run_briefly(self.loop)
+ self.assertIs(outer.exception(), exc)
+
+ def test_shield_cancel(self):
+ inner = futures.Future(loop=self.loop)
+ outer = tasks.shield(inner)
+ test_utils.run_briefly(self.loop)
+ inner.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(outer.cancelled())
+
+ def test_shield_shortcut(self):
+ fut = futures.Future(loop=self.loop)
+ fut.set_result(42)
+ res = self.loop.run_until_complete(tasks.shield(fut))
+ self.assertEqual(res, 42)
+
+ def test_shield_effect(self):
+ # Cancelling outer() does not affect inner().
+ proof = 0
+ waiter = futures.Future(loop=self.loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof
+ yield from tasks.shield(inner(), loop=self.loop)
+ proof += 100
+
+ f = tasks.async(outer(), loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ f.cancel()
+ with self.assertRaises(futures.CancelledError):
+ self.loop.run_until_complete(f)
+ waiter.set_result(None)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(proof, 1)
+
+ def test_shield_gather(self):
+ child1 = futures.Future(loop=self.loop)
+ child2 = futures.Future(loop=self.loop)
+ parent = tasks.gather(child1, child2, loop=self.loop)
+ outer = tasks.shield(parent, loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ outer.cancel()
+ test_utils.run_briefly(self.loop)
+ self.assertTrue(outer.cancelled())
+ child1.set_result(1)
+ child2.set_result(2)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(parent.result(), [1, 2])
+
+ def test_gather_shield(self):
+ child1 = futures.Future(loop=self.loop)
+ child2 = futures.Future(loop=self.loop)
+ inner1 = tasks.shield(child1, loop=self.loop)
+ inner2 = tasks.shield(child2, loop=self.loop)
+ parent = tasks.gather(inner1, inner2, loop=self.loop)
+ test_utils.run_briefly(self.loop)
+ parent.cancel()
+ # This should cancel inner1 and inner2 but bot child1 and child2.
+ test_utils.run_briefly(self.loop)
+ self.assertIsInstance(parent.exception(), futures.CancelledError)
+ self.assertTrue(inner1.cancelled())
+ self.assertTrue(inner2.cancelled())
+ child1.set_result(1)
+ child2.set_result(2)
+ test_utils.run_briefly(self.loop)
+
+
+class GatherTestsBase:
+
+ def setUp(self):
+ self.one_loop = test_utils.TestLoop()
+ self.other_loop = test_utils.TestLoop()
+
+ def tearDown(self):
+ self.one_loop.close()
+ self.other_loop.close()
+
+ def _run_loop(self, loop):
+ while loop._ready:
+ test_utils.run_briefly(loop)
+
+ def _check_success(self, **kwargs):
+ a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)]
+ fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ b.set_result(1)
+ a.set_result(2)
+ self._run_loop(self.one_loop)
+ self.assertEqual(cb.called, False)
+ self.assertFalse(fut.done())
+ c.set_result(3)
+ self._run_loop(self.one_loop)
+ cb.assert_called_once_with(fut)
+ self.assertEqual(fut.result(), [2, 1, 3])
+
+ def test_success(self):
+ self._check_success()
+ self._check_success(return_exceptions=False)
+
+ def test_result_exception_success(self):
+ self._check_success(return_exceptions=True)
+
+ def test_one_exception(self):
+ a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)]
+ fut = tasks.gather(*self.wrap_futures(a, b, c, d, e))
+ cb = Mock()
+ fut.add_done_callback(cb)
+ exc = ZeroDivisionError()
+ a.set_result(1)
+ b.set_exception(exc)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertIs(fut.exception(), exc)
+ # Does nothing
+ c.set_result(3)
+ d.cancel()
+ e.set_exception(RuntimeError())
+
+ def test_return_exceptions(self):
+ a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)]
+ fut = tasks.gather(*self.wrap_futures(a, b, c, d),
+ return_exceptions=True)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ exc = ZeroDivisionError()
+ exc2 = RuntimeError()
+ b.set_result(1)
+ c.set_exception(exc)
+ a.set_result(3)
+ self._run_loop(self.one_loop)
+ self.assertFalse(fut.done())
+ d.set_exception(exc2)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertEqual(fut.result(), [3, 1, exc, exc2])
+
+
+class FutureGatherTests(GatherTestsBase, unittest.TestCase):
+
+ def wrap_futures(self, *futures):
+ return futures
+
+ def _check_empty_sequence(self, seq_or_iter):
+ events.set_event_loop(self.one_loop)
+ self.addCleanup(events.set_event_loop, None)
+ fut = tasks.gather(*seq_or_iter)
+ self.assertIsInstance(fut, futures.Future)
+ self.assertIs(fut._loop, self.one_loop)
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ self.assertEqual(fut.result(), [])
+ fut = tasks.gather(*seq_or_iter, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+
+ def test_constructor_empty_sequence(self):
+ self._check_empty_sequence([])
+ self._check_empty_sequence(())
+ self._check_empty_sequence(set())
+ self._check_empty_sequence(iter(""))
+
+ def test_constructor_heterogenous_futures(self):
+ fut1 = futures.Future(loop=self.one_loop)
+ fut2 = futures.Future(loop=self.other_loop)
+ with self.assertRaises(ValueError):
+ tasks.gather(fut1, fut2)
+ with self.assertRaises(ValueError):
+ tasks.gather(fut1, loop=self.other_loop)
+
+ def test_constructor_homogenous_futures(self):
+ children = [futures.Future(loop=self.other_loop) for i in range(3)]
+ fut = tasks.gather(*children)
+ self.assertIs(fut._loop, self.other_loop)
+ self._run_loop(self.other_loop)
+ self.assertFalse(fut.done())
+ fut = tasks.gather(*children, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+ self._run_loop(self.other_loop)
+ self.assertFalse(fut.done())
+
+ def test_one_cancellation(self):
+ a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)]
+ fut = tasks.gather(a, b, c, d, e)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ a.set_result(1)
+ b.cancel()
+ self._run_loop(self.one_loop)
+ self.assertTrue(fut.done())
+ cb.assert_called_once_with(fut)
+ self.assertFalse(fut.cancelled())
+ self.assertIsInstance(fut.exception(), futures.CancelledError)
+ # Does nothing
+ c.set_result(3)
+ d.cancel()
+ e.set_exception(RuntimeError())
+
+ def test_result_exception_one_cancellation(self):
+ a, b, c, d, e, f = [futures.Future(loop=self.one_loop)
+ for i in range(6)]
+ fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True)
+ cb = Mock()
+ fut.add_done_callback(cb)
+ a.set_result(1)
+ zde = ZeroDivisionError()
+ b.set_exception(zde)
+ c.cancel()
+ self._run_loop(self.one_loop)
+ self.assertFalse(fut.done())
+ d.set_result(3)
+ e.cancel()
+ rte = RuntimeError()
+ f.set_exception(rte)
+ res = self.one_loop.run_until_complete(fut)
+ self.assertIsInstance(res[2], futures.CancelledError)
+ self.assertIsInstance(res[4], futures.CancelledError)
+ res[2] = res[4] = None
+ self.assertEqual(res, [1, zde, None, 3, None, rte])
+ cb.assert_called_once_with(fut)
+
+
+class CoroutineGatherTests(GatherTestsBase, unittest.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ events.set_event_loop(self.one_loop)
+
+ def tearDown(self):
+ events.set_event_loop(None)
+ super().tearDown()
+
+ def wrap_futures(self, *futures):
+ coros = []
+ for fut in futures:
+ @tasks.coroutine
+ def coro(fut=fut):
+ return (yield from fut)
+ coros.append(coro())
+ return coros
+
+ def test_constructor_loop_selection(self):
+ @tasks.coroutine
+ def coro():
+ return 'abc'
+ gen1 = coro()
+ gen2 = coro()
+ fut = tasks.gather(gen1, gen2)
+ self.assertIs(fut._loop, self.one_loop)
+ gen1.close()
+ gen2.close()
+ gen3 = coro()
+ gen4 = coro()
+ fut = tasks.gather(gen3, gen4, loop=self.other_loop)
+ self.assertIs(fut._loop, self.other_loop)
+ gen3.close()
+ gen4.close()
+
+ def test_cancellation_broadcast(self):
+ # Cancelling outer() cancels all children.
+ proof = 0
+ waiter = futures.Future(loop=self.one_loop)
+
+ @tasks.coroutine
+ def inner():
+ nonlocal proof
+ yield from waiter
+ proof += 1
+
+ child1 = tasks.async(inner(), loop=self.one_loop)
+ child2 = tasks.async(inner(), loop=self.one_loop)
+ gatherer = None
+
+ @tasks.coroutine
+ def outer():
+ nonlocal proof, gatherer
+ gatherer = tasks.gather(child1, child2, loop=self.one_loop)
+ yield from gatherer
+ proof += 100
+
+ f = tasks.async(outer(), loop=self.one_loop)
+ test_utils.run_briefly(self.one_loop)
+ self.assertTrue(f.cancel())
+ with self.assertRaises(futures.CancelledError):
+ self.one_loop.run_until_complete(f)
+ self.assertFalse(gatherer.cancel())
+ self.assertTrue(waiter.cancelled())
+ self.assertTrue(child1.cancelled())
+ self.assertTrue(child2.cancelled())
+ test_utils.run_briefly(self.one_loop)
+ self.assertEqual(proof, 0)
+
+ def test_exception_marking(self):
+ # Test for the first line marked "Mark exception retrieved."
+
+ @tasks.coroutine
+ def inner(f):
+ yield from f
+ raise RuntimeError('should not be ignored')
+
+ a = futures.Future(loop=self.one_loop)
+ b = futures.Future(loop=self.one_loop)
+
+ @tasks.coroutine
+ def outer():
+ yield from tasks.gather(inner(a), inner(b), loop=self.one_loop)
+
+ f = tasks.async(outer(), loop=self.one_loop)
+ test_utils.run_briefly(self.one_loop)
+ a.set_result(None)
+ test_utils.run_briefly(self.one_loop)
+ b.set_result(None)
+ test_utils.run_briefly(self.one_loop)
+ self.assertIsInstance(f.exception(), RuntimeError)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py
new file mode 100644
index 0000000..fce2e6f
--- /dev/null
+++ b/Lib/test/test_asyncio/test_transports.py
@@ -0,0 +1,55 @@
+"""Tests for transports.py."""
+
+import unittest
+import unittest.mock
+
+from asyncio import transports
+
+
+class TransportTests(unittest.TestCase):
+
+ def test_ctor_extra_is_none(self):
+ transport = transports.Transport()
+ self.assertEqual(transport._extra, {})
+
+ def test_get_extra_info(self):
+ transport = transports.Transport({'extra': 'info'})
+ self.assertEqual('info', transport.get_extra_info('extra'))
+ self.assertIsNone(transport.get_extra_info('unknown'))
+
+ default = object()
+ self.assertIs(default, transport.get_extra_info('unknown', default))
+
+ def test_writelines(self):
+ transport = transports.Transport()
+ transport.write = unittest.mock.Mock()
+
+ transport.writelines(['line1', 'line2', 'line3'])
+ self.assertEqual(3, transport.write.call_count)
+
+ def test_not_implemented(self):
+ transport = transports.Transport()
+
+ self.assertRaises(NotImplementedError, transport.write, 'data')
+ self.assertRaises(NotImplementedError, transport.write_eof)
+ self.assertRaises(NotImplementedError, transport.can_write_eof)
+ self.assertRaises(NotImplementedError, transport.pause)
+ self.assertRaises(NotImplementedError, transport.resume)
+ self.assertRaises(NotImplementedError, transport.close)
+ self.assertRaises(NotImplementedError, transport.abort)
+
+ def test_dgram_not_implemented(self):
+ transport = transports.DatagramTransport()
+
+ self.assertRaises(NotImplementedError, transport.sendto, 'data')
+ self.assertRaises(NotImplementedError, transport.abort)
+
+ def test_subprocess_transport_not_implemented(self):
+ transport = transports.SubprocessTransport()
+
+ self.assertRaises(NotImplementedError, transport.get_pid)
+ self.assertRaises(NotImplementedError, transport.get_returncode)
+ self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1)
+ self.assertRaises(NotImplementedError, transport.send_signal, 1)
+ self.assertRaises(NotImplementedError, transport.terminate)
+ self.assertRaises(NotImplementedError, transport.kill)
diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py
new file mode 100644
index 0000000..ea67862
--- /dev/null
+++ b/Lib/test/test_asyncio/test_unix_events.py
@@ -0,0 +1,767 @@
+"""Tests for unix_events.py."""
+
+import gc
+import errno
+import io
+import pprint
+import signal
+import stat
+import sys
+import unittest
+import unittest.mock
+
+
+from asyncio import events
+from asyncio import futures
+from asyncio import protocols
+from asyncio import test_utils
+from asyncio import unix_events
+
+
+@unittest.skipUnless(signal, 'Signals are not supported')
+class SelectorEventLoopTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = unix_events.SelectorEventLoop()
+ events.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+
+ def test_check_signal(self):
+ self.assertRaises(
+ TypeError, self.loop._check_signal, '1')
+ self.assertRaises(
+ ValueError, self.loop._check_signal, signal.NSIG + 1)
+
+ def test_handle_signal_no_handler(self):
+ self.loop._handle_signal(signal.NSIG + 1, ())
+
+ def test_handle_signal_cancelled_handler(self):
+ h = events.Handle(unittest.mock.Mock(), ())
+ h.cancel()
+ self.loop._signal_handlers[signal.NSIG + 1] = h
+ self.loop.remove_signal_handler = unittest.mock.Mock()
+ self.loop._handle_signal(signal.NSIG + 1, ())
+ self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler_setup_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ cb = lambda: True
+ self.loop.add_signal_handler(signal.SIGHUP, cb)
+ h = self.loop._signal_handlers.get(signal.SIGHUP)
+ self.assertTrue(isinstance(h, events.Handle))
+ self.assertEqual(h._callback, cb)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_add_signal_handler_install_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ def set_wakeup_fd(fd):
+ if fd == -1:
+ raise ValueError()
+ m_signal.set_wakeup_fd = set_wakeup_fd
+
+ class Err(OSError):
+ errno = errno.EFAULT
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ Err,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ def test_add_signal_handler_install_error2(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.loop._signal_handlers[signal.SIGHUP] = lambda: True
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(1, m_signal.set_wakeup_fd.call_count)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ def test_add_signal_handler_install_error3(self, m_logging, m_signal):
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+ m_signal.NSIG = signal.NSIG
+
+ self.assertRaises(
+ RuntimeError,
+ self.loop.add_signal_handler,
+ signal.SIGINT, lambda: True)
+ self.assertFalse(m_logging.info.called)
+ self.assertEqual(2, m_signal.set_wakeup_fd.call_count)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ self.assertTrue(
+ self.loop.remove_signal_handler(signal.SIGHUP))
+ self.assertTrue(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0])
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ m_signal.SIGINT = signal.SIGINT
+
+ self.loop.add_signal_handler(signal.SIGINT, lambda: True)
+ self.loop._signal_handlers[signal.SIGHUP] = object()
+ m_signal.set_wakeup_fd.reset_mock()
+
+ self.assertTrue(
+ self.loop.remove_signal_handler(signal.SIGINT))
+ self.assertFalse(m_signal.set_wakeup_fd.called)
+ self.assertTrue(m_signal.signal.called)
+ self.assertEqual(
+ (signal.SIGINT, m_signal.default_int_handler),
+ m_signal.signal.call_args[0])
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.set_wakeup_fd.side_effect = ValueError
+
+ self.loop.remove_signal_handler(signal.SIGHUP)
+ self.assertTrue(m_logging.info)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_error(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ m_signal.signal.side_effect = OSError
+
+ self.assertRaises(
+ OSError, self.loop.remove_signal_handler, signal.SIGHUP)
+
+ @unittest.mock.patch('asyncio.unix_events.signal')
+ def test_remove_signal_handler_error2(self, m_signal):
+ m_signal.NSIG = signal.NSIG
+ self.loop.add_signal_handler(signal.SIGHUP, lambda: True)
+
+ class Err(OSError):
+ errno = errno.EINVAL
+ m_signal.signal.side_effect = Err
+
+ self.assertRaises(
+ RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(7, object()), ChildProcessError]
+ m_WIFEXITED.return_value = True
+ m_WIFSIGNALED.return_value = False
+ m_WEXITSTATUS.return_value = 3
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ transp._process_exited.assert_called_with(3)
+ self.assertFalse(m_WTERMSIG.called)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(7, object()), ChildProcessError]
+ m_WIFEXITED.return_value = False
+ m_WIFSIGNALED.return_value = True
+ m_WTERMSIG.return_value = 1
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ transp._process_exited.assert_called_with(-1)
+ self.assertFalse(m_WEXITSTATUS.called)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(0, object()), ChildProcessError]
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ self.assertFalse(transp._process_exited.called)
+ self.assertFalse(m_WIFSIGNALED.called)
+ self.assertFalse(m_WIFEXITED.called)
+ self.assertFalse(m_WTERMSIG.called)
+ self.assertFalse(m_WEXITSTATUS.called)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_not_registered_subprocess(self, m_waitpid,
+ m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(7, object()), ChildProcessError]
+ m_WIFEXITED.return_value = True
+ m_WIFSIGNALED.return_value = False
+ m_WEXITSTATUS.return_value = 3
+
+ self.loop._sig_chld()
+ self.assertFalse(m_WTERMSIG.called)
+
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_unknown_status(self, m_waitpid,
+ m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG):
+ m_waitpid.side_effect = [(7, object()), ChildProcessError]
+ m_WIFEXITED.return_value = False
+ m_WIFSIGNALED.return_value = False
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ self.assertFalse(transp._process_exited.called)
+ self.assertFalse(m_WEXITSTATUS.called)
+ self.assertFalse(m_WTERMSIG.called)
+
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ @unittest.mock.patch('os.WTERMSIG')
+ @unittest.mock.patch('os.WEXITSTATUS')
+ @unittest.mock.patch('os.WIFSIGNALED')
+ @unittest.mock.patch('os.WIFEXITED')
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_unknown_status_in_handler(self, m_waitpid,
+ m_WIFEXITED, m_WIFSIGNALED,
+ m_WEXITSTATUS, m_WTERMSIG,
+ m_log):
+ m_waitpid.side_effect = Exception
+ transp = unittest.mock.Mock()
+ self.loop._subprocesses[7] = transp
+
+ self.loop._sig_chld()
+ self.assertFalse(transp._process_exited.called)
+ self.assertFalse(m_WIFSIGNALED.called)
+ self.assertFalse(m_WIFEXITED.called)
+ self.assertFalse(m_WTERMSIG.called)
+ self.assertFalse(m_WEXITSTATUS.called)
+ m_log.exception.assert_called_with(
+ 'Unknown exception in SIGCHLD handler')
+
+ @unittest.mock.patch('os.waitpid')
+ def test__sig_chld_process_error(self, m_waitpid):
+ m_waitpid.side_effect = ChildProcessError
+ self.loop._sig_chld()
+ self.assertTrue(m_waitpid.called)
+
+
+class UnixReadPipeTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(protocols.Protocol)
+ self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+
+ fcntl_patcher = unittest.mock.patch('fcntl.fcntl')
+ fcntl_patcher.start()
+ self.addCleanup(fcntl_patcher.stop)
+
+ def test_ctor(self):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.loop.assert_reader(5, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future(loop=self.loop)
+ unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol, fut)
+ test_utils.run_briefly(self.loop)
+ self.assertIsNone(fut.result())
+
+ @unittest.mock.patch('os.read')
+ def test__read_ready(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ m_read.return_value = b'data'
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.protocol.data_received.assert_called_with(b'data')
+
+ @unittest.mock.patch('os.read')
+ def test__read_ready_eof(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ m_read.return_value = b''
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.eof_received.assert_called_with()
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @unittest.mock.patch('os.read')
+ def test__read_ready_blocked(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ m_read.side_effect = BlockingIOError
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ test_utils.run_briefly(self.loop)
+ self.assertFalse(self.protocol.data_received.called)
+
+ @unittest.mock.patch('asyncio.log.asyncio_log.exception')
+ @unittest.mock.patch('os.read')
+ def test__read_ready_error(self, m_read, m_logexc):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+ err = OSError()
+ m_read.side_effect = err
+ tr._close = unittest.mock.Mock()
+ tr._read_ready()
+
+ m_read.assert_called_with(5, tr.max_size)
+ tr._close.assert_called_with(err)
+ m_logexc.assert_called_with('Fatal error for %s', tr)
+
+ @unittest.mock.patch('os.read')
+ def test_pause(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m = unittest.mock.Mock()
+ self.loop.add_reader(5, m)
+ tr.pause()
+ self.assertFalse(self.loop.readers)
+
+ @unittest.mock.patch('os.read')
+ def test_resume(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.resume()
+ self.loop.assert_reader(5, tr._read_ready)
+
+ @unittest.mock.patch('os.read')
+ def test_close(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr._close = unittest.mock.Mock()
+ tr.close()
+ tr._close.assert_called_with(None)
+
+ @unittest.mock.patch('os.read')
+ def test_close_already_closing(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr._closing = True
+ tr._close = unittest.mock.Mock()
+ tr.close()
+ self.assertFalse(tr._close.called)
+
+ @unittest.mock.patch('os.read')
+ def test__close(self, m_read):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = object()
+ tr._close(err)
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(err)
+
+ def test__call_connection_lost(self):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+ def test__call_connection_lost_with_err(self):
+ tr = unix_events._UnixReadPipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+
+class UnixWritePipeTransportTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = test_utils.TestLoop()
+ self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol)
+ self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase)
+ self.pipe.fileno.return_value = 5
+
+ fcntl_patcher = unittest.mock.patch('fcntl.fcntl')
+ fcntl_patcher.start()
+ self.addCleanup(fcntl_patcher.stop)
+
+ fstat_patcher = unittest.mock.patch('os.fstat')
+ m_fstat = fstat_patcher.start()
+ st = unittest.mock.Mock()
+ st.st_mode = stat.S_IFIFO
+ m_fstat.return_value = st
+ self.addCleanup(fstat_patcher.stop)
+
+ def test_ctor(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.loop.assert_reader(5, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_made.assert_called_with(tr)
+
+ def test_ctor_with_waiter(self):
+ fut = futures.Future(loop=self.loop)
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol, fut)
+ self.loop.assert_reader(5, tr._read_ready)
+ test_utils.run_briefly(self.loop)
+ self.assertEqual(None, fut.result())
+
+ def test_can_write_eof(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.assertTrue(tr.can_write_eof())
+
+ @unittest.mock.patch('os.write')
+ def test_write(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m_write.return_value = 4
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_no_data(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write(b'')
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_partial(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m_write.return_value = 2
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'ta'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_buffer(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'previous']
+ tr.write(b'data')
+ self.assertFalse(m_write.called)
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'previous', b'data'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test_write_again(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ m_write.side_effect = BlockingIOError()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('asyncio.unix_events.asyncio_log')
+ @unittest.mock.patch('os.write')
+ def test_write_err(self, m_write, m_log):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = OSError()
+ m_write.side_effect = err
+ tr._fatal_error = unittest.mock.Mock()
+ tr.write(b'data')
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+ tr._fatal_error.assert_called_with(err)
+ self.assertEqual(1, tr._conn_lost)
+
+ tr.write(b'data')
+ self.assertEqual(2, tr._conn_lost)
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ tr.write(b'data')
+ # This is a bit overspecified. :-(
+ m_log.warning.assert_called_with(
+ 'pipe closed by peer or os.write(pipe, data) raised exception.')
+
+ @unittest.mock.patch('os.write')
+ def test_write_close(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ tr._read_ready() # pipe was closed by peer
+
+ tr.write(b'data')
+ self.assertEqual(tr._conn_lost, 1)
+ tr.write(b'data')
+ self.assertEqual(tr._conn_lost, 2)
+
+ def test__read_ready(self):
+ tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe,
+ self.protocol)
+ tr._read_ready()
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+ self.assertTrue(tr._closing)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_partial(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 3
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'a'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_again(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = BlockingIOError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_empty(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 0
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.loop.assert_writer(5, tr._write_ready)
+ self.assertEqual([b'data'], tr._buffer)
+
+ @unittest.mock.patch('asyncio.log.asyncio_log.exception')
+ @unittest.mock.patch('os.write')
+ def test__write_ready_err(self, m_write, m_logexc):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._buffer = [b'da', b'ta']
+ m_write.side_effect = err = OSError()
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr._closing)
+ m_logexc.assert_called_with('Fatal error for %s', tr)
+ self.assertEqual(1, tr._conn_lost)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(err)
+
+ @unittest.mock.patch('os.write')
+ def test__write_ready_closing(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ tr._closing = True
+ tr._buffer = [b'da', b'ta']
+ m_write.return_value = 4
+ tr._write_ready()
+ m_write.assert_called_with(5, b'data')
+ self.assertFalse(self.loop.writers)
+ self.assertFalse(self.loop.readers)
+ self.assertEqual([], tr._buffer)
+ self.protocol.connection_lost.assert_called_with(None)
+ self.pipe.close.assert_called_with()
+
+ @unittest.mock.patch('os.write')
+ def test_abort(self, m_write):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ self.loop.add_writer(5, tr._write_ready)
+ self.loop.add_reader(5, tr._read_ready)
+ tr._buffer = [b'da', b'ta']
+ tr.abort()
+ self.assertFalse(m_write.called)
+ self.assertFalse(self.loop.readers)
+ self.assertFalse(self.loop.writers)
+ self.assertEqual([], tr._buffer)
+ self.assertTrue(tr._closing)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test__call_connection_lost(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = None
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+ def test__call_connection_lost_with_err(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ err = OSError()
+ tr._call_connection_lost(err)
+ self.protocol.connection_lost.assert_called_with(err)
+ self.pipe.close.assert_called_with()
+
+ self.assertIsNone(tr._protocol)
+ self.assertEqual(2, sys.getrefcount(self.protocol),
+ pprint.pformat(gc.get_referrers(self.protocol)))
+ self.assertIsNone(tr._loop)
+ self.assertEqual(2, sys.getrefcount(self.loop),
+ pprint.pformat(gc.get_referrers(self.loop)))
+
+ def test_close(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write_eof = unittest.mock.Mock()
+ tr.close()
+ tr.write_eof.assert_called_with()
+
+ def test_close_closing(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write_eof = unittest.mock.Mock()
+ tr._closing = True
+ tr.close()
+ self.assertFalse(tr.write_eof.called)
+
+ def test_write_eof(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.loop.readers)
+ test_utils.run_briefly(self.loop)
+ self.protocol.connection_lost.assert_called_with(None)
+
+ def test_write_eof_pending(self):
+ tr = unix_events._UnixWritePipeTransport(
+ self.loop, self.pipe, self.protocol)
+ tr._buffer = [b'data']
+ tr.write_eof()
+ self.assertTrue(tr._closing)
+ self.assertFalse(self.protocol.connection_lost.called)
diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py
new file mode 100644
index 0000000..4b04073
--- /dev/null
+++ b/Lib/test/test_asyncio/test_windows_events.py
@@ -0,0 +1,95 @@
+import os
+import sys
+import unittest
+
+if sys.platform != 'win32':
+ raise unittest.SkipTest('Windows only')
+
+import asyncio
+
+from asyncio import windows_events
+from asyncio import protocols
+from asyncio import streams
+from asyncio import transports
+from asyncio import test_utils
+
+
+class UpperProto(protocols.Protocol):
+ def __init__(self):
+ self.buf = []
+
+ def connection_made(self, trans):
+ self.trans = trans
+
+ def data_received(self, data):
+ self.buf.append(data)
+ if b'\n' in data:
+ self.trans.write(b''.join(self.buf).upper())
+ self.trans.close()
+
+
+class ProactorTests(unittest.TestCase):
+
+ def setUp(self):
+ self.loop = windows_events.ProactorEventLoop()
+ asyncio.set_event_loop(None)
+
+ def tearDown(self):
+ self.loop.close()
+ self.loop = None
+
+ def test_close(self):
+ a, b = self.loop._socketpair()
+ trans = self.loop._make_socket_transport(a, protocols.Protocol())
+ f = asyncio.async(self.loop.sock_recv(b, 100))
+ trans.close()
+ self.loop.run_until_complete(f)
+ self.assertEqual(f.result(), b'')
+
+ def test_double_bind(self):
+ ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid()
+ server1 = windows_events.PipeServer(ADDRESS)
+ with self.assertRaises(PermissionError):
+ server2 = windows_events.PipeServer(ADDRESS)
+ server1.close()
+
+ def test_pipe(self):
+ res = self.loop.run_until_complete(self._test_pipe())
+ self.assertEqual(res, 'done')
+
+ def _test_pipe(self):
+ ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid()
+
+ with self.assertRaises(FileNotFoundError):
+ yield from self.loop.create_pipe_connection(
+ protocols.Protocol, ADDRESS)
+
+ [server] = yield from self.loop.start_serving_pipe(
+ UpperProto, ADDRESS)
+ self.assertIsInstance(server, windows_events.PipeServer)
+
+ clients = []
+ for i in range(5):
+ stream_reader = streams.StreamReader(loop=self.loop)
+ protocol = streams.StreamReaderProtocol(stream_reader)
+ trans, proto = yield from self.loop.create_pipe_connection(
+ lambda:protocol, ADDRESS)
+ self.assertIsInstance(trans, transports.Transport)
+ self.assertEqual(protocol, proto)
+ clients.append((stream_reader, trans))
+
+ for i, (r, w) in enumerate(clients):
+ w.write('lower-{}\n'.format(i).encode())
+
+ for i, (r, w) in enumerate(clients):
+ response = yield from r.readline()
+ self.assertEqual(response, 'LOWER-{}\n'.format(i).encode())
+ w.close()
+
+ server.close()
+
+ with self.assertRaises(FileNotFoundError):
+ yield from self.loop.create_pipe_connection(
+ protocols.Protocol, ADDRESS)
+
+ return 'done'
diff --git a/Lib/test/test_asyncio/test_windows_utils.py b/Lib/test/test_asyncio/test_windows_utils.py
new file mode 100644
index 0000000..4b96086
--- /dev/null
+++ b/Lib/test/test_asyncio/test_windows_utils.py
@@ -0,0 +1,136 @@
+"""Tests for window_utils"""
+
+import sys
+import test.support
+import unittest
+import unittest.mock
+
+if sys.platform != 'win32':
+ raise unittest.SkipTest('Windows only')
+
+import _winapi
+
+from asyncio import windows_utils
+from asyncio import _overlapped
+
+
+class WinsocketpairTests(unittest.TestCase):
+
+ def test_winsocketpair(self):
+ ssock, csock = windows_utils.socketpair()
+
+ csock.send(b'xxx')
+ self.assertEqual(b'xxx', ssock.recv(1024))
+
+ csock.close()
+ ssock.close()
+
+ @unittest.mock.patch('asyncio.windows_utils.socket')
+ def test_winsocketpair_exc(self, m_socket):
+ m_socket.socket.return_value.getsockname.return_value = ('', 12345)
+ m_socket.socket.return_value.accept.return_value = object(), object()
+ m_socket.socket.return_value.connect.side_effect = OSError()
+
+ self.assertRaises(OSError, windows_utils.socketpair)
+
+
+class PipeTests(unittest.TestCase):
+
+ def test_pipe_overlapped(self):
+ h1, h2 = windows_utils.pipe(overlapped=(True, True))
+ try:
+ ov1 = _overlapped.Overlapped()
+ self.assertFalse(ov1.pending)
+ self.assertEqual(ov1.error, 0)
+
+ ov1.ReadFile(h1, 100)
+ self.assertTrue(ov1.pending)
+ self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING)
+ ERROR_IO_INCOMPLETE = 996
+ try:
+ ov1.getresult()
+ except OSError as e:
+ self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE)
+ else:
+ raise RuntimeError('expected ERROR_IO_INCOMPLETE')
+
+ ov2 = _overlapped.Overlapped()
+ self.assertFalse(ov2.pending)
+ self.assertEqual(ov2.error, 0)
+
+ ov2.WriteFile(h2, b"hello")
+ self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
+
+ res = _winapi.WaitForMultipleObjects([ov2.event], False, 100)
+ self.assertEqual(res, _winapi.WAIT_OBJECT_0)
+
+ self.assertFalse(ov1.pending)
+ self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE)
+ self.assertFalse(ov2.pending)
+ self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
+ self.assertEqual(ov1.getresult(), b"hello")
+ finally:
+ _winapi.CloseHandle(h1)
+ _winapi.CloseHandle(h2)
+
+ def test_pipe_handle(self):
+ h, _ = windows_utils.pipe(overlapped=(True, True))
+ _winapi.CloseHandle(_)
+ p = windows_utils.PipeHandle(h)
+ self.assertEqual(p.fileno(), h)
+ self.assertEqual(p.handle, h)
+
+ # check garbage collection of p closes handle
+ del p
+ test.support.gc_collect()
+ try:
+ _winapi.CloseHandle(h)
+ except OSError as e:
+ self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE
+ else:
+ raise RuntimeError('expected ERROR_INVALID_HANDLE')
+
+
+class PopenTests(unittest.TestCase):
+
+ def test_popen(self):
+ command = r"""if 1:
+ import sys
+ s = sys.stdin.readline()
+ sys.stdout.write(s.upper())
+ sys.stderr.write('stderr')
+ """
+ msg = b"blah\n"
+
+ p = windows_utils.Popen([sys.executable, '-c', command],
+ stdin=windows_utils.PIPE,
+ stdout=windows_utils.PIPE,
+ stderr=windows_utils.PIPE)
+
+ for f in [p.stdin, p.stdout, p.stderr]:
+ self.assertIsInstance(f, windows_utils.PipeHandle)
+
+ ovin = _overlapped.Overlapped()
+ ovout = _overlapped.Overlapped()
+ overr = _overlapped.Overlapped()
+
+ ovin.WriteFile(p.stdin.handle, msg)
+ ovout.ReadFile(p.stdout.handle, 100)
+ overr.ReadFile(p.stderr.handle, 100)
+
+ events = [ovin.event, ovout.event, overr.event]
+ res = _winapi.WaitForMultipleObjects(events, True, 2000)
+ self.assertEqual(res, _winapi.WAIT_OBJECT_0)
+ self.assertFalse(ovout.pending)
+ self.assertFalse(overr.pending)
+ self.assertFalse(ovin.pending)
+
+ self.assertEqual(ovin.getresult(), len(msg))
+ out = ovout.getresult().rstrip()
+ err = overr.getresult().rstrip()
+
+ self.assertGreater(len(out), 0)
+ self.assertGreater(len(err), 0)
+ # allow for partial reads...
+ self.assertTrue(msg.upper().rstrip().startswith(out))
+ self.assertTrue(b"stderr".startswith(err))
diff --git a/Lib/test/test_asyncio/tests.txt b/Lib/test/test_asyncio/tests.txt
new file mode 100644
index 0000000..e947721
--- /dev/null
+++ b/Lib/test/test_asyncio/tests.txt
@@ -0,0 +1,14 @@
+test_asyncio.test_base_events
+test_asyncio.test_events
+test_asyncio.test_futures
+test_asyncio.test_locks
+test_asyncio.test_proactor_events
+test_asyncio.test_queues
+test_asyncio.test_selector_events
+test_asyncio.test_selectors
+test_asyncio.test_streams
+test_asyncio.test_tasks
+test_asyncio.test_transports
+test_asyncio.test_unix_events
+test_asyncio.test_windows_events
+test_asyncio.test_windows_utils