diff options
Diffstat (limited to 'Lib')
40 files changed, 14572 insertions, 0 deletions
diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py new file mode 100644 index 0000000..afc444d --- /dev/null +++ b/Lib/asyncio/__init__.py @@ -0,0 +1,33 @@ +"""The asyncio package, tracking PEP 3156.""" + +import sys + +# The selectors module is in the stdlib in Python 3.4 but not in 3.3. +# Do this first, so the other submodules can use "from . import selectors". +try: + import selectors # Will also be exported. +except ImportError: + from . import selectors + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py new file mode 100644 index 0000000..32457eb --- /dev/null +++ b/Lib/asyncio/base_events.py @@ -0,0 +1,606 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import subprocess +import time +import os +import sys + +from . import events +from . import futures +from . import tasks +from .log import asyncio_log + + +__all__ = ['BaseEventLoop', 'Server'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class Server(events.AbstractServer): + + def __init__(self, loop, sockets): + self.loop = loop + self.sockets = sockets + self.active_count = 0 + self.waiters = [] + + def attach(self, transport): + assert self.sockets is not None + self.active_count += 1 + + def detach(self, transport): + assert self.active_count > 0 + self.active_count -= 1 + if self.active_count == 0 and self.sockets is None: + self._wakeup() + + def close(self): + sockets = self.sockets + if sockets is not None: + self.sockets = None + for sock in sockets: + self.loop._stop_serving(sock) + if self.active_count == 0: + self._wakeup() + + def _wakeup(self): + waiters = self.waiters + self.waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(waiter) + + @tasks.coroutine + def wait_closed(self): + if self.sockets is None or self.waiters is None: + return + waiter = futures.Future(loop=self.loop) + self.waiters.append(waiter) + yield from waiter + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run_forever(self): + """Run until stop() is called.""" + if self._running: + raise RuntimeError('Event loop is running.') + self._running = True + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_until_complete(self, future): + """Run until the Future is done. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. + """ + future = tasks.async(future, loop=self) + future.add_done_callback(_raise_stop_error) + self.run_forever() + future.remove_done_callback(_raise_stop_error) + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def time(self): + """Return the time according to the event loop's clock.""" + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + return self.call_at(self.time() + delay, callback, *args) + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time.""" + timer = events.TimerHandle(when, callback, args) + heapq.heappush(self._scheduled, timer) + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + callback, args = callback._callback, callback._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(callback, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, + server_side=False, server_hostname=host) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport(sock, protocol, r_addr) + return transport, protocol + + @tasks.coroutine + def create_server(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sockets = [sock] + + server = Server(self, sockets) + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + return server + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter) + yield from waiter + return transport, protocol + + @tasks.coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert shell, "shell must be True" + assert isinstance(cmd, str), cmd + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) + return transport, protocol + + @tasks.coroutine + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=False, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert not shell, "shell must be False" + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, args, False, stdin, stdout, stderr, bufsize, **kwargs) + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + if isinstance(handle, events.TimerHandle): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + heapq.heappop(self._scheduled) + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + deadline = max(0, when - self.time()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + # TODO: Instrumentation only in debug mode? + t0 = self.time() + event_list = self._selector.select(timeout) + t1 = self.time() + argstr = '' if timeout is None else '{:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + asyncio_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = self.time() + while self._scheduled: + handle = self._scheduled[0] + if handle._when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle._cancelled: + handle._run() + handle = None # Needed to break cycles when an exception occurs. diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py new file mode 100644 index 0000000..79c3b93 --- /dev/null +++ b/Lib/asyncio/constants.py @@ -0,0 +1,4 @@ +"""Constants.""" + + +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py new file mode 100644 index 0000000..9724615 --- /dev/null +++ b/Lib/asyncio/events.py @@ -0,0 +1,395 @@ +"""Event loop and event loop policy.""" + +__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'AbstractServer', + 'Handle', 'TimerHandle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import subprocess +import sys +import threading +import socket + +from .log import asyncio_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '<cancelled>' + return res + + def cancel(self): + self._cancelled = True + + def _run(self): + try: + self._callback(*self._args) + except Exception: + asyncio_log.exception('Exception in callback %s %r', + self._callback, self._args) + self = None # Needed to break cycles when an exception occurs. + + +def make_handle(callback, args): + # TODO: Inline this? + assert not isinstance(callback, Handle), 'A Handle is not a callback' + return Handle(callback, args) + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'TimerHandle({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '<cancelled>' + + return res + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractServer: + """Abstract server returned by create_service().""" + + def close(self): + """Stop serving. This leaves existing connections open.""" + return NotImplemented + + def wait_closed(self): + """Coroutine to wait until service is closed.""" + return NotImplemented + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + raise NotImplementedError + + def create_server(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """A coroutine which creates a TCP server bound to host and port. + + The return value is a Server object which can be used to stop + the service. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop = None + _set_called = False + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._loop is None and + not self._set_called and + isinstance(threading.current_thread(), threading._MainThread)): + self._loop = self.new_event_loop() + assert self._loop is not None, \ + ('There is no current event loop in thread %r.' % + threading.current_thread().name) + return self._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + # TODO: The isinstance() test violates the PEP. + self._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + # TODO: The isinstance() test violates the PEP. + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """XXX""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/Lib/asyncio/futures.py b/Lib/asyncio/futures.py new file mode 100644 index 0000000..99a043b --- /dev/null +++ b/Lib/asyncio/futures.py @@ -0,0 +1,338 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import traceback + +from . import events +from .log import asyncio_log + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ['exc', 'tb'] + + def __init__(self, exc): + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + asyncio_log.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _loop = None + + _blocking = False # proper use of future (yield vs yield from) + + _tb_logger = None + + def __init__(self, *, loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += '<exception={!r}>'.format(self._exception) + else: + res += '<result={!r}>'.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._exception = exception + self._tb_logger = _TracebackLogger(exception) + self._state = _FINISHED + self._schedule_callbacks() + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + + if loop is None: + loop = events.get_event_loop() + + new_future = Future(loop=loop) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, fut)) + return new_future diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py new file mode 100644 index 0000000..06edbbc --- /dev/null +++ b/Lib/asyncio/locks.py @@ -0,0 +1,401 @@ +"""Synchronization primitives.""" + +__all__ = ['Lock', 'Event', 'Condition', 'Semaphore'] + +import collections + +from . import events +from . import futures +from . import tasks + + +class Lock: + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context manager protocol. '(yield from lock)' + should be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + break + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class Event: + """An Event implementation, our equivalent to threading.Event. + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + +# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self, *, loop=None): + super().__init__(loop=loop) + self._condition_waiters = collections.deque() + + # TODO: Add __repr__() with len(_condition_waiters). + + @tasks.coroutine + def wait(self): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + keep_lock = True + self.release() + try: + fut = futures.Future(loop=self._loop) + self._condition_waiters.append(fut) + try: + yield from fut + return True + finally: + self._condition_waiters.remove(fut) + + except GeneratorExit: + keep_lock = False # Prevent yield in finally clause. + raise + finally: + if keep_lock: + yield from self.acquire() + + @tasks.coroutine + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + yield from self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._value -= 1 + if self._value == 0: + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + # TODO: This is questionable. How do we know the user actually + # wrote "with (yield from sema)" instead of "with sema"? + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/Lib/asyncio/log.py b/Lib/asyncio/log.py new file mode 100644 index 0000000..54dc784 --- /dev/null +++ b/Lib/asyncio/log.py @@ -0,0 +1,6 @@ +"""Logging configuration.""" + +import logging + + +asyncio_log = logging.getLogger("asyncio") diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py new file mode 100644 index 0000000..348de03 --- /dev/null +++ b/Lib/asyncio/proactor_events.py @@ -0,0 +1,352 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import socket + +from . import base_events +from . import constants +from . import futures +from . import transports +from .log import asyncio_log + + +class _ProactorBasePipeTransport(transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(extra) + self._set_extra(sock) + self._loop = loop + self._sock = sock + self._protocol = protocol + self._server = server + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._eof_written = False + if self._server is not None: + self._server.attach(self) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + + def _fatal_error(self, exc): + asyncio_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + server = self._server + if server is not None: + server.detach(self) + self._server = None + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, waiter, extra, server) + self._read_fut = None + self._paused = False + self._loop.call_soon(self._loop_reading) + + def pause(self): + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + self._paused = True + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.call_soon(self._loop_reading, self._read_fut) + + def _loop_reading(self, fut=None): + if self._paused: + return + data = None + + try: + if fut is not None: + assert self._read_fut is fut or (self._read_fut is None and + self._closing) + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + keep_open = self._protocol.eof_received() + if not keep_open: + self.close() + + +class _ProactorWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if self._eof_written: + raise IOError('write_eof() already called') + + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) + if self._write_fut is None: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + self._write_fut = None + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + if self._eof_written: + self._sock.shutdown(socket.SHUT_WR) + return + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + + def can_write_eof(self): + return True + + def write_eof(self): + self.close() + + def abort(self): + self._force_close(None) + + +class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for duplex pipes.""" + + def can_write_eof(self): + return False + + def write_eof(self): + raise NotImplementedError + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + try: + self._extra['sockname'] = sock.getsockname() + except (socket.error, AttributeError): + pass + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except (socket.error, AttributeError): + pass + + def can_write_eof(self): + return True + + def write_eof(self): + if self._closing or self._eof_written: + return + self._eof_written = True + if self._write_fut is None: + self._sock.shutdown(socket.SHUT_WR) + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + asyncio_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, + extra=None, server=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorDuplexPipeTransport(self, + sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock, ssl=None, server=None): + assert not ssl, 'IocpEventLoop is incompatible with SSL.' + + def loop(f=None): + try: + if f is not None: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, + extra={'peername': addr}, server=server) + f = self._proactor.accept(sock) + except OSError: + if sock.fileno() != -1: + asyncio_log.exception('Accept failed') + sock.close() + except futures.CancelledError: + sock.close() + else: + f.add_done_callback(loop) + + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll + + def _stop_serving(self, sock): + self._proactor._stop_serving(sock) + sock.close() diff --git a/Lib/asyncio/protocols.py b/Lib/asyncio/protocols.py new file mode 100644 index 0000000..a94abbe --- /dev/null +++ b/Lib/asyncio/protocols.py @@ -0,0 +1,98 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + +class SubprocessProtocol(BaseProtocol): + """ABC representing a protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when subprocess write a data into stdout/stderr pipes. + + fd is int file dascriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited. + """ diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py new file mode 100644 index 0000000..536de1c --- /dev/null +++ b/Lib/asyncio/queues.py @@ -0,0 +1,284 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'Full', 'Empty'] + +import collections +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +# Re-export queue.Full and .Empty exceptions. +Full = queue.Full +Empty = queue.Empty + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self): + # Delete waiters at the head of the get() queue who've timed out. + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item): + """Put an item into the queue. + + If you yield from put(), wait until a free slot is available + before adding item. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future(loop=self._loop) + + self._putters.append((item, waiter)) + yield from waiter + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise Full. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise Full + else: + self._put(item) + + @coroutine + def get(self): + """Remove and return an item from the queue. + + If you yield from get(), wait until a item is available. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop) + + self._getters.append(waiter) + return (yield from waiter) + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) + self._unfinished_tasks = 0 + self._finished = locks.Event(loop=self._loop) + self._finished.set() + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py new file mode 100644 index 0000000..bae9a49 --- /dev/null +++ b/Lib/asyncio/selector_events.py @@ -0,0 +1,769 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from .log import asyncio_log + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + asyncio_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, + server_side, server_hostname, extra, server) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock, ssl=None, server=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, ssl, server) + + def _accept_connection(self, protocol_factory, sock, ssl=None, + server=None): + try: + conn, addr = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + pass # False alarm. + except Exception: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + asyncio_log.exception('Accept failed') + else: + if ssl: + self._make_ssl_transport( + conn, protocol_factory(), ssl, None, + server_side=True, extra={'peername': addr}, server=server) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'peername': addr}, + server=server) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future(loop=self) + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + # TODO: Use getaddrinfo() to look up the address, to avoid the + # trap of hanging the entire event loop when the address + # requires doing a DNS lookup. (OTOH, the caller should + # already have done this, so it would be nice if we could + # easily tell whether the address needs looking up or not. I + # know how to do this for IPv4, but IPv6 addresses have many + # syntaxes.) + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise OSError(err, 'Connect call failed') + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def _stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports.Transport): + + max_size = 256 * 1024 # Buffer size passed to recv(). + + def __init__(self, loop, sock, protocol, extra, server=None): + super().__init__(extra) + self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except socket.error: + self._extra['peername'] = None + self._loop = loop + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._server = server + self._buffer = collections.deque() + self._conn_lost = 0 + self._closing = False # Set when close() called. + if server is not None: + server.attach(self) + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + # should be called from exception handler only + asyncio_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._buffer: + self._buffer.clear() + self._loop.remove_writer(self._sock_fd) + + if self._closing: + return + + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + server = self._server + if server is not None: + server.detach(self) + self._server = None + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, extra, server) + self._eof = False + self._paused = False + + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def pause(self): + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + self._paused = True + self._loop.remove_reader(self._sock_fd) + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + + def _read_ready(self): + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + keep_open = self._protocol.eof_received() + if not keep_open: + self.close() + + def write(self, data): + assert isinstance(data, bytes), repr(type(data)) + assert not self._eof, 'Cannot call write() after write_eof()' + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except (BrokenPipeError, ConnectionResetError) as exc: + self._force_close(exc) + return + except OSError as exc: + self._fatal_error(exc) + return + else: + data = data[n:] + if not data: + return + # Start async I/O. + self._loop.add_writer(self._sock_fd, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except (BrokenPipeError, ConnectionResetError) as exc: + self._loop.remove_writer(self._sock_fd) + self._force_close(exc) + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._fatal_error(exc) + else: + data = data[n:] + if not data: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + elif self._eof: + self._sock.shutdown(socket.SHUT_WR) + return + + self._buffer.append(data) # Try again later. + + def write_eof(self): + if self._eof: + return + self._eof = True + if not self._buffer: + self._sock.shutdown(socket.SHUT_WR) + + def can_write_eof(self): + return True + + +class _SelectorSslTransport(_SelectorTransport): + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, server_hostname=None, + extra=None, server=None): + if server_side: + assert isinstance( + sslcontext, ssl.SSLContext), 'Must pass an SSLContext' + else: + # Client-side may pass ssl=True to use a default context. + # The default is the same as used by urllib. + if sslcontext is None: + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED + wrap_kwargs = { + 'server_side': server_side, + 'do_handshake_on_connect': False, + } + if server_hostname is not None and not server_side and ssl.HAS_SNI: + wrap_kwargs['server_hostname'] = server_hostname + sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) + + super().__init__(loop, sslsock, protocol, extra, server) + + self._server_hostname = server_hostname + self._waiter = waiter + self._rawsock = rawsock + self._sslcontext = sslcontext + self._paused = False + + # SSL-specific extra info. (peercert is set later) + self._extra.update(sslcontext=sslcontext) + + self._on_handshake() + + def _on_handshake(self): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, self._on_handshake) + return + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + raise + + # Verify hostname if requested. + peercert = self._sock.getpeercert() + if (self._server_hostname is not None and + self._sslcontext.verify_mode == ssl.CERT_REQUIRED): + try: + ssl.match_hostname(peercert, self._server_hostname) + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=self._sock.cipher(), + compression=self._sock.compression(), + ) + + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_writer(self._sock_fd, self._on_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + self._loop.call_soon(self._waiter.set_result, None) + + def pause(self): + # XXX This is a bit icky, given the comment at the top of + # _on_ready(). Is it possible to evoke a deadlock? I don't + # know, although it doesn't look like it; write() will still + # accept more data for the buffer and eventually the app will + # call resume() again, and things will flow again. + + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + self._paused = True + self._loop.remove_reader(self._sock_fd) + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # First try reading. + if not self._closing and not self._paused: + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + # Now try writing, if there's anything to write. + if self._buffer: + data = b''.join(self._buffer) + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + n = 0 + except (BrokenPipeError, ConnectionResetError) as exc: + self._loop.remove_writer(self._sock_fd) + self._force_close(exc) + return + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + + if self._closing and not self._buffer: + self._loop.remove_writer(self._sock_fd) + self._call_connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(type(data)) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def can_write_eof(self): + return False + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + + +class _SelectorDatagramTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, address=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._address = address + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes), repr(type(data)) + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def _force_close(self, exc): + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + + super()._force_close(exc) diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py new file mode 100644 index 0000000..d0f12e8 --- /dev/null +++ b/Lib/asyncio/streams.py @@ -0,0 +1,257 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] + +import collections + +from . import events +from . import futures +from . import protocols +from . import tasks + + +_DEFAULT_LIMIT = 2**16 + + +@tasks.coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + Transport. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + return reader, transport # (reader, writer) + + +class StreamReaderProtocol(protocols.Protocol): + """Trivial helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_made(self, transport): + self.stream_reader.set_transport(transport) + + def connection_lost(self, exc): + if exc is None: + self.stream_reader.feed_eof() + else: + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + self.limit = limit + if loop is None: + loop = events.get_event_loop() + self.loop = loop + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + self._transport = None + self._paused = False + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and self.byte_count <= self.limit: + self._paused = False + self._transport.resume() + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.cancelled(): + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.cancelled(): + waiter.set_result(False) + + if (self._transport is not None and + not self._paused and + self.byte_count > 2*self.limit): + try: + self._transport.pause() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + self._maybe_resume_transport() + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + + line = b''.join(parts) + self.byte_count -= parts_size + self._maybe_resume_transport() + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + self._maybe_resume_transport() + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + self._maybe_resume_transport() + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + + return (yield from self.read(n)) diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py new file mode 100644 index 0000000..2c8579f --- /dev/null +++ b/Lib/asyncio/tasks.py @@ -0,0 +1,636 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + 'gather', + ] + +import collections +import concurrent.futures +import functools +import inspect +import linecache +import traceback +import weakref + +from . import events +from . import futures +from .log import asyncio_log + +# If you set _DEBUG to true, @coroutine will wrap the resulting +# generator objects in a CoroWrapper instance (defined below). That +# instance will log a message when the generator is never iterated +# over, which may happen when you forget to use "yield from" with a +# coroutine call. Note that the value of the _DEBUG flag is taken +# when the decorator is used, so to be of any use it must be set +# before you define your coroutines. A downside of using this feature +# is that tracebacks show entries for the CoroWrapper.__next__ method +# when _DEBUG is true. +_DEBUG = False + + +class CoroWrapper: + """Wrapper for coroutine in _DEBUG mode.""" + + __slot__ = ['gen', 'func'] + + def __init__(self, gen, func): + assert inspect.isgenerator(gen), gen + self.gen = gen + self.func = func + + def __iter__(self): + return self + + def __next__(self): + return next(self.gen) + + def send(self, value): + return self.gen.send(value) + + def throw(self, exc): + return self.gen.throw(exc) + + def close(self): + return self.gen.close() + + def __del__(self): + frame = self.gen.gi_frame + if frame is not None and frame.f_lasti == -1: + func = self.func + code = func.__code__ + filename = code.co_filename + lineno = code.co_firstlineno + asyncio_log.error('Coroutine %r defined at %s:%s was never yielded from', + func.__name__, filename, lineno) + + +def coroutine(func): + """Decorator to mark coroutines. + + If the coroutine is not yielded from before it is destroyed, + an error message is logged. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + if not _DEBUG: + wrapper = coro + else: + @functools.wraps(func) + def wrapper(*args, **kwds): + w = CoroWrapper(coro(*args, **kwds), func) + w.__name__ = coro.__name__ + w.__doc__ = coro.__doc__ + return w + + wrapper._is_coroutine = True # For iscoroutinefunction(). + return wrapper + + +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return getattr(func, '_is_coroutine', False) + + +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return isinstance(obj, CoroWrapper) or inspect.isgenerator(obj) + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + # An important invariant maintained while a Task not done: + # + # - Either _fut_waiter is None, and _step() is scheduled; + # - or _fut_waiter is some Future, and _step() is *not* scheduled. + # + # The only transition from the latter to the former is through + # _wakeup(). When _fut_waiter is not None, one of its callbacks + # must be _wakeup(). + + # Weak set containing all tasks alive. + _all_tasks = weakref.WeakSet() + + @classmethod + def all_tasks(cls, loop=None): + """Return a set of all tasks for an event loop. + + By default all tasks for the current event loop are returned. + """ + if loop is None: + loop = events.get_event_loop() + return {t for t in cls._all_tasks if t._loop is loop} + + def __init__(self, coro, *, loop=None): + assert iscoroutine(coro), repr(coro) # Not a coroutine function! + super().__init__(loop=loop) + self._coro = iter(coro) # Use the iterator just in case. + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + self.__class__._all_tasks.add(self) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + '<PENDING' in res): + res = res.replace('<PENDING', '<CANCELLING', 1) + i = res.find('<') + if i < 0: + i = len(res) + res = res[:i] + '(<{}>)'.format(self._coro.__name__) + res[i:] + return res + + def get_stack(self, *, limit=None): + """Return the list of stack frames for this task's coroutine. + + If the coroutine is active, this returns the stack where it is + suspended. If the coroutine has completed successfully or was + cancelled, this returns an empty list. If the coroutine was + terminated by an exception, this returns the list of traceback + frames. + + The frames are always ordered from oldest to newest. + + The optional limit gives the maximum nummber of frames to + return; by default all available frames are returned. Its + meaning differs depending on whether a stack or a traceback is + returned: the newest frames of a stack are returned, but the + oldest frames of a traceback are returned. (This matches the + behavior of the traceback module.) + + For reasons beyond our control, only one stack frame is + returned for a suspended coroutine. + """ + frames = [] + f = self._coro.gi_frame + if f is not None: + while f is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(f) + f = f.f_back + frames.reverse() + elif self._exception is not None: + tb = self._exception.__traceback__ + while tb is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(tb.tb_frame) + tb = tb.tb_next + return frames + + def print_stack(self, *, limit=None, file=None): + """Print the stack or traceback for this task's coroutine. + + This produces output similar to that of the traceback module, + for the frames retrieved by get_stack(). The limit argument + is passed to get_stack(). The file argument is an I/O stream + to which the output goes; by default it goes to sys.stderr. + """ + extracted_list = [] + checked = set() + for f in self.get_stack(limit=limit): + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + if filename not in checked: + checked.add(filename) + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, f.f_globals) + extracted_list.append((filename, lineno, name, line)) + exc = self._exception + if not extracted_list: + print('No stack for %r' % self, file=file) + elif exc is not None: + print('Traceback for %r (most recent call last):' % self, + file=file) + else: + print('Stack for %r (most recent call last):' % self, + file=file) + traceback.print_list(extracted_list, file=file) + if exc is not None: + for line in traceback.format_exception_only(exc.__class__, exc): + print(line, file=file, end='') + + def cancel(self): + if self.done(): + return False + if self._fut_waiter is not None: + if self._fut_waiter.cancel(): + # Leave self._fut_waiter; it may be a Task that + # catches and ignores the cancellation so we may have + # to cancel it again later. + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + if self._must_cancel: + if not isinstance(exc, futures.CancelledError): + exc = futures.CancelledError() + self._must_cancel = False + coro = self._coro + self._fut_waiter = None + # Call either coro.throw(exc) or coro.send(value). + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). + except Exception as exc: + self.set_exception(exc) + except BaseException as exc: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False + else: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + self = None + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + # This may also be a cancellation. + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from asyncio.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + + if loop is None: + loop = events.get_event_loop() + + fs = set(async(f, loop=loop) for f in fs) + + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + return (yield from _wait(fs, timeout, return_when, loop)) + + +def _release_waiter(waiter, value=True, *args): + if not waiter.done(): + waiter.set_result(value) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. Raises TimeoutError when + timeout occurs. + + Usage: + + result = yield from asyncio.wait_for(fut, 10.0) + + """ + if loop is None: + loop = events.get_event_loop() + + waiter = futures.Future(loop=loop) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter, False) + cb = functools.partial(_release_waiter, waiter, True) + + fut = async(fut, loop=loop) + fut.add_done_callback(cb) + + try: + if (yield from waiter): + return fut.result() + else: + fut.remove_done_callback(cb) + raise futures.TimeoutError() + finally: + timeout_handle.cancel() + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait() and _wait_for(). + + The fs argument must be a collection of Futures. + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() + if not waiter.done(): + waiter.set_result(False) + + for f in fs: + f.add_done_callback(_on_completion) + + try: + yield from waiter + finally: + if timeout_handle is not None: + timeout_handle.cancel() + + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + loop = loop if loop is not None else events.get_event_loop() + deadline = None if timeout is None else loop.time() + timeout + todo = set(async(f, loop=loop) for f in fs) + completed = collections.deque() + + @coroutine + def _wait_for_one(): + while not completed: + timeout = None + if deadline is not None: + timeout = deadline - loop.time() + if timeout < 0: + raise futures.TimeoutError() + done, pending = yield from _wait( + todo, timeout, FIRST_COMPLETED, loop) + # Multiple callers might be waiting for the same events + # and getting the same outcome. Dedupe by updating todo. + for f in done: + if f in todo: + todo.remove(f) + completed.append(f) + f = completed.popleft() + return f.result() # May raise. + + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, future.set_result, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif iscoroutine(coro_or_future): + return Task(coro_or_future, loop=loop) + else: + raise TypeError('A Future or coroutine is required') + + +class _GatheringFuture(futures.Future): + """Helper for gather(). + + This overrides cancel() to cancel all the children and act more + like Task.cancel(), which doesn't immediately mark itself as + cancelled. + """ + + def __init__(self, children, *, loop=None): + super().__init__(loop=loop) + self._children = children + + def cancel(self): + if self.done(): + return False + for child in self._children: + child.cancel() + return True + + +def gather(*coros_or_futures, loop=None, return_exceptions=False): + """Return a future aggregating results from the given coroutines + or futures. + + All futures must share the same event loop. If all the tasks are + done successfully, the returned future's result is the list of + results (in the order of the original sequence, not necessarily + the order of results arrival). If *result_exception* is True, + exceptions in the tasks are treated the same as successful + results, and gathered in the result list; otherwise, the first + raised exception will be immediately propagated to the returned + future. + + Cancellation: if the outer Future is cancelled, all children (that + have not completed yet) are also cancelled. If any child is + cancelled, this is treated as if it raised CancelledError -- + the outer Future is *not* cancelled in this case. (This is to + prevent the cancellation of one child to cause other children to + be cancelled.) + """ + children = [async(fut, loop=loop) for fut in coros_or_futures] + n = len(children) + if n == 0: + outer = futures.Future(loop=loop) + outer.set_result([]) + return outer + if loop is None: + loop = children[0]._loop + for fut in children: + if fut._loop is not loop: + raise ValueError("futures are tied to different event loops") + outer = _GatheringFuture(children, loop=loop) + nfinished = 0 + results = [None] * n + + def _done_callback(i, fut): + nonlocal nfinished + if outer._state != futures._PENDING: + if fut._exception is not None: + # Mark exception retrieved. + fut.exception() + return + if fut._state == futures._CANCELLED: + res = futures.CancelledError() + if not return_exceptions: + outer.set_exception(res) + return + elif fut._exception is not None: + res = fut.exception() # Mark exception retrieved. + if not return_exceptions: + outer.set_exception(res) + return + else: + res = fut._result + results[i] = res + nfinished += 1 + if nfinished == n: + outer.set_result(results) + + for i, fut in enumerate(children): + fut.add_done_callback(functools.partial(_done_callback, i)) + return outer + + +def shield(arg, *, loop=None): + """Wait for a future, shielding it from cancellation. + + The statement + + res = yield from shield(something()) + + is exactly equivalent to the statement + + res = yield from something() + + *except* that if the coroutine containing it is cancelled, the + task running in something() is not cancelled. From the POV of + something(), the cancellation did not happen. But its caller is + still cancelled, so the yield-from expression still raises + CancelledError. Note: If something() is cancelled by other means + this will still cancel shield(). + + If you want to completely ignore cancellation (not recommended) + you can combine shield() with a try/except clause, as follows: + + try: + res = yield from shield(something()) + except CancelledError: + res = None + """ + inner = async(arg, loop=loop) + if inner.done(): + # Shortcut. + return inner + loop = inner._loop + outer = futures.Future(loop=loop) + + def _done_callback(inner): + if outer.cancelled(): + # Mark inner's result as retrieved. + inner.cancelled() or inner.exception() + return + if inner.cancelled(): + outer.cancel() + else: + exc = inner.exception() + if exc is not None: + outer.set_exception(exc) + else: + outer.set_result(inner.result()) + + inner.add_done_callback(_done_callback) + return outer diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py new file mode 100644 index 0000000..91bbedb --- /dev/null +++ b/Lib/asyncio/test_utils.py @@ -0,0 +1,246 @@ +"""Utilities shared by tests.""" + +import collections +import contextlib +import io +import unittest.mock +import os +import sys +import threading +import unittest +import unittest.mock +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import tasks +from . import base_events +from . import events +from . import selectors + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def dummy_ssl_context(): + if ssl is None: + return None + else: + return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + +def run_briefly(loop): + @tasks.coroutine + def once(): + pass + gen = once() + t = tasks.Task(gen, loop=loop) + try: + loop.run_until_complete(t) + finally: + gen.close() + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): + def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the sample key and certificate files) differs + # between the stdlib and stand-alone Tulip/asyncio. + # Prefer our own if we can find it. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + if not os.path.isdir(here): + here = os.path.join(os.path.dirname(os.__file__), + 'test', 'test_asyncio') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + def app(environ, start_response): + status = '200 OK' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer + httpd = make_server(host, port, app, + server_class, SilentWSGIRequestHandler) + httpd.address = httpd.server_address + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + server_thread.join() + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = unittest.mock.Mock(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def select(self, timeout): + return [] + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value retuned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + def gen(): + yield + self._check_on_close = False + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.make_handle(callback, args) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.make_handle(callback, args) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass diff --git a/Lib/asyncio/transports.py b/Lib/asyncio/transports.py new file mode 100644 index 0000000..bf3adee --- /dev/null +++ b/Lib/asyncio/transports.py @@ -0,0 +1,186 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.create_server().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py new file mode 100644 index 0000000..a3a8e11 --- /dev/null +++ b/Lib/asyncio/unix_events.py @@ -0,0 +1,541 @@ +"""Selector eventloop for Unix with signal handling.""" + +import collections +import errno +import fcntl +import functools +import os +import signal +import socket +import stat +import subprocess +import sys + + +from . import constants +from . import events +from . import protocols +from . import selector_events +from . import tasks +from . import transports +from .log import asyncio_log + + +__all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR'] + +STDIN = 0 +STDOUT = 1 +STDERR = 2 + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + self._subprocesses = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + handler = self._signal_handlers.get(signal.SIGCHLD) + if handler is not None: + self.remove_signal_handler(signal.SIGCHLD) + super().close() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + asyncio_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + asyncio_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + self._reg_sigchld() + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + self._subprocesses[transp.get_pid()] = transp + yield from transp._post_init() + return transp + + def _reg_sigchld(self): + if signal.SIGCHLD not in self._signal_handlers: + self.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + def _sig_chld(self): + try: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + return + if pid == 0: + self.call_soon(self._sig_chld) + return + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + self.call_soon(self._sig_chld) + return + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) + except Exception: + asyncio_log.exception('Unknown exception in SIGCHLD handler') + + def _subprocess_closed(self, transport): + pid = transport.get_pid() + self._subprocesses.pop(pid, None) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause(self): + self._loop.remove_reader(self._fileno) + + def resume(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + asyncio_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + raise ValueError("Pipe transport is for pipes only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + # pipe was closed by peer + self._close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost or self._closing: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('pipe closed by peer or ' + 'os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + if self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + # TODO: Make the relationships between write_eof(), close(), + # abort(), _fatal_error() and _close() more straightforward. + + def write_eof(self): + if self._closing: + return + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + asyncio_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): + pipe = None + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.connected = False + self.disconnected = False + proc._pipes[fd] = self + + def connection_made(self, transport): + self.connected = True + self.pipe = transport + self.proc._try_connected() + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + +class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) + + def eof_received(self): + pass + + +class _UnixSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[STDIN] = None + if stdout == subprocess.PIPE: + self._pipes[STDOUT] = None + if stderr == subprocess.PIPE: + self._pipes[STDERR] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + self._extra['subprocess'] = self._proc + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._proc.pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @tasks.coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + transp, proto = yield from loop.connect_write_pipe( + functools.partial( + _UnixWriteSubprocessPipeProto, self, STDIN), + proc.stdin) + if proc.stdout is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial( + _UnixReadSubprocessPipeProto, self, STDOUT), + proc.stdout) + if proc.stderr is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial( + _UnixReadSubprocessPipeProto, self, STDERR), + proc.stderr) + if not self._pipes: + self._try_connected() + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _try_connected(self): + assert self._pending_calls is not None + if all(p is not None and p.connected for p in self._pipes.values()): + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + self._returncode = returncode + self._loop._subprocess_closed(self) + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py new file mode 100644 index 0000000..1d0ad26 --- /dev/null +++ b/Lib/asyncio/windows_events.py @@ -0,0 +1,375 @@ +"""Selector and proactor eventloops for Windows.""" + +import errno +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import tasks +from . import windows_utils +from .log import asyncio_log + +try: + import _overlapped +except ImportError: + from . import _overlapped + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + self.ov = ov + + def cancel(self): + try: + self.ov.cancel() + except OSError: + pass + return super().cancel() + + +class PipeServer(object): + """Class representing a pipe server. + + This is much like a bound, listening socket. + """ + def __init__(self, address): + self._address = address + self._free_instances = weakref.WeakSet() + self._pipe = self._server_pipe_handle(True) + + def _get_unconnected_pipe(self): + # Create new instance and return previous one. This ensures + # that (until the server is closed) there is always at least + # one pipe handle for address. Therefore if a client attempt + # to connect it will not fail with FileNotFoundError. + tmp, self._pipe = self._pipe, self._server_pipe_handle(False) + return tmp + + def _server_pipe_handle(self, first): + # Return a wrapper for a new pipe handle. + if self._address is None: + return None + flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED + if first: + flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + h = _winapi.CreateNamedPipe( + self._address, flags, + _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | + _winapi.PIPE_WAIT, + _winapi.PIPE_UNLIMITED_INSTANCES, + windows_utils.BUFSIZE, windows_utils.BUFSIZE, + _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + pipe = windows_utils.PipeHandle(h) + self._free_instances.add(pipe) + return pipe + + def close(self): + # Close all instances which have not been connected to by a client. + if self._address is not None: + for pipe in self._free_instances: + pipe.close() + self._pipe = None + self._address = None + self._free_instances.clear() + + __del__ = close + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Windows version of selector event loop.""" + + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + """Windows version of proactor event loop using IOCP.""" + + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + @tasks.coroutine + def create_pipe_connection(self, protocol_factory, address): + f = self._proactor.connect_pipe(address) + pipe = yield from f + protocol = protocol_factory() + trans = self._make_duplex_pipe_transport(pipe, protocol, + extra={'addr': address}) + return trans, protocol + + @tasks.coroutine + def start_serving_pipe(self, protocol_factory, address): + server = PipeServer(address) + def loop(f=None): + pipe = None + try: + if f: + pipe = f.result() + server._free_instances.discard(pipe) + protocol = protocol_factory() + self._make_duplex_pipe_transport( + pipe, protocol, extra={'addr': address}) + pipe = server._get_unconnected_pipe() + if pipe is None: + return + f = self._proactor.accept_pipe(pipe) + except OSError: + if pipe and pipe.fileno() != -1: + asyncio_log.exception('Pipe accept failed') + pipe.close() + except futures.CancelledError: + if pipe: + pipe.close() + else: + f.add_done_callback(loop) + self.call_soon(loop) + return [server] + + def _stop_serving(self, server): + server.close() + + +class IocpProactor: + """Proactor implementation using IOCP.""" + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._stopped_serving = weakref.WeakSet() + + def set_loop(self, loop): + self._loop = loop + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(conn.fileno(), nbytes) + def finish(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + return self._register(ov, conn, finish) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(conn.fileno(), buf) + def finish(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + return self._register(ov, conn, finish) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket(listener.family) + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + def finish_accept(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + # The socket needs to be locally bound before we call ConnectEx(). + try: + _overlapped.BindLocal(conn.fileno(), conn.family) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # Probably already locally bound; check using getsockname(). + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + def finish_connect(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) + return conn + return self._register(ov, conn, finish_connect) + + def accept_pipe(self, pipe): + self._register_with_iocp(pipe) + ov = _overlapped.Overlapped(NULL) + ov.ConnectNamedPipe(pipe.fileno()) + def finish(trans, key, ov): + ov.getresult() + return pipe + return self._register(ov, pipe, finish) + + def connect_pipe(self, address): + ov = _overlapped.Overlapped(NULL) + ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address) + def finish(err, handle, ov): + # err, handle were arguments passed to PostQueuedCompletionStatus() + # in a function run in a thread pool. + if err == _overlapped.ERROR_SEM_TIMEOUT: + # Connection did not succeed within time limit. + msg = _overlapped.FormatMessage(err) + raise ConnectionRefusedError(0, msg, None, err) + elif err != 0: + msg = _overlapped.FormatMessage(err) + raise OSError(0, msg, None, err) + else: + return windows_utils.PipeHandle(handle) + return self._register(ov, None, finish, wait_for_post=True) + + def _register_with_iocp(self, obj): + # To get notifications of finished ops on this objects sent to the + # completion port, were must register the handle. + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + # XXX We could also use SetFileCompletionNotificationModes() + # to avoid sending notifications to completion port of ops + # that succeed immediately. + + def _register(self, ov, obj, callback, wait_for_post=False): + # Return a future which will be set with the result of the + # operation when it completes. The future's value is actually + # the value returned by callback(). + f = _OverlappedFuture(ov, loop=self._loop) + if ov.pending or wait_for_post: + # Register the overlapped operation for later. Note that + # we only store obj to prevent it from being garbage + # collected too early. + self._cache[ov.address] = (f, ov, obj, callback) + else: + # The operation has completed, so no need to postpone the + # work. We cannot take this short cut if we need the + # NumberOfBytes, CompletionKey values returned by + # PostQueuedCompletionStatus(). + try: + value = callback(None, None, ov) + except OSError as e: + f.set_exception(e) + else: + f.set_result(value) + return f + + def _get_accept_socket(self, family): + s = socket.socket(family) + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + err, transferred, key, address = status + try: + f, ov, obj, callback = self._cache.pop(address) + except KeyError: + # key is either zero, or it is used to return a pipe + # handle which should be closed to avoid a leak. + if key not in (0, _overlapped.INVALID_HANDLE_VALUE): + _winapi.CloseHandle(key) + ms = 0 + continue + if obj in self._stopped_serving: + f.cancel() + elif not f.cancelled(): + try: + value = callback(transferred, key, ov) + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def _stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop._stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + # Cancel remaining registered operations. + for address, (f, ov, obj, callback) in list(self._cache.items()): + if obj is None: + # The operation was started with connect_pipe() which + # queues a task to Windows' thread pool. This cannot + # be cancelled, so just forget it. + del self._cache[address] + else: + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + asyncio_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/Lib/asyncio/windows_utils.py b/Lib/asyncio/windows_utils.py new file mode 100644 index 0000000..04b43e9 --- /dev/null +++ b/Lib/asyncio/windows_utils.py @@ -0,0 +1,181 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import socket +import itertools +import msvcrt +import os +import subprocess +import tempfile +import _winapi + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + +# +# Constants/globals +# + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +_mmap_counter = itertools.count() + +# +# Replacement for socket.socketpair() +# + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except Exception: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) + +# +# Replacement for os.pipe() using handles instead of fds +# + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + +# +# Wrapper for a pipe handle +# + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle != -1: + CloseHandle(self._handle) + self._handle = -1 + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + +# +# Replacement for subprocess.Popen using overlapped pipe handles +# + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True)) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + try: + super().__init__(args, bufsize=0, universal_newlines=False, + stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py new file mode 100644 index 0000000..ec483ea --- /dev/null +++ b/Lib/test/test_asyncio/__init__.py @@ -0,0 +1,26 @@ +import os +import sys +import unittest +from test.support import run_unittest + + +def suite(): + tests_file = os.path.join(os.path.dirname(__file__), 'tests.txt') + with open(tests_file) as fp: + test_names = fp.read().splitlines() + tests = unittest.TestSuite() + loader = unittest.TestLoader() + for test_name in test_names: + mod_name = 'test.' + test_name + try: + __import__(mod_name) + except unittest.SkipTest: + pass + else: + mod = sys.modules[mod_name] + tests.addTests(loader.loadTestsFromModule(mod)) + return tests + + +def test_main(): + run_unittest(suite()) diff --git a/Lib/test/test_asyncio/__main__.py b/Lib/test/test_asyncio/__main__.py new file mode 100644 index 0000000..b549492 --- /dev/null +++ b/Lib/test/test_asyncio/__main__.py @@ -0,0 +1,5 @@ +from . import test_main + + +if __name__ == '__main__': + test_main() diff --git a/Lib/test/test_asyncio/echo.py b/Lib/test/test_asyncio/echo.py new file mode 100644 index 0000000..f6ac0a3 --- /dev/null +++ b/Lib/test/test_asyncio/echo.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, buf) diff --git a/Lib/test/test_asyncio/echo2.py b/Lib/test/test_asyncio/echo2.py new file mode 100644 index 0000000..e83ca09 --- /dev/null +++ b/Lib/test/test_asyncio/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/Lib/test/test_asyncio/echo3.py b/Lib/test/test_asyncio/echo3.py new file mode 100644 index 0000000..f1f7ea7 --- /dev/null +++ b/Lib/test/test_asyncio/echo3.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/Lib/test/test_asyncio/sample.crt b/Lib/test/test_asyncio/sample.crt new file mode 100644 index 0000000..6a1e3f3 --- /dev/null +++ b/Lib/test/test_asyncio/sample.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV +UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi +MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3 +MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp +Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g +U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn +t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7 +gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg +Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB +MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I +I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB +-----END CERTIFICATE----- diff --git a/Lib/test/test_asyncio/sample.key b/Lib/test/test_asyncio/sample.key new file mode 100644 index 0000000..edfea8d --- /dev/null +++ b/Lib/test/test_asyncio/sample.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx +/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi +qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB +AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y +bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1 +iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb +DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP +lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL +21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF +ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0 +zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u +GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq +V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof +-----END RSA PRIVATE KEY----- diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py new file mode 100644 index 0000000..d48d12c --- /dev/null +++ b/Lib/test/test_asyncio/test_base_events.py @@ -0,0 +1,590 @@ +"""Tests for base_events.py""" + +import logging +import socket +import time +import unittest +import unittest.mock + +from asyncio import base_events +from asyncio import events +from asyncio import futures +from asyncio import protocols +from asyncio import tasks +from asyncio import test_utils + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = unittest.mock.Mock() + events.set_event_loop(None) + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, self.loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + self.assertRaises(NotImplementedError, next, iter(gen)) + + def test__add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_timer(self): + h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + + def test__add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.loop.run_in_executor = unittest.mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, events.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = unittest.mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = unittest.mock.Mock() + when = self.loop.time() + 0.1 + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + t1 = self.loop.time() + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.TimerHandle(10, cb, ())) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future(loop=self.loop) + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + @unittest.mock.patch('asyncio.base_events.time') + @unittest.mock.patch('asyncio.base_events.asyncio_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.loop._scheduled.append( + events.TimerHandle(11.0, lambda: True, ())) + self.loop._process_events = unittest.mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test_run_until_complete_type_error(self): + self.assertRaises( + TypeError, self.loop.run_until_complete, 'blah') + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors(self, m_socket): + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @tasks.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_server_empty_host(self): + # if host is empty string use None instead + host = object() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.create_server(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_create_server_host_port_sock(self): + fut = self.loop.create_server( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_host_port_sock(self): + fut = self.loop.create_server(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_server_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = OSError() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py new file mode 100644 index 0000000..243f400 --- /dev/null +++ b/Lib/test/test_asyncio/test_events.py @@ -0,0 +1,1573 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +import unittest.mock +from test.support import find_unused_port + + +from asyncio import futures +from asyncio import events +from asyncio import transports +from asyncio import protocols +from asyncio import selector_events +from asyncio import tasks +from asyncio import test_utils +from asyncio import locks + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(protocols.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(protocols.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = futures.Future(loop=loop) + self.completed = futures.Future(loop=loop) + self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: locks.Event(loop=loop), + 2: locks.Event(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + + def test_run_until_complete_stopped(self): + @tasks.coroutine + def cb(): + self.loop.stop() + yield from tasks.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) + + def remove_writer(): + self.assertTrue(self.loop.remove_writer(w.fileno())) + + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertGreaterEqual(len(data), 200) + + def test_sock_client_ops(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_briefly(self.loop) + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(0.015, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_sock(self): + with test_utils.run_test_server() as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address, + ssl=test_utils.dummy_ssl_context()) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_local_addr(self): + with test_utils.run_test_server() as httpd: + port = find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('sockname')[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_create_server(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.loop.create_server(factory, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_briefly(self.loop) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.loop.create_server( + factory, '127.0.0.1', 0, ssl=sslcontext) + + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.loop.create_connection(ClientMyProto, host, port, + ssl=test_utils.dummy_ssl_context()) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + def test_create_server_sock(self): + proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(TestMyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + server.close() + + def test_create_server_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(MyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + f = self.loop.create_server(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + server.close() + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_create_server_dual_stack(self): + f_proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = find_unused_port() + f = self.loop.create_server(TestMyProto, host=None, port=port) + server = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = futures.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + server.close() + + def test_server_close(self): + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + server.close() + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('sockname') + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + for _ in range(1000): + if server.nbytes: + break + test_utils.run_briefly(self.loop) + self.assertEqual(3, server.nbytes) + for _ in range(1000): + if client.nbytes: + break + test_utils.run_briefly(self.loop) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('sockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_briefly(self.loop) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_briefly(self.loop) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + os.close(rpipe) + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + self.assertTrue(ov is None or ov.pending) + + @tasks.coroutine + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except futures.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = tasks.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(futures.CancelledError, f.result) + self.assertTrue(ov is None or not ov.pending) + self.loop._stop_serving(r) + + r.close() + w.close() + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exec(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_interactive(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_shell(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo "Python"') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exitcode(self): + proto = None + + @tasks.coroutine + def connect(): + nonlocal proto + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_after_finish(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_kill(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGKILL, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_send_signal(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr_redirect_to_stdout(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_client_stream(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + +if sys.platform == 'win32': + from asyncio import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_create_server_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from asyncio import selectors + from asyncio import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '<function HandleTests.test_handle.<locals>.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '<function HandleTests.test_handle.<locals>.callback')) + self.assertTrue(r.endswith('())<cancelled>'), r) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + self.assertRaises( + AssertionError, events.make_handle, h1, ()) + + @unittest.mock.patch('asyncio.events.asyncio_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h._run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_hash(self): + when = time.monotonic() + h = events.TimerHandle(when, lambda: False, ()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.TimerHandle(when, callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())<cancelled>'), r) + + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.create_server, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + unittest.mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = protocols.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + + self.assertIs(policy._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = events.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @unittest.mock.patch('asyncio.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = events.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_futures.py b/Lib/test/test_asyncio/test_futures.py new file mode 100644 index 0000000..9b5108c --- /dev/null +++ b/Lib/test/test_asyncio/test_futures.py @@ -0,0 +1,329 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import threading +import unittest +import unittest.mock + +from asyncio import events +from asyncio import futures +from asyncio import test_utils + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_initial_state(self): + f = futures.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + try: + events.set_event_loop(self.loop) + f = futures.Future() + self.assertIs(f._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_constructor_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future(loop=self.loop) + self.assertEqual(repr(f_pending), 'Future<PENDING>') + f_pending.cancel() + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future<CANCELLED>') + + f_result = futures.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future<result=4>') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future<exception=RuntimeError()>') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = futures.Future(loop=self.loop) + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future<PENDING, [<function _fakefunc', + repr(f_few_callbacks)) + f_few_callbacks.cancel() + + f_many_callbacks = futures.Future(loop=self.loop) + for i in range(20): + f_many_callbacks.add_done_callback(_fakefunc) + r = repr(f_many_callbacks) + self.assertIn('Future<PENDING, [<function _fakefunc', r) + self.assertIn('<18 more>', r) + f_many_callbacks.cancel() + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future(loop=self.loop) + f.set_result(10) + + newf = futures.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = futures.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = futures.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_abandoned(self, m_log): + fut = futures.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_result_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_exception_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, futures.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = futures.Future(loop=self.loop) + f2 = futures.wrap_future(f1) + self.assertIs(f1, f2) + + @unittest.mock.patch('asyncio.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def run_briefly(self): + test_utils.run_briefly(self.loop) + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(loop=self.loop) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + + self.run_briefly() + + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py new file mode 100644 index 0000000..31b4d64 --- /dev/null +++ b/Lib/test/test_asyncio/test_locks.py @@ -0,0 +1,765 @@ +"""Tests for lock.py""" + +import unittest +import unittest.mock + +from asyncio import events +from asyncio import futures +from asyncio import locks +from asyncio import tasks +from asyncio import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + lock = locks.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = locks.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + lock = locks.Lock() + self.assertIs(lock._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = futures.Future(loop=self.loop) + ta = tasks.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = tasks.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = tasks.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = locks.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + ev = locks.Event(loop=loop) + self.assertIs(ev._loop, loop) + + ev = locks.Event(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + ev = locks.Event() + self.assertIs(ev._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + ev = locks.Event(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = locks.Event(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = locks.Event(loop=self.loop) + + wait = tasks.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.Event(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = tasks.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + cond = locks.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = locks.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + cond = locks.Condition() + self.assertIs(cond._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_wait(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = locks.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = tasks.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = locks.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = locks.Condition(loop=self.loop) + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + sem = locks.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = locks.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + sem = locks.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + sem = locks.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = tasks.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + + def test_acquire_cancel(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True, loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2, loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py new file mode 100644 index 0000000..c52ade0 --- /dev/null +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -0,0 +1,480 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import asyncio +from asyncio.proactor_events import BaseProactorEventLoop +from asyncio.proactor_events import _ProactorSocketTransport +from asyncio.proactor_events import _ProactorWritePipeTransport +from asyncio.proactor_events import _ProactorDuplexPipeTransport +from asyncio import test_utils + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.proactor = unittest.mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = unittest.mock.Mock(socket.socket) + + def test_ctor(self): + fut = asyncio.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = unittest.mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._force_close = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.exception.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + def test_write_eof(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._eof_written) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.assertTrue(tr._closing) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_buffer_write_pipe(self): + tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_duplex_pipe(self): + tr = _ProactorDuplexPipeTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr.can_write_eof()) + with self.assertRaises(NotImplementedError): + tr.write_eof() + tr.close() + + def test_pause_resume(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + futures = [] + for msg in [b'data1', b'data2', b'data3', b'data4', b'']: + f = asyncio.Future(loop=self.loop) + f.set_result(msg) + futures.append(f) + self.loop._proactor.recv.side_effect = futures + self.loop._run_once() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data1') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.pause() + self.assertTrue(tr._paused) + for i in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.resume() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data3') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data4') + tr.close() + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + def test_create_server(self, m_log): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) + + def test_create_server_cancel(self): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = asyncio.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = unittest.mock.Mock() + self.loop._stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor._stop_serving.assert_called_with(sock) diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py new file mode 100644 index 0000000..8af4ee7 --- /dev/null +++ b/Lib/test/test_asyncio/test_queues.py @@ -0,0 +1,470 @@ +"""Tests for queues.py""" + +import unittest +import unittest.mock + +from asyncio import events +from asyncio import futures +from asyncio import locks +from asyncio import queues +from asyncio import tasks +from asyncio import test_utils + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + self.assertTrue(fn(q).startswith('<Queue'), fn(q)) + id_is_present = hex(id(q)) in fn(q) + self.assertEqual(expect_id, id_is_present) + + @tasks.coroutine + def add_getter(): + q = queues.Queue(loop=loop) + # Start a task that waits to get. + tasks.Task(q.get(), loop=loop) + # Let it start waiting. + yield from tasks.sleep(0.1, loop=loop) + self.assertTrue('_getters[1]' in fn(q)) + # resume q.get coroutine to finish generator + q.put_nowait(0) + + loop.run_until_complete(add_getter()) + + @tasks.coroutine + def add_putter(): + q = queues.Queue(maxsize=1, loop=loop) + q.put_nowait(1) + # Start a task that waits to put. + tasks.Task(q.put(2), loop=loop) + # Let it start waiting. + yield from tasks.sleep(0.1, loop=loop) + self.assertTrue('_putters[1]' in fn(q)) + # resume q.put coroutine to finish generator + q.get_nowait() + + loop.run_until_complete(add_putter()) + + q = queues.Queue(loop=loop) + q.put_nowait(1) + self.assertTrue('_queue=[1]' in fn(q)) + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + q = queues.Queue(loop=loop) + self.assertIs(q._loop, loop) + + q = queues.Queue(loop=self.loop) + self.assertIs(q._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + q = queues.Queue() + self.assertIs(q._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + self._test_repr_or_str(repr, True) + + def test_str(self): + self._test_repr_or_str(str, False) + + def test_empty(self): + q = queues.Queue(loop=self.loop) + self.assertTrue(q.empty()) + q.put_nowait(1) + self.assertFalse(q.empty()) + self.assertEqual(1, q.get_nowait()) + self.assertTrue(q.empty()) + + def test_full(self): + q = queues.Queue(loop=self.loop) + self.assertFalse(q.full()) + + q = queues.Queue(maxsize=1, loop=self.loop) + q.put_nowait(1) + self.assertTrue(q.full()) + + def test_order(self): + q = queues.Queue(loop=self.loop) + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 3, 2], items) + + def test_maxsize(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0.01 + self.assertAlmostEqual(0.02, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(maxsize=2, loop=loop) + self.assertEqual(2, q.maxsize) + have_been_put = [] + + @tasks.coroutine + def putter(): + for i in range(3): + yield from q.put(i) + have_been_put.append(i) + return True + + @tasks.coroutine + def test(): + t = tasks.Task(putter(), loop=loop) + yield from tasks.sleep(0.01, loop=loop) + + # The putter is blocked after putting two items. + self.assertEqual([0, 1], have_been_put) + self.assertEqual(0, q.get_nowait()) + + # Let the putter resume and put last item. + yield from tasks.sleep(0.01, loop=loop) + self.assertEqual([0, 1, 2], have_been_put) + self.assertEqual(1, q.get_nowait()) + self.assertEqual(2, q.get_nowait()) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + loop.run_until_complete(test()) + self.assertAlmostEqual(0.02, loop.time()) + + +class QueueGetTests(_QueueTestBase): + + def test_blocking_get(self): + q = queues.Queue(loop=self.loop) + q.put_nowait(1) + + @tasks.coroutine + def queue_get(): + return (yield from q.get()) + + res = self.loop.run_until_complete(queue_get()) + self.assertEqual(1, res) + + def test_get_with_putters(self): + q = queues.Queue(1, loop=self.loop) + q.put_nowait(1) + + waiter = futures.Future(loop=self.loop) + q._putters.append((2, waiter)) + + res = self.loop.run_until_complete(q.get()) + self.assertEqual(1, res) + self.assertTrue(waiter.done()) + self.assertIsNone(waiter.result()) + + def test_blocking_get_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + started = locks.Event(loop=loop) + finished = False + + @tasks.coroutine + def queue_get(): + nonlocal finished + started.set() + res = yield from q.get() + finished = True + return res + + @tasks.coroutine + def queue_put(): + loop.call_later(0.01, q.put_nowait, 1) + queue_get_task = tasks.Task(queue_get(), loop=loop) + yield from started.wait() + self.assertFalse(finished) + res = yield from queue_get_task + self.assertTrue(finished) + return res + + res = loop.run_until_complete(queue_put()) + self.assertEqual(1, res) + self.assertAlmostEqual(0.01, loop.time()) + + def test_nonblocking_get(self): + q = queues.Queue(loop=self.loop) + q.put_nowait(1) + self.assertEqual(1, q.get_nowait()) + + def test_nonblocking_get_exception(self): + q = queues.Queue(loop=self.loop) + self.assertRaises(queues.Empty, q.get_nowait) + + def test_get_cancelled(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0.01 + self.assertAlmostEqual(0.061, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + + @tasks.coroutine + def queue_get(): + return (yield from tasks.wait_for(q.get(), 0.051, loop=loop)) + + @tasks.coroutine + def test(): + get_task = tasks.Task(queue_get(), loop=loop) + yield from tasks.sleep(0.01, loop=loop) # let the task start + q.put_nowait(1) + return (yield from get_task) + + self.assertEqual(1, loop.run_until_complete(test())) + self.assertAlmostEqual(0.06, loop.time()) + + def test_get_cancelled_race(self): + q = queues.Queue(loop=self.loop) + + t1 = tasks.Task(q.get(), loop=self.loop) + t2 = tasks.Task(q.get(), loop=self.loop) + + test_utils.run_briefly(self.loop) + t1.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t1.done()) + q.put_nowait('a') + test_utils.run_briefly(self.loop) + self.assertEqual(t2.result(), 'a') + + def test_get_with_waiting_putters(self): + q = queues.Queue(loop=self.loop, maxsize=1) + tasks.Task(q.put('a'), loop=self.loop) + tasks.Task(q.put('b'), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(self.loop.run_until_complete(q.get()), 'a') + self.assertEqual(self.loop.run_until_complete(q.get()), 'b') + + +class QueuePutTests(_QueueTestBase): + + def test_blocking_put(self): + q = queues.Queue(loop=self.loop) + + @tasks.coroutine + def queue_put(): + # No maxsize, won't block. + yield from q.put(1) + + self.loop.run_until_complete(queue_put()) + + def test_blocking_put_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(maxsize=1, loop=loop) + started = locks.Event(loop=loop) + finished = False + + @tasks.coroutine + def queue_put(): + nonlocal finished + started.set() + yield from q.put(1) + yield from q.put(2) + finished = True + + @tasks.coroutine + def queue_get(): + loop.call_later(0.01, q.get_nowait) + queue_put_task = tasks.Task(queue_put(), loop=loop) + yield from started.wait() + self.assertFalse(finished) + yield from queue_put_task + self.assertTrue(finished) + + loop.run_until_complete(queue_get()) + self.assertAlmostEqual(0.01, loop.time()) + + def test_nonblocking_put(self): + q = queues.Queue(loop=self.loop) + q.put_nowait(1) + self.assertEqual(1, q.get_nowait()) + + def test_nonblocking_put_exception(self): + q = queues.Queue(maxsize=1, loop=self.loop) + q.put_nowait(1) + self.assertRaises(queues.Full, q.put_nowait, 2) + + def test_put_cancelled(self): + q = queues.Queue(loop=self.loop) + + @tasks.coroutine + def queue_put(): + yield from q.put(1) + return True + + @tasks.coroutine + def test(): + return (yield from q.get()) + + t = tasks.Task(queue_put(), loop=self.loop) + self.assertEqual(1, self.loop.run_until_complete(test())) + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_put_cancelled_race(self): + q = queues.Queue(loop=self.loop, maxsize=1) + + tasks.Task(q.put('a'), loop=self.loop) + tasks.Task(q.put('c'), loop=self.loop) + t = tasks.Task(q.put('b'), loop=self.loop) + + test_utils.run_briefly(self.loop) + t.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t.done()) + self.assertEqual(q.get_nowait(), 'a') + self.assertEqual(q.get_nowait(), 'c') + + def test_put_with_waiting_getters(self): + q = queues.Queue(loop=self.loop) + t = tasks.Task(q.get(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.loop.run_until_complete(q.put('a')) + self.assertEqual(self.loop.run_until_complete(t), 'a') + + +class LifoQueueTests(_QueueTestBase): + + def test_order(self): + q = queues.LifoQueue(loop=self.loop) + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([2, 3, 1], items) + + +class PriorityQueueTests(_QueueTestBase): + + def test_order(self): + q = queues.PriorityQueue(loop=self.loop) + for i in [1, 3, 2]: + q.put_nowait(i) + + items = [q.get_nowait() for _ in range(3)] + self.assertEqual([1, 2, 3], items) + + +class JoinableQueueTests(_QueueTestBase): + + def test_task_done_underflow(self): + q = queues.JoinableQueue(loop=self.loop) + self.assertRaises(ValueError, q.task_done) + + def test_task_done(self): + q = queues.JoinableQueue(loop=self.loop) + for i in range(100): + q.put_nowait(i) + + accumulator = 0 + + # Two workers get items from the queue and call task_done after each. + # Join the queue and assert all items have been processed. + running = True + + @tasks.coroutine + def worker(): + nonlocal accumulator + + while running: + item = yield from q.get() + accumulator += item + q.task_done() + + @tasks.coroutine + def test(): + for _ in range(2): + tasks.Task(worker(), loop=self.loop) + + yield from q.join() + + self.loop.run_until_complete(test()) + self.assertEqual(sum(range(100)), accumulator) + + # close running generators + running = False + for i in range(2): + q.put_nowait(0) + + def test_join_empty_queue(self): + q = queues.JoinableQueue(loop=self.loop) + + # Test that a queue join()s successfully, and before anything else + # (done twice for insurance). + + @tasks.coroutine + def join(): + yield from q.join() + yield from q.join() + + self.loop.run_until_complete(join()) + + def test_format(self): + q = queues.JoinableQueue(loop=self.loop) + self.assertEqual(q._format(), 'maxsize=0') + + q._unfinished_tasks = 2 + self.assertEqual(q._format(), 'maxsize=0 tasks=2') + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py new file mode 100644 index 0000000..0225e13 --- /dev/null +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -0,0 +1,1485 @@ +"""Tests for selector_events.py""" + +import collections +import errno +import gc +import pprint +import socket +import sys +import unittest +import unittest.mock +try: + import ssl +except ImportError: + ssl = None + +from asyncio import futures +from asyncio import selectors +from asyncio import test_utils +from asyncio.protocols import DatagramProtocol, Protocol +from asyncio.selector_events import BaseSelectorEventLoop +from asyncio.selector_events import _SelectorTransport +from asyncio.selector_events import _SelectorSslTransport +from asyncio.selector_events import _SelectorSocketTransport +from asyncio.selector_events import _SelectorDatagramTransport + + +class TestBaseSelectorEventLoop(BaseSelectorEventLoop): + + def _make_self_pipe(self): + self._ssock = unittest.mock.Mock() + self._csock = unittest.mock.Mock() + self._internal_fds += 1 + + +class BaseSelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = TestBaseSelectorEventLoop(unittest.mock.Mock()) + + def test_make_socket_transport(self): + m = unittest.mock.Mock() + self.loop.add_reader = unittest.mock.Mock() + self.assertIsInstance( + self.loop._make_socket_transport(m, m), _SelectorSocketTransport) + + def test_make_ssl_transport(self): + m = unittest.mock.Mock() + self.loop.add_reader = unittest.mock.Mock() + self.loop.add_writer = unittest.mock.Mock() + self.loop.remove_reader = unittest.mock.Mock() + self.loop.remove_writer = unittest.mock.Mock() + self.assertIsInstance( + self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) + + def test_close(self): + ssock = self.loop._ssock + ssock.fileno.return_value = 7 + csock = self.loop._csock + csock.fileno.return_value = 1 + remove_reader = self.loop.remove_reader = unittest.mock.Mock() + + self.loop._selector.close() + self.loop._selector = selector = unittest.mock.Mock() + self.loop.close() + self.assertIsNone(self.loop._selector) + self.assertIsNone(self.loop._csock) + self.assertIsNone(self.loop._ssock) + selector.close.assert_called_with() + ssock.close.assert_called_with() + csock.close.assert_called_with() + remove_reader.assert_called_with(7) + + self.loop.close() + self.loop.close() + + def test_close_no_selector(self): + ssock = self.loop._ssock + csock = self.loop._csock + remove_reader = self.loop.remove_reader = unittest.mock.Mock() + + self.loop._selector.close() + self.loop._selector = None + self.loop.close() + self.assertIsNone(self.loop._selector) + self.assertFalse(ssock.close.called) + self.assertFalse(csock.close.called) + self.assertFalse(remove_reader.called) + + def test_socketpair(self): + self.assertRaises(NotImplementedError, self.loop._socketpair) + + def test_read_from_self_tryagain(self): + self.loop._ssock.recv.side_effect = BlockingIOError + self.assertIsNone(self.loop._read_from_self()) + + def test_read_from_self_exception(self): + self.loop._ssock.recv.side_effect = OSError + self.assertRaises(OSError, self.loop._read_from_self) + + def test_write_to_self_tryagain(self): + self.loop._csock.send.side_effect = BlockingIOError + self.assertIsNone(self.loop._write_to_self()) + + def test_write_to_self_exception(self): + self.loop._csock.send.side_effect = OSError() + self.assertRaises(OSError, self.loop._write_to_self) + + def test_sock_recv(self): + sock = unittest.mock.Mock() + self.loop._sock_recv = unittest.mock.Mock() + + f = self.loop.sock_recv(sock, 1024) + self.assertIsInstance(f, futures.Future) + self.loop._sock_recv.assert_called_with(f, False, sock, 1024) + + def test__sock_recv_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop._sock_recv(f, False, sock, 1024) + self.assertFalse(sock.recv.called) + + def test__sock_recv_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop.remove_reader = unittest.mock.Mock() + self.loop._sock_recv(f, True, sock, 1024) + self.assertEqual((10,), self.loop.remove_reader.call_args[0]) + + def test__sock_recv_tryagain(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = BlockingIOError + + self.loop.add_reader = unittest.mock.Mock() + self.loop._sock_recv(f, False, sock, 1024) + self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024), + self.loop.add_reader.call_args[0]) + + def test__sock_recv_exception(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + err = sock.recv.side_effect = OSError() + + self.loop._sock_recv(f, False, sock, 1024) + self.assertIs(err, f.exception()) + + def test_sock_sendall(self): + sock = unittest.mock.Mock() + self.loop._sock_sendall = unittest.mock.Mock() + + f = self.loop.sock_sendall(sock, b'data') + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, b'data'), + self.loop._sock_sendall.call_args[0]) + + def test_sock_sendall_nodata(self): + sock = unittest.mock.Mock() + self.loop._sock_sendall = unittest.mock.Mock() + + f = self.loop.sock_sendall(sock, b'') + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + self.assertFalse(self.loop._sock_sendall.called) + + def test__sock_sendall_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(sock.send.called) + + def test__sock_sendall_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop.remove_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, True, sock, b'data') + self.assertEqual((10,), self.loop.remove_writer.call_args[0]) + + def test__sock_sendall_tryagain(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = BlockingIOError + + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + + def test__sock_sendall_interrupted(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = InterruptedError + + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + + def test__sock_sendall_exception(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + err = sock.send.side_effect = OSError() + + self.loop._sock_sendall(f, False, sock, b'data') + self.assertIs(f.exception(), err) + + def test__sock_sendall(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.send.return_value = 4 + + self.loop._sock_sendall(f, False, sock, b'data') + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test__sock_sendall_partial(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.send.return_value = 2 + + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'ta'), + self.loop.add_writer.call_args[0]) + + def test__sock_sendall_none(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + sock.fileno.return_value = 10 + sock.send.return_value = 0 + + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + + def test_sock_connect(self): + sock = unittest.mock.Mock() + self.loop._sock_connect = unittest.mock.Mock() + + f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, ('127.0.0.1', 8080)), + self.loop._sock_connect.call_args[0]) + + def test__sock_connect(self): + f = futures.Future(loop=self.loop) + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + self.assertTrue(sock.connect.called) + + def test__sock_connect_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertFalse(sock.connect.called) + + def test__sock_connect_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop.remove_writer = unittest.mock.Mock() + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual((10,), self.loop.remove_writer.call_args[0]) + + def test__sock_connect_tryagain(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.EAGAIN + + self.loop.add_writer = unittest.mock.Mock() + self.loop.remove_writer = unittest.mock.Mock() + + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual( + (10, self.loop._sock_connect, f, + True, sock, ('127.0.0.1', 8080)), + self.loop.add_writer.call_args[0]) + + def test__sock_connect_exception(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.ENOTCONN + + self.loop.remove_writer = unittest.mock.Mock() + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f.exception(), OSError) + + def test_sock_accept(self): + sock = unittest.mock.Mock() + self.loop._sock_accept = unittest.mock.Mock() + + f = self.loop.sock_accept(sock) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock), self.loop._sock_accept.call_args[0]) + + def test__sock_accept(self): + f = futures.Future(loop=self.loop) + + conn = unittest.mock.Mock() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.return_value = conn, ('127.0.0.1', 1000) + + self.loop._sock_accept(f, False, sock) + self.assertTrue(f.done()) + self.assertEqual((conn, ('127.0.0.1', 1000)), f.result()) + self.assertEqual((False,), conn.setblocking.call_args[0]) + + def test__sock_accept_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop._sock_accept(f, False, sock) + self.assertFalse(sock.accept.called) + + def test__sock_accept_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future(loop=self.loop) + f.cancel() + + self.loop.remove_reader = unittest.mock.Mock() + self.loop._sock_accept(f, True, sock) + self.assertEqual((10,), self.loop.remove_reader.call_args[0]) + + def test__sock_accept_tryagain(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = BlockingIOError + + self.loop.add_reader = unittest.mock.Mock() + self.loop._sock_accept(f, False, sock) + self.assertEqual( + (10, self.loop._sock_accept, f, True, sock), + self.loop.add_reader.call_args[0]) + + def test__sock_accept_exception(self): + f = futures.Future(loop=self.loop) + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + err = sock.accept.side_effect = OSError() + + self.loop._sock_accept(f, False, sock) + self.assertIs(err, f.exception()) + + def test_add_reader(self): + self.loop._selector.get_key.side_effect = KeyError + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertIsNone(w) + + def test_add_reader_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (reader, writer)) + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertTrue(reader.cancel.called) + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertEqual(writer, w) + + def test_add_reader_existing_writer(self): + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, writer)) + cb = lambda: True + self.loop.add_reader(1, cb) + + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r._callback) + self.assertEqual(writer, w) + + def test_remove_reader(self): + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (None, None)) + self.assertFalse(self.loop.remove_reader(1)) + + self.assertTrue(self.loop._selector.unregister.called) + + def test_remove_reader_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer)) + self.assertTrue( + self.loop.remove_reader(1)) + + self.assertFalse(self.loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, writer)), + self.loop._selector.modify.call_args[0]) + + def test_remove_reader_unknown(self): + self.loop._selector.get_key.side_effect = KeyError + self.assertFalse( + self.loop.remove_reader(1)) + + def test_add_writer(self): + self.loop._selector.get_key.side_effect = KeyError + cb = lambda: True + self.loop.add_writer(1, cb) + + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE, mask) + self.assertIsNone(r) + self.assertEqual(cb, w._callback) + + def test_add_writer_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, writer)) + cb = lambda: True + self.loop.add_writer(1, cb) + + self.assertTrue(writer.cancel.called) + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(reader, r) + self.assertEqual(cb, w._callback) + + def test_remove_writer(self): + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, None)) + self.assertFalse(self.loop.remove_writer(1)) + + self.assertTrue(self.loop._selector.unregister.called) + + def test_remove_writer_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer)) + self.assertTrue( + self.loop.remove_writer(1)) + + self.assertFalse(self.loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_READ, (reader, None)), + self.loop._selector.modify.call_args[0]) + + def test_remove_writer_unknown(self): + self.loop._selector.get_key.side_effect = KeyError + self.assertFalse( + self.loop.remove_writer(1)) + + def test_process_events_read(self): + reader = unittest.mock.Mock() + reader._cancelled = False + + self.loop._add_callback = unittest.mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) + self.assertTrue(self.loop._add_callback.called) + self.loop._add_callback.assert_called_with(reader) + + def test_process_events_read_cancelled(self): + reader = unittest.mock.Mock() + reader.cancelled = True + + self.loop.remove_reader = unittest.mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) + self.loop.remove_reader.assert_called_with(1) + + def test_process_events_write(self): + writer = unittest.mock.Mock() + writer._cancelled = False + + self.loop._add_callback = unittest.mock.Mock() + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) + self.loop._add_callback.assert_called_with(writer) + + def test_process_events_write_cancelled(self): + writer = unittest.mock.Mock() + writer.cancelled = True + self.loop.remove_writer = unittest.mock.Mock() + + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) + self.loop.remove_writer.assert_called_with(1) + + +class SelectorTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(Protocol) + self.sock = unittest.mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + + def test_ctor(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + self.assertIs(tr._loop, self.loop) + self.assertIs(tr._sock, self.sock) + self.assertIs(tr._sock_fd, 7) + + def test_abort(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._force_close = unittest.mock.Mock() + + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr.close() + + self.assertTrue(tr._closing) + self.assertEqual(1, self.loop.remove_reader_count[7]) + self.protocol.connection_lost(None) + self.assertEqual(tr._conn_lost, 1) + + tr.close() + self.assertEqual(tr._conn_lost, 1) + self.assertEqual(1, self.loop.remove_reader_count[7]) + + def test_close_write_buffer(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._buffer.append(b'data') + tr.close() + + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_force_close(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._buffer.append(b'1') + self.loop.add_reader(7, unittest.mock.sentinel) + self.loop.add_writer(7, unittest.mock.sentinel) + tr._force_close(None) + + self.assertTrue(tr._closing) + self.assertEqual(tr._buffer, collections.deque()) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + + # second close should not remove reader + tr._force_close(None) + self.assertFalse(self.loop.readers) + self.assertEqual(1, self.loop.remove_reader_count[7]) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + def test_fatal_error(self, m_exc): + exc = OSError() + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(exc) + + m_exc.assert_called_with('Fatal error for %s', tr) + tr._force_close.assert_called_with(exc) + + def test_connection_lost(self): + exc = OSError() + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._call_connection_lost(exc) + + self.protocol.connection_lost.assert_called_with(exc) + self.sock.close.assert_called_with() + self.assertIsNone(tr._sock) + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class SelectorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(Protocol) + self.sock = unittest.mock.Mock(socket.socket) + self.sock_fd = self.sock.fileno.return_value = 7 + + def test_ctor(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.loop.assert_reader(7, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + + _SelectorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + def test_pause_resume(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr._paused) + self.loop.assert_reader(7, tr._read_ready) + tr.pause() + self.assertTrue(tr._paused) + self.assertFalse(7 in self.loop.readers) + tr.resume() + self.assertFalse(tr._paused) + self.loop.assert_reader(7, tr._read_ready) + + def test_read_ready(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.protocol.data_received.assert_called_with(b'data') + + def test_read_ready_eof(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close = unittest.mock.Mock() + + self.sock.recv.return_value = b'' + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + transport.close.assert_called_with() + + def test_read_ready_eof_keep_open(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close = unittest.mock.Mock() + + self.sock.recv.return_value = b'' + self.protocol.eof_received.return_value = True + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + self.assertFalse(transport.close.called) + + @unittest.mock.patch('logging.exception') + def test_read_ready_tryagain(self, m_exc): + self.sock.recv.side_effect = BlockingIOError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @unittest.mock.patch('logging.exception') + def test_read_ready_tryagain_interrupted(self, m_exc): + self.sock.recv.side_effect = InterruptedError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @unittest.mock.patch('logging.exception') + def test_read_ready_conn_reset(self, m_exc): + err = self.sock.recv.side_effect = ConnectionResetError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._force_close = unittest.mock.Mock() + transport._read_ready() + transport._force_close.assert_called_with(err) + + @unittest.mock.patch('logging.exception') + def test_read_ready_err(self, m_exc): + err = self.sock.recv.side_effect = OSError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with(err) + + def test_write(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + self.sock.send.assert_called_with(data) + + def test_write_no_data(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(b'data') + transport.write(b'') + self.assertFalse(self.sock.send.called) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_write_buffer(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(b'data1') + transport.write(b'data2') + self.assertFalse(self.sock.send.called) + self.assertEqual(collections.deque([b'data1', b'data2']), + transport._buffer) + + def test_write_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'ta']), transport._buffer) + + def test_write_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + self.sock.fileno.return_value = 7 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_write_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + data = b'data' + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_write_exception(self, m_log): + err = self.sock.send.side_effect = OSError() + + data = b'data' + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.write(data) + transport._fatal_error.assert_called_with(err) + transport._conn_lost = 1 + + self.sock.reset_mock() + transport.write(data) + self.assertFalse(self.sock.send.called) + self.assertEqual(transport._conn_lost, 2) + transport.write(data) + transport.write(data) + transport.write(data) + transport.write(data) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_write_str(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.write, 'str') + + def test_write_closing(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.write(b'data') + self.assertEqual(transport._conn_lost, 2) + + def test_write_ready(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + self.assertFalse(self.loop.writers) + + def test_write_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.sock.send.assert_called_with(data) + self.assertFalse(self.loop.writers) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + def test_write_ready_no_data(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport._write_ready) + + def test_write_ready_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'ta']), transport._buffer) + + def test_write_ready_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_write_ready_tryagain(self): + self.sock.send.side_effect = BlockingIOError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer = collections.deque([b'data1', b'data2']) + self.loop.add_writer(7, transport._write_ready) + transport._write_ready() + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(collections.deque([b'data1data2']), transport._buffer) + + def test_write_ready_exception(self): + err = self.sock.send.side_effect = OSError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append(b'data') + transport._write_ready() + transport._fatal_error.assert_called_with(err) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_write_ready_exception_and_close(self, m_log): + self.sock.send.side_effect = OSError() + remove_writer = self.loop.remove_writer = unittest.mock.Mock() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close() + transport._buffer.append(b'data') + transport._write_ready() + remove_writer.assert_called_with(self.sock_fd) + + def test_write_eof(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.sock.send.side_effect = BlockingIOError + tr.write(b'data') + tr.write_eof() + self.assertEqual(tr._buffer, collections.deque([b'data'])) + self.assertTrue(tr._eof) + self.assertFalse(self.sock.shutdown.called) + self.sock.send.side_effect = lambda _: 4 + tr._write_ready() + self.sock.send.assert_called_with(b'data') + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorSslTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(Protocol) + self.sock = unittest.mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + self.sslsock = unittest.mock.Mock() + self.sslsock.fileno.return_value = 1 + self.sslcontext = unittest.mock.Mock() + self.sslcontext.wrap_socket.return_value = self.sslsock + + def _make_one(self, create_waiter=None): + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + self.sock.reset_mock() + self.sslsock.reset_mock() + self.sslcontext.reset_mock() + self.loop.reset_counters() + return transport + + def test_on_handshake(self): + waiter = futures.Future(loop=self.loop) + tr = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, + waiter=waiter) + self.assertTrue(self.sslsock.do_handshake.called) + self.loop.assert_reader(1, tr._on_ready) + self.loop.assert_writer(1, tr._on_ready) + test_utils.run_briefly(self.loop) + self.assertIsNone(waiter.result()) + + def test_on_handshake_reader_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._on_handshake() + self.loop.assert_reader(1, transport._on_handshake) + + def test_on_handshake_writer_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._on_handshake() + self.loop.assert_writer(1, transport._on_handshake) + + def test_on_handshake_exc(self): + exc = ValueError() + self.sslsock.do_handshake.side_effect = exc + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._waiter = futures.Future(loop=self.loop) + transport._on_handshake() + self.assertTrue(self.sslsock.close.called) + self.assertTrue(transport._waiter.done()) + self.assertIs(exc, transport._waiter.exception()) + + def test_on_handshake_base_exc(self): + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + transport._waiter = futures.Future(loop=self.loop) + exc = BaseException() + self.sslsock.do_handshake.side_effect = exc + self.assertRaises(BaseException, transport._on_handshake) + self.assertTrue(self.sslsock.close.called) + self.assertTrue(transport._waiter.done()) + self.assertIs(exc, transport._waiter.exception()) + + def test_pause_resume(self): + tr = self._make_one() + self.assertFalse(tr._paused) + self.loop.assert_reader(1, tr._on_ready) + tr.pause() + self.assertTrue(tr._paused) + self.assertFalse(1 in self.loop.readers) + tr.resume() + self.assertFalse(tr._paused) + self.loop.assert_reader(1, tr._on_ready) + + def test_write_no_data(self): + transport = self._make_one() + transport._buffer.append(b'data') + transport.write(b'') + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_write_str(self): + transport = self._make_one() + self.assertRaises(AssertionError, transport.write, 'str') + + def test_write_closing(self): + transport = self._make_one() + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.write(b'data') + self.assertEqual(transport._conn_lost, 2) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_write_exception(self, m_log): + transport = self._make_one() + transport._conn_lost = 1 + transport.write(b'data') + self.assertEqual(transport._buffer, collections.deque()) + transport.write(b'data') + transport.write(b'data') + transport.write(b'data') + transport.write(b'data') + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_on_ready_recv(self): + self.sslsock.recv.return_value = b'data' + transport = self._make_one() + transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) + + def test_on_ready_recv_eof(self): + self.sslsock.recv.return_value = b'' + transport = self._make_one() + transport.close = unittest.mock.Mock() + transport._on_ready() + transport.close.assert_called_with() + self.protocol.eof_received.assert_called_with() + + def test_on_ready_recv_conn_reset(self): + err = self.sslsock.recv.side_effect = ConnectionResetError() + transport = self._make_one() + transport._force_close = unittest.mock.Mock() + transport._on_ready() + transport._force_close.assert_called_with(err) + + def test_on_ready_recv_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + transport = self._make_one() + transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = ssl.SSLWantWriteError + transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = BlockingIOError + transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = InterruptedError + transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + def test_on_ready_recv_exc(self): + err = self.sslsock.recv.side_effect = OSError() + transport = self._make_one() + transport._fatal_error = unittest.mock.Mock() + transport._on_ready() + transport._fatal_error.assert_called_with(err) + + def test_on_ready_send(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport._buffer = collections.deque([b'data']) + transport._on_ready() + self.assertEqual(collections.deque(), transport._buffer) + self.assertTrue(self.sslsock.send.called) + + def test_on_ready_send_none(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 0 + transport = self._make_one() + transport._buffer = collections.deque([b'data1', b'data2']) + transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual(collections.deque([b'data1data2']), transport._buffer) + + def test_on_ready_send_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + transport = self._make_one() + transport._buffer = collections.deque([b'data1', b'data2']) + transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual(collections.deque([b'ta1data2']), transport._buffer) + + def test_on_ready_send_closing_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + transport = self._make_one() + transport._buffer = collections.deque([b'data1', b'data2']) + transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertFalse(self.sslsock.close.called) + + def test_on_ready_send_closing(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport.close() + transport._buffer = collections.deque([b'data']) + transport._on_ready() + self.assertFalse(self.loop.writers) + self.protocol.connection_lost.assert_called_with(None) + + def test_on_ready_send_closing_empty_buffer(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport.close() + transport._buffer = collections.deque() + transport._on_ready() + self.assertFalse(self.loop.writers) + self.protocol.connection_lost.assert_called_with(None) + + def test_on_ready_send_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + + transport = self._make_one() + transport._buffer = collections.deque([b'data']) + + self.sslsock.send.side_effect = ssl.SSLWantReadError + transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual(collections.deque([b'data']), transport._buffer) + + self.sslsock.send.side_effect = ssl.SSLWantWriteError + transport._on_ready() + self.assertEqual(collections.deque([b'data']), transport._buffer) + + self.sslsock.send.side_effect = BlockingIOError() + transport._on_ready() + self.assertEqual(collections.deque([b'data']), transport._buffer) + + def test_on_ready_send_exc(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + err = self.sslsock.send.side_effect = OSError() + + transport = self._make_one() + transport._buffer = collections.deque([b'data']) + transport._fatal_error = unittest.mock.Mock() + transport._on_ready() + transport._fatal_error.assert_called_with(err) + self.assertEqual(collections.deque(), transport._buffer) + + def test_write_eof(self): + tr = self._make_one() + self.assertFalse(tr.can_write_eof()) + self.assertRaises(NotImplementedError, tr.write_eof) + + def test_close(self): + tr = self._make_one() + tr.close() + + self.assertTrue(tr._closing) + self.assertEqual(1, self.loop.remove_reader_count[1]) + self.assertEqual(tr._conn_lost, 1) + + tr.close() + self.assertEqual(tr._conn_lost, 1) + self.assertEqual(1, self.loop.remove_reader_count[1]) + + @unittest.skipIf(ssl is None or not ssl.HAS_SNI, 'No SNI support') + def test_server_hostname(self): + _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, + server_hostname='localhost') + self.sslcontext.wrap_socket.assert_called_with( + self.sock, do_handshake_on_connect=False, server_side=False, + server_hostname='localhost') + + +class SelectorDatagramTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(DatagramProtocol) + self.sock = unittest.mock.Mock(spec_set=socket.socket) + self.sock.fileno.return_value = 7 + + def test_read_ready(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) + transport._read_ready() + + self.protocol.datagram_received.assert_called_with( + b'data', ('0.0.0.0', 1234)) + + def test_read_ready_tryagain(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + self.sock.recvfrom.side_effect = BlockingIOError + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + err = self.sock.recvfrom.side_effect = OSError() + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + transport._fatal_error.assert_called_with(err) + + def test_sendto(self): + data = b'data' + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_no_data(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data', ('0.0.0.0', 12345))) + transport.sendto(b'', ()) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_buffer(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_tryagain(self): + data = b'data' + + self.sock.sendto.side_effect = BlockingIOError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 12345)) + + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_sendto_exception(self, m_log): + data = b'data' + err = self.sock.sendto.side_effect = OSError() + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + transport._fatal_error.assert_called_with(err) + transport._conn_lost = 1 + + transport._address = ('123',) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_sendto_connection_refused(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertEqual(transport._conn_lost, 0) + self.assertFalse(transport._fatal_error.called) + + def test_sendto_connection_refused_connected(self): + data = b'data' + + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data) + + self.assertTrue(transport._fatal_error.called) + + def test_sendto_str(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.assertRaises( + AssertionError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, address=(1,)) + transport.close() + self.assertEqual(transport._conn_lost, 1) + transport.sendto(b'data', (1,)) + self.assertEqual(transport._conn_lost, 2) + + def test_sendto_ready(self): + data = b'data' + self.sock.sendto.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((data, ('0.0.0.0', 12345))) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertFalse(self.loop.writers) + + def test_sendto_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append((data, ())) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.sock.sendto.assert_called_with(data, ()) + self.assertFalse(self.loop.writers) + self.sock.close.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + def test_sendto_ready_no_data(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + self.assertFalse(self.sock.sendto.called) + self.assertFalse(self.loop.writers) + + def test_sendto_ready_tryagain(self): + self.sock.sendto.side_effect = BlockingIOError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.extend([(b'data1', ()), (b'data2', ())]) + self.loop.add_writer(7, transport._sendto_ready) + transport._sendto_ready() + + self.loop.assert_writer(7, transport._sendto_ready) + self.assertEqual( + [(b'data1', ()), (b'data2', ())], + list(transport._buffer)) + + def test_sendto_ready_exception(self): + err = self.sock.sendto.side_effect = OSError() + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + transport._fatal_error.assert_called_with(err) + + def test_sendto_ready_connection_refused(self): + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_ready_connection_refused_connection(self): + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertTrue(transport._fatal_error.called) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + def test_fatal_error_connected(self, m_exc): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + err = ConnectionRefusedError() + transport._fatal_error(err) + self.protocol.connection_refused.assert_called_with(err) + m_exc.assert_called_with('Fatal error for %s', transport) diff --git a/Lib/test/test_asyncio/test_selectors.py b/Lib/test/test_asyncio/test_selectors.py new file mode 100644 index 0000000..2f7dc69 --- /dev/null +++ b/Lib/test/test_asyncio/test_selectors.py @@ -0,0 +1,145 @@ +"""Tests for selectors.py.""" + +import unittest +import unittest.mock + +from asyncio import selectors + + +class FakeSelector(selectors.BaseSelector): + """Trivial non-abstract subclass of BaseSelector.""" + + def select(self, timeout=None): + raise NotImplementedError + + +class BaseSelectorTests(unittest.TestCase): + + def test_fileobj_to_fd(self): + self.assertEqual(10, selectors._fileobj_to_fd(10)) + + f = unittest.mock.Mock() + f.fileno.return_value = 10 + self.assertEqual(10, selectors._fileobj_to_fd(f)) + + f.fileno.side_effect = AttributeError + self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + + def test_selector_key_repr(self): + key = selectors.SelectorKey(10, 10, selectors.EVENT_READ, None) + self.assertEqual( + "SelectorKey(fileobj=10, fd=10, events=1, data=None)", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = FakeSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(KeyError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + + def test_unregister_unknown(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + self.assertRaises(KeyError, s.unregister, fobj) + + def test_modify_unknown(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + self.assertRaises(KeyError, s.modify, fobj, 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = FakeSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual( + selectors.SelectorKey(fobj, 10, selectors.EVENT_WRITE, None), + s.get_key(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = FakeSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual( + selectors.SelectorKey(fobj, 10, selectors.EVENT_READ, d2), + s.get_key(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = FakeSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = FakeSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = FakeSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + + def test_context_manager(self): + s = FakeSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + + def test_key_from_fd(self): + s = FakeSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) + + if hasattr(selectors.DefaultSelector, 'fileno'): + def test_fileno(self): + self.assertIsInstance(selectors.DefaultSelector().fileno(), int) diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py new file mode 100644 index 0000000..011a09d --- /dev/null +++ b/Lib/test/test_asyncio/test_streams.py @@ -0,0 +1,361 @@ +"""Tests for streams.py.""" + +import gc +import ssl +import unittest +import unittest.mock + +from asyncio import events +from asyncio import streams +from asyncio import tasks +from asyncio import test_utils + + +class StreamReaderTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + + @unittest.mock.patch('asyncio.streams.events') + def test_ctor_global_loop(self, m_events): + stream = streams.StreamReader() + self.assertIs(stream.loop, m_events.get_event_loop.return_value) + + def test_open_connection(self): + with test_utils.run_test_server() as httpd: + f = streams.open_connection(*httpd.address, loop=self.loop) + reader, writer = self.loop.run_until_complete(f) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_open_connection_no_loop_ssl(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + try: + events.set_event_loop(self.loop) + f = streams.open_connection(*httpd.address, + ssl=test_utils.dummy_ssl_context()) + reader, writer = self.loop.run_until_complete(f) + finally: + events.set_event_loop(None) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + + def test_open_connection_error(self): + with test_utils.run_test_server() as httpd: + f = streams.open_connection(*httpd.address, loop=self.loop) + reader, writer = self.loop.run_until_complete(f) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + + writer.close() + test_utils.run_briefly(self.loop) + + def test_feed_empty_data(self): + stream = streams.StreamReader(loop=self.loop) + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(30), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(1024), loop=self.loop) + + def cb(): + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(-1), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline(), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.loop.call_soon(cb) + + line = self.loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + stream = streams.StreamReader(3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + stream = streams.StreamReader(7, loop=self.loop) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.loop.call_soon(cb) + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + self.loop.run_until_complete(stream.readline()) + + data = self.loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader(loop=self.loop) + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader(loop=self.loop) + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader(loop=self.loop) + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader(loop=self.loop) + + @tasks.coroutine + def set_err(): + stream.set_exception(ValueError()) + + @tasks.coroutine + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) + + self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) + + self.assertRaises(ValueError, t1.result) + + def test_exception_cancel(self): + stream = streams.StreamReader(loop=self.loop) + + @tasks.coroutine + def read_a_line(): + yield from stream.readline() + + t = tasks.Task(read_a_line(), loop=self.loop) + test_utils.run_briefly(self.loop) + t.cancel() + test_utils.run_briefly(self.loop) + # The following line fails if set_exception() isn't careful. + stream.set_exception(RuntimeError('message')) + test_utils.run_briefly(self.loop) + self.assertIs(stream.waiter, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_tasks.py b/Lib/test/test_asyncio/test_tasks.py new file mode 100644 index 0000000..57fb053 --- /dev/null +++ b/Lib/test/test_asyncio/test_tasks.py @@ -0,0 +1,1518 @@ +"""Tests for tasks.py.""" + +import gc +import unittest +import unittest.mock +from unittest.mock import Mock + +from asyncio import events +from asyncio import futures +from asyncio import tasks +from asyncio import test_utils + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + gc.collect() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = events.new_event_loop() + t = tasks.Task(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.close() + + def test_async_coroutine(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t = tasks.async(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = events.new_event_loop() + t = tasks.async(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.close() + + def test_async_future(self): + f_orig = futures.Future(loop=self.loop) + f_orig.set_result('ko') + + f = tasks.async(f_orig) + self.loop.run_until_complete(f) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 'ko') + self.assertIs(f, f_orig) + + loop = events.new_event_loop() + + with self.assertRaises(ValueError): + f = tasks.async(f_orig, loop=loop) + + loop.close() + + f = tasks.async(f_orig, loop=self.loop) + self.assertIs(f, f_orig) + + def test_async_task(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t_orig = tasks.Task(notmuch(), loop=self.loop) + t = tasks.async(t_orig) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t, t_orig) + + loop = events.new_event_loop() + + with self.assertRaises(ValueError): + t = tasks.async(t_orig, loop=loop) + + loop.close() + + t = tasks.async(t_orig, loop=self.loop) + self.assertIs(t, t_orig) + + def test_async_neither(self): + with self.assertRaises(TypeError): + tasks.async('ok') + + def test_task_repr(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'abc' + + t = tasks.Task(notmuch(), loop=self.loop) + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task(<notmuch>)<PENDING, [Dummy()]>') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLING, [Dummy()]>') + self.assertRaises(futures.CancelledError, + self.loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task(<notmuch>)<CANCELLED>') + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertEqual(repr(t), "Task(<notmuch>)<result='abc'>") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + gen = coro() + t = MyTask(gen, loop=self.loop) + self.assertEqual(repr(t), 'T[](<coro>)') + gen.close() + + def test_task_basics(self): + @tasks.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.coroutine + def inner1(): + return 42 + + @tasks.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + loop.call_soon(t.cancel) + with self.assertRaises(futures.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @tasks.coroutine + def task(): + yield + yield + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_both_task_and_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + + f.cancel() + t.cancel() + + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_task_catching(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + return 42 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + fut3 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + pass + res = yield from fut3 + return res + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = events.new_event_loop() + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + yield from tasks.sleep(100, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + futures.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + x = 0 + waiters = [] + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(tasks.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + RuntimeError, loop.run_until_complete, t) + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.4, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + fut = tasks.Task(foo(), loop=loop) + + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) + + self.assertFalse(fut.done()) + self.assertAlmostEqual(0.1, loop.time()) + + # wait for result + res = loop.run_until_complete( + tasks.wait_for(fut, 0.3, loop=loop)) + self.assertEqual(res, 'done') + self.assertAlmostEqual(0.2, loop.time()) + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + events.set_event_loop(loop) + try: + fut = tasks.Task(foo(), loop=loop) + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.01)) + finally: + events.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertFalse(fut.done()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(fut) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + events.set_event_loop(loop) + try: + res = loop.run_until_complete( + tasks.Task(foo(), loop=loop)) + finally: + events.set_event_loop(None) + + self.assertEqual(res, 42) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait(set(), loop=self.loop)) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait([tasks.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1(), loop=self.loop) + b = tasks.Task(coro2(), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, task already has exception + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, exception during waiting + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper(), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + tasks.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + @tasks.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from tasks.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.12, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0.1 + self.assertAlmostEqual(0.12, when) + yield 0.02 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.1, 'a', loop=loop) + b = tasks.sleep(0.15, 'b', loop=loop) + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = tasks.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2, loop=loop) + res = yield from tasks.sleep(dt/2, arg, loop=loop) + return res + + t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sleepfut = None + + @tasks.coroutine + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt, loop=loop) + yield from sleepfut + + @tasks.coroutine + def doit(): + sleeper = tasks.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except futures.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def coro(): + yield from fut + + task = tasks.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @tasks.coroutine + def notmuch(): + return 'ko' + + gen = notmuch() + task = tasks.Task(gen, loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + gen.close() + + def test_step_result(self): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(futures.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @tasks.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = tasks.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10, loop=loop) + + base_exc = BaseException() + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise base_exc + + task = tasks.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @tasks.coroutine + def coro(): + yield + + @tasks.coroutine + def wait_for_future(): + gen = coro() + try: + yield gen + finally: + gen.close() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func(), loop=self.loop) + t2 = tasks.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + # Some thorough tests for cancellation propagation through + # coroutines, tasks and wait(). + + def test_yield_future_passes_cancel(self): + # Cancelling outer() cancels inner() cancels waiter. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + try: + yield from waiter + except futures.CancelledError: + proof += 1 + raise + else: + self.fail('got past sleep() in inner()') + + @tasks.coroutine + def outer(): + nonlocal proof + try: + yield from inner() + except futures.CancelledError: + proof += 100 # Expect this path. + else: + proof += 10 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + self.assertEqual(proof, 101) + self.assertTrue(waiter.cancelled()) + + def test_yield_wait_does_not_shield_cancel(self): + # Cancelling outer() makes wait() return early, leaves inner() + # running. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + d, p = yield from tasks.wait([inner()], loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_result(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + inner.set_result(42) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_exception(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + exc = RuntimeError('expected') + inner.set_exception(exc) + test_utils.run_briefly(self.loop) + self.assertIs(outer.exception(), exc) + + def test_shield_cancel(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + inner.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + + def test_shield_shortcut(self): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + res = self.loop.run_until_complete(tasks.shield(fut)) + self.assertEqual(res, 42) + + def test_shield_effect(self): + # Cancelling outer() does not affect inner(). + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + yield from tasks.shield(inner(), loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_gather(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + parent = tasks.gather(child1, child2, loop=self.loop) + outer = tasks.shield(parent, loop=self.loop) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + self.assertEqual(parent.result(), [1, 2]) + + def test_gather_shield(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + inner1 = tasks.shield(child1, loop=self.loop) + inner2 = tasks.shield(child2, loop=self.loop) + parent = tasks.gather(inner1, inner2, loop=self.loop) + test_utils.run_briefly(self.loop) + parent.cancel() + # This should cancel inner1 and inner2 but bot child1 and child2. + test_utils.run_briefly(self.loop) + self.assertIsInstance(parent.exception(), futures.CancelledError) + self.assertTrue(inner1.cancelled()) + self.assertTrue(inner2.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + + +class GatherTestsBase: + + def setUp(self): + self.one_loop = test_utils.TestLoop() + self.other_loop = test_utils.TestLoop() + + def tearDown(self): + self.one_loop.close() + self.other_loop.close() + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)] + fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs) + cb = Mock() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d, e)) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_return_exceptions(self): + a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + +class FutureGatherTests(GatherTestsBase, unittest.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _check_empty_sequence(self, seq_or_iter): + events.set_event_loop(self.one_loop) + self.addCleanup(events.set_event_loop, None) + fut = tasks.gather(*seq_or_iter) + self.assertIsInstance(fut, futures.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + fut = tasks.gather(*seq_or_iter, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_constructor_empty_sequence(self): + self._check_empty_sequence([]) + self._check_empty_sequence(()) + self._check_empty_sequence(set()) + self._check_empty_sequence(iter("")) + + def test_constructor_heterogenous_futures(self): + fut1 = futures.Future(loop=self.one_loop) + fut2 = futures.Future(loop=self.other_loop) + with self.assertRaises(ValueError): + tasks.gather(fut1, fut2) + with self.assertRaises(ValueError): + tasks.gather(fut1, loop=self.other_loop) + + def test_constructor_homogenous_futures(self): + children = [futures.Future(loop=self.other_loop) for i in range(3)] + fut = tasks.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = tasks.gather(*children, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(a, b, c, d, e) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertFalse(fut.cancelled()) + self.assertIsInstance(fut.exception(), futures.CancelledError) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [futures.Future(loop=self.one_loop) + for i in range(6)] + fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + zde = ZeroDivisionError() + b.set_exception(zde) + c.cancel() + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_result(3) + e.cancel() + rte = RuntimeError() + f.set_exception(rte) + res = self.one_loop.run_until_complete(fut) + self.assertIsInstance(res[2], futures.CancelledError) + self.assertIsInstance(res[4], futures.CancelledError) + res[2] = res[4] = None + self.assertEqual(res, [1, zde, None, 3, None, rte]) + cb.assert_called_once_with(fut) + + +class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): + + def setUp(self): + super().setUp() + events.set_event_loop(self.one_loop) + + def tearDown(self): + events.set_event_loop(None) + super().tearDown() + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + @tasks.coroutine + def coro(fut=fut): + return (yield from fut) + coros.append(coro()) + return coros + + def test_constructor_loop_selection(self): + @tasks.coroutine + def coro(): + return 'abc' + gen1 = coro() + gen2 = coro() + fut = tasks.gather(gen1, gen2) + self.assertIs(fut._loop, self.one_loop) + gen1.close() + gen2.close() + gen3 = coro() + gen4 = coro() + fut = tasks.gather(gen3, gen4, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + gen3.close() + gen4.close() + + def test_cancellation_broadcast(self): + # Cancelling outer() cancels all children. + proof = 0 + waiter = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + child1 = tasks.async(inner(), loop=self.one_loop) + child2 = tasks.async(inner(), loop=self.one_loop) + gatherer = None + + @tasks.coroutine + def outer(): + nonlocal proof, gatherer + gatherer = tasks.gather(child1, child2, loop=self.one_loop) + yield from gatherer + proof += 100 + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + self.assertTrue(f.cancel()) + with self.assertRaises(futures.CancelledError): + self.one_loop.run_until_complete(f) + self.assertFalse(gatherer.cancel()) + self.assertTrue(waiter.cancelled()) + self.assertTrue(child1.cancelled()) + self.assertTrue(child2.cancelled()) + test_utils.run_briefly(self.one_loop) + self.assertEqual(proof, 0) + + def test_exception_marking(self): + # Test for the first line marked "Mark exception retrieved." + + @tasks.coroutine + def inner(f): + yield from f + raise RuntimeError('should not be ignored') + + a = futures.Future(loop=self.one_loop) + b = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def outer(): + yield from tasks.gather(inner(a), inner(b), loop=self.one_loop) + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + a.set_result(None) + test_utils.run_briefly(self.one_loop) + b.set_result(None) + test_utils.run_briefly(self.one_loop) + self.assertIsInstance(f.exception(), RuntimeError) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncio/test_transports.py b/Lib/test/test_asyncio/test_transports.py new file mode 100644 index 0000000..fce2e6f --- /dev/null +++ b/Lib/test/test_asyncio/test_transports.py @@ -0,0 +1,55 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from asyncio import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = transports.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py new file mode 100644 index 0000000..ea67862 --- /dev/null +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -0,0 +1,767 @@ +"""Tests for unix_events.py.""" + +import gc +import errno +import io +import pprint +import signal +import stat +import sys +import unittest +import unittest.mock + + +from asyncio import events +from asyncio import futures +from asyncio import protocols +from asyncio import test_utils +from asyncio import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = unix_events.SelectorEventLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1, ()) + + def test_handle_signal_cancelled_handler(self): + h = events.Handle(unittest.mock.Mock(), ()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertTrue(isinstance(h, events.Handle)) + self.assertEqual(h._callback, cb) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('asyncio.unix_events.signal') + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('asyncio.unix_events.signal') + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('asyncio.unix_events.signal') + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(3) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = True + m_WTERMSIG.return_value = 1 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(-1) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(0, object()), ChildProcessError] + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_not_registered_subprocess(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + + self.loop._sig_chld() + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = False + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status_in_handler(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): + m_waitpid.side_effect = Exception + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + m_log.exception.assert_called_with( + 'Unknown exception in SIGCHLD handler') + + @unittest.mock.patch('os.waitpid') + def test__sig_chld_process_error(self, m_waitpid): + m_waitpid.side_effect = ChildProcessError + self.loop._sig_chld() + self.assertTrue(m_waitpid.called) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.Protocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @unittest.mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + def test_pause(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = unittest.mock.Mock() + self.loop.add_reader(5, m) + tr.pause() + self.assertFalse(self.loop.readers) + + @unittest.mock.patch('os.read') + def test_resume(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume() + self.loop.assert_reader(5, tr._read_ready) + + @unittest.mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + # This is a bit overspecified. :-( + m_log.warning.assert_called_with( + 'pipe closed by peer or os.write(pipe, data) raised exception.') + + @unittest.mock.patch('os.write') + def test_write_close(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._read_ready() # pipe was closed by peer + + tr.write(b'data') + self.assertEqual(tr._conn_lost, 1) + tr.write(b'data') + self.assertEqual(tr._conn_lost, 2) + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/Lib/test/test_asyncio/test_windows_events.py b/Lib/test/test_asyncio/test_windows_events.py new file mode 100644 index 0000000..4b04073 --- /dev/null +++ b/Lib/test/test_asyncio/test_windows_events.py @@ -0,0 +1,95 @@ +import os +import sys +import unittest + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import asyncio + +from asyncio import windows_events +from asyncio import protocols +from asyncio import streams +from asyncio import transports +from asyncio import test_utils + + +class UpperProto(protocols.Protocol): + def __init__(self): + self.buf = [] + + def connection_made(self, trans): + self.trans = trans + + def data_received(self, data): + self.buf.append(data) + if b'\n' in data: + self.trans.write(b''.join(self.buf).upper()) + self.trans.close() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = windows_events.ProactorEventLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + self.loop = None + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, protocols.Protocol()) + f = asyncio.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f) + self.assertEqual(f.result(), b'') + + def test_double_bind(self): + ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() + server1 = windows_events.PipeServer(ADDRESS) + with self.assertRaises(PermissionError): + server2 = windows_events.PipeServer(ADDRESS) + server1.close() + + def test_pipe(self): + res = self.loop.run_until_complete(self._test_pipe()) + self.assertEqual(res, 'done') + + def _test_pipe(self): + ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + [server] = yield from self.loop.start_serving_pipe( + UpperProto, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + clients = [] + for i in range(5): + stream_reader = streams.StreamReader(loop=self.loop) + protocol = streams.StreamReaderProtocol(stream_reader) + trans, proto = yield from self.loop.create_pipe_connection( + lambda:protocol, ADDRESS) + self.assertIsInstance(trans, transports.Transport) + self.assertEqual(protocol, proto) + clients.append((stream_reader, trans)) + + for i, (r, w) in enumerate(clients): + w.write('lower-{}\n'.format(i).encode()) + + for i, (r, w) in enumerate(clients): + response = yield from r.readline() + self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + w.close() + + server.close() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + return 'done' diff --git a/Lib/test/test_asyncio/test_windows_utils.py b/Lib/test/test_asyncio/test_windows_utils.py new file mode 100644 index 0000000..4b96086 --- /dev/null +++ b/Lib/test/test_asyncio/test_windows_utils.py @@ -0,0 +1,136 @@ +"""Tests for window_utils""" + +import sys +import test.support +import unittest +import unittest.mock + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _winapi + +from asyncio import windows_utils +from asyncio import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('asyncio.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + res = _winapi.WaitForMultipleObjects(events, True, 2000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) diff --git a/Lib/test/test_asyncio/tests.txt b/Lib/test/test_asyncio/tests.txt new file mode 100644 index 0000000..e947721 --- /dev/null +++ b/Lib/test/test_asyncio/tests.txt @@ -0,0 +1,14 @@ +test_asyncio.test_base_events +test_asyncio.test_events +test_asyncio.test_futures +test_asyncio.test_locks +test_asyncio.test_proactor_events +test_asyncio.test_queues +test_asyncio.test_selector_events +test_asyncio.test_selectors +test_asyncio.test_streams +test_asyncio.test_tasks +test_asyncio.test_transports +test_asyncio.test_unix_events +test_asyncio.test_windows_events +test_asyncio.test_windows_utils |