summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorAntoine Pitrou <antoine@python.org>2023-11-04 13:59:24 (GMT)
committerGitHub <noreply@github.com>2023-11-04 13:59:24 (GMT)
commit0e9c364f4ac18a2237bdbac702b96bcf8ef9cb09 (patch)
tree8febb8282c2c1ebd73a18205ec5b9229a99ac4fe /Lib
parenta28a3967ab9a189122f895d51d2551f7b3a273b0 (diff)
downloadcpython-0e9c364f4ac18a2237bdbac702b96bcf8ef9cb09.zip
cpython-0e9c364f4ac18a2237bdbac702b96bcf8ef9cb09.tar.gz
cpython-0e9c364f4ac18a2237bdbac702b96bcf8ef9cb09.tar.bz2
GH-110829: Ensure Thread.join() joins the OS thread (#110848)
Joining a thread now ensures the underlying OS thread has exited. This is required for safer fork() in multi-threaded processes. --------- Co-authored-by: blurb-it[bot] <43283697+blurb-it[bot]@users.noreply.github.com>
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/_test_multiprocessing.py3
-rw-r--r--Lib/test/audit-tests.py3
-rw-r--r--Lib/test/test_audit.py2
-rw-r--r--Lib/test/test_concurrent_futures/test_process_pool.py6
-rw-r--r--Lib/test/test_thread.py126
-rw-r--r--Lib/test/test_threading.py47
-rw-r--r--Lib/threading.py63
7 files changed, 229 insertions, 21 deletions
diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py
index bf87a3e..ec003d8 100644
--- a/Lib/test/_test_multiprocessing.py
+++ b/Lib/test/_test_multiprocessing.py
@@ -2693,6 +2693,9 @@ class _TestPool(BaseTestCase):
p.join()
def test_terminate(self):
+ if self.TYPE == 'threads':
+ self.skipTest("Threads cannot be terminated")
+
# Simulate slow tasks which take "forever" to complete
p = self.Pool(3)
args = [support.LONG_TIMEOUT for i in range(10_000)]
diff --git a/Lib/test/audit-tests.py b/Lib/test/audit-tests.py
index 89f407d..ce4a11b 100644
--- a/Lib/test/audit-tests.py
+++ b/Lib/test/audit-tests.py
@@ -455,6 +455,9 @@ def test_threading():
i = _thread.start_new_thread(test_func(), ())
lock.acquire()
+ handle = _thread.start_joinable_thread(test_func())
+ handle.join()
+
def test_threading_abort():
# Ensures that aborting PyThreadState_New raises the correct exception
diff --git a/Lib/test/test_audit.py b/Lib/test/test_audit.py
index 47e5832..cd0a4e2 100644
--- a/Lib/test/test_audit.py
+++ b/Lib/test/test_audit.py
@@ -209,6 +209,8 @@ class AuditTest(unittest.TestCase):
expected = [
("_thread.start_new_thread", "(<test_func>, (), None)"),
("test.test_func", "()"),
+ ("_thread.start_joinable_thread", "(<test_func>,)"),
+ ("test.test_func", "()"),
]
self.assertEqual(actual, expected)
diff --git a/Lib/test/test_concurrent_futures/test_process_pool.py b/Lib/test/test_concurrent_futures/test_process_pool.py
index c73c2da..3e61b0c 100644
--- a/Lib/test/test_concurrent_futures/test_process_pool.py
+++ b/Lib/test/test_concurrent_futures/test_process_pool.py
@@ -194,11 +194,11 @@ class ProcessPoolExecutorTest(ExecutorTest):
context = self.get_context()
- # gh-109047: Mock the threading.start_new_thread() function to inject
+ # gh-109047: Mock the threading.start_joinable_thread() function to inject
# RuntimeError: simulate the error raised during Python finalization.
# Block the second creation: create _ExecutorManagerThread, but block
# QueueFeederThread.
- orig_start_new_thread = threading._start_new_thread
+ orig_start_new_thread = threading._start_joinable_thread
nthread = 0
def mock_start_new_thread(func, *args):
nonlocal nthread
@@ -208,7 +208,7 @@ class ProcessPoolExecutorTest(ExecutorTest):
nthread += 1
return orig_start_new_thread(func, *args)
- with support.swap_attr(threading, '_start_new_thread',
+ with support.swap_attr(threading, '_start_joinable_thread',
mock_start_new_thread):
executor = self.executor_type(max_workers=2, mp_context=context)
with executor:
diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py
index 831aaf5..931cb4b 100644
--- a/Lib/test/test_thread.py
+++ b/Lib/test/test_thread.py
@@ -160,6 +160,132 @@ class ThreadRunningTests(BasicThreadTest):
f"Exception ignored in thread started by {task!r}")
self.assertIsNotNone(cm.unraisable.exc_traceback)
+ def test_join_thread(self):
+ finished = []
+
+ def task():
+ time.sleep(0.05)
+ finished.append(thread.get_ident())
+
+ with threading_helper.wait_threads_exit():
+ handle = thread.start_joinable_thread(task)
+ handle.join()
+ self.assertEqual(len(finished), 1)
+ self.assertEqual(handle.ident, finished[0])
+
+ def test_join_thread_already_exited(self):
+ def task():
+ pass
+
+ with threading_helper.wait_threads_exit():
+ handle = thread.start_joinable_thread(task)
+ time.sleep(0.05)
+ handle.join()
+
+ def test_join_several_times(self):
+ def task():
+ pass
+
+ with threading_helper.wait_threads_exit():
+ handle = thread.start_joinable_thread(task)
+ handle.join()
+ with self.assertRaisesRegex(ValueError, "not joinable"):
+ handle.join()
+
+ def test_joinable_not_joined(self):
+ handle_destroyed = thread.allocate_lock()
+ handle_destroyed.acquire()
+
+ def task():
+ handle_destroyed.acquire()
+
+ with threading_helper.wait_threads_exit():
+ handle = thread.start_joinable_thread(task)
+ del handle
+ handle_destroyed.release()
+
+ def test_join_from_self(self):
+ errors = []
+ handles = []
+ start_joinable_thread_returned = thread.allocate_lock()
+ start_joinable_thread_returned.acquire()
+ task_tried_to_join = thread.allocate_lock()
+ task_tried_to_join.acquire()
+
+ def task():
+ start_joinable_thread_returned.acquire()
+ try:
+ handles[0].join()
+ except Exception as e:
+ errors.append(e)
+ finally:
+ task_tried_to_join.release()
+
+ with threading_helper.wait_threads_exit():
+ handle = thread.start_joinable_thread(task)
+ handles.append(handle)
+ start_joinable_thread_returned.release()
+ # Can still join after joining failed in other thread
+ task_tried_to_join.acquire()
+ handle.join()
+
+ assert len(errors) == 1
+ with self.assertRaisesRegex(RuntimeError, "Cannot join current thread"):
+ raise errors[0]
+
+ def test_detach_from_self(self):
+ errors = []
+ handles = []
+ start_joinable_thread_returned = thread.allocate_lock()
+ start_joinable_thread_returned.acquire()
+ thread_detached = thread.allocate_lock()
+ thread_detached.acquire()
+
+ def task():
+ start_joinable_thread_returned.acquire()
+ try:
+ handles[0].detach()
+ except Exception as e:
+ errors.append(e)
+ finally:
+ thread_detached.release()
+
+ with threading_helper.wait_threads_exit():
+ handle = thread.start_joinable_thread(task)
+ handles.append(handle)
+ start_joinable_thread_returned.release()
+ thread_detached.acquire()
+ with self.assertRaisesRegex(ValueError, "not joinable"):
+ handle.join()
+
+ assert len(errors) == 0
+
+ def test_detach_then_join(self):
+ lock = thread.allocate_lock()
+ lock.acquire()
+
+ def task():
+ lock.acquire()
+
+ with threading_helper.wait_threads_exit():
+ handle = thread.start_joinable_thread(task)
+ # detach() returns even though the thread is blocked on lock
+ handle.detach()
+ # join() then cannot be called anymore
+ with self.assertRaisesRegex(ValueError, "not joinable"):
+ handle.join()
+ lock.release()
+
+ def test_join_then_detach(self):
+ def task():
+ pass
+
+ with threading_helper.wait_threads_exit():
+ handle = thread.start_joinable_thread(task)
+ handle.join()
+ with self.assertRaisesRegex(ValueError, "not joinable"):
+ handle.detach()
+
class Barrier:
def __init__(self, num_threads):
diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py
index 00a6437..146e2db 100644
--- a/Lib/test/test_threading.py
+++ b/Lib/test/test_threading.py
@@ -376,8 +376,8 @@ class ThreadTests(BaseTestCase):
# Issue 7481: Failure to start thread should cleanup the limbo map.
def fail_new_thread(*args):
raise threading.ThreadError()
- _start_new_thread = threading._start_new_thread
- threading._start_new_thread = fail_new_thread
+ _start_joinable_thread = threading._start_joinable_thread
+ threading._start_joinable_thread = fail_new_thread
try:
t = threading.Thread(target=lambda: None)
self.assertRaises(threading.ThreadError, t.start)
@@ -385,7 +385,7 @@ class ThreadTests(BaseTestCase):
t in threading._limbo,
"Failed to cleanup _limbo map on failure of Thread.start().")
finally:
- threading._start_new_thread = _start_new_thread
+ threading._start_joinable_thread = _start_joinable_thread
def test_finalize_running_thread(self):
# Issue 1402: the PyGILState_Ensure / _Release functions may be called
@@ -482,6 +482,47 @@ class ThreadTests(BaseTestCase):
finally:
sys.setswitchinterval(old_interval)
+ def test_join_from_multiple_threads(self):
+ # Thread.join() should be thread-safe
+ errors = []
+
+ def worker():
+ time.sleep(0.005)
+
+ def joiner(thread):
+ try:
+ thread.join()
+ except Exception as e:
+ errors.append(e)
+
+ for N in range(2, 20):
+ threads = [threading.Thread(target=worker)]
+ for i in range(N):
+ threads.append(threading.Thread(target=joiner,
+ args=(threads[0],)))
+ for t in threads:
+ t.start()
+ time.sleep(0.01)
+ for t in threads:
+ t.join()
+ if errors:
+ raise errors[0]
+
+ def test_join_with_timeout(self):
+ lock = _thread.allocate_lock()
+ lock.acquire()
+
+ def worker():
+ lock.acquire()
+
+ thread = threading.Thread(target=worker)
+ thread.start()
+ thread.join(timeout=0.01)
+ assert thread.is_alive()
+ lock.release()
+ thread.join()
+ assert not thread.is_alive()
+
def test_no_refcycle_through_target(self):
class RunSelfFunction(object):
def __init__(self, should_raise):
diff --git a/Lib/threading.py b/Lib/threading.py
index 41c3a9f..85aff58 100644
--- a/Lib/threading.py
+++ b/Lib/threading.py
@@ -5,6 +5,7 @@ import sys as _sys
import _thread
import functools
import warnings
+import _weakref
from time import monotonic as _time
from _weakrefset import WeakSet
@@ -33,7 +34,7 @@ __all__ = ['get_ident', 'active_count', 'Condition', 'current_thread',
'setprofile_all_threads','settrace_all_threads']
# Rename some stuff so "from threading import *" is safe
-_start_new_thread = _thread.start_new_thread
+_start_joinable_thread = _thread.start_joinable_thread
_daemon_threads_allowed = _thread.daemon_threads_allowed
_allocate_lock = _thread.allocate_lock
_set_sentinel = _thread._set_sentinel
@@ -589,7 +590,7 @@ class Event:
return f"<{cls.__module__}.{cls.__qualname__} at {id(self):#x}: {status}>"
def _at_fork_reinit(self):
- # Private method called by Thread._reset_internal_locks()
+ # Private method called by Thread._after_fork()
self._cond._at_fork_reinit()
def is_set(self):
@@ -924,6 +925,8 @@ class Thread:
if _HAVE_THREAD_NATIVE_ID:
self._native_id = None
self._tstate_lock = None
+ self._join_lock = None
+ self._handle = None
self._started = Event()
self._is_stopped = False
self._initialized = True
@@ -933,22 +936,32 @@ class Thread:
# For debugging and _after_fork()
_dangling.add(self)
- def _reset_internal_locks(self, is_alive):
- # private! Called by _after_fork() to reset our internal locks as
- # they may be in an invalid state leading to a deadlock or crash.
+ def _after_fork(self, new_ident=None):
+ # Private! Called by threading._after_fork().
self._started._at_fork_reinit()
- if is_alive:
+ if new_ident is not None:
+ # This thread is alive.
+ self._ident = new_ident
+ if self._handle is not None:
+ self._handle.after_fork_alive()
+ assert self._handle.ident == new_ident
# bpo-42350: If the fork happens when the thread is already stopped
# (ex: after threading._shutdown() has been called), _tstate_lock
# is None. Do nothing in this case.
if self._tstate_lock is not None:
self._tstate_lock._at_fork_reinit()
self._tstate_lock.acquire()
+ if self._join_lock is not None:
+ self._join_lock._at_fork_reinit()
else:
- # The thread isn't alive after fork: it doesn't have a tstate
+ # This thread isn't alive after fork: it doesn't have a tstate
# anymore.
self._is_stopped = True
self._tstate_lock = None
+ self._join_lock = None
+ if self._handle is not None:
+ self._handle.after_fork_dead()
+ self._handle = None
def __repr__(self):
assert self._initialized, "Thread.__init__() was not called"
@@ -980,15 +993,18 @@ class Thread:
if self._started.is_set():
raise RuntimeError("threads can only be started once")
+ self._join_lock = _allocate_lock()
+
with _active_limbo_lock:
_limbo[self] = self
try:
- _start_new_thread(self._bootstrap, ())
+ # Start joinable thread
+ self._handle = _start_joinable_thread(self._bootstrap)
except Exception:
with _active_limbo_lock:
del _limbo[self]
raise
- self._started.wait()
+ self._started.wait() # Will set ident and native_id
def run(self):
"""Method representing the thread's activity.
@@ -1144,6 +1160,22 @@ class Thread:
# historically .join(timeout=x) for x<0 has acted as if timeout=0
self._wait_for_tstate_lock(timeout=max(timeout, 0))
+ if self._is_stopped:
+ self._join_os_thread()
+
+ def _join_os_thread(self):
+ join_lock = self._join_lock
+ if join_lock is None:
+ return
+ with join_lock:
+ # Calling join() multiple times would raise an exception
+ # in one of the callers.
+ if self._handle is not None:
+ self._handle.join()
+ self._handle = None
+ # No need to keep this around
+ self._join_lock = None
+
def _wait_for_tstate_lock(self, block=True, timeout=-1):
# Issue #18808: wait for the thread state to be gone.
# At the end of the thread's life, after all knowledge of the thread
@@ -1223,7 +1255,10 @@ class Thread:
if self._is_stopped or not self._started.is_set():
return False
self._wait_for_tstate_lock(False)
- return not self._is_stopped
+ if not self._is_stopped:
+ return True
+ self._join_os_thread()
+ return False
@property
def daemon(self):
@@ -1679,15 +1714,13 @@ def _after_fork():
# Any lock/condition variable may be currently locked or in an
# invalid state, so we reinitialize them.
if thread is current:
- # There is only one active thread. We reset the ident to
- # its new value since it can have changed.
- thread._reset_internal_locks(True)
+ # This is the one and only active thread.
ident = get_ident()
- thread._ident = ident
+ thread._after_fork(new_ident=ident)
new_active[ident] = thread
else:
# All the others are already stopped.
- thread._reset_internal_locks(False)
+ thread._after_fork()
thread._stop()
_limbo.clear()