diff options
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/support/__init__.py | 39 | ||||
-rw-r--r-- | Lib/test/test_bz2.py | 6 | ||||
-rw-r--r-- | Lib/test/test_capi.py | 14 | ||||
-rw-r--r-- | Lib/test/test_gc.py | 12 | ||||
-rw-r--r-- | Lib/test/test_io.py | 30 | ||||
-rw-r--r-- | Lib/test/test_threaded_import.py | 10 | ||||
-rw-r--r-- | Lib/test/test_threadedtempfile.py | 32 | ||||
-rw-r--r-- | Lib/test/test_threading_local.py | 11 |
8 files changed, 74 insertions, 80 deletions
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 10c48b4..75fff21 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -6,6 +6,7 @@ if __name__ != 'test.support': import collections.abc import contextlib import errno +import faulthandler import fnmatch import functools import gc @@ -96,7 +97,7 @@ __all__ = [ # logging "TestHandler", # threads - "threading_setup", "threading_cleanup", + "threading_setup", "threading_cleanup", "reap_threads", "start_threads", # miscellaneous "check_warnings", "EnvironmentVarGuard", "run_with_locale", "swap_item", "swap_attr", "Matcher", "set_memlimit", "SuppressCrashReport", "sortdict", @@ -1941,6 +1942,42 @@ def reap_children(): break @contextlib.contextmanager +def start_threads(threads, unlock=None): + threads = list(threads) + started = [] + try: + try: + for t in threads: + t.start() + started.append(t) + except: + if verbose: + print("Can't start %d threads, only %d threads started" % + (len(threads), len(started))) + raise + yield + finally: + try: + if unlock: + unlock() + endtime = starttime = time.time() + for timeout in range(1, 16): + endtime += 60 + for t in started: + t.join(max(endtime - time.time(), 0.01)) + started = [t for t in started if t.isAlive()] + if not started: + break + if verbose: + print('Unable to join %d threads during a period of ' + '%d minutes' % (len(started), timeout)) + finally: + started = [t for t in started if t.isAlive()] + if started: + faulthandler.dump_traceback(sys.stdout) + raise AssertionError('Unable to join %d threads' % len(started)) + +@contextlib.contextmanager def swap_attr(obj, attr, new_val): """Temporary swap out an attribute with a new object. diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py index 1535e8e..beef275 100644 --- a/Lib/test/test_bz2.py +++ b/Lib/test/test_bz2.py @@ -493,10 +493,8 @@ class BZ2FileTest(BaseTest): for i in range(5): f.write(data) threads = [threading.Thread(target=comp) for i in range(nthreads)] - for t in threads: - t.start() - for t in threads: - t.join() + with support.start_threads(threads): + pass def testWithoutThreading(self): module = support.import_fresh_module("bz2", blocked=("threading",)) diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py index ba7c38d..36c62376 100644 --- a/Lib/test/test_capi.py +++ b/Lib/test/test_capi.py @@ -202,15 +202,11 @@ class TestPendingCalls(unittest.TestCase): context.lock = threading.Lock() context.event = threading.Event() - for i in range(context.nThreads): - t = threading.Thread(target=self.pendingcalls_thread, args = (context,)) - t.start() - threads.append(t) - - self.pendingcalls_wait(context.l, n, context) - - for t in threads: - t.join() + threads = [threading.Thread(target=self.pendingcalls_thread, + args=(context,)) + for i in range(context.nThreads)] + with support.start_threads(threads): + self.pendingcalls_wait(context.l, n, context) def pendingcalls_thread(self, context): try: diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py index c025512..2ac1d4b 100644 --- a/Lib/test/test_gc.py +++ b/Lib/test/test_gc.py @@ -1,6 +1,6 @@ import unittest from test.support import (verbose, refcount_test, run_unittest, - strip_python_stderr, cpython_only) + strip_python_stderr, cpython_only, start_threads) from test.script_helper import assert_python_ok, make_script, temp_dir import sys @@ -397,19 +397,13 @@ class GCTests(unittest.TestCase): old_switchinterval = sys.getswitchinterval() sys.setswitchinterval(1e-5) try: - exit = False + exit = [] threads = [] for i in range(N_THREADS): t = threading.Thread(target=run_thread) threads.append(t) - try: - for t in threads: - t.start() - finally: + with start_threads(threads, lambda: exit.append(1)): time.sleep(1.0) - exit = True - for t in threads: - t.join() finally: sys.setswitchinterval(old_switchinterval) gc.collect() diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py index ec19562..95277d9 100644 --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -1070,11 +1070,8 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests): errors.append(e) raise threads = [threading.Thread(target=f) for x in range(20)] - for t in threads: - t.start() - time.sleep(0.02) # yield - for t in threads: - t.join() + with support.start_threads(threads): + time.sleep(0.02) # yield self.assertFalse(errors, "the following exceptions were caught: %r" % errors) s = b''.join(results) @@ -1393,11 +1390,8 @@ class BufferedWriterTest(unittest.TestCase, CommonBufferedTests): errors.append(e) raise threads = [threading.Thread(target=f) for x in range(20)] - for t in threads: - t.start() - time.sleep(0.02) # yield - for t in threads: - t.join() + with support.start_threads(threads): + time.sleep(0.02) # yield self.assertFalse(errors, "the following exceptions were caught: %r" % errors) bufio.close() @@ -2691,14 +2685,10 @@ class TextIOWrapperTest(unittest.TestCase): text = "Thread%03d\n" % n event.wait() f.write(text) - threads = [threading.Thread(target=lambda n=x: run(n)) + threads = [threading.Thread(target=run, args=(x,)) for x in range(20)] - for t in threads: - t.start() - time.sleep(0.02) - event.set() - for t in threads: - t.join() + with support.start_threads(threads, event.set): + time.sleep(0.02) with self.open(support.TESTFN) as f: content = f.read() for n in range(20): @@ -3402,11 +3392,11 @@ class SignalsTest(unittest.TestCase): # handlers, which in this case will invoke alarm_interrupt(). signal.alarm(1) try: - self.assertRaises(ZeroDivisionError, - wio.write, item * (support.PIPE_MAX_SIZE // len(item) + 1)) + with self.assertRaises(ZeroDivisionError): + wio.write(item * (support.PIPE_MAX_SIZE // len(item) + 1)) finally: signal.alarm(0) - t.join() + t.join() # We got one byte, get another one and check that it isn't a # repeat of the first one. read_results.append(os.read(r, 1)) diff --git a/Lib/test/test_threaded_import.py b/Lib/test/test_threaded_import.py index 192fa08..4be615a 100644 --- a/Lib/test/test_threaded_import.py +++ b/Lib/test/test_threaded_import.py @@ -14,7 +14,7 @@ import shutil import unittest from test.support import ( verbose, import_module, run_unittest, TESTFN, reap_threads, - forget, unlink, rmtree) + forget, unlink, rmtree, start_threads) threading = import_module('threading') def task(N, done, done_tasks, errors): @@ -115,10 +115,10 @@ class ThreadedImportTests(unittest.TestCase): errors = [] done_tasks = [] done.clear() - for i in range(N): - t = threading.Thread(target=task, - args=(N, done, done_tasks, errors,)) - t.start() + with start_threads(threading.Thread(target=task, + args=(N, done, done_tasks, errors,)) + for i in range(N)): + pass self.assertTrue(done.wait(60)) self.assertFalse(errors) if verbose: diff --git a/Lib/test/test_threadedtempfile.py b/Lib/test/test_threadedtempfile.py index 2dfd3a0..b742036 100644 --- a/Lib/test/test_threadedtempfile.py +++ b/Lib/test/test_threadedtempfile.py @@ -18,7 +18,7 @@ FILES_PER_THREAD = 50 import tempfile -from test.support import threading_setup, threading_cleanup, run_unittest, import_module +from test.support import start_threads, import_module threading = import_module('threading') import unittest import io @@ -46,33 +46,17 @@ class TempFileGreedy(threading.Thread): class ThreadedTempFileTest(unittest.TestCase): def test_main(self): - threads = [] - thread_info = threading_setup() - - for i in range(NUM_THREADS): - t = TempFileGreedy() - threads.append(t) - t.start() - - startEvent.set() - - ok = 0 - errors = [] - for t in threads: - t.join() - ok += t.ok_count - if t.error_count: - errors.append(str(t.name) + str(t.errors.getvalue())) - - threading_cleanup(*thread_info) + threads = [TempFileGreedy() for i in range(NUM_THREADS)] + with start_threads(threads, startEvent.set): + pass + ok = sum(t.ok_count for t in threads) + errors = [str(t.name) + str(t.errors.getvalue()) + for t in threads if t.error_count] msg = "Errors: errors %d ok %d\n%s" % (len(errors), ok, '\n'.join(errors)) self.assertEqual(errors, [], msg) self.assertEqual(ok, NUM_THREADS * FILES_PER_THREAD) -def test_main(): - run_unittest(ThreadedTempFileTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_threading_local.py b/Lib/test/test_threading_local.py index c886a25..c7f394c 100644 --- a/Lib/test/test_threading_local.py +++ b/Lib/test/test_threading_local.py @@ -64,14 +64,9 @@ class BaseLocalTest: # Simply check that the variable is correctly set self.assertEqual(local.x, i) - threads= [] - for i in range(10): - t = threading.Thread(target=f, args=(i,)) - t.start() - threads.append(t) - - for t in threads: - t.join() + with support.start_threads(threading.Thread(target=f, args=(i,)) + for i in range(10)): + pass def test_derived_cycle_dealloc(self): # http://bugs.python.org/issue6990 |