diff options
author | Antoine Pitrou <pitrou@free.fr> | 2017-09-07 16:56:24 (GMT) |
---|---|---|
committer | Victor Stinner <victor.stinner@gmail.com> | 2017-09-07 16:56:24 (GMT) |
commit | a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344 (patch) | |
tree | 1c31738009bee903417cea928e705a112aea2392 /Lib | |
parent | 1f06a680de465be0c24a78ea3b610053955daa99 (diff) | |
download | cpython-a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344.zip cpython-a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344.tar.gz cpython-a6a4dc816d68df04a7d592e0b6af8c7ecc4d4344.tar.bz2 |
bpo-31370: Remove support for threads-less builds (#3385)
* Remove Setup.config
* Always define WITH_THREAD for compatibility.
Diffstat (limited to 'Lib')
81 files changed, 2372 insertions, 3387 deletions
diff --git a/Lib/_dummy_thread.py b/Lib/_dummy_thread.py deleted file mode 100644 index a2cae54..0000000 --- a/Lib/_dummy_thread.py +++ /dev/null @@ -1,163 +0,0 @@ -"""Drop-in replacement for the thread module. - -Meant to be used as a brain-dead substitute so that threaded code does -not need to be rewritten for when the thread module is not present. - -Suggested usage is:: - - try: - import _thread - except ImportError: - import _dummy_thread as _thread - -""" -# Exports only things specified by thread documentation; -# skipping obsolete synonyms allocate(), start_new(), exit_thread(). -__all__ = ['error', 'start_new_thread', 'exit', 'get_ident', 'allocate_lock', - 'interrupt_main', 'LockType'] - -# A dummy value -TIMEOUT_MAX = 2**31 - -# NOTE: this module can be imported early in the extension building process, -# and so top level imports of other modules should be avoided. Instead, all -# imports are done when needed on a function-by-function basis. Since threads -# are disabled, the import lock should not be an issue anyway (??). - -error = RuntimeError - -def start_new_thread(function, args, kwargs={}): - """Dummy implementation of _thread.start_new_thread(). - - Compatibility is maintained by making sure that ``args`` is a - tuple and ``kwargs`` is a dictionary. If an exception is raised - and it is SystemExit (which can be done by _thread.exit()) it is - caught and nothing is done; all other exceptions are printed out - by using traceback.print_exc(). - - If the executed function calls interrupt_main the KeyboardInterrupt will be - raised when the function returns. - - """ - if type(args) != type(tuple()): - raise TypeError("2nd arg must be a tuple") - if type(kwargs) != type(dict()): - raise TypeError("3rd arg must be a dict") - global _main - _main = False - try: - function(*args, **kwargs) - except SystemExit: - pass - except: - import traceback - traceback.print_exc() - _main = True - global _interrupt - if _interrupt: - _interrupt = False - raise KeyboardInterrupt - -def exit(): - """Dummy implementation of _thread.exit().""" - raise SystemExit - -def get_ident(): - """Dummy implementation of _thread.get_ident(). - - Since this module should only be used when _threadmodule is not - available, it is safe to assume that the current process is the - only thread. Thus a constant can be safely returned. - """ - return 1 - -def allocate_lock(): - """Dummy implementation of _thread.allocate_lock().""" - return LockType() - -def stack_size(size=None): - """Dummy implementation of _thread.stack_size().""" - if size is not None: - raise error("setting thread stack size not supported") - return 0 - -def _set_sentinel(): - """Dummy implementation of _thread._set_sentinel().""" - return LockType() - -class LockType(object): - """Class implementing dummy implementation of _thread.LockType. - - Compatibility is maintained by maintaining self.locked_status - which is a boolean that stores the state of the lock. Pickling of - the lock, though, should not be done since if the _thread module is - then used with an unpickled ``lock()`` from here problems could - occur from this class not having atomic methods. - - """ - - def __init__(self): - self.locked_status = False - - def acquire(self, waitflag=None, timeout=-1): - """Dummy implementation of acquire(). - - For blocking calls, self.locked_status is automatically set to - True and returned appropriately based on value of - ``waitflag``. If it is non-blocking, then the value is - actually checked and not set if it is already acquired. This - is all done so that threading.Condition's assert statements - aren't triggered and throw a little fit. - - """ - if waitflag is None or waitflag: - self.locked_status = True - return True - else: - if not self.locked_status: - self.locked_status = True - return True - else: - if timeout > 0: - import time - time.sleep(timeout) - return False - - __enter__ = acquire - - def __exit__(self, typ, val, tb): - self.release() - - def release(self): - """Release the dummy lock.""" - # XXX Perhaps shouldn't actually bother to test? Could lead - # to problems for complex, threaded code. - if not self.locked_status: - raise error - self.locked_status = False - return True - - def locked(self): - return self.locked_status - - def __repr__(self): - return "<%s %s.%s object at %s>" % ( - "locked" if self.locked_status else "unlocked", - self.__class__.__module__, - self.__class__.__qualname__, - hex(id(self)) - ) - -# Used to signal that interrupt_main was called in a "thread" -_interrupt = False -# True when not executing in a "thread" -_main = True - -def interrupt_main(): - """Set _interrupt flag to True to have start_new_thread raise - KeyboardInterrupt upon exiting.""" - if _main: - raise KeyboardInterrupt - else: - global _interrupt - _interrupt = True diff --git a/Lib/_pydecimal.py b/Lib/_pydecimal.py index edabf72..a43c75f 100644 --- a/Lib/_pydecimal.py +++ b/Lib/_pydecimal.py @@ -436,75 +436,34 @@ _rounding_modes = (ROUND_DOWN, ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_CEILING, # work for older Pythons. If threads are not part of the build, create a # mock threading object with threading.local() returning the module namespace. -try: - import threading -except ImportError: - # Python was compiled without threads; create a mock object instead - class MockThreading(object): - def local(self, sys=sys): - return sys.modules[__xname__] - threading = MockThreading() - del MockThreading - -try: - threading.local - -except AttributeError: - - # To fix reloading, force it to create a new context - # Old contexts have different exceptions in their dicts, making problems. - if hasattr(threading.current_thread(), '__decimal_context__'): - del threading.current_thread().__decimal_context__ - - def setcontext(context): - """Set this thread's context to context.""" - if context in (DefaultContext, BasicContext, ExtendedContext): - context = context.copy() - context.clear_flags() - threading.current_thread().__decimal_context__ = context - - def getcontext(): - """Returns this thread's context. - - If this thread does not yet have a context, returns - a new context and sets this thread's context. - New contexts are copies of DefaultContext. - """ - try: - return threading.current_thread().__decimal_context__ - except AttributeError: - context = Context() - threading.current_thread().__decimal_context__ = context - return context +import threading -else: +local = threading.local() +if hasattr(local, '__decimal_context__'): + del local.__decimal_context__ - local = threading.local() - if hasattr(local, '__decimal_context__'): - del local.__decimal_context__ +def getcontext(_local=local): + """Returns this thread's context. - def getcontext(_local=local): - """Returns this thread's context. - - If this thread does not yet have a context, returns - a new context and sets this thread's context. - New contexts are copies of DefaultContext. - """ - try: - return _local.__decimal_context__ - except AttributeError: - context = Context() - _local.__decimal_context__ = context - return context - - def setcontext(context, _local=local): - """Set this thread's context to context.""" - if context in (DefaultContext, BasicContext, ExtendedContext): - context = context.copy() - context.clear_flags() + If this thread does not yet have a context, returns + a new context and sets this thread's context. + New contexts are copies of DefaultContext. + """ + try: + return _local.__decimal_context__ + except AttributeError: + context = Context() _local.__decimal_context__ = context + return context + +def setcontext(context, _local=local): + """Set this thread's context to context.""" + if context in (DefaultContext, BasicContext, ExtendedContext): + context = context.copy() + context.clear_flags() + _local.__decimal_context__ = context - del threading, local # Don't contaminate the namespace +del threading, local # Don't contaminate the namespace def localcontext(ctx=None): """Return a context manager for a copy of the supplied context diff --git a/Lib/_pyio.py b/Lib/_pyio.py index 4653847..1e105f2 100644 --- a/Lib/_pyio.py +++ b/Lib/_pyio.py @@ -9,10 +9,7 @@ import errno import stat import sys # Import _thread instead of threading to reduce startup cost -try: - from _thread import allocate_lock as Lock -except ImportError: - from _dummy_thread import allocate_lock as Lock +from _thread import allocate_lock as Lock if sys.platform in {'win32', 'cygwin'}: from msvcrt import setmode as _setmode else: diff --git a/Lib/_strptime.py b/Lib/_strptime.py index fe94361..284175d 100644 --- a/Lib/_strptime.py +++ b/Lib/_strptime.py @@ -19,10 +19,7 @@ from re import escape as re_escape from datetime import (date as datetime_date, timedelta as datetime_timedelta, timezone as datetime_timezone) -try: - from _thread import allocate_lock as _thread_allocate_lock -except ImportError: - from _dummy_thread import allocate_lock as _thread_allocate_lock +from _thread import allocate_lock as _thread_allocate_lock __all__ = [] @@ -14,11 +14,7 @@ import io import os import warnings import _compression - -try: - from threading import RLock -except ImportError: - from dummy_threading import RLock +from threading import RLock from _bz2 import BZ2Compressor, BZ2Decompressor diff --git a/Lib/ctypes/test/test_errno.py b/Lib/ctypes/test/test_errno.py index 4690a0d..3685164 100644 --- a/Lib/ctypes/test/test_errno.py +++ b/Lib/ctypes/test/test_errno.py @@ -1,10 +1,8 @@ import unittest, os, errno +import threading + from ctypes import * from ctypes.util import find_library -try: - import threading -except ImportError: - threading = None class Test(unittest.TestCase): def test_open(self): @@ -25,25 +23,24 @@ class Test(unittest.TestCase): self.assertEqual(set_errno(32), errno.ENOENT) self.assertEqual(get_errno(), 32) - if threading: - def _worker(): - set_errno(0) + def _worker(): + set_errno(0) - libc = CDLL(libc_name, use_errno=False) - if os.name == "nt": - libc_open = libc._open - else: - libc_open = libc.open - libc_open.argtypes = c_char_p, c_int - self.assertEqual(libc_open(b"", 0), -1) - self.assertEqual(get_errno(), 0) + libc = CDLL(libc_name, use_errno=False) + if os.name == "nt": + libc_open = libc._open + else: + libc_open = libc.open + libc_open.argtypes = c_char_p, c_int + self.assertEqual(libc_open(b"", 0), -1) + self.assertEqual(get_errno(), 0) - t = threading.Thread(target=_worker) - t.start() - t.join() + t = threading.Thread(target=_worker) + t.start() + t.join() - self.assertEqual(get_errno(), 32) - set_errno(0) + self.assertEqual(get_errno(), 32) + set_errno(0) @unittest.skipUnless(os.name == "nt", 'Test specific to Windows') def test_GetLastError(self): diff --git a/Lib/dummy_threading.py b/Lib/dummy_threading.py deleted file mode 100644 index 1bb7eee..0000000 --- a/Lib/dummy_threading.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Faux ``threading`` version using ``dummy_thread`` instead of ``thread``. - -The module ``_dummy_threading`` is added to ``sys.modules`` in order -to not have ``threading`` considered imported. Had ``threading`` been -directly imported it would have made all subsequent imports succeed -regardless of whether ``_thread`` was available which is not desired. - -""" -from sys import modules as sys_modules - -import _dummy_thread - -# Declaring now so as to not have to nest ``try``s to get proper clean-up. -holding_thread = False -holding_threading = False -holding__threading_local = False - -try: - # Could have checked if ``_thread`` was not in sys.modules and gone - # a different route, but decided to mirror technique used with - # ``threading`` below. - if '_thread' in sys_modules: - held_thread = sys_modules['_thread'] - holding_thread = True - # Must have some module named ``_thread`` that implements its API - # in order to initially import ``threading``. - sys_modules['_thread'] = sys_modules['_dummy_thread'] - - if 'threading' in sys_modules: - # If ``threading`` is already imported, might as well prevent - # trying to import it more than needed by saving it if it is - # already imported before deleting it. - held_threading = sys_modules['threading'] - holding_threading = True - del sys_modules['threading'] - - if '_threading_local' in sys_modules: - # If ``_threading_local`` is already imported, might as well prevent - # trying to import it more than needed by saving it if it is - # already imported before deleting it. - held__threading_local = sys_modules['_threading_local'] - holding__threading_local = True - del sys_modules['_threading_local'] - - import threading - # Need a copy of the code kept somewhere... - sys_modules['_dummy_threading'] = sys_modules['threading'] - del sys_modules['threading'] - sys_modules['_dummy__threading_local'] = sys_modules['_threading_local'] - del sys_modules['_threading_local'] - from _dummy_threading import * - from _dummy_threading import __all__ - -finally: - # Put back ``threading`` if we overwrote earlier - - if holding_threading: - sys_modules['threading'] = held_threading - del held_threading - del holding_threading - - # Put back ``_threading_local`` if we overwrote earlier - - if holding__threading_local: - sys_modules['_threading_local'] = held__threading_local - del held__threading_local - del holding__threading_local - - # Put back ``thread`` if we overwrote, else del the entry we made - if holding_thread: - sys_modules['_thread'] = held_thread - del held_thread - else: - del sys_modules['_thread'] - del holding_thread - - del _dummy_thread - del sys_modules diff --git a/Lib/functools.py b/Lib/functools.py index 23ad160..25075de 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -22,13 +22,7 @@ from collections import namedtuple from types import MappingProxyType from weakref import WeakKeyDictionary from reprlib import recursive_repr -try: - from _thread import RLock -except ImportError: - class RLock: - 'Dummy reentrant lock for builds without threads' - def __enter__(self): pass - def __exit__(self, exctype, excinst, exctb): pass +from _thread import RLock ################################################################################ diff --git a/Lib/http/cookiejar.py b/Lib/http/cookiejar.py index adf956d..e0f1032 100644 --- a/Lib/http/cookiejar.py +++ b/Lib/http/cookiejar.py @@ -33,10 +33,7 @@ import datetime import re import time import urllib.parse, urllib.request -try: - import threading as _threading -except ImportError: - import dummy_threading as _threading +import threading as _threading import http.client # only for the default HTTP port from calendar import timegm diff --git a/Lib/logging/__init__.py b/Lib/logging/__init__.py index 83db827..19b96b8 100644 --- a/Lib/logging/__init__.py +++ b/Lib/logging/__init__.py @@ -37,10 +37,7 @@ __all__ = ['BASIC_FORMAT', 'BufferingFormatter', 'CRITICAL', 'DEBUG', 'ERROR', 'warn', 'warning', 'getLogRecordFactory', 'setLogRecordFactory', 'lastResort', 'raiseExceptions'] -try: - import threading -except ImportError: #pragma: no cover - threading = None +import threading __author__ = "Vinay Sajip <vinay_sajip@red-dove.com>" __status__ = "production" @@ -210,11 +207,7 @@ def _checkLevel(level): #the lock would already have been acquired - so we need an RLock. #The same argument applies to Loggers and Manager.loggerDict. # -if threading: - _lock = threading.RLock() -else: #pragma: no cover - _lock = None - +_lock = threading.RLock() def _acquireLock(): """ @@ -295,7 +288,7 @@ class LogRecord(object): self.created = ct self.msecs = (ct - int(ct)) * 1000 self.relativeCreated = (self.created - _startTime) * 1000 - if logThreads and threading: + if logThreads: self.thread = threading.get_ident() self.threadName = threading.current_thread().name else: # pragma: no cover @@ -799,10 +792,7 @@ class Handler(Filterer): """ Acquire a thread lock for serializing access to the underlying I/O. """ - if threading: - self.lock = threading.RLock() - else: #pragma: no cover - self.lock = None + self.lock = threading.RLock() def acquire(self): """ diff --git a/Lib/logging/config.py b/Lib/logging/config.py index b3f4e28..c16a75a 100644 --- a/Lib/logging/config.py +++ b/Lib/logging/config.py @@ -31,14 +31,9 @@ import logging.handlers import re import struct import sys +import threading import traceback -try: - import _thread as thread - import threading -except ImportError: #pragma: no cover - thread = None - from socketserver import ThreadingTCPServer, StreamRequestHandler @@ -816,8 +811,6 @@ def listen(port=DEFAULT_LOGGING_CONFIG_PORT, verify=None): normal. Note that you can return transformed bytes, e.g. by decrypting the bytes passed in. """ - if not thread: #pragma: no cover - raise NotImplementedError("listen() needs threading to work") class ConfigStreamHandler(StreamRequestHandler): """ diff --git a/Lib/logging/handlers.py b/Lib/logging/handlers.py index b5fdfbc..a815f03 100644 --- a/Lib/logging/handlers.py +++ b/Lib/logging/handlers.py @@ -26,10 +26,7 @@ To use, simply 'import logging.handlers' and log away! import logging, socket, os, pickle, struct, time, re from stat import ST_DEV, ST_INO, ST_MTIME import queue -try: - import threading -except ImportError: #pragma: no cover - threading = None +import threading # # Some constants... @@ -1395,110 +1392,110 @@ class QueueHandler(logging.Handler): except Exception: self.handleError(record) -if threading: - class QueueListener(object): - """ - This class implements an internal threaded listener which watches for - LogRecords being added to a queue, removes them and passes them to a - list of handlers for processing. - """ - _sentinel = None - - def __init__(self, queue, *handlers, respect_handler_level=False): - """ - Initialise an instance with the specified queue and - handlers. - """ - self.queue = queue - self.handlers = handlers - self._thread = None - self.respect_handler_level = respect_handler_level - - def dequeue(self, block): - """ - Dequeue a record and return it, optionally blocking. - - The base implementation uses get. You may want to override this method - if you want to use timeouts or work with custom queue implementations. - """ - return self.queue.get(block) - - def start(self): - """ - Start the listener. - - This starts up a background thread to monitor the queue for - LogRecords to process. - """ - self._thread = t = threading.Thread(target=self._monitor) - t.daemon = True - t.start() - - def prepare(self , record): - """ - Prepare a record for handling. - - This method just returns the passed-in record. You may want to - override this method if you need to do any custom marshalling or - manipulation of the record before passing it to the handlers. - """ - return record - - def handle(self, record): - """ - Handle a record. - - This just loops through the handlers offering them the record - to handle. - """ - record = self.prepare(record) - for handler in self.handlers: - if not self.respect_handler_level: - process = True - else: - process = record.levelno >= handler.level - if process: - handler.handle(record) - - def _monitor(self): - """ - Monitor the queue for records, and ask the handler - to deal with them. - - This method runs on a separate, internal thread. - The thread will terminate if it sees a sentinel object in the queue. - """ - q = self.queue - has_task_done = hasattr(q, 'task_done') - while True: - try: - record = self.dequeue(True) - if record is self._sentinel: - break - self.handle(record) - if has_task_done: - q.task_done() - except queue.Empty: + +class QueueListener(object): + """ + This class implements an internal threaded listener which watches for + LogRecords being added to a queue, removes them and passes them to a + list of handlers for processing. + """ + _sentinel = None + + def __init__(self, queue, *handlers, respect_handler_level=False): + """ + Initialise an instance with the specified queue and + handlers. + """ + self.queue = queue + self.handlers = handlers + self._thread = None + self.respect_handler_level = respect_handler_level + + def dequeue(self, block): + """ + Dequeue a record and return it, optionally blocking. + + The base implementation uses get. You may want to override this method + if you want to use timeouts or work with custom queue implementations. + """ + return self.queue.get(block) + + def start(self): + """ + Start the listener. + + This starts up a background thread to monitor the queue for + LogRecords to process. + """ + self._thread = t = threading.Thread(target=self._monitor) + t.daemon = True + t.start() + + def prepare(self , record): + """ + Prepare a record for handling. + + This method just returns the passed-in record. You may want to + override this method if you need to do any custom marshalling or + manipulation of the record before passing it to the handlers. + """ + return record + + def handle(self, record): + """ + Handle a record. + + This just loops through the handlers offering them the record + to handle. + """ + record = self.prepare(record) + for handler in self.handlers: + if not self.respect_handler_level: + process = True + else: + process = record.levelno >= handler.level + if process: + handler.handle(record) + + def _monitor(self): + """ + Monitor the queue for records, and ask the handler + to deal with them. + + This method runs on a separate, internal thread. + The thread will terminate if it sees a sentinel object in the queue. + """ + q = self.queue + has_task_done = hasattr(q, 'task_done') + while True: + try: + record = self.dequeue(True) + if record is self._sentinel: break + self.handle(record) + if has_task_done: + q.task_done() + except queue.Empty: + break + + def enqueue_sentinel(self): + """ + This is used to enqueue the sentinel record. + + The base implementation uses put_nowait. You may want to override this + method if you want to use timeouts or work with custom queue + implementations. + """ + self.queue.put_nowait(self._sentinel) - def enqueue_sentinel(self): - """ - This is used to enqueue the sentinel record. - - The base implementation uses put_nowait. You may want to override this - method if you want to use timeouts or work with custom queue - implementations. - """ - self.queue.put_nowait(self._sentinel) - - def stop(self): - """ - Stop the listener. - - This asks the thread to terminate, and then waits for it to do so. - Note that if you don't call this before your application exits, there - may be some records still left on the queue, which won't be processed. - """ - self.enqueue_sentinel() - self._thread.join() - self._thread = None + def stop(self): + """ + Stop the listener. + + This asks the thread to terminate, and then waits for it to do so. + Note that if you don't call this before your application exits, there + may be some records still left on the queue, which won't be processed. + """ + self.enqueue_sentinel() + self._thread.join() + self._thread = None diff --git a/Lib/queue.py b/Lib/queue.py index 572425e..c803b96 100644 --- a/Lib/queue.py +++ b/Lib/queue.py @@ -1,9 +1,6 @@ '''A multi-producer, multi-consumer queue.''' -try: - import threading -except ImportError: - import dummy_threading as threading +import threading from collections import deque from heapq import heappush, heappop from time import monotonic as time diff --git a/Lib/reprlib.py b/Lib/reprlib.py index 40d991f..616b343 100644 --- a/Lib/reprlib.py +++ b/Lib/reprlib.py @@ -4,10 +4,7 @@ __all__ = ["Repr", "repr", "recursive_repr"] import builtins from itertools import islice -try: - from _thread import get_ident -except ImportError: - from _dummy_thread import get_ident +from _thread import get_ident def recursive_repr(fillvalue='...'): 'Decorator to make a repr function return fillvalue for a recursive call' diff --git a/Lib/sched.py b/Lib/sched.py index 3d8c011..ff87874 100644 --- a/Lib/sched.py +++ b/Lib/sched.py @@ -26,10 +26,7 @@ has another way to reference private data (besides global variables). import time import heapq from collections import namedtuple -try: - import threading -except ImportError: - import dummy_threading as threading +import threading from time import monotonic as _time __all__ = ["scheduler"] diff --git a/Lib/socketserver.py b/Lib/socketserver.py index df17114..721eb50 100644 --- a/Lib/socketserver.py +++ b/Lib/socketserver.py @@ -127,10 +127,7 @@ import socket import selectors import os import sys -try: - import threading -except ImportError: - import dummy_threading as threading +import threading from io import BufferedIOBase from time import monotonic as time diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index cb85814..12c9fb4 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -21,12 +21,9 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. +import threading import unittest import sqlite3 as sqlite -try: - import threading -except ImportError: - threading = None from test.support import TESTFN, unlink @@ -503,7 +500,6 @@ class CursorTests(unittest.TestCase): self.assertEqual(results, expected) -@unittest.skipUnless(threading, 'This test requires threading.') class ThreadTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") diff --git a/Lib/subprocess.py b/Lib/subprocess.py index 25e5698..f61ff0c 100644 --- a/Lib/subprocess.py +++ b/Lib/subprocess.py @@ -138,10 +138,7 @@ else: import _posixsubprocess import select import selectors - try: - import threading - except ImportError: - import dummy_threading as threading + import threading # When select or poll has indicated that the file is writable, # we can write up to _PIPE_BUF bytes without risk of blocking. diff --git a/Lib/tempfile.py b/Lib/tempfile.py index 6146235..71ecafa 100644 --- a/Lib/tempfile.py +++ b/Lib/tempfile.py @@ -44,11 +44,7 @@ import shutil as _shutil import errno as _errno from random import Random as _Random import weakref as _weakref - -try: - import _thread -except ImportError: - import _dummy_thread as _thread +import _thread _allocate_lock = _thread.allocate_lock _text_openflags = _os.O_RDWR | _os.O_CREAT | _os.O_EXCL diff --git a/Lib/test/fork_wait.py b/Lib/test/fork_wait.py index 6af79ad..9850b06 100644 --- a/Lib/test/fork_wait.py +++ b/Lib/test/fork_wait.py @@ -10,9 +10,9 @@ active threads survive in the child after a fork(); this is an error. """ import os, sys, time, unittest +import threading import test.support as support -threading = support.import_module('threading') LONGSLEEP = 2 SHORTSLEEP = 0.5 diff --git a/Lib/test/libregrtest/runtest_mp.py b/Lib/test/libregrtest/runtest_mp.py index 779ff01..31b830d 100644 --- a/Lib/test/libregrtest/runtest_mp.py +++ b/Lib/test/libregrtest/runtest_mp.py @@ -3,15 +3,11 @@ import json import os import queue import sys +import threading import time import traceback import types from test import support -try: - import threading -except ImportError: - print("Multiprocess option requires thread support") - sys.exit(2) from test.libregrtest.runtest import ( runtest, INTERRUPTED, CHILD_ERROR, PROGRESS_MIN_TIME, diff --git a/Lib/test/libregrtest/save_env.py b/Lib/test/libregrtest/save_env.py index 3c45621..45b365d 100644 --- a/Lib/test/libregrtest/save_env.py +++ b/Lib/test/libregrtest/save_env.py @@ -5,13 +5,10 @@ import os import shutil import sys import sysconfig +import threading import warnings from test import support try: - import threading -except ImportError: - threading = None -try: import _multiprocessing, multiprocessing.process except ImportError: multiprocessing = None @@ -181,13 +178,9 @@ class saved_test_environment: # Controlling dangling references to Thread objects can make it easier # to track reference leaks. def get_threading__dangling(self): - if not threading: - return None # This copies the weakrefs without making any strong reference return threading._dangling.copy() def restore_threading__dangling(self, saved): - if not threading: - return threading._dangling.clear() threading._dangling.update(saved) diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 0235498..bfceba1 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -25,6 +25,8 @@ import subprocess import sys import sysconfig import tempfile +import _thread +import threading import time import types import unittest @@ -32,11 +34,6 @@ import urllib.error import warnings try: - import _thread, threading -except ImportError: - _thread = None - threading = None -try: import multiprocessing.process except ImportError: multiprocessing = None @@ -2028,16 +2025,11 @@ environment_altered = False # at the end of a test run. def threading_setup(): - if _thread: - return _thread._count(), threading._dangling.copy() - else: - return 1, () + return _thread._count(), threading._dangling.copy() def threading_cleanup(*original_values): global environment_altered - if not _thread: - return _MAX_COUNT = 100 t0 = time.monotonic() for count in range(_MAX_COUNT): @@ -2061,9 +2053,6 @@ def reap_threads(func): ensure that the threads are cleaned up even when the test fails. If threading is unavailable this function does nothing. """ - if not _thread: - return func - @functools.wraps(func) def decorator(*args): key = threading_setup() diff --git a/Lib/test/test_asynchat.py b/Lib/test/test_asynchat.py index 0eba76d..2362834 100644 --- a/Lib/test/test_asynchat.py +++ b/Lib/test/test_asynchat.py @@ -2,110 +2,104 @@ from test import support -# If this fails, the test will be skipped. -thread = support.import_module('_thread') - import asynchat import asyncore import errno import socket import sys +import _thread as thread +import threading import time import unittest import unittest.mock -try: - import threading -except ImportError: - threading = None HOST = support.HOST SERVER_QUIT = b'QUIT\n' TIMEOUT = 3.0 -if threading: - class echo_server(threading.Thread): - # parameter to determine the number of bytes passed back to the - # client each send - chunk_size = 1 - - def __init__(self, event): - threading.Thread.__init__(self) - self.event = event - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = support.bind_port(self.sock) - # This will be set if the client wants us to wait before echoing - # data back. - self.start_resend_event = None - - def run(self): - self.sock.listen() - self.event.set() - conn, client = self.sock.accept() - self.buffer = b"" - # collect data until quit message is seen - while SERVER_QUIT not in self.buffer: - data = conn.recv(1) - if not data: - break - self.buffer = self.buffer + data - - # remove the SERVER_QUIT message - self.buffer = self.buffer.replace(SERVER_QUIT, b'') - - if self.start_resend_event: - self.start_resend_event.wait() - - # re-send entire set of collected data - try: - # this may fail on some tests, such as test_close_when_done, - # since the client closes the channel when it's done sending - while self.buffer: - n = conn.send(self.buffer[:self.chunk_size]) - time.sleep(0.001) - self.buffer = self.buffer[n:] - except: - pass - - conn.close() - self.sock.close() - class echo_client(asynchat.async_chat): - - def __init__(self, terminator, server_port): - asynchat.async_chat.__init__(self) - self.contents = [] - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) - self.connect((HOST, server_port)) - self.set_terminator(terminator) - self.buffer = b"" - - def handle_connect(self): +class echo_server(threading.Thread): + # parameter to determine the number of bytes passed back to the + # client each send + chunk_size = 1 + + def __init__(self, event): + threading.Thread.__init__(self) + self.event = event + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.port = support.bind_port(self.sock) + # This will be set if the client wants us to wait before echoing + # data back. + self.start_resend_event = None + + def run(self): + self.sock.listen() + self.event.set() + conn, client = self.sock.accept() + self.buffer = b"" + # collect data until quit message is seen + while SERVER_QUIT not in self.buffer: + data = conn.recv(1) + if not data: + break + self.buffer = self.buffer + data + + # remove the SERVER_QUIT message + self.buffer = self.buffer.replace(SERVER_QUIT, b'') + + if self.start_resend_event: + self.start_resend_event.wait() + + # re-send entire set of collected data + try: + # this may fail on some tests, such as test_close_when_done, + # since the client closes the channel when it's done sending + while self.buffer: + n = conn.send(self.buffer[:self.chunk_size]) + time.sleep(0.001) + self.buffer = self.buffer[n:] + except: + pass + + conn.close() + self.sock.close() + +class echo_client(asynchat.async_chat): + + def __init__(self, terminator, server_port): + asynchat.async_chat.__init__(self) + self.contents = [] + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.connect((HOST, server_port)) + self.set_terminator(terminator) + self.buffer = b"" + + def handle_connect(self): + pass + + if sys.platform == 'darwin': + # select.poll returns a select.POLLHUP at the end of the tests + # on darwin, so just ignore it + def handle_expt(self): pass - if sys.platform == 'darwin': - # select.poll returns a select.POLLHUP at the end of the tests - # on darwin, so just ignore it - def handle_expt(self): - pass - - def collect_incoming_data(self, data): - self.buffer += data + def collect_incoming_data(self, data): + self.buffer += data - def found_terminator(self): - self.contents.append(self.buffer) - self.buffer = b"" + def found_terminator(self): + self.contents.append(self.buffer) + self.buffer = b"" - def start_echo_server(): - event = threading.Event() - s = echo_server(event) - s.start() - event.wait() - event.clear() - time.sleep(0.01) # Give server time to start accepting. - return s, event +def start_echo_server(): + event = threading.Event() + s = echo_server(event) + s.start() + event.wait() + event.clear() + time.sleep(0.01) # Give server time to start accepting. + return s, event -@unittest.skipUnless(threading, 'Threading required for this test.') class TestAsynchat(unittest.TestCase): usepoll = False diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py index 80a9eea..c77c7a8 100644 --- a/Lib/test/test_asyncio/__init__.py +++ b/Lib/test/test_asyncio/__init__.py @@ -1,8 +1,6 @@ import os from test.support import load_package_tests, import_module -# Skip tests if we don't have threading. -import_module('threading') # Skip tests if we don't have concurrent.futures. import_module('concurrent.futures') diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py index 07edf22..c8e9727 100644 --- a/Lib/test/test_asyncore.py +++ b/Lib/test/test_asyncore.py @@ -7,6 +7,7 @@ import sys import time import errno import struct +import threading from test import support from io import BytesIO @@ -14,10 +15,6 @@ from io import BytesIO if support.PGO: raise unittest.SkipTest("test is not helpful for PGO") -try: - import threading -except ImportError: - threading = None TIMEOUT = 3 HAS_UNIX_SOCKETS = hasattr(socket, 'AF_UNIX') @@ -326,7 +323,6 @@ class DispatcherWithSendTests(unittest.TestCase): def tearDown(self): asyncore.close_all() - @unittest.skipUnless(threading, 'Threading required for this test.') @support.reap_threads def test_send(self): evt = threading.Event() @@ -776,7 +772,6 @@ class BaseTestAPI: self.assertTrue(s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR)) - @unittest.skipUnless(threading, 'Threading required for this test.') @support.reap_threads def test_quick_connect(self): # see: http://bugs.python.org/issue10340 diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py index e76b283..58dbf96 100644 --- a/Lib/test/test_bz2.py +++ b/Lib/test/test_bz2.py @@ -10,13 +10,10 @@ import pathlib import random import shutil import subprocess +import threading from test.support import unlink import _compression -try: - import threading -except ImportError: - threading = None # Skip tests if the bz2 module doesn't exist. bz2 = support.import_module('bz2') @@ -491,7 +488,6 @@ class BZ2FileTest(BaseTest): else: self.fail("1/0 didn't raise an exception") - @unittest.skipUnless(threading, 'Threading required for this test.') def testThreading(self): # Issue #7205: Using a BZ2File from several threads shouldn't deadlock. data = b"1" * 2**20 @@ -504,13 +500,6 @@ class BZ2FileTest(BaseTest): with support.start_threads(threads): pass - def testWithoutThreading(self): - module = support.import_fresh_module("bz2", blocked=("threading",)) - with module.BZ2File(self.filename, "wb") as f: - f.write(b"abc") - with module.BZ2File(self.filename, "rb") as f: - self.assertEqual(f.read(), b"abc") - def testMixedIterationAndReads(self): self.createTempFile() linelen = len(self.TEXT_LINES[0]) diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py index c3a04b4..1b826ee 100644 --- a/Lib/test/test_capi.py +++ b/Lib/test/test_capi.py @@ -11,6 +11,7 @@ import subprocess import sys import sysconfig import textwrap +import threading import time import unittest from test import support @@ -20,10 +21,7 @@ try: import _posixsubprocess except ImportError: _posixsubprocess = None -try: - import threading -except ImportError: - threading = None + # Skip this test if the _testcapi module isn't available. _testcapi = support.import_module('_testcapi') @@ -52,7 +50,6 @@ class CAPITest(unittest.TestCase): self.assertEqual(testfunction.attribute, "test") self.assertRaises(AttributeError, setattr, inst.testfunction, "attribute", "test") - @unittest.skipUnless(threading, 'Threading required for this test.') def test_no_FatalError_infinite_loop(self): with support.SuppressCrashReport(): p = subprocess.Popen([sys.executable, "-c", @@ -276,7 +273,6 @@ class CAPITest(unittest.TestCase): self.assertIn(b'MemoryError 3 30', out) -@unittest.skipUnless(threading, 'Threading required for this test.') class TestPendingCalls(unittest.TestCase): def pendingcalls_submit(self, l, n): @@ -685,7 +681,6 @@ class SkipitemTest(unittest.TestCase): parse((1,), {}, 'O|OO', ['', 'a', '']) -@unittest.skipUnless(threading, 'Threading required for this test.') class TestThreadState(unittest.TestCase): @support.reap_threads @@ -762,7 +757,6 @@ class PyMemDebugTests(unittest.TestCase): regex = regex.format(ptr=self.PTR_REGEX) self.assertRegex(out, regex) - @unittest.skipUnless(threading, 'Test requires a GIL (multithreading)') def check_malloc_without_gil(self, code): out = self.check(code) expected = ('Fatal Python error: Python memory allocator called ' diff --git a/Lib/test/test_concurrent_futures.py b/Lib/test/test_concurrent_futures.py index 03f8d1d..a888dca 100644 --- a/Lib/test/test_concurrent_futures.py +++ b/Lib/test/test_concurrent_futures.py @@ -4,10 +4,6 @@ import test.support test.support.import_module('_multiprocessing') # Skip tests if sem_open implementation is broken. test.support.import_module('multiprocessing.synchronize') -# import threading after _multiprocessing to raise a more relevant error -# message: "No module named _multiprocessing". _multiprocessing is not compiled -# without thread support. -test.support.import_module('threading') from test.support.script_helper import assert_python_ok diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 2301f75..64b6578 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -3,13 +3,10 @@ import io import sys import tempfile +import threading import unittest from contextlib import * # Tests __all__ from test import support -try: - import threading -except ImportError: - threading = None class TestAbstractContextManager(unittest.TestCase): @@ -275,7 +272,6 @@ class FileContextTestCase(unittest.TestCase): finally: support.unlink(tfn) -@unittest.skipUnless(threading, 'Threading required for this test.') class LockContextTestCase(unittest.TestCase): def boilerPlate(self, lock, locked): diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 53f71ca..5d9da0e 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -38,10 +38,7 @@ from test.support import (import_fresh_module, TestFailed, run_with_locale, cpython_only) import random import inspect -try: - import threading -except ImportError: - threading = None +import threading C = import_fresh_module('decimal', fresh=['_decimal']) @@ -1625,10 +1622,10 @@ class ThreadingTest(unittest.TestCase): DefaultContext.Emax = save_emax DefaultContext.Emin = save_emin -@unittest.skipUnless(threading, 'threading required') + class CThreadingTest(ThreadingTest): decimal = C -@unittest.skipUnless(threading, 'threading required') + class PyThreadingTest(ThreadingTest): decimal = P diff --git a/Lib/test/test_docxmlrpc.py b/Lib/test/test_docxmlrpc.py index 0090333..f077f05 100644 --- a/Lib/test/test_docxmlrpc.py +++ b/Lib/test/test_docxmlrpc.py @@ -1,8 +1,8 @@ from xmlrpc.server import DocXMLRPCServer import http.client import sys +import threading from test import support -threading = support.import_module('threading') import unittest def make_request_and_skipIf(condition, reason): diff --git a/Lib/test/test_dummy_thread.py b/Lib/test/test_dummy_thread.py deleted file mode 100644 index 0840be6..0000000 --- a/Lib/test/test_dummy_thread.py +++ /dev/null @@ -1,255 +0,0 @@ -import _dummy_thread as _thread -import time -import queue -import random -import unittest -from test import support -from unittest import mock - -DELAY = 0 - - -class LockTests(unittest.TestCase): - """Test lock objects.""" - - def setUp(self): - # Create a lock - self.lock = _thread.allocate_lock() - - def test_initlock(self): - #Make sure locks start locked - self.assertFalse(self.lock.locked(), - "Lock object is not initialized unlocked.") - - def test_release(self): - # Test self.lock.release() - self.lock.acquire() - self.lock.release() - self.assertFalse(self.lock.locked(), - "Lock object did not release properly.") - - def test_LockType_context_manager(self): - with _thread.LockType(): - pass - self.assertFalse(self.lock.locked(), - "Acquired Lock was not released") - - def test_improper_release(self): - #Make sure release of an unlocked thread raises RuntimeError - self.assertRaises(RuntimeError, self.lock.release) - - def test_cond_acquire_success(self): - #Make sure the conditional acquiring of the lock works. - self.assertTrue(self.lock.acquire(0), - "Conditional acquiring of the lock failed.") - - def test_cond_acquire_fail(self): - #Test acquiring locked lock returns False - self.lock.acquire(0) - self.assertFalse(self.lock.acquire(0), - "Conditional acquiring of a locked lock incorrectly " - "succeeded.") - - def test_uncond_acquire_success(self): - #Make sure unconditional acquiring of a lock works. - self.lock.acquire() - self.assertTrue(self.lock.locked(), - "Uncondional locking failed.") - - def test_uncond_acquire_return_val(self): - #Make sure that an unconditional locking returns True. - self.assertIs(self.lock.acquire(1), True, - "Unconditional locking did not return True.") - self.assertIs(self.lock.acquire(), True) - - def test_uncond_acquire_blocking(self): - #Make sure that unconditional acquiring of a locked lock blocks. - def delay_unlock(to_unlock, delay): - """Hold on to lock for a set amount of time before unlocking.""" - time.sleep(delay) - to_unlock.release() - - self.lock.acquire() - start_time = int(time.time()) - _thread.start_new_thread(delay_unlock,(self.lock, DELAY)) - if support.verbose: - print() - print("*** Waiting for thread to release the lock "\ - "(approx. %s sec.) ***" % DELAY) - self.lock.acquire() - end_time = int(time.time()) - if support.verbose: - print("done") - self.assertGreaterEqual(end_time - start_time, DELAY, - "Blocking by unconditional acquiring failed.") - - @mock.patch('time.sleep') - def test_acquire_timeout(self, mock_sleep): - """Test invoking acquire() with a positive timeout when the lock is - already acquired. Ensure that time.sleep() is invoked with the given - timeout and that False is returned.""" - - self.lock.acquire() - retval = self.lock.acquire(waitflag=0, timeout=1) - self.assertTrue(mock_sleep.called) - mock_sleep.assert_called_once_with(1) - self.assertEqual(retval, False) - - def test_lock_representation(self): - self.lock.acquire() - self.assertIn("locked", repr(self.lock)) - self.lock.release() - self.assertIn("unlocked", repr(self.lock)) - - -class MiscTests(unittest.TestCase): - """Miscellaneous tests.""" - - def test_exit(self): - self.assertRaises(SystemExit, _thread.exit) - - def test_ident(self): - self.assertIsInstance(_thread.get_ident(), int, - "_thread.get_ident() returned a non-integer") - self.assertGreater(_thread.get_ident(), 0) - - def test_LockType(self): - self.assertIsInstance(_thread.allocate_lock(), _thread.LockType, - "_thread.LockType is not an instance of what " - "is returned by _thread.allocate_lock()") - - def test_set_sentinel(self): - self.assertIsInstance(_thread._set_sentinel(), _thread.LockType, - "_thread._set_sentinel() did not return a " - "LockType instance.") - - def test_interrupt_main(self): - #Calling start_new_thread with a function that executes interrupt_main - # should raise KeyboardInterrupt upon completion. - def call_interrupt(): - _thread.interrupt_main() - - self.assertRaises(KeyboardInterrupt, - _thread.start_new_thread, - call_interrupt, - tuple()) - - def test_interrupt_in_main(self): - self.assertRaises(KeyboardInterrupt, _thread.interrupt_main) - - def test_stack_size_None(self): - retval = _thread.stack_size(None) - self.assertEqual(retval, 0) - - def test_stack_size_not_None(self): - with self.assertRaises(_thread.error) as cm: - _thread.stack_size("") - self.assertEqual(cm.exception.args[0], - "setting thread stack size not supported") - - -class ThreadTests(unittest.TestCase): - """Test thread creation.""" - - def test_arg_passing(self): - #Make sure that parameter passing works. - def arg_tester(queue, arg1=False, arg2=False): - """Use to test _thread.start_new_thread() passes args properly.""" - queue.put((arg1, arg2)) - - testing_queue = queue.Queue(1) - _thread.start_new_thread(arg_tester, (testing_queue, True, True)) - result = testing_queue.get() - self.assertTrue(result[0] and result[1], - "Argument passing for thread creation " - "using tuple failed") - - _thread.start_new_thread( - arg_tester, - tuple(), - {'queue':testing_queue, 'arg1':True, 'arg2':True}) - - result = testing_queue.get() - self.assertTrue(result[0] and result[1], - "Argument passing for thread creation " - "using kwargs failed") - - _thread.start_new_thread( - arg_tester, - (testing_queue, True), - {'arg2':True}) - - result = testing_queue.get() - self.assertTrue(result[0] and result[1], - "Argument passing for thread creation using both tuple" - " and kwargs failed") - - def test_multi_thread_creation(self): - def queue_mark(queue, delay): - time.sleep(delay) - queue.put(_thread.get_ident()) - - thread_count = 5 - testing_queue = queue.Queue(thread_count) - - if support.verbose: - print() - print("*** Testing multiple thread creation " - "(will take approx. %s to %s sec.) ***" % ( - DELAY, thread_count)) - - for count in range(thread_count): - if DELAY: - local_delay = round(random.random(), 1) - else: - local_delay = 0 - _thread.start_new_thread(queue_mark, - (testing_queue, local_delay)) - time.sleep(DELAY) - if support.verbose: - print('done') - self.assertEqual(testing_queue.qsize(), thread_count, - "Not all %s threads executed properly " - "after %s sec." % (thread_count, DELAY)) - - def test_args_not_tuple(self): - """ - Test invoking start_new_thread() with a non-tuple value for "args". - Expect TypeError with a meaningful error message to be raised. - """ - with self.assertRaises(TypeError) as cm: - _thread.start_new_thread(mock.Mock(), []) - self.assertEqual(cm.exception.args[0], "2nd arg must be a tuple") - - def test_kwargs_not_dict(self): - """ - Test invoking start_new_thread() with a non-dict value for "kwargs". - Expect TypeError with a meaningful error message to be raised. - """ - with self.assertRaises(TypeError) as cm: - _thread.start_new_thread(mock.Mock(), tuple(), kwargs=[]) - self.assertEqual(cm.exception.args[0], "3rd arg must be a dict") - - def test_SystemExit(self): - """ - Test invoking start_new_thread() with a function that raises - SystemExit. - The exception should be discarded. - """ - func = mock.Mock(side_effect=SystemExit()) - try: - _thread.start_new_thread(func, tuple()) - except SystemExit: - self.fail("start_new_thread raised SystemExit.") - - @mock.patch('traceback.print_exc') - def test_RaiseException(self, mock_print_exc): - """ - Test invoking start_new_thread() with a function that raises exception. - - The exception should be discarded and the traceback should be printed - via traceback.print_exc() - """ - func = mock.Mock(side_effect=Exception) - _thread.start_new_thread(func, tuple()) - self.assertTrue(mock_print_exc.called) diff --git a/Lib/test/test_dummy_threading.py b/Lib/test/test_dummy_threading.py deleted file mode 100644 index a0c2972..0000000 --- a/Lib/test/test_dummy_threading.py +++ /dev/null @@ -1,60 +0,0 @@ -from test import support -import unittest -import dummy_threading as _threading -import time - -class DummyThreadingTestCase(unittest.TestCase): - - class TestThread(_threading.Thread): - - def run(self): - global running - global sema - global mutex - # Uncomment if testing another module, such as the real 'threading' - # module. - #delay = random.random() * 2 - delay = 0 - if support.verbose: - print('task', self.name, 'will run for', delay, 'sec') - sema.acquire() - mutex.acquire() - running += 1 - if support.verbose: - print(running, 'tasks are running') - mutex.release() - time.sleep(delay) - if support.verbose: - print('task', self.name, 'done') - mutex.acquire() - running -= 1 - if support.verbose: - print(self.name, 'is finished.', running, 'tasks are running') - mutex.release() - sema.release() - - def setUp(self): - self.numtasks = 10 - global sema - sema = _threading.BoundedSemaphore(value=3) - global mutex - mutex = _threading.RLock() - global running - running = 0 - self.threads = [] - - def test_tasks(self): - for i in range(self.numtasks): - t = self.TestThread(name="<thread %d>"%i) - self.threads.append(t) - t.start() - - if support.verbose: - print('waiting for all tasks to complete') - for t in self.threads: - t.join() - if support.verbose: - print('all tasks done') - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/test/test_email/test_email.py b/Lib/test/test_email/test_email.py index f97ccc6..621754c 100644 --- a/Lib/test/test_email/test_email.py +++ b/Lib/test/test_email/test_email.py @@ -12,10 +12,7 @@ from io import StringIO, BytesIO from itertools import chain from random import choice from socket import getfqdn -try: - from threading import Thread -except ImportError: - from dummy_threading import Thread +from threading import Thread import email import email.policy diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index ea52de7..e6324d4 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -2,15 +2,12 @@ import enum import inspect import pydoc import unittest +import threading from collections import OrderedDict from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique, auto from io import StringIO from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL from test import support -try: - import threading -except ImportError: - threading = None # for pickle tests @@ -1988,7 +1985,6 @@ class TestFlag(unittest.TestCase): d = 6 self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>') - @unittest.skipUnless(threading, 'Threading required for this test.') @support.reap_threads def test_unique_composite(self): # override __eq__ to be identity only @@ -2339,7 +2335,6 @@ class TestIntFlag(unittest.TestCase): for f in Open: self.assertEqual(bool(f.value), bool(f)) - @unittest.skipUnless(threading, 'Threading required for this test.') @support.reap_threads def test_unique_composite(self): # override __eq__ to be identity only diff --git a/Lib/test/test_faulthandler.py b/Lib/test/test_faulthandler.py index e2fcb2b..889e641 100644 --- a/Lib/test/test_faulthandler.py +++ b/Lib/test/test_faulthandler.py @@ -8,15 +8,11 @@ import sys from test import support from test.support import script_helper, is_android, requires_android_level import tempfile +import threading import unittest from textwrap import dedent try: - import threading - HAVE_THREADS = True -except ImportError: - HAVE_THREADS = False -try: import _testcapi except ImportError: _testcapi = None @@ -154,7 +150,6 @@ class FaultHandlerTests(unittest.TestCase): 3, 'Segmentation fault') - @unittest.skipIf(not HAVE_THREADS, 'need threads') def test_fatal_error_c_thread(self): self.check_fatal_error(""" import faulthandler @@ -231,7 +226,7 @@ class FaultHandlerTests(unittest.TestCase): 2, 'xyz') - @unittest.skipIf(sys.platform.startswith('openbsd') and HAVE_THREADS, + @unittest.skipIf(sys.platform.startswith('openbsd'), "Issue #12868: sigaltstack() doesn't work on " "OpenBSD if Python is compiled with pthread") @unittest.skipIf(not hasattr(faulthandler, '_stack_overflow'), @@ -456,7 +451,6 @@ class FaultHandlerTests(unittest.TestCase): self.assertEqual(trace, expected) self.assertEqual(exitcode, 0) - @unittest.skipIf(not HAVE_THREADS, 'need threads') def check_dump_traceback_threads(self, filename): """ Call explicitly dump_traceback(all_threads=True) and check the output. diff --git a/Lib/test/test_fork1.py b/Lib/test/test_fork1.py index da46fe5..9ca9724 100644 --- a/Lib/test/test_fork1.py +++ b/Lib/test/test_fork1.py @@ -5,6 +5,7 @@ import _imp as imp import os import signal import sys +import threading import time import unittest @@ -12,7 +13,6 @@ from test.fork_wait import ForkWait from test.support import (reap_children, get_attribute, import_module, verbose) -threading = import_module('threading') # Skip test if fork does not exist. get_attribute(os, 'fork') diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index 24ea382..151e091 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -10,6 +10,7 @@ import socket import io import errno import os +import threading import time try: import ssl @@ -19,7 +20,6 @@ except ImportError: from unittest import TestCase, skipUnless from test import support from test.support import HOST, HOSTv6 -threading = support.import_module('threading') TIMEOUT = 3 # the dummy data returned by server over the data channel when diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 3acfb92..f7a1166 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -8,15 +8,12 @@ import pickle from random import choice import sys from test import support +import threading import time import unittest import unittest.mock from weakref import proxy import contextlib -try: - import threading -except ImportError: - threading = None import functools @@ -1406,7 +1403,6 @@ class TestLRU: for attr in self.module.WRAPPER_ASSIGNMENTS: self.assertEqual(getattr(g, attr), getattr(f, attr)) - @unittest.skipUnless(threading, 'This test requires threading.') def test_lru_cache_threaded(self): n, m = 5, 11 def orig(x, y): @@ -1455,7 +1451,6 @@ class TestLRU: finally: sys.setswitchinterval(orig_si) - @unittest.skipUnless(threading, 'This test requires threading.') def test_lru_cache_threaded2(self): # Simultaneous call with the same arguments n, m = 5, 7 @@ -1483,7 +1478,6 @@ class TestLRU: pause.reset() self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) - @unittest.skipUnless(threading, 'This test requires threading.') def test_lru_cache_threaded3(self): @self.module.lru_cache(maxsize=2) def f(x): diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py index e727499..914efec 100644 --- a/Lib/test/test_gc.py +++ b/Lib/test/test_gc.py @@ -8,11 +8,7 @@ import sys import time import gc import weakref - -try: - import threading -except ImportError: - threading = None +import threading try: from _testcapi import with_tp_del @@ -352,7 +348,6 @@ class GCTests(unittest.TestCase): v = {1: v, 2: Ouch()} gc.disable() - @unittest.skipUnless(threading, "test meaningless on builds without threads") def test_trashcan_threads(self): # Issue #13992: trashcan mechanism should be thread-safe NESTING = 60 diff --git a/Lib/test/test_gdb.py b/Lib/test/test_gdb.py index 46736f6..9e0eaea 100644 --- a/Lib/test/test_gdb.py +++ b/Lib/test/test_gdb.py @@ -12,12 +12,6 @@ import sysconfig import textwrap import unittest -# Is this Python configured to support threads? -try: - import _thread -except ImportError: - _thread = None - from test import support from test.support import run_unittest, findfile, python_is_optimized @@ -755,8 +749,6 @@ Traceback \(most recent call first\): foo\(1, 2, 3\) ''') - @unittest.skipUnless(_thread, - "Python was compiled without thread support") def test_threads(self): 'Verify that "py-bt" indicates threads that are waiting for the GIL' cmd = ''' @@ -794,8 +786,6 @@ id(42) # Some older versions of gdb will fail with # "Cannot find new threads: generic error" # unless we add LD_PRELOAD=PATH-TO-libpthread.so.1 as a workaround - @unittest.skipUnless(_thread, - "Python was compiled without thread support") def test_gc(self): 'Verify that "py-bt" indicates if a thread is garbage-collecting' cmd = ('from gc import collect\n' @@ -822,8 +812,6 @@ id(42) # Some older versions of gdb will fail with # "Cannot find new threads: generic error" # unless we add LD_PRELOAD=PATH-TO-libpthread.so.1 as a workaround - @unittest.skipUnless(_thread, - "Python was compiled without thread support") def test_pycfunction(self): 'Verify that "py-bt" displays invocations of PyCFunction instances' # Tested function must not be defined with METH_NOARGS or METH_O, diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py index f748b46..5c8f090 100644 --- a/Lib/test/test_hashlib.py +++ b/Lib/test/test_hashlib.py @@ -12,10 +12,7 @@ import hashlib import itertools import os import sys -try: - import threading -except ImportError: - threading = None +import threading import unittest import warnings from test import support @@ -738,7 +735,6 @@ class HashLibTestCase(unittest.TestCase): m = hashlib.md5(b'x' * gil_minsize) self.assertEqual(m.hexdigest(), 'cfb767f225d58469c5de3632a8803958') - @unittest.skipUnless(threading, 'Threading required for this test.') @support.reap_threads def test_threaded_hashing(self): # Updating the same hash object from several threads at once diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index 8cddcdc..20e6f66 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -22,12 +22,13 @@ import urllib.parse import tempfile import time import datetime +import threading from unittest import mock from io import BytesIO import unittest from test import support -threading = support.import_module('threading') + class NoLogRequestHandler: def log_message(self, *args): diff --git a/Lib/test/test_idle.py b/Lib/test/test_idle.py index da05da5..b7ef70d 100644 --- a/Lib/test/test_idle.py +++ b/Lib/test/test_idle.py @@ -3,7 +3,6 @@ from test.support import import_module # Skip test if _thread or _tkinter wasn't built, if idlelib is missing, # or if tcl/tk is not the 8.5+ needed for ttk widgets. -import_module('threading') # imported by PyShell, imports _thread tk = import_module('tkinter') # imports _tkinter if tk.TkVersion < 8.5: raise unittest.SkipTest("IDLE requires tk 8.5 or later.") diff --git a/Lib/test/test_imaplib.py b/Lib/test/test_imaplib.py index 7c12438..132c586 100644 --- a/Lib/test/test_imaplib.py +++ b/Lib/test/test_imaplib.py @@ -1,8 +1,4 @@ from test import support -# If we end up with a significant number of tests that don't require -# threading, this test module should be split. Right now we skip -# them all if we don't have threading. -threading = support.import_module('threading') from contextlib import contextmanager import imaplib @@ -10,6 +6,7 @@ import os.path import socketserver import time import calendar +import threading from test.support import (reap_threads, verbose, transient_internet, run_with_tz, run_with_locale, cpython_only) diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py index 6f35f49..3bfcb09 100644 --- a/Lib/test/test_imp.py +++ b/Lib/test/test_imp.py @@ -1,7 +1,3 @@ -try: - import _thread -except ImportError: - _thread = None import importlib import importlib.util import os @@ -23,7 +19,6 @@ def requires_load_dynamic(meth): 'imp.load_dynamic() required')(meth) -@unittest.skipIf(_thread is None, '_thread module is required') class LockTests(unittest.TestCase): """Very basic test of import lock functions.""" diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py index dbce9c2..d86172a 100644 --- a/Lib/test/test_importlib/test_locks.py +++ b/Lib/test/test_importlib/test_locks.py @@ -3,134 +3,112 @@ from . import util as test_util init = test_util.import_importlib('importlib') import sys +import threading import unittest import weakref from test import support - -try: - import threading -except ImportError: - threading = None -else: - from test import lock_tests - -if threading is not None: - class ModuleLockAsRLockTests: - locktype = classmethod(lambda cls: cls.LockType("some_lock")) - - # _is_owned() unsupported - test__is_owned = None - # acquire(blocking=False) unsupported - test_try_acquire = None - test_try_acquire_contended = None - # `with` unsupported - test_with = None - # acquire(timeout=...) unsupported - test_timeout = None - # _release_save() unsupported - test_release_save_unacquired = None - # lock status in repr unsupported - test_repr = None - test_locked_repr = None - - LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock - for kind, splitinit in init.items()} - - (Frozen_ModuleLockAsRLockTests, - Source_ModuleLockAsRLockTests - ) = test_util.test_both(ModuleLockAsRLockTests, lock_tests.RLockTests, - LockType=LOCK_TYPES) -else: - LOCK_TYPES = {} - - class Frozen_ModuleLockAsRLockTests(unittest.TestCase): - pass - - class Source_ModuleLockAsRLockTests(unittest.TestCase): - pass - - -if threading is not None: - class DeadlockAvoidanceTests: - - def setUp(self): +from test import lock_tests + + +class ModuleLockAsRLockTests: + locktype = classmethod(lambda cls: cls.LockType("some_lock")) + + # _is_owned() unsupported + test__is_owned = None + # acquire(blocking=False) unsupported + test_try_acquire = None + test_try_acquire_contended = None + # `with` unsupported + test_with = None + # acquire(timeout=...) unsupported + test_timeout = None + # _release_save() unsupported + test_release_save_unacquired = None + # lock status in repr unsupported + test_repr = None + test_locked_repr = None + +LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock + for kind, splitinit in init.items()} + +(Frozen_ModuleLockAsRLockTests, + Source_ModuleLockAsRLockTests + ) = test_util.test_both(ModuleLockAsRLockTests, lock_tests.RLockTests, + LockType=LOCK_TYPES) + + +class DeadlockAvoidanceTests: + + def setUp(self): + try: + self.old_switchinterval = sys.getswitchinterval() + support.setswitchinterval(0.000001) + except AttributeError: + self.old_switchinterval = None + + def tearDown(self): + if self.old_switchinterval is not None: + sys.setswitchinterval(self.old_switchinterval) + + def run_deadlock_avoidance_test(self, create_deadlock): + NLOCKS = 10 + locks = [self.LockType(str(i)) for i in range(NLOCKS)] + pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)] + if create_deadlock: + NTHREADS = NLOCKS + else: + NTHREADS = NLOCKS - 1 + barrier = threading.Barrier(NTHREADS) + results = [] + + def _acquire(lock): + """Try to acquire the lock. Return True on success, + False on deadlock.""" try: - self.old_switchinterval = sys.getswitchinterval() - support.setswitchinterval(0.000001) - except AttributeError: - self.old_switchinterval = None - - def tearDown(self): - if self.old_switchinterval is not None: - sys.setswitchinterval(self.old_switchinterval) - - def run_deadlock_avoidance_test(self, create_deadlock): - NLOCKS = 10 - locks = [self.LockType(str(i)) for i in range(NLOCKS)] - pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)] - if create_deadlock: - NTHREADS = NLOCKS + lock.acquire() + except self.DeadlockError: + return False else: - NTHREADS = NLOCKS - 1 - barrier = threading.Barrier(NTHREADS) - results = [] - - def _acquire(lock): - """Try to acquire the lock. Return True on success, - False on deadlock.""" - try: - lock.acquire() - except self.DeadlockError: - return False - else: - return True - - def f(): - a, b = pairs.pop() - ra = _acquire(a) - barrier.wait() - rb = _acquire(b) - results.append((ra, rb)) - if rb: - b.release() - if ra: - a.release() - lock_tests.Bunch(f, NTHREADS).wait_for_finished() - self.assertEqual(len(results), NTHREADS) - return results - - def test_deadlock(self): - results = self.run_deadlock_avoidance_test(True) - # At least one of the threads detected a potential deadlock on its - # second acquire() call. It may be several of them, because the - # deadlock avoidance mechanism is conservative. - nb_deadlocks = results.count((True, False)) - self.assertGreaterEqual(nb_deadlocks, 1) - self.assertEqual(results.count((True, True)), len(results) - nb_deadlocks) - - def test_no_deadlock(self): - results = self.run_deadlock_avoidance_test(False) - self.assertEqual(results.count((True, False)), 0) - self.assertEqual(results.count((True, True)), len(results)) - - - DEADLOCK_ERRORS = {kind: splitinit._bootstrap._DeadlockError - for kind, splitinit in init.items()} - - (Frozen_DeadlockAvoidanceTests, - Source_DeadlockAvoidanceTests - ) = test_util.test_both(DeadlockAvoidanceTests, - LockType=LOCK_TYPES, - DeadlockError=DEADLOCK_ERRORS) -else: - DEADLOCK_ERRORS = {} - - class Frozen_DeadlockAvoidanceTests(unittest.TestCase): - pass - - class Source_DeadlockAvoidanceTests(unittest.TestCase): - pass + return True + + def f(): + a, b = pairs.pop() + ra = _acquire(a) + barrier.wait() + rb = _acquire(b) + results.append((ra, rb)) + if rb: + b.release() + if ra: + a.release() + lock_tests.Bunch(f, NTHREADS).wait_for_finished() + self.assertEqual(len(results), NTHREADS) + return results + + def test_deadlock(self): + results = self.run_deadlock_avoidance_test(True) + # At least one of the threads detected a potential deadlock on its + # second acquire() call. It may be several of them, because the + # deadlock avoidance mechanism is conservative. + nb_deadlocks = results.count((True, False)) + self.assertGreaterEqual(nb_deadlocks, 1) + self.assertEqual(results.count((True, True)), len(results) - nb_deadlocks) + + def test_no_deadlock(self): + results = self.run_deadlock_avoidance_test(False) + self.assertEqual(results.count((True, False)), 0) + self.assertEqual(results.count((True, True)), len(results)) + + +DEADLOCK_ERRORS = {kind: splitinit._bootstrap._DeadlockError + for kind, splitinit in init.items()} + +(Frozen_DeadlockAvoidanceTests, + Source_DeadlockAvoidanceTests + ) = test_util.test_both(DeadlockAvoidanceTests, + LockType=LOCK_TYPES, + DeadlockError=DEADLOCK_ERRORS) class LifetimeTests: diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py index 48270c8..d4685dd 100644 --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -28,6 +28,7 @@ import pickle import random import signal import sys +import threading import time import unittest import warnings @@ -40,10 +41,6 @@ from test.support.script_helper import assert_python_ok, run_python_until_end import codecs import io # C implementation of io import _pyio as pyio # Python implementation of io -try: - import threading -except ImportError: - threading = None try: import ctypes @@ -443,8 +440,6 @@ class IOTest(unittest.TestCase): (self.BytesIO, "rws"), (self.StringIO, "rws"), ) for [test, abilities] in tests: - if test is pipe_writer and not threading: - continue # Skip subtest that uses a background thread with self.subTest(test), test() as obj: readable = "r" in abilities self.assertEqual(obj.readable(), readable) @@ -1337,7 +1332,6 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests): self.assertEqual(b"abcdefg", bufio.read()) - @unittest.skipUnless(threading, 'Threading required for this test.') @support.requires_resource('cpu') def test_threads(self): try: @@ -1664,7 +1658,6 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests): with self.open(support.TESTFN, "rb", buffering=0) as f: self.assertEqual(f.read(), b"abc") - @unittest.skipUnless(threading, 'Threading required for this test.') @support.requires_resource('cpu') def test_threads(self): try: @@ -3053,7 +3046,6 @@ class TextIOWrapperTest(unittest.TestCase): self.assertEqual(f.errors, "replace") @support.no_tracing - @unittest.skipUnless(threading, 'Threading required for this test.') def test_threads_write(self): # Issue6750: concurrent writes could duplicate data event = threading.Event() @@ -3804,7 +3796,6 @@ class CMiscIOTest(MiscIOTest): b = bytearray(2) self.assertRaises(ValueError, bufio.readinto, b) - @unittest.skipUnless(threading, 'Threading required for this test.') def check_daemon_threads_shutdown_deadlock(self, stream_name): # Issue #23309: deadlocks at shutdown should be avoided when a # daemon thread and the main thread both write to a file. @@ -3868,7 +3859,6 @@ class SignalsTest(unittest.TestCase): def alarm_interrupt(self, sig, frame): 1/0 - @unittest.skipUnless(threading, 'Threading required for this test.') def check_interrupted_write(self, item, bytes, **fdopen_kwargs): """Check that a partial write, when it gets interrupted, properly invokes the signal handler, and bubbles up the exception raised @@ -3990,7 +3980,6 @@ class SignalsTest(unittest.TestCase): self.check_interrupted_read_retry(lambda x: x, mode="r") - @unittest.skipUnless(threading, 'Threading required for this test.') def check_interrupted_write_retry(self, item, **fdopen_kwargs): """Check that a buffered write, when it gets interrupted (either returning a partial result or EINTR), properly invokes the signal diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index f4aef9f..9c3816a 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -42,22 +42,19 @@ import tempfile from test.support.script_helper import assert_python_ok from test import support import textwrap +import threading import time import unittest import warnings import weakref -try: - import threading - # The following imports are needed only for tests which - # require threading - import asyncore - from http.server import HTTPServer, BaseHTTPRequestHandler - import smtpd - from urllib.parse import urlparse, parse_qs - from socketserver import (ThreadingUDPServer, DatagramRequestHandler, - ThreadingTCPServer, StreamRequestHandler) -except ImportError: - threading = None + +import asyncore +from http.server import HTTPServer, BaseHTTPRequestHandler +import smtpd +from urllib.parse import urlparse, parse_qs +from socketserver import (ThreadingUDPServer, DatagramRequestHandler, + ThreadingTCPServer, StreamRequestHandler) + try: import win32evtlog, win32evtlogutil, pywintypes except ImportError: @@ -625,7 +622,6 @@ class HandlerTest(BaseTest): os.unlink(fn) @unittest.skipIf(os.name == 'nt', 'WatchedFileHandler not appropriate for Windows.') - @unittest.skipUnless(threading, 'Threading required for this test.') def test_race(self): # Issue #14632 refers. def remove_loop(fname, tries): @@ -719,276 +715,274 @@ class StreamHandlerTest(BaseTest): # -- The following section could be moved into a server_helper.py module # -- if it proves to be of wider utility than just test_logging -if threading: - class TestSMTPServer(smtpd.SMTPServer): +class TestSMTPServer(smtpd.SMTPServer): + """ + This class implements a test SMTP server. + + :param addr: A (host, port) tuple which the server listens on. + You can specify a port value of zero: the server's + *port* attribute will hold the actual port number + used, which can be used in client connections. + :param handler: A callable which will be called to process + incoming messages. The handler will be passed + the client address tuple, who the message is from, + a list of recipients and the message data. + :param poll_interval: The interval, in seconds, used in the underlying + :func:`select` or :func:`poll` call by + :func:`asyncore.loop`. + :param sockmap: A dictionary which will be used to hold + :class:`asyncore.dispatcher` instances used by + :func:`asyncore.loop`. This avoids changing the + :mod:`asyncore` module's global state. + """ + + def __init__(self, addr, handler, poll_interval, sockmap): + smtpd.SMTPServer.__init__(self, addr, None, map=sockmap, + decode_data=True) + self.port = self.socket.getsockname()[1] + self._handler = handler + self._thread = None + self.poll_interval = poll_interval + + def process_message(self, peer, mailfrom, rcpttos, data): + """ + Delegates to the handler passed in to the server's constructor. + + Typically, this will be a test case method. + :param peer: The client (host, port) tuple. + :param mailfrom: The address of the sender. + :param rcpttos: The addresses of the recipients. + :param data: The message. """ - This class implements a test SMTP server. - - :param addr: A (host, port) tuple which the server listens on. - You can specify a port value of zero: the server's - *port* attribute will hold the actual port number - used, which can be used in client connections. - :param handler: A callable which will be called to process - incoming messages. The handler will be passed - the client address tuple, who the message is from, - a list of recipients and the message data. + self._handler(peer, mailfrom, rcpttos, data) + + def start(self): + """ + Start the server running on a separate daemon thread. + """ + self._thread = t = threading.Thread(target=self.serve_forever, + args=(self.poll_interval,)) + t.setDaemon(True) + t.start() + + def serve_forever(self, poll_interval): + """ + Run the :mod:`asyncore` loop until normal termination + conditions arise. :param poll_interval: The interval, in seconds, used in the underlying :func:`select` or :func:`poll` call by :func:`asyncore.loop`. - :param sockmap: A dictionary which will be used to hold - :class:`asyncore.dispatcher` instances used by - :func:`asyncore.loop`. This avoids changing the - :mod:`asyncore` module's global state. """ + try: + asyncore.loop(poll_interval, map=self._map) + except OSError: + # On FreeBSD 8, closing the server repeatably + # raises this error. We swallow it if the + # server has been closed. + if self.connected or self.accepting: + raise - def __init__(self, addr, handler, poll_interval, sockmap): - smtpd.SMTPServer.__init__(self, addr, None, map=sockmap, - decode_data=True) - self.port = self.socket.getsockname()[1] - self._handler = handler - self._thread = None - self.poll_interval = poll_interval + def stop(self, timeout=None): + """ + Stop the thread by closing the server instance. + Wait for the server thread to terminate. - def process_message(self, peer, mailfrom, rcpttos, data): - """ - Delegates to the handler passed in to the server's constructor. + :param timeout: How long to wait for the server thread + to terminate. + """ + self.close() + self._thread.join(timeout) + asyncore.close_all(map=self._map, ignore_all=True) - Typically, this will be a test case method. - :param peer: The client (host, port) tuple. - :param mailfrom: The address of the sender. - :param rcpttos: The addresses of the recipients. - :param data: The message. - """ - self._handler(peer, mailfrom, rcpttos, data) + alive = self._thread.is_alive() + self._thread = None + if alive: + self.fail("join() timed out") - def start(self): - """ - Start the server running on a separate daemon thread. - """ - self._thread = t = threading.Thread(target=self.serve_forever, - args=(self.poll_interval,)) - t.setDaemon(True) - t.start() +class ControlMixin(object): + """ + This mixin is used to start a server on a separate thread, and + shut it down programmatically. Request handling is simplified - instead + of needing to derive a suitable RequestHandler subclass, you just + provide a callable which will be passed each received request to be + processed. + + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. This handler is called on the + server thread, effectively meaning that requests are + processed serially. While not quite Web scale ;-), + this should be fine for testing applications. + :param poll_interval: The polling interval in seconds. + """ + def __init__(self, handler, poll_interval): + self._thread = None + self.poll_interval = poll_interval + self._handler = handler + self.ready = threading.Event() - def serve_forever(self, poll_interval): - """ - Run the :mod:`asyncore` loop until normal termination - conditions arise. - :param poll_interval: The interval, in seconds, used in the underlying - :func:`select` or :func:`poll` call by - :func:`asyncore.loop`. - """ - try: - asyncore.loop(poll_interval, map=self._map) - except OSError: - # On FreeBSD 8, closing the server repeatably - # raises this error. We swallow it if the - # server has been closed. - if self.connected or self.accepting: - raise - - def stop(self, timeout=None): - """ - Stop the thread by closing the server instance. - Wait for the server thread to terminate. + def start(self): + """ + Create a daemon thread to run the server, and start it. + """ + self._thread = t = threading.Thread(target=self.serve_forever, + args=(self.poll_interval,)) + t.setDaemon(True) + t.start() - :param timeout: How long to wait for the server thread - to terminate. - """ - self.close() - self._thread.join(timeout) - asyncore.close_all(map=self._map, ignore_all=True) + def serve_forever(self, poll_interval): + """ + Run the server. Set the ready flag before entering the + service loop. + """ + self.ready.set() + super(ControlMixin, self).serve_forever(poll_interval) + def stop(self, timeout=None): + """ + Tell the server thread to stop, and wait for it to do so. + + :param timeout: How long to wait for the server thread + to terminate. + """ + self.shutdown() + if self._thread is not None: + self._thread.join(timeout) alive = self._thread.is_alive() self._thread = None if alive: self.fail("join() timed out") + self.server_close() + self.ready.clear() - class ControlMixin(object): - """ - This mixin is used to start a server on a separate thread, and - shut it down programmatically. Request handling is simplified - instead - of needing to derive a suitable RequestHandler subclass, you just - provide a callable which will be passed each received request to be - processed. - - :param handler: A handler callable which will be called with a - single parameter - the request - in order to - process the request. This handler is called on the - server thread, effectively meaning that requests are - processed serially. While not quite Web scale ;-), - this should be fine for testing applications. - :param poll_interval: The polling interval in seconds. - """ - def __init__(self, handler, poll_interval): - self._thread = None - self.poll_interval = poll_interval - self._handler = handler - self.ready = threading.Event() +class TestHTTPServer(ControlMixin, HTTPServer): + """ + An HTTP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. + :param poll_interval: The polling interval in seconds. + :param log: Pass ``True`` to enable log messages. + """ + def __init__(self, addr, handler, poll_interval=0.5, + log=False, sslctx=None): + class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler): + def __getattr__(self, name, default=None): + if name.startswith('do_'): + return self.process_request + raise AttributeError(name) + + def process_request(self): + self.server._handler(self) + + def log_message(self, format, *args): + if log: + super(DelegatingHTTPRequestHandler, + self).log_message(format, *args) + HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler) + ControlMixin.__init__(self, handler, poll_interval) + self.sslctx = sslctx + + def get_request(self): + try: + sock, addr = self.socket.accept() + if self.sslctx: + sock = self.sslctx.wrap_socket(sock, server_side=True) + except OSError as e: + # socket errors are silenced by the caller, print them here + sys.stderr.write("Got an error:\n%s\n" % e) + raise + return sock, addr - def start(self): - """ - Create a daemon thread to run the server, and start it. - """ - self._thread = t = threading.Thread(target=self.serve_forever, - args=(self.poll_interval,)) - t.setDaemon(True) - t.start() +class TestTCPServer(ControlMixin, ThreadingTCPServer): + """ + A TCP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a single + parameter - the request - in order to process the request. + :param poll_interval: The polling interval in seconds. + :bind_and_activate: If True (the default), binds the server and starts it + listening. If False, you need to call + :meth:`server_bind` and :meth:`server_activate` at + some later time before calling :meth:`start`, so that + the server will set up the socket and listen on it. + """ - def serve_forever(self, poll_interval): - """ - Run the server. Set the ready flag before entering the - service loop. - """ - self.ready.set() - super(ControlMixin, self).serve_forever(poll_interval) + allow_reuse_address = True - def stop(self, timeout=None): - """ - Tell the server thread to stop, and wait for it to do so. + def __init__(self, addr, handler, poll_interval=0.5, + bind_and_activate=True): + class DelegatingTCPRequestHandler(StreamRequestHandler): - :param timeout: How long to wait for the server thread - to terminate. - """ - self.shutdown() - if self._thread is not None: - self._thread.join(timeout) - alive = self._thread.is_alive() - self._thread = None - if alive: - self.fail("join() timed out") - self.server_close() - self.ready.clear() - - class TestHTTPServer(ControlMixin, HTTPServer): - """ - An HTTP server which is controllable using :class:`ControlMixin`. - - :param addr: A tuple with the IP address and port to listen on. - :param handler: A handler callable which will be called with a - single parameter - the request - in order to - process the request. - :param poll_interval: The polling interval in seconds. - :param log: Pass ``True`` to enable log messages. - """ - def __init__(self, addr, handler, poll_interval=0.5, - log=False, sslctx=None): - class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler): - def __getattr__(self, name, default=None): - if name.startswith('do_'): - return self.process_request - raise AttributeError(name) - - def process_request(self): - self.server._handler(self) - - def log_message(self, format, *args): - if log: - super(DelegatingHTTPRequestHandler, - self).log_message(format, *args) - HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler) - ControlMixin.__init__(self, handler, poll_interval) - self.sslctx = sslctx - - def get_request(self): - try: - sock, addr = self.socket.accept() - if self.sslctx: - sock = self.sslctx.wrap_socket(sock, server_side=True) - except OSError as e: - # socket errors are silenced by the caller, print them here - sys.stderr.write("Got an error:\n%s\n" % e) - raise - return sock, addr + def handle(self): + self.server._handler(self) + ThreadingTCPServer.__init__(self, addr, DelegatingTCPRequestHandler, + bind_and_activate) + ControlMixin.__init__(self, handler, poll_interval) - class TestTCPServer(ControlMixin, ThreadingTCPServer): - """ - A TCP server which is controllable using :class:`ControlMixin`. - - :param addr: A tuple with the IP address and port to listen on. - :param handler: A handler callable which will be called with a single - parameter - the request - in order to process the request. - :param poll_interval: The polling interval in seconds. - :bind_and_activate: If True (the default), binds the server and starts it - listening. If False, you need to call - :meth:`server_bind` and :meth:`server_activate` at - some later time before calling :meth:`start`, so that - the server will set up the socket and listen on it. - """ + def server_bind(self): + super(TestTCPServer, self).server_bind() + self.port = self.socket.getsockname()[1] - allow_reuse_address = True +class TestUDPServer(ControlMixin, ThreadingUDPServer): + """ + A UDP server which is controllable using :class:`ControlMixin`. + + :param addr: A tuple with the IP address and port to listen on. + :param handler: A handler callable which will be called with a + single parameter - the request - in order to + process the request. + :param poll_interval: The polling interval for shutdown requests, + in seconds. + :bind_and_activate: If True (the default), binds the server and + starts it listening. If False, you need to + call :meth:`server_bind` and + :meth:`server_activate` at some later time + before calling :meth:`start`, so that the server will + set up the socket and listen on it. + """ + def __init__(self, addr, handler, poll_interval=0.5, + bind_and_activate=True): + class DelegatingUDPRequestHandler(DatagramRequestHandler): - def __init__(self, addr, handler, poll_interval=0.5, - bind_and_activate=True): - class DelegatingTCPRequestHandler(StreamRequestHandler): + def handle(self): + self.server._handler(self) - def handle(self): - self.server._handler(self) - ThreadingTCPServer.__init__(self, addr, DelegatingTCPRequestHandler, - bind_and_activate) - ControlMixin.__init__(self, handler, poll_interval) + def finish(self): + data = self.wfile.getvalue() + if data: + try: + super(DelegatingUDPRequestHandler, self).finish() + except OSError: + if not self.server._closed: + raise - def server_bind(self): - super(TestTCPServer, self).server_bind() - self.port = self.socket.getsockname()[1] + ThreadingUDPServer.__init__(self, addr, + DelegatingUDPRequestHandler, + bind_and_activate) + ControlMixin.__init__(self, handler, poll_interval) + self._closed = False - class TestUDPServer(ControlMixin, ThreadingUDPServer): - """ - A UDP server which is controllable using :class:`ControlMixin`. - - :param addr: A tuple with the IP address and port to listen on. - :param handler: A handler callable which will be called with a - single parameter - the request - in order to - process the request. - :param poll_interval: The polling interval for shutdown requests, - in seconds. - :bind_and_activate: If True (the default), binds the server and - starts it listening. If False, you need to - call :meth:`server_bind` and - :meth:`server_activate` at some later time - before calling :meth:`start`, so that the server will - set up the socket and listen on it. - """ - def __init__(self, addr, handler, poll_interval=0.5, - bind_and_activate=True): - class DelegatingUDPRequestHandler(DatagramRequestHandler): - - def handle(self): - self.server._handler(self) - - def finish(self): - data = self.wfile.getvalue() - if data: - try: - super(DelegatingUDPRequestHandler, self).finish() - except OSError: - if not self.server._closed: - raise - - ThreadingUDPServer.__init__(self, addr, - DelegatingUDPRequestHandler, - bind_and_activate) - ControlMixin.__init__(self, handler, poll_interval) - self._closed = False - - def server_bind(self): - super(TestUDPServer, self).server_bind() - self.port = self.socket.getsockname()[1] - - def server_close(self): - super(TestUDPServer, self).server_close() - self._closed = True + def server_bind(self): + super(TestUDPServer, self).server_bind() + self.port = self.socket.getsockname()[1] - if hasattr(socket, "AF_UNIX"): - class TestUnixStreamServer(TestTCPServer): - address_family = socket.AF_UNIX + def server_close(self): + super(TestUDPServer, self).server_close() + self._closed = True - class TestUnixDatagramServer(TestUDPServer): - address_family = socket.AF_UNIX +if hasattr(socket, "AF_UNIX"): + class TestUnixStreamServer(TestTCPServer): + address_family = socket.AF_UNIX + + class TestUnixDatagramServer(TestUDPServer): + address_family = socket.AF_UNIX # - end of server_helper section -@unittest.skipUnless(threading, 'Threading required for this test.') class SMTPHandlerTest(BaseTest): TIMEOUT = 8.0 @@ -1472,14 +1466,12 @@ class ConfigFileTest(BaseTest): @unittest.skipIf(True, "FIXME: bpo-30830") -@unittest.skipUnless(threading, 'Threading required for this test.') class SocketHandlerTest(BaseTest): """Test for SocketHandler objects.""" - if threading: - server_class = TestTCPServer - address = ('localhost', 0) + server_class = TestTCPServer + address = ('localhost', 0) def setUp(self): """Set up a TCP server to receive log messages, and a SocketHandler @@ -1573,12 +1565,11 @@ def _get_temp_domain_socket(): @unittest.skipIf(True, "FIXME: bpo-30830") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") -@unittest.skipUnless(threading, 'Threading required for this test.') class UnixSocketHandlerTest(SocketHandlerTest): """Test for SocketHandler with unix sockets.""" - if threading and hasattr(socket, "AF_UNIX"): + if hasattr(socket, "AF_UNIX"): server_class = TestUnixStreamServer def setUp(self): @@ -1591,14 +1582,12 @@ class UnixSocketHandlerTest(SocketHandlerTest): support.unlink(self.address) @unittest.skipIf(True, "FIXME: bpo-30830") -@unittest.skipUnless(threading, 'Threading required for this test.') class DatagramHandlerTest(BaseTest): """Test for DatagramHandler.""" - if threading: - server_class = TestUDPServer - address = ('localhost', 0) + server_class = TestUDPServer + address = ('localhost', 0) def setUp(self): """Set up a UDP server to receive log messages, and a DatagramHandler @@ -1659,12 +1648,11 @@ class DatagramHandlerTest(BaseTest): @unittest.skipIf(True, "FIXME: bpo-30830") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") -@unittest.skipUnless(threading, 'Threading required for this test.') class UnixDatagramHandlerTest(DatagramHandlerTest): """Test for DatagramHandler using Unix sockets.""" - if threading and hasattr(socket, "AF_UNIX"): + if hasattr(socket, "AF_UNIX"): server_class = TestUnixDatagramServer def setUp(self): @@ -1676,14 +1664,12 @@ class UnixDatagramHandlerTest(DatagramHandlerTest): DatagramHandlerTest.tearDown(self) support.unlink(self.address) -@unittest.skipUnless(threading, 'Threading required for this test.') class SysLogHandlerTest(BaseTest): """Test for SysLogHandler using UDP.""" - if threading: - server_class = TestUDPServer - address = ('localhost', 0) + server_class = TestUDPServer + address = ('localhost', 0) def setUp(self): """Set up a UDP server to receive log messages, and a SysLogHandler @@ -1747,12 +1733,11 @@ class SysLogHandlerTest(BaseTest): @unittest.skipIf(True, "FIXME: bpo-30830") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") -@unittest.skipUnless(threading, 'Threading required for this test.') class UnixSysLogHandlerTest(SysLogHandlerTest): """Test for SysLogHandler with Unix sockets.""" - if threading and hasattr(socket, "AF_UNIX"): + if hasattr(socket, "AF_UNIX"): server_class = TestUnixDatagramServer def setUp(self): @@ -1767,7 +1752,6 @@ class UnixSysLogHandlerTest(SysLogHandlerTest): @unittest.skipIf(True, "FIXME: bpo-30830") @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 support required for this test.') -@unittest.skipUnless(threading, 'Threading required for this test.') class IPv6SysLogHandlerTest(SysLogHandlerTest): """Test for SysLogHandler with IPv6 host.""" @@ -1783,7 +1767,6 @@ class IPv6SysLogHandlerTest(SysLogHandlerTest): self.server_class.address_family = socket.AF_INET super(IPv6SysLogHandlerTest, self).tearDown() -@unittest.skipUnless(threading, 'Threading required for this test.') class HTTPHandlerTest(BaseTest): """Test for HTTPHandler.""" @@ -2892,7 +2875,6 @@ class ConfigDictTest(BaseTest): # listen() uses ConfigSocketReceiver which is based # on socketserver.ThreadingTCPServer @unittest.skipIf(True, "FIXME: bpo-30830") - @unittest.skipUnless(threading, 'listen() needs threading to work') def setup_via_listener(self, text, verify=None): text = text.encode("utf-8") # Ask for a randomly assigned port (by using port 0) @@ -2923,7 +2905,6 @@ class ConfigDictTest(BaseTest): if t.is_alive(): self.fail("join() timed out") - @unittest.skipUnless(threading, 'Threading required for this test.') def test_listen_config_10_ok(self): with support.captured_stdout() as output: self.setup_via_listener(json.dumps(self.config10)) @@ -2943,7 +2924,6 @@ class ConfigDictTest(BaseTest): ('ERROR', '4'), ], stream=output) - @unittest.skipUnless(threading, 'Threading required for this test.') def test_listen_config_1_ok(self): with support.captured_stdout() as output: self.setup_via_listener(textwrap.dedent(ConfigFileTest.config1)) @@ -2958,7 +2938,6 @@ class ConfigDictTest(BaseTest): # Original logger output is empty. self.assert_log_lines([]) - @unittest.skipUnless(threading, 'Threading required for this test.') def test_listen_verify(self): def verify_fail(stuff): @@ -3713,9 +3692,8 @@ class LogRecordTest(BaseTest): def test_optional(self): r = logging.makeLogRecord({}) NOT_NONE = self.assertIsNotNone - if threading: - NOT_NONE(r.thread) - NOT_NONE(r.threadName) + NOT_NONE(r.thread) + NOT_NONE(r.threadName) NOT_NONE(r.process) NOT_NONE(r.processName) log_threads = logging.logThreads diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py index e2cd36a..bb780bf 100644 --- a/Lib/test/test_nntplib.py +++ b/Lib/test/test_nntplib.py @@ -6,6 +6,8 @@ import unittest import functools import contextlib import os.path +import threading + from test import support from nntplib import NNTP, GroupInfo import nntplib @@ -14,10 +16,7 @@ try: import ssl except ImportError: ssl = None -try: - import threading -except ImportError: - threading = None + TIMEOUT = 30 certfile = os.path.join(os.path.dirname(__file__), 'keycert3.pem') @@ -1520,7 +1519,7 @@ class MockSslTests(MockSocketTests): def nntp_class(*pos, **kw): return nntplib.NNTP_SSL(*pos, ssl_context=bypass_context, **kw) -@unittest.skipUnless(threading, 'requires multithreading') + class LocalServerTests(unittest.TestCase): def setUp(self): sock = socket.socket() diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index 234f701..c3c8238 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -22,15 +22,13 @@ import stat import subprocess import sys import sysconfig +import threading import time import unittest import uuid import warnings from test import support -try: - import threading -except ImportError: - threading = None + try: import resource except ImportError: @@ -2516,92 +2514,90 @@ class ProgramPriorityTests(unittest.TestCase): raise -if threading is not None: - class SendfileTestServer(asyncore.dispatcher, threading.Thread): - - class Handler(asynchat.async_chat): +class SendfileTestServer(asyncore.dispatcher, threading.Thread): - def __init__(self, conn): - asynchat.async_chat.__init__(self, conn) - self.in_buffer = [] - self.closed = False - self.push(b"220 ready\r\n") + class Handler(asynchat.async_chat): - def handle_read(self): - data = self.recv(4096) - self.in_buffer.append(data) + def __init__(self, conn): + asynchat.async_chat.__init__(self, conn) + self.in_buffer = [] + self.closed = False + self.push(b"220 ready\r\n") - def get_data(self): - return b''.join(self.in_buffer) + def handle_read(self): + data = self.recv(4096) + self.in_buffer.append(data) - def handle_close(self): - self.close() - self.closed = True + def get_data(self): + return b''.join(self.in_buffer) - def handle_error(self): - raise - - def __init__(self, address): - threading.Thread.__init__(self) - asyncore.dispatcher.__init__(self) - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) - self.bind(address) - self.listen(5) - self.host, self.port = self.socket.getsockname()[:2] - self.handler_instance = None - self._active = False - self._active_lock = threading.Lock() - - # --- public API - - @property - def running(self): - return self._active - - def start(self): - assert not self.running - self.__flag = threading.Event() - threading.Thread.start(self) - self.__flag.wait() - - def stop(self): - assert self.running - self._active = False - self.join() - - def wait(self): - # wait for handler connection to be closed, then stop the server - while not getattr(self.handler_instance, "closed", False): - time.sleep(0.001) - self.stop() - - # --- internals - - def run(self): - self._active = True - self.__flag.set() - while self._active and asyncore.socket_map: - self._active_lock.acquire() - asyncore.loop(timeout=0.001, count=1) - self._active_lock.release() - asyncore.close_all() - - def handle_accept(self): - conn, addr = self.accept() - self.handler_instance = self.Handler(conn) - - def handle_connect(self): + def handle_close(self): self.close() - handle_read = handle_connect - - def writable(self): - return 0 + self.closed = True def handle_error(self): raise + def __init__(self, address): + threading.Thread.__init__(self) + asyncore.dispatcher.__init__(self) + self.create_socket(socket.AF_INET, socket.SOCK_STREAM) + self.bind(address) + self.listen(5) + self.host, self.port = self.socket.getsockname()[:2] + self.handler_instance = None + self._active = False + self._active_lock = threading.Lock() + + # --- public API + + @property + def running(self): + return self._active + + def start(self): + assert not self.running + self.__flag = threading.Event() + threading.Thread.start(self) + self.__flag.wait() + + def stop(self): + assert self.running + self._active = False + self.join() + + def wait(self): + # wait for handler connection to be closed, then stop the server + while not getattr(self.handler_instance, "closed", False): + time.sleep(0.001) + self.stop() + + # --- internals + + def run(self): + self._active = True + self.__flag.set() + while self._active and asyncore.socket_map: + self._active_lock.acquire() + asyncore.loop(timeout=0.001, count=1) + self._active_lock.release() + asyncore.close_all() + + def handle_accept(self): + conn, addr = self.accept() + self.handler_instance = self.Handler(conn) + + def handle_connect(self): + self.close() + handle_read = handle_connect + + def writable(self): + return 0 + + def handle_error(self): + raise + -@unittest.skipUnless(threading is not None, "test needs threading module") @unittest.skipUnless(hasattr(os, 'sendfile'), "test needs os.sendfile()") class TestSendfile(unittest.TestCase): diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py index 0ea2af5..755d265 100644 --- a/Lib/test/test_pdb.py +++ b/Lib/test/test_pdb.py @@ -1040,9 +1040,6 @@ class PdbTestCase(unittest.TestCase): # invoking "continue" on a non-main thread triggered an exception # inside signal.signal - # raises SkipTest if python was built without threads - support.import_module('threading') - with open(support.TESTFN, 'wb') as f: f.write(textwrap.dedent(""" import threading diff --git a/Lib/test/test_poll.py b/Lib/test/test_poll.py index 6a2bf6e..16c2d2e 100644 --- a/Lib/test/test_poll.py +++ b/Lib/test/test_poll.py @@ -4,10 +4,7 @@ import os import subprocess import random import select -try: - import threading -except ImportError: - threading = None +import threading import time import unittest from test.support import TESTFN, run_unittest, reap_threads, cpython_only @@ -179,7 +176,6 @@ class PollTests(unittest.TestCase): self.assertRaises(OverflowError, pollster.poll, INT_MAX + 1) self.assertRaises(OverflowError, pollster.poll, UINT_MAX + 1) - @unittest.skipUnless(threading, 'Threading required for this test.') @reap_threads def test_threaded_poll(self): r, w = os.pipe() diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py index e5b16dc..92febbf 100644 --- a/Lib/test/test_poplib.py +++ b/Lib/test/test_poplib.py @@ -9,10 +9,10 @@ import asynchat import socket import os import errno +import threading from unittest import TestCase, skipUnless from test import support as test_support -threading = test_support.import_module('threading') HOST = test_support.HOST PORT = 0 diff --git a/Lib/test/test_pydoc.py b/Lib/test/test_pydoc.py index d68ab55..1ac08ed 100644 --- a/Lib/test/test_pydoc.py +++ b/Lib/test/test_pydoc.py @@ -20,6 +20,7 @@ import urllib.parse import xml.etree import xml.etree.ElementTree import textwrap +import threading from io import StringIO from collections import namedtuple from test.support.script_helper import assert_python_ok @@ -30,10 +31,6 @@ from test.support import ( ) from test import pydoc_mod -try: - import threading -except ImportError: - threading = None class nonascii: 'Це не латиниця' @@ -902,7 +899,6 @@ class TestDescriptions(unittest.TestCase): "stat(path, *, dir_fd=None, follow_symlinks=True)") -@unittest.skipUnless(threading, 'Threading required for this test.') class PydocServerTest(unittest.TestCase): """Tests for pydoc._start_server""" diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py index 4ccaa39..718ed67 100644 --- a/Lib/test/test_queue.py +++ b/Lib/test/test_queue.py @@ -1,10 +1,11 @@ # Some simple queue module tests, plus some failure conditions # to ensure the Queue locks remain stable. import queue +import threading import time import unittest from test import support -threading = support.import_module('threading') + QUEUE_SIZE = 5 diff --git a/Lib/test/test_regrtest.py b/Lib/test/test_regrtest.py index b756839..8364767 100644 --- a/Lib/test/test_regrtest.py +++ b/Lib/test/test_regrtest.py @@ -15,6 +15,7 @@ import sys import sysconfig import tempfile import textwrap +import threading import unittest from test import libregrtest from test import support @@ -741,12 +742,7 @@ class ArgsTestCase(BaseTestCase): code = TEST_INTERRUPTED test = self.create_test("sigint", code=code) - try: - import threading - tests = (False, True) - except ImportError: - tests = (False,) - for multiprocessing in tests: + for multiprocessing in (False, True): if multiprocessing: args = ("--slowest", "-j2", test) else: diff --git a/Lib/test/test_robotparser.py b/Lib/test/test_robotparser.py index 0f64ba8..5c1a571 100644 --- a/Lib/test/test_robotparser.py +++ b/Lib/test/test_robotparser.py @@ -1,14 +1,11 @@ import io import os +import threading import unittest import urllib.robotparser from collections import namedtuple from test import support from http.server import BaseHTTPRequestHandler, HTTPServer -try: - import threading -except ImportError: - threading = None class BaseRobotTest: @@ -255,7 +252,6 @@ class RobotHandler(BaseHTTPRequestHandler): pass -@unittest.skipUnless(threading, 'threading required for this test') class PasswordProtectedSiteTestCase(unittest.TestCase): def setUp(self): diff --git a/Lib/test/test_sched.py b/Lib/test/test_sched.py index ebf8856..794c637 100644 --- a/Lib/test/test_sched.py +++ b/Lib/test/test_sched.py @@ -1,11 +1,9 @@ import queue import sched +import threading import time import unittest -try: - import threading -except ImportError: - threading = None + TIMEOUT = 10 @@ -58,7 +56,6 @@ class TestCase(unittest.TestCase): scheduler.run() self.assertEqual(l, [0.01, 0.02, 0.03, 0.04, 0.05]) - @unittest.skipUnless(threading, 'Threading required for this test.') def test_enter_concurrent(self): q = queue.Queue() fun = q.put @@ -113,7 +110,6 @@ class TestCase(unittest.TestCase): scheduler.run() self.assertEqual(l, [0.02, 0.03, 0.04]) - @unittest.skipUnless(threading, 'Threading required for this test.') def test_cancel_concurrent(self): q = queue.Queue() fun = q.put diff --git a/Lib/test/test_signal.py b/Lib/test/test_signal.py index 0e1d067..dc048e5 100644 --- a/Lib/test/test_signal.py +++ b/Lib/test/test_signal.py @@ -5,15 +5,12 @@ import socket import statistics import subprocess import sys +import threading import time import unittest from test import support from test.support.script_helper import assert_python_ok, spawn_python try: - import threading -except ImportError: - threading = None -try: import _testcapi except ImportError: _testcapi = None @@ -21,7 +18,6 @@ except ImportError: class GenericTests(unittest.TestCase): - @unittest.skipIf(threading is None, "test needs threading module") def test_enums(self): for name in dir(signal): sig = getattr(signal, name) @@ -807,7 +803,6 @@ class PendingSignalsTests(unittest.TestCase): 'need signal.sigwait()') @unittest.skipUnless(hasattr(signal, 'pthread_sigmask'), 'need signal.pthread_sigmask()') - @unittest.skipIf(threading is None, "test needs threading module") def test_sigwait_thread(self): # Check that calling sigwait() from a thread doesn't suspend the whole # process. A new interpreter is spawned to avoid problems when mixing diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index 28539f3..42f4266 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -15,14 +15,11 @@ import time import select import errno import textwrap +import threading import unittest from test import support, mock_socket -try: - import threading -except ImportError: - threading = None HOST = support.HOST @@ -191,7 +188,6 @@ MSG_END = '------------ END MESSAGE ------------\n' # test server times out, causing the test to fail. # Test behavior of smtpd.DebuggingServer -@unittest.skipUnless(threading, 'Threading required for this test.') class DebuggingServerTests(unittest.TestCase): maxDiff = None @@ -570,7 +566,6 @@ class NonConnectingTests(unittest.TestCase): # test response of client to a non-successful HELO message -@unittest.skipUnless(threading, 'Threading required for this test.') class BadHELOServerTests(unittest.TestCase): def setUp(self): @@ -590,7 +585,6 @@ class BadHELOServerTests(unittest.TestCase): HOST, self.port, 'localhost', 3) -@unittest.skipUnless(threading, 'Threading required for this test.') class TooLongLineTests(unittest.TestCase): respdata = b'250 OK' + (b'.' * smtplib._MAXLINE * 2) + b'\n' @@ -835,7 +829,6 @@ class SimSMTPServer(smtpd.SMTPServer): # Test various SMTP & ESMTP commands/behaviors that require a simulated server # (i.e., something with more features than DebuggingServer) -@unittest.skipUnless(threading, 'Threading required for this test.') class SMTPSimTests(unittest.TestCase): def setUp(self): @@ -1091,7 +1084,6 @@ class SimSMTPUTF8Server(SimSMTPServer): self.last_rcpt_options = rcpt_options -@unittest.skipUnless(threading, 'Threading required for this test.') class SMTPUTF8SimTests(unittest.TestCase): maxDiff = None @@ -1227,7 +1219,6 @@ class SimSMTPAUTHInitialResponseServer(SimSMTPServer): channel_class = SimSMTPAUTHInitialResponseChannel -@unittest.skipUnless(threading, 'Threading required for this test.') class SMTPAUTHInitialResponseSimTests(unittest.TestCase): def setUp(self): self.real_getfqdn = socket.getfqdn diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 50016ab..27d9d49 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -21,6 +21,8 @@ import pickle import struct import random import string +import _thread as thread +import threading try: import multiprocessing except ImportError: @@ -36,12 +38,6 @@ MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string VSOCKPORT = 1234 try: - import _thread as thread - import threading -except ImportError: - thread = None - threading = None -try: import _socket except ImportError: _socket = None @@ -143,18 +139,17 @@ class ThreadSafeCleanupTestCase(unittest.TestCase): with a recursive lock. """ - if threading: - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._cleanup_lock = threading.RLock() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._cleanup_lock = threading.RLock() - def addCleanup(self, *args, **kwargs): - with self._cleanup_lock: - return super().addCleanup(*args, **kwargs) + def addCleanup(self, *args, **kwargs): + with self._cleanup_lock: + return super().addCleanup(*args, **kwargs) - def doCleanups(self, *args, **kwargs): - with self._cleanup_lock: - return super().doCleanups(*args, **kwargs) + def doCleanups(self, *args, **kwargs): + with self._cleanup_lock: + return super().doCleanups(*args, **kwargs) class SocketCANTest(unittest.TestCase): @@ -407,7 +402,6 @@ class ThreadedRDSSocketTest(SocketRDSTest, ThreadableTest): ThreadableTest.clientTearDown(self) @unittest.skipIf(fcntl is None, "need fcntl") -@unittest.skipUnless(thread, 'Threading required for this test.') @unittest.skipUnless(HAVE_SOCKET_VSOCK, 'VSOCK sockets required for this test.') @unittest.skipUnless(get_cid() != 2, @@ -1684,7 +1678,6 @@ class BasicCANTest(unittest.TestCase): @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') -@unittest.skipUnless(thread, 'Threading required for this test.') class CANTest(ThreadedCANSocketTest): def __init__(self, methodName='runTest'): @@ -1838,7 +1831,6 @@ class BasicRDSTest(unittest.TestCase): @unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.') -@unittest.skipUnless(thread, 'Threading required for this test.') class RDSTest(ThreadedRDSSocketTest): def __init__(self, methodName='runTest'): @@ -1977,7 +1969,7 @@ class BasicVSOCKTest(unittest.TestCase): s.getsockopt(socket.AF_VSOCK, socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE)) -@unittest.skipUnless(thread, 'Threading required for this test.') + class BasicTCPTest(SocketConnectedTest): def __init__(self, methodName='runTest'): @@ -2100,7 +2092,7 @@ class BasicTCPTest(SocketConnectedTest): def _testDetach(self): self.serv_conn.send(MSG) -@unittest.skipUnless(thread, 'Threading required for this test.') + class BasicUDPTest(ThreadedUDPSocketTest): def __init__(self, methodName='runTest'): @@ -3697,17 +3689,14 @@ class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase, pass @requireAttrs(socket.socket, "sendmsg") -@unittest.skipUnless(thread, 'Threading required for this test.') class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase): pass @requireAttrs(socket.socket, "recvmsg") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase): pass @requireAttrs(socket.socket, "recvmsg_into") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase): pass @@ -3724,21 +3713,18 @@ class SendrecvmsgUDP6TestBase(SendrecvmsgDgramFlagsBase, @requireAttrs(socket.socket, "sendmsg") @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.') @requireSocket("AF_INET6", "SOCK_DGRAM") -@unittest.skipUnless(thread, 'Threading required for this test.') class SendmsgUDP6Test(SendmsgConnectionlessTests, SendrecvmsgUDP6TestBase): pass @requireAttrs(socket.socket, "recvmsg") @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.') @requireSocket("AF_INET6", "SOCK_DGRAM") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgUDP6Test(RecvmsgTests, SendrecvmsgUDP6TestBase): pass @requireAttrs(socket.socket, "recvmsg_into") @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.') @requireSocket("AF_INET6", "SOCK_DGRAM") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase): pass @@ -3746,7 +3732,6 @@ class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase): @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.') @requireAttrs(socket, "IPPROTO_IPV6") @requireSocket("AF_INET6", "SOCK_DGRAM") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest, SendrecvmsgUDP6TestBase): pass @@ -3755,7 +3740,6 @@ class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest, @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.') @requireAttrs(socket, "IPPROTO_IPV6") @requireSocket("AF_INET6", "SOCK_DGRAM") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin, RFC3542AncillaryTest, SendrecvmsgUDP6TestBase): @@ -3767,18 +3751,15 @@ class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase, pass @requireAttrs(socket.socket, "sendmsg") -@unittest.skipUnless(thread, 'Threading required for this test.') class SendmsgTCPTest(SendmsgStreamTests, SendrecvmsgTCPTestBase): pass @requireAttrs(socket.socket, "recvmsg") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgTCPTest(RecvmsgTests, RecvmsgGenericStreamTests, SendrecvmsgTCPTestBase): pass @requireAttrs(socket.socket, "recvmsg_into") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgIntoTCPTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, SendrecvmsgTCPTestBase): pass @@ -3791,13 +3772,11 @@ class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase, @requireAttrs(socket.socket, "sendmsg") @requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") -@unittest.skipUnless(thread, 'Threading required for this test.') class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase): pass @requireAttrs(socket.socket, "recvmsg") @requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, SendrecvmsgSCTPStreamTestBase): @@ -3811,7 +3790,6 @@ class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, @requireAttrs(socket.socket, "recvmsg_into") @requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, SendrecvmsgSCTPStreamTestBase): @@ -3830,33 +3808,28 @@ class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase, @requireAttrs(socket.socket, "sendmsg") @requireAttrs(socket, "AF_UNIX") -@unittest.skipUnless(thread, 'Threading required for this test.') class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase): pass @requireAttrs(socket.socket, "recvmsg") @requireAttrs(socket, "AF_UNIX") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests, SendrecvmsgUnixStreamTestBase): pass @requireAttrs(socket.socket, "recvmsg_into") @requireAttrs(socket, "AF_UNIX") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests, SendrecvmsgUnixStreamTestBase): pass @requireAttrs(socket.socket, "sendmsg", "recvmsg") @requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgSCMRightsStreamTest(SCMRightsTest, SendrecvmsgUnixStreamTestBase): pass @requireAttrs(socket.socket, "sendmsg", "recvmsg_into") @requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS") -@unittest.skipUnless(thread, 'Threading required for this test.') class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest, SendrecvmsgUnixStreamTestBase): pass @@ -3944,7 +3917,6 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase): @requireAttrs(signal, "siginterrupt") @unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"), "Don't have signal.alarm or signal.setitimer") -@unittest.skipUnless(thread, 'Threading required for this test.') class InterruptedSendTimeoutTest(InterruptedTimeoutBase, ThreadSafeCleanupTestCase, SocketListeningTestMixin, TCPTestBase): @@ -3997,7 +3969,6 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase, self.checkInterruptedSend(self.serv_conn.sendmsg, [b"a"*512]) -@unittest.skipUnless(thread, 'Threading required for this test.') class TCPCloserTest(ThreadedTCPSocketTest): def testClose(self): @@ -4017,7 +3988,7 @@ class TCPCloserTest(ThreadedTCPSocketTest): self.cli.connect((HOST, self.port)) time.sleep(1.0) -@unittest.skipUnless(thread, 'Threading required for this test.') + class BasicSocketPairTest(SocketPairTest): def __init__(self, methodName='runTest'): @@ -4052,7 +4023,7 @@ class BasicSocketPairTest(SocketPairTest): msg = self.cli.recv(1024) self.assertEqual(msg, MSG) -@unittest.skipUnless(thread, 'Threading required for this test.') + class NonBlockingTCPTests(ThreadedTCPSocketTest): def __init__(self, methodName='runTest'): @@ -4180,7 +4151,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): time.sleep(0.1) self.cli.send(MSG) -@unittest.skipUnless(thread, 'Threading required for this test.') + class FileObjectClassTestCase(SocketConnectedTest): """Unit tests for the object returned by socket.makefile() @@ -4564,7 +4535,6 @@ class NetworkConnectionNoServer(unittest.TestCase): socket.create_connection((HOST, 1234)) -@unittest.skipUnless(thread, 'Threading required for this test.') class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): def __init__(self, methodName='runTest'): @@ -4633,7 +4603,7 @@ class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): self.addCleanup(self.cli.close) self.assertEqual(self.cli.gettimeout(), 30) -@unittest.skipUnless(thread, 'Threading required for this test.') + class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest): def __init__(self, methodName='runTest'): @@ -4877,7 +4847,7 @@ class TestUnixDomain(unittest.TestCase): self.addCleanup(support.unlink, path) self.assertEqual(self.sock.getsockname(), path) -@unittest.skipUnless(thread, 'Threading required for this test.') + class BufferIOTest(SocketConnectedTest): """ Test the buffer versions of socket.recv() and socket.send(). @@ -5050,7 +5020,6 @@ class TIPCThreadableTest(unittest.TestCase, ThreadableTest): self.cli.close() -@unittest.skipUnless(thread, 'Threading required for this test.') class ContextManagersTest(ThreadedTCPSocketTest): def _testSocketClass(self): @@ -5312,7 +5281,6 @@ class TestSocketSharing(SocketTCPTest): source.close() -@unittest.skipUnless(thread, 'Threading required for this test.') class SendfileUsingSendTest(ThreadedTCPSocketTest): """ Test the send() implementation of socket.sendfile(). @@ -5570,7 +5538,6 @@ class SendfileUsingSendTest(ThreadedTCPSocketTest): meth, file, count=-1) -@unittest.skipUnless(thread, 'Threading required for this test.') @unittest.skipUnless(hasattr(os, "sendfile"), 'os.sendfile() required for this test.') class SendfileUsingSendfileTest(SendfileUsingSendTest): diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 3d93566..a23373f 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -9,15 +9,13 @@ import select import signal import socket import tempfile +import threading import unittest import socketserver import test.support from test.support import reap_children, reap_threads, verbose -try: - import threading -except ImportError: - threading = None + test.support.requires("network") @@ -68,7 +66,6 @@ def simple_subprocess(testcase): testcase.assertEqual(72 << 8, status) -@unittest.skipUnless(threading, 'Threading required for this test.') class SocketServerTest(unittest.TestCase): """Test all socket servers.""" @@ -306,12 +303,10 @@ class ErrorHandlerTest(unittest.TestCase): BaseErrorTestServer(SystemExit) self.check_result(handled=False) - @unittest.skipUnless(threading, 'Threading required for this test.') def test_threading_handled(self): ThreadingErrorTestServer(ValueError) self.check_result(handled=True) - @unittest.skipUnless(threading, 'Threading required for this test.') def test_threading_not_handled(self): ThreadingErrorTestServer(SystemExit) self.check_result(handled=False) @@ -396,7 +391,6 @@ class SocketWriterTest(unittest.TestCase): self.assertIsInstance(server.wfile, io.BufferedIOBase) self.assertEqual(server.wfile_fileno, server.request_fileno) - @unittest.skipUnless(threading, 'Threading required for this test.') def test_write(self): # Test that wfile.write() sends data immediately, and that it does # not truncate sends when interrupted by a Unix signal diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 16cad9d..89b4609 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -12,6 +12,7 @@ import os import errno import pprint import urllib.request +import threading import traceback import asyncore import weakref @@ -20,12 +21,6 @@ import functools ssl = support.import_module("ssl") -try: - import threading -except ImportError: - _have_threads = False -else: - _have_threads = True PROTOCOLS = sorted(ssl._PROTOCOL_NAMES) HOST = support.HOST @@ -1468,7 +1463,6 @@ class MemoryBIOTests(unittest.TestCase): self.assertRaises(TypeError, bio.write, 1) -@unittest.skipUnless(_have_threads, "Needs threading module") class SimpleBackgroundTests(unittest.TestCase): """Tests that connect to a simple server running in the background""" @@ -1828,1744 +1822,1743 @@ def _test_get_server_certificate_fail(test, host, port): test.fail("Got server certificate %s for %s:%s!" % (pem, host, port)) -if _have_threads: - from test.ssl_servers import make_https_server +from test.ssl_servers import make_https_server - class ThreadedEchoServer(threading.Thread): +class ThreadedEchoServer(threading.Thread): - class ConnectionHandler(threading.Thread): + class ConnectionHandler(threading.Thread): - """A mildly complicated class, because we want it to work both - with and without the SSL wrapper around the socket connection, so - that we can test the STARTTLS functionality.""" + """A mildly complicated class, because we want it to work both + with and without the SSL wrapper around the socket connection, so + that we can test the STARTTLS functionality.""" - def __init__(self, server, connsock, addr): - self.server = server + def __init__(self, server, connsock, addr): + self.server = server + self.running = False + self.sock = connsock + self.addr = addr + self.sock.setblocking(1) + self.sslconn = None + threading.Thread.__init__(self) + self.daemon = True + + def wrap_conn(self): + try: + self.sslconn = self.server.context.wrap_socket( + self.sock, server_side=True) + self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol()) + self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol()) + except (ssl.SSLError, ConnectionResetError, OSError) as e: + # We treat ConnectionResetError as though it were an + # SSLError - OpenSSL on Ubuntu abruptly closes the + # connection when asked to use an unsupported protocol. + # + # OSError may occur with wrong protocols, e.g. both + # sides use PROTOCOL_TLS_SERVER. + # + # XXX Various errors can have happened here, for example + # a mismatching protocol version, an invalid certificate, + # or a low-level bug. This should be made more discriminating. + # + # bpo-31323: Store the exception as string to prevent + # a reference leak: server -> conn_errors -> exception + # -> traceback -> self (ConnectionHandler) -> server + self.server.conn_errors.append(str(e)) + if self.server.chatty: + handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n") self.running = False - self.sock = connsock - self.addr = addr - self.sock.setblocking(1) - self.sslconn = None - threading.Thread.__init__(self) - self.daemon = True - - def wrap_conn(self): - try: - self.sslconn = self.server.context.wrap_socket( - self.sock, server_side=True) - self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol()) - self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol()) - except (ssl.SSLError, ConnectionResetError, OSError) as e: - # We treat ConnectionResetError as though it were an - # SSLError - OpenSSL on Ubuntu abruptly closes the - # connection when asked to use an unsupported protocol. - # - # OSError may occur with wrong protocols, e.g. both - # sides use PROTOCOL_TLS_SERVER. - # - # XXX Various errors can have happened here, for example - # a mismatching protocol version, an invalid certificate, - # or a low-level bug. This should be made more discriminating. - # - # bpo-31323: Store the exception as string to prevent - # a reference leak: server -> conn_errors -> exception - # -> traceback -> self (ConnectionHandler) -> server - self.server.conn_errors.append(str(e)) - if self.server.chatty: - handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n") - self.running = False - self.server.stop() - self.close() - return False - else: - self.server.shared_ciphers.append(self.sslconn.shared_ciphers()) - if self.server.context.verify_mode == ssl.CERT_REQUIRED: - cert = self.sslconn.getpeercert() - if support.verbose and self.server.chatty: - sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") - cert_binary = self.sslconn.getpeercert(True) - if support.verbose and self.server.chatty: - sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") - cipher = self.sslconn.cipher() + self.server.stop() + self.close() + return False + else: + self.server.shared_ciphers.append(self.sslconn.shared_ciphers()) + if self.server.context.verify_mode == ssl.CERT_REQUIRED: + cert = self.sslconn.getpeercert() if support.verbose and self.server.chatty: - sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") - sys.stdout.write(" server: selected protocol is now " - + str(self.sslconn.selected_npn_protocol()) + "\n") - return True - - def read(self): - if self.sslconn: - return self.sslconn.read() - else: - return self.sock.recv(1024) + sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n") + cert_binary = self.sslconn.getpeercert(True) + if support.verbose and self.server.chatty: + sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n") + cipher = self.sslconn.cipher() + if support.verbose and self.server.chatty: + sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n") + sys.stdout.write(" server: selected protocol is now " + + str(self.sslconn.selected_npn_protocol()) + "\n") + return True + + def read(self): + if self.sslconn: + return self.sslconn.read() + else: + return self.sock.recv(1024) - def write(self, bytes): - if self.sslconn: - return self.sslconn.write(bytes) - else: - return self.sock.send(bytes) + def write(self, bytes): + if self.sslconn: + return self.sslconn.write(bytes) + else: + return self.sock.send(bytes) - def close(self): - if self.sslconn: - self.sslconn.close() - else: - self.sock.close() + def close(self): + if self.sslconn: + self.sslconn.close() + else: + self.sock.close() - def run(self): - self.running = True - if not self.server.starttls_server: - if not self.wrap_conn(): - return - while self.running: - try: - msg = self.read() - stripped = msg.strip() - if not stripped: - # eof, so quit this handler - self.running = False - try: - self.sock = self.sslconn.unwrap() - except OSError: - # Many tests shut the TCP connection down - # without an SSL shutdown. This causes - # unwrap() to raise OSError with errno=0! - pass - else: - self.sslconn = None - self.close() - elif stripped == b'over': - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: client closed connection\n") - self.close() - return - elif (self.server.starttls_server and - stripped == b'STARTTLS'): - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: read STARTTLS from client, sending OK...\n") - self.write(b"OK\n") - if not self.wrap_conn(): - return - elif (self.server.starttls_server and self.sslconn - and stripped == b'ENDTLS'): - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: read ENDTLS from client, sending OK...\n") - self.write(b"OK\n") + def run(self): + self.running = True + if not self.server.starttls_server: + if not self.wrap_conn(): + return + while self.running: + try: + msg = self.read() + stripped = msg.strip() + if not stripped: + # eof, so quit this handler + self.running = False + try: self.sock = self.sslconn.unwrap() - self.sslconn = None - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: connection is now unencrypted...\n") - elif stripped == b'CB tls-unique': - if support.verbose and self.server.connectionchatty: - sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n") - data = self.sslconn.get_channel_binding("tls-unique") - self.write(repr(data).encode("us-ascii") + b"\n") + except OSError: + # Many tests shut the TCP connection down + # without an SSL shutdown. This causes + # unwrap() to raise OSError with errno=0! + pass else: - if (support.verbose and - self.server.connectionchatty): - ctype = (self.sslconn and "encrypted") or "unencrypted" - sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n" - % (msg, ctype, msg.lower(), ctype)) - self.write(msg.lower()) - except OSError: - if self.server.chatty: - handle_error("Test server failure:\n") + self.sslconn = None self.close() - self.running = False - # normally, we'd just stop here, but for the test - # harness, we want to stop the server - self.server.stop() - - def __init__(self, certificate=None, ssl_version=None, - certreqs=None, cacerts=None, - chatty=True, connectionchatty=False, starttls_server=False, - npn_protocols=None, alpn_protocols=None, - ciphers=None, context=None): - if context: - self.context = context - else: - self.context = ssl.SSLContext(ssl_version - if ssl_version is not None - else ssl.PROTOCOL_TLSv1) - self.context.verify_mode = (certreqs if certreqs is not None - else ssl.CERT_NONE) - if cacerts: - self.context.load_verify_locations(cacerts) - if certificate: - self.context.load_cert_chain(certificate) - if npn_protocols: - self.context.set_npn_protocols(npn_protocols) - if alpn_protocols: - self.context.set_alpn_protocols(alpn_protocols) - if ciphers: - self.context.set_ciphers(ciphers) - self.chatty = chatty - self.connectionchatty = connectionchatty - self.starttls_server = starttls_server - self.sock = socket.socket() - self.port = support.bind_port(self.sock) - self.flag = None - self.active = False - self.selected_npn_protocols = [] - self.selected_alpn_protocols = [] - self.shared_ciphers = [] - self.conn_errors = [] - threading.Thread.__init__(self) - self.daemon = True - - def __enter__(self): - self.start(threading.Event()) - self.flag.wait() - return self - - def __exit__(self, *args): - self.stop() - self.join() - - def start(self, flag=None): - self.flag = flag - threading.Thread.start(self) + elif stripped == b'over': + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: client closed connection\n") + self.close() + return + elif (self.server.starttls_server and + stripped == b'STARTTLS'): + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read STARTTLS from client, sending OK...\n") + self.write(b"OK\n") + if not self.wrap_conn(): + return + elif (self.server.starttls_server and self.sslconn + and stripped == b'ENDTLS'): + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read ENDTLS from client, sending OK...\n") + self.write(b"OK\n") + self.sock = self.sslconn.unwrap() + self.sslconn = None + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: connection is now unencrypted...\n") + elif stripped == b'CB tls-unique': + if support.verbose and self.server.connectionchatty: + sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n") + data = self.sslconn.get_channel_binding("tls-unique") + self.write(repr(data).encode("us-ascii") + b"\n") + else: + if (support.verbose and + self.server.connectionchatty): + ctype = (self.sslconn and "encrypted") or "unencrypted" + sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n" + % (msg, ctype, msg.lower(), ctype)) + self.write(msg.lower()) + except OSError: + if self.server.chatty: + handle_error("Test server failure:\n") + self.close() + self.running = False + # normally, we'd just stop here, but for the test + # harness, we want to stop the server + self.server.stop() - def run(self): - self.sock.settimeout(0.05) - self.sock.listen() - self.active = True - if self.flag: - # signal an event - self.flag.set() - while self.active: - try: - newconn, connaddr = self.sock.accept() - if support.verbose and self.chatty: - sys.stdout.write(' server: new connection from ' - + repr(connaddr) + '\n') - handler = self.ConnectionHandler(self, newconn, connaddr) - handler.start() - handler.join() - except socket.timeout: - pass - except KeyboardInterrupt: - self.stop() - self.sock.close() + def __init__(self, certificate=None, ssl_version=None, + certreqs=None, cacerts=None, + chatty=True, connectionchatty=False, starttls_server=False, + npn_protocols=None, alpn_protocols=None, + ciphers=None, context=None): + if context: + self.context = context + else: + self.context = ssl.SSLContext(ssl_version + if ssl_version is not None + else ssl.PROTOCOL_TLSv1) + self.context.verify_mode = (certreqs if certreqs is not None + else ssl.CERT_NONE) + if cacerts: + self.context.load_verify_locations(cacerts) + if certificate: + self.context.load_cert_chain(certificate) + if npn_protocols: + self.context.set_npn_protocols(npn_protocols) + if alpn_protocols: + self.context.set_alpn_protocols(alpn_protocols) + if ciphers: + self.context.set_ciphers(ciphers) + self.chatty = chatty + self.connectionchatty = connectionchatty + self.starttls_server = starttls_server + self.sock = socket.socket() + self.port = support.bind_port(self.sock) + self.flag = None + self.active = False + self.selected_npn_protocols = [] + self.selected_alpn_protocols = [] + self.shared_ciphers = [] + self.conn_errors = [] + threading.Thread.__init__(self) + self.daemon = True + + def __enter__(self): + self.start(threading.Event()) + self.flag.wait() + return self + + def __exit__(self, *args): + self.stop() + self.join() + + def start(self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.sock.settimeout(0.05) + self.sock.listen() + self.active = True + if self.flag: + # signal an event + self.flag.set() + while self.active: + try: + newconn, connaddr = self.sock.accept() + if support.verbose and self.chatty: + sys.stdout.write(' server: new connection from ' + + repr(connaddr) + '\n') + handler = self.ConnectionHandler(self, newconn, connaddr) + handler.start() + handler.join() + except socket.timeout: + pass + except KeyboardInterrupt: + self.stop() + self.sock.close() - def stop(self): - self.active = False + def stop(self): + self.active = False - class AsyncoreEchoServer(threading.Thread): +class AsyncoreEchoServer(threading.Thread): - # this one's based on asyncore.dispatcher + # this one's based on asyncore.dispatcher - class EchoServer (asyncore.dispatcher): + class EchoServer (asyncore.dispatcher): - class ConnectionHandler(asyncore.dispatcher_with_send): + class ConnectionHandler(asyncore.dispatcher_with_send): - def __init__(self, conn, certfile): - self.socket = test_wrap_socket(conn, server_side=True, - certfile=certfile, - do_handshake_on_connect=False) - asyncore.dispatcher_with_send.__init__(self, self.socket) - self._ssl_accepting = True - self._do_ssl_handshake() + def __init__(self, conn, certfile): + self.socket = test_wrap_socket(conn, server_side=True, + certfile=certfile, + do_handshake_on_connect=False) + asyncore.dispatcher_with_send.__init__(self, self.socket) + self._ssl_accepting = True + self._do_ssl_handshake() - def readable(self): - if isinstance(self.socket, ssl.SSLSocket): - while self.socket.pending() > 0: - self.handle_read_event() - return True + def readable(self): + if isinstance(self.socket, ssl.SSLSocket): + while self.socket.pending() > 0: + self.handle_read_event() + return True - def _do_ssl_handshake(self): - try: - self.socket.do_handshake() - except (ssl.SSLWantReadError, ssl.SSLWantWriteError): - return - except ssl.SSLEOFError: + def _do_ssl_handshake(self): + try: + self.socket.do_handshake() + except (ssl.SSLWantReadError, ssl.SSLWantWriteError): + return + except ssl.SSLEOFError: + return self.handle_close() + except ssl.SSLError: + raise + except OSError as err: + if err.args[0] == errno.ECONNABORTED: return self.handle_close() - except ssl.SSLError: - raise - except OSError as err: - if err.args[0] == errno.ECONNABORTED: - return self.handle_close() - else: - self._ssl_accepting = False - - def handle_read(self): - if self._ssl_accepting: - self._do_ssl_handshake() - else: - data = self.recv(1024) - if support.verbose: - sys.stdout.write(" server: read %s from client\n" % repr(data)) - if not data: - self.close() - else: - self.send(data.lower()) + else: + self._ssl_accepting = False - def handle_close(self): - self.close() + def handle_read(self): + if self._ssl_accepting: + self._do_ssl_handshake() + else: + data = self.recv(1024) if support.verbose: - sys.stdout.write(" server: closed connection %s\n" % self.socket) - - def handle_error(self): - raise - - def __init__(self, certfile): - self.certfile = certfile - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.port = support.bind_port(sock, '') - asyncore.dispatcher.__init__(self, sock) - self.listen(5) + sys.stdout.write(" server: read %s from client\n" % repr(data)) + if not data: + self.close() + else: + self.send(data.lower()) - def handle_accepted(self, sock_obj, addr): + def handle_close(self): + self.close() if support.verbose: - sys.stdout.write(" server: new connection from %s:%s\n" %addr) - self.ConnectionHandler(sock_obj, self.certfile) + sys.stdout.write(" server: closed connection %s\n" % self.socket) def handle_error(self): raise def __init__(self, certfile): - self.flag = None - self.active = False - self.server = self.EchoServer(certfile) - self.port = self.server.port - threading.Thread.__init__(self) - self.daemon = True + self.certfile = certfile + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.port = support.bind_port(sock, '') + asyncore.dispatcher.__init__(self, sock) + self.listen(5) - def __str__(self): - return "<%s %s>" % (self.__class__.__name__, self.server) + def handle_accepted(self, sock_obj, addr): + if support.verbose: + sys.stdout.write(" server: new connection from %s:%s\n" %addr) + self.ConnectionHandler(sock_obj, self.certfile) - def __enter__(self): - self.start(threading.Event()) - self.flag.wait() - return self + def handle_error(self): + raise - def __exit__(self, *args): - if support.verbose: - sys.stdout.write(" cleanup: stopping server.\n") - self.stop() - if support.verbose: - sys.stdout.write(" cleanup: joining server thread.\n") - self.join() - if support.verbose: - sys.stdout.write(" cleanup: successfully joined.\n") - # make sure that ConnectionHandler is removed from socket_map - asyncore.close_all(ignore_all=True) + def __init__(self, certfile): + self.flag = None + self.active = False + self.server = self.EchoServer(certfile) + self.port = self.server.port + threading.Thread.__init__(self) + self.daemon = True - def start (self, flag=None): - self.flag = flag - threading.Thread.start(self) + def __str__(self): + return "<%s %s>" % (self.__class__.__name__, self.server) - def run(self): - self.active = True - if self.flag: - self.flag.set() - while self.active: - try: - asyncore.loop(1) - except: - pass + def __enter__(self): + self.start(threading.Event()) + self.flag.wait() + return self - def stop(self): - self.active = False - self.server.close() + def __exit__(self, *args): + if support.verbose: + sys.stdout.write(" cleanup: stopping server.\n") + self.stop() + if support.verbose: + sys.stdout.write(" cleanup: joining server thread.\n") + self.join() + if support.verbose: + sys.stdout.write(" cleanup: successfully joined.\n") + # make sure that ConnectionHandler is removed from socket_map + asyncore.close_all(ignore_all=True) + + def start (self, flag=None): + self.flag = flag + threading.Thread.start(self) + + def run(self): + self.active = True + if self.flag: + self.flag.set() + while self.active: + try: + asyncore.loop(1) + except: + pass - def server_params_test(client_context, server_context, indata=b"FOO\n", - chatty=True, connectionchatty=False, sni_name=None, - session=None): - """ - Launch a server, connect a client to it and try various reads - and writes. - """ - stats = {} - server = ThreadedEchoServer(context=server_context, - chatty=chatty, - connectionchatty=False) - with server: - with client_context.wrap_socket(socket.socket(), - server_hostname=sni_name, session=session) as s: - s.connect((HOST, server.port)) - for arg in [indata, bytearray(indata), memoryview(indata)]: - if connectionchatty: - if support.verbose: - sys.stdout.write( - " client: sending %r...\n" % indata) - s.write(arg) - outdata = s.read() - if connectionchatty: - if support.verbose: - sys.stdout.write(" client: read %r\n" % outdata) - if outdata != indata.lower(): - raise AssertionError( - "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" - % (outdata[:20], len(outdata), - indata[:20].lower(), len(indata))) - s.write(b"over\n") + def stop(self): + self.active = False + self.server.close() + +def server_params_test(client_context, server_context, indata=b"FOO\n", + chatty=True, connectionchatty=False, sni_name=None, + session=None): + """ + Launch a server, connect a client to it and try various reads + and writes. + """ + stats = {} + server = ThreadedEchoServer(context=server_context, + chatty=chatty, + connectionchatty=False) + with server: + with client_context.wrap_socket(socket.socket(), + server_hostname=sni_name, session=session) as s: + s.connect((HOST, server.port)) + for arg in [indata, bytearray(indata), memoryview(indata)]: if connectionchatty: if support.verbose: - sys.stdout.write(" client: closing connection.\n") - stats.update({ - 'compression': s.compression(), - 'cipher': s.cipher(), - 'peercert': s.getpeercert(), - 'client_alpn_protocol': s.selected_alpn_protocol(), - 'client_npn_protocol': s.selected_npn_protocol(), - 'version': s.version(), - 'session_reused': s.session_reused, - 'session': s.session, - }) - s.close() - stats['server_alpn_protocols'] = server.selected_alpn_protocols - stats['server_npn_protocols'] = server.selected_npn_protocols - stats['server_shared_ciphers'] = server.shared_ciphers - return stats + sys.stdout.write( + " client: sending %r...\n" % indata) + s.write(arg) + outdata = s.read() + if connectionchatty: + if support.verbose: + sys.stdout.write(" client: read %r\n" % outdata) + if outdata != indata.lower(): + raise AssertionError( + "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" + % (outdata[:20], len(outdata), + indata[:20].lower(), len(indata))) + s.write(b"over\n") + if connectionchatty: + if support.verbose: + sys.stdout.write(" client: closing connection.\n") + stats.update({ + 'compression': s.compression(), + 'cipher': s.cipher(), + 'peercert': s.getpeercert(), + 'client_alpn_protocol': s.selected_alpn_protocol(), + 'client_npn_protocol': s.selected_npn_protocol(), + 'version': s.version(), + 'session_reused': s.session_reused, + 'session': s.session, + }) + s.close() + stats['server_alpn_protocols'] = server.selected_alpn_protocols + stats['server_npn_protocols'] = server.selected_npn_protocols + stats['server_shared_ciphers'] = server.shared_ciphers + return stats + +def try_protocol_combo(server_protocol, client_protocol, expect_success, + certsreqs=None, server_options=0, client_options=0): + """ + Try to SSL-connect using *client_protocol* to *server_protocol*. + If *expect_success* is true, assert that the connection succeeds, + if it's false, assert that the connection fails. + Also, if *expect_success* is a string, assert that it is the protocol + version actually used by the connection. + """ + if certsreqs is None: + certsreqs = ssl.CERT_NONE + certtype = { + ssl.CERT_NONE: "CERT_NONE", + ssl.CERT_OPTIONAL: "CERT_OPTIONAL", + ssl.CERT_REQUIRED: "CERT_REQUIRED", + }[certsreqs] + if support.verbose: + formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n" + sys.stdout.write(formatstr % + (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol), + certtype)) + client_context = ssl.SSLContext(client_protocol) + client_context.options |= client_options + server_context = ssl.SSLContext(server_protocol) + server_context.options |= server_options + + # NOTE: we must enable "ALL" ciphers on the client, otherwise an + # SSLv23 client will send an SSLv3 hello (rather than SSLv2) + # starting from OpenSSL 1.0.0 (see issue #8322). + if client_context.protocol == ssl.PROTOCOL_SSLv23: + client_context.set_ciphers("ALL") + + for ctx in (client_context, server_context): + ctx.verify_mode = certsreqs + ctx.load_cert_chain(CERTFILE) + ctx.load_verify_locations(CERTFILE) + try: + stats = server_params_test(client_context, server_context, + chatty=False, connectionchatty=False) + # Protocol mismatch can result in either an SSLError, or a + # "Connection reset by peer" error. + except ssl.SSLError: + if expect_success: + raise + except OSError as e: + if expect_success or e.errno != errno.ECONNRESET: + raise + else: + if not expect_success: + raise AssertionError( + "Client protocol %s succeeded with server protocol %s!" + % (ssl.get_protocol_name(client_protocol), + ssl.get_protocol_name(server_protocol))) + elif (expect_success is not True + and expect_success != stats['version']): + raise AssertionError("version mismatch: expected %r, got %r" + % (expect_success, stats['version'])) - def try_protocol_combo(server_protocol, client_protocol, expect_success, - certsreqs=None, server_options=0, client_options=0): - """ - Try to SSL-connect using *client_protocol* to *server_protocol*. - If *expect_success* is true, assert that the connection succeeds, - if it's false, assert that the connection fails. - Also, if *expect_success* is a string, assert that it is the protocol - version actually used by the connection. - """ - if certsreqs is None: - certsreqs = ssl.CERT_NONE - certtype = { - ssl.CERT_NONE: "CERT_NONE", - ssl.CERT_OPTIONAL: "CERT_OPTIONAL", - ssl.CERT_REQUIRED: "CERT_REQUIRED", - }[certsreqs] - if support.verbose: - formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n" - sys.stdout.write(formatstr % - (ssl.get_protocol_name(client_protocol), - ssl.get_protocol_name(server_protocol), - certtype)) - client_context = ssl.SSLContext(client_protocol) - client_context.options |= client_options - server_context = ssl.SSLContext(server_protocol) - server_context.options |= server_options - - # NOTE: we must enable "ALL" ciphers on the client, otherwise an - # SSLv23 client will send an SSLv3 hello (rather than SSLv2) - # starting from OpenSSL 1.0.0 (see issue #8322). - if client_context.protocol == ssl.PROTOCOL_SSLv23: - client_context.set_ciphers("ALL") - - for ctx in (client_context, server_context): - ctx.verify_mode = certsreqs - ctx.load_cert_chain(CERTFILE) - ctx.load_verify_locations(CERTFILE) - try: - stats = server_params_test(client_context, server_context, - chatty=False, connectionchatty=False) - # Protocol mismatch can result in either an SSLError, or a - # "Connection reset by peer" error. - except ssl.SSLError: - if expect_success: - raise - except OSError as e: - if expect_success or e.errno != errno.ECONNRESET: - raise - else: - if not expect_success: - raise AssertionError( - "Client protocol %s succeeded with server protocol %s!" - % (ssl.get_protocol_name(client_protocol), - ssl.get_protocol_name(server_protocol))) - elif (expect_success is not True - and expect_success != stats['version']): - raise AssertionError("version mismatch: expected %r, got %r" - % (expect_success, stats['version'])) - - - class ThreadedTests(unittest.TestCase): - - @skip_if_broken_ubuntu_ssl - def test_echo(self): - """Basic test of an SSL client connecting to a server""" - if support.verbose: - sys.stdout.write("\n") - for protocol in PROTOCOLS: - if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}: - continue - with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): - context = ssl.SSLContext(protocol) - context.load_cert_chain(CERTFILE) - server_params_test(context, context, - chatty=True, connectionchatty=True) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - client_context.load_verify_locations(SIGNING_CA) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - # server_context.load_verify_locations(SIGNING_CA) - server_context.load_cert_chain(SIGNED_CERTFILE2) +class ThreadedTests(unittest.TestCase): - with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER): - server_params_test(client_context=client_context, - server_context=server_context, + @skip_if_broken_ubuntu_ssl + def test_echo(self): + """Basic test of an SSL client connecting to a server""" + if support.verbose: + sys.stdout.write("\n") + for protocol in PROTOCOLS: + if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}: + continue + with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]): + context = ssl.SSLContext(protocol) + context.load_cert_chain(CERTFILE) + server_params_test(context, context, + chatty=True, connectionchatty=True) + + client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_context.load_verify_locations(SIGNING_CA) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + # server_context.load_verify_locations(SIGNING_CA) + server_context.load_cert_chain(SIGNED_CERTFILE2) + + with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER): + server_params_test(client_context=client_context, + server_context=server_context, + chatty=True, connectionchatty=True, + sni_name='fakehostname') + + client_context.check_hostname = False + with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT): + with self.assertRaises(ssl.SSLError) as e: + server_params_test(client_context=server_context, + server_context=client_context, chatty=True, connectionchatty=True, sni_name='fakehostname') + self.assertIn('called a function you should not call', + str(e.exception)) - client_context.check_hostname = False - with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT): - with self.assertRaises(ssl.SSLError) as e: - server_params_test(client_context=server_context, - server_context=client_context, - chatty=True, connectionchatty=True, - sni_name='fakehostname') - self.assertIn('called a function you should not call', - str(e.exception)) - - with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER): - with self.assertRaises(ssl.SSLError) as e: - server_params_test(client_context=server_context, - server_context=server_context, - chatty=True, connectionchatty=True) - self.assertIn('called a function you should not call', - str(e.exception)) + with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER): + with self.assertRaises(ssl.SSLError) as e: + server_params_test(client_context=server_context, + server_context=server_context, + chatty=True, connectionchatty=True) + self.assertIn('called a function you should not call', + str(e.exception)) - with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT): - with self.assertRaises(ssl.SSLError) as e: - server_params_test(client_context=server_context, - server_context=client_context, - chatty=True, connectionchatty=True) - self.assertIn('called a function you should not call', - str(e.exception)) + with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT): + with self.assertRaises(ssl.SSLError) as e: + server_params_test(client_context=server_context, + server_context=client_context, + chatty=True, connectionchatty=True) + self.assertIn('called a function you should not call', + str(e.exception)) - def test_getpeercert(self): + def test_getpeercert(self): + if support.verbose: + sys.stdout.write("\n") + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + s = context.wrap_socket(socket.socket(), + do_handshake_on_connect=False) + s.connect((HOST, server.port)) + # getpeercert() raise ValueError while the handshake isn't + # done. + with self.assertRaises(ValueError): + s.getpeercert() + s.do_handshake() + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + cipher = s.cipher() if support.verbose: - sys.stdout.write("\n") - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = ThreadedEchoServer(context=context, chatty=False) - with server: - s = context.wrap_socket(socket.socket(), - do_handshake_on_connect=False) + sys.stdout.write(pprint.pformat(cert) + '\n') + sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') + if 'subject' not in cert: + self.fail("No subject field in certificate: %s." % + pprint.pformat(cert)) + if ((('organizationName', 'Python Software Foundation'),) + not in cert['subject']): + self.fail( + "Missing or invalid 'organizationName' field in certificate subject; " + "should be 'Python Software Foundation'.") + self.assertIn('notBefore', cert) + self.assertIn('notAfter', cert) + before = ssl.cert_time_to_seconds(cert['notBefore']) + after = ssl.cert_time_to_seconds(cert['notAfter']) + self.assertLess(before, after) + s.close() + + @unittest.skipUnless(have_verify_flags(), + "verify_flags need OpenSSL > 0.9.8") + def test_crl_check(self): + if support.verbose: + sys.stdout.write("\n") + + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(SIGNING_CA) + tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) + self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf) + + # VERIFY_DEFAULT should pass + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: s.connect((HOST, server.port)) - # getpeercert() raise ValueError while the handshake isn't - # done. - with self.assertRaises(ValueError): - s.getpeercert() - s.do_handshake() cert = s.getpeercert() self.assertTrue(cert, "Can't get peer certificate.") - cipher = s.cipher() - if support.verbose: - sys.stdout.write(pprint.pformat(cert) + '\n') - sys.stdout.write("Connection cipher is " + str(cipher) + '.\n') - if 'subject' not in cert: - self.fail("No subject field in certificate: %s." % - pprint.pformat(cert)) - if ((('organizationName', 'Python Software Foundation'),) - not in cert['subject']): - self.fail( - "Missing or invalid 'organizationName' field in certificate subject; " - "should be 'Python Software Foundation'.") - self.assertIn('notBefore', cert) - self.assertIn('notAfter', cert) - before = ssl.cert_time_to_seconds(cert['notBefore']) - after = ssl.cert_time_to_seconds(cert['notAfter']) - self.assertLess(before, after) - s.close() - @unittest.skipUnless(have_verify_flags(), - "verify_flags need OpenSSL > 0.9.8") - def test_crl_check(self): - if support.verbose: - sys.stdout.write("\n") + # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails + context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(SIGNING_CA) - tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0) - self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf) - - # VERIFY_DEFAULT should pass - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket()) as s: + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + with self.assertRaisesRegex(ssl.SSLError, + "certificate verify failed"): s.connect((HOST, server.port)) - cert = s.getpeercert() - self.assertTrue(cert, "Can't get peer certificate.") - # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails - context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF + # now load a CRL file. The CRL file is signed by the CA. + context.load_verify_locations(CRLFILE) - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket()) as s: - with self.assertRaisesRegex(ssl.SSLError, - "certificate verify failed"): - s.connect((HOST, server.port)) + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") - # now load a CRL file. The CRL file is signed by the CA. - context.load_verify_locations(CRLFILE) + def test_check_hostname(self): + if support.verbose: + sys.stdout.write("\n") - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket()) as s: - s.connect((HOST, server.port)) - cert = s.getpeercert() - self.assertTrue(cert, "Can't get peer certificate.") + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) - def test_check_hostname(self): - if support.verbose: - sys.stdout.write("\n") + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = True + context.load_verify_locations(SIGNING_CA) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.verify_mode = ssl.CERT_REQUIRED - context.check_hostname = True - context.load_verify_locations(SIGNING_CA) - - # correct hostname should verify - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket(), - server_hostname="localhost") as s: + # correct hostname should verify + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="localhost") as s: + s.connect((HOST, server.port)) + cert = s.getpeercert() + self.assertTrue(cert, "Can't get peer certificate.") + + # incorrect hostname should raise an exception + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with context.wrap_socket(socket.socket(), + server_hostname="invalid") as s: + with self.assertRaisesRegex(ssl.CertificateError, + "hostname 'invalid' doesn't match 'localhost'"): s.connect((HOST, server.port)) - cert = s.getpeercert() - self.assertTrue(cert, "Can't get peer certificate.") - - # incorrect hostname should raise an exception - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with context.wrap_socket(socket.socket(), - server_hostname="invalid") as s: - with self.assertRaisesRegex(ssl.CertificateError, - "hostname 'invalid' doesn't match 'localhost'"): - s.connect((HOST, server.port)) - - # missing server_hostname arg should cause an exception, too - server = ThreadedEchoServer(context=server_context, chatty=True) - with server: - with socket.socket() as s: - with self.assertRaisesRegex(ValueError, - "check_hostname requires server_hostname"): - context.wrap_socket(s) - - def test_wrong_cert(self): - """Connecting when the server rejects the client's certificate - - Launch a server with CERT_REQUIRED, and check that trying to - connect to it with a wrong client certificate fails. - """ - certfile = os.path.join(os.path.dirname(__file__) or os.curdir, - "wrongcert.pem") - server = ThreadedEchoServer(CERTFILE, - certreqs=ssl.CERT_REQUIRED, - cacerts=CERTFILE, chatty=False, - connectionchatty=False) - with server, \ - socket.socket() as sock, \ - test_wrap_socket(sock, - certfile=certfile, - ssl_version=ssl.PROTOCOL_TLSv1) as s: + + # missing server_hostname arg should cause an exception, too + server = ThreadedEchoServer(context=server_context, chatty=True) + with server: + with socket.socket() as s: + with self.assertRaisesRegex(ValueError, + "check_hostname requires server_hostname"): + context.wrap_socket(s) + + def test_wrong_cert(self): + """Connecting when the server rejects the client's certificate + + Launch a server with CERT_REQUIRED, and check that trying to + connect to it with a wrong client certificate fails. + """ + certfile = os.path.join(os.path.dirname(__file__) or os.curdir, + "wrongcert.pem") + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_REQUIRED, + cacerts=CERTFILE, chatty=False, + connectionchatty=False) + with server, \ + socket.socket() as sock, \ + test_wrap_socket(sock, + certfile=certfile, + ssl_version=ssl.PROTOCOL_TLSv1) as s: + try: + # Expect either an SSL error about the server rejecting + # the connection, or a low-level connection reset (which + # sometimes happens on Windows) + s.connect((HOST, server.port)) + except ssl.SSLError as e: + if support.verbose: + sys.stdout.write("\nSSLError is %r\n" % e) + except OSError as e: + if e.errno != errno.ECONNRESET: + raise + if support.verbose: + sys.stdout.write("\nsocket.error is %r\n" % e) + else: + self.fail("Use of invalid cert should have failed!") + + def test_rude_shutdown(self): + """A brutal shutdown of an SSL server should raise an OSError + in the client when attempting handshake. + """ + listener_ready = threading.Event() + listener_gone = threading.Event() + + s = socket.socket() + port = support.bind_port(s, HOST) + + # `listener` runs in a thread. It sits in an accept() until + # the main thread connects. Then it rudely closes the socket, + # and sets Event `listener_gone` to let the main thread know + # the socket is gone. + def listener(): + s.listen() + listener_ready.set() + newsock, addr = s.accept() + newsock.close() + s.close() + listener_gone.set() + + def connector(): + listener_ready.wait() + with socket.socket() as c: + c.connect((HOST, port)) + listener_gone.wait() try: - # Expect either an SSL error about the server rejecting - # the connection, or a low-level connection reset (which - # sometimes happens on Windows) - s.connect((HOST, server.port)) - except ssl.SSLError as e: - if support.verbose: - sys.stdout.write("\nSSLError is %r\n" % e) - except OSError as e: - if e.errno != errno.ECONNRESET: - raise - if support.verbose: - sys.stdout.write("\nsocket.error is %r\n" % e) + ssl_sock = test_wrap_socket(c) + except OSError: + pass else: - self.fail("Use of invalid cert should have failed!") + self.fail('connecting to closed SSL socket should have failed') - def test_rude_shutdown(self): - """A brutal shutdown of an SSL server should raise an OSError - in the client when attempting handshake. - """ - listener_ready = threading.Event() - listener_gone = threading.Event() + t = threading.Thread(target=listener) + t.start() + try: + connector() + finally: + t.join() - s = socket.socket() - port = support.bind_port(s, HOST) - - # `listener` runs in a thread. It sits in an accept() until - # the main thread connects. Then it rudely closes the socket, - # and sets Event `listener_gone` to let the main thread know - # the socket is gone. - def listener(): - s.listen() - listener_ready.set() - newsock, addr = s.accept() - newsock.close() - s.close() - listener_gone.set() - - def connector(): - listener_ready.wait() - with socket.socket() as c: - c.connect((HOST, port)) - listener_gone.wait() - try: - ssl_sock = test_wrap_socket(c) - except OSError: - pass - else: - self.fail('connecting to closed SSL socket should have failed') + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), + "OpenSSL is compiled without SSLv2 support") + def test_protocol_sslv2(self): + """Connecting to an SSLv2 server with various client options""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) + # SSLv23 client with specific SSL options + if no_sslv2_implies_sslv3_hello(): + # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_SSLv2) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_SSLv3) + try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1) - t = threading.Thread(target=listener) - t.start() + @skip_if_broken_ubuntu_ssl + def test_protocol_sslv23(self): + """Connecting to an SSLv23 server with various client options""" + if support.verbose: + sys.stdout.write("\n") + if hasattr(ssl, 'PROTOCOL_SSLv2'): try: - connector() - finally: - t.join() + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) + except OSError as x: + # this fails on some older versions of OpenSSL (0.9.7l, for instance) + if support.verbose: + sys.stdout.write( + " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" + % str(x)) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') + + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) + + # Server with specific SSL options + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, + server_options=ssl.OP_NO_SSLv3) + # Will choose TLSv1 + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, + server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False, + server_options=ssl.OP_NO_TLSv1) + + + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'), + "OpenSSL is compiled without SSLv3 support") + def test_protocol_sslv3(self): + """Connecting to an SSLv3 server with various client options""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3') + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_SSLv3) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) + if no_sslv2_implies_sslv3_hello(): + # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, + False, client_options=ssl.OP_NO_SSLv2) + + @skip_if_broken_ubuntu_ssl + def test_protocol_tlsv1(self): + """Connecting to a TLSv1 server with various client options""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1') + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1) + + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), + "TLS version 1.1 not supported.") + def test_protocol_tlsv1_1(self): + """Connecting to a TLSv1.1 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_1) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) - @skip_if_broken_ubuntu_ssl - @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), - "OpenSSL is compiled without SSLv2 support") - def test_protocol_sslv2(self): - """Connecting to an SSLv2 server with various client options""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False) - # SSLv23 client with specific SSL options - if no_sslv2_implies_sslv3_hello(): - # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_SSLv2) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_TLSv1) - @skip_if_broken_ubuntu_ssl - def test_protocol_sslv23(self): - """Connecting to an SSLv23 server with various client options""" + @skip_if_broken_ubuntu_ssl + @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), + "TLS version 1.2 not supported.") + def test_protocol_tlsv1_2(self): + """Connecting to a TLSv1.2 server with various client options. + Testing against older TLS versions.""" + if support.verbose: + sys.stdout.write("\n") + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2', + server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, + client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) + if hasattr(ssl, 'PROTOCOL_SSLv2'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) + if hasattr(ssl, 'PROTOCOL_SSLv3'): + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, + client_options=ssl.OP_NO_TLSv1_2) + + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) + + def test_starttls(self): + """Switching from clear text to encrypted and back again.""" + msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") + + server = ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_TLSv1, + starttls_server=True, + chatty=True, + connectionchatty=True) + wrapped = False + with server: + s = socket.socket() + s.setblocking(1) + s.connect((HOST, server.port)) if support.verbose: sys.stdout.write("\n") - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try: - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True) - except OSError as x: - # this fails on some older versions of OpenSSL (0.9.7l, for instance) + for indata in msgs: + if support.verbose: + sys.stdout.write( + " client: sending %r...\n" % indata) + if wrapped: + conn.write(indata) + outdata = conn.read() + else: + s.send(indata) + outdata = s.recv(1024) + msg = outdata.strip().lower() + if indata == b"STARTTLS" and msg.startswith(b"ok"): + # STARTTLS ok, switch to secure mode if support.verbose: sys.stdout.write( - " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" - % str(x)) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') - - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) - - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) - - # Server with specific SSL options - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, - server_options=ssl.OP_NO_SSLv3) - # Will choose TLSv1 - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, - server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False, - server_options=ssl.OP_NO_TLSv1) - - - @skip_if_broken_ubuntu_ssl - @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'), - "OpenSSL is compiled without SSLv3 support") - def test_protocol_sslv3(self): - """Connecting to an SSLv3 server with various client options""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3') - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_SSLv3) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) - if no_sslv2_implies_sslv3_hello(): - # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, - False, client_options=ssl.OP_NO_SSLv2) - - @skip_if_broken_ubuntu_ssl - def test_protocol_tlsv1(self): - """Connecting to a TLSv1 server with various client options""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1') - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_TLSv1) - - @skip_if_broken_ubuntu_ssl - @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"), - "TLS version 1.1 not supported.") - def test_protocol_tlsv1_1(self): - """Connecting to a TLSv1.1 server with various client options. - Testing against older TLS versions.""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_TLSv1_1) - - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) - - - @skip_if_broken_ubuntu_ssl - @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"), - "TLS version 1.2 not supported.") - def test_protocol_tlsv1_2(self): - """Connecting to a TLSv1.2 server with various client options. - Testing against older TLS versions.""" - if support.verbose: - sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2', - server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, - client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) - if hasattr(ssl, 'PROTOCOL_SSLv2'): - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False) - if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, - client_options=ssl.OP_NO_TLSv1_2) - - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False) - - def test_starttls(self): - """Switching from clear text to encrypted and back again.""" - msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6") - - server = ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_TLSv1, - starttls_server=True, - chatty=True, - connectionchatty=True) - wrapped = False - with server: - s = socket.socket() - s.setblocking(1) - s.connect((HOST, server.port)) - if support.verbose: - sys.stdout.write("\n") - for indata in msgs: + " client: read %r from server, starting TLS...\n" + % msg) + conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) + wrapped = True + elif indata == b"ENDTLS" and msg.startswith(b"ok"): + # ENDTLS ok, switch back to clear text if support.verbose: sys.stdout.write( - " client: sending %r...\n" % indata) - if wrapped: - conn.write(indata) - outdata = conn.read() - else: - s.send(indata) - outdata = s.recv(1024) - msg = outdata.strip().lower() - if indata == b"STARTTLS" and msg.startswith(b"ok"): - # STARTTLS ok, switch to secure mode - if support.verbose: - sys.stdout.write( - " client: read %r from server, starting TLS...\n" - % msg) - conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1) - wrapped = True - elif indata == b"ENDTLS" and msg.startswith(b"ok"): - # ENDTLS ok, switch back to clear text - if support.verbose: - sys.stdout.write( - " client: read %r from server, ending TLS...\n" - % msg) - s = conn.unwrap() - wrapped = False - else: - if support.verbose: - sys.stdout.write( - " client: read %r from server\n" % msg) - if support.verbose: - sys.stdout.write(" client: closing connection.\n") - if wrapped: - conn.write(b"over\n") + " client: read %r from server, ending TLS...\n" + % msg) + s = conn.unwrap() + wrapped = False else: - s.send(b"over\n") - if wrapped: - conn.close() - else: - s.close() - - def test_socketserver(self): - """Using socketserver to create and manage SSL connections.""" - server = make_https_server(self, certfile=CERTFILE) - # try to connect - if support.verbose: - sys.stdout.write('\n') - with open(CERTFILE, 'rb') as f: - d1 = f.read() - d2 = '' - # now fetch the same data from the HTTPS server - url = 'https://localhost:%d/%s' % ( - server.port, os.path.split(CERTFILE)[1]) - context = ssl.create_default_context(cafile=CERTFILE) - f = urllib.request.urlopen(url, context=context) - try: - dlen = f.info().get("content-length") - if dlen and (int(dlen) > 0): - d2 = f.read(int(dlen)) if support.verbose: sys.stdout.write( - " client: read %d bytes from remote server '%s'\n" - % (len(d2), server)) - finally: - f.close() - self.assertEqual(d1, d2) - - def test_asyncore_server(self): - """Check the example asyncore integration.""" + " client: read %r from server\n" % msg) if support.verbose: - sys.stdout.write("\n") + sys.stdout.write(" client: closing connection.\n") + if wrapped: + conn.write(b"over\n") + else: + s.send(b"over\n") + if wrapped: + conn.close() + else: + s.close() - indata = b"FOO\n" - server = AsyncoreEchoServer(CERTFILE) - with server: - s = test_wrap_socket(socket.socket()) - s.connect(('127.0.0.1', server.port)) + def test_socketserver(self): + """Using socketserver to create and manage SSL connections.""" + server = make_https_server(self, certfile=CERTFILE) + # try to connect + if support.verbose: + sys.stdout.write('\n') + with open(CERTFILE, 'rb') as f: + d1 = f.read() + d2 = '' + # now fetch the same data from the HTTPS server + url = 'https://localhost:%d/%s' % ( + server.port, os.path.split(CERTFILE)[1]) + context = ssl.create_default_context(cafile=CERTFILE) + f = urllib.request.urlopen(url, context=context) + try: + dlen = f.info().get("content-length") + if dlen and (int(dlen) > 0): + d2 = f.read(int(dlen)) if support.verbose: sys.stdout.write( - " client: sending %r...\n" % indata) - s.write(indata) - outdata = s.read() - if support.verbose: - sys.stdout.write(" client: read %r\n" % outdata) - if outdata != indata.lower(): - self.fail( - "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" - % (outdata[:20], len(outdata), - indata[:20].lower(), len(indata))) - s.write(b"over\n") - if support.verbose: - sys.stdout.write(" client: closing connection.\n") - s.close() - if support.verbose: - sys.stdout.write(" client: connection closed.\n") + " client: read %d bytes from remote server '%s'\n" + % (len(d2), server)) + finally: + f.close() + self.assertEqual(d1, d2) - def test_recv_send(self): - """Test recv(), send() and friends.""" + def test_asyncore_server(self): + """Check the example asyncore integration.""" + if support.verbose: + sys.stdout.write("\n") + + indata = b"FOO\n" + server = AsyncoreEchoServer(CERTFILE) + with server: + s = test_wrap_socket(socket.socket()) + s.connect(('127.0.0.1', server.port)) if support.verbose: - sys.stdout.write("\n") + sys.stdout.write( + " client: sending %r...\n" % indata) + s.write(indata) + outdata = s.read() + if support.verbose: + sys.stdout.write(" client: read %r\n" % outdata) + if outdata != indata.lower(): + self.fail( + "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n" + % (outdata[:20], len(outdata), + indata[:20].lower(), len(indata))) + s.write(b"over\n") + if support.verbose: + sys.stdout.write(" client: closing connection.\n") + s.close() + if support.verbose: + sys.stdout.write(" client: connection closed.\n") - server = ThreadedEchoServer(CERTFILE, - certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, - cacerts=CERTFILE, - chatty=True, - connectionchatty=False) - with server: - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - # helper methods for standardising recv* method signatures - def _recv_into(): - b = bytearray(b"\0"*100) - count = s.recv_into(b) - return b[:count] - - def _recvfrom_into(): - b = bytearray(b"\0"*100) - count, addr = s.recvfrom_into(b) - return b[:count] - - # (name, method, expect success?, *args, return value func) - send_methods = [ - ('send', s.send, True, [], len), - ('sendto', s.sendto, False, ["some.address"], len), - ('sendall', s.sendall, True, [], lambda x: None), - ] - # (name, method, whether to expect success, *args) - recv_methods = [ - ('recv', s.recv, True, []), - ('recvfrom', s.recvfrom, False, ["some.address"]), - ('recv_into', _recv_into, True, []), - ('recvfrom_into', _recvfrom_into, False, []), - ] - data_prefix = "PREFIX_" - - for (meth_name, send_meth, expect_success, args, - ret_val_meth) in send_methods: - indata = (data_prefix + meth_name).encode('ascii') - try: - ret = send_meth(indata, *args) - msg = "sending with {}".format(meth_name) - self.assertEqual(ret, ret_val_meth(indata), msg=msg) - outdata = s.read() - if outdata != indata.lower(): - self.fail( - "While sending with <<{name:s}>> bad data " - "<<{outdata:r}>> ({nout:d}) received; " - "expected <<{indata:r}>> ({nin:d})\n".format( - name=meth_name, outdata=outdata[:20], - nout=len(outdata), - indata=indata[:20], nin=len(indata) - ) - ) - except ValueError as e: - if expect_success: - self.fail( - "Failed to send with method <<{name:s}>>; " - "expected to succeed.\n".format(name=meth_name) + def test_recv_send(self): + """Test recv(), send() and friends.""" + if support.verbose: + sys.stdout.write("\n") + + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = test_wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + # helper methods for standardising recv* method signatures + def _recv_into(): + b = bytearray(b"\0"*100) + count = s.recv_into(b) + return b[:count] + + def _recvfrom_into(): + b = bytearray(b"\0"*100) + count, addr = s.recvfrom_into(b) + return b[:count] + + # (name, method, expect success?, *args, return value func) + send_methods = [ + ('send', s.send, True, [], len), + ('sendto', s.sendto, False, ["some.address"], len), + ('sendall', s.sendall, True, [], lambda x: None), + ] + # (name, method, whether to expect success, *args) + recv_methods = [ + ('recv', s.recv, True, []), + ('recvfrom', s.recvfrom, False, ["some.address"]), + ('recv_into', _recv_into, True, []), + ('recvfrom_into', _recvfrom_into, False, []), + ] + data_prefix = "PREFIX_" + + for (meth_name, send_meth, expect_success, args, + ret_val_meth) in send_methods: + indata = (data_prefix + meth_name).encode('ascii') + try: + ret = send_meth(indata, *args) + msg = "sending with {}".format(meth_name) + self.assertEqual(ret, ret_val_meth(indata), msg=msg) + outdata = s.read() + if outdata != indata.lower(): + self.fail( + "While sending with <<{name:s}>> bad data " + "<<{outdata:r}>> ({nout:d}) received; " + "expected <<{indata:r}>> ({nin:d})\n".format( + name=meth_name, outdata=outdata[:20], + nout=len(outdata), + indata=indata[:20], nin=len(indata) ) - if not str(e).startswith(meth_name): - self.fail( - "Method <<{name:s}>> failed with unexpected " - "exception message: {exp:s}\n".format( - name=meth_name, exp=e - ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to send with method <<{name:s}>>; " + "expected to succeed.\n".format(name=meth_name) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<{name:s}>> failed with unexpected " + "exception message: {exp:s}\n".format( + name=meth_name, exp=e ) + ) - for meth_name, recv_meth, expect_success, args in recv_methods: - indata = (data_prefix + meth_name).encode('ascii') - try: - s.send(indata) - outdata = recv_meth(*args) - if outdata != indata.lower(): - self.fail( - "While receiving with <<{name:s}>> bad data " - "<<{outdata:r}>> ({nout:d}) received; " - "expected <<{indata:r}>> ({nin:d})\n".format( - name=meth_name, outdata=outdata[:20], - nout=len(outdata), - indata=indata[:20], nin=len(indata) - ) - ) - except ValueError as e: - if expect_success: - self.fail( - "Failed to receive with method <<{name:s}>>; " - "expected to succeed.\n".format(name=meth_name) + for meth_name, recv_meth, expect_success, args in recv_methods: + indata = (data_prefix + meth_name).encode('ascii') + try: + s.send(indata) + outdata = recv_meth(*args) + if outdata != indata.lower(): + self.fail( + "While receiving with <<{name:s}>> bad data " + "<<{outdata:r}>> ({nout:d}) received; " + "expected <<{indata:r}>> ({nin:d})\n".format( + name=meth_name, outdata=outdata[:20], + nout=len(outdata), + indata=indata[:20], nin=len(indata) ) - if not str(e).startswith(meth_name): - self.fail( - "Method <<{name:s}>> failed with unexpected " - "exception message: {exp:s}\n".format( - name=meth_name, exp=e - ) + ) + except ValueError as e: + if expect_success: + self.fail( + "Failed to receive with method <<{name:s}>>; " + "expected to succeed.\n".format(name=meth_name) + ) + if not str(e).startswith(meth_name): + self.fail( + "Method <<{name:s}>> failed with unexpected " + "exception message: {exp:s}\n".format( + name=meth_name, exp=e ) - # consume data - s.read() + ) + # consume data + s.read() - # read(-1, buffer) is supported, even though read(-1) is not - data = b"data" - s.send(data) - buffer = bytearray(len(data)) - self.assertEqual(s.read(-1, buffer), len(data)) - self.assertEqual(buffer, data) + # read(-1, buffer) is supported, even though read(-1) is not + data = b"data" + s.send(data) + buffer = bytearray(len(data)) + self.assertEqual(s.read(-1, buffer), len(data)) + self.assertEqual(buffer, data) - # Make sure sendmsg et al are disallowed to avoid - # inadvertent disclosure of data and/or corruption - # of the encrypted data stream - self.assertRaises(NotImplementedError, s.sendmsg, [b"data"]) - self.assertRaises(NotImplementedError, s.recvmsg, 100) - self.assertRaises(NotImplementedError, - s.recvmsg_into, bytearray(100)) + # Make sure sendmsg et al are disallowed to avoid + # inadvertent disclosure of data and/or corruption + # of the encrypted data stream + self.assertRaises(NotImplementedError, s.sendmsg, [b"data"]) + self.assertRaises(NotImplementedError, s.recvmsg, 100) + self.assertRaises(NotImplementedError, + s.recvmsg_into, bytearray(100)) - s.write(b"over\n") + s.write(b"over\n") - self.assertRaises(ValueError, s.recv, -1) - self.assertRaises(ValueError, s.read, -1) + self.assertRaises(ValueError, s.recv, -1) + self.assertRaises(ValueError, s.read, -1) - s.close() + s.close() - def test_recv_zero(self): - server = ThreadedEchoServer(CERTFILE) - server.__enter__() - self.addCleanup(server.__exit__, None, None) - s = socket.create_connection((HOST, server.port)) - self.addCleanup(s.close) - s = test_wrap_socket(s, suppress_ragged_eofs=False) - self.addCleanup(s.close) + def test_recv_zero(self): + server = ThreadedEchoServer(CERTFILE) + server.__enter__() + self.addCleanup(server.__exit__, None, None) + s = socket.create_connection((HOST, server.port)) + self.addCleanup(s.close) + s = test_wrap_socket(s, suppress_ragged_eofs=False) + self.addCleanup(s.close) - # recv/read(0) should return no data - s.send(b"data") - self.assertEqual(s.recv(0), b"") - self.assertEqual(s.read(0), b"") - self.assertEqual(s.read(), b"data") + # recv/read(0) should return no data + s.send(b"data") + self.assertEqual(s.recv(0), b"") + self.assertEqual(s.read(0), b"") + self.assertEqual(s.read(), b"data") + + # Should not block if the other end sends no data + s.setblocking(False) + self.assertEqual(s.recv(0), b"") + self.assertEqual(s.recv_into(bytearray()), 0) - # Should not block if the other end sends no data + def test_nonblocking_send(self): + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = test_wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) s.setblocking(False) - self.assertEqual(s.recv(0), b"") - self.assertEqual(s.recv_into(bytearray()), 0) - - def test_nonblocking_send(self): - server = ThreadedEchoServer(CERTFILE, - certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, - cacerts=CERTFILE, - chatty=True, - connectionchatty=False) - with server: - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) - s.connect((HOST, server.port)) - s.setblocking(False) - - # If we keep sending data, at some point the buffers - # will be full and the call will block - buf = bytearray(8192) - def fill_buffer(): - while True: - s.send(buf) - self.assertRaises((ssl.SSLWantWriteError, - ssl.SSLWantReadError), fill_buffer) - - # Now read all the output and discard it - s.setblocking(True) - s.close() - def test_handshake_timeout(self): - # Issue #5103: SSL handshake must respect the socket timeout - server = socket.socket(socket.AF_INET) - host = "127.0.0.1" - port = support.bind_port(server) - started = threading.Event() - finish = False - - def serve(): - server.listen() - started.set() - conns = [] - while not finish: - r, w, e = select.select([server], [], [], 0.1) - if server in r: - # Let the socket hang around rather than having - # it closed by garbage collection. - conns.append(server.accept()[0]) - for sock in conns: - sock.close() - - t = threading.Thread(target=serve) - t.start() - started.wait() + # If we keep sending data, at some point the buffers + # will be full and the call will block + buf = bytearray(8192) + def fill_buffer(): + while True: + s.send(buf) + self.assertRaises((ssl.SSLWantWriteError, + ssl.SSLWantReadError), fill_buffer) + + # Now read all the output and discard it + s.setblocking(True) + s.close() + + def test_handshake_timeout(self): + # Issue #5103: SSL handshake must respect the socket timeout + server = socket.socket(socket.AF_INET) + host = "127.0.0.1" + port = support.bind_port(server) + started = threading.Event() + finish = False + + def serve(): + server.listen() + started.set() + conns = [] + while not finish: + r, w, e = select.select([server], [], [], 0.1) + if server in r: + # Let the socket hang around rather than having + # it closed by garbage collection. + conns.append(server.accept()[0]) + for sock in conns: + sock.close() + + t = threading.Thread(target=serve) + t.start() + started.wait() + try: try: - try: - c = socket.socket(socket.AF_INET) - c.settimeout(0.2) - c.connect((host, port)) - # Will attempt handshake and time out - self.assertRaisesRegex(socket.timeout, "timed out", - test_wrap_socket, c) - finally: - c.close() - try: - c = socket.socket(socket.AF_INET) - c = test_wrap_socket(c) - c.settimeout(0.2) - # Will attempt handshake and time out - self.assertRaisesRegex(socket.timeout, "timed out", - c.connect, (host, port)) - finally: - c.close() + c = socket.socket(socket.AF_INET) + c.settimeout(0.2) + c.connect((host, port)) + # Will attempt handshake and time out + self.assertRaisesRegex(socket.timeout, "timed out", + test_wrap_socket, c) finally: - finish = True - t.join() - server.close() - - def test_server_accept(self): - # Issue #16357: accept() on a SSLSocket created through - # SSLContext.wrap_socket(). - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = socket.socket(socket.AF_INET) - host = "127.0.0.1" - port = support.bind_port(server) - server = context.wrap_socket(server, server_side=True) - self.assertTrue(server.server_side) - - evt = threading.Event() - remote = None - peer = None - def serve(): - nonlocal remote, peer - server.listen() - # Block on the accept and wait on the connection to close. - evt.set() - remote, peer = server.accept() - remote.recv(1) - - t = threading.Thread(target=serve) - t.start() - # Client wait until server setup and perform a connect. - evt.wait() - client = context.wrap_socket(socket.socket()) - client.connect((host, port)) - client_addr = client.getsockname() - client.close() + c.close() + try: + c = socket.socket(socket.AF_INET) + c = test_wrap_socket(c) + c.settimeout(0.2) + # Will attempt handshake and time out + self.assertRaisesRegex(socket.timeout, "timed out", + c.connect, (host, port)) + finally: + c.close() + finally: + finish = True t.join() - remote.close() server.close() - # Sanity checks. - self.assertIsInstance(remote, ssl.SSLSocket) - self.assertEqual(peer, client_addr) - - def test_getpeercert_enotconn(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - with context.wrap_socket(socket.socket()) as sock: - with self.assertRaises(OSError) as cm: - sock.getpeercert() - self.assertEqual(cm.exception.errno, errno.ENOTCONN) - - def test_do_handshake_enotconn(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - with context.wrap_socket(socket.socket()) as sock: - with self.assertRaises(OSError) as cm: - sock.do_handshake() - self.assertEqual(cm.exception.errno, errno.ENOTCONN) - - def test_default_ciphers(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - try: - # Force a set of weak ciphers on our client context - context.set_ciphers("DES") - except ssl.SSLError: - self.skipTest("no DES cipher available") - with ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_SSLv23, - chatty=False) as server: - with context.wrap_socket(socket.socket()) as s: - with self.assertRaises(OSError): - s.connect((HOST, server.port)) - self.assertIn("no shared cipher", server.conn_errors[0]) - - def test_version_basic(self): - """ - Basic tests for SSLSocket.version(). - More tests are done in the test_protocol_*() methods. - """ - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - with ThreadedEchoServer(CERTFILE, - ssl_version=ssl.PROTOCOL_TLSv1, - chatty=False) as server: - with context.wrap_socket(socket.socket()) as s: - self.assertIs(s.version(), None) - s.connect((HOST, server.port)) - self.assertEqual(s.version(), 'TLSv1') - self.assertIs(s.version(), None) - @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") - def test_default_ecdh_curve(self): - # Issue #21015: elliptic curve-based Diffie Hellman key exchange - # should be enabled by default on SSL contexts. - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.load_cert_chain(CERTFILE) - # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled - # explicitly using the 'ECCdraft' cipher alias. Otherwise, - # our default cipher list should prefer ECDH-based ciphers - # automatically. - if ssl.OPENSSL_VERSION_INFO < (1, 0, 0): - context.set_ciphers("ECCdraft:ECDH") - with ThreadedEchoServer(context=context) as server: - with context.wrap_socket(socket.socket()) as s: + def test_server_accept(self): + # Issue #16357: accept() on a SSLSocket created through + # SSLContext.wrap_socket(). + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = socket.socket(socket.AF_INET) + host = "127.0.0.1" + port = support.bind_port(server) + server = context.wrap_socket(server, server_side=True) + self.assertTrue(server.server_side) + + evt = threading.Event() + remote = None + peer = None + def serve(): + nonlocal remote, peer + server.listen() + # Block on the accept and wait on the connection to close. + evt.set() + remote, peer = server.accept() + remote.recv(1) + + t = threading.Thread(target=serve) + t.start() + # Client wait until server setup and perform a connect. + evt.wait() + client = context.wrap_socket(socket.socket()) + client.connect((host, port)) + client_addr = client.getsockname() + client.close() + t.join() + remote.close() + server.close() + # Sanity checks. + self.assertIsInstance(remote, ssl.SSLSocket) + self.assertEqual(peer, client_addr) + + def test_getpeercert_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.getpeercert() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + + def test_do_handshake_enotconn(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + with context.wrap_socket(socket.socket()) as sock: + with self.assertRaises(OSError) as cm: + sock.do_handshake() + self.assertEqual(cm.exception.errno, errno.ENOTCONN) + + def test_default_ciphers(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + try: + # Force a set of weak ciphers on our client context + context.set_ciphers("DES") + except ssl.SSLError: + self.skipTest("no DES cipher available") + with ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_SSLv23, + chatty=False) as server: + with context.wrap_socket(socket.socket()) as s: + with self.assertRaises(OSError): s.connect((HOST, server.port)) - self.assertIn("ECDH", s.cipher()[0]) + self.assertIn("no shared cipher", server.conn_errors[0]) - @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, - "'tls-unique' channel binding not available") - def test_tls_unique_channel_binding(self): - """Test tls-unique channel binding.""" - if support.verbose: - sys.stdout.write("\n") - - server = ThreadedEchoServer(CERTFILE, - certreqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1, - cacerts=CERTFILE, - chatty=True, - connectionchatty=False) - with server: - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) + def test_version_basic(self): + """ + Basic tests for SSLSocket.version(). + More tests are done in the test_protocol_*() methods. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + with ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_TLSv1, + chatty=False) as server: + with context.wrap_socket(socket.socket()) as s: + self.assertIs(s.version(), None) s.connect((HOST, server.port)) - # get the data - cb_data = s.get_channel_binding("tls-unique") - if support.verbose: - sys.stdout.write(" got channel binding data: {0!r}\n" - .format(cb_data)) - - # check if it is sane - self.assertIsNotNone(cb_data) - self.assertEqual(len(cb_data), 12) # True for TLSv1 - - # and compare with the peers version - s.write(b"CB tls-unique\n") - peer_data_repr = s.read().strip() - self.assertEqual(peer_data_repr, - repr(cb_data).encode("us-ascii")) - s.close() - - # now, again - s = test_wrap_socket(socket.socket(), - server_side=False, - certfile=CERTFILE, - ca_certs=CERTFILE, - cert_reqs=ssl.CERT_NONE, - ssl_version=ssl.PROTOCOL_TLSv1) + self.assertEqual(s.version(), 'TLSv1') + self.assertIs(s.version(), None) + + @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") + def test_default_ecdh_curve(self): + # Issue #21015: elliptic curve-based Diffie Hellman key exchange + # should be enabled by default on SSL contexts. + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.load_cert_chain(CERTFILE) + # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled + # explicitly using the 'ECCdraft' cipher alias. Otherwise, + # our default cipher list should prefer ECDH-based ciphers + # automatically. + if ssl.OPENSSL_VERSION_INFO < (1, 0, 0): + context.set_ciphers("ECCdraft:ECDH") + with ThreadedEchoServer(context=context) as server: + with context.wrap_socket(socket.socket()) as s: s.connect((HOST, server.port)) - new_cb_data = s.get_channel_binding("tls-unique") - if support.verbose: - sys.stdout.write(" got another channel binding data: {0!r}\n" - .format(new_cb_data)) - # is it really unique - self.assertNotEqual(cb_data, new_cb_data) - self.assertIsNotNone(cb_data) - self.assertEqual(len(cb_data), 12) # True for TLSv1 - s.write(b"CB tls-unique\n") - peer_data_repr = s.read().strip() - self.assertEqual(peer_data_repr, - repr(new_cb_data).encode("us-ascii")) - s.close() + self.assertIn("ECDH", s.cipher()[0]) - def test_compression(self): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - if support.verbose: - sys.stdout.write(" got compression: {!r}\n".format(stats['compression'])) - self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' }) - - @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'), - "ssl.OP_NO_COMPRESSION needed for this test") - def test_compression_disabled(self): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - context.options |= ssl.OP_NO_COMPRESSION - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - self.assertIs(stats['compression'], None) - - def test_dh_params(self): - # Check we can get a connection with ephemeral Diffie-Hellman - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - context.load_dh_params(DHFILE) - context.set_ciphers("kEDH") - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - cipher = stats["cipher"][0] - parts = cipher.split("-") - if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: - self.fail("Non-DH cipher: " + cipher[0]) - - def test_selected_alpn_protocol(self): - # selected_alpn_protocol() is None unless ALPN is used. - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - self.assertIs(stats['client_alpn_protocol'], None) + @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, + "'tls-unique' channel binding not available") + def test_tls_unique_channel_binding(self): + """Test tls-unique channel binding.""" + if support.verbose: + sys.stdout.write("\n") - @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") - def test_selected_alpn_protocol_if_server_uses_alpn(self): - # selected_alpn_protocol() is None unless ALPN is used by the client. - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.load_verify_locations(CERTFILE) - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = test_wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + # get the data + cb_data = s.get_channel_binding("tls-unique") + if support.verbose: + sys.stdout.write(" got channel binding data: {0!r}\n" + .format(cb_data)) + + # check if it is sane + self.assertIsNotNone(cb_data) + self.assertEqual(len(cb_data), 12) # True for TLSv1 + + # and compare with the peers version + s.write(b"CB tls-unique\n") + peer_data_repr = s.read().strip() + self.assertEqual(peer_data_repr, + repr(cb_data).encode("us-ascii")) + s.close() + + # now, again + s = test_wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + new_cb_data = s.get_channel_binding("tls-unique") + if support.verbose: + sys.stdout.write(" got another channel binding data: {0!r}\n" + .format(new_cb_data)) + # is it really unique + self.assertNotEqual(cb_data, new_cb_data) + self.assertIsNotNone(cb_data) + self.assertEqual(len(cb_data), 12) # True for TLSv1 + s.write(b"CB tls-unique\n") + peer_data_repr = s.read().strip() + self.assertEqual(peer_data_repr, + repr(new_cb_data).encode("us-ascii")) + s.close() + + def test_compression(self): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + if support.verbose: + sys.stdout.write(" got compression: {!r}\n".format(stats['compression'])) + self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' }) + + @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'), + "ssl.OP_NO_COMPRESSION needed for this test") + def test_compression_disabled(self): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + context.options |= ssl.OP_NO_COMPRESSION + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['compression'], None) + + def test_dh_params(self): + # Check we can get a connection with ephemeral Diffie-Hellman + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + context.load_dh_params(DHFILE) + context.set_ciphers("kEDH") + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + cipher = stats["cipher"][0] + parts = cipher.split("-") + if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: + self.fail("Non-DH cipher: " + cipher[0]) + + def test_selected_alpn_protocol(self): + # selected_alpn_protocol() is None unless ALPN is used. + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") + def test_selected_alpn_protocol_if_server_uses_alpn(self): + # selected_alpn_protocol() is None unless ALPN is used by the client. + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.load_verify_locations(CERTFILE) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(CERTFILE) + server_context.set_alpn_protocols(['foo', 'bar']) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") + def test_alpn_protocols(self): + server_protocols = ['foo', 'bar', 'milkshake'] + protocol_tests = [ + (['foo', 'bar'], 'foo'), + (['bar', 'foo'], 'foo'), + (['milkshake'], 'milkshake'), + (['http/3.0', 'http/4.0'], None) + ] + for client_protocols, expected in protocol_tests: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) server_context.load_cert_chain(CERTFILE) - server_context.set_alpn_protocols(['foo', 'bar']) - stats = server_params_test(client_context, server_context, - chatty=True, connectionchatty=True) - self.assertIs(stats['client_alpn_protocol'], None) - - @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") - def test_alpn_protocols(self): - server_protocols = ['foo', 'bar', 'milkshake'] - protocol_tests = [ - (['foo', 'bar'], 'foo'), - (['bar', 'foo'], 'foo'), - (['milkshake'], 'milkshake'), - (['http/3.0', 'http/4.0'], None) - ] - for client_protocols, expected in protocol_tests: - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - server_context.load_cert_chain(CERTFILE) - server_context.set_alpn_protocols(server_protocols) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - client_context.load_cert_chain(CERTFILE) - client_context.set_alpn_protocols(client_protocols) + server_context.set_alpn_protocols(server_protocols) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + client_context.load_cert_chain(CERTFILE) + client_context.set_alpn_protocols(client_protocols) - try: - stats = server_params_test(client_context, - server_context, - chatty=True, - connectionchatty=True) - except ssl.SSLError as e: - stats = e - - if (expected is None and IS_OPENSSL_1_1 - and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)): - # OpenSSL 1.1.0 to 1.1.0e raises handshake error - self.assertIsInstance(stats, ssl.SSLError) - else: - msg = "failed trying %s (s) and %s (c).\n" \ - "was expecting %s, but got %%s from the %%s" \ - % (str(server_protocols), str(client_protocols), - str(expected)) - client_result = stats['client_alpn_protocol'] - self.assertEqual(client_result, expected, - msg % (client_result, "client")) - server_result = stats['server_alpn_protocols'][-1] \ - if len(stats['server_alpn_protocols']) else 'nothing' - self.assertEqual(server_result, expected, - msg % (server_result, "server")) - - def test_selected_npn_protocol(self): - # selected_npn_protocol() is None unless NPN is used - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - context.load_cert_chain(CERTFILE) - stats = server_params_test(context, context, - chatty=True, connectionchatty=True) - self.assertIs(stats['client_npn_protocol'], None) - - @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test") - def test_npn_protocols(self): - server_protocols = ['http/1.1', 'spdy/2'] - protocol_tests = [ - (['http/1.1', 'spdy/2'], 'http/1.1'), - (['spdy/2', 'http/1.1'], 'http/1.1'), - (['spdy/2', 'test'], 'spdy/2'), - (['abc', 'def'], 'abc') - ] - for client_protocols, expected in protocol_tests: - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(CERTFILE) - server_context.set_npn_protocols(server_protocols) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.load_cert_chain(CERTFILE) - client_context.set_npn_protocols(client_protocols) - stats = server_params_test(client_context, server_context, - chatty=True, connectionchatty=True) + try: + stats = server_params_test(client_context, + server_context, + chatty=True, + connectionchatty=True) + except ssl.SSLError as e: + stats = e + if (expected is None and IS_OPENSSL_1_1 + and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)): + # OpenSSL 1.1.0 to 1.1.0e raises handshake error + self.assertIsInstance(stats, ssl.SSLError) + else: msg = "failed trying %s (s) and %s (c).\n" \ - "was expecting %s, but got %%s from the %%s" \ - % (str(server_protocols), str(client_protocols), - str(expected)) - client_result = stats['client_npn_protocol'] - self.assertEqual(client_result, expected, msg % (client_result, "client")) - server_result = stats['server_npn_protocols'][-1] \ - if len(stats['server_npn_protocols']) else 'nothing' - self.assertEqual(server_result, expected, msg % (server_result, "server")) - - def sni_contexts(self): + "was expecting %s, but got %%s from the %%s" \ + % (str(server_protocols), str(client_protocols), + str(expected)) + client_result = stats['client_alpn_protocol'] + self.assertEqual(client_result, expected, + msg % (client_result, "client")) + server_result = stats['server_alpn_protocols'][-1] \ + if len(stats['server_alpn_protocols']) else 'nothing' + self.assertEqual(server_result, expected, + msg % (server_result, "server")) + + def test_selected_npn_protocol(self): + # selected_npn_protocol() is None unless NPN is used + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_npn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test") + def test_npn_protocols(self): + server_protocols = ['http/1.1', 'spdy/2'] + protocol_tests = [ + (['http/1.1', 'spdy/2'], 'http/1.1'), + (['spdy/2', 'http/1.1'], 'http/1.1'), + (['spdy/2', 'test'], 'spdy/2'), + (['abc', 'def'], 'abc') + ] + for client_protocols, expected in protocol_tests: server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - other_context.load_cert_chain(SIGNED_CERTFILE2) + server_context.load_cert_chain(CERTFILE) + server_context.set_npn_protocols(server_protocols) client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.load_verify_locations(SIGNING_CA) - return server_context, other_context, client_context + client_context.load_cert_chain(CERTFILE) + client_context.set_npn_protocols(client_protocols) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) - def check_common_name(self, stats, name): - cert = stats['peercert'] - self.assertIn((('commonName', name),), cert['subject']) + msg = "failed trying %s (s) and %s (c).\n" \ + "was expecting %s, but got %%s from the %%s" \ + % (str(server_protocols), str(client_protocols), + str(expected)) + client_result = stats['client_npn_protocol'] + self.assertEqual(client_result, expected, msg % (client_result, "client")) + server_result = stats['server_npn_protocols'][-1] \ + if len(stats['server_npn_protocols']) else 'nothing' + self.assertEqual(server_result, expected, msg % (server_result, "server")) + + def sni_contexts(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + other_context.load_cert_chain(SIGNED_CERTFILE2) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + return server_context, other_context, client_context + + def check_common_name(self, stats, name): + cert = stats['peercert'] + self.assertIn((('commonName', name),), cert['subject']) + + @needs_sni + def test_sni_callback(self): + calls = [] + server_context, other_context, client_context = self.sni_contexts() + + def servername_cb(ssl_sock, server_name, initial_context): + calls.append((server_name, initial_context)) + if server_name is not None: + ssl_sock.context = other_context + server_context.set_servername_callback(servername_cb) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='supermessage') + # The hostname was fetched properly, and the certificate was + # changed for the connection. + self.assertEqual(calls, [("supermessage", server_context)]) + # CERTFILE4 was selected + self.check_common_name(stats, 'fakehostname') + + calls = [] + # The callback is called with server_name=None + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name=None) + self.assertEqual(calls, [(None, server_context)]) + self.check_common_name(stats, 'localhost') + + # Check disabling the callback + calls = [] + server_context.set_servername_callback(None) + + stats = server_params_test(client_context, server_context, + chatty=True, + sni_name='notfunny') + # Certificate didn't change + self.check_common_name(stats, 'localhost') + self.assertEqual(calls, []) - @needs_sni - def test_sni_callback(self): - calls = [] - server_context, other_context, client_context = self.sni_contexts() + @needs_sni + def test_sni_callback_alert(self): + # Returning a TLS alert is reflected to the connecting client + server_context, other_context, client_context = self.sni_contexts() - def servername_cb(ssl_sock, server_name, initial_context): - calls.append((server_name, initial_context)) - if server_name is not None: - ssl_sock.context = other_context - server_context.set_servername_callback(servername_cb) + def cb_returning_alert(ssl_sock, server_name, initial_context): + return ssl.ALERT_DESCRIPTION_ACCESS_DENIED + server_context.set_servername_callback(cb_returning_alert) + with self.assertRaises(ssl.SSLError) as cm: stats = server_params_test(client_context, server_context, - chatty=True, + chatty=False, sni_name='supermessage') - # The hostname was fetched properly, and the certificate was - # changed for the connection. - self.assertEqual(calls, [("supermessage", server_context)]) - # CERTFILE4 was selected - self.check_common_name(stats, 'fakehostname') - - calls = [] - # The callback is called with server_name=None + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED') + + @needs_sni + def test_sni_callback_raising(self): + # Raising fails the connection with a TLS handshake failure alert. + server_context, other_context, client_context = self.sni_contexts() + + def cb_raising(ssl_sock, server_name, initial_context): + 1/0 + server_context.set_servername_callback(cb_raising) + + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: stats = server_params_test(client_context, server_context, - chatty=True, - sni_name=None) - self.assertEqual(calls, [(None, server_context)]) - self.check_common_name(stats, 'localhost') + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE') + self.assertIn("ZeroDivisionError", stderr.getvalue()) + + @needs_sni + def test_sni_callback_wrong_return_type(self): + # Returning the wrong return type terminates the TLS connection + # with an internal error alert. + server_context, other_context, client_context = self.sni_contexts() - # Check disabling the callback - calls = [] - server_context.set_servername_callback(None) + def cb_wrong_return_type(ssl_sock, server_name, initial_context): + return "foo" + server_context.set_servername_callback(cb_wrong_return_type) + with self.assertRaises(ssl.SSLError) as cm, \ + support.captured_stderr() as stderr: stats = server_params_test(client_context, server_context, - chatty=True, - sni_name='notfunny') - # Certificate didn't change - self.check_common_name(stats, 'localhost') - self.assertEqual(calls, []) - - @needs_sni - def test_sni_callback_alert(self): - # Returning a TLS alert is reflected to the connecting client - server_context, other_context, client_context = self.sni_contexts() - - def cb_returning_alert(ssl_sock, server_name, initial_context): - return ssl.ALERT_DESCRIPTION_ACCESS_DENIED - server_context.set_servername_callback(cb_returning_alert) - - with self.assertRaises(ssl.SSLError) as cm: - stats = server_params_test(client_context, server_context, - chatty=False, - sni_name='supermessage') - self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED') - - @needs_sni - def test_sni_callback_raising(self): - # Raising fails the connection with a TLS handshake failure alert. - server_context, other_context, client_context = self.sni_contexts() - - def cb_raising(ssl_sock, server_name, initial_context): - 1/0 - server_context.set_servername_callback(cb_raising) - - with self.assertRaises(ssl.SSLError) as cm, \ - support.captured_stderr() as stderr: - stats = server_params_test(client_context, server_context, - chatty=False, - sni_name='supermessage') - self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE') - self.assertIn("ZeroDivisionError", stderr.getvalue()) - - @needs_sni - def test_sni_callback_wrong_return_type(self): - # Returning the wrong return type terminates the TLS connection - # with an internal error alert. - server_context, other_context, client_context = self.sni_contexts() - - def cb_wrong_return_type(ssl_sock, server_name, initial_context): - return "foo" - server_context.set_servername_callback(cb_wrong_return_type) - - with self.assertRaises(ssl.SSLError) as cm, \ - support.captured_stderr() as stderr: - stats = server_params_test(client_context, server_context, - chatty=False, - sni_name='supermessage') - self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR') - self.assertIn("TypeError", stderr.getvalue()) - - def test_shared_ciphers(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.load_verify_locations(SIGNING_CA) - if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): - client_context.set_ciphers("AES128:AES256") - server_context.set_ciphers("AES256") - alg1 = "AES256" - alg2 = "AES-256" - else: - client_context.set_ciphers("AES:3DES") - server_context.set_ciphers("3DES") - alg1 = "3DES" - alg2 = "DES-CBC3" - - stats = server_params_test(client_context, server_context) - ciphers = stats['server_shared_ciphers'][0] - self.assertGreater(len(ciphers), 0) - for name, tls_version, bits in ciphers: - if not alg1 in name.split("-") and alg2 not in name: - self.fail(name) - - def test_read_write_after_close_raises_valuerror(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = ThreadedEchoServer(context=context, chatty=False) - - with server: - s = context.wrap_socket(socket.socket()) + chatty=False, + sni_name='supermessage') + self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR') + self.assertIn("TypeError", stderr.getvalue()) + + def test_shared_ciphers(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2): + client_context.set_ciphers("AES128:AES256") + server_context.set_ciphers("AES256") + alg1 = "AES256" + alg2 = "AES-256" + else: + client_context.set_ciphers("AES:3DES") + server_context.set_ciphers("3DES") + alg1 = "3DES" + alg2 = "DES-CBC3" + + stats = server_params_test(client_context, server_context) + ciphers = stats['server_shared_ciphers'][0] + self.assertGreater(len(ciphers), 0) + for name, tls_version, bits in ciphers: + if not alg1 in name.split("-") and alg2 not in name: + self.fail(name) + + def test_read_write_after_close_raises_valuerror(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + + with server: + s = context.wrap_socket(socket.socket()) + s.connect((HOST, server.port)) + s.close() + + self.assertRaises(ValueError, s.read, 1024) + self.assertRaises(ValueError, s.write, b'hello') + + def test_sendfile(self): + TEST_DATA = b"x" * 512 + with open(support.TESTFN, 'wb') as f: + f.write(TEST_DATA) + self.addCleanup(support.unlink, support.TESTFN) + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: s.connect((HOST, server.port)) - s.close() + with open(support.TESTFN, 'rb') as file: + s.sendfile(file) + self.assertEqual(s.recv(1024), TEST_DATA) + + def test_session(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + + # first connection without session + stats = server_params_test(client_context, server_context) + session = stats['session'] + self.assertTrue(session.id) + self.assertGreater(session.time, 0) + self.assertGreater(session.timeout, 0) + self.assertTrue(session.has_ticket) + if ssl.OPENSSL_VERSION_INFO > (1, 0, 1): + self.assertGreater(session.ticket_lifetime_hint, 0) + self.assertFalse(stats['session_reused']) + sess_stat = server_context.session_stats() + self.assertEqual(sess_stat['accept'], 1) + self.assertEqual(sess_stat['hits'], 0) + + # reuse session + stats = server_params_test(client_context, server_context, session=session) + sess_stat = server_context.session_stats() + self.assertEqual(sess_stat['accept'], 2) + self.assertEqual(sess_stat['hits'], 1) + self.assertTrue(stats['session_reused']) + session2 = stats['session'] + self.assertEqual(session2.id, session.id) + self.assertEqual(session2, session) + self.assertIsNot(session2, session) + self.assertGreaterEqual(session2.time, session.time) + self.assertGreaterEqual(session2.timeout, session.timeout) + + # another one without session + stats = server_params_test(client_context, server_context) + self.assertFalse(stats['session_reused']) + session3 = stats['session'] + self.assertNotEqual(session3.id, session.id) + self.assertNotEqual(session3, session) + sess_stat = server_context.session_stats() + self.assertEqual(sess_stat['accept'], 3) + self.assertEqual(sess_stat['hits'], 1) + + # reuse session again + stats = server_params_test(client_context, server_context, session=session) + self.assertTrue(stats['session_reused']) + session4 = stats['session'] + self.assertEqual(session4.id, session.id) + self.assertEqual(session4, session) + self.assertGreaterEqual(session4.time, session.time) + self.assertGreaterEqual(session4.timeout, session.timeout) + sess_stat = server_context.session_stats() + self.assertEqual(sess_stat['accept'], 4) + self.assertEqual(sess_stat['hits'], 2) + + def test_session_handling(self): + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + + context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context2.verify_mode = ssl.CERT_REQUIRED + context2.load_verify_locations(CERTFILE) + context2.load_cert_chain(CERTFILE) + + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: + # session is None before handshake + self.assertEqual(s.session, None) + self.assertEqual(s.session_reused, None) + s.connect((HOST, server.port)) + session = s.session + self.assertTrue(session) + with self.assertRaises(TypeError) as e: + s.session = object + self.assertEqual(str(e.exception), 'Value is not a SSLSession.') - self.assertRaises(ValueError, s.read, 1024) - self.assertRaises(ValueError, s.write, b'hello') - - def test_sendfile(self): - TEST_DATA = b"x" * 512 - with open(support.TESTFN, 'wb') as f: - f.write(TEST_DATA) - self.addCleanup(support.unlink, support.TESTFN) - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - server = ThreadedEchoServer(context=context, chatty=False) - with server: - with context.wrap_socket(socket.socket()) as s: - s.connect((HOST, server.port)) - with open(support.TESTFN, 'rb') as file: - s.sendfile(file) - self.assertEqual(s.recv(1024), TEST_DATA) + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + # cannot set session after handshake + with self.assertRaises(ValueError) as e: + s.session = session + self.assertEqual(str(e.exception), + 'Cannot set session after handshake.') - def test_session(self): - server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - server_context.load_cert_chain(SIGNED_CERTFILE) - client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - client_context.verify_mode = ssl.CERT_REQUIRED - client_context.load_verify_locations(SIGNING_CA) - - # first connection without session - stats = server_params_test(client_context, server_context) - session = stats['session'] - self.assertTrue(session.id) - self.assertGreater(session.time, 0) - self.assertGreater(session.timeout, 0) - self.assertTrue(session.has_ticket) - if ssl.OPENSSL_VERSION_INFO > (1, 0, 1): - self.assertGreater(session.ticket_lifetime_hint, 0) - self.assertFalse(stats['session_reused']) - sess_stat = server_context.session_stats() - self.assertEqual(sess_stat['accept'], 1) - self.assertEqual(sess_stat['hits'], 0) - - # reuse session - stats = server_params_test(client_context, server_context, session=session) - sess_stat = server_context.session_stats() - self.assertEqual(sess_stat['accept'], 2) - self.assertEqual(sess_stat['hits'], 1) - self.assertTrue(stats['session_reused']) - session2 = stats['session'] - self.assertEqual(session2.id, session.id) - self.assertEqual(session2, session) - self.assertIsNot(session2, session) - self.assertGreaterEqual(session2.time, session.time) - self.assertGreaterEqual(session2.timeout, session.timeout) - - # another one without session - stats = server_params_test(client_context, server_context) - self.assertFalse(stats['session_reused']) - session3 = stats['session'] - self.assertNotEqual(session3.id, session.id) - self.assertNotEqual(session3, session) - sess_stat = server_context.session_stats() - self.assertEqual(sess_stat['accept'], 3) - self.assertEqual(sess_stat['hits'], 1) - - # reuse session again - stats = server_params_test(client_context, server_context, session=session) - self.assertTrue(stats['session_reused']) - session4 = stats['session'] - self.assertEqual(session4.id, session.id) - self.assertEqual(session4, session) - self.assertGreaterEqual(session4.time, session.time) - self.assertGreaterEqual(session4.timeout, session.timeout) - sess_stat = server_context.session_stats() - self.assertEqual(sess_stat['accept'], 4) - self.assertEqual(sess_stat['hits'], 2) - - def test_session_handling(self): - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context.verify_mode = ssl.CERT_REQUIRED - context.load_verify_locations(CERTFILE) - context.load_cert_chain(CERTFILE) - - context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - context2.verify_mode = ssl.CERT_REQUIRED - context2.load_verify_locations(CERTFILE) - context2.load_cert_chain(CERTFILE) - - server = ThreadedEchoServer(context=context, chatty=False) - with server: - with context.wrap_socket(socket.socket()) as s: - # session is None before handshake - self.assertEqual(s.session, None) - self.assertEqual(s.session_reused, None) - s.connect((HOST, server.port)) - session = s.session - self.assertTrue(session) - with self.assertRaises(TypeError) as e: - s.session = object - self.assertEqual(str(e.exception), 'Value is not a SSLSession.') + with context.wrap_socket(socket.socket()) as s: + # can set session before handshake and before the + # connection was established + s.session = session + s.connect((HOST, server.port)) + self.assertEqual(s.session.id, session.id) + self.assertEqual(s.session, session) + self.assertEqual(s.session_reused, True) - with context.wrap_socket(socket.socket()) as s: - s.connect((HOST, server.port)) - # cannot set session after handshake - with self.assertRaises(ValueError) as e: - s.session = session - self.assertEqual(str(e.exception), - 'Cannot set session after handshake.') - - with context.wrap_socket(socket.socket()) as s: - # can set session before handshake and before the - # connection was established + with context2.wrap_socket(socket.socket()) as s: + # cannot re-use session with a different SSLContext + with self.assertRaises(ValueError) as e: s.session = session s.connect((HOST, server.port)) - self.assertEqual(s.session.id, session.id) - self.assertEqual(s.session, session) - self.assertEqual(s.session_reused, True) - - with context2.wrap_socket(socket.socket()) as s: - # cannot re-use session with a different SSLContext - with self.assertRaises(ValueError) as e: - s.session = session - s.connect((HOST, server.port)) - self.assertEqual(str(e.exception), - 'Session refers to a different SSLContext.') + self.assertEqual(str(e.exception), + 'Session refers to a different SSLContext.') def test_main(verbose=False): @@ -3610,22 +3603,17 @@ def test_main(verbose=False): tests = [ ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests, - SimpleBackgroundTests, + SimpleBackgroundTests, ThreadedTests, ] if support.is_resource_enabled('network'): tests.append(NetworkedTests) - if _have_threads: - thread_info = support.threading_setup() - if thread_info: - tests.append(ThreadedTests) - + thread_info = support.threading_setup() try: support.run_unittest(*tests) finally: - if _have_threads: - support.threading_cleanup(*thread_info) + support.threading_cleanup(*thread_info) if __name__ == "__main__": test_main() diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py index b7079b1..10ef87b 100644 --- a/Lib/test/test_subprocess.py +++ b/Lib/test/test_subprocess.py @@ -14,6 +14,7 @@ import selectors import sysconfig import select import shutil +import threading import gc import textwrap @@ -25,11 +26,6 @@ else: import ctypes.util try: - import threading -except ImportError: - threading = None - -try: import _testcapi except ImportError: _testcapi = None @@ -1196,7 +1192,6 @@ class ProcessTestCase(BaseTestCase): self.assertEqual(stderr, "") self.assertEqual(proc.returncode, 0) - @unittest.skipIf(threading is None, "threading required") def test_double_close_on_error(self): # Issue #18851 fds = [] @@ -1226,7 +1221,6 @@ class ProcessTestCase(BaseTestCase): if exc is not None: raise exc - @unittest.skipIf(threading is None, "threading required") def test_threadsafe_wait(self): """Issue21291: Popen.wait() needs to be threadsafe for returncode.""" proc = subprocess.Popen([sys.executable, '-c', diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py index 3844812..04550e5 100644 --- a/Lib/test/test_sys.py +++ b/Lib/test/test_sys.py @@ -10,15 +10,12 @@ import codecs import gc import sysconfig import locale +import threading # count the number of test runs, used to create unique # strings to intern in test_intern() numruns = 0 -try: - import threading -except ImportError: - threading = None class SysModuleTest(unittest.TestCase): @@ -172,7 +169,6 @@ class SysModuleTest(unittest.TestCase): sys.setcheckinterval(n) self.assertEqual(sys.getcheckinterval(), n) - @unittest.skipUnless(threading, 'Threading required for this test.') def test_switchinterval(self): self.assertRaises(TypeError, sys.setswitchinterval) self.assertRaises(TypeError, sys.setswitchinterval, "a") @@ -348,21 +344,8 @@ class SysModuleTest(unittest.TestCase): ) # sys._current_frames() is a CPython-only gimmick. - def test_current_frames(self): - have_threads = True - try: - import _thread - except ImportError: - have_threads = False - - if have_threads: - self.current_frames_with_threads() - else: - self.current_frames_without_threads() - - # Test sys._current_frames() in a WITH_THREADS build. @test.support.reap_threads - def current_frames_with_threads(self): + def test_current_frames(self): import threading import traceback @@ -426,15 +409,6 @@ class SysModuleTest(unittest.TestCase): leave_g.set() t.join() - # Test sys._current_frames() when thread support doesn't exist. - def current_frames_without_threads(self): - # Not much happens here: there is only one thread, with artificial - # "thread id" 0. - d = sys._current_frames() - self.assertEqual(len(d), 1) - self.assertIn(0, d) - self.assertTrue(d[0] is sys._getframe()) - def test_attributes(self): self.assertIsInstance(sys.api_version, int) self.assertIsInstance(sys.argv, list) @@ -516,8 +490,6 @@ class SysModuleTest(unittest.TestCase): if not sys.platform.startswith('win'): self.assertIsInstance(sys.abiflags, str) - @unittest.skipUnless(hasattr(sys, 'thread_info'), - 'Threading required for this test.') def test_thread_info(self): info = sys.thread_info self.assertEqual(len(info), 3) diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py index 51d82e1..414c328 100644 --- a/Lib/test/test_telnetlib.py +++ b/Lib/test/test_telnetlib.py @@ -1,11 +1,11 @@ import socket import selectors import telnetlib +import threading import contextlib from test import support import unittest -threading = support.import_module('threading') HOST = support.HOST diff --git a/Lib/test/test_threaded_import.py b/Lib/test/test_threaded_import.py index 75c66b0..035344b 100644 --- a/Lib/test/test_threaded_import.py +++ b/Lib/test/test_threaded_import.py @@ -11,12 +11,12 @@ import importlib import sys import time import shutil +import threading import unittest from unittest import mock from test.support import ( verbose, import_module, run_unittest, TESTFN, reap_threads, forget, unlink, rmtree, start_threads) -threading = import_module('threading') def task(N, done, done_tasks, errors): try: diff --git a/Lib/test/test_threadedtempfile.py b/Lib/test/test_threadedtempfile.py index b742036..f3d4ba3 100644 --- a/Lib/test/test_threadedtempfile.py +++ b/Lib/test/test_threadedtempfile.py @@ -19,9 +19,9 @@ FILES_PER_THREAD = 50 import tempfile from test.support import start_threads, import_module -threading = import_module('threading') import unittest import io +import threading from traceback import print_exc startEvent = threading.Event() diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 800d26f..912eb3f 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -9,8 +9,8 @@ from test.support.script_helper import assert_python_ok, assert_python_failure import random import sys -_thread = import_module('_thread') -threading = import_module('threading') +import _thread +import threading import time import unittest import weakref diff --git a/Lib/test/test_threading_local.py b/Lib/test/test_threading_local.py index 4092cf3..984f8dd 100644 --- a/Lib/test/test_threading_local.py +++ b/Lib/test/test_threading_local.py @@ -5,8 +5,8 @@ import weakref import gc # Modules under test -_thread = support.import_module('_thread') -threading = support.import_module('threading') +import _thread +import threading import _threading_local diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py index 810ec37..42323b9 100644 --- a/Lib/test/test_time.py +++ b/Lib/test/test_time.py @@ -9,10 +9,6 @@ import sysconfig import time import unittest try: - import threading -except ImportError: - threading = None -try: import _testcapi except ImportError: _testcapi = None diff --git a/Lib/test/test_tracemalloc.py b/Lib/test/test_tracemalloc.py index 742259b..780942a 100644 --- a/Lib/test/test_tracemalloc.py +++ b/Lib/test/test_tracemalloc.py @@ -7,10 +7,7 @@ from unittest.mock import patch from test.support.script_helper import (assert_python_ok, assert_python_failure, interpreter_requires_environment) from test import support -try: - import threading -except ImportError: - threading = None + try: import _testcapi except ImportError: diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py index f83f9cc..741d136 100644 --- a/Lib/test/test_urllib2_localnet.py +++ b/Lib/test/test_urllib2_localnet.py @@ -4,13 +4,12 @@ import email import urllib.parse import urllib.request import http.server +import threading import unittest import hashlib from test import support -threading = support.import_module('threading') - try: import ssl except ImportError: @@ -276,7 +275,6 @@ class FakeProxyHandler(http.server.BaseHTTPRequestHandler): # Test cases -@unittest.skipUnless(threading, "Threading required for this test.") class BasicAuthTests(unittest.TestCase): USER = "testUser" PASSWD = "testPass" @@ -317,7 +315,6 @@ class BasicAuthTests(unittest.TestCase): self.assertRaises(urllib.error.HTTPError, urllib.request.urlopen, self.server_url) -@unittest.skipUnless(threading, "Threading required for this test.") class ProxyAuthTests(unittest.TestCase): URL = "http://localhost" @@ -439,7 +436,6 @@ def GetRequestHandler(responses): return FakeHTTPRequestHandler -@unittest.skipUnless(threading, "Threading required for this test.") class TestUrlopen(unittest.TestCase): """Tests urllib.request.urlopen using the network. diff --git a/Lib/test/test_venv.py b/Lib/test/test_venv.py index 2691632..aac010b 100644 --- a/Lib/test/test_venv.py +++ b/Lib/test/test_venv.py @@ -15,15 +15,10 @@ import sys import tempfile from test.support import (captured_stdout, captured_stderr, can_symlink, EnvironmentVarGuard, rmtree) +import threading import unittest import venv - -try: - import threading -except ImportError: - threading = None - try: import ctypes except ImportError: @@ -420,8 +415,6 @@ class EnsurePipTest(BaseTest): if not system_site_packages: self.assert_pip_not_installed() - @unittest.skipUnless(threading, 'some dependencies of pip import threading' - ' module unconditionally') # Issue #26610: pip/pep425tags.py requires ctypes @unittest.skipUnless(ctypes, 'pip requires ctypes') def test_with_pip(self): diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index 43cf2c0..0384a9f 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -6,6 +6,7 @@ import weakref import operator import contextlib import copy +import threading import time from test import support @@ -78,7 +79,6 @@ def collect_in_thread(period=0.0001): """ Ensure GC collections happen in a different thread, at a high frequency. """ - threading = support.import_module('threading') please_stop = False def collect(): diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py index a609eef..f2b496a 100644 --- a/Lib/test/test_xmlrpc.py +++ b/Lib/test/test_xmlrpc.py @@ -10,6 +10,7 @@ import xmlrpc.server import http.client import http, http.server import socket +import threading import re import io import contextlib @@ -19,10 +20,6 @@ try: import gzip except ImportError: gzip = None -try: - import threading -except ImportError: - threading = None alist = [{'astring': 'foo@bar.baz.spam', 'afloat': 7283.43, @@ -307,7 +304,6 @@ class XMLRPCTestCase(unittest.TestCase): except OSError: self.assertTrue(has_ssl) - @unittest.skipUnless(threading, "Threading required for this test.") def test_keepalive_disconnect(self): class RequestHandler(http.server.BaseHTTPRequestHandler): protocol_version = "HTTP/1.1" @@ -747,7 +743,6 @@ def make_request_and_skipIf(condition, reason): return make_request_and_skip return decorator -@unittest.skipUnless(threading, 'Threading required for this test.') class BaseServerTestCase(unittest.TestCase): requestHandler = None request_count = 1 @@ -1206,7 +1201,6 @@ class FailingMessageClass(http.client.HTTPMessage): return super().get(key, failobj) -@unittest.skipUnless(threading, 'Threading required for this test.') class FailingServerTestCase(unittest.TestCase): def setUp(self): self.evt = threading.Event() diff --git a/Lib/threading.py b/Lib/threading.py index 06dbc68..e4bf974 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -990,38 +990,12 @@ class Thread: def _delete(self): "Remove current thread from the dict of currently running threads." - - # Notes about running with _dummy_thread: - # - # Must take care to not raise an exception if _dummy_thread is being - # used (and thus this module is being used as an instance of - # dummy_threading). _dummy_thread.get_ident() always returns 1 since - # there is only one thread if _dummy_thread is being used. Thus - # len(_active) is always <= 1 here, and any Thread instance created - # overwrites the (if any) thread currently registered in _active. - # - # An instance of _MainThread is always created by 'threading'. This - # gets overwritten the instant an instance of Thread is created; both - # threads return 1 from _dummy_thread.get_ident() and thus have the - # same key in the dict. So when the _MainThread instance created by - # 'threading' tries to clean itself up when atexit calls this method - # it gets a KeyError if another Thread instance was created. - # - # This all means that KeyError from trying to delete something from - # _active if dummy_threading is being used is a red herring. But - # since it isn't if dummy_threading is *not* being used then don't - # hide the exception. - - try: - with _active_limbo_lock: - del _active[get_ident()] - # There must not be any python code between the previous line - # and after the lock is released. Otherwise a tracing function - # could try to acquire the lock again in the same thread, (in - # current_thread()), and would block. - except KeyError: - if 'dummy_threading' not in _sys.modules: - raise + with _active_limbo_lock: + del _active[get_ident()] + # There must not be any python code between the previous line + # and after the lock is released. Otherwise a tracing function + # could try to acquire the lock again in the same thread, (in + # current_thread()), and would block. def join(self, timeout=None): """Wait until the thread terminates. diff --git a/Lib/trace.py b/Lib/trace.py index e443edd..48a1d1b 100755 --- a/Lib/trace.py +++ b/Lib/trace.py @@ -61,21 +61,15 @@ import dis import pickle from time import monotonic as _time -try: - import threading -except ImportError: - _settrace = sys.settrace - - def _unsettrace(): - sys.settrace(None) -else: - def _settrace(func): - threading.settrace(func) - sys.settrace(func) - - def _unsettrace(): - sys.settrace(None) - threading.settrace(None) +import threading + +def _settrace(func): + threading.settrace(func) + sys.settrace(func) + +def _unsettrace(): + sys.settrace(None) + threading.settrace(None) PRAGMA_NOCOVER = "#pragma NO COVER" diff --git a/Lib/zipfile.py b/Lib/zipfile.py index cc46a6c..37ce328 100644 --- a/Lib/zipfile.py +++ b/Lib/zipfile.py @@ -12,11 +12,7 @@ import stat import shutil import struct import binascii - -try: - import threading -except ImportError: - import dummy_threading as threading +import threading try: import zlib # We may need its compression method |