diff options
author | Antoine Pitrou <antoine@python.org> | 2023-11-04 13:59:24 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-04 13:59:24 (GMT) |
commit | 0e9c364f4ac18a2237bdbac702b96bcf8ef9cb09 (patch) | |
tree | 8febb8282c2c1ebd73a18205ec5b9229a99ac4fe /Lib | |
parent | a28a3967ab9a189122f895d51d2551f7b3a273b0 (diff) | |
download | cpython-0e9c364f4ac18a2237bdbac702b96bcf8ef9cb09.zip cpython-0e9c364f4ac18a2237bdbac702b96bcf8ef9cb09.tar.gz cpython-0e9c364f4ac18a2237bdbac702b96bcf8ef9cb09.tar.bz2 |
GH-110829: Ensure Thread.join() joins the OS thread (#110848)
Joining a thread now ensures the underlying OS thread has exited. This is required for safer fork() in multi-threaded processes.
---------
Co-authored-by: blurb-it[bot] <43283697+blurb-it[bot]@users.noreply.github.com>
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/_test_multiprocessing.py | 3 | ||||
-rw-r--r-- | Lib/test/audit-tests.py | 3 | ||||
-rw-r--r-- | Lib/test/test_audit.py | 2 | ||||
-rw-r--r-- | Lib/test/test_concurrent_futures/test_process_pool.py | 6 | ||||
-rw-r--r-- | Lib/test/test_thread.py | 126 | ||||
-rw-r--r-- | Lib/test/test_threading.py | 47 | ||||
-rw-r--r-- | Lib/threading.py | 63 |
7 files changed, 229 insertions, 21 deletions
diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index bf87a3e..ec003d8 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -2693,6 +2693,9 @@ class _TestPool(BaseTestCase): p.join() def test_terminate(self): + if self.TYPE == 'threads': + self.skipTest("Threads cannot be terminated") + # Simulate slow tasks which take "forever" to complete p = self.Pool(3) args = [support.LONG_TIMEOUT for i in range(10_000)] diff --git a/Lib/test/audit-tests.py b/Lib/test/audit-tests.py index 89f407d..ce4a11b 100644 --- a/Lib/test/audit-tests.py +++ b/Lib/test/audit-tests.py @@ -455,6 +455,9 @@ def test_threading(): i = _thread.start_new_thread(test_func(), ()) lock.acquire() + handle = _thread.start_joinable_thread(test_func()) + handle.join() + def test_threading_abort(): # Ensures that aborting PyThreadState_New raises the correct exception diff --git a/Lib/test/test_audit.py b/Lib/test/test_audit.py index 47e5832..cd0a4e2 100644 --- a/Lib/test/test_audit.py +++ b/Lib/test/test_audit.py @@ -209,6 +209,8 @@ class AuditTest(unittest.TestCase): expected = [ ("_thread.start_new_thread", "(<test_func>, (), None)"), ("test.test_func", "()"), + ("_thread.start_joinable_thread", "(<test_func>,)"), + ("test.test_func", "()"), ] self.assertEqual(actual, expected) diff --git a/Lib/test/test_concurrent_futures/test_process_pool.py b/Lib/test/test_concurrent_futures/test_process_pool.py index c73c2da..3e61b0c 100644 --- a/Lib/test/test_concurrent_futures/test_process_pool.py +++ b/Lib/test/test_concurrent_futures/test_process_pool.py @@ -194,11 +194,11 @@ class ProcessPoolExecutorTest(ExecutorTest): context = self.get_context() - # gh-109047: Mock the threading.start_new_thread() function to inject + # gh-109047: Mock the threading.start_joinable_thread() function to inject # RuntimeError: simulate the error raised during Python finalization. # Block the second creation: create _ExecutorManagerThread, but block # QueueFeederThread. - orig_start_new_thread = threading._start_new_thread + orig_start_new_thread = threading._start_joinable_thread nthread = 0 def mock_start_new_thread(func, *args): nonlocal nthread @@ -208,7 +208,7 @@ class ProcessPoolExecutorTest(ExecutorTest): nthread += 1 return orig_start_new_thread(func, *args) - with support.swap_attr(threading, '_start_new_thread', + with support.swap_attr(threading, '_start_joinable_thread', mock_start_new_thread): executor = self.executor_type(max_workers=2, mp_context=context) with executor: diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py index 831aaf5..931cb4b 100644 --- a/Lib/test/test_thread.py +++ b/Lib/test/test_thread.py @@ -160,6 +160,132 @@ class ThreadRunningTests(BasicThreadTest): f"Exception ignored in thread started by {task!r}") self.assertIsNotNone(cm.unraisable.exc_traceback) + def test_join_thread(self): + finished = [] + + def task(): + time.sleep(0.05) + finished.append(thread.get_ident()) + + with threading_helper.wait_threads_exit(): + handle = thread.start_joinable_thread(task) + handle.join() + self.assertEqual(len(finished), 1) + self.assertEqual(handle.ident, finished[0]) + + def test_join_thread_already_exited(self): + def task(): + pass + + with threading_helper.wait_threads_exit(): + handle = thread.start_joinable_thread(task) + time.sleep(0.05) + handle.join() + + def test_join_several_times(self): + def task(): + pass + + with threading_helper.wait_threads_exit(): + handle = thread.start_joinable_thread(task) + handle.join() + with self.assertRaisesRegex(ValueError, "not joinable"): + handle.join() + + def test_joinable_not_joined(self): + handle_destroyed = thread.allocate_lock() + handle_destroyed.acquire() + + def task(): + handle_destroyed.acquire() + + with threading_helper.wait_threads_exit(): + handle = thread.start_joinable_thread(task) + del handle + handle_destroyed.release() + + def test_join_from_self(self): + errors = [] + handles = [] + start_joinable_thread_returned = thread.allocate_lock() + start_joinable_thread_returned.acquire() + task_tried_to_join = thread.allocate_lock() + task_tried_to_join.acquire() + + def task(): + start_joinable_thread_returned.acquire() + try: + handles[0].join() + except Exception as e: + errors.append(e) + finally: + task_tried_to_join.release() + + with threading_helper.wait_threads_exit(): + handle = thread.start_joinable_thread(task) + handles.append(handle) + start_joinable_thread_returned.release() + # Can still join after joining failed in other thread + task_tried_to_join.acquire() + handle.join() + + assert len(errors) == 1 + with self.assertRaisesRegex(RuntimeError, "Cannot join current thread"): + raise errors[0] + + def test_detach_from_self(self): + errors = [] + handles = [] + start_joinable_thread_returned = thread.allocate_lock() + start_joinable_thread_returned.acquire() + thread_detached = thread.allocate_lock() + thread_detached.acquire() + + def task(): + start_joinable_thread_returned.acquire() + try: + handles[0].detach() + except Exception as e: + errors.append(e) + finally: + thread_detached.release() + + with threading_helper.wait_threads_exit(): + handle = thread.start_joinable_thread(task) + handles.append(handle) + start_joinable_thread_returned.release() + thread_detached.acquire() + with self.assertRaisesRegex(ValueError, "not joinable"): + handle.join() + + assert len(errors) == 0 + + def test_detach_then_join(self): + lock = thread.allocate_lock() + lock.acquire() + + def task(): + lock.acquire() + + with threading_helper.wait_threads_exit(): + handle = thread.start_joinable_thread(task) + # detach() returns even though the thread is blocked on lock + handle.detach() + # join() then cannot be called anymore + with self.assertRaisesRegex(ValueError, "not joinable"): + handle.join() + lock.release() + + def test_join_then_detach(self): + def task(): + pass + + with threading_helper.wait_threads_exit(): + handle = thread.start_joinable_thread(task) + handle.join() + with self.assertRaisesRegex(ValueError, "not joinable"): + handle.detach() + class Barrier: def __init__(self, num_threads): diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 00a6437..146e2db 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -376,8 +376,8 @@ class ThreadTests(BaseTestCase): # Issue 7481: Failure to start thread should cleanup the limbo map. def fail_new_thread(*args): raise threading.ThreadError() - _start_new_thread = threading._start_new_thread - threading._start_new_thread = fail_new_thread + _start_joinable_thread = threading._start_joinable_thread + threading._start_joinable_thread = fail_new_thread try: t = threading.Thread(target=lambda: None) self.assertRaises(threading.ThreadError, t.start) @@ -385,7 +385,7 @@ class ThreadTests(BaseTestCase): t in threading._limbo, "Failed to cleanup _limbo map on failure of Thread.start().") finally: - threading._start_new_thread = _start_new_thread + threading._start_joinable_thread = _start_joinable_thread def test_finalize_running_thread(self): # Issue 1402: the PyGILState_Ensure / _Release functions may be called @@ -482,6 +482,47 @@ class ThreadTests(BaseTestCase): finally: sys.setswitchinterval(old_interval) + def test_join_from_multiple_threads(self): + # Thread.join() should be thread-safe + errors = [] + + def worker(): + time.sleep(0.005) + + def joiner(thread): + try: + thread.join() + except Exception as e: + errors.append(e) + + for N in range(2, 20): + threads = [threading.Thread(target=worker)] + for i in range(N): + threads.append(threading.Thread(target=joiner, + args=(threads[0],))) + for t in threads: + t.start() + time.sleep(0.01) + for t in threads: + t.join() + if errors: + raise errors[0] + + def test_join_with_timeout(self): + lock = _thread.allocate_lock() + lock.acquire() + + def worker(): + lock.acquire() + + thread = threading.Thread(target=worker) + thread.start() + thread.join(timeout=0.01) + assert thread.is_alive() + lock.release() + thread.join() + assert not thread.is_alive() + def test_no_refcycle_through_target(self): class RunSelfFunction(object): def __init__(self, should_raise): diff --git a/Lib/threading.py b/Lib/threading.py index 41c3a9f..85aff58 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -5,6 +5,7 @@ import sys as _sys import _thread import functools import warnings +import _weakref from time import monotonic as _time from _weakrefset import WeakSet @@ -33,7 +34,7 @@ __all__ = ['get_ident', 'active_count', 'Condition', 'current_thread', 'setprofile_all_threads','settrace_all_threads'] # Rename some stuff so "from threading import *" is safe -_start_new_thread = _thread.start_new_thread +_start_joinable_thread = _thread.start_joinable_thread _daemon_threads_allowed = _thread.daemon_threads_allowed _allocate_lock = _thread.allocate_lock _set_sentinel = _thread._set_sentinel @@ -589,7 +590,7 @@ class Event: return f"<{cls.__module__}.{cls.__qualname__} at {id(self):#x}: {status}>" def _at_fork_reinit(self): - # Private method called by Thread._reset_internal_locks() + # Private method called by Thread._after_fork() self._cond._at_fork_reinit() def is_set(self): @@ -924,6 +925,8 @@ class Thread: if _HAVE_THREAD_NATIVE_ID: self._native_id = None self._tstate_lock = None + self._join_lock = None + self._handle = None self._started = Event() self._is_stopped = False self._initialized = True @@ -933,22 +936,32 @@ class Thread: # For debugging and _after_fork() _dangling.add(self) - def _reset_internal_locks(self, is_alive): - # private! Called by _after_fork() to reset our internal locks as - # they may be in an invalid state leading to a deadlock or crash. + def _after_fork(self, new_ident=None): + # Private! Called by threading._after_fork(). self._started._at_fork_reinit() - if is_alive: + if new_ident is not None: + # This thread is alive. + self._ident = new_ident + if self._handle is not None: + self._handle.after_fork_alive() + assert self._handle.ident == new_ident # bpo-42350: If the fork happens when the thread is already stopped # (ex: after threading._shutdown() has been called), _tstate_lock # is None. Do nothing in this case. if self._tstate_lock is not None: self._tstate_lock._at_fork_reinit() self._tstate_lock.acquire() + if self._join_lock is not None: + self._join_lock._at_fork_reinit() else: - # The thread isn't alive after fork: it doesn't have a tstate + # This thread isn't alive after fork: it doesn't have a tstate # anymore. self._is_stopped = True self._tstate_lock = None + self._join_lock = None + if self._handle is not None: + self._handle.after_fork_dead() + self._handle = None def __repr__(self): assert self._initialized, "Thread.__init__() was not called" @@ -980,15 +993,18 @@ class Thread: if self._started.is_set(): raise RuntimeError("threads can only be started once") + self._join_lock = _allocate_lock() + with _active_limbo_lock: _limbo[self] = self try: - _start_new_thread(self._bootstrap, ()) + # Start joinable thread + self._handle = _start_joinable_thread(self._bootstrap) except Exception: with _active_limbo_lock: del _limbo[self] raise - self._started.wait() + self._started.wait() # Will set ident and native_id def run(self): """Method representing the thread's activity. @@ -1144,6 +1160,22 @@ class Thread: # historically .join(timeout=x) for x<0 has acted as if timeout=0 self._wait_for_tstate_lock(timeout=max(timeout, 0)) + if self._is_stopped: + self._join_os_thread() + + def _join_os_thread(self): + join_lock = self._join_lock + if join_lock is None: + return + with join_lock: + # Calling join() multiple times would raise an exception + # in one of the callers. + if self._handle is not None: + self._handle.join() + self._handle = None + # No need to keep this around + self._join_lock = None + def _wait_for_tstate_lock(self, block=True, timeout=-1): # Issue #18808: wait for the thread state to be gone. # At the end of the thread's life, after all knowledge of the thread @@ -1223,7 +1255,10 @@ class Thread: if self._is_stopped or not self._started.is_set(): return False self._wait_for_tstate_lock(False) - return not self._is_stopped + if not self._is_stopped: + return True + self._join_os_thread() + return False @property def daemon(self): @@ -1679,15 +1714,13 @@ def _after_fork(): # Any lock/condition variable may be currently locked or in an # invalid state, so we reinitialize them. if thread is current: - # There is only one active thread. We reset the ident to - # its new value since it can have changed. - thread._reset_internal_locks(True) + # This is the one and only active thread. ident = get_ident() - thread._ident = ident + thread._after_fork(new_ident=ident) new_active[ident] = thread else: # All the others are already stopped. - thread._reset_internal_locks(False) + thread._after_fork() thread._stop() _limbo.clear() |