summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/_dummy_thread.py163
-rw-r--r--Lib/_pydecimal.py87
-rw-r--r--Lib/_pyio.py5
-rw-r--r--Lib/_strptime.py5
-rw-r--r--Lib/bz2.py6
-rw-r--r--Lib/ctypes/test/test_errno.py37
-rw-r--r--Lib/dummy_threading.py78
-rw-r--r--Lib/functools.py8
-rw-r--r--Lib/http/cookiejar.py5
-rw-r--r--Lib/logging/__init__.py18
-rw-r--r--Lib/logging/config.py9
-rw-r--r--Lib/logging/handlers.py215
-rw-r--r--Lib/queue.py5
-rw-r--r--Lib/reprlib.py5
-rw-r--r--Lib/sched.py5
-rw-r--r--Lib/socketserver.py5
-rw-r--r--Lib/sqlite3/test/dbapi.py6
-rw-r--r--Lib/subprocess.py5
-rw-r--r--Lib/tempfile.py6
-rw-r--r--Lib/test/fork_wait.py2
-rw-r--r--Lib/test/libregrtest/runtest_mp.py6
-rw-r--r--Lib/test/libregrtest/save_env.py9
-rw-r--r--Lib/test/support/__init__.py17
-rw-r--r--Lib/test/test_asynchat.py162
-rw-r--r--Lib/test/test_asyncio/__init__.py2
-rw-r--r--Lib/test/test_asyncore.py7
-rw-r--r--Lib/test/test_bz2.py13
-rw-r--r--Lib/test/test_capi.py10
-rw-r--r--Lib/test/test_concurrent_futures.py4
-rw-r--r--Lib/test/test_contextlib.py6
-rw-r--r--Lib/test/test_decimal.py9
-rw-r--r--Lib/test/test_docxmlrpc.py2
-rw-r--r--Lib/test/test_dummy_thread.py255
-rw-r--r--Lib/test/test_dummy_threading.py60
-rw-r--r--Lib/test/test_email/test_email.py5
-rw-r--r--Lib/test/test_enum.py7
-rw-r--r--Lib/test/test_faulthandler.py10
-rw-r--r--Lib/test/test_fork1.py2
-rw-r--r--Lib/test/test_ftplib.py2
-rw-r--r--Lib/test/test_functools.py8
-rw-r--r--Lib/test/test_gc.py7
-rw-r--r--Lib/test/test_gdb.py12
-rw-r--r--Lib/test/test_hashlib.py6
-rw-r--r--Lib/test/test_httpservers.py3
-rw-r--r--Lib/test/test_idle.py1
-rw-r--r--Lib/test/test_imaplib.py5
-rw-r--r--Lib/test/test_imp.py5
-rw-r--r--Lib/test/test_importlib/test_locks.py222
-rw-r--r--Lib/test/test_io.py13
-rw-r--r--Lib/test/test_logging.py534
-rw-r--r--Lib/test/test_nntplib.py9
-rw-r--r--Lib/test/test_os.py154
-rw-r--r--Lib/test/test_pdb.py3
-rw-r--r--Lib/test/test_poll.py6
-rw-r--r--Lib/test/test_poplib.py2
-rw-r--r--Lib/test/test_pydoc.py6
-rw-r--r--Lib/test/test_queue.py3
-rw-r--r--Lib/test/test_regrtest.py8
-rw-r--r--Lib/test/test_robotparser.py6
-rw-r--r--Lib/test/test_sched.py8
-rw-r--r--Lib/test/test_signal.py7
-rw-r--r--Lib/test/test_smtplib.py11
-rw-r--r--Lib/test/test_socket.py69
-rw-r--r--Lib/test/test_socketserver.py10
-rw-r--r--Lib/test/test_ssl.py3222
-rw-r--r--Lib/test/test_subprocess.py8
-rw-r--r--Lib/test/test_sys.py32
-rw-r--r--Lib/test/test_telnetlib.py2
-rw-r--r--Lib/test/test_threaded_import.py2
-rw-r--r--Lib/test/test_threadedtempfile.py2
-rw-r--r--Lib/test/test_threading.py4
-rw-r--r--Lib/test/test_threading_local.py4
-rw-r--r--Lib/test/test_time.py4
-rw-r--r--Lib/test/test_tracemalloc.py5
-rw-r--r--Lib/test/test_urllib2_localnet.py6
-rw-r--r--Lib/test/test_venv.py9
-rw-r--r--Lib/test/test_weakref.py2
-rw-r--r--Lib/test/test_xmlrpc.py8
-rw-r--r--Lib/threading.py38
-rwxr-xr-xLib/trace.py24
-rw-r--r--Lib/zipfile.py6
81 files changed, 2372 insertions, 3387 deletions
diff --git a/Lib/_dummy_thread.py b/Lib/_dummy_thread.py
deleted file mode 100644
index a2cae54..0000000
--- a/Lib/_dummy_thread.py
+++ /dev/null
@@ -1,163 +0,0 @@
-"""Drop-in replacement for the thread module.
-
-Meant to be used as a brain-dead substitute so that threaded code does
-not need to be rewritten for when the thread module is not present.
-
-Suggested usage is::
-
- try:
- import _thread
- except ImportError:
- import _dummy_thread as _thread
-
-"""
-# Exports only things specified by thread documentation;
-# skipping obsolete synonyms allocate(), start_new(), exit_thread().
-__all__ = ['error', 'start_new_thread', 'exit', 'get_ident', 'allocate_lock',
- 'interrupt_main', 'LockType']
-
-# A dummy value
-TIMEOUT_MAX = 2**31
-
-# NOTE: this module can be imported early in the extension building process,
-# and so top level imports of other modules should be avoided. Instead, all
-# imports are done when needed on a function-by-function basis. Since threads
-# are disabled, the import lock should not be an issue anyway (??).
-
-error = RuntimeError
-
-def start_new_thread(function, args, kwargs={}):
- """Dummy implementation of _thread.start_new_thread().
-
- Compatibility is maintained by making sure that ``args`` is a
- tuple and ``kwargs`` is a dictionary. If an exception is raised
- and it is SystemExit (which can be done by _thread.exit()) it is
- caught and nothing is done; all other exceptions are printed out
- by using traceback.print_exc().
-
- If the executed function calls interrupt_main the KeyboardInterrupt will be
- raised when the function returns.
-
- """
- if type(args) != type(tuple()):
- raise TypeError("2nd arg must be a tuple")
- if type(kwargs) != type(dict()):
- raise TypeError("3rd arg must be a dict")
- global _main
- _main = False
- try:
- function(*args, **kwargs)
- except SystemExit:
- pass
- except:
- import traceback
- traceback.print_exc()
- _main = True
- global _interrupt
- if _interrupt:
- _interrupt = False
- raise KeyboardInterrupt
-
-def exit():
- """Dummy implementation of _thread.exit()."""
- raise SystemExit
-
-def get_ident():
- """Dummy implementation of _thread.get_ident().
-
- Since this module should only be used when _threadmodule is not
- available, it is safe to assume that the current process is the
- only thread. Thus a constant can be safely returned.
- """
- return 1
-
-def allocate_lock():
- """Dummy implementation of _thread.allocate_lock()."""
- return LockType()
-
-def stack_size(size=None):
- """Dummy implementation of _thread.stack_size()."""
- if size is not None:
- raise error("setting thread stack size not supported")
- return 0
-
-def _set_sentinel():
- """Dummy implementation of _thread._set_sentinel()."""
- return LockType()
-
-class LockType(object):
- """Class implementing dummy implementation of _thread.LockType.
-
- Compatibility is maintained by maintaining self.locked_status
- which is a boolean that stores the state of the lock. Pickling of
- the lock, though, should not be done since if the _thread module is
- then used with an unpickled ``lock()`` from here problems could
- occur from this class not having atomic methods.
-
- """
-
- def __init__(self):
- self.locked_status = False
-
- def acquire(self, waitflag=None, timeout=-1):
- """Dummy implementation of acquire().
-
- For blocking calls, self.locked_status is automatically set to
- True and returned appropriately based on value of
- ``waitflag``. If it is non-blocking, then the value is
- actually checked and not set if it is already acquired. This
- is all done so that threading.Condition's assert statements
- aren't triggered and throw a little fit.
-
- """
- if waitflag is None or waitflag:
- self.locked_status = True
- return True
- else:
- if not self.locked_status:
- self.locked_status = True
- return True
- else:
- if timeout > 0:
- import time
- time.sleep(timeout)
- return False
-
- __enter__ = acquire
-
- def __exit__(self, typ, val, tb):
- self.release()
-
- def release(self):
- """Release the dummy lock."""
- # XXX Perhaps shouldn't actually bother to test? Could lead
- # to problems for complex, threaded code.
- if not self.locked_status:
- raise error
- self.locked_status = False
- return True
-
- def locked(self):
- return self.locked_status
-
- def __repr__(self):
- return "<%s %s.%s object at %s>" % (
- "locked" if self.locked_status else "unlocked",
- self.__class__.__module__,
- self.__class__.__qualname__,
- hex(id(self))
- )
-
-# Used to signal that interrupt_main was called in a "thread"
-_interrupt = False
-# True when not executing in a "thread"
-_main = True
-
-def interrupt_main():
- """Set _interrupt flag to True to have start_new_thread raise
- KeyboardInterrupt upon exiting."""
- if _main:
- raise KeyboardInterrupt
- else:
- global _interrupt
- _interrupt = True
diff --git a/Lib/_pydecimal.py b/Lib/_pydecimal.py
index edabf72..a43c75f 100644
--- a/Lib/_pydecimal.py
+++ b/Lib/_pydecimal.py
@@ -436,75 +436,34 @@ _rounding_modes = (ROUND_DOWN, ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_CEILING,
# work for older Pythons. If threads are not part of the build, create a
# mock threading object with threading.local() returning the module namespace.
-try:
- import threading
-except ImportError:
- # Python was compiled without threads; create a mock object instead
- class MockThreading(object):
- def local(self, sys=sys):
- return sys.modules[__xname__]
- threading = MockThreading()
- del MockThreading
-
-try:
- threading.local
-
-except AttributeError:
-
- # To fix reloading, force it to create a new context
- # Old contexts have different exceptions in their dicts, making problems.
- if hasattr(threading.current_thread(), '__decimal_context__'):
- del threading.current_thread().__decimal_context__
-
- def setcontext(context):
- """Set this thread's context to context."""
- if context in (DefaultContext, BasicContext, ExtendedContext):
- context = context.copy()
- context.clear_flags()
- threading.current_thread().__decimal_context__ = context
-
- def getcontext():
- """Returns this thread's context.
-
- If this thread does not yet have a context, returns
- a new context and sets this thread's context.
- New contexts are copies of DefaultContext.
- """
- try:
- return threading.current_thread().__decimal_context__
- except AttributeError:
- context = Context()
- threading.current_thread().__decimal_context__ = context
- return context
+import threading
-else:
+local = threading.local()
+if hasattr(local, '__decimal_context__'):
+ del local.__decimal_context__
- local = threading.local()
- if hasattr(local, '__decimal_context__'):
- del local.__decimal_context__
+def getcontext(_local=local):
+ """Returns this thread's context.
- def getcontext(_local=local):
- """Returns this thread's context.
-
- If this thread does not yet have a context, returns
- a new context and sets this thread's context.
- New contexts are copies of DefaultContext.
- """
- try:
- return _local.__decimal_context__
- except AttributeError:
- context = Context()
- _local.__decimal_context__ = context
- return context
-
- def setcontext(context, _local=local):
- """Set this thread's context to context."""
- if context in (DefaultContext, BasicContext, ExtendedContext):
- context = context.copy()
- context.clear_flags()
+ If this thread does not yet have a context, returns
+ a new context and sets this thread's context.
+ New contexts are copies of DefaultContext.
+ """
+ try:
+ return _local.__decimal_context__
+ except AttributeError:
+ context = Context()
_local.__decimal_context__ = context
+ return context
+
+def setcontext(context, _local=local):
+ """Set this thread's context to context."""
+ if context in (DefaultContext, BasicContext, ExtendedContext):
+ context = context.copy()
+ context.clear_flags()
+ _local.__decimal_context__ = context
- del threading, local # Don't contaminate the namespace
+del threading, local # Don't contaminate the namespace
def localcontext(ctx=None):
"""Return a context manager for a copy of the supplied context
diff --git a/Lib/_pyio.py b/Lib/_pyio.py
index 4653847..1e105f2 100644
--- a/Lib/_pyio.py
+++ b/Lib/_pyio.py
@@ -9,10 +9,7 @@ import errno
import stat
import sys
# Import _thread instead of threading to reduce startup cost
-try:
- from _thread import allocate_lock as Lock
-except ImportError:
- from _dummy_thread import allocate_lock as Lock
+from _thread import allocate_lock as Lock
if sys.platform in {'win32', 'cygwin'}:
from msvcrt import setmode as _setmode
else:
diff --git a/Lib/_strptime.py b/Lib/_strptime.py
index fe94361..284175d 100644
--- a/Lib/_strptime.py
+++ b/Lib/_strptime.py
@@ -19,10 +19,7 @@ from re import escape as re_escape
from datetime import (date as datetime_date,
timedelta as datetime_timedelta,
timezone as datetime_timezone)
-try:
- from _thread import allocate_lock as _thread_allocate_lock
-except ImportError:
- from _dummy_thread import allocate_lock as _thread_allocate_lock
+from _thread import allocate_lock as _thread_allocate_lock
__all__ = []
diff --git a/Lib/bz2.py b/Lib/bz2.py
index 6f56328..3924aae 100644
--- a/Lib/bz2.py
+++ b/Lib/bz2.py
@@ -14,11 +14,7 @@ import io
import os
import warnings
import _compression
-
-try:
- from threading import RLock
-except ImportError:
- from dummy_threading import RLock
+from threading import RLock
from _bz2 import BZ2Compressor, BZ2Decompressor
diff --git a/Lib/ctypes/test/test_errno.py b/Lib/ctypes/test/test_errno.py
index 4690a0d..3685164 100644
--- a/Lib/ctypes/test/test_errno.py
+++ b/Lib/ctypes/test/test_errno.py
@@ -1,10 +1,8 @@
import unittest, os, errno
+import threading
+
from ctypes import *
from ctypes.util import find_library
-try:
- import threading
-except ImportError:
- threading = None
class Test(unittest.TestCase):
def test_open(self):
@@ -25,25 +23,24 @@ class Test(unittest.TestCase):
self.assertEqual(set_errno(32), errno.ENOENT)
self.assertEqual(get_errno(), 32)
- if threading:
- def _worker():
- set_errno(0)
+ def _worker():
+ set_errno(0)
- libc = CDLL(libc_name, use_errno=False)
- if os.name == "nt":
- libc_open = libc._open
- else:
- libc_open = libc.open
- libc_open.argtypes = c_char_p, c_int
- self.assertEqual(libc_open(b"", 0), -1)
- self.assertEqual(get_errno(), 0)
+ libc = CDLL(libc_name, use_errno=False)
+ if os.name == "nt":
+ libc_open = libc._open
+ else:
+ libc_open = libc.open
+ libc_open.argtypes = c_char_p, c_int
+ self.assertEqual(libc_open(b"", 0), -1)
+ self.assertEqual(get_errno(), 0)
- t = threading.Thread(target=_worker)
- t.start()
- t.join()
+ t = threading.Thread(target=_worker)
+ t.start()
+ t.join()
- self.assertEqual(get_errno(), 32)
- set_errno(0)
+ self.assertEqual(get_errno(), 32)
+ set_errno(0)
@unittest.skipUnless(os.name == "nt", 'Test specific to Windows')
def test_GetLastError(self):
diff --git a/Lib/dummy_threading.py b/Lib/dummy_threading.py
deleted file mode 100644
index 1bb7eee..0000000
--- a/Lib/dummy_threading.py
+++ /dev/null
@@ -1,78 +0,0 @@
-"""Faux ``threading`` version using ``dummy_thread`` instead of ``thread``.
-
-The module ``_dummy_threading`` is added to ``sys.modules`` in order
-to not have ``threading`` considered imported. Had ``threading`` been
-directly imported it would have made all subsequent imports succeed
-regardless of whether ``_thread`` was available which is not desired.
-
-"""
-from sys import modules as sys_modules
-
-import _dummy_thread
-
-# Declaring now so as to not have to nest ``try``s to get proper clean-up.
-holding_thread = False
-holding_threading = False
-holding__threading_local = False
-
-try:
- # Could have checked if ``_thread`` was not in sys.modules and gone
- # a different route, but decided to mirror technique used with
- # ``threading`` below.
- if '_thread' in sys_modules:
- held_thread = sys_modules['_thread']
- holding_thread = True
- # Must have some module named ``_thread`` that implements its API
- # in order to initially import ``threading``.
- sys_modules['_thread'] = sys_modules['_dummy_thread']
-
- if 'threading' in sys_modules:
- # If ``threading`` is already imported, might as well prevent
- # trying to import it more than needed by saving it if it is
- # already imported before deleting it.
- held_threading = sys_modules['threading']
- holding_threading = True
- del sys_modules['threading']
-
- if '_threading_local' in sys_modules:
- # If ``_threading_local`` is already imported, might as well prevent
- # trying to import it more than needed by saving it if it is
- # already imported before deleting it.
- held__threading_local = sys_modules['_threading_local']
- holding__threading_local = True
- del sys_modules['_threading_local']
-
- import threading
- # Need a copy of the code kept somewhere...
- sys_modules['_dummy_threading'] = sys_modules['threading']
- del sys_modules['threading']
- sys_modules['_dummy__threading_local'] = sys_modules['_threading_local']
- del sys_modules['_threading_local']
- from _dummy_threading import *
- from _dummy_threading import __all__
-
-finally:
- # Put back ``threading`` if we overwrote earlier
-
- if holding_threading:
- sys_modules['threading'] = held_threading
- del held_threading
- del holding_threading
-
- # Put back ``_threading_local`` if we overwrote earlier
-
- if holding__threading_local:
- sys_modules['_threading_local'] = held__threading_local
- del held__threading_local
- del holding__threading_local
-
- # Put back ``thread`` if we overwrote, else del the entry we made
- if holding_thread:
- sys_modules['_thread'] = held_thread
- del held_thread
- else:
- del sys_modules['_thread']
- del holding_thread
-
- del _dummy_thread
- del sys_modules
diff --git a/Lib/functools.py b/Lib/functools.py
index 23ad160..25075de 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -22,13 +22,7 @@ from collections import namedtuple
from types import MappingProxyType
from weakref import WeakKeyDictionary
from reprlib import recursive_repr
-try:
- from _thread import RLock
-except ImportError:
- class RLock:
- 'Dummy reentrant lock for builds without threads'
- def __enter__(self): pass
- def __exit__(self, exctype, excinst, exctb): pass
+from _thread import RLock
################################################################################
diff --git a/Lib/http/cookiejar.py b/Lib/http/cookiejar.py
index adf956d..e0f1032 100644
--- a/Lib/http/cookiejar.py
+++ b/Lib/http/cookiejar.py
@@ -33,10 +33,7 @@ import datetime
import re
import time
import urllib.parse, urllib.request
-try:
- import threading as _threading
-except ImportError:
- import dummy_threading as _threading
+import threading as _threading
import http.client # only for the default HTTP port
from calendar import timegm
diff --git a/Lib/logging/__init__.py b/Lib/logging/__init__.py
index 83db827..19b96b8 100644
--- a/Lib/logging/__init__.py
+++ b/Lib/logging/__init__.py
@@ -37,10 +37,7 @@ __all__ = ['BASIC_FORMAT', 'BufferingFormatter', 'CRITICAL', 'DEBUG', 'ERROR',
'warn', 'warning', 'getLogRecordFactory', 'setLogRecordFactory',
'lastResort', 'raiseExceptions']
-try:
- import threading
-except ImportError: #pragma: no cover
- threading = None
+import threading
__author__ = "Vinay Sajip <vinay_sajip@red-dove.com>"
__status__ = "production"
@@ -210,11 +207,7 @@ def _checkLevel(level):
#the lock would already have been acquired - so we need an RLock.
#The same argument applies to Loggers and Manager.loggerDict.
#
-if threading:
- _lock = threading.RLock()
-else: #pragma: no cover
- _lock = None
-
+_lock = threading.RLock()
def _acquireLock():
"""
@@ -295,7 +288,7 @@ class LogRecord(object):
self.created = ct
self.msecs = (ct - int(ct)) * 1000
self.relativeCreated = (self.created - _startTime) * 1000
- if logThreads and threading:
+ if logThreads:
self.thread = threading.get_ident()
self.threadName = threading.current_thread().name
else: # pragma: no cover
@@ -799,10 +792,7 @@ class Handler(Filterer):
"""
Acquire a thread lock for serializing access to the underlying I/O.
"""
- if threading:
- self.lock = threading.RLock()
- else: #pragma: no cover
- self.lock = None
+ self.lock = threading.RLock()
def acquire(self):
"""
diff --git a/Lib/logging/config.py b/Lib/logging/config.py
index b3f4e28..c16a75a 100644
--- a/Lib/logging/config.py
+++ b/Lib/logging/config.py
@@ -31,14 +31,9 @@ import logging.handlers
import re
import struct
import sys
+import threading
import traceback
-try:
- import _thread as thread
- import threading
-except ImportError: #pragma: no cover
- thread = None
-
from socketserver import ThreadingTCPServer, StreamRequestHandler
@@ -816,8 +811,6 @@ def listen(port=DEFAULT_LOGGING_CONFIG_PORT, verify=None):
normal. Note that you can return transformed bytes, e.g. by decrypting
the bytes passed in.
"""
- if not thread: #pragma: no cover
- raise NotImplementedError("listen() needs threading to work")
class ConfigStreamHandler(StreamRequestHandler):
"""
diff --git a/Lib/logging/handlers.py b/Lib/logging/handlers.py
index b5fdfbc..a815f03 100644
--- a/Lib/logging/handlers.py
+++ b/Lib/logging/handlers.py
@@ -26,10 +26,7 @@ To use, simply 'import logging.handlers' and log away!
import logging, socket, os, pickle, struct, time, re
from stat import ST_DEV, ST_INO, ST_MTIME
import queue
-try:
- import threading
-except ImportError: #pragma: no cover
- threading = None
+import threading
#
# Some constants...
@@ -1395,110 +1392,110 @@ class QueueHandler(logging.Handler):
except Exception:
self.handleError(record)
-if threading:
- class QueueListener(object):
- """
- This class implements an internal threaded listener which watches for
- LogRecords being added to a queue, removes them and passes them to a
- list of handlers for processing.
- """
- _sentinel = None
-
- def __init__(self, queue, *handlers, respect_handler_level=False):
- """
- Initialise an instance with the specified queue and
- handlers.
- """
- self.queue = queue
- self.handlers = handlers
- self._thread = None
- self.respect_handler_level = respect_handler_level
-
- def dequeue(self, block):
- """
- Dequeue a record and return it, optionally blocking.
-
- The base implementation uses get. You may want to override this method
- if you want to use timeouts or work with custom queue implementations.
- """
- return self.queue.get(block)
-
- def start(self):
- """
- Start the listener.
-
- This starts up a background thread to monitor the queue for
- LogRecords to process.
- """
- self._thread = t = threading.Thread(target=self._monitor)
- t.daemon = True
- t.start()
-
- def prepare(self , record):
- """
- Prepare a record for handling.
-
- This method just returns the passed-in record. You may want to
- override this method if you need to do any custom marshalling or
- manipulation of the record before passing it to the handlers.
- """
- return record
-
- def handle(self, record):
- """
- Handle a record.
-
- This just loops through the handlers offering them the record
- to handle.
- """
- record = self.prepare(record)
- for handler in self.handlers:
- if not self.respect_handler_level:
- process = True
- else:
- process = record.levelno >= handler.level
- if process:
- handler.handle(record)
-
- def _monitor(self):
- """
- Monitor the queue for records, and ask the handler
- to deal with them.
-
- This method runs on a separate, internal thread.
- The thread will terminate if it sees a sentinel object in the queue.
- """
- q = self.queue
- has_task_done = hasattr(q, 'task_done')
- while True:
- try:
- record = self.dequeue(True)
- if record is self._sentinel:
- break
- self.handle(record)
- if has_task_done:
- q.task_done()
- except queue.Empty:
+
+class QueueListener(object):
+ """
+ This class implements an internal threaded listener which watches for
+ LogRecords being added to a queue, removes them and passes them to a
+ list of handlers for processing.
+ """
+ _sentinel = None
+
+ def __init__(self, queue, *handlers, respect_handler_level=False):
+ """
+ Initialise an instance with the specified queue and
+ handlers.
+ """
+ self.queue = queue
+ self.handlers = handlers
+ self._thread = None
+ self.respect_handler_level = respect_handler_level
+
+ def dequeue(self, block):
+ """
+ Dequeue a record and return it, optionally blocking.
+
+ The base implementation uses get. You may want to override this method
+ if you want to use timeouts or work with custom queue implementations.
+ """
+ return self.queue.get(block)
+
+ def start(self):
+ """
+ Start the listener.
+
+ This starts up a background thread to monitor the queue for
+ LogRecords to process.
+ """
+ self._thread = t = threading.Thread(target=self._monitor)
+ t.daemon = True
+ t.start()
+
+ def prepare(self , record):
+ """
+ Prepare a record for handling.
+
+ This method just returns the passed-in record. You may want to
+ override this method if you need to do any custom marshalling or
+ manipulation of the record before passing it to the handlers.
+ """
+ return record
+
+ def handle(self, record):
+ """
+ Handle a record.
+
+ This just loops through the handlers offering them the record
+ to handle.
+ """
+ record = self.prepare(record)
+ for handler in self.handlers:
+ if not self.respect_handler_level:
+ process = True
+ else:
+ process = record.levelno >= handler.level
+ if process:
+ handler.handle(record)
+
+ def _monitor(self):
+ """
+ Monitor the queue for records, and ask the handler
+ to deal with them.
+
+ This method runs on a separate, internal thread.
+ The thread will terminate if it sees a sentinel object in the queue.
+ """
+ q = self.queue
+ has_task_done = hasattr(q, 'task_done')
+ while True:
+ try:
+ record = self.dequeue(True)
+ if record is self._sentinel:
break
+ self.handle(record)
+ if has_task_done:
+ q.task_done()
+ except queue.Empty:
+ break
+
+ def enqueue_sentinel(self):
+ """
+ This is used to enqueue the sentinel record.
+
+ The base implementation uses put_nowait. You may want to override this
+ method if you want to use timeouts or work with custom queue
+ implementations.
+ """
+ self.queue.put_nowait(self._sentinel)
- def enqueue_sentinel(self):
- """
- This is used to enqueue the sentinel record.
-
- The base implementation uses put_nowait. You may want to override this
- method if you want to use timeouts or work with custom queue
- implementations.
- """
- self.queue.put_nowait(self._sentinel)
-
- def stop(self):
- """
- Stop the listener.
-
- This asks the thread to terminate, and then waits for it to do so.
- Note that if you don't call this before your application exits, there
- may be some records still left on the queue, which won't be processed.
- """
- self.enqueue_sentinel()
- self._thread.join()
- self._thread = None
+ def stop(self):
+ """
+ Stop the listener.
+
+ This asks the thread to terminate, and then waits for it to do so.
+ Note that if you don't call this before your application exits, there
+ may be some records still left on the queue, which won't be processed.
+ """
+ self.enqueue_sentinel()
+ self._thread.join()
+ self._thread = None
diff --git a/Lib/queue.py b/Lib/queue.py
index 572425e..c803b96 100644
--- a/Lib/queue.py
+++ b/Lib/queue.py
@@ -1,9 +1,6 @@
'''A multi-producer, multi-consumer queue.'''
-try:
- import threading
-except ImportError:
- import dummy_threading as threading
+import threading
from collections import deque
from heapq import heappush, heappop
from time import monotonic as time
diff --git a/Lib/reprlib.py b/Lib/reprlib.py
index 40d991f..616b343 100644
--- a/Lib/reprlib.py
+++ b/Lib/reprlib.py
@@ -4,10 +4,7 @@ __all__ = ["Repr", "repr", "recursive_repr"]
import builtins
from itertools import islice
-try:
- from _thread import get_ident
-except ImportError:
- from _dummy_thread import get_ident
+from _thread import get_ident
def recursive_repr(fillvalue='...'):
'Decorator to make a repr function return fillvalue for a recursive call'
diff --git a/Lib/sched.py b/Lib/sched.py
index 3d8c011..ff87874 100644
--- a/Lib/sched.py
+++ b/Lib/sched.py
@@ -26,10 +26,7 @@ has another way to reference private data (besides global variables).
import time
import heapq
from collections import namedtuple
-try:
- import threading
-except ImportError:
- import dummy_threading as threading
+import threading
from time import monotonic as _time
__all__ = ["scheduler"]
diff --git a/Lib/socketserver.py b/Lib/socketserver.py
index df17114..721eb50 100644
--- a/Lib/socketserver.py
+++ b/Lib/socketserver.py
@@ -127,10 +127,7 @@ import socket
import selectors
import os
import sys
-try:
- import threading
-except ImportError:
- import dummy_threading as threading
+import threading
from io import BufferedIOBase
from time import monotonic as time
diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py
index cb85814..12c9fb4 100644
--- a/Lib/sqlite3/test/dbapi.py
+++ b/Lib/sqlite3/test/dbapi.py
@@ -21,12 +21,9 @@
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
+import threading
import unittest
import sqlite3 as sqlite
-try:
- import threading
-except ImportError:
- threading = None
from test.support import TESTFN, unlink
@@ -503,7 +500,6 @@ class CursorTests(unittest.TestCase):
self.assertEqual(results, expected)
-@unittest.skipUnless(threading, 'This test requires threading.')
class ThreadTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
diff --git a/Lib/subprocess.py b/Lib/subprocess.py
index 25e5698..f61ff0c 100644
--- a/Lib/subprocess.py
+++ b/Lib/subprocess.py
@@ -138,10 +138,7 @@ else:
import _posixsubprocess
import select
import selectors
- try:
- import threading
- except ImportError:
- import dummy_threading as threading
+ import threading
# When select or poll has indicated that the file is writable,
# we can write up to _PIPE_BUF bytes without risk of blocking.
diff --git a/Lib/tempfile.py b/Lib/tempfile.py
index 6146235..71ecafa 100644
--- a/Lib/tempfile.py
+++ b/Lib/tempfile.py
@@ -44,11 +44,7 @@ import shutil as _shutil
import errno as _errno
from random import Random as _Random
import weakref as _weakref
-
-try:
- import _thread
-except ImportError:
- import _dummy_thread as _thread
+import _thread
_allocate_lock = _thread.allocate_lock
_text_openflags = _os.O_RDWR | _os.O_CREAT | _os.O_EXCL
diff --git a/Lib/test/fork_wait.py b/Lib/test/fork_wait.py
index 6af79ad..9850b06 100644
--- a/Lib/test/fork_wait.py
+++ b/Lib/test/fork_wait.py
@@ -10,9 +10,9 @@ active threads survive in the child after a fork(); this is an error.
"""
import os, sys, time, unittest
+import threading
import test.support as support
-threading = support.import_module('threading')
LONGSLEEP = 2
SHORTSLEEP = 0.5
diff --git a/Lib/test/libregrtest/runtest_mp.py b/Lib/test/libregrtest/runtest_mp.py
index 779ff01..31b830d 100644
--- a/Lib/test/libregrtest/runtest_mp.py
+++ b/Lib/test/libregrtest/runtest_mp.py
@@ -3,15 +3,11 @@ import json
import os
import queue
import sys
+import threading
import time
import traceback
import types
from test import support
-try:
- import threading
-except ImportError:
- print("Multiprocess option requires thread support")
- sys.exit(2)
from test.libregrtest.runtest import (
runtest, INTERRUPTED, CHILD_ERROR, PROGRESS_MIN_TIME,
diff --git a/Lib/test/libregrtest/save_env.py b/Lib/test/libregrtest/save_env.py
index 3c45621..45b365d 100644
--- a/Lib/test/libregrtest/save_env.py
+++ b/Lib/test/libregrtest/save_env.py
@@ -5,13 +5,10 @@ import os
import shutil
import sys
import sysconfig
+import threading
import warnings
from test import support
try:
- import threading
-except ImportError:
- threading = None
-try:
import _multiprocessing, multiprocessing.process
except ImportError:
multiprocessing = None
@@ -181,13 +178,9 @@ class saved_test_environment:
# Controlling dangling references to Thread objects can make it easier
# to track reference leaks.
def get_threading__dangling(self):
- if not threading:
- return None
# This copies the weakrefs without making any strong reference
return threading._dangling.copy()
def restore_threading__dangling(self, saved):
- if not threading:
- return
threading._dangling.clear()
threading._dangling.update(saved)
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
index 0235498..bfceba1 100644
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -25,6 +25,8 @@ import subprocess
import sys
import sysconfig
import tempfile
+import _thread
+import threading
import time
import types
import unittest
@@ -32,11 +34,6 @@ import urllib.error
import warnings
try:
- import _thread, threading
-except ImportError:
- _thread = None
- threading = None
-try:
import multiprocessing.process
except ImportError:
multiprocessing = None
@@ -2028,16 +2025,11 @@ environment_altered = False
# at the end of a test run.
def threading_setup():
- if _thread:
- return _thread._count(), threading._dangling.copy()
- else:
- return 1, ()
+ return _thread._count(), threading._dangling.copy()
def threading_cleanup(*original_values):
global environment_altered
- if not _thread:
- return
_MAX_COUNT = 100
t0 = time.monotonic()
for count in range(_MAX_COUNT):
@@ -2061,9 +2053,6 @@ def reap_threads(func):
ensure that the threads are cleaned up even when the test fails.
If threading is unavailable this function does nothing.
"""
- if not _thread:
- return func
-
@functools.wraps(func)
def decorator(*args):
key = threading_setup()
diff --git a/Lib/test/test_asynchat.py b/Lib/test/test_asynchat.py
index 0eba76d..2362834 100644
--- a/Lib/test/test_asynchat.py
+++ b/Lib/test/test_asynchat.py
@@ -2,110 +2,104 @@
from test import support
-# If this fails, the test will be skipped.
-thread = support.import_module('_thread')
-
import asynchat
import asyncore
import errno
import socket
import sys
+import _thread as thread
+import threading
import time
import unittest
import unittest.mock
-try:
- import threading
-except ImportError:
- threading = None
HOST = support.HOST
SERVER_QUIT = b'QUIT\n'
TIMEOUT = 3.0
-if threading:
- class echo_server(threading.Thread):
- # parameter to determine the number of bytes passed back to the
- # client each send
- chunk_size = 1
-
- def __init__(self, event):
- threading.Thread.__init__(self)
- self.event = event
- self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.port = support.bind_port(self.sock)
- # This will be set if the client wants us to wait before echoing
- # data back.
- self.start_resend_event = None
-
- def run(self):
- self.sock.listen()
- self.event.set()
- conn, client = self.sock.accept()
- self.buffer = b""
- # collect data until quit message is seen
- while SERVER_QUIT not in self.buffer:
- data = conn.recv(1)
- if not data:
- break
- self.buffer = self.buffer + data
-
- # remove the SERVER_QUIT message
- self.buffer = self.buffer.replace(SERVER_QUIT, b'')
-
- if self.start_resend_event:
- self.start_resend_event.wait()
-
- # re-send entire set of collected data
- try:
- # this may fail on some tests, such as test_close_when_done,
- # since the client closes the channel when it's done sending
- while self.buffer:
- n = conn.send(self.buffer[:self.chunk_size])
- time.sleep(0.001)
- self.buffer = self.buffer[n:]
- except:
- pass
-
- conn.close()
- self.sock.close()
- class echo_client(asynchat.async_chat):
-
- def __init__(self, terminator, server_port):
- asynchat.async_chat.__init__(self)
- self.contents = []
- self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
- self.connect((HOST, server_port))
- self.set_terminator(terminator)
- self.buffer = b""
-
- def handle_connect(self):
+class echo_server(threading.Thread):
+ # parameter to determine the number of bytes passed back to the
+ # client each send
+ chunk_size = 1
+
+ def __init__(self, event):
+ threading.Thread.__init__(self)
+ self.event = event
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.port = support.bind_port(self.sock)
+ # This will be set if the client wants us to wait before echoing
+ # data back.
+ self.start_resend_event = None
+
+ def run(self):
+ self.sock.listen()
+ self.event.set()
+ conn, client = self.sock.accept()
+ self.buffer = b""
+ # collect data until quit message is seen
+ while SERVER_QUIT not in self.buffer:
+ data = conn.recv(1)
+ if not data:
+ break
+ self.buffer = self.buffer + data
+
+ # remove the SERVER_QUIT message
+ self.buffer = self.buffer.replace(SERVER_QUIT, b'')
+
+ if self.start_resend_event:
+ self.start_resend_event.wait()
+
+ # re-send entire set of collected data
+ try:
+ # this may fail on some tests, such as test_close_when_done,
+ # since the client closes the channel when it's done sending
+ while self.buffer:
+ n = conn.send(self.buffer[:self.chunk_size])
+ time.sleep(0.001)
+ self.buffer = self.buffer[n:]
+ except:
+ pass
+
+ conn.close()
+ self.sock.close()
+
+class echo_client(asynchat.async_chat):
+
+ def __init__(self, terminator, server_port):
+ asynchat.async_chat.__init__(self)
+ self.contents = []
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.connect((HOST, server_port))
+ self.set_terminator(terminator)
+ self.buffer = b""
+
+ def handle_connect(self):
+ pass
+
+ if sys.platform == 'darwin':
+ # select.poll returns a select.POLLHUP at the end of the tests
+ # on darwin, so just ignore it
+ def handle_expt(self):
pass
- if sys.platform == 'darwin':
- # select.poll returns a select.POLLHUP at the end of the tests
- # on darwin, so just ignore it
- def handle_expt(self):
- pass
-
- def collect_incoming_data(self, data):
- self.buffer += data
+ def collect_incoming_data(self, data):
+ self.buffer += data
- def found_terminator(self):
- self.contents.append(self.buffer)
- self.buffer = b""
+ def found_terminator(self):
+ self.contents.append(self.buffer)
+ self.buffer = b""
- def start_echo_server():
- event = threading.Event()
- s = echo_server(event)
- s.start()
- event.wait()
- event.clear()
- time.sleep(0.01) # Give server time to start accepting.
- return s, event
+def start_echo_server():
+ event = threading.Event()
+ s = echo_server(event)
+ s.start()
+ event.wait()
+ event.clear()
+ time.sleep(0.01) # Give server time to start accepting.
+ return s, event
-@unittest.skipUnless(threading, 'Threading required for this test.')
class TestAsynchat(unittest.TestCase):
usepoll = False
diff --git a/Lib/test/test_asyncio/__init__.py b/Lib/test/test_asyncio/__init__.py
index 80a9eea..c77c7a8 100644
--- a/Lib/test/test_asyncio/__init__.py
+++ b/Lib/test/test_asyncio/__init__.py
@@ -1,8 +1,6 @@
import os
from test.support import load_package_tests, import_module
-# Skip tests if we don't have threading.
-import_module('threading')
# Skip tests if we don't have concurrent.futures.
import_module('concurrent.futures')
diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py
index 07edf22..c8e9727 100644
--- a/Lib/test/test_asyncore.py
+++ b/Lib/test/test_asyncore.py
@@ -7,6 +7,7 @@ import sys
import time
import errno
import struct
+import threading
from test import support
from io import BytesIO
@@ -14,10 +15,6 @@ from io import BytesIO
if support.PGO:
raise unittest.SkipTest("test is not helpful for PGO")
-try:
- import threading
-except ImportError:
- threading = None
TIMEOUT = 3
HAS_UNIX_SOCKETS = hasattr(socket, 'AF_UNIX')
@@ -326,7 +323,6 @@ class DispatcherWithSendTests(unittest.TestCase):
def tearDown(self):
asyncore.close_all()
- @unittest.skipUnless(threading, 'Threading required for this test.')
@support.reap_threads
def test_send(self):
evt = threading.Event()
@@ -776,7 +772,6 @@ class BaseTestAPI:
self.assertTrue(s.socket.getsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR))
- @unittest.skipUnless(threading, 'Threading required for this test.')
@support.reap_threads
def test_quick_connect(self):
# see: http://bugs.python.org/issue10340
diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py
index e76b283..58dbf96 100644
--- a/Lib/test/test_bz2.py
+++ b/Lib/test/test_bz2.py
@@ -10,13 +10,10 @@ import pathlib
import random
import shutil
import subprocess
+import threading
from test.support import unlink
import _compression
-try:
- import threading
-except ImportError:
- threading = None
# Skip tests if the bz2 module doesn't exist.
bz2 = support.import_module('bz2')
@@ -491,7 +488,6 @@ class BZ2FileTest(BaseTest):
else:
self.fail("1/0 didn't raise an exception")
- @unittest.skipUnless(threading, 'Threading required for this test.')
def testThreading(self):
# Issue #7205: Using a BZ2File from several threads shouldn't deadlock.
data = b"1" * 2**20
@@ -504,13 +500,6 @@ class BZ2FileTest(BaseTest):
with support.start_threads(threads):
pass
- def testWithoutThreading(self):
- module = support.import_fresh_module("bz2", blocked=("threading",))
- with module.BZ2File(self.filename, "wb") as f:
- f.write(b"abc")
- with module.BZ2File(self.filename, "rb") as f:
- self.assertEqual(f.read(), b"abc")
-
def testMixedIterationAndReads(self):
self.createTempFile()
linelen = len(self.TEXT_LINES[0])
diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py
index c3a04b4..1b826ee 100644
--- a/Lib/test/test_capi.py
+++ b/Lib/test/test_capi.py
@@ -11,6 +11,7 @@ import subprocess
import sys
import sysconfig
import textwrap
+import threading
import time
import unittest
from test import support
@@ -20,10 +21,7 @@ try:
import _posixsubprocess
except ImportError:
_posixsubprocess = None
-try:
- import threading
-except ImportError:
- threading = None
+
# Skip this test if the _testcapi module isn't available.
_testcapi = support.import_module('_testcapi')
@@ -52,7 +50,6 @@ class CAPITest(unittest.TestCase):
self.assertEqual(testfunction.attribute, "test")
self.assertRaises(AttributeError, setattr, inst.testfunction, "attribute", "test")
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_no_FatalError_infinite_loop(self):
with support.SuppressCrashReport():
p = subprocess.Popen([sys.executable, "-c",
@@ -276,7 +273,6 @@ class CAPITest(unittest.TestCase):
self.assertIn(b'MemoryError 3 30', out)
-@unittest.skipUnless(threading, 'Threading required for this test.')
class TestPendingCalls(unittest.TestCase):
def pendingcalls_submit(self, l, n):
@@ -685,7 +681,6 @@ class SkipitemTest(unittest.TestCase):
parse((1,), {}, 'O|OO', ['', 'a', ''])
-@unittest.skipUnless(threading, 'Threading required for this test.')
class TestThreadState(unittest.TestCase):
@support.reap_threads
@@ -762,7 +757,6 @@ class PyMemDebugTests(unittest.TestCase):
regex = regex.format(ptr=self.PTR_REGEX)
self.assertRegex(out, regex)
- @unittest.skipUnless(threading, 'Test requires a GIL (multithreading)')
def check_malloc_without_gil(self, code):
out = self.check(code)
expected = ('Fatal Python error: Python memory allocator called '
diff --git a/Lib/test/test_concurrent_futures.py b/Lib/test/test_concurrent_futures.py
index 03f8d1d..a888dca 100644
--- a/Lib/test/test_concurrent_futures.py
+++ b/Lib/test/test_concurrent_futures.py
@@ -4,10 +4,6 @@ import test.support
test.support.import_module('_multiprocessing')
# Skip tests if sem_open implementation is broken.
test.support.import_module('multiprocessing.synchronize')
-# import threading after _multiprocessing to raise a more relevant error
-# message: "No module named _multiprocessing". _multiprocessing is not compiled
-# without thread support.
-test.support.import_module('threading')
from test.support.script_helper import assert_python_ok
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 2301f75..64b6578 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -3,13 +3,10 @@
import io
import sys
import tempfile
+import threading
import unittest
from contextlib import * # Tests __all__
from test import support
-try:
- import threading
-except ImportError:
- threading = None
class TestAbstractContextManager(unittest.TestCase):
@@ -275,7 +272,6 @@ class FileContextTestCase(unittest.TestCase):
finally:
support.unlink(tfn)
-@unittest.skipUnless(threading, 'Threading required for this test.')
class LockContextTestCase(unittest.TestCase):
def boilerPlate(self, lock, locked):
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 53f71ca..5d9da0e 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -38,10 +38,7 @@ from test.support import (import_fresh_module, TestFailed,
run_with_locale, cpython_only)
import random
import inspect
-try:
- import threading
-except ImportError:
- threading = None
+import threading
C = import_fresh_module('decimal', fresh=['_decimal'])
@@ -1625,10 +1622,10 @@ class ThreadingTest(unittest.TestCase):
DefaultContext.Emax = save_emax
DefaultContext.Emin = save_emin
-@unittest.skipUnless(threading, 'threading required')
+
class CThreadingTest(ThreadingTest):
decimal = C
-@unittest.skipUnless(threading, 'threading required')
+
class PyThreadingTest(ThreadingTest):
decimal = P
diff --git a/Lib/test/test_docxmlrpc.py b/Lib/test/test_docxmlrpc.py
index 0090333..f077f05 100644
--- a/Lib/test/test_docxmlrpc.py
+++ b/Lib/test/test_docxmlrpc.py
@@ -1,8 +1,8 @@
from xmlrpc.server import DocXMLRPCServer
import http.client
import sys
+import threading
from test import support
-threading = support.import_module('threading')
import unittest
def make_request_and_skipIf(condition, reason):
diff --git a/Lib/test/test_dummy_thread.py b/Lib/test/test_dummy_thread.py
deleted file mode 100644
index 0840be6..0000000
--- a/Lib/test/test_dummy_thread.py
+++ /dev/null
@@ -1,255 +0,0 @@
-import _dummy_thread as _thread
-import time
-import queue
-import random
-import unittest
-from test import support
-from unittest import mock
-
-DELAY = 0
-
-
-class LockTests(unittest.TestCase):
- """Test lock objects."""
-
- def setUp(self):
- # Create a lock
- self.lock = _thread.allocate_lock()
-
- def test_initlock(self):
- #Make sure locks start locked
- self.assertFalse(self.lock.locked(),
- "Lock object is not initialized unlocked.")
-
- def test_release(self):
- # Test self.lock.release()
- self.lock.acquire()
- self.lock.release()
- self.assertFalse(self.lock.locked(),
- "Lock object did not release properly.")
-
- def test_LockType_context_manager(self):
- with _thread.LockType():
- pass
- self.assertFalse(self.lock.locked(),
- "Acquired Lock was not released")
-
- def test_improper_release(self):
- #Make sure release of an unlocked thread raises RuntimeError
- self.assertRaises(RuntimeError, self.lock.release)
-
- def test_cond_acquire_success(self):
- #Make sure the conditional acquiring of the lock works.
- self.assertTrue(self.lock.acquire(0),
- "Conditional acquiring of the lock failed.")
-
- def test_cond_acquire_fail(self):
- #Test acquiring locked lock returns False
- self.lock.acquire(0)
- self.assertFalse(self.lock.acquire(0),
- "Conditional acquiring of a locked lock incorrectly "
- "succeeded.")
-
- def test_uncond_acquire_success(self):
- #Make sure unconditional acquiring of a lock works.
- self.lock.acquire()
- self.assertTrue(self.lock.locked(),
- "Uncondional locking failed.")
-
- def test_uncond_acquire_return_val(self):
- #Make sure that an unconditional locking returns True.
- self.assertIs(self.lock.acquire(1), True,
- "Unconditional locking did not return True.")
- self.assertIs(self.lock.acquire(), True)
-
- def test_uncond_acquire_blocking(self):
- #Make sure that unconditional acquiring of a locked lock blocks.
- def delay_unlock(to_unlock, delay):
- """Hold on to lock for a set amount of time before unlocking."""
- time.sleep(delay)
- to_unlock.release()
-
- self.lock.acquire()
- start_time = int(time.time())
- _thread.start_new_thread(delay_unlock,(self.lock, DELAY))
- if support.verbose:
- print()
- print("*** Waiting for thread to release the lock "\
- "(approx. %s sec.) ***" % DELAY)
- self.lock.acquire()
- end_time = int(time.time())
- if support.verbose:
- print("done")
- self.assertGreaterEqual(end_time - start_time, DELAY,
- "Blocking by unconditional acquiring failed.")
-
- @mock.patch('time.sleep')
- def test_acquire_timeout(self, mock_sleep):
- """Test invoking acquire() with a positive timeout when the lock is
- already acquired. Ensure that time.sleep() is invoked with the given
- timeout and that False is returned."""
-
- self.lock.acquire()
- retval = self.lock.acquire(waitflag=0, timeout=1)
- self.assertTrue(mock_sleep.called)
- mock_sleep.assert_called_once_with(1)
- self.assertEqual(retval, False)
-
- def test_lock_representation(self):
- self.lock.acquire()
- self.assertIn("locked", repr(self.lock))
- self.lock.release()
- self.assertIn("unlocked", repr(self.lock))
-
-
-class MiscTests(unittest.TestCase):
- """Miscellaneous tests."""
-
- def test_exit(self):
- self.assertRaises(SystemExit, _thread.exit)
-
- def test_ident(self):
- self.assertIsInstance(_thread.get_ident(), int,
- "_thread.get_ident() returned a non-integer")
- self.assertGreater(_thread.get_ident(), 0)
-
- def test_LockType(self):
- self.assertIsInstance(_thread.allocate_lock(), _thread.LockType,
- "_thread.LockType is not an instance of what "
- "is returned by _thread.allocate_lock()")
-
- def test_set_sentinel(self):
- self.assertIsInstance(_thread._set_sentinel(), _thread.LockType,
- "_thread._set_sentinel() did not return a "
- "LockType instance.")
-
- def test_interrupt_main(self):
- #Calling start_new_thread with a function that executes interrupt_main
- # should raise KeyboardInterrupt upon completion.
- def call_interrupt():
- _thread.interrupt_main()
-
- self.assertRaises(KeyboardInterrupt,
- _thread.start_new_thread,
- call_interrupt,
- tuple())
-
- def test_interrupt_in_main(self):
- self.assertRaises(KeyboardInterrupt, _thread.interrupt_main)
-
- def test_stack_size_None(self):
- retval = _thread.stack_size(None)
- self.assertEqual(retval, 0)
-
- def test_stack_size_not_None(self):
- with self.assertRaises(_thread.error) as cm:
- _thread.stack_size("")
- self.assertEqual(cm.exception.args[0],
- "setting thread stack size not supported")
-
-
-class ThreadTests(unittest.TestCase):
- """Test thread creation."""
-
- def test_arg_passing(self):
- #Make sure that parameter passing works.
- def arg_tester(queue, arg1=False, arg2=False):
- """Use to test _thread.start_new_thread() passes args properly."""
- queue.put((arg1, arg2))
-
- testing_queue = queue.Queue(1)
- _thread.start_new_thread(arg_tester, (testing_queue, True, True))
- result = testing_queue.get()
- self.assertTrue(result[0] and result[1],
- "Argument passing for thread creation "
- "using tuple failed")
-
- _thread.start_new_thread(
- arg_tester,
- tuple(),
- {'queue':testing_queue, 'arg1':True, 'arg2':True})
-
- result = testing_queue.get()
- self.assertTrue(result[0] and result[1],
- "Argument passing for thread creation "
- "using kwargs failed")
-
- _thread.start_new_thread(
- arg_tester,
- (testing_queue, True),
- {'arg2':True})
-
- result = testing_queue.get()
- self.assertTrue(result[0] and result[1],
- "Argument passing for thread creation using both tuple"
- " and kwargs failed")
-
- def test_multi_thread_creation(self):
- def queue_mark(queue, delay):
- time.sleep(delay)
- queue.put(_thread.get_ident())
-
- thread_count = 5
- testing_queue = queue.Queue(thread_count)
-
- if support.verbose:
- print()
- print("*** Testing multiple thread creation "
- "(will take approx. %s to %s sec.) ***" % (
- DELAY, thread_count))
-
- for count in range(thread_count):
- if DELAY:
- local_delay = round(random.random(), 1)
- else:
- local_delay = 0
- _thread.start_new_thread(queue_mark,
- (testing_queue, local_delay))
- time.sleep(DELAY)
- if support.verbose:
- print('done')
- self.assertEqual(testing_queue.qsize(), thread_count,
- "Not all %s threads executed properly "
- "after %s sec." % (thread_count, DELAY))
-
- def test_args_not_tuple(self):
- """
- Test invoking start_new_thread() with a non-tuple value for "args".
- Expect TypeError with a meaningful error message to be raised.
- """
- with self.assertRaises(TypeError) as cm:
- _thread.start_new_thread(mock.Mock(), [])
- self.assertEqual(cm.exception.args[0], "2nd arg must be a tuple")
-
- def test_kwargs_not_dict(self):
- """
- Test invoking start_new_thread() with a non-dict value for "kwargs".
- Expect TypeError with a meaningful error message to be raised.
- """
- with self.assertRaises(TypeError) as cm:
- _thread.start_new_thread(mock.Mock(), tuple(), kwargs=[])
- self.assertEqual(cm.exception.args[0], "3rd arg must be a dict")
-
- def test_SystemExit(self):
- """
- Test invoking start_new_thread() with a function that raises
- SystemExit.
- The exception should be discarded.
- """
- func = mock.Mock(side_effect=SystemExit())
- try:
- _thread.start_new_thread(func, tuple())
- except SystemExit:
- self.fail("start_new_thread raised SystemExit.")
-
- @mock.patch('traceback.print_exc')
- def test_RaiseException(self, mock_print_exc):
- """
- Test invoking start_new_thread() with a function that raises exception.
-
- The exception should be discarded and the traceback should be printed
- via traceback.print_exc()
- """
- func = mock.Mock(side_effect=Exception)
- _thread.start_new_thread(func, tuple())
- self.assertTrue(mock_print_exc.called)
diff --git a/Lib/test/test_dummy_threading.py b/Lib/test/test_dummy_threading.py
deleted file mode 100644
index a0c2972..0000000
--- a/Lib/test/test_dummy_threading.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from test import support
-import unittest
-import dummy_threading as _threading
-import time
-
-class DummyThreadingTestCase(unittest.TestCase):
-
- class TestThread(_threading.Thread):
-
- def run(self):
- global running
- global sema
- global mutex
- # Uncomment if testing another module, such as the real 'threading'
- # module.
- #delay = random.random() * 2
- delay = 0
- if support.verbose:
- print('task', self.name, 'will run for', delay, 'sec')
- sema.acquire()
- mutex.acquire()
- running += 1
- if support.verbose:
- print(running, 'tasks are running')
- mutex.release()
- time.sleep(delay)
- if support.verbose:
- print('task', self.name, 'done')
- mutex.acquire()
- running -= 1
- if support.verbose:
- print(self.name, 'is finished.', running, 'tasks are running')
- mutex.release()
- sema.release()
-
- def setUp(self):
- self.numtasks = 10
- global sema
- sema = _threading.BoundedSemaphore(value=3)
- global mutex
- mutex = _threading.RLock()
- global running
- running = 0
- self.threads = []
-
- def test_tasks(self):
- for i in range(self.numtasks):
- t = self.TestThread(name="<thread %d>"%i)
- self.threads.append(t)
- t.start()
-
- if support.verbose:
- print('waiting for all tasks to complete')
- for t in self.threads:
- t.join()
- if support.verbose:
- print('all tasks done')
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/Lib/test/test_email/test_email.py b/Lib/test/test_email/test_email.py
index f97ccc6..621754c 100644
--- a/Lib/test/test_email/test_email.py
+++ b/Lib/test/test_email/test_email.py
@@ -12,10 +12,7 @@ from io import StringIO, BytesIO
from itertools import chain
from random import choice
from socket import getfqdn
-try:
- from threading import Thread
-except ImportError:
- from dummy_threading import Thread
+from threading import Thread
import email
import email.policy
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index ea52de7..e6324d4 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -2,15 +2,12 @@ import enum
import inspect
import pydoc
import unittest
+import threading
from collections import OrderedDict
from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique, auto
from io import StringIO
from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL
from test import support
-try:
- import threading
-except ImportError:
- threading = None
# for pickle tests
@@ -1988,7 +1985,6 @@ class TestFlag(unittest.TestCase):
d = 6
self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>')
- @unittest.skipUnless(threading, 'Threading required for this test.')
@support.reap_threads
def test_unique_composite(self):
# override __eq__ to be identity only
@@ -2339,7 +2335,6 @@ class TestIntFlag(unittest.TestCase):
for f in Open:
self.assertEqual(bool(f.value), bool(f))
- @unittest.skipUnless(threading, 'Threading required for this test.')
@support.reap_threads
def test_unique_composite(self):
# override __eq__ to be identity only
diff --git a/Lib/test/test_faulthandler.py b/Lib/test/test_faulthandler.py
index e2fcb2b..889e641 100644
--- a/Lib/test/test_faulthandler.py
+++ b/Lib/test/test_faulthandler.py
@@ -8,15 +8,11 @@ import sys
from test import support
from test.support import script_helper, is_android, requires_android_level
import tempfile
+import threading
import unittest
from textwrap import dedent
try:
- import threading
- HAVE_THREADS = True
-except ImportError:
- HAVE_THREADS = False
-try:
import _testcapi
except ImportError:
_testcapi = None
@@ -154,7 +150,6 @@ class FaultHandlerTests(unittest.TestCase):
3,
'Segmentation fault')
- @unittest.skipIf(not HAVE_THREADS, 'need threads')
def test_fatal_error_c_thread(self):
self.check_fatal_error("""
import faulthandler
@@ -231,7 +226,7 @@ class FaultHandlerTests(unittest.TestCase):
2,
'xyz')
- @unittest.skipIf(sys.platform.startswith('openbsd') and HAVE_THREADS,
+ @unittest.skipIf(sys.platform.startswith('openbsd'),
"Issue #12868: sigaltstack() doesn't work on "
"OpenBSD if Python is compiled with pthread")
@unittest.skipIf(not hasattr(faulthandler, '_stack_overflow'),
@@ -456,7 +451,6 @@ class FaultHandlerTests(unittest.TestCase):
self.assertEqual(trace, expected)
self.assertEqual(exitcode, 0)
- @unittest.skipIf(not HAVE_THREADS, 'need threads')
def check_dump_traceback_threads(self, filename):
"""
Call explicitly dump_traceback(all_threads=True) and check the output.
diff --git a/Lib/test/test_fork1.py b/Lib/test/test_fork1.py
index da46fe5..9ca9724 100644
--- a/Lib/test/test_fork1.py
+++ b/Lib/test/test_fork1.py
@@ -5,6 +5,7 @@ import _imp as imp
import os
import signal
import sys
+import threading
import time
import unittest
@@ -12,7 +13,6 @@ from test.fork_wait import ForkWait
from test.support import (reap_children, get_attribute,
import_module, verbose)
-threading = import_module('threading')
# Skip test if fork does not exist.
get_attribute(os, 'fork')
diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py
index 24ea382..151e091 100644
--- a/Lib/test/test_ftplib.py
+++ b/Lib/test/test_ftplib.py
@@ -10,6 +10,7 @@ import socket
import io
import errno
import os
+import threading
import time
try:
import ssl
@@ -19,7 +20,6 @@ except ImportError:
from unittest import TestCase, skipUnless
from test import support
from test.support import HOST, HOSTv6
-threading = support.import_module('threading')
TIMEOUT = 3
# the dummy data returned by server over the data channel when
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 3acfb92..f7a1166 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -8,15 +8,12 @@ import pickle
from random import choice
import sys
from test import support
+import threading
import time
import unittest
import unittest.mock
from weakref import proxy
import contextlib
-try:
- import threading
-except ImportError:
- threading = None
import functools
@@ -1406,7 +1403,6 @@ class TestLRU:
for attr in self.module.WRAPPER_ASSIGNMENTS:
self.assertEqual(getattr(g, attr), getattr(f, attr))
- @unittest.skipUnless(threading, 'This test requires threading.')
def test_lru_cache_threaded(self):
n, m = 5, 11
def orig(x, y):
@@ -1455,7 +1451,6 @@ class TestLRU:
finally:
sys.setswitchinterval(orig_si)
- @unittest.skipUnless(threading, 'This test requires threading.')
def test_lru_cache_threaded2(self):
# Simultaneous call with the same arguments
n, m = 5, 7
@@ -1483,7 +1478,6 @@ class TestLRU:
pause.reset()
self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
- @unittest.skipUnless(threading, 'This test requires threading.')
def test_lru_cache_threaded3(self):
@self.module.lru_cache(maxsize=2)
def f(x):
diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py
index e727499..914efec 100644
--- a/Lib/test/test_gc.py
+++ b/Lib/test/test_gc.py
@@ -8,11 +8,7 @@ import sys
import time
import gc
import weakref
-
-try:
- import threading
-except ImportError:
- threading = None
+import threading
try:
from _testcapi import with_tp_del
@@ -352,7 +348,6 @@ class GCTests(unittest.TestCase):
v = {1: v, 2: Ouch()}
gc.disable()
- @unittest.skipUnless(threading, "test meaningless on builds without threads")
def test_trashcan_threads(self):
# Issue #13992: trashcan mechanism should be thread-safe
NESTING = 60
diff --git a/Lib/test/test_gdb.py b/Lib/test/test_gdb.py
index 46736f6..9e0eaea 100644
--- a/Lib/test/test_gdb.py
+++ b/Lib/test/test_gdb.py
@@ -12,12 +12,6 @@ import sysconfig
import textwrap
import unittest
-# Is this Python configured to support threads?
-try:
- import _thread
-except ImportError:
- _thread = None
-
from test import support
from test.support import run_unittest, findfile, python_is_optimized
@@ -755,8 +749,6 @@ Traceback \(most recent call first\):
foo\(1, 2, 3\)
''')
- @unittest.skipUnless(_thread,
- "Python was compiled without thread support")
def test_threads(self):
'Verify that "py-bt" indicates threads that are waiting for the GIL'
cmd = '''
@@ -794,8 +786,6 @@ id(42)
# Some older versions of gdb will fail with
# "Cannot find new threads: generic error"
# unless we add LD_PRELOAD=PATH-TO-libpthread.so.1 as a workaround
- @unittest.skipUnless(_thread,
- "Python was compiled without thread support")
def test_gc(self):
'Verify that "py-bt" indicates if a thread is garbage-collecting'
cmd = ('from gc import collect\n'
@@ -822,8 +812,6 @@ id(42)
# Some older versions of gdb will fail with
# "Cannot find new threads: generic error"
# unless we add LD_PRELOAD=PATH-TO-libpthread.so.1 as a workaround
- @unittest.skipUnless(_thread,
- "Python was compiled without thread support")
def test_pycfunction(self):
'Verify that "py-bt" displays invocations of PyCFunction instances'
# Tested function must not be defined with METH_NOARGS or METH_O,
diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py
index f748b46..5c8f090 100644
--- a/Lib/test/test_hashlib.py
+++ b/Lib/test/test_hashlib.py
@@ -12,10 +12,7 @@ import hashlib
import itertools
import os
import sys
-try:
- import threading
-except ImportError:
- threading = None
+import threading
import unittest
import warnings
from test import support
@@ -738,7 +735,6 @@ class HashLibTestCase(unittest.TestCase):
m = hashlib.md5(b'x' * gil_minsize)
self.assertEqual(m.hexdigest(), 'cfb767f225d58469c5de3632a8803958')
- @unittest.skipUnless(threading, 'Threading required for this test.')
@support.reap_threads
def test_threaded_hashing(self):
# Updating the same hash object from several threads at once
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py
index 8cddcdc..20e6f66 100644
--- a/Lib/test/test_httpservers.py
+++ b/Lib/test/test_httpservers.py
@@ -22,12 +22,13 @@ import urllib.parse
import tempfile
import time
import datetime
+import threading
from unittest import mock
from io import BytesIO
import unittest
from test import support
-threading = support.import_module('threading')
+
class NoLogRequestHandler:
def log_message(self, *args):
diff --git a/Lib/test/test_idle.py b/Lib/test/test_idle.py
index da05da5..b7ef70d 100644
--- a/Lib/test/test_idle.py
+++ b/Lib/test/test_idle.py
@@ -3,7 +3,6 @@ from test.support import import_module
# Skip test if _thread or _tkinter wasn't built, if idlelib is missing,
# or if tcl/tk is not the 8.5+ needed for ttk widgets.
-import_module('threading') # imported by PyShell, imports _thread
tk = import_module('tkinter') # imports _tkinter
if tk.TkVersion < 8.5:
raise unittest.SkipTest("IDLE requires tk 8.5 or later.")
diff --git a/Lib/test/test_imaplib.py b/Lib/test/test_imaplib.py
index 7c12438..132c586 100644
--- a/Lib/test/test_imaplib.py
+++ b/Lib/test/test_imaplib.py
@@ -1,8 +1,4 @@
from test import support
-# If we end up with a significant number of tests that don't require
-# threading, this test module should be split. Right now we skip
-# them all if we don't have threading.
-threading = support.import_module('threading')
from contextlib import contextmanager
import imaplib
@@ -10,6 +6,7 @@ import os.path
import socketserver
import time
import calendar
+import threading
from test.support import (reap_threads, verbose, transient_internet,
run_with_tz, run_with_locale, cpython_only)
diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py
index 6f35f49..3bfcb09 100644
--- a/Lib/test/test_imp.py
+++ b/Lib/test/test_imp.py
@@ -1,7 +1,3 @@
-try:
- import _thread
-except ImportError:
- _thread = None
import importlib
import importlib.util
import os
@@ -23,7 +19,6 @@ def requires_load_dynamic(meth):
'imp.load_dynamic() required')(meth)
-@unittest.skipIf(_thread is None, '_thread module is required')
class LockTests(unittest.TestCase):
"""Very basic test of import lock functions."""
diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py
index dbce9c2..d86172a 100644
--- a/Lib/test/test_importlib/test_locks.py
+++ b/Lib/test/test_importlib/test_locks.py
@@ -3,134 +3,112 @@ from . import util as test_util
init = test_util.import_importlib('importlib')
import sys
+import threading
import unittest
import weakref
from test import support
-
-try:
- import threading
-except ImportError:
- threading = None
-else:
- from test import lock_tests
-
-if threading is not None:
- class ModuleLockAsRLockTests:
- locktype = classmethod(lambda cls: cls.LockType("some_lock"))
-
- # _is_owned() unsupported
- test__is_owned = None
- # acquire(blocking=False) unsupported
- test_try_acquire = None
- test_try_acquire_contended = None
- # `with` unsupported
- test_with = None
- # acquire(timeout=...) unsupported
- test_timeout = None
- # _release_save() unsupported
- test_release_save_unacquired = None
- # lock status in repr unsupported
- test_repr = None
- test_locked_repr = None
-
- LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock
- for kind, splitinit in init.items()}
-
- (Frozen_ModuleLockAsRLockTests,
- Source_ModuleLockAsRLockTests
- ) = test_util.test_both(ModuleLockAsRLockTests, lock_tests.RLockTests,
- LockType=LOCK_TYPES)
-else:
- LOCK_TYPES = {}
-
- class Frozen_ModuleLockAsRLockTests(unittest.TestCase):
- pass
-
- class Source_ModuleLockAsRLockTests(unittest.TestCase):
- pass
-
-
-if threading is not None:
- class DeadlockAvoidanceTests:
-
- def setUp(self):
+from test import lock_tests
+
+
+class ModuleLockAsRLockTests:
+ locktype = classmethod(lambda cls: cls.LockType("some_lock"))
+
+ # _is_owned() unsupported
+ test__is_owned = None
+ # acquire(blocking=False) unsupported
+ test_try_acquire = None
+ test_try_acquire_contended = None
+ # `with` unsupported
+ test_with = None
+ # acquire(timeout=...) unsupported
+ test_timeout = None
+ # _release_save() unsupported
+ test_release_save_unacquired = None
+ # lock status in repr unsupported
+ test_repr = None
+ test_locked_repr = None
+
+LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock
+ for kind, splitinit in init.items()}
+
+(Frozen_ModuleLockAsRLockTests,
+ Source_ModuleLockAsRLockTests
+ ) = test_util.test_both(ModuleLockAsRLockTests, lock_tests.RLockTests,
+ LockType=LOCK_TYPES)
+
+
+class DeadlockAvoidanceTests:
+
+ def setUp(self):
+ try:
+ self.old_switchinterval = sys.getswitchinterval()
+ support.setswitchinterval(0.000001)
+ except AttributeError:
+ self.old_switchinterval = None
+
+ def tearDown(self):
+ if self.old_switchinterval is not None:
+ sys.setswitchinterval(self.old_switchinterval)
+
+ def run_deadlock_avoidance_test(self, create_deadlock):
+ NLOCKS = 10
+ locks = [self.LockType(str(i)) for i in range(NLOCKS)]
+ pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)]
+ if create_deadlock:
+ NTHREADS = NLOCKS
+ else:
+ NTHREADS = NLOCKS - 1
+ barrier = threading.Barrier(NTHREADS)
+ results = []
+
+ def _acquire(lock):
+ """Try to acquire the lock. Return True on success,
+ False on deadlock."""
try:
- self.old_switchinterval = sys.getswitchinterval()
- support.setswitchinterval(0.000001)
- except AttributeError:
- self.old_switchinterval = None
-
- def tearDown(self):
- if self.old_switchinterval is not None:
- sys.setswitchinterval(self.old_switchinterval)
-
- def run_deadlock_avoidance_test(self, create_deadlock):
- NLOCKS = 10
- locks = [self.LockType(str(i)) for i in range(NLOCKS)]
- pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)]
- if create_deadlock:
- NTHREADS = NLOCKS
+ lock.acquire()
+ except self.DeadlockError:
+ return False
else:
- NTHREADS = NLOCKS - 1
- barrier = threading.Barrier(NTHREADS)
- results = []
-
- def _acquire(lock):
- """Try to acquire the lock. Return True on success,
- False on deadlock."""
- try:
- lock.acquire()
- except self.DeadlockError:
- return False
- else:
- return True
-
- def f():
- a, b = pairs.pop()
- ra = _acquire(a)
- barrier.wait()
- rb = _acquire(b)
- results.append((ra, rb))
- if rb:
- b.release()
- if ra:
- a.release()
- lock_tests.Bunch(f, NTHREADS).wait_for_finished()
- self.assertEqual(len(results), NTHREADS)
- return results
-
- def test_deadlock(self):
- results = self.run_deadlock_avoidance_test(True)
- # At least one of the threads detected a potential deadlock on its
- # second acquire() call. It may be several of them, because the
- # deadlock avoidance mechanism is conservative.
- nb_deadlocks = results.count((True, False))
- self.assertGreaterEqual(nb_deadlocks, 1)
- self.assertEqual(results.count((True, True)), len(results) - nb_deadlocks)
-
- def test_no_deadlock(self):
- results = self.run_deadlock_avoidance_test(False)
- self.assertEqual(results.count((True, False)), 0)
- self.assertEqual(results.count((True, True)), len(results))
-
-
- DEADLOCK_ERRORS = {kind: splitinit._bootstrap._DeadlockError
- for kind, splitinit in init.items()}
-
- (Frozen_DeadlockAvoidanceTests,
- Source_DeadlockAvoidanceTests
- ) = test_util.test_both(DeadlockAvoidanceTests,
- LockType=LOCK_TYPES,
- DeadlockError=DEADLOCK_ERRORS)
-else:
- DEADLOCK_ERRORS = {}
-
- class Frozen_DeadlockAvoidanceTests(unittest.TestCase):
- pass
-
- class Source_DeadlockAvoidanceTests(unittest.TestCase):
- pass
+ return True
+
+ def f():
+ a, b = pairs.pop()
+ ra = _acquire(a)
+ barrier.wait()
+ rb = _acquire(b)
+ results.append((ra, rb))
+ if rb:
+ b.release()
+ if ra:
+ a.release()
+ lock_tests.Bunch(f, NTHREADS).wait_for_finished()
+ self.assertEqual(len(results), NTHREADS)
+ return results
+
+ def test_deadlock(self):
+ results = self.run_deadlock_avoidance_test(True)
+ # At least one of the threads detected a potential deadlock on its
+ # second acquire() call. It may be several of them, because the
+ # deadlock avoidance mechanism is conservative.
+ nb_deadlocks = results.count((True, False))
+ self.assertGreaterEqual(nb_deadlocks, 1)
+ self.assertEqual(results.count((True, True)), len(results) - nb_deadlocks)
+
+ def test_no_deadlock(self):
+ results = self.run_deadlock_avoidance_test(False)
+ self.assertEqual(results.count((True, False)), 0)
+ self.assertEqual(results.count((True, True)), len(results))
+
+
+DEADLOCK_ERRORS = {kind: splitinit._bootstrap._DeadlockError
+ for kind, splitinit in init.items()}
+
+(Frozen_DeadlockAvoidanceTests,
+ Source_DeadlockAvoidanceTests
+ ) = test_util.test_both(DeadlockAvoidanceTests,
+ LockType=LOCK_TYPES,
+ DeadlockError=DEADLOCK_ERRORS)
class LifetimeTests:
diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py
index 48270c8..d4685dd 100644
--- a/Lib/test/test_io.py
+++ b/Lib/test/test_io.py
@@ -28,6 +28,7 @@ import pickle
import random
import signal
import sys
+import threading
import time
import unittest
import warnings
@@ -40,10 +41,6 @@ from test.support.script_helper import assert_python_ok, run_python_until_end
import codecs
import io # C implementation of io
import _pyio as pyio # Python implementation of io
-try:
- import threading
-except ImportError:
- threading = None
try:
import ctypes
@@ -443,8 +440,6 @@ class IOTest(unittest.TestCase):
(self.BytesIO, "rws"), (self.StringIO, "rws"),
)
for [test, abilities] in tests:
- if test is pipe_writer and not threading:
- continue # Skip subtest that uses a background thread
with self.subTest(test), test() as obj:
readable = "r" in abilities
self.assertEqual(obj.readable(), readable)
@@ -1337,7 +1332,6 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests):
self.assertEqual(b"abcdefg", bufio.read())
- @unittest.skipUnless(threading, 'Threading required for this test.')
@support.requires_resource('cpu')
def test_threads(self):
try:
@@ -1664,7 +1658,6 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests):
with self.open(support.TESTFN, "rb", buffering=0) as f:
self.assertEqual(f.read(), b"abc")
- @unittest.skipUnless(threading, 'Threading required for this test.')
@support.requires_resource('cpu')
def test_threads(self):
try:
@@ -3053,7 +3046,6 @@ class TextIOWrapperTest(unittest.TestCase):
self.assertEqual(f.errors, "replace")
@support.no_tracing
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_threads_write(self):
# Issue6750: concurrent writes could duplicate data
event = threading.Event()
@@ -3804,7 +3796,6 @@ class CMiscIOTest(MiscIOTest):
b = bytearray(2)
self.assertRaises(ValueError, bufio.readinto, b)
- @unittest.skipUnless(threading, 'Threading required for this test.')
def check_daemon_threads_shutdown_deadlock(self, stream_name):
# Issue #23309: deadlocks at shutdown should be avoided when a
# daemon thread and the main thread both write to a file.
@@ -3868,7 +3859,6 @@ class SignalsTest(unittest.TestCase):
def alarm_interrupt(self, sig, frame):
1/0
- @unittest.skipUnless(threading, 'Threading required for this test.')
def check_interrupted_write(self, item, bytes, **fdopen_kwargs):
"""Check that a partial write, when it gets interrupted, properly
invokes the signal handler, and bubbles up the exception raised
@@ -3990,7 +3980,6 @@ class SignalsTest(unittest.TestCase):
self.check_interrupted_read_retry(lambda x: x,
mode="r")
- @unittest.skipUnless(threading, 'Threading required for this test.')
def check_interrupted_write_retry(self, item, **fdopen_kwargs):
"""Check that a buffered write, when it gets interrupted (either
returning a partial result or EINTR), properly invokes the signal
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
index f4aef9f..9c3816a 100644
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -42,22 +42,19 @@ import tempfile
from test.support.script_helper import assert_python_ok
from test import support
import textwrap
+import threading
import time
import unittest
import warnings
import weakref
-try:
- import threading
- # The following imports are needed only for tests which
- # require threading
- import asyncore
- from http.server import HTTPServer, BaseHTTPRequestHandler
- import smtpd
- from urllib.parse import urlparse, parse_qs
- from socketserver import (ThreadingUDPServer, DatagramRequestHandler,
- ThreadingTCPServer, StreamRequestHandler)
-except ImportError:
- threading = None
+
+import asyncore
+from http.server import HTTPServer, BaseHTTPRequestHandler
+import smtpd
+from urllib.parse import urlparse, parse_qs
+from socketserver import (ThreadingUDPServer, DatagramRequestHandler,
+ ThreadingTCPServer, StreamRequestHandler)
+
try:
import win32evtlog, win32evtlogutil, pywintypes
except ImportError:
@@ -625,7 +622,6 @@ class HandlerTest(BaseTest):
os.unlink(fn)
@unittest.skipIf(os.name == 'nt', 'WatchedFileHandler not appropriate for Windows.')
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_race(self):
# Issue #14632 refers.
def remove_loop(fname, tries):
@@ -719,276 +715,274 @@ class StreamHandlerTest(BaseTest):
# -- The following section could be moved into a server_helper.py module
# -- if it proves to be of wider utility than just test_logging
-if threading:
- class TestSMTPServer(smtpd.SMTPServer):
+class TestSMTPServer(smtpd.SMTPServer):
+ """
+ This class implements a test SMTP server.
+
+ :param addr: A (host, port) tuple which the server listens on.
+ You can specify a port value of zero: the server's
+ *port* attribute will hold the actual port number
+ used, which can be used in client connections.
+ :param handler: A callable which will be called to process
+ incoming messages. The handler will be passed
+ the client address tuple, who the message is from,
+ a list of recipients and the message data.
+ :param poll_interval: The interval, in seconds, used in the underlying
+ :func:`select` or :func:`poll` call by
+ :func:`asyncore.loop`.
+ :param sockmap: A dictionary which will be used to hold
+ :class:`asyncore.dispatcher` instances used by
+ :func:`asyncore.loop`. This avoids changing the
+ :mod:`asyncore` module's global state.
+ """
+
+ def __init__(self, addr, handler, poll_interval, sockmap):
+ smtpd.SMTPServer.__init__(self, addr, None, map=sockmap,
+ decode_data=True)
+ self.port = self.socket.getsockname()[1]
+ self._handler = handler
+ self._thread = None
+ self.poll_interval = poll_interval
+
+ def process_message(self, peer, mailfrom, rcpttos, data):
+ """
+ Delegates to the handler passed in to the server's constructor.
+
+ Typically, this will be a test case method.
+ :param peer: The client (host, port) tuple.
+ :param mailfrom: The address of the sender.
+ :param rcpttos: The addresses of the recipients.
+ :param data: The message.
"""
- This class implements a test SMTP server.
-
- :param addr: A (host, port) tuple which the server listens on.
- You can specify a port value of zero: the server's
- *port* attribute will hold the actual port number
- used, which can be used in client connections.
- :param handler: A callable which will be called to process
- incoming messages. The handler will be passed
- the client address tuple, who the message is from,
- a list of recipients and the message data.
+ self._handler(peer, mailfrom, rcpttos, data)
+
+ def start(self):
+ """
+ Start the server running on a separate daemon thread.
+ """
+ self._thread = t = threading.Thread(target=self.serve_forever,
+ args=(self.poll_interval,))
+ t.setDaemon(True)
+ t.start()
+
+ def serve_forever(self, poll_interval):
+ """
+ Run the :mod:`asyncore` loop until normal termination
+ conditions arise.
:param poll_interval: The interval, in seconds, used in the underlying
:func:`select` or :func:`poll` call by
:func:`asyncore.loop`.
- :param sockmap: A dictionary which will be used to hold
- :class:`asyncore.dispatcher` instances used by
- :func:`asyncore.loop`. This avoids changing the
- :mod:`asyncore` module's global state.
"""
+ try:
+ asyncore.loop(poll_interval, map=self._map)
+ except OSError:
+ # On FreeBSD 8, closing the server repeatably
+ # raises this error. We swallow it if the
+ # server has been closed.
+ if self.connected or self.accepting:
+ raise
- def __init__(self, addr, handler, poll_interval, sockmap):
- smtpd.SMTPServer.__init__(self, addr, None, map=sockmap,
- decode_data=True)
- self.port = self.socket.getsockname()[1]
- self._handler = handler
- self._thread = None
- self.poll_interval = poll_interval
+ def stop(self, timeout=None):
+ """
+ Stop the thread by closing the server instance.
+ Wait for the server thread to terminate.
- def process_message(self, peer, mailfrom, rcpttos, data):
- """
- Delegates to the handler passed in to the server's constructor.
+ :param timeout: How long to wait for the server thread
+ to terminate.
+ """
+ self.close()
+ self._thread.join(timeout)
+ asyncore.close_all(map=self._map, ignore_all=True)
- Typically, this will be a test case method.
- :param peer: The client (host, port) tuple.
- :param mailfrom: The address of the sender.
- :param rcpttos: The addresses of the recipients.
- :param data: The message.
- """
- self._handler(peer, mailfrom, rcpttos, data)
+ alive = self._thread.is_alive()
+ self._thread = None
+ if alive:
+ self.fail("join() timed out")
- def start(self):
- """
- Start the server running on a separate daemon thread.
- """
- self._thread = t = threading.Thread(target=self.serve_forever,
- args=(self.poll_interval,))
- t.setDaemon(True)
- t.start()
+class ControlMixin(object):
+ """
+ This mixin is used to start a server on a separate thread, and
+ shut it down programmatically. Request handling is simplified - instead
+ of needing to derive a suitable RequestHandler subclass, you just
+ provide a callable which will be passed each received request to be
+ processed.
+
+ :param handler: A handler callable which will be called with a
+ single parameter - the request - in order to
+ process the request. This handler is called on the
+ server thread, effectively meaning that requests are
+ processed serially. While not quite Web scale ;-),
+ this should be fine for testing applications.
+ :param poll_interval: The polling interval in seconds.
+ """
+ def __init__(self, handler, poll_interval):
+ self._thread = None
+ self.poll_interval = poll_interval
+ self._handler = handler
+ self.ready = threading.Event()
- def serve_forever(self, poll_interval):
- """
- Run the :mod:`asyncore` loop until normal termination
- conditions arise.
- :param poll_interval: The interval, in seconds, used in the underlying
- :func:`select` or :func:`poll` call by
- :func:`asyncore.loop`.
- """
- try:
- asyncore.loop(poll_interval, map=self._map)
- except OSError:
- # On FreeBSD 8, closing the server repeatably
- # raises this error. We swallow it if the
- # server has been closed.
- if self.connected or self.accepting:
- raise
-
- def stop(self, timeout=None):
- """
- Stop the thread by closing the server instance.
- Wait for the server thread to terminate.
+ def start(self):
+ """
+ Create a daemon thread to run the server, and start it.
+ """
+ self._thread = t = threading.Thread(target=self.serve_forever,
+ args=(self.poll_interval,))
+ t.setDaemon(True)
+ t.start()
- :param timeout: How long to wait for the server thread
- to terminate.
- """
- self.close()
- self._thread.join(timeout)
- asyncore.close_all(map=self._map, ignore_all=True)
+ def serve_forever(self, poll_interval):
+ """
+ Run the server. Set the ready flag before entering the
+ service loop.
+ """
+ self.ready.set()
+ super(ControlMixin, self).serve_forever(poll_interval)
+ def stop(self, timeout=None):
+ """
+ Tell the server thread to stop, and wait for it to do so.
+
+ :param timeout: How long to wait for the server thread
+ to terminate.
+ """
+ self.shutdown()
+ if self._thread is not None:
+ self._thread.join(timeout)
alive = self._thread.is_alive()
self._thread = None
if alive:
self.fail("join() timed out")
+ self.server_close()
+ self.ready.clear()
- class ControlMixin(object):
- """
- This mixin is used to start a server on a separate thread, and
- shut it down programmatically. Request handling is simplified - instead
- of needing to derive a suitable RequestHandler subclass, you just
- provide a callable which will be passed each received request to be
- processed.
-
- :param handler: A handler callable which will be called with a
- single parameter - the request - in order to
- process the request. This handler is called on the
- server thread, effectively meaning that requests are
- processed serially. While not quite Web scale ;-),
- this should be fine for testing applications.
- :param poll_interval: The polling interval in seconds.
- """
- def __init__(self, handler, poll_interval):
- self._thread = None
- self.poll_interval = poll_interval
- self._handler = handler
- self.ready = threading.Event()
+class TestHTTPServer(ControlMixin, HTTPServer):
+ """
+ An HTTP server which is controllable using :class:`ControlMixin`.
+
+ :param addr: A tuple with the IP address and port to listen on.
+ :param handler: A handler callable which will be called with a
+ single parameter - the request - in order to
+ process the request.
+ :param poll_interval: The polling interval in seconds.
+ :param log: Pass ``True`` to enable log messages.
+ """
+ def __init__(self, addr, handler, poll_interval=0.5,
+ log=False, sslctx=None):
+ class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler):
+ def __getattr__(self, name, default=None):
+ if name.startswith('do_'):
+ return self.process_request
+ raise AttributeError(name)
+
+ def process_request(self):
+ self.server._handler(self)
+
+ def log_message(self, format, *args):
+ if log:
+ super(DelegatingHTTPRequestHandler,
+ self).log_message(format, *args)
+ HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler)
+ ControlMixin.__init__(self, handler, poll_interval)
+ self.sslctx = sslctx
+
+ def get_request(self):
+ try:
+ sock, addr = self.socket.accept()
+ if self.sslctx:
+ sock = self.sslctx.wrap_socket(sock, server_side=True)
+ except OSError as e:
+ # socket errors are silenced by the caller, print them here
+ sys.stderr.write("Got an error:\n%s\n" % e)
+ raise
+ return sock, addr
- def start(self):
- """
- Create a daemon thread to run the server, and start it.
- """
- self._thread = t = threading.Thread(target=self.serve_forever,
- args=(self.poll_interval,))
- t.setDaemon(True)
- t.start()
+class TestTCPServer(ControlMixin, ThreadingTCPServer):
+ """
+ A TCP server which is controllable using :class:`ControlMixin`.
+
+ :param addr: A tuple with the IP address and port to listen on.
+ :param handler: A handler callable which will be called with a single
+ parameter - the request - in order to process the request.
+ :param poll_interval: The polling interval in seconds.
+ :bind_and_activate: If True (the default), binds the server and starts it
+ listening. If False, you need to call
+ :meth:`server_bind` and :meth:`server_activate` at
+ some later time before calling :meth:`start`, so that
+ the server will set up the socket and listen on it.
+ """
- def serve_forever(self, poll_interval):
- """
- Run the server. Set the ready flag before entering the
- service loop.
- """
- self.ready.set()
- super(ControlMixin, self).serve_forever(poll_interval)
+ allow_reuse_address = True
- def stop(self, timeout=None):
- """
- Tell the server thread to stop, and wait for it to do so.
+ def __init__(self, addr, handler, poll_interval=0.5,
+ bind_and_activate=True):
+ class DelegatingTCPRequestHandler(StreamRequestHandler):
- :param timeout: How long to wait for the server thread
- to terminate.
- """
- self.shutdown()
- if self._thread is not None:
- self._thread.join(timeout)
- alive = self._thread.is_alive()
- self._thread = None
- if alive:
- self.fail("join() timed out")
- self.server_close()
- self.ready.clear()
-
- class TestHTTPServer(ControlMixin, HTTPServer):
- """
- An HTTP server which is controllable using :class:`ControlMixin`.
-
- :param addr: A tuple with the IP address and port to listen on.
- :param handler: A handler callable which will be called with a
- single parameter - the request - in order to
- process the request.
- :param poll_interval: The polling interval in seconds.
- :param log: Pass ``True`` to enable log messages.
- """
- def __init__(self, addr, handler, poll_interval=0.5,
- log=False, sslctx=None):
- class DelegatingHTTPRequestHandler(BaseHTTPRequestHandler):
- def __getattr__(self, name, default=None):
- if name.startswith('do_'):
- return self.process_request
- raise AttributeError(name)
-
- def process_request(self):
- self.server._handler(self)
-
- def log_message(self, format, *args):
- if log:
- super(DelegatingHTTPRequestHandler,
- self).log_message(format, *args)
- HTTPServer.__init__(self, addr, DelegatingHTTPRequestHandler)
- ControlMixin.__init__(self, handler, poll_interval)
- self.sslctx = sslctx
-
- def get_request(self):
- try:
- sock, addr = self.socket.accept()
- if self.sslctx:
- sock = self.sslctx.wrap_socket(sock, server_side=True)
- except OSError as e:
- # socket errors are silenced by the caller, print them here
- sys.stderr.write("Got an error:\n%s\n" % e)
- raise
- return sock, addr
+ def handle(self):
+ self.server._handler(self)
+ ThreadingTCPServer.__init__(self, addr, DelegatingTCPRequestHandler,
+ bind_and_activate)
+ ControlMixin.__init__(self, handler, poll_interval)
- class TestTCPServer(ControlMixin, ThreadingTCPServer):
- """
- A TCP server which is controllable using :class:`ControlMixin`.
-
- :param addr: A tuple with the IP address and port to listen on.
- :param handler: A handler callable which will be called with a single
- parameter - the request - in order to process the request.
- :param poll_interval: The polling interval in seconds.
- :bind_and_activate: If True (the default), binds the server and starts it
- listening. If False, you need to call
- :meth:`server_bind` and :meth:`server_activate` at
- some later time before calling :meth:`start`, so that
- the server will set up the socket and listen on it.
- """
+ def server_bind(self):
+ super(TestTCPServer, self).server_bind()
+ self.port = self.socket.getsockname()[1]
- allow_reuse_address = True
+class TestUDPServer(ControlMixin, ThreadingUDPServer):
+ """
+ A UDP server which is controllable using :class:`ControlMixin`.
+
+ :param addr: A tuple with the IP address and port to listen on.
+ :param handler: A handler callable which will be called with a
+ single parameter - the request - in order to
+ process the request.
+ :param poll_interval: The polling interval for shutdown requests,
+ in seconds.
+ :bind_and_activate: If True (the default), binds the server and
+ starts it listening. If False, you need to
+ call :meth:`server_bind` and
+ :meth:`server_activate` at some later time
+ before calling :meth:`start`, so that the server will
+ set up the socket and listen on it.
+ """
+ def __init__(self, addr, handler, poll_interval=0.5,
+ bind_and_activate=True):
+ class DelegatingUDPRequestHandler(DatagramRequestHandler):
- def __init__(self, addr, handler, poll_interval=0.5,
- bind_and_activate=True):
- class DelegatingTCPRequestHandler(StreamRequestHandler):
+ def handle(self):
+ self.server._handler(self)
- def handle(self):
- self.server._handler(self)
- ThreadingTCPServer.__init__(self, addr, DelegatingTCPRequestHandler,
- bind_and_activate)
- ControlMixin.__init__(self, handler, poll_interval)
+ def finish(self):
+ data = self.wfile.getvalue()
+ if data:
+ try:
+ super(DelegatingUDPRequestHandler, self).finish()
+ except OSError:
+ if not self.server._closed:
+ raise
- def server_bind(self):
- super(TestTCPServer, self).server_bind()
- self.port = self.socket.getsockname()[1]
+ ThreadingUDPServer.__init__(self, addr,
+ DelegatingUDPRequestHandler,
+ bind_and_activate)
+ ControlMixin.__init__(self, handler, poll_interval)
+ self._closed = False
- class TestUDPServer(ControlMixin, ThreadingUDPServer):
- """
- A UDP server which is controllable using :class:`ControlMixin`.
-
- :param addr: A tuple with the IP address and port to listen on.
- :param handler: A handler callable which will be called with a
- single parameter - the request - in order to
- process the request.
- :param poll_interval: The polling interval for shutdown requests,
- in seconds.
- :bind_and_activate: If True (the default), binds the server and
- starts it listening. If False, you need to
- call :meth:`server_bind` and
- :meth:`server_activate` at some later time
- before calling :meth:`start`, so that the server will
- set up the socket and listen on it.
- """
- def __init__(self, addr, handler, poll_interval=0.5,
- bind_and_activate=True):
- class DelegatingUDPRequestHandler(DatagramRequestHandler):
-
- def handle(self):
- self.server._handler(self)
-
- def finish(self):
- data = self.wfile.getvalue()
- if data:
- try:
- super(DelegatingUDPRequestHandler, self).finish()
- except OSError:
- if not self.server._closed:
- raise
-
- ThreadingUDPServer.__init__(self, addr,
- DelegatingUDPRequestHandler,
- bind_and_activate)
- ControlMixin.__init__(self, handler, poll_interval)
- self._closed = False
-
- def server_bind(self):
- super(TestUDPServer, self).server_bind()
- self.port = self.socket.getsockname()[1]
-
- def server_close(self):
- super(TestUDPServer, self).server_close()
- self._closed = True
+ def server_bind(self):
+ super(TestUDPServer, self).server_bind()
+ self.port = self.socket.getsockname()[1]
- if hasattr(socket, "AF_UNIX"):
- class TestUnixStreamServer(TestTCPServer):
- address_family = socket.AF_UNIX
+ def server_close(self):
+ super(TestUDPServer, self).server_close()
+ self._closed = True
- class TestUnixDatagramServer(TestUDPServer):
- address_family = socket.AF_UNIX
+if hasattr(socket, "AF_UNIX"):
+ class TestUnixStreamServer(TestTCPServer):
+ address_family = socket.AF_UNIX
+
+ class TestUnixDatagramServer(TestUDPServer):
+ address_family = socket.AF_UNIX
# - end of server_helper section
-@unittest.skipUnless(threading, 'Threading required for this test.')
class SMTPHandlerTest(BaseTest):
TIMEOUT = 8.0
@@ -1472,14 +1466,12 @@ class ConfigFileTest(BaseTest):
@unittest.skipIf(True, "FIXME: bpo-30830")
-@unittest.skipUnless(threading, 'Threading required for this test.')
class SocketHandlerTest(BaseTest):
"""Test for SocketHandler objects."""
- if threading:
- server_class = TestTCPServer
- address = ('localhost', 0)
+ server_class = TestTCPServer
+ address = ('localhost', 0)
def setUp(self):
"""Set up a TCP server to receive log messages, and a SocketHandler
@@ -1573,12 +1565,11 @@ def _get_temp_domain_socket():
@unittest.skipIf(True, "FIXME: bpo-30830")
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required")
-@unittest.skipUnless(threading, 'Threading required for this test.')
class UnixSocketHandlerTest(SocketHandlerTest):
"""Test for SocketHandler with unix sockets."""
- if threading and hasattr(socket, "AF_UNIX"):
+ if hasattr(socket, "AF_UNIX"):
server_class = TestUnixStreamServer
def setUp(self):
@@ -1591,14 +1582,12 @@ class UnixSocketHandlerTest(SocketHandlerTest):
support.unlink(self.address)
@unittest.skipIf(True, "FIXME: bpo-30830")
-@unittest.skipUnless(threading, 'Threading required for this test.')
class DatagramHandlerTest(BaseTest):
"""Test for DatagramHandler."""
- if threading:
- server_class = TestUDPServer
- address = ('localhost', 0)
+ server_class = TestUDPServer
+ address = ('localhost', 0)
def setUp(self):
"""Set up a UDP server to receive log messages, and a DatagramHandler
@@ -1659,12 +1648,11 @@ class DatagramHandlerTest(BaseTest):
@unittest.skipIf(True, "FIXME: bpo-30830")
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required")
-@unittest.skipUnless(threading, 'Threading required for this test.')
class UnixDatagramHandlerTest(DatagramHandlerTest):
"""Test for DatagramHandler using Unix sockets."""
- if threading and hasattr(socket, "AF_UNIX"):
+ if hasattr(socket, "AF_UNIX"):
server_class = TestUnixDatagramServer
def setUp(self):
@@ -1676,14 +1664,12 @@ class UnixDatagramHandlerTest(DatagramHandlerTest):
DatagramHandlerTest.tearDown(self)
support.unlink(self.address)
-@unittest.skipUnless(threading, 'Threading required for this test.')
class SysLogHandlerTest(BaseTest):
"""Test for SysLogHandler using UDP."""
- if threading:
- server_class = TestUDPServer
- address = ('localhost', 0)
+ server_class = TestUDPServer
+ address = ('localhost', 0)
def setUp(self):
"""Set up a UDP server to receive log messages, and a SysLogHandler
@@ -1747,12 +1733,11 @@ class SysLogHandlerTest(BaseTest):
@unittest.skipIf(True, "FIXME: bpo-30830")
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required")
-@unittest.skipUnless(threading, 'Threading required for this test.')
class UnixSysLogHandlerTest(SysLogHandlerTest):
"""Test for SysLogHandler with Unix sockets."""
- if threading and hasattr(socket, "AF_UNIX"):
+ if hasattr(socket, "AF_UNIX"):
server_class = TestUnixDatagramServer
def setUp(self):
@@ -1767,7 +1752,6 @@ class UnixSysLogHandlerTest(SysLogHandlerTest):
@unittest.skipIf(True, "FIXME: bpo-30830")
@unittest.skipUnless(support.IPV6_ENABLED,
'IPv6 support required for this test.')
-@unittest.skipUnless(threading, 'Threading required for this test.')
class IPv6SysLogHandlerTest(SysLogHandlerTest):
"""Test for SysLogHandler with IPv6 host."""
@@ -1783,7 +1767,6 @@ class IPv6SysLogHandlerTest(SysLogHandlerTest):
self.server_class.address_family = socket.AF_INET
super(IPv6SysLogHandlerTest, self).tearDown()
-@unittest.skipUnless(threading, 'Threading required for this test.')
class HTTPHandlerTest(BaseTest):
"""Test for HTTPHandler."""
@@ -2892,7 +2875,6 @@ class ConfigDictTest(BaseTest):
# listen() uses ConfigSocketReceiver which is based
# on socketserver.ThreadingTCPServer
@unittest.skipIf(True, "FIXME: bpo-30830")
- @unittest.skipUnless(threading, 'listen() needs threading to work')
def setup_via_listener(self, text, verify=None):
text = text.encode("utf-8")
# Ask for a randomly assigned port (by using port 0)
@@ -2923,7 +2905,6 @@ class ConfigDictTest(BaseTest):
if t.is_alive():
self.fail("join() timed out")
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_listen_config_10_ok(self):
with support.captured_stdout() as output:
self.setup_via_listener(json.dumps(self.config10))
@@ -2943,7 +2924,6 @@ class ConfigDictTest(BaseTest):
('ERROR', '4'),
], stream=output)
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_listen_config_1_ok(self):
with support.captured_stdout() as output:
self.setup_via_listener(textwrap.dedent(ConfigFileTest.config1))
@@ -2958,7 +2938,6 @@ class ConfigDictTest(BaseTest):
# Original logger output is empty.
self.assert_log_lines([])
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_listen_verify(self):
def verify_fail(stuff):
@@ -3713,9 +3692,8 @@ class LogRecordTest(BaseTest):
def test_optional(self):
r = logging.makeLogRecord({})
NOT_NONE = self.assertIsNotNone
- if threading:
- NOT_NONE(r.thread)
- NOT_NONE(r.threadName)
+ NOT_NONE(r.thread)
+ NOT_NONE(r.threadName)
NOT_NONE(r.process)
NOT_NONE(r.processName)
log_threads = logging.logThreads
diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py
index e2cd36a..bb780bf 100644
--- a/Lib/test/test_nntplib.py
+++ b/Lib/test/test_nntplib.py
@@ -6,6 +6,8 @@ import unittest
import functools
import contextlib
import os.path
+import threading
+
from test import support
from nntplib import NNTP, GroupInfo
import nntplib
@@ -14,10 +16,7 @@ try:
import ssl
except ImportError:
ssl = None
-try:
- import threading
-except ImportError:
- threading = None
+
TIMEOUT = 30
certfile = os.path.join(os.path.dirname(__file__), 'keycert3.pem')
@@ -1520,7 +1519,7 @@ class MockSslTests(MockSocketTests):
def nntp_class(*pos, **kw):
return nntplib.NNTP_SSL(*pos, ssl_context=bypass_context, **kw)
-@unittest.skipUnless(threading, 'requires multithreading')
+
class LocalServerTests(unittest.TestCase):
def setUp(self):
sock = socket.socket()
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
index 234f701..c3c8238 100644
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -22,15 +22,13 @@ import stat
import subprocess
import sys
import sysconfig
+import threading
import time
import unittest
import uuid
import warnings
from test import support
-try:
- import threading
-except ImportError:
- threading = None
+
try:
import resource
except ImportError:
@@ -2516,92 +2514,90 @@ class ProgramPriorityTests(unittest.TestCase):
raise
-if threading is not None:
- class SendfileTestServer(asyncore.dispatcher, threading.Thread):
-
- class Handler(asynchat.async_chat):
+class SendfileTestServer(asyncore.dispatcher, threading.Thread):
- def __init__(self, conn):
- asynchat.async_chat.__init__(self, conn)
- self.in_buffer = []
- self.closed = False
- self.push(b"220 ready\r\n")
+ class Handler(asynchat.async_chat):
- def handle_read(self):
- data = self.recv(4096)
- self.in_buffer.append(data)
+ def __init__(self, conn):
+ asynchat.async_chat.__init__(self, conn)
+ self.in_buffer = []
+ self.closed = False
+ self.push(b"220 ready\r\n")
- def get_data(self):
- return b''.join(self.in_buffer)
+ def handle_read(self):
+ data = self.recv(4096)
+ self.in_buffer.append(data)
- def handle_close(self):
- self.close()
- self.closed = True
+ def get_data(self):
+ return b''.join(self.in_buffer)
- def handle_error(self):
- raise
-
- def __init__(self, address):
- threading.Thread.__init__(self)
- asyncore.dispatcher.__init__(self)
- self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
- self.bind(address)
- self.listen(5)
- self.host, self.port = self.socket.getsockname()[:2]
- self.handler_instance = None
- self._active = False
- self._active_lock = threading.Lock()
-
- # --- public API
-
- @property
- def running(self):
- return self._active
-
- def start(self):
- assert not self.running
- self.__flag = threading.Event()
- threading.Thread.start(self)
- self.__flag.wait()
-
- def stop(self):
- assert self.running
- self._active = False
- self.join()
-
- def wait(self):
- # wait for handler connection to be closed, then stop the server
- while not getattr(self.handler_instance, "closed", False):
- time.sleep(0.001)
- self.stop()
-
- # --- internals
-
- def run(self):
- self._active = True
- self.__flag.set()
- while self._active and asyncore.socket_map:
- self._active_lock.acquire()
- asyncore.loop(timeout=0.001, count=1)
- self._active_lock.release()
- asyncore.close_all()
-
- def handle_accept(self):
- conn, addr = self.accept()
- self.handler_instance = self.Handler(conn)
-
- def handle_connect(self):
+ def handle_close(self):
self.close()
- handle_read = handle_connect
-
- def writable(self):
- return 0
+ self.closed = True
def handle_error(self):
raise
+ def __init__(self, address):
+ threading.Thread.__init__(self)
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.bind(address)
+ self.listen(5)
+ self.host, self.port = self.socket.getsockname()[:2]
+ self.handler_instance = None
+ self._active = False
+ self._active_lock = threading.Lock()
+
+ # --- public API
+
+ @property
+ def running(self):
+ return self._active
+
+ def start(self):
+ assert not self.running
+ self.__flag = threading.Event()
+ threading.Thread.start(self)
+ self.__flag.wait()
+
+ def stop(self):
+ assert self.running
+ self._active = False
+ self.join()
+
+ def wait(self):
+ # wait for handler connection to be closed, then stop the server
+ while not getattr(self.handler_instance, "closed", False):
+ time.sleep(0.001)
+ self.stop()
+
+ # --- internals
+
+ def run(self):
+ self._active = True
+ self.__flag.set()
+ while self._active and asyncore.socket_map:
+ self._active_lock.acquire()
+ asyncore.loop(timeout=0.001, count=1)
+ self._active_lock.release()
+ asyncore.close_all()
+
+ def handle_accept(self):
+ conn, addr = self.accept()
+ self.handler_instance = self.Handler(conn)
+
+ def handle_connect(self):
+ self.close()
+ handle_read = handle_connect
+
+ def writable(self):
+ return 0
+
+ def handle_error(self):
+ raise
+
-@unittest.skipUnless(threading is not None, "test needs threading module")
@unittest.skipUnless(hasattr(os, 'sendfile'), "test needs os.sendfile()")
class TestSendfile(unittest.TestCase):
diff --git a/Lib/test/test_pdb.py b/Lib/test/test_pdb.py
index 0ea2af5..755d265 100644
--- a/Lib/test/test_pdb.py
+++ b/Lib/test/test_pdb.py
@@ -1040,9 +1040,6 @@ class PdbTestCase(unittest.TestCase):
# invoking "continue" on a non-main thread triggered an exception
# inside signal.signal
- # raises SkipTest if python was built without threads
- support.import_module('threading')
-
with open(support.TESTFN, 'wb') as f:
f.write(textwrap.dedent("""
import threading
diff --git a/Lib/test/test_poll.py b/Lib/test/test_poll.py
index 6a2bf6e..16c2d2e 100644
--- a/Lib/test/test_poll.py
+++ b/Lib/test/test_poll.py
@@ -4,10 +4,7 @@ import os
import subprocess
import random
import select
-try:
- import threading
-except ImportError:
- threading = None
+import threading
import time
import unittest
from test.support import TESTFN, run_unittest, reap_threads, cpython_only
@@ -179,7 +176,6 @@ class PollTests(unittest.TestCase):
self.assertRaises(OverflowError, pollster.poll, INT_MAX + 1)
self.assertRaises(OverflowError, pollster.poll, UINT_MAX + 1)
- @unittest.skipUnless(threading, 'Threading required for this test.')
@reap_threads
def test_threaded_poll(self):
r, w = os.pipe()
diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py
index e5b16dc..92febbf 100644
--- a/Lib/test/test_poplib.py
+++ b/Lib/test/test_poplib.py
@@ -9,10 +9,10 @@ import asynchat
import socket
import os
import errno
+import threading
from unittest import TestCase, skipUnless
from test import support as test_support
-threading = test_support.import_module('threading')
HOST = test_support.HOST
PORT = 0
diff --git a/Lib/test/test_pydoc.py b/Lib/test/test_pydoc.py
index d68ab55..1ac08ed 100644
--- a/Lib/test/test_pydoc.py
+++ b/Lib/test/test_pydoc.py
@@ -20,6 +20,7 @@ import urllib.parse
import xml.etree
import xml.etree.ElementTree
import textwrap
+import threading
from io import StringIO
from collections import namedtuple
from test.support.script_helper import assert_python_ok
@@ -30,10 +31,6 @@ from test.support import (
)
from test import pydoc_mod
-try:
- import threading
-except ImportError:
- threading = None
class nonascii:
'Це не латиниця'
@@ -902,7 +899,6 @@ class TestDescriptions(unittest.TestCase):
"stat(path, *, dir_fd=None, follow_symlinks=True)")
-@unittest.skipUnless(threading, 'Threading required for this test.')
class PydocServerTest(unittest.TestCase):
"""Tests for pydoc._start_server"""
diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py
index 4ccaa39..718ed67 100644
--- a/Lib/test/test_queue.py
+++ b/Lib/test/test_queue.py
@@ -1,10 +1,11 @@
# Some simple queue module tests, plus some failure conditions
# to ensure the Queue locks remain stable.
import queue
+import threading
import time
import unittest
from test import support
-threading = support.import_module('threading')
+
QUEUE_SIZE = 5
diff --git a/Lib/test/test_regrtest.py b/Lib/test/test_regrtest.py
index b756839..8364767 100644
--- a/Lib/test/test_regrtest.py
+++ b/Lib/test/test_regrtest.py
@@ -15,6 +15,7 @@ import sys
import sysconfig
import tempfile
import textwrap
+import threading
import unittest
from test import libregrtest
from test import support
@@ -741,12 +742,7 @@ class ArgsTestCase(BaseTestCase):
code = TEST_INTERRUPTED
test = self.create_test("sigint", code=code)
- try:
- import threading
- tests = (False, True)
- except ImportError:
- tests = (False,)
- for multiprocessing in tests:
+ for multiprocessing in (False, True):
if multiprocessing:
args = ("--slowest", "-j2", test)
else:
diff --git a/Lib/test/test_robotparser.py b/Lib/test/test_robotparser.py
index 0f64ba8..5c1a571 100644
--- a/Lib/test/test_robotparser.py
+++ b/Lib/test/test_robotparser.py
@@ -1,14 +1,11 @@
import io
import os
+import threading
import unittest
import urllib.robotparser
from collections import namedtuple
from test import support
from http.server import BaseHTTPRequestHandler, HTTPServer
-try:
- import threading
-except ImportError:
- threading = None
class BaseRobotTest:
@@ -255,7 +252,6 @@ class RobotHandler(BaseHTTPRequestHandler):
pass
-@unittest.skipUnless(threading, 'threading required for this test')
class PasswordProtectedSiteTestCase(unittest.TestCase):
def setUp(self):
diff --git a/Lib/test/test_sched.py b/Lib/test/test_sched.py
index ebf8856..794c637 100644
--- a/Lib/test/test_sched.py
+++ b/Lib/test/test_sched.py
@@ -1,11 +1,9 @@
import queue
import sched
+import threading
import time
import unittest
-try:
- import threading
-except ImportError:
- threading = None
+
TIMEOUT = 10
@@ -58,7 +56,6 @@ class TestCase(unittest.TestCase):
scheduler.run()
self.assertEqual(l, [0.01, 0.02, 0.03, 0.04, 0.05])
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_enter_concurrent(self):
q = queue.Queue()
fun = q.put
@@ -113,7 +110,6 @@ class TestCase(unittest.TestCase):
scheduler.run()
self.assertEqual(l, [0.02, 0.03, 0.04])
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_cancel_concurrent(self):
q = queue.Queue()
fun = q.put
diff --git a/Lib/test/test_signal.py b/Lib/test/test_signal.py
index 0e1d067..dc048e5 100644
--- a/Lib/test/test_signal.py
+++ b/Lib/test/test_signal.py
@@ -5,15 +5,12 @@ import socket
import statistics
import subprocess
import sys
+import threading
import time
import unittest
from test import support
from test.support.script_helper import assert_python_ok, spawn_python
try:
- import threading
-except ImportError:
- threading = None
-try:
import _testcapi
except ImportError:
_testcapi = None
@@ -21,7 +18,6 @@ except ImportError:
class GenericTests(unittest.TestCase):
- @unittest.skipIf(threading is None, "test needs threading module")
def test_enums(self):
for name in dir(signal):
sig = getattr(signal, name)
@@ -807,7 +803,6 @@ class PendingSignalsTests(unittest.TestCase):
'need signal.sigwait()')
@unittest.skipUnless(hasattr(signal, 'pthread_sigmask'),
'need signal.pthread_sigmask()')
- @unittest.skipIf(threading is None, "test needs threading module")
def test_sigwait_thread(self):
# Check that calling sigwait() from a thread doesn't suspend the whole
# process. A new interpreter is spawned to avoid problems when mixing
diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py
index 28539f3..42f4266 100644
--- a/Lib/test/test_smtplib.py
+++ b/Lib/test/test_smtplib.py
@@ -15,14 +15,11 @@ import time
import select
import errno
import textwrap
+import threading
import unittest
from test import support, mock_socket
-try:
- import threading
-except ImportError:
- threading = None
HOST = support.HOST
@@ -191,7 +188,6 @@ MSG_END = '------------ END MESSAGE ------------\n'
# test server times out, causing the test to fail.
# Test behavior of smtpd.DebuggingServer
-@unittest.skipUnless(threading, 'Threading required for this test.')
class DebuggingServerTests(unittest.TestCase):
maxDiff = None
@@ -570,7 +566,6 @@ class NonConnectingTests(unittest.TestCase):
# test response of client to a non-successful HELO message
-@unittest.skipUnless(threading, 'Threading required for this test.')
class BadHELOServerTests(unittest.TestCase):
def setUp(self):
@@ -590,7 +585,6 @@ class BadHELOServerTests(unittest.TestCase):
HOST, self.port, 'localhost', 3)
-@unittest.skipUnless(threading, 'Threading required for this test.')
class TooLongLineTests(unittest.TestCase):
respdata = b'250 OK' + (b'.' * smtplib._MAXLINE * 2) + b'\n'
@@ -835,7 +829,6 @@ class SimSMTPServer(smtpd.SMTPServer):
# Test various SMTP & ESMTP commands/behaviors that require a simulated server
# (i.e., something with more features than DebuggingServer)
-@unittest.skipUnless(threading, 'Threading required for this test.')
class SMTPSimTests(unittest.TestCase):
def setUp(self):
@@ -1091,7 +1084,6 @@ class SimSMTPUTF8Server(SimSMTPServer):
self.last_rcpt_options = rcpt_options
-@unittest.skipUnless(threading, 'Threading required for this test.')
class SMTPUTF8SimTests(unittest.TestCase):
maxDiff = None
@@ -1227,7 +1219,6 @@ class SimSMTPAUTHInitialResponseServer(SimSMTPServer):
channel_class = SimSMTPAUTHInitialResponseChannel
-@unittest.skipUnless(threading, 'Threading required for this test.')
class SMTPAUTHInitialResponseSimTests(unittest.TestCase):
def setUp(self):
self.real_getfqdn = socket.getfqdn
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 50016ab..27d9d49 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -21,6 +21,8 @@ import pickle
import struct
import random
import string
+import _thread as thread
+import threading
try:
import multiprocessing
except ImportError:
@@ -36,12 +38,6 @@ MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') ## test unicode string
VSOCKPORT = 1234
try:
- import _thread as thread
- import threading
-except ImportError:
- thread = None
- threading = None
-try:
import _socket
except ImportError:
_socket = None
@@ -143,18 +139,17 @@ class ThreadSafeCleanupTestCase(unittest.TestCase):
with a recursive lock.
"""
- if threading:
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._cleanup_lock = threading.RLock()
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._cleanup_lock = threading.RLock()
- def addCleanup(self, *args, **kwargs):
- with self._cleanup_lock:
- return super().addCleanup(*args, **kwargs)
+ def addCleanup(self, *args, **kwargs):
+ with self._cleanup_lock:
+ return super().addCleanup(*args, **kwargs)
- def doCleanups(self, *args, **kwargs):
- with self._cleanup_lock:
- return super().doCleanups(*args, **kwargs)
+ def doCleanups(self, *args, **kwargs):
+ with self._cleanup_lock:
+ return super().doCleanups(*args, **kwargs)
class SocketCANTest(unittest.TestCase):
@@ -407,7 +402,6 @@ class ThreadedRDSSocketTest(SocketRDSTest, ThreadableTest):
ThreadableTest.clientTearDown(self)
@unittest.skipIf(fcntl is None, "need fcntl")
-@unittest.skipUnless(thread, 'Threading required for this test.')
@unittest.skipUnless(HAVE_SOCKET_VSOCK,
'VSOCK sockets required for this test.')
@unittest.skipUnless(get_cid() != 2,
@@ -1684,7 +1678,6 @@ class BasicCANTest(unittest.TestCase):
@unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.')
-@unittest.skipUnless(thread, 'Threading required for this test.')
class CANTest(ThreadedCANSocketTest):
def __init__(self, methodName='runTest'):
@@ -1838,7 +1831,6 @@ class BasicRDSTest(unittest.TestCase):
@unittest.skipUnless(HAVE_SOCKET_RDS, 'RDS sockets required for this test.')
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RDSTest(ThreadedRDSSocketTest):
def __init__(self, methodName='runTest'):
@@ -1977,7 +1969,7 @@ class BasicVSOCKTest(unittest.TestCase):
s.getsockopt(socket.AF_VSOCK,
socket.SO_VM_SOCKETS_BUFFER_MIN_SIZE))
-@unittest.skipUnless(thread, 'Threading required for this test.')
+
class BasicTCPTest(SocketConnectedTest):
def __init__(self, methodName='runTest'):
@@ -2100,7 +2092,7 @@ class BasicTCPTest(SocketConnectedTest):
def _testDetach(self):
self.serv_conn.send(MSG)
-@unittest.skipUnless(thread, 'Threading required for this test.')
+
class BasicUDPTest(ThreadedUDPSocketTest):
def __init__(self, methodName='runTest'):
@@ -3697,17 +3689,14 @@ class SendrecvmsgUDPTestBase(SendrecvmsgDgramFlagsBase,
pass
@requireAttrs(socket.socket, "sendmsg")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class SendmsgUDPTest(SendmsgConnectionlessTests, SendrecvmsgUDPTestBase):
pass
@requireAttrs(socket.socket, "recvmsg")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgUDPTest(RecvmsgTests, SendrecvmsgUDPTestBase):
pass
@requireAttrs(socket.socket, "recvmsg_into")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgIntoUDPTest(RecvmsgIntoTests, SendrecvmsgUDPTestBase):
pass
@@ -3724,21 +3713,18 @@ class SendrecvmsgUDP6TestBase(SendrecvmsgDgramFlagsBase,
@requireAttrs(socket.socket, "sendmsg")
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
@requireSocket("AF_INET6", "SOCK_DGRAM")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class SendmsgUDP6Test(SendmsgConnectionlessTests, SendrecvmsgUDP6TestBase):
pass
@requireAttrs(socket.socket, "recvmsg")
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
@requireSocket("AF_INET6", "SOCK_DGRAM")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgUDP6Test(RecvmsgTests, SendrecvmsgUDP6TestBase):
pass
@requireAttrs(socket.socket, "recvmsg_into")
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
@requireSocket("AF_INET6", "SOCK_DGRAM")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase):
pass
@@ -3746,7 +3732,6 @@ class RecvmsgIntoUDP6Test(RecvmsgIntoTests, SendrecvmsgUDP6TestBase):
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
@requireAttrs(socket, "IPPROTO_IPV6")
@requireSocket("AF_INET6", "SOCK_DGRAM")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest,
SendrecvmsgUDP6TestBase):
pass
@@ -3755,7 +3740,6 @@ class RecvmsgRFC3542AncillaryUDP6Test(RFC3542AncillaryTest,
@unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 required for this test.')
@requireAttrs(socket, "IPPROTO_IPV6")
@requireSocket("AF_INET6", "SOCK_DGRAM")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgIntoRFC3542AncillaryUDP6Test(RecvmsgIntoMixin,
RFC3542AncillaryTest,
SendrecvmsgUDP6TestBase):
@@ -3767,18 +3751,15 @@ class SendrecvmsgTCPTestBase(SendrecvmsgConnectedBase,
pass
@requireAttrs(socket.socket, "sendmsg")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class SendmsgTCPTest(SendmsgStreamTests, SendrecvmsgTCPTestBase):
pass
@requireAttrs(socket.socket, "recvmsg")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgTCPTest(RecvmsgTests, RecvmsgGenericStreamTests,
SendrecvmsgTCPTestBase):
pass
@requireAttrs(socket.socket, "recvmsg_into")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgIntoTCPTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
SendrecvmsgTCPTestBase):
pass
@@ -3791,13 +3772,11 @@ class SendrecvmsgSCTPStreamTestBase(SendrecvmsgSCTPFlagsBase,
@requireAttrs(socket.socket, "sendmsg")
@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class SendmsgSCTPStreamTest(SendmsgStreamTests, SendrecvmsgSCTPStreamTestBase):
pass
@requireAttrs(socket.socket, "recvmsg")
@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
SendrecvmsgSCTPStreamTestBase):
@@ -3811,7 +3790,6 @@ class RecvmsgSCTPStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
@requireAttrs(socket.socket, "recvmsg_into")
@requireSocket("AF_INET", "SOCK_STREAM", "IPPROTO_SCTP")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgIntoSCTPStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
SendrecvmsgSCTPStreamTestBase):
@@ -3830,33 +3808,28 @@ class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase,
@requireAttrs(socket.socket, "sendmsg")
@requireAttrs(socket, "AF_UNIX")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase):
pass
@requireAttrs(socket.socket, "recvmsg")
@requireAttrs(socket, "AF_UNIX")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgUnixStreamTest(RecvmsgTests, RecvmsgGenericStreamTests,
SendrecvmsgUnixStreamTestBase):
pass
@requireAttrs(socket.socket, "recvmsg_into")
@requireAttrs(socket, "AF_UNIX")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgIntoUnixStreamTest(RecvmsgIntoTests, RecvmsgGenericStreamTests,
SendrecvmsgUnixStreamTestBase):
pass
@requireAttrs(socket.socket, "sendmsg", "recvmsg")
@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgSCMRightsStreamTest(SCMRightsTest, SendrecvmsgUnixStreamTestBase):
pass
@requireAttrs(socket.socket, "sendmsg", "recvmsg_into")
@requireAttrs(socket, "AF_UNIX", "SOL_SOCKET", "SCM_RIGHTS")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class RecvmsgIntoSCMRightsStreamTest(RecvmsgIntoMixin, SCMRightsTest,
SendrecvmsgUnixStreamTestBase):
pass
@@ -3944,7 +3917,6 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase):
@requireAttrs(signal, "siginterrupt")
@unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"),
"Don't have signal.alarm or signal.setitimer")
-@unittest.skipUnless(thread, 'Threading required for this test.')
class InterruptedSendTimeoutTest(InterruptedTimeoutBase,
ThreadSafeCleanupTestCase,
SocketListeningTestMixin, TCPTestBase):
@@ -3997,7 +3969,6 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase,
self.checkInterruptedSend(self.serv_conn.sendmsg, [b"a"*512])
-@unittest.skipUnless(thread, 'Threading required for this test.')
class TCPCloserTest(ThreadedTCPSocketTest):
def testClose(self):
@@ -4017,7 +3988,7 @@ class TCPCloserTest(ThreadedTCPSocketTest):
self.cli.connect((HOST, self.port))
time.sleep(1.0)
-@unittest.skipUnless(thread, 'Threading required for this test.')
+
class BasicSocketPairTest(SocketPairTest):
def __init__(self, methodName='runTest'):
@@ -4052,7 +4023,7 @@ class BasicSocketPairTest(SocketPairTest):
msg = self.cli.recv(1024)
self.assertEqual(msg, MSG)
-@unittest.skipUnless(thread, 'Threading required for this test.')
+
class NonBlockingTCPTests(ThreadedTCPSocketTest):
def __init__(self, methodName='runTest'):
@@ -4180,7 +4151,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):
time.sleep(0.1)
self.cli.send(MSG)
-@unittest.skipUnless(thread, 'Threading required for this test.')
+
class FileObjectClassTestCase(SocketConnectedTest):
"""Unit tests for the object returned by socket.makefile()
@@ -4564,7 +4535,6 @@ class NetworkConnectionNoServer(unittest.TestCase):
socket.create_connection((HOST, 1234))
-@unittest.skipUnless(thread, 'Threading required for this test.')
class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
def __init__(self, methodName='runTest'):
@@ -4633,7 +4603,7 @@ class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
self.addCleanup(self.cli.close)
self.assertEqual(self.cli.gettimeout(), 30)
-@unittest.skipUnless(thread, 'Threading required for this test.')
+
class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
def __init__(self, methodName='runTest'):
@@ -4877,7 +4847,7 @@ class TestUnixDomain(unittest.TestCase):
self.addCleanup(support.unlink, path)
self.assertEqual(self.sock.getsockname(), path)
-@unittest.skipUnless(thread, 'Threading required for this test.')
+
class BufferIOTest(SocketConnectedTest):
"""
Test the buffer versions of socket.recv() and socket.send().
@@ -5050,7 +5020,6 @@ class TIPCThreadableTest(unittest.TestCase, ThreadableTest):
self.cli.close()
-@unittest.skipUnless(thread, 'Threading required for this test.')
class ContextManagersTest(ThreadedTCPSocketTest):
def _testSocketClass(self):
@@ -5312,7 +5281,6 @@ class TestSocketSharing(SocketTCPTest):
source.close()
-@unittest.skipUnless(thread, 'Threading required for this test.')
class SendfileUsingSendTest(ThreadedTCPSocketTest):
"""
Test the send() implementation of socket.sendfile().
@@ -5570,7 +5538,6 @@ class SendfileUsingSendTest(ThreadedTCPSocketTest):
meth, file, count=-1)
-@unittest.skipUnless(thread, 'Threading required for this test.')
@unittest.skipUnless(hasattr(os, "sendfile"),
'os.sendfile() required for this test.')
class SendfileUsingSendfileTest(SendfileUsingSendTest):
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py
index 3d93566..a23373f 100644
--- a/Lib/test/test_socketserver.py
+++ b/Lib/test/test_socketserver.py
@@ -9,15 +9,13 @@ import select
import signal
import socket
import tempfile
+import threading
import unittest
import socketserver
import test.support
from test.support import reap_children, reap_threads, verbose
-try:
- import threading
-except ImportError:
- threading = None
+
test.support.requires("network")
@@ -68,7 +66,6 @@ def simple_subprocess(testcase):
testcase.assertEqual(72 << 8, status)
-@unittest.skipUnless(threading, 'Threading required for this test.')
class SocketServerTest(unittest.TestCase):
"""Test all socket servers."""
@@ -306,12 +303,10 @@ class ErrorHandlerTest(unittest.TestCase):
BaseErrorTestServer(SystemExit)
self.check_result(handled=False)
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_threading_handled(self):
ThreadingErrorTestServer(ValueError)
self.check_result(handled=True)
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_threading_not_handled(self):
ThreadingErrorTestServer(SystemExit)
self.check_result(handled=False)
@@ -396,7 +391,6 @@ class SocketWriterTest(unittest.TestCase):
self.assertIsInstance(server.wfile, io.BufferedIOBase)
self.assertEqual(server.wfile_fileno, server.request_fileno)
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_write(self):
# Test that wfile.write() sends data immediately, and that it does
# not truncate sends when interrupted by a Unix signal
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 16cad9d..89b4609 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -12,6 +12,7 @@ import os
import errno
import pprint
import urllib.request
+import threading
import traceback
import asyncore
import weakref
@@ -20,12 +21,6 @@ import functools
ssl = support.import_module("ssl")
-try:
- import threading
-except ImportError:
- _have_threads = False
-else:
- _have_threads = True
PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
HOST = support.HOST
@@ -1468,7 +1463,6 @@ class MemoryBIOTests(unittest.TestCase):
self.assertRaises(TypeError, bio.write, 1)
-@unittest.skipUnless(_have_threads, "Needs threading module")
class SimpleBackgroundTests(unittest.TestCase):
"""Tests that connect to a simple server running in the background"""
@@ -1828,1744 +1822,1743 @@ def _test_get_server_certificate_fail(test, host, port):
test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
-if _have_threads:
- from test.ssl_servers import make_https_server
+from test.ssl_servers import make_https_server
- class ThreadedEchoServer(threading.Thread):
+class ThreadedEchoServer(threading.Thread):
- class ConnectionHandler(threading.Thread):
+ class ConnectionHandler(threading.Thread):
- """A mildly complicated class, because we want it to work both
- with and without the SSL wrapper around the socket connection, so
- that we can test the STARTTLS functionality."""
+ """A mildly complicated class, because we want it to work both
+ with and without the SSL wrapper around the socket connection, so
+ that we can test the STARTTLS functionality."""
- def __init__(self, server, connsock, addr):
- self.server = server
+ def __init__(self, server, connsock, addr):
+ self.server = server
+ self.running = False
+ self.sock = connsock
+ self.addr = addr
+ self.sock.setblocking(1)
+ self.sslconn = None
+ threading.Thread.__init__(self)
+ self.daemon = True
+
+ def wrap_conn(self):
+ try:
+ self.sslconn = self.server.context.wrap_socket(
+ self.sock, server_side=True)
+ self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
+ self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
+ except (ssl.SSLError, ConnectionResetError, OSError) as e:
+ # We treat ConnectionResetError as though it were an
+ # SSLError - OpenSSL on Ubuntu abruptly closes the
+ # connection when asked to use an unsupported protocol.
+ #
+ # OSError may occur with wrong protocols, e.g. both
+ # sides use PROTOCOL_TLS_SERVER.
+ #
+ # XXX Various errors can have happened here, for example
+ # a mismatching protocol version, an invalid certificate,
+ # or a low-level bug. This should be made more discriminating.
+ #
+ # bpo-31323: Store the exception as string to prevent
+ # a reference leak: server -> conn_errors -> exception
+ # -> traceback -> self (ConnectionHandler) -> server
+ self.server.conn_errors.append(str(e))
+ if self.server.chatty:
+ handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n")
self.running = False
- self.sock = connsock
- self.addr = addr
- self.sock.setblocking(1)
- self.sslconn = None
- threading.Thread.__init__(self)
- self.daemon = True
-
- def wrap_conn(self):
- try:
- self.sslconn = self.server.context.wrap_socket(
- self.sock, server_side=True)
- self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
- self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
- except (ssl.SSLError, ConnectionResetError, OSError) as e:
- # We treat ConnectionResetError as though it were an
- # SSLError - OpenSSL on Ubuntu abruptly closes the
- # connection when asked to use an unsupported protocol.
- #
- # OSError may occur with wrong protocols, e.g. both
- # sides use PROTOCOL_TLS_SERVER.
- #
- # XXX Various errors can have happened here, for example
- # a mismatching protocol version, an invalid certificate,
- # or a low-level bug. This should be made more discriminating.
- #
- # bpo-31323: Store the exception as string to prevent
- # a reference leak: server -> conn_errors -> exception
- # -> traceback -> self (ConnectionHandler) -> server
- self.server.conn_errors.append(str(e))
- if self.server.chatty:
- handle_error("\n server: bad connection attempt from " + repr(self.addr) + ":\n")
- self.running = False
- self.server.stop()
- self.close()
- return False
- else:
- self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
- if self.server.context.verify_mode == ssl.CERT_REQUIRED:
- cert = self.sslconn.getpeercert()
- if support.verbose and self.server.chatty:
- sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
- cert_binary = self.sslconn.getpeercert(True)
- if support.verbose and self.server.chatty:
- sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
- cipher = self.sslconn.cipher()
+ self.server.stop()
+ self.close()
+ return False
+ else:
+ self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
+ if self.server.context.verify_mode == ssl.CERT_REQUIRED:
+ cert = self.sslconn.getpeercert()
if support.verbose and self.server.chatty:
- sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
- sys.stdout.write(" server: selected protocol is now "
- + str(self.sslconn.selected_npn_protocol()) + "\n")
- return True
-
- def read(self):
- if self.sslconn:
- return self.sslconn.read()
- else:
- return self.sock.recv(1024)
+ sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
+ cert_binary = self.sslconn.getpeercert(True)
+ if support.verbose and self.server.chatty:
+ sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
+ cipher = self.sslconn.cipher()
+ if support.verbose and self.server.chatty:
+ sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
+ sys.stdout.write(" server: selected protocol is now "
+ + str(self.sslconn.selected_npn_protocol()) + "\n")
+ return True
+
+ def read(self):
+ if self.sslconn:
+ return self.sslconn.read()
+ else:
+ return self.sock.recv(1024)
- def write(self, bytes):
- if self.sslconn:
- return self.sslconn.write(bytes)
- else:
- return self.sock.send(bytes)
+ def write(self, bytes):
+ if self.sslconn:
+ return self.sslconn.write(bytes)
+ else:
+ return self.sock.send(bytes)
- def close(self):
- if self.sslconn:
- self.sslconn.close()
- else:
- self.sock.close()
+ def close(self):
+ if self.sslconn:
+ self.sslconn.close()
+ else:
+ self.sock.close()
- def run(self):
- self.running = True
- if not self.server.starttls_server:
- if not self.wrap_conn():
- return
- while self.running:
- try:
- msg = self.read()
- stripped = msg.strip()
- if not stripped:
- # eof, so quit this handler
- self.running = False
- try:
- self.sock = self.sslconn.unwrap()
- except OSError:
- # Many tests shut the TCP connection down
- # without an SSL shutdown. This causes
- # unwrap() to raise OSError with errno=0!
- pass
- else:
- self.sslconn = None
- self.close()
- elif stripped == b'over':
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: client closed connection\n")
- self.close()
- return
- elif (self.server.starttls_server and
- stripped == b'STARTTLS'):
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
- self.write(b"OK\n")
- if not self.wrap_conn():
- return
- elif (self.server.starttls_server and self.sslconn
- and stripped == b'ENDTLS'):
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
- self.write(b"OK\n")
+ def run(self):
+ self.running = True
+ if not self.server.starttls_server:
+ if not self.wrap_conn():
+ return
+ while self.running:
+ try:
+ msg = self.read()
+ stripped = msg.strip()
+ if not stripped:
+ # eof, so quit this handler
+ self.running = False
+ try:
self.sock = self.sslconn.unwrap()
- self.sslconn = None
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: connection is now unencrypted...\n")
- elif stripped == b'CB tls-unique':
- if support.verbose and self.server.connectionchatty:
- sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
- data = self.sslconn.get_channel_binding("tls-unique")
- self.write(repr(data).encode("us-ascii") + b"\n")
+ except OSError:
+ # Many tests shut the TCP connection down
+ # without an SSL shutdown. This causes
+ # unwrap() to raise OSError with errno=0!
+ pass
else:
- if (support.verbose and
- self.server.connectionchatty):
- ctype = (self.sslconn and "encrypted") or "unencrypted"
- sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
- % (msg, ctype, msg.lower(), ctype))
- self.write(msg.lower())
- except OSError:
- if self.server.chatty:
- handle_error("Test server failure:\n")
+ self.sslconn = None
self.close()
- self.running = False
- # normally, we'd just stop here, but for the test
- # harness, we want to stop the server
- self.server.stop()
-
- def __init__(self, certificate=None, ssl_version=None,
- certreqs=None, cacerts=None,
- chatty=True, connectionchatty=False, starttls_server=False,
- npn_protocols=None, alpn_protocols=None,
- ciphers=None, context=None):
- if context:
- self.context = context
- else:
- self.context = ssl.SSLContext(ssl_version
- if ssl_version is not None
- else ssl.PROTOCOL_TLSv1)
- self.context.verify_mode = (certreqs if certreqs is not None
- else ssl.CERT_NONE)
- if cacerts:
- self.context.load_verify_locations(cacerts)
- if certificate:
- self.context.load_cert_chain(certificate)
- if npn_protocols:
- self.context.set_npn_protocols(npn_protocols)
- if alpn_protocols:
- self.context.set_alpn_protocols(alpn_protocols)
- if ciphers:
- self.context.set_ciphers(ciphers)
- self.chatty = chatty
- self.connectionchatty = connectionchatty
- self.starttls_server = starttls_server
- self.sock = socket.socket()
- self.port = support.bind_port(self.sock)
- self.flag = None
- self.active = False
- self.selected_npn_protocols = []
- self.selected_alpn_protocols = []
- self.shared_ciphers = []
- self.conn_errors = []
- threading.Thread.__init__(self)
- self.daemon = True
-
- def __enter__(self):
- self.start(threading.Event())
- self.flag.wait()
- return self
-
- def __exit__(self, *args):
- self.stop()
- self.join()
-
- def start(self, flag=None):
- self.flag = flag
- threading.Thread.start(self)
+ elif stripped == b'over':
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: client closed connection\n")
+ self.close()
+ return
+ elif (self.server.starttls_server and
+ stripped == b'STARTTLS'):
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
+ self.write(b"OK\n")
+ if not self.wrap_conn():
+ return
+ elif (self.server.starttls_server and self.sslconn
+ and stripped == b'ENDTLS'):
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
+ self.write(b"OK\n")
+ self.sock = self.sslconn.unwrap()
+ self.sslconn = None
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: connection is now unencrypted...\n")
+ elif stripped == b'CB tls-unique':
+ if support.verbose and self.server.connectionchatty:
+ sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
+ data = self.sslconn.get_channel_binding("tls-unique")
+ self.write(repr(data).encode("us-ascii") + b"\n")
+ else:
+ if (support.verbose and
+ self.server.connectionchatty):
+ ctype = (self.sslconn and "encrypted") or "unencrypted"
+ sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
+ % (msg, ctype, msg.lower(), ctype))
+ self.write(msg.lower())
+ except OSError:
+ if self.server.chatty:
+ handle_error("Test server failure:\n")
+ self.close()
+ self.running = False
+ # normally, we'd just stop here, but for the test
+ # harness, we want to stop the server
+ self.server.stop()
- def run(self):
- self.sock.settimeout(0.05)
- self.sock.listen()
- self.active = True
- if self.flag:
- # signal an event
- self.flag.set()
- while self.active:
- try:
- newconn, connaddr = self.sock.accept()
- if support.verbose and self.chatty:
- sys.stdout.write(' server: new connection from '
- + repr(connaddr) + '\n')
- handler = self.ConnectionHandler(self, newconn, connaddr)
- handler.start()
- handler.join()
- except socket.timeout:
- pass
- except KeyboardInterrupt:
- self.stop()
- self.sock.close()
+ def __init__(self, certificate=None, ssl_version=None,
+ certreqs=None, cacerts=None,
+ chatty=True, connectionchatty=False, starttls_server=False,
+ npn_protocols=None, alpn_protocols=None,
+ ciphers=None, context=None):
+ if context:
+ self.context = context
+ else:
+ self.context = ssl.SSLContext(ssl_version
+ if ssl_version is not None
+ else ssl.PROTOCOL_TLSv1)
+ self.context.verify_mode = (certreqs if certreqs is not None
+ else ssl.CERT_NONE)
+ if cacerts:
+ self.context.load_verify_locations(cacerts)
+ if certificate:
+ self.context.load_cert_chain(certificate)
+ if npn_protocols:
+ self.context.set_npn_protocols(npn_protocols)
+ if alpn_protocols:
+ self.context.set_alpn_protocols(alpn_protocols)
+ if ciphers:
+ self.context.set_ciphers(ciphers)
+ self.chatty = chatty
+ self.connectionchatty = connectionchatty
+ self.starttls_server = starttls_server
+ self.sock = socket.socket()
+ self.port = support.bind_port(self.sock)
+ self.flag = None
+ self.active = False
+ self.selected_npn_protocols = []
+ self.selected_alpn_protocols = []
+ self.shared_ciphers = []
+ self.conn_errors = []
+ threading.Thread.__init__(self)
+ self.daemon = True
+
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+ return self
+
+ def __exit__(self, *args):
+ self.stop()
+ self.join()
+
+ def start(self, flag=None):
+ self.flag = flag
+ threading.Thread.start(self)
+
+ def run(self):
+ self.sock.settimeout(0.05)
+ self.sock.listen()
+ self.active = True
+ if self.flag:
+ # signal an event
+ self.flag.set()
+ while self.active:
+ try:
+ newconn, connaddr = self.sock.accept()
+ if support.verbose and self.chatty:
+ sys.stdout.write(' server: new connection from '
+ + repr(connaddr) + '\n')
+ handler = self.ConnectionHandler(self, newconn, connaddr)
+ handler.start()
+ handler.join()
+ except socket.timeout:
+ pass
+ except KeyboardInterrupt:
+ self.stop()
+ self.sock.close()
- def stop(self):
- self.active = False
+ def stop(self):
+ self.active = False
- class AsyncoreEchoServer(threading.Thread):
+class AsyncoreEchoServer(threading.Thread):
- # this one's based on asyncore.dispatcher
+ # this one's based on asyncore.dispatcher
- class EchoServer (asyncore.dispatcher):
+ class EchoServer (asyncore.dispatcher):
- class ConnectionHandler(asyncore.dispatcher_with_send):
+ class ConnectionHandler(asyncore.dispatcher_with_send):
- def __init__(self, conn, certfile):
- self.socket = test_wrap_socket(conn, server_side=True,
- certfile=certfile,
- do_handshake_on_connect=False)
- asyncore.dispatcher_with_send.__init__(self, self.socket)
- self._ssl_accepting = True
- self._do_ssl_handshake()
+ def __init__(self, conn, certfile):
+ self.socket = test_wrap_socket(conn, server_side=True,
+ certfile=certfile,
+ do_handshake_on_connect=False)
+ asyncore.dispatcher_with_send.__init__(self, self.socket)
+ self._ssl_accepting = True
+ self._do_ssl_handshake()
- def readable(self):
- if isinstance(self.socket, ssl.SSLSocket):
- while self.socket.pending() > 0:
- self.handle_read_event()
- return True
+ def readable(self):
+ if isinstance(self.socket, ssl.SSLSocket):
+ while self.socket.pending() > 0:
+ self.handle_read_event()
+ return True
- def _do_ssl_handshake(self):
- try:
- self.socket.do_handshake()
- except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
- return
- except ssl.SSLEOFError:
+ def _do_ssl_handshake(self):
+ try:
+ self.socket.do_handshake()
+ except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
+ return
+ except ssl.SSLEOFError:
+ return self.handle_close()
+ except ssl.SSLError:
+ raise
+ except OSError as err:
+ if err.args[0] == errno.ECONNABORTED:
return self.handle_close()
- except ssl.SSLError:
- raise
- except OSError as err:
- if err.args[0] == errno.ECONNABORTED:
- return self.handle_close()
- else:
- self._ssl_accepting = False
-
- def handle_read(self):
- if self._ssl_accepting:
- self._do_ssl_handshake()
- else:
- data = self.recv(1024)
- if support.verbose:
- sys.stdout.write(" server: read %s from client\n" % repr(data))
- if not data:
- self.close()
- else:
- self.send(data.lower())
+ else:
+ self._ssl_accepting = False
- def handle_close(self):
- self.close()
+ def handle_read(self):
+ if self._ssl_accepting:
+ self._do_ssl_handshake()
+ else:
+ data = self.recv(1024)
if support.verbose:
- sys.stdout.write(" server: closed connection %s\n" % self.socket)
-
- def handle_error(self):
- raise
-
- def __init__(self, certfile):
- self.certfile = certfile
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.port = support.bind_port(sock, '')
- asyncore.dispatcher.__init__(self, sock)
- self.listen(5)
+ sys.stdout.write(" server: read %s from client\n" % repr(data))
+ if not data:
+ self.close()
+ else:
+ self.send(data.lower())
- def handle_accepted(self, sock_obj, addr):
+ def handle_close(self):
+ self.close()
if support.verbose:
- sys.stdout.write(" server: new connection from %s:%s\n" %addr)
- self.ConnectionHandler(sock_obj, self.certfile)
+ sys.stdout.write(" server: closed connection %s\n" % self.socket)
def handle_error(self):
raise
def __init__(self, certfile):
- self.flag = None
- self.active = False
- self.server = self.EchoServer(certfile)
- self.port = self.server.port
- threading.Thread.__init__(self)
- self.daemon = True
+ self.certfile = certfile
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.port = support.bind_port(sock, '')
+ asyncore.dispatcher.__init__(self, sock)
+ self.listen(5)
- def __str__(self):
- return "<%s %s>" % (self.__class__.__name__, self.server)
+ def handle_accepted(self, sock_obj, addr):
+ if support.verbose:
+ sys.stdout.write(" server: new connection from %s:%s\n" %addr)
+ self.ConnectionHandler(sock_obj, self.certfile)
- def __enter__(self):
- self.start(threading.Event())
- self.flag.wait()
- return self
+ def handle_error(self):
+ raise
- def __exit__(self, *args):
- if support.verbose:
- sys.stdout.write(" cleanup: stopping server.\n")
- self.stop()
- if support.verbose:
- sys.stdout.write(" cleanup: joining server thread.\n")
- self.join()
- if support.verbose:
- sys.stdout.write(" cleanup: successfully joined.\n")
- # make sure that ConnectionHandler is removed from socket_map
- asyncore.close_all(ignore_all=True)
+ def __init__(self, certfile):
+ self.flag = None
+ self.active = False
+ self.server = self.EchoServer(certfile)
+ self.port = self.server.port
+ threading.Thread.__init__(self)
+ self.daemon = True
- def start (self, flag=None):
- self.flag = flag
- threading.Thread.start(self)
+ def __str__(self):
+ return "<%s %s>" % (self.__class__.__name__, self.server)
- def run(self):
- self.active = True
- if self.flag:
- self.flag.set()
- while self.active:
- try:
- asyncore.loop(1)
- except:
- pass
+ def __enter__(self):
+ self.start(threading.Event())
+ self.flag.wait()
+ return self
- def stop(self):
- self.active = False
- self.server.close()
+ def __exit__(self, *args):
+ if support.verbose:
+ sys.stdout.write(" cleanup: stopping server.\n")
+ self.stop()
+ if support.verbose:
+ sys.stdout.write(" cleanup: joining server thread.\n")
+ self.join()
+ if support.verbose:
+ sys.stdout.write(" cleanup: successfully joined.\n")
+ # make sure that ConnectionHandler is removed from socket_map
+ asyncore.close_all(ignore_all=True)
+
+ def start (self, flag=None):
+ self.flag = flag
+ threading.Thread.start(self)
+
+ def run(self):
+ self.active = True
+ if self.flag:
+ self.flag.set()
+ while self.active:
+ try:
+ asyncore.loop(1)
+ except:
+ pass
- def server_params_test(client_context, server_context, indata=b"FOO\n",
- chatty=True, connectionchatty=False, sni_name=None,
- session=None):
- """
- Launch a server, connect a client to it and try various reads
- and writes.
- """
- stats = {}
- server = ThreadedEchoServer(context=server_context,
- chatty=chatty,
- connectionchatty=False)
- with server:
- with client_context.wrap_socket(socket.socket(),
- server_hostname=sni_name, session=session) as s:
- s.connect((HOST, server.port))
- for arg in [indata, bytearray(indata), memoryview(indata)]:
- if connectionchatty:
- if support.verbose:
- sys.stdout.write(
- " client: sending %r...\n" % indata)
- s.write(arg)
- outdata = s.read()
- if connectionchatty:
- if support.verbose:
- sys.stdout.write(" client: read %r\n" % outdata)
- if outdata != indata.lower():
- raise AssertionError(
- "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
- % (outdata[:20], len(outdata),
- indata[:20].lower(), len(indata)))
- s.write(b"over\n")
+ def stop(self):
+ self.active = False
+ self.server.close()
+
+def server_params_test(client_context, server_context, indata=b"FOO\n",
+ chatty=True, connectionchatty=False, sni_name=None,
+ session=None):
+ """
+ Launch a server, connect a client to it and try various reads
+ and writes.
+ """
+ stats = {}
+ server = ThreadedEchoServer(context=server_context,
+ chatty=chatty,
+ connectionchatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=sni_name, session=session) as s:
+ s.connect((HOST, server.port))
+ for arg in [indata, bytearray(indata), memoryview(indata)]:
if connectionchatty:
if support.verbose:
- sys.stdout.write(" client: closing connection.\n")
- stats.update({
- 'compression': s.compression(),
- 'cipher': s.cipher(),
- 'peercert': s.getpeercert(),
- 'client_alpn_protocol': s.selected_alpn_protocol(),
- 'client_npn_protocol': s.selected_npn_protocol(),
- 'version': s.version(),
- 'session_reused': s.session_reused,
- 'session': s.session,
- })
- s.close()
- stats['server_alpn_protocols'] = server.selected_alpn_protocols
- stats['server_npn_protocols'] = server.selected_npn_protocols
- stats['server_shared_ciphers'] = server.shared_ciphers
- return stats
+ sys.stdout.write(
+ " client: sending %r...\n" % indata)
+ s.write(arg)
+ outdata = s.read()
+ if connectionchatty:
+ if support.verbose:
+ sys.stdout.write(" client: read %r\n" % outdata)
+ if outdata != indata.lower():
+ raise AssertionError(
+ "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
+ % (outdata[:20], len(outdata),
+ indata[:20].lower(), len(indata)))
+ s.write(b"over\n")
+ if connectionchatty:
+ if support.verbose:
+ sys.stdout.write(" client: closing connection.\n")
+ stats.update({
+ 'compression': s.compression(),
+ 'cipher': s.cipher(),
+ 'peercert': s.getpeercert(),
+ 'client_alpn_protocol': s.selected_alpn_protocol(),
+ 'client_npn_protocol': s.selected_npn_protocol(),
+ 'version': s.version(),
+ 'session_reused': s.session_reused,
+ 'session': s.session,
+ })
+ s.close()
+ stats['server_alpn_protocols'] = server.selected_alpn_protocols
+ stats['server_npn_protocols'] = server.selected_npn_protocols
+ stats['server_shared_ciphers'] = server.shared_ciphers
+ return stats
+
+def try_protocol_combo(server_protocol, client_protocol, expect_success,
+ certsreqs=None, server_options=0, client_options=0):
+ """
+ Try to SSL-connect using *client_protocol* to *server_protocol*.
+ If *expect_success* is true, assert that the connection succeeds,
+ if it's false, assert that the connection fails.
+ Also, if *expect_success* is a string, assert that it is the protocol
+ version actually used by the connection.
+ """
+ if certsreqs is None:
+ certsreqs = ssl.CERT_NONE
+ certtype = {
+ ssl.CERT_NONE: "CERT_NONE",
+ ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
+ ssl.CERT_REQUIRED: "CERT_REQUIRED",
+ }[certsreqs]
+ if support.verbose:
+ formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
+ sys.stdout.write(formatstr %
+ (ssl.get_protocol_name(client_protocol),
+ ssl.get_protocol_name(server_protocol),
+ certtype))
+ client_context = ssl.SSLContext(client_protocol)
+ client_context.options |= client_options
+ server_context = ssl.SSLContext(server_protocol)
+ server_context.options |= server_options
+
+ # NOTE: we must enable "ALL" ciphers on the client, otherwise an
+ # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
+ # starting from OpenSSL 1.0.0 (see issue #8322).
+ if client_context.protocol == ssl.PROTOCOL_SSLv23:
+ client_context.set_ciphers("ALL")
+
+ for ctx in (client_context, server_context):
+ ctx.verify_mode = certsreqs
+ ctx.load_cert_chain(CERTFILE)
+ ctx.load_verify_locations(CERTFILE)
+ try:
+ stats = server_params_test(client_context, server_context,
+ chatty=False, connectionchatty=False)
+ # Protocol mismatch can result in either an SSLError, or a
+ # "Connection reset by peer" error.
+ except ssl.SSLError:
+ if expect_success:
+ raise
+ except OSError as e:
+ if expect_success or e.errno != errno.ECONNRESET:
+ raise
+ else:
+ if not expect_success:
+ raise AssertionError(
+ "Client protocol %s succeeded with server protocol %s!"
+ % (ssl.get_protocol_name(client_protocol),
+ ssl.get_protocol_name(server_protocol)))
+ elif (expect_success is not True
+ and expect_success != stats['version']):
+ raise AssertionError("version mismatch: expected %r, got %r"
+ % (expect_success, stats['version']))
- def try_protocol_combo(server_protocol, client_protocol, expect_success,
- certsreqs=None, server_options=0, client_options=0):
- """
- Try to SSL-connect using *client_protocol* to *server_protocol*.
- If *expect_success* is true, assert that the connection succeeds,
- if it's false, assert that the connection fails.
- Also, if *expect_success* is a string, assert that it is the protocol
- version actually used by the connection.
- """
- if certsreqs is None:
- certsreqs = ssl.CERT_NONE
- certtype = {
- ssl.CERT_NONE: "CERT_NONE",
- ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
- ssl.CERT_REQUIRED: "CERT_REQUIRED",
- }[certsreqs]
- if support.verbose:
- formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
- sys.stdout.write(formatstr %
- (ssl.get_protocol_name(client_protocol),
- ssl.get_protocol_name(server_protocol),
- certtype))
- client_context = ssl.SSLContext(client_protocol)
- client_context.options |= client_options
- server_context = ssl.SSLContext(server_protocol)
- server_context.options |= server_options
-
- # NOTE: we must enable "ALL" ciphers on the client, otherwise an
- # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
- # starting from OpenSSL 1.0.0 (see issue #8322).
- if client_context.protocol == ssl.PROTOCOL_SSLv23:
- client_context.set_ciphers("ALL")
-
- for ctx in (client_context, server_context):
- ctx.verify_mode = certsreqs
- ctx.load_cert_chain(CERTFILE)
- ctx.load_verify_locations(CERTFILE)
- try:
- stats = server_params_test(client_context, server_context,
- chatty=False, connectionchatty=False)
- # Protocol mismatch can result in either an SSLError, or a
- # "Connection reset by peer" error.
- except ssl.SSLError:
- if expect_success:
- raise
- except OSError as e:
- if expect_success or e.errno != errno.ECONNRESET:
- raise
- else:
- if not expect_success:
- raise AssertionError(
- "Client protocol %s succeeded with server protocol %s!"
- % (ssl.get_protocol_name(client_protocol),
- ssl.get_protocol_name(server_protocol)))
- elif (expect_success is not True
- and expect_success != stats['version']):
- raise AssertionError("version mismatch: expected %r, got %r"
- % (expect_success, stats['version']))
-
-
- class ThreadedTests(unittest.TestCase):
-
- @skip_if_broken_ubuntu_ssl
- def test_echo(self):
- """Basic test of an SSL client connecting to a server"""
- if support.verbose:
- sys.stdout.write("\n")
- for protocol in PROTOCOLS:
- if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
- continue
- with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
- context = ssl.SSLContext(protocol)
- context.load_cert_chain(CERTFILE)
- server_params_test(context, context,
- chatty=True, connectionchatty=True)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
- client_context.load_verify_locations(SIGNING_CA)
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
- # server_context.load_verify_locations(SIGNING_CA)
- server_context.load_cert_chain(SIGNED_CERTFILE2)
+class ThreadedTests(unittest.TestCase):
- with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
- server_params_test(client_context=client_context,
- server_context=server_context,
+ @skip_if_broken_ubuntu_ssl
+ def test_echo(self):
+ """Basic test of an SSL client connecting to a server"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ for protocol in PROTOCOLS:
+ if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
+ continue
+ with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
+ context = ssl.SSLContext(protocol)
+ context.load_cert_chain(CERTFILE)
+ server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ client_context.load_verify_locations(SIGNING_CA)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ # server_context.load_verify_locations(SIGNING_CA)
+ server_context.load_cert_chain(SIGNED_CERTFILE2)
+
+ with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
+ server_params_test(client_context=client_context,
+ server_context=server_context,
+ chatty=True, connectionchatty=True,
+ sni_name='fakehostname')
+
+ client_context.check_hostname = False
+ with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
+ with self.assertRaises(ssl.SSLError) as e:
+ server_params_test(client_context=server_context,
+ server_context=client_context,
chatty=True, connectionchatty=True,
sni_name='fakehostname')
+ self.assertIn('called a function you should not call',
+ str(e.exception))
- client_context.check_hostname = False
- with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
- with self.assertRaises(ssl.SSLError) as e:
- server_params_test(client_context=server_context,
- server_context=client_context,
- chatty=True, connectionchatty=True,
- sni_name='fakehostname')
- self.assertIn('called a function you should not call',
- str(e.exception))
-
- with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
- with self.assertRaises(ssl.SSLError) as e:
- server_params_test(client_context=server_context,
- server_context=server_context,
- chatty=True, connectionchatty=True)
- self.assertIn('called a function you should not call',
- str(e.exception))
+ with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
+ with self.assertRaises(ssl.SSLError) as e:
+ server_params_test(client_context=server_context,
+ server_context=server_context,
+ chatty=True, connectionchatty=True)
+ self.assertIn('called a function you should not call',
+ str(e.exception))
- with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
- with self.assertRaises(ssl.SSLError) as e:
- server_params_test(client_context=server_context,
- server_context=client_context,
- chatty=True, connectionchatty=True)
- self.assertIn('called a function you should not call',
- str(e.exception))
+ with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
+ with self.assertRaises(ssl.SSLError) as e:
+ server_params_test(client_context=server_context,
+ server_context=client_context,
+ chatty=True, connectionchatty=True)
+ self.assertIn('called a function you should not call',
+ str(e.exception))
- def test_getpeercert(self):
+ def test_getpeercert(self):
+ if support.verbose:
+ sys.stdout.write("\n")
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
+ with server:
+ s = context.wrap_socket(socket.socket(),
+ do_handshake_on_connect=False)
+ s.connect((HOST, server.port))
+ # getpeercert() raise ValueError while the handshake isn't
+ # done.
+ with self.assertRaises(ValueError):
+ s.getpeercert()
+ s.do_handshake()
+ cert = s.getpeercert()
+ self.assertTrue(cert, "Can't get peer certificate.")
+ cipher = s.cipher()
if support.verbose:
- sys.stdout.write("\n")
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
- server = ThreadedEchoServer(context=context, chatty=False)
- with server:
- s = context.wrap_socket(socket.socket(),
- do_handshake_on_connect=False)
+ sys.stdout.write(pprint.pformat(cert) + '\n')
+ sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
+ if 'subject' not in cert:
+ self.fail("No subject field in certificate: %s." %
+ pprint.pformat(cert))
+ if ((('organizationName', 'Python Software Foundation'),)
+ not in cert['subject']):
+ self.fail(
+ "Missing or invalid 'organizationName' field in certificate subject; "
+ "should be 'Python Software Foundation'.")
+ self.assertIn('notBefore', cert)
+ self.assertIn('notAfter', cert)
+ before = ssl.cert_time_to_seconds(cert['notBefore'])
+ after = ssl.cert_time_to_seconds(cert['notAfter'])
+ self.assertLess(before, after)
+ s.close()
+
+ @unittest.skipUnless(have_verify_flags(),
+ "verify_flags need OpenSSL > 0.9.8")
+ def test_crl_check(self):
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(SIGNING_CA)
+ tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
+ self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf)
+
+ # VERIFY_DEFAULT should pass
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
- # getpeercert() raise ValueError while the handshake isn't
- # done.
- with self.assertRaises(ValueError):
- s.getpeercert()
- s.do_handshake()
cert = s.getpeercert()
self.assertTrue(cert, "Can't get peer certificate.")
- cipher = s.cipher()
- if support.verbose:
- sys.stdout.write(pprint.pformat(cert) + '\n')
- sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
- if 'subject' not in cert:
- self.fail("No subject field in certificate: %s." %
- pprint.pformat(cert))
- if ((('organizationName', 'Python Software Foundation'),)
- not in cert['subject']):
- self.fail(
- "Missing or invalid 'organizationName' field in certificate subject; "
- "should be 'Python Software Foundation'.")
- self.assertIn('notBefore', cert)
- self.assertIn('notAfter', cert)
- before = ssl.cert_time_to_seconds(cert['notBefore'])
- after = ssl.cert_time_to_seconds(cert['notAfter'])
- self.assertLess(before, after)
- s.close()
- @unittest.skipUnless(have_verify_flags(),
- "verify_flags need OpenSSL > 0.9.8")
- def test_crl_check(self):
- if support.verbose:
- sys.stdout.write("\n")
+ # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
+ context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
-
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(SIGNING_CA)
- tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
- self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf)
-
- # VERIFY_DEFAULT should pass
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket()) as s:
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
+ with self.assertRaisesRegex(ssl.SSLError,
+ "certificate verify failed"):
s.connect((HOST, server.port))
- cert = s.getpeercert()
- self.assertTrue(cert, "Can't get peer certificate.")
- # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
- context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
+ # now load a CRL file. The CRL file is signed by the CA.
+ context.load_verify_locations(CRLFILE)
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket()) as s:
- with self.assertRaisesRegex(ssl.SSLError,
- "certificate verify failed"):
- s.connect((HOST, server.port))
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
+ s.connect((HOST, server.port))
+ cert = s.getpeercert()
+ self.assertTrue(cert, "Can't get peer certificate.")
- # now load a CRL file. The CRL file is signed by the CA.
- context.load_verify_locations(CRLFILE)
+ def test_check_hostname(self):
+ if support.verbose:
+ sys.stdout.write("\n")
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket()) as s:
- s.connect((HOST, server.port))
- cert = s.getpeercert()
- self.assertTrue(cert, "Can't get peer certificate.")
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
- def test_check_hostname(self):
- if support.verbose:
- sys.stdout.write("\n")
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.check_hostname = True
+ context.load_verify_locations(SIGNING_CA)
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
-
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.verify_mode = ssl.CERT_REQUIRED
- context.check_hostname = True
- context.load_verify_locations(SIGNING_CA)
-
- # correct hostname should verify
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket(),
- server_hostname="localhost") as s:
+ # correct hostname should verify
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket(),
+ server_hostname="localhost") as s:
+ s.connect((HOST, server.port))
+ cert = s.getpeercert()
+ self.assertTrue(cert, "Can't get peer certificate.")
+
+ # incorrect hostname should raise an exception
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with context.wrap_socket(socket.socket(),
+ server_hostname="invalid") as s:
+ with self.assertRaisesRegex(ssl.CertificateError,
+ "hostname 'invalid' doesn't match 'localhost'"):
s.connect((HOST, server.port))
- cert = s.getpeercert()
- self.assertTrue(cert, "Can't get peer certificate.")
-
- # incorrect hostname should raise an exception
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with context.wrap_socket(socket.socket(),
- server_hostname="invalid") as s:
- with self.assertRaisesRegex(ssl.CertificateError,
- "hostname 'invalid' doesn't match 'localhost'"):
- s.connect((HOST, server.port))
-
- # missing server_hostname arg should cause an exception, too
- server = ThreadedEchoServer(context=server_context, chatty=True)
- with server:
- with socket.socket() as s:
- with self.assertRaisesRegex(ValueError,
- "check_hostname requires server_hostname"):
- context.wrap_socket(s)
-
- def test_wrong_cert(self):
- """Connecting when the server rejects the client's certificate
-
- Launch a server with CERT_REQUIRED, and check that trying to
- connect to it with a wrong client certificate fails.
- """
- certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
- "wrongcert.pem")
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_REQUIRED,
- cacerts=CERTFILE, chatty=False,
- connectionchatty=False)
- with server, \
- socket.socket() as sock, \
- test_wrap_socket(sock,
- certfile=certfile,
- ssl_version=ssl.PROTOCOL_TLSv1) as s:
+
+ # missing server_hostname arg should cause an exception, too
+ server = ThreadedEchoServer(context=server_context, chatty=True)
+ with server:
+ with socket.socket() as s:
+ with self.assertRaisesRegex(ValueError,
+ "check_hostname requires server_hostname"):
+ context.wrap_socket(s)
+
+ def test_wrong_cert(self):
+ """Connecting when the server rejects the client's certificate
+
+ Launch a server with CERT_REQUIRED, and check that trying to
+ connect to it with a wrong client certificate fails.
+ """
+ certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
+ "wrongcert.pem")
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_REQUIRED,
+ cacerts=CERTFILE, chatty=False,
+ connectionchatty=False)
+ with server, \
+ socket.socket() as sock, \
+ test_wrap_socket(sock,
+ certfile=certfile,
+ ssl_version=ssl.PROTOCOL_TLSv1) as s:
+ try:
+ # Expect either an SSL error about the server rejecting
+ # the connection, or a low-level connection reset (which
+ # sometimes happens on Windows)
+ s.connect((HOST, server.port))
+ except ssl.SSLError as e:
+ if support.verbose:
+ sys.stdout.write("\nSSLError is %r\n" % e)
+ except OSError as e:
+ if e.errno != errno.ECONNRESET:
+ raise
+ if support.verbose:
+ sys.stdout.write("\nsocket.error is %r\n" % e)
+ else:
+ self.fail("Use of invalid cert should have failed!")
+
+ def test_rude_shutdown(self):
+ """A brutal shutdown of an SSL server should raise an OSError
+ in the client when attempting handshake.
+ """
+ listener_ready = threading.Event()
+ listener_gone = threading.Event()
+
+ s = socket.socket()
+ port = support.bind_port(s, HOST)
+
+ # `listener` runs in a thread. It sits in an accept() until
+ # the main thread connects. Then it rudely closes the socket,
+ # and sets Event `listener_gone` to let the main thread know
+ # the socket is gone.
+ def listener():
+ s.listen()
+ listener_ready.set()
+ newsock, addr = s.accept()
+ newsock.close()
+ s.close()
+ listener_gone.set()
+
+ def connector():
+ listener_ready.wait()
+ with socket.socket() as c:
+ c.connect((HOST, port))
+ listener_gone.wait()
try:
- # Expect either an SSL error about the server rejecting
- # the connection, or a low-level connection reset (which
- # sometimes happens on Windows)
- s.connect((HOST, server.port))
- except ssl.SSLError as e:
- if support.verbose:
- sys.stdout.write("\nSSLError is %r\n" % e)
- except OSError as e:
- if e.errno != errno.ECONNRESET:
- raise
- if support.verbose:
- sys.stdout.write("\nsocket.error is %r\n" % e)
+ ssl_sock = test_wrap_socket(c)
+ except OSError:
+ pass
else:
- self.fail("Use of invalid cert should have failed!")
+ self.fail('connecting to closed SSL socket should have failed')
- def test_rude_shutdown(self):
- """A brutal shutdown of an SSL server should raise an OSError
- in the client when attempting handshake.
- """
- listener_ready = threading.Event()
- listener_gone = threading.Event()
+ t = threading.Thread(target=listener)
+ t.start()
+ try:
+ connector()
+ finally:
+ t.join()
- s = socket.socket()
- port = support.bind_port(s, HOST)
-
- # `listener` runs in a thread. It sits in an accept() until
- # the main thread connects. Then it rudely closes the socket,
- # and sets Event `listener_gone` to let the main thread know
- # the socket is gone.
- def listener():
- s.listen()
- listener_ready.set()
- newsock, addr = s.accept()
- newsock.close()
- s.close()
- listener_gone.set()
-
- def connector():
- listener_ready.wait()
- with socket.socket() as c:
- c.connect((HOST, port))
- listener_gone.wait()
- try:
- ssl_sock = test_wrap_socket(c)
- except OSError:
- pass
- else:
- self.fail('connecting to closed SSL socket should have failed')
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'),
+ "OpenSSL is compiled without SSLv2 support")
+ def test_protocol_sslv2(self):
+ """Connecting to an SSLv2 server with various client options"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False)
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
+ # SSLv23 client with specific SSL options
+ if no_sslv2_implies_sslv3_hello():
+ # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_SSLv2)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_SSLv3)
+ try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1)
- t = threading.Thread(target=listener)
- t.start()
+ @skip_if_broken_ubuntu_ssl
+ def test_protocol_sslv23(self):
+ """Connecting to an SSLv23 server with various client options"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
try:
- connector()
- finally:
- t.join()
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
+ except OSError as x:
+ # this fails on some older versions of OpenSSL (0.9.7l, for instance)
+ if support.verbose:
+ sys.stdout.write(
+ " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
+ % str(x))
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')
+
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
+
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
+
+ # Server with specific SSL options
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False,
+ server_options=ssl.OP_NO_SSLv3)
+ # Will choose TLSv1
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True,
+ server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False,
+ server_options=ssl.OP_NO_TLSv1)
+
+
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'),
+ "OpenSSL is compiled without SSLv3 support")
+ def test_protocol_sslv3(self):
+ """Connecting to an SSLv3 server with various client options"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_SSLv3)
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
+ if no_sslv2_implies_sslv3_hello():
+ # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
+ try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23,
+ False, client_options=ssl.OP_NO_SSLv2)
+
+ @skip_if_broken_ubuntu_ssl
+ def test_protocol_tlsv1(self):
+ """Connecting to a TLSv1 server with various client options"""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1)
+
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
+ "TLS version 1.1 not supported.")
+ def test_protocol_tlsv1_1(self):
+ """Connecting to a TLSv1.1 server with various client options.
+ Testing against older TLS versions."""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1_1)
+
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
- @skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'),
- "OpenSSL is compiled without SSLv2 support")
- def test_protocol_sslv2(self):
- """Connecting to an SSLv2 server with various client options"""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False)
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
- # SSLv23 client with specific SSL options
- if no_sslv2_implies_sslv3_hello():
- # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_SSLv2)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_SSLv3)
- try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_TLSv1)
- @skip_if_broken_ubuntu_ssl
- def test_protocol_sslv23(self):
- """Connecting to an SSLv23 server with various client options"""
+ @skip_if_broken_ubuntu_ssl
+ @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"),
+ "TLS version 1.2 not supported.")
+ def test_protocol_tlsv1_2(self):
+ """Connecting to a TLSv1.2 server with various client options.
+ Testing against older TLS versions."""
+ if support.verbose:
+ sys.stdout.write("\n")
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
+ server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
+ client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
+ if hasattr(ssl, 'PROTOCOL_SSLv2'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
+ if hasattr(ssl, 'PROTOCOL_SSLv3'):
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
+ client_options=ssl.OP_NO_TLSv1_2)
+
+ try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
+ try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
+
+ def test_starttls(self):
+ """Switching from clear text to encrypted and back again."""
+ msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
+
+ server = ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ starttls_server=True,
+ chatty=True,
+ connectionchatty=True)
+ wrapped = False
+ with server:
+ s = socket.socket()
+ s.setblocking(1)
+ s.connect((HOST, server.port))
if support.verbose:
sys.stdout.write("\n")
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try:
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
- except OSError as x:
- # this fails on some older versions of OpenSSL (0.9.7l, for instance)
+ for indata in msgs:
+ if support.verbose:
+ sys.stdout.write(
+ " client: sending %r...\n" % indata)
+ if wrapped:
+ conn.write(indata)
+ outdata = conn.read()
+ else:
+ s.send(indata)
+ outdata = s.recv(1024)
+ msg = outdata.strip().lower()
+ if indata == b"STARTTLS" and msg.startswith(b"ok"):
+ # STARTTLS ok, switch to secure mode
if support.verbose:
sys.stdout.write(
- " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
- % str(x))
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')
-
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
-
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
-
- # Server with specific SSL options
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False,
- server_options=ssl.OP_NO_SSLv3)
- # Will choose TLSv1
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True,
- server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False,
- server_options=ssl.OP_NO_TLSv1)
-
-
- @skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'),
- "OpenSSL is compiled without SSLv3 support")
- def test_protocol_sslv3(self):
- """Connecting to an SSLv3 server with various client options"""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_SSLv3)
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
- if no_sslv2_implies_sslv3_hello():
- # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
- try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23,
- False, client_options=ssl.OP_NO_SSLv2)
-
- @skip_if_broken_ubuntu_ssl
- def test_protocol_tlsv1(self):
- """Connecting to a TLSv1 server with various client options"""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_TLSv1)
-
- @skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
- "TLS version 1.1 not supported.")
- def test_protocol_tlsv1_1(self):
- """Connecting to a TLSv1.1 server with various client options.
- Testing against older TLS versions."""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_TLSv1_1)
-
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
-
-
- @skip_if_broken_ubuntu_ssl
- @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"),
- "TLS version 1.2 not supported.")
- def test_protocol_tlsv1_2(self):
- """Connecting to a TLSv1.2 server with various client options.
- Testing against older TLS versions."""
- if support.verbose:
- sys.stdout.write("\n")
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
- server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
- client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
- if hasattr(ssl, 'PROTOCOL_SSLv2'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
- if hasattr(ssl, 'PROTOCOL_SSLv3'):
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
- client_options=ssl.OP_NO_TLSv1_2)
-
- try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
- try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
-
- def test_starttls(self):
- """Switching from clear text to encrypted and back again."""
- msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
-
- server = ThreadedEchoServer(CERTFILE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- starttls_server=True,
- chatty=True,
- connectionchatty=True)
- wrapped = False
- with server:
- s = socket.socket()
- s.setblocking(1)
- s.connect((HOST, server.port))
- if support.verbose:
- sys.stdout.write("\n")
- for indata in msgs:
+ " client: read %r from server, starting TLS...\n"
+ % msg)
+ conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
+ wrapped = True
+ elif indata == b"ENDTLS" and msg.startswith(b"ok"):
+ # ENDTLS ok, switch back to clear text
if support.verbose:
sys.stdout.write(
- " client: sending %r...\n" % indata)
- if wrapped:
- conn.write(indata)
- outdata = conn.read()
- else:
- s.send(indata)
- outdata = s.recv(1024)
- msg = outdata.strip().lower()
- if indata == b"STARTTLS" and msg.startswith(b"ok"):
- # STARTTLS ok, switch to secure mode
- if support.verbose:
- sys.stdout.write(
- " client: read %r from server, starting TLS...\n"
- % msg)
- conn = test_wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
- wrapped = True
- elif indata == b"ENDTLS" and msg.startswith(b"ok"):
- # ENDTLS ok, switch back to clear text
- if support.verbose:
- sys.stdout.write(
- " client: read %r from server, ending TLS...\n"
- % msg)
- s = conn.unwrap()
- wrapped = False
- else:
- if support.verbose:
- sys.stdout.write(
- " client: read %r from server\n" % msg)
- if support.verbose:
- sys.stdout.write(" client: closing connection.\n")
- if wrapped:
- conn.write(b"over\n")
+ " client: read %r from server, ending TLS...\n"
+ % msg)
+ s = conn.unwrap()
+ wrapped = False
else:
- s.send(b"over\n")
- if wrapped:
- conn.close()
- else:
- s.close()
-
- def test_socketserver(self):
- """Using socketserver to create and manage SSL connections."""
- server = make_https_server(self, certfile=CERTFILE)
- # try to connect
- if support.verbose:
- sys.stdout.write('\n')
- with open(CERTFILE, 'rb') as f:
- d1 = f.read()
- d2 = ''
- # now fetch the same data from the HTTPS server
- url = 'https://localhost:%d/%s' % (
- server.port, os.path.split(CERTFILE)[1])
- context = ssl.create_default_context(cafile=CERTFILE)
- f = urllib.request.urlopen(url, context=context)
- try:
- dlen = f.info().get("content-length")
- if dlen and (int(dlen) > 0):
- d2 = f.read(int(dlen))
if support.verbose:
sys.stdout.write(
- " client: read %d bytes from remote server '%s'\n"
- % (len(d2), server))
- finally:
- f.close()
- self.assertEqual(d1, d2)
-
- def test_asyncore_server(self):
- """Check the example asyncore integration."""
+ " client: read %r from server\n" % msg)
if support.verbose:
- sys.stdout.write("\n")
+ sys.stdout.write(" client: closing connection.\n")
+ if wrapped:
+ conn.write(b"over\n")
+ else:
+ s.send(b"over\n")
+ if wrapped:
+ conn.close()
+ else:
+ s.close()
- indata = b"FOO\n"
- server = AsyncoreEchoServer(CERTFILE)
- with server:
- s = test_wrap_socket(socket.socket())
- s.connect(('127.0.0.1', server.port))
+ def test_socketserver(self):
+ """Using socketserver to create and manage SSL connections."""
+ server = make_https_server(self, certfile=CERTFILE)
+ # try to connect
+ if support.verbose:
+ sys.stdout.write('\n')
+ with open(CERTFILE, 'rb') as f:
+ d1 = f.read()
+ d2 = ''
+ # now fetch the same data from the HTTPS server
+ url = 'https://localhost:%d/%s' % (
+ server.port, os.path.split(CERTFILE)[1])
+ context = ssl.create_default_context(cafile=CERTFILE)
+ f = urllib.request.urlopen(url, context=context)
+ try:
+ dlen = f.info().get("content-length")
+ if dlen and (int(dlen) > 0):
+ d2 = f.read(int(dlen))
if support.verbose:
sys.stdout.write(
- " client: sending %r...\n" % indata)
- s.write(indata)
- outdata = s.read()
- if support.verbose:
- sys.stdout.write(" client: read %r\n" % outdata)
- if outdata != indata.lower():
- self.fail(
- "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
- % (outdata[:20], len(outdata),
- indata[:20].lower(), len(indata)))
- s.write(b"over\n")
- if support.verbose:
- sys.stdout.write(" client: closing connection.\n")
- s.close()
- if support.verbose:
- sys.stdout.write(" client: connection closed.\n")
+ " client: read %d bytes from remote server '%s'\n"
+ % (len(d2), server))
+ finally:
+ f.close()
+ self.assertEqual(d1, d2)
- def test_recv_send(self):
- """Test recv(), send() and friends."""
+ def test_asyncore_server(self):
+ """Check the example asyncore integration."""
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ indata = b"FOO\n"
+ server = AsyncoreEchoServer(CERTFILE)
+ with server:
+ s = test_wrap_socket(socket.socket())
+ s.connect(('127.0.0.1', server.port))
if support.verbose:
- sys.stdout.write("\n")
+ sys.stdout.write(
+ " client: sending %r...\n" % indata)
+ s.write(indata)
+ outdata = s.read()
+ if support.verbose:
+ sys.stdout.write(" client: read %r\n" % outdata)
+ if outdata != indata.lower():
+ self.fail(
+ "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
+ % (outdata[:20], len(outdata),
+ indata[:20].lower(), len(indata)))
+ s.write(b"over\n")
+ if support.verbose:
+ sys.stdout.write(" client: closing connection.\n")
+ s.close()
+ if support.verbose:
+ sys.stdout.write(" client: connection closed.\n")
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- cacerts=CERTFILE,
- chatty=True,
- connectionchatty=False)
- with server:
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- # helper methods for standardising recv* method signatures
- def _recv_into():
- b = bytearray(b"\0"*100)
- count = s.recv_into(b)
- return b[:count]
-
- def _recvfrom_into():
- b = bytearray(b"\0"*100)
- count, addr = s.recvfrom_into(b)
- return b[:count]
-
- # (name, method, expect success?, *args, return value func)
- send_methods = [
- ('send', s.send, True, [], len),
- ('sendto', s.sendto, False, ["some.address"], len),
- ('sendall', s.sendall, True, [], lambda x: None),
- ]
- # (name, method, whether to expect success, *args)
- recv_methods = [
- ('recv', s.recv, True, []),
- ('recvfrom', s.recvfrom, False, ["some.address"]),
- ('recv_into', _recv_into, True, []),
- ('recvfrom_into', _recvfrom_into, False, []),
- ]
- data_prefix = "PREFIX_"
-
- for (meth_name, send_meth, expect_success, args,
- ret_val_meth) in send_methods:
- indata = (data_prefix + meth_name).encode('ascii')
- try:
- ret = send_meth(indata, *args)
- msg = "sending with {}".format(meth_name)
- self.assertEqual(ret, ret_val_meth(indata), msg=msg)
- outdata = s.read()
- if outdata != indata.lower():
- self.fail(
- "While sending with <<{name:s}>> bad data "
- "<<{outdata:r}>> ({nout:d}) received; "
- "expected <<{indata:r}>> ({nin:d})\n".format(
- name=meth_name, outdata=outdata[:20],
- nout=len(outdata),
- indata=indata[:20], nin=len(indata)
- )
- )
- except ValueError as e:
- if expect_success:
- self.fail(
- "Failed to send with method <<{name:s}>>; "
- "expected to succeed.\n".format(name=meth_name)
+ def test_recv_send(self):
+ """Test recv(), send() and friends."""
+ if support.verbose:
+ sys.stdout.write("\n")
+
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = test_wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ # helper methods for standardising recv* method signatures
+ def _recv_into():
+ b = bytearray(b"\0"*100)
+ count = s.recv_into(b)
+ return b[:count]
+
+ def _recvfrom_into():
+ b = bytearray(b"\0"*100)
+ count, addr = s.recvfrom_into(b)
+ return b[:count]
+
+ # (name, method, expect success?, *args, return value func)
+ send_methods = [
+ ('send', s.send, True, [], len),
+ ('sendto', s.sendto, False, ["some.address"], len),
+ ('sendall', s.sendall, True, [], lambda x: None),
+ ]
+ # (name, method, whether to expect success, *args)
+ recv_methods = [
+ ('recv', s.recv, True, []),
+ ('recvfrom', s.recvfrom, False, ["some.address"]),
+ ('recv_into', _recv_into, True, []),
+ ('recvfrom_into', _recvfrom_into, False, []),
+ ]
+ data_prefix = "PREFIX_"
+
+ for (meth_name, send_meth, expect_success, args,
+ ret_val_meth) in send_methods:
+ indata = (data_prefix + meth_name).encode('ascii')
+ try:
+ ret = send_meth(indata, *args)
+ msg = "sending with {}".format(meth_name)
+ self.assertEqual(ret, ret_val_meth(indata), msg=msg)
+ outdata = s.read()
+ if outdata != indata.lower():
+ self.fail(
+ "While sending with <<{name:s}>> bad data "
+ "<<{outdata:r}>> ({nout:d}) received; "
+ "expected <<{indata:r}>> ({nin:d})\n".format(
+ name=meth_name, outdata=outdata[:20],
+ nout=len(outdata),
+ indata=indata[:20], nin=len(indata)
)
- if not str(e).startswith(meth_name):
- self.fail(
- "Method <<{name:s}>> failed with unexpected "
- "exception message: {exp:s}\n".format(
- name=meth_name, exp=e
- )
+ )
+ except ValueError as e:
+ if expect_success:
+ self.fail(
+ "Failed to send with method <<{name:s}>>; "
+ "expected to succeed.\n".format(name=meth_name)
+ )
+ if not str(e).startswith(meth_name):
+ self.fail(
+ "Method <<{name:s}>> failed with unexpected "
+ "exception message: {exp:s}\n".format(
+ name=meth_name, exp=e
)
+ )
- for meth_name, recv_meth, expect_success, args in recv_methods:
- indata = (data_prefix + meth_name).encode('ascii')
- try:
- s.send(indata)
- outdata = recv_meth(*args)
- if outdata != indata.lower():
- self.fail(
- "While receiving with <<{name:s}>> bad data "
- "<<{outdata:r}>> ({nout:d}) received; "
- "expected <<{indata:r}>> ({nin:d})\n".format(
- name=meth_name, outdata=outdata[:20],
- nout=len(outdata),
- indata=indata[:20], nin=len(indata)
- )
- )
- except ValueError as e:
- if expect_success:
- self.fail(
- "Failed to receive with method <<{name:s}>>; "
- "expected to succeed.\n".format(name=meth_name)
+ for meth_name, recv_meth, expect_success, args in recv_methods:
+ indata = (data_prefix + meth_name).encode('ascii')
+ try:
+ s.send(indata)
+ outdata = recv_meth(*args)
+ if outdata != indata.lower():
+ self.fail(
+ "While receiving with <<{name:s}>> bad data "
+ "<<{outdata:r}>> ({nout:d}) received; "
+ "expected <<{indata:r}>> ({nin:d})\n".format(
+ name=meth_name, outdata=outdata[:20],
+ nout=len(outdata),
+ indata=indata[:20], nin=len(indata)
)
- if not str(e).startswith(meth_name):
- self.fail(
- "Method <<{name:s}>> failed with unexpected "
- "exception message: {exp:s}\n".format(
- name=meth_name, exp=e
- )
+ )
+ except ValueError as e:
+ if expect_success:
+ self.fail(
+ "Failed to receive with method <<{name:s}>>; "
+ "expected to succeed.\n".format(name=meth_name)
+ )
+ if not str(e).startswith(meth_name):
+ self.fail(
+ "Method <<{name:s}>> failed with unexpected "
+ "exception message: {exp:s}\n".format(
+ name=meth_name, exp=e
)
- # consume data
- s.read()
+ )
+ # consume data
+ s.read()
- # read(-1, buffer) is supported, even though read(-1) is not
- data = b"data"
- s.send(data)
- buffer = bytearray(len(data))
- self.assertEqual(s.read(-1, buffer), len(data))
- self.assertEqual(buffer, data)
+ # read(-1, buffer) is supported, even though read(-1) is not
+ data = b"data"
+ s.send(data)
+ buffer = bytearray(len(data))
+ self.assertEqual(s.read(-1, buffer), len(data))
+ self.assertEqual(buffer, data)
- # Make sure sendmsg et al are disallowed to avoid
- # inadvertent disclosure of data and/or corruption
- # of the encrypted data stream
- self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
- self.assertRaises(NotImplementedError, s.recvmsg, 100)
- self.assertRaises(NotImplementedError,
- s.recvmsg_into, bytearray(100))
+ # Make sure sendmsg et al are disallowed to avoid
+ # inadvertent disclosure of data and/or corruption
+ # of the encrypted data stream
+ self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
+ self.assertRaises(NotImplementedError, s.recvmsg, 100)
+ self.assertRaises(NotImplementedError,
+ s.recvmsg_into, bytearray(100))
- s.write(b"over\n")
+ s.write(b"over\n")
- self.assertRaises(ValueError, s.recv, -1)
- self.assertRaises(ValueError, s.read, -1)
+ self.assertRaises(ValueError, s.recv, -1)
+ self.assertRaises(ValueError, s.read, -1)
- s.close()
+ s.close()
- def test_recv_zero(self):
- server = ThreadedEchoServer(CERTFILE)
- server.__enter__()
- self.addCleanup(server.__exit__, None, None)
- s = socket.create_connection((HOST, server.port))
- self.addCleanup(s.close)
- s = test_wrap_socket(s, suppress_ragged_eofs=False)
- self.addCleanup(s.close)
+ def test_recv_zero(self):
+ server = ThreadedEchoServer(CERTFILE)
+ server.__enter__()
+ self.addCleanup(server.__exit__, None, None)
+ s = socket.create_connection((HOST, server.port))
+ self.addCleanup(s.close)
+ s = test_wrap_socket(s, suppress_ragged_eofs=False)
+ self.addCleanup(s.close)
- # recv/read(0) should return no data
- s.send(b"data")
- self.assertEqual(s.recv(0), b"")
- self.assertEqual(s.read(0), b"")
- self.assertEqual(s.read(), b"data")
+ # recv/read(0) should return no data
+ s.send(b"data")
+ self.assertEqual(s.recv(0), b"")
+ self.assertEqual(s.read(0), b"")
+ self.assertEqual(s.read(), b"data")
+
+ # Should not block if the other end sends no data
+ s.setblocking(False)
+ self.assertEqual(s.recv(0), b"")
+ self.assertEqual(s.recv_into(bytearray()), 0)
- # Should not block if the other end sends no data
+ def test_nonblocking_send(self):
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = test_wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
s.setblocking(False)
- self.assertEqual(s.recv(0), b"")
- self.assertEqual(s.recv_into(bytearray()), 0)
-
- def test_nonblocking_send(self):
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- cacerts=CERTFILE,
- chatty=True,
- connectionchatty=False)
- with server:
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
- s.connect((HOST, server.port))
- s.setblocking(False)
-
- # If we keep sending data, at some point the buffers
- # will be full and the call will block
- buf = bytearray(8192)
- def fill_buffer():
- while True:
- s.send(buf)
- self.assertRaises((ssl.SSLWantWriteError,
- ssl.SSLWantReadError), fill_buffer)
-
- # Now read all the output and discard it
- s.setblocking(True)
- s.close()
- def test_handshake_timeout(self):
- # Issue #5103: SSL handshake must respect the socket timeout
- server = socket.socket(socket.AF_INET)
- host = "127.0.0.1"
- port = support.bind_port(server)
- started = threading.Event()
- finish = False
-
- def serve():
- server.listen()
- started.set()
- conns = []
- while not finish:
- r, w, e = select.select([server], [], [], 0.1)
- if server in r:
- # Let the socket hang around rather than having
- # it closed by garbage collection.
- conns.append(server.accept()[0])
- for sock in conns:
- sock.close()
-
- t = threading.Thread(target=serve)
- t.start()
- started.wait()
+ # If we keep sending data, at some point the buffers
+ # will be full and the call will block
+ buf = bytearray(8192)
+ def fill_buffer():
+ while True:
+ s.send(buf)
+ self.assertRaises((ssl.SSLWantWriteError,
+ ssl.SSLWantReadError), fill_buffer)
+
+ # Now read all the output and discard it
+ s.setblocking(True)
+ s.close()
+
+ def test_handshake_timeout(self):
+ # Issue #5103: SSL handshake must respect the socket timeout
+ server = socket.socket(socket.AF_INET)
+ host = "127.0.0.1"
+ port = support.bind_port(server)
+ started = threading.Event()
+ finish = False
+
+ def serve():
+ server.listen()
+ started.set()
+ conns = []
+ while not finish:
+ r, w, e = select.select([server], [], [], 0.1)
+ if server in r:
+ # Let the socket hang around rather than having
+ # it closed by garbage collection.
+ conns.append(server.accept()[0])
+ for sock in conns:
+ sock.close()
+
+ t = threading.Thread(target=serve)
+ t.start()
+ started.wait()
+ try:
try:
- try:
- c = socket.socket(socket.AF_INET)
- c.settimeout(0.2)
- c.connect((host, port))
- # Will attempt handshake and time out
- self.assertRaisesRegex(socket.timeout, "timed out",
- test_wrap_socket, c)
- finally:
- c.close()
- try:
- c = socket.socket(socket.AF_INET)
- c = test_wrap_socket(c)
- c.settimeout(0.2)
- # Will attempt handshake and time out
- self.assertRaisesRegex(socket.timeout, "timed out",
- c.connect, (host, port))
- finally:
- c.close()
+ c = socket.socket(socket.AF_INET)
+ c.settimeout(0.2)
+ c.connect((host, port))
+ # Will attempt handshake and time out
+ self.assertRaisesRegex(socket.timeout, "timed out",
+ test_wrap_socket, c)
finally:
- finish = True
- t.join()
- server.close()
-
- def test_server_accept(self):
- # Issue #16357: accept() on a SSLSocket created through
- # SSLContext.wrap_socket().
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
- server = socket.socket(socket.AF_INET)
- host = "127.0.0.1"
- port = support.bind_port(server)
- server = context.wrap_socket(server, server_side=True)
- self.assertTrue(server.server_side)
-
- evt = threading.Event()
- remote = None
- peer = None
- def serve():
- nonlocal remote, peer
- server.listen()
- # Block on the accept and wait on the connection to close.
- evt.set()
- remote, peer = server.accept()
- remote.recv(1)
-
- t = threading.Thread(target=serve)
- t.start()
- # Client wait until server setup and perform a connect.
- evt.wait()
- client = context.wrap_socket(socket.socket())
- client.connect((host, port))
- client_addr = client.getsockname()
- client.close()
+ c.close()
+ try:
+ c = socket.socket(socket.AF_INET)
+ c = test_wrap_socket(c)
+ c.settimeout(0.2)
+ # Will attempt handshake and time out
+ self.assertRaisesRegex(socket.timeout, "timed out",
+ c.connect, (host, port))
+ finally:
+ c.close()
+ finally:
+ finish = True
t.join()
- remote.close()
server.close()
- # Sanity checks.
- self.assertIsInstance(remote, ssl.SSLSocket)
- self.assertEqual(peer, client_addr)
-
- def test_getpeercert_enotconn(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- with context.wrap_socket(socket.socket()) as sock:
- with self.assertRaises(OSError) as cm:
- sock.getpeercert()
- self.assertEqual(cm.exception.errno, errno.ENOTCONN)
-
- def test_do_handshake_enotconn(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- with context.wrap_socket(socket.socket()) as sock:
- with self.assertRaises(OSError) as cm:
- sock.do_handshake()
- self.assertEqual(cm.exception.errno, errno.ENOTCONN)
-
- def test_default_ciphers(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- try:
- # Force a set of weak ciphers on our client context
- context.set_ciphers("DES")
- except ssl.SSLError:
- self.skipTest("no DES cipher available")
- with ThreadedEchoServer(CERTFILE,
- ssl_version=ssl.PROTOCOL_SSLv23,
- chatty=False) as server:
- with context.wrap_socket(socket.socket()) as s:
- with self.assertRaises(OSError):
- s.connect((HOST, server.port))
- self.assertIn("no shared cipher", server.conn_errors[0])
-
- def test_version_basic(self):
- """
- Basic tests for SSLSocket.version().
- More tests are done in the test_protocol_*() methods.
- """
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- with ThreadedEchoServer(CERTFILE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- chatty=False) as server:
- with context.wrap_socket(socket.socket()) as s:
- self.assertIs(s.version(), None)
- s.connect((HOST, server.port))
- self.assertEqual(s.version(), 'TLSv1')
- self.assertIs(s.version(), None)
- @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
- def test_default_ecdh_curve(self):
- # Issue #21015: elliptic curve-based Diffie Hellman key exchange
- # should be enabled by default on SSL contexts.
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.load_cert_chain(CERTFILE)
- # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
- # explicitly using the 'ECCdraft' cipher alias. Otherwise,
- # our default cipher list should prefer ECDH-based ciphers
- # automatically.
- if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
- context.set_ciphers("ECCdraft:ECDH")
- with ThreadedEchoServer(context=context) as server:
- with context.wrap_socket(socket.socket()) as s:
+ def test_server_accept(self):
+ # Issue #16357: accept() on a SSLSocket created through
+ # SSLContext.wrap_socket().
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = socket.socket(socket.AF_INET)
+ host = "127.0.0.1"
+ port = support.bind_port(server)
+ server = context.wrap_socket(server, server_side=True)
+ self.assertTrue(server.server_side)
+
+ evt = threading.Event()
+ remote = None
+ peer = None
+ def serve():
+ nonlocal remote, peer
+ server.listen()
+ # Block on the accept and wait on the connection to close.
+ evt.set()
+ remote, peer = server.accept()
+ remote.recv(1)
+
+ t = threading.Thread(target=serve)
+ t.start()
+ # Client wait until server setup and perform a connect.
+ evt.wait()
+ client = context.wrap_socket(socket.socket())
+ client.connect((host, port))
+ client_addr = client.getsockname()
+ client.close()
+ t.join()
+ remote.close()
+ server.close()
+ # Sanity checks.
+ self.assertIsInstance(remote, ssl.SSLSocket)
+ self.assertEqual(peer, client_addr)
+
+ def test_getpeercert_enotconn(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ with context.wrap_socket(socket.socket()) as sock:
+ with self.assertRaises(OSError) as cm:
+ sock.getpeercert()
+ self.assertEqual(cm.exception.errno, errno.ENOTCONN)
+
+ def test_do_handshake_enotconn(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ with context.wrap_socket(socket.socket()) as sock:
+ with self.assertRaises(OSError) as cm:
+ sock.do_handshake()
+ self.assertEqual(cm.exception.errno, errno.ENOTCONN)
+
+ def test_default_ciphers(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ try:
+ # Force a set of weak ciphers on our client context
+ context.set_ciphers("DES")
+ except ssl.SSLError:
+ self.skipTest("no DES cipher available")
+ with ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_SSLv23,
+ chatty=False) as server:
+ with context.wrap_socket(socket.socket()) as s:
+ with self.assertRaises(OSError):
s.connect((HOST, server.port))
- self.assertIn("ECDH", s.cipher()[0])
+ self.assertIn("no shared cipher", server.conn_errors[0])
- @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
- "'tls-unique' channel binding not available")
- def test_tls_unique_channel_binding(self):
- """Test tls-unique channel binding."""
- if support.verbose:
- sys.stdout.write("\n")
-
- server = ThreadedEchoServer(CERTFILE,
- certreqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1,
- cacerts=CERTFILE,
- chatty=True,
- connectionchatty=False)
- with server:
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
+ def test_version_basic(self):
+ """
+ Basic tests for SSLSocket.version().
+ More tests are done in the test_protocol_*() methods.
+ """
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ with ThreadedEchoServer(CERTFILE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ chatty=False) as server:
+ with context.wrap_socket(socket.socket()) as s:
+ self.assertIs(s.version(), None)
s.connect((HOST, server.port))
- # get the data
- cb_data = s.get_channel_binding("tls-unique")
- if support.verbose:
- sys.stdout.write(" got channel binding data: {0!r}\n"
- .format(cb_data))
-
- # check if it is sane
- self.assertIsNotNone(cb_data)
- self.assertEqual(len(cb_data), 12) # True for TLSv1
-
- # and compare with the peers version
- s.write(b"CB tls-unique\n")
- peer_data_repr = s.read().strip()
- self.assertEqual(peer_data_repr,
- repr(cb_data).encode("us-ascii"))
- s.close()
-
- # now, again
- s = test_wrap_socket(socket.socket(),
- server_side=False,
- certfile=CERTFILE,
- ca_certs=CERTFILE,
- cert_reqs=ssl.CERT_NONE,
- ssl_version=ssl.PROTOCOL_TLSv1)
+ self.assertEqual(s.version(), 'TLSv1')
+ self.assertIs(s.version(), None)
+
+ @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
+ def test_default_ecdh_curve(self):
+ # Issue #21015: elliptic curve-based Diffie Hellman key exchange
+ # should be enabled by default on SSL contexts.
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.load_cert_chain(CERTFILE)
+ # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
+ # explicitly using the 'ECCdraft' cipher alias. Otherwise,
+ # our default cipher list should prefer ECDH-based ciphers
+ # automatically.
+ if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
+ context.set_ciphers("ECCdraft:ECDH")
+ with ThreadedEchoServer(context=context) as server:
+ with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
- new_cb_data = s.get_channel_binding("tls-unique")
- if support.verbose:
- sys.stdout.write(" got another channel binding data: {0!r}\n"
- .format(new_cb_data))
- # is it really unique
- self.assertNotEqual(cb_data, new_cb_data)
- self.assertIsNotNone(cb_data)
- self.assertEqual(len(cb_data), 12) # True for TLSv1
- s.write(b"CB tls-unique\n")
- peer_data_repr = s.read().strip()
- self.assertEqual(peer_data_repr,
- repr(new_cb_data).encode("us-ascii"))
- s.close()
+ self.assertIn("ECDH", s.cipher()[0])
- def test_compression(self):
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- if support.verbose:
- sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
- self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
-
- @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
- "ssl.OP_NO_COMPRESSION needed for this test")
- def test_compression_disabled(self):
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- context.options |= ssl.OP_NO_COMPRESSION
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- self.assertIs(stats['compression'], None)
-
- def test_dh_params(self):
- # Check we can get a connection with ephemeral Diffie-Hellman
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- context.load_dh_params(DHFILE)
- context.set_ciphers("kEDH")
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- cipher = stats["cipher"][0]
- parts = cipher.split("-")
- if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
- self.fail("Non-DH cipher: " + cipher[0])
-
- def test_selected_alpn_protocol(self):
- # selected_alpn_protocol() is None unless ALPN is used.
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- self.assertIs(stats['client_alpn_protocol'], None)
+ @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
+ "'tls-unique' channel binding not available")
+ def test_tls_unique_channel_binding(self):
+ """Test tls-unique channel binding."""
+ if support.verbose:
+ sys.stdout.write("\n")
- @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
- def test_selected_alpn_protocol_if_server_uses_alpn(self):
- # selected_alpn_protocol() is None unless ALPN is used by the client.
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.load_verify_locations(CERTFILE)
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server = ThreadedEchoServer(CERTFILE,
+ certreqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1,
+ cacerts=CERTFILE,
+ chatty=True,
+ connectionchatty=False)
+ with server:
+ s = test_wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ # get the data
+ cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got channel binding data: {0!r}\n"
+ .format(cb_data))
+
+ # check if it is sane
+ self.assertIsNotNone(cb_data)
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+
+ # and compare with the peers version
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(cb_data).encode("us-ascii"))
+ s.close()
+
+ # now, again
+ s = test_wrap_socket(socket.socket(),
+ server_side=False,
+ certfile=CERTFILE,
+ ca_certs=CERTFILE,
+ cert_reqs=ssl.CERT_NONE,
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ s.connect((HOST, server.port))
+ new_cb_data = s.get_channel_binding("tls-unique")
+ if support.verbose:
+ sys.stdout.write(" got another channel binding data: {0!r}\n"
+ .format(new_cb_data))
+ # is it really unique
+ self.assertNotEqual(cb_data, new_cb_data)
+ self.assertIsNotNone(cb_data)
+ self.assertEqual(len(cb_data), 12) # True for TLSv1
+ s.write(b"CB tls-unique\n")
+ peer_data_repr = s.read().strip()
+ self.assertEqual(peer_data_repr,
+ repr(new_cb_data).encode("us-ascii"))
+ s.close()
+
+ def test_compression(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ if support.verbose:
+ sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
+ self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
+
+ @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
+ "ssl.OP_NO_COMPRESSION needed for this test")
+ def test_compression_disabled(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ context.options |= ssl.OP_NO_COMPRESSION
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['compression'], None)
+
+ def test_dh_params(self):
+ # Check we can get a connection with ephemeral Diffie-Hellman
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ context.load_dh_params(DHFILE)
+ context.set_ciphers("kEDH")
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ cipher = stats["cipher"][0]
+ parts = cipher.split("-")
+ if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
+ self.fail("Non-DH cipher: " + cipher[0])
+
+ def test_selected_alpn_protocol(self):
+ # selected_alpn_protocol() is None unless ALPN is used.
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
+ def test_selected_alpn_protocol_if_server_uses_alpn(self):
+ # selected_alpn_protocol() is None unless ALPN is used by the client.
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.load_verify_locations(CERTFILE)
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_alpn_protocols(['foo', 'bar'])
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_alpn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
+ def test_alpn_protocols(self):
+ server_protocols = ['foo', 'bar', 'milkshake']
+ protocol_tests = [
+ (['foo', 'bar'], 'foo'),
+ (['bar', 'foo'], 'foo'),
+ (['milkshake'], 'milkshake'),
+ (['http/3.0', 'http/4.0'], None)
+ ]
+ for client_protocols, expected in protocol_tests:
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
server_context.load_cert_chain(CERTFILE)
- server_context.set_alpn_protocols(['foo', 'bar'])
- stats = server_params_test(client_context, server_context,
- chatty=True, connectionchatty=True)
- self.assertIs(stats['client_alpn_protocol'], None)
-
- @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
- def test_alpn_protocols(self):
- server_protocols = ['foo', 'bar', 'milkshake']
- protocol_tests = [
- (['foo', 'bar'], 'foo'),
- (['bar', 'foo'], 'foo'),
- (['milkshake'], 'milkshake'),
- (['http/3.0', 'http/4.0'], None)
- ]
- for client_protocols, expected in protocol_tests:
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
- server_context.load_cert_chain(CERTFILE)
- server_context.set_alpn_protocols(server_protocols)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
- client_context.load_cert_chain(CERTFILE)
- client_context.set_alpn_protocols(client_protocols)
+ server_context.set_alpn_protocols(server_protocols)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
+ client_context.load_cert_chain(CERTFILE)
+ client_context.set_alpn_protocols(client_protocols)
- try:
- stats = server_params_test(client_context,
- server_context,
- chatty=True,
- connectionchatty=True)
- except ssl.SSLError as e:
- stats = e
-
- if (expected is None and IS_OPENSSL_1_1
- and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
- # OpenSSL 1.1.0 to 1.1.0e raises handshake error
- self.assertIsInstance(stats, ssl.SSLError)
- else:
- msg = "failed trying %s (s) and %s (c).\n" \
- "was expecting %s, but got %%s from the %%s" \
- % (str(server_protocols), str(client_protocols),
- str(expected))
- client_result = stats['client_alpn_protocol']
- self.assertEqual(client_result, expected,
- msg % (client_result, "client"))
- server_result = stats['server_alpn_protocols'][-1] \
- if len(stats['server_alpn_protocols']) else 'nothing'
- self.assertEqual(server_result, expected,
- msg % (server_result, "server"))
-
- def test_selected_npn_protocol(self):
- # selected_npn_protocol() is None unless NPN is used
- context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- context.load_cert_chain(CERTFILE)
- stats = server_params_test(context, context,
- chatty=True, connectionchatty=True)
- self.assertIs(stats['client_npn_protocol'], None)
-
- @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
- def test_npn_protocols(self):
- server_protocols = ['http/1.1', 'spdy/2']
- protocol_tests = [
- (['http/1.1', 'spdy/2'], 'http/1.1'),
- (['spdy/2', 'http/1.1'], 'http/1.1'),
- (['spdy/2', 'test'], 'spdy/2'),
- (['abc', 'def'], 'abc')
- ]
- for client_protocols, expected in protocol_tests:
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(CERTFILE)
- server_context.set_npn_protocols(server_protocols)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.load_cert_chain(CERTFILE)
- client_context.set_npn_protocols(client_protocols)
- stats = server_params_test(client_context, server_context,
- chatty=True, connectionchatty=True)
+ try:
+ stats = server_params_test(client_context,
+ server_context,
+ chatty=True,
+ connectionchatty=True)
+ except ssl.SSLError as e:
+ stats = e
+ if (expected is None and IS_OPENSSL_1_1
+ and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
+ # OpenSSL 1.1.0 to 1.1.0e raises handshake error
+ self.assertIsInstance(stats, ssl.SSLError)
+ else:
msg = "failed trying %s (s) and %s (c).\n" \
- "was expecting %s, but got %%s from the %%s" \
- % (str(server_protocols), str(client_protocols),
- str(expected))
- client_result = stats['client_npn_protocol']
- self.assertEqual(client_result, expected, msg % (client_result, "client"))
- server_result = stats['server_npn_protocols'][-1] \
- if len(stats['server_npn_protocols']) else 'nothing'
- self.assertEqual(server_result, expected, msg % (server_result, "server"))
-
- def sni_contexts(self):
+ "was expecting %s, but got %%s from the %%s" \
+ % (str(server_protocols), str(client_protocols),
+ str(expected))
+ client_result = stats['client_alpn_protocol']
+ self.assertEqual(client_result, expected,
+ msg % (client_result, "client"))
+ server_result = stats['server_alpn_protocols'][-1] \
+ if len(stats['server_alpn_protocols']) else 'nothing'
+ self.assertEqual(server_result, expected,
+ msg % (server_result, "server"))
+
+ def test_selected_npn_protocol(self):
+ # selected_npn_protocol() is None unless NPN is used
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.load_cert_chain(CERTFILE)
+ stats = server_params_test(context, context,
+ chatty=True, connectionchatty=True)
+ self.assertIs(stats['client_npn_protocol'], None)
+
+ @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
+ def test_npn_protocols(self):
+ server_protocols = ['http/1.1', 'spdy/2']
+ protocol_tests = [
+ (['http/1.1', 'spdy/2'], 'http/1.1'),
+ (['spdy/2', 'http/1.1'], 'http/1.1'),
+ (['spdy/2', 'test'], 'spdy/2'),
+ (['abc', 'def'], 'abc')
+ ]
+ for client_protocols, expected in protocol_tests:
server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
- other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- other_context.load_cert_chain(SIGNED_CERTFILE2)
+ server_context.load_cert_chain(CERTFILE)
+ server_context.set_npn_protocols(server_protocols)
client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.verify_mode = ssl.CERT_REQUIRED
- client_context.load_verify_locations(SIGNING_CA)
- return server_context, other_context, client_context
+ client_context.load_cert_chain(CERTFILE)
+ client_context.set_npn_protocols(client_protocols)
+ stats = server_params_test(client_context, server_context,
+ chatty=True, connectionchatty=True)
- def check_common_name(self, stats, name):
- cert = stats['peercert']
- self.assertIn((('commonName', name),), cert['subject'])
+ msg = "failed trying %s (s) and %s (c).\n" \
+ "was expecting %s, but got %%s from the %%s" \
+ % (str(server_protocols), str(client_protocols),
+ str(expected))
+ client_result = stats['client_npn_protocol']
+ self.assertEqual(client_result, expected, msg % (client_result, "client"))
+ server_result = stats['server_npn_protocols'][-1] \
+ if len(stats['server_npn_protocols']) else 'nothing'
+ self.assertEqual(server_result, expected, msg % (server_result, "server"))
+
+ def sni_contexts(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ other_context.load_cert_chain(SIGNED_CERTFILE2)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+ return server_context, other_context, client_context
+
+ def check_common_name(self, stats, name):
+ cert = stats['peercert']
+ self.assertIn((('commonName', name),), cert['subject'])
+
+ @needs_sni
+ def test_sni_callback(self):
+ calls = []
+ server_context, other_context, client_context = self.sni_contexts()
+
+ def servername_cb(ssl_sock, server_name, initial_context):
+ calls.append((server_name, initial_context))
+ if server_name is not None:
+ ssl_sock.context = other_context
+ server_context.set_servername_callback(servername_cb)
+
+ stats = server_params_test(client_context, server_context,
+ chatty=True,
+ sni_name='supermessage')
+ # The hostname was fetched properly, and the certificate was
+ # changed for the connection.
+ self.assertEqual(calls, [("supermessage", server_context)])
+ # CERTFILE4 was selected
+ self.check_common_name(stats, 'fakehostname')
+
+ calls = []
+ # The callback is called with server_name=None
+ stats = server_params_test(client_context, server_context,
+ chatty=True,
+ sni_name=None)
+ self.assertEqual(calls, [(None, server_context)])
+ self.check_common_name(stats, 'localhost')
+
+ # Check disabling the callback
+ calls = []
+ server_context.set_servername_callback(None)
+
+ stats = server_params_test(client_context, server_context,
+ chatty=True,
+ sni_name='notfunny')
+ # Certificate didn't change
+ self.check_common_name(stats, 'localhost')
+ self.assertEqual(calls, [])
- @needs_sni
- def test_sni_callback(self):
- calls = []
- server_context, other_context, client_context = self.sni_contexts()
+ @needs_sni
+ def test_sni_callback_alert(self):
+ # Returning a TLS alert is reflected to the connecting client
+ server_context, other_context, client_context = self.sni_contexts()
- def servername_cb(ssl_sock, server_name, initial_context):
- calls.append((server_name, initial_context))
- if server_name is not None:
- ssl_sock.context = other_context
- server_context.set_servername_callback(servername_cb)
+ def cb_returning_alert(ssl_sock, server_name, initial_context):
+ return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
+ server_context.set_servername_callback(cb_returning_alert)
+ with self.assertRaises(ssl.SSLError) as cm:
stats = server_params_test(client_context, server_context,
- chatty=True,
+ chatty=False,
sni_name='supermessage')
- # The hostname was fetched properly, and the certificate was
- # changed for the connection.
- self.assertEqual(calls, [("supermessage", server_context)])
- # CERTFILE4 was selected
- self.check_common_name(stats, 'fakehostname')
-
- calls = []
- # The callback is called with server_name=None
+ self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
+
+ @needs_sni
+ def test_sni_callback_raising(self):
+ # Raising fails the connection with a TLS handshake failure alert.
+ server_context, other_context, client_context = self.sni_contexts()
+
+ def cb_raising(ssl_sock, server_name, initial_context):
+ 1/0
+ server_context.set_servername_callback(cb_raising)
+
+ with self.assertRaises(ssl.SSLError) as cm, \
+ support.captured_stderr() as stderr:
stats = server_params_test(client_context, server_context,
- chatty=True,
- sni_name=None)
- self.assertEqual(calls, [(None, server_context)])
- self.check_common_name(stats, 'localhost')
+ chatty=False,
+ sni_name='supermessage')
+ self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE')
+ self.assertIn("ZeroDivisionError", stderr.getvalue())
+
+ @needs_sni
+ def test_sni_callback_wrong_return_type(self):
+ # Returning the wrong return type terminates the TLS connection
+ # with an internal error alert.
+ server_context, other_context, client_context = self.sni_contexts()
- # Check disabling the callback
- calls = []
- server_context.set_servername_callback(None)
+ def cb_wrong_return_type(ssl_sock, server_name, initial_context):
+ return "foo"
+ server_context.set_servername_callback(cb_wrong_return_type)
+ with self.assertRaises(ssl.SSLError) as cm, \
+ support.captured_stderr() as stderr:
stats = server_params_test(client_context, server_context,
- chatty=True,
- sni_name='notfunny')
- # Certificate didn't change
- self.check_common_name(stats, 'localhost')
- self.assertEqual(calls, [])
-
- @needs_sni
- def test_sni_callback_alert(self):
- # Returning a TLS alert is reflected to the connecting client
- server_context, other_context, client_context = self.sni_contexts()
-
- def cb_returning_alert(ssl_sock, server_name, initial_context):
- return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
- server_context.set_servername_callback(cb_returning_alert)
-
- with self.assertRaises(ssl.SSLError) as cm:
- stats = server_params_test(client_context, server_context,
- chatty=False,
- sni_name='supermessage')
- self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
-
- @needs_sni
- def test_sni_callback_raising(self):
- # Raising fails the connection with a TLS handshake failure alert.
- server_context, other_context, client_context = self.sni_contexts()
-
- def cb_raising(ssl_sock, server_name, initial_context):
- 1/0
- server_context.set_servername_callback(cb_raising)
-
- with self.assertRaises(ssl.SSLError) as cm, \
- support.captured_stderr() as stderr:
- stats = server_params_test(client_context, server_context,
- chatty=False,
- sni_name='supermessage')
- self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE')
- self.assertIn("ZeroDivisionError", stderr.getvalue())
-
- @needs_sni
- def test_sni_callback_wrong_return_type(self):
- # Returning the wrong return type terminates the TLS connection
- # with an internal error alert.
- server_context, other_context, client_context = self.sni_contexts()
-
- def cb_wrong_return_type(ssl_sock, server_name, initial_context):
- return "foo"
- server_context.set_servername_callback(cb_wrong_return_type)
-
- with self.assertRaises(ssl.SSLError) as cm, \
- support.captured_stderr() as stderr:
- stats = server_params_test(client_context, server_context,
- chatty=False,
- sni_name='supermessage')
- self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
- self.assertIn("TypeError", stderr.getvalue())
-
- def test_shared_ciphers(self):
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.verify_mode = ssl.CERT_REQUIRED
- client_context.load_verify_locations(SIGNING_CA)
- if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
- client_context.set_ciphers("AES128:AES256")
- server_context.set_ciphers("AES256")
- alg1 = "AES256"
- alg2 = "AES-256"
- else:
- client_context.set_ciphers("AES:3DES")
- server_context.set_ciphers("3DES")
- alg1 = "3DES"
- alg2 = "DES-CBC3"
-
- stats = server_params_test(client_context, server_context)
- ciphers = stats['server_shared_ciphers'][0]
- self.assertGreater(len(ciphers), 0)
- for name, tls_version, bits in ciphers:
- if not alg1 in name.split("-") and alg2 not in name:
- self.fail(name)
-
- def test_read_write_after_close_raises_valuerror(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
- server = ThreadedEchoServer(context=context, chatty=False)
-
- with server:
- s = context.wrap_socket(socket.socket())
+ chatty=False,
+ sni_name='supermessage')
+ self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
+ self.assertIn("TypeError", stderr.getvalue())
+
+ def test_shared_ciphers(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+ if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
+ client_context.set_ciphers("AES128:AES256")
+ server_context.set_ciphers("AES256")
+ alg1 = "AES256"
+ alg2 = "AES-256"
+ else:
+ client_context.set_ciphers("AES:3DES")
+ server_context.set_ciphers("3DES")
+ alg1 = "3DES"
+ alg2 = "DES-CBC3"
+
+ stats = server_params_test(client_context, server_context)
+ ciphers = stats['server_shared_ciphers'][0]
+ self.assertGreater(len(ciphers), 0)
+ for name, tls_version, bits in ciphers:
+ if not alg1 in name.split("-") and alg2 not in name:
+ self.fail(name)
+
+ def test_read_write_after_close_raises_valuerror(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
+
+ with server:
+ s = context.wrap_socket(socket.socket())
+ s.connect((HOST, server.port))
+ s.close()
+
+ self.assertRaises(ValueError, s.read, 1024)
+ self.assertRaises(ValueError, s.write, b'hello')
+
+ def test_sendfile(self):
+ TEST_DATA = b"x" * 512
+ with open(support.TESTFN, 'wb') as f:
+ f.write(TEST_DATA)
+ self.addCleanup(support.unlink, support.TESTFN)
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+ server = ThreadedEchoServer(context=context, chatty=False)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
- s.close()
+ with open(support.TESTFN, 'rb') as file:
+ s.sendfile(file)
+ self.assertEqual(s.recv(1024), TEST_DATA)
+
+ def test_session(self):
+ server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ server_context.load_cert_chain(SIGNED_CERTFILE)
+ client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ client_context.verify_mode = ssl.CERT_REQUIRED
+ client_context.load_verify_locations(SIGNING_CA)
+
+ # first connection without session
+ stats = server_params_test(client_context, server_context)
+ session = stats['session']
+ self.assertTrue(session.id)
+ self.assertGreater(session.time, 0)
+ self.assertGreater(session.timeout, 0)
+ self.assertTrue(session.has_ticket)
+ if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
+ self.assertGreater(session.ticket_lifetime_hint, 0)
+ self.assertFalse(stats['session_reused'])
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 1)
+ self.assertEqual(sess_stat['hits'], 0)
+
+ # reuse session
+ stats = server_params_test(client_context, server_context, session=session)
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 2)
+ self.assertEqual(sess_stat['hits'], 1)
+ self.assertTrue(stats['session_reused'])
+ session2 = stats['session']
+ self.assertEqual(session2.id, session.id)
+ self.assertEqual(session2, session)
+ self.assertIsNot(session2, session)
+ self.assertGreaterEqual(session2.time, session.time)
+ self.assertGreaterEqual(session2.timeout, session.timeout)
+
+ # another one without session
+ stats = server_params_test(client_context, server_context)
+ self.assertFalse(stats['session_reused'])
+ session3 = stats['session']
+ self.assertNotEqual(session3.id, session.id)
+ self.assertNotEqual(session3, session)
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 3)
+ self.assertEqual(sess_stat['hits'], 1)
+
+ # reuse session again
+ stats = server_params_test(client_context, server_context, session=session)
+ self.assertTrue(stats['session_reused'])
+ session4 = stats['session']
+ self.assertEqual(session4.id, session.id)
+ self.assertEqual(session4, session)
+ self.assertGreaterEqual(session4.time, session.time)
+ self.assertGreaterEqual(session4.timeout, session.timeout)
+ sess_stat = server_context.session_stats()
+ self.assertEqual(sess_stat['accept'], 4)
+ self.assertEqual(sess_stat['hits'], 2)
+
+ def test_session_handling(self):
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(CERTFILE)
+ context.load_cert_chain(CERTFILE)
+
+ context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context2.verify_mode = ssl.CERT_REQUIRED
+ context2.load_verify_locations(CERTFILE)
+ context2.load_cert_chain(CERTFILE)
+
+ server = ThreadedEchoServer(context=context, chatty=False)
+ with server:
+ with context.wrap_socket(socket.socket()) as s:
+ # session is None before handshake
+ self.assertEqual(s.session, None)
+ self.assertEqual(s.session_reused, None)
+ s.connect((HOST, server.port))
+ session = s.session
+ self.assertTrue(session)
+ with self.assertRaises(TypeError) as e:
+ s.session = object
+ self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
- self.assertRaises(ValueError, s.read, 1024)
- self.assertRaises(ValueError, s.write, b'hello')
-
- def test_sendfile(self):
- TEST_DATA = b"x" * 512
- with open(support.TESTFN, 'wb') as f:
- f.write(TEST_DATA)
- self.addCleanup(support.unlink, support.TESTFN)
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
- server = ThreadedEchoServer(context=context, chatty=False)
- with server:
- with context.wrap_socket(socket.socket()) as s:
- s.connect((HOST, server.port))
- with open(support.TESTFN, 'rb') as file:
- s.sendfile(file)
- self.assertEqual(s.recv(1024), TEST_DATA)
+ with context.wrap_socket(socket.socket()) as s:
+ s.connect((HOST, server.port))
+ # cannot set session after handshake
+ with self.assertRaises(ValueError) as e:
+ s.session = session
+ self.assertEqual(str(e.exception),
+ 'Cannot set session after handshake.')
- def test_session(self):
- server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- server_context.load_cert_chain(SIGNED_CERTFILE)
- client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
- client_context.verify_mode = ssl.CERT_REQUIRED
- client_context.load_verify_locations(SIGNING_CA)
-
- # first connection without session
- stats = server_params_test(client_context, server_context)
- session = stats['session']
- self.assertTrue(session.id)
- self.assertGreater(session.time, 0)
- self.assertGreater(session.timeout, 0)
- self.assertTrue(session.has_ticket)
- if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
- self.assertGreater(session.ticket_lifetime_hint, 0)
- self.assertFalse(stats['session_reused'])
- sess_stat = server_context.session_stats()
- self.assertEqual(sess_stat['accept'], 1)
- self.assertEqual(sess_stat['hits'], 0)
-
- # reuse session
- stats = server_params_test(client_context, server_context, session=session)
- sess_stat = server_context.session_stats()
- self.assertEqual(sess_stat['accept'], 2)
- self.assertEqual(sess_stat['hits'], 1)
- self.assertTrue(stats['session_reused'])
- session2 = stats['session']
- self.assertEqual(session2.id, session.id)
- self.assertEqual(session2, session)
- self.assertIsNot(session2, session)
- self.assertGreaterEqual(session2.time, session.time)
- self.assertGreaterEqual(session2.timeout, session.timeout)
-
- # another one without session
- stats = server_params_test(client_context, server_context)
- self.assertFalse(stats['session_reused'])
- session3 = stats['session']
- self.assertNotEqual(session3.id, session.id)
- self.assertNotEqual(session3, session)
- sess_stat = server_context.session_stats()
- self.assertEqual(sess_stat['accept'], 3)
- self.assertEqual(sess_stat['hits'], 1)
-
- # reuse session again
- stats = server_params_test(client_context, server_context, session=session)
- self.assertTrue(stats['session_reused'])
- session4 = stats['session']
- self.assertEqual(session4.id, session.id)
- self.assertEqual(session4, session)
- self.assertGreaterEqual(session4.time, session.time)
- self.assertGreaterEqual(session4.timeout, session.timeout)
- sess_stat = server_context.session_stats()
- self.assertEqual(sess_stat['accept'], 4)
- self.assertEqual(sess_stat['hits'], 2)
-
- def test_session_handling(self):
- context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context.verify_mode = ssl.CERT_REQUIRED
- context.load_verify_locations(CERTFILE)
- context.load_cert_chain(CERTFILE)
-
- context2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- context2.verify_mode = ssl.CERT_REQUIRED
- context2.load_verify_locations(CERTFILE)
- context2.load_cert_chain(CERTFILE)
-
- server = ThreadedEchoServer(context=context, chatty=False)
- with server:
- with context.wrap_socket(socket.socket()) as s:
- # session is None before handshake
- self.assertEqual(s.session, None)
- self.assertEqual(s.session_reused, None)
- s.connect((HOST, server.port))
- session = s.session
- self.assertTrue(session)
- with self.assertRaises(TypeError) as e:
- s.session = object
- self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
+ with context.wrap_socket(socket.socket()) as s:
+ # can set session before handshake and before the
+ # connection was established
+ s.session = session
+ s.connect((HOST, server.port))
+ self.assertEqual(s.session.id, session.id)
+ self.assertEqual(s.session, session)
+ self.assertEqual(s.session_reused, True)
- with context.wrap_socket(socket.socket()) as s:
- s.connect((HOST, server.port))
- # cannot set session after handshake
- with self.assertRaises(ValueError) as e:
- s.session = session
- self.assertEqual(str(e.exception),
- 'Cannot set session after handshake.')
-
- with context.wrap_socket(socket.socket()) as s:
- # can set session before handshake and before the
- # connection was established
+ with context2.wrap_socket(socket.socket()) as s:
+ # cannot re-use session with a different SSLContext
+ with self.assertRaises(ValueError) as e:
s.session = session
s.connect((HOST, server.port))
- self.assertEqual(s.session.id, session.id)
- self.assertEqual(s.session, session)
- self.assertEqual(s.session_reused, True)
-
- with context2.wrap_socket(socket.socket()) as s:
- # cannot re-use session with a different SSLContext
- with self.assertRaises(ValueError) as e:
- s.session = session
- s.connect((HOST, server.port))
- self.assertEqual(str(e.exception),
- 'Session refers to a different SSLContext.')
+ self.assertEqual(str(e.exception),
+ 'Session refers to a different SSLContext.')
def test_main(verbose=False):
@@ -3610,22 +3603,17 @@ def test_main(verbose=False):
tests = [
ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
- SimpleBackgroundTests,
+ SimpleBackgroundTests, ThreadedTests,
]
if support.is_resource_enabled('network'):
tests.append(NetworkedTests)
- if _have_threads:
- thread_info = support.threading_setup()
- if thread_info:
- tests.append(ThreadedTests)
-
+ thread_info = support.threading_setup()
try:
support.run_unittest(*tests)
finally:
- if _have_threads:
- support.threading_cleanup(*thread_info)
+ support.threading_cleanup(*thread_info)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py
index b7079b1..10ef87b 100644
--- a/Lib/test/test_subprocess.py
+++ b/Lib/test/test_subprocess.py
@@ -14,6 +14,7 @@ import selectors
import sysconfig
import select
import shutil
+import threading
import gc
import textwrap
@@ -25,11 +26,6 @@ else:
import ctypes.util
try:
- import threading
-except ImportError:
- threading = None
-
-try:
import _testcapi
except ImportError:
_testcapi = None
@@ -1196,7 +1192,6 @@ class ProcessTestCase(BaseTestCase):
self.assertEqual(stderr, "")
self.assertEqual(proc.returncode, 0)
- @unittest.skipIf(threading is None, "threading required")
def test_double_close_on_error(self):
# Issue #18851
fds = []
@@ -1226,7 +1221,6 @@ class ProcessTestCase(BaseTestCase):
if exc is not None:
raise exc
- @unittest.skipIf(threading is None, "threading required")
def test_threadsafe_wait(self):
"""Issue21291: Popen.wait() needs to be threadsafe for returncode."""
proc = subprocess.Popen([sys.executable, '-c',
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index 3844812..04550e5 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -10,15 +10,12 @@ import codecs
import gc
import sysconfig
import locale
+import threading
# count the number of test runs, used to create unique
# strings to intern in test_intern()
numruns = 0
-try:
- import threading
-except ImportError:
- threading = None
class SysModuleTest(unittest.TestCase):
@@ -172,7 +169,6 @@ class SysModuleTest(unittest.TestCase):
sys.setcheckinterval(n)
self.assertEqual(sys.getcheckinterval(), n)
- @unittest.skipUnless(threading, 'Threading required for this test.')
def test_switchinterval(self):
self.assertRaises(TypeError, sys.setswitchinterval)
self.assertRaises(TypeError, sys.setswitchinterval, "a")
@@ -348,21 +344,8 @@ class SysModuleTest(unittest.TestCase):
)
# sys._current_frames() is a CPython-only gimmick.
- def test_current_frames(self):
- have_threads = True
- try:
- import _thread
- except ImportError:
- have_threads = False
-
- if have_threads:
- self.current_frames_with_threads()
- else:
- self.current_frames_without_threads()
-
- # Test sys._current_frames() in a WITH_THREADS build.
@test.support.reap_threads
- def current_frames_with_threads(self):
+ def test_current_frames(self):
import threading
import traceback
@@ -426,15 +409,6 @@ class SysModuleTest(unittest.TestCase):
leave_g.set()
t.join()
- # Test sys._current_frames() when thread support doesn't exist.
- def current_frames_without_threads(self):
- # Not much happens here: there is only one thread, with artificial
- # "thread id" 0.
- d = sys._current_frames()
- self.assertEqual(len(d), 1)
- self.assertIn(0, d)
- self.assertTrue(d[0] is sys._getframe())
-
def test_attributes(self):
self.assertIsInstance(sys.api_version, int)
self.assertIsInstance(sys.argv, list)
@@ -516,8 +490,6 @@ class SysModuleTest(unittest.TestCase):
if not sys.platform.startswith('win'):
self.assertIsInstance(sys.abiflags, str)
- @unittest.skipUnless(hasattr(sys, 'thread_info'),
- 'Threading required for this test.')
def test_thread_info(self):
info = sys.thread_info
self.assertEqual(len(info), 3)
diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py
index 51d82e1..414c328 100644
--- a/Lib/test/test_telnetlib.py
+++ b/Lib/test/test_telnetlib.py
@@ -1,11 +1,11 @@
import socket
import selectors
import telnetlib
+import threading
import contextlib
from test import support
import unittest
-threading = support.import_module('threading')
HOST = support.HOST
diff --git a/Lib/test/test_threaded_import.py b/Lib/test/test_threaded_import.py
index 75c66b0..035344b 100644
--- a/Lib/test/test_threaded_import.py
+++ b/Lib/test/test_threaded_import.py
@@ -11,12 +11,12 @@ import importlib
import sys
import time
import shutil
+import threading
import unittest
from unittest import mock
from test.support import (
verbose, import_module, run_unittest, TESTFN, reap_threads,
forget, unlink, rmtree, start_threads)
-threading = import_module('threading')
def task(N, done, done_tasks, errors):
try:
diff --git a/Lib/test/test_threadedtempfile.py b/Lib/test/test_threadedtempfile.py
index b742036..f3d4ba3 100644
--- a/Lib/test/test_threadedtempfile.py
+++ b/Lib/test/test_threadedtempfile.py
@@ -19,9 +19,9 @@ FILES_PER_THREAD = 50
import tempfile
from test.support import start_threads, import_module
-threading = import_module('threading')
import unittest
import io
+import threading
from traceback import print_exc
startEvent = threading.Event()
diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py
index 800d26f..912eb3f 100644
--- a/Lib/test/test_threading.py
+++ b/Lib/test/test_threading.py
@@ -9,8 +9,8 @@ from test.support.script_helper import assert_python_ok, assert_python_failure
import random
import sys
-_thread = import_module('_thread')
-threading = import_module('threading')
+import _thread
+import threading
import time
import unittest
import weakref
diff --git a/Lib/test/test_threading_local.py b/Lib/test/test_threading_local.py
index 4092cf3..984f8dd 100644
--- a/Lib/test/test_threading_local.py
+++ b/Lib/test/test_threading_local.py
@@ -5,8 +5,8 @@ import weakref
import gc
# Modules under test
-_thread = support.import_module('_thread')
-threading = support.import_module('threading')
+import _thread
+import threading
import _threading_local
diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py
index 810ec37..42323b9 100644
--- a/Lib/test/test_time.py
+++ b/Lib/test/test_time.py
@@ -9,10 +9,6 @@ import sysconfig
import time
import unittest
try:
- import threading
-except ImportError:
- threading = None
-try:
import _testcapi
except ImportError:
_testcapi = None
diff --git a/Lib/test/test_tracemalloc.py b/Lib/test/test_tracemalloc.py
index 742259b..780942a 100644
--- a/Lib/test/test_tracemalloc.py
+++ b/Lib/test/test_tracemalloc.py
@@ -7,10 +7,7 @@ from unittest.mock import patch
from test.support.script_helper import (assert_python_ok, assert_python_failure,
interpreter_requires_environment)
from test import support
-try:
- import threading
-except ImportError:
- threading = None
+
try:
import _testcapi
except ImportError:
diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py
index f83f9cc..741d136 100644
--- a/Lib/test/test_urllib2_localnet.py
+++ b/Lib/test/test_urllib2_localnet.py
@@ -4,13 +4,12 @@ import email
import urllib.parse
import urllib.request
import http.server
+import threading
import unittest
import hashlib
from test import support
-threading = support.import_module('threading')
-
try:
import ssl
except ImportError:
@@ -276,7 +275,6 @@ class FakeProxyHandler(http.server.BaseHTTPRequestHandler):
# Test cases
-@unittest.skipUnless(threading, "Threading required for this test.")
class BasicAuthTests(unittest.TestCase):
USER = "testUser"
PASSWD = "testPass"
@@ -317,7 +315,6 @@ class BasicAuthTests(unittest.TestCase):
self.assertRaises(urllib.error.HTTPError, urllib.request.urlopen, self.server_url)
-@unittest.skipUnless(threading, "Threading required for this test.")
class ProxyAuthTests(unittest.TestCase):
URL = "http://localhost"
@@ -439,7 +436,6 @@ def GetRequestHandler(responses):
return FakeHTTPRequestHandler
-@unittest.skipUnless(threading, "Threading required for this test.")
class TestUrlopen(unittest.TestCase):
"""Tests urllib.request.urlopen using the network.
diff --git a/Lib/test/test_venv.py b/Lib/test/test_venv.py
index 2691632..aac010b 100644
--- a/Lib/test/test_venv.py
+++ b/Lib/test/test_venv.py
@@ -15,15 +15,10 @@ import sys
import tempfile
from test.support import (captured_stdout, captured_stderr,
can_symlink, EnvironmentVarGuard, rmtree)
+import threading
import unittest
import venv
-
-try:
- import threading
-except ImportError:
- threading = None
-
try:
import ctypes
except ImportError:
@@ -420,8 +415,6 @@ class EnsurePipTest(BaseTest):
if not system_site_packages:
self.assert_pip_not_installed()
- @unittest.skipUnless(threading, 'some dependencies of pip import threading'
- ' module unconditionally')
# Issue #26610: pip/pep425tags.py requires ctypes
@unittest.skipUnless(ctypes, 'pip requires ctypes')
def test_with_pip(self):
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
index 43cf2c0..0384a9f 100644
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -6,6 +6,7 @@ import weakref
import operator
import contextlib
import copy
+import threading
import time
from test import support
@@ -78,7 +79,6 @@ def collect_in_thread(period=0.0001):
"""
Ensure GC collections happen in a different thread, at a high frequency.
"""
- threading = support.import_module('threading')
please_stop = False
def collect():
diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py
index a609eef..f2b496a 100644
--- a/Lib/test/test_xmlrpc.py
+++ b/Lib/test/test_xmlrpc.py
@@ -10,6 +10,7 @@ import xmlrpc.server
import http.client
import http, http.server
import socket
+import threading
import re
import io
import contextlib
@@ -19,10 +20,6 @@ try:
import gzip
except ImportError:
gzip = None
-try:
- import threading
-except ImportError:
- threading = None
alist = [{'astring': 'foo@bar.baz.spam',
'afloat': 7283.43,
@@ -307,7 +304,6 @@ class XMLRPCTestCase(unittest.TestCase):
except OSError:
self.assertTrue(has_ssl)
- @unittest.skipUnless(threading, "Threading required for this test.")
def test_keepalive_disconnect(self):
class RequestHandler(http.server.BaseHTTPRequestHandler):
protocol_version = "HTTP/1.1"
@@ -747,7 +743,6 @@ def make_request_and_skipIf(condition, reason):
return make_request_and_skip
return decorator
-@unittest.skipUnless(threading, 'Threading required for this test.')
class BaseServerTestCase(unittest.TestCase):
requestHandler = None
request_count = 1
@@ -1206,7 +1201,6 @@ class FailingMessageClass(http.client.HTTPMessage):
return super().get(key, failobj)
-@unittest.skipUnless(threading, 'Threading required for this test.')
class FailingServerTestCase(unittest.TestCase):
def setUp(self):
self.evt = threading.Event()
diff --git a/Lib/threading.py b/Lib/threading.py
index 06dbc68..e4bf974 100644
--- a/Lib/threading.py
+++ b/Lib/threading.py
@@ -990,38 +990,12 @@ class Thread:
def _delete(self):
"Remove current thread from the dict of currently running threads."
-
- # Notes about running with _dummy_thread:
- #
- # Must take care to not raise an exception if _dummy_thread is being
- # used (and thus this module is being used as an instance of
- # dummy_threading). _dummy_thread.get_ident() always returns 1 since
- # there is only one thread if _dummy_thread is being used. Thus
- # len(_active) is always <= 1 here, and any Thread instance created
- # overwrites the (if any) thread currently registered in _active.
- #
- # An instance of _MainThread is always created by 'threading'. This
- # gets overwritten the instant an instance of Thread is created; both
- # threads return 1 from _dummy_thread.get_ident() and thus have the
- # same key in the dict. So when the _MainThread instance created by
- # 'threading' tries to clean itself up when atexit calls this method
- # it gets a KeyError if another Thread instance was created.
- #
- # This all means that KeyError from trying to delete something from
- # _active if dummy_threading is being used is a red herring. But
- # since it isn't if dummy_threading is *not* being used then don't
- # hide the exception.
-
- try:
- with _active_limbo_lock:
- del _active[get_ident()]
- # There must not be any python code between the previous line
- # and after the lock is released. Otherwise a tracing function
- # could try to acquire the lock again in the same thread, (in
- # current_thread()), and would block.
- except KeyError:
- if 'dummy_threading' not in _sys.modules:
- raise
+ with _active_limbo_lock:
+ del _active[get_ident()]
+ # There must not be any python code between the previous line
+ # and after the lock is released. Otherwise a tracing function
+ # could try to acquire the lock again in the same thread, (in
+ # current_thread()), and would block.
def join(self, timeout=None):
"""Wait until the thread terminates.
diff --git a/Lib/trace.py b/Lib/trace.py
index e443edd..48a1d1b 100755
--- a/Lib/trace.py
+++ b/Lib/trace.py
@@ -61,21 +61,15 @@ import dis
import pickle
from time import monotonic as _time
-try:
- import threading
-except ImportError:
- _settrace = sys.settrace
-
- def _unsettrace():
- sys.settrace(None)
-else:
- def _settrace(func):
- threading.settrace(func)
- sys.settrace(func)
-
- def _unsettrace():
- sys.settrace(None)
- threading.settrace(None)
+import threading
+
+def _settrace(func):
+ threading.settrace(func)
+ sys.settrace(func)
+
+def _unsettrace():
+ sys.settrace(None)
+ threading.settrace(None)
PRAGMA_NOCOVER = "#pragma NO COVER"
diff --git a/Lib/zipfile.py b/Lib/zipfile.py
index cc46a6c..37ce328 100644
--- a/Lib/zipfile.py
+++ b/Lib/zipfile.py
@@ -12,11 +12,7 @@ import stat
import shutil
import struct
import binascii
-
-try:
- import threading
-except ImportError:
- import dummy_threading as threading
+import threading
try:
import zlib # We may need its compression method