From 4848b0b92ce2737cea08fa3b322fd0f0a671bb07 Mon Sep 17 00:00:00 2001
From: Eric Snow <ericsnowcurrently@gmail.com>
Date: Mon, 21 Oct 2024 15:49:58 -0600
Subject: gh-125716: Use A Global Mutex When Initializing Global State For The
 _interpqueues Module (gh-125803)

This includes a drive-by cleanup in _queues_init() and _queues_fini().

This change also applies to the _interpchannels module.
---
 Modules/_interpchannelsmodule.c | 64 ++++++++++++++++++++++--------------
 Modules/_interpqueuesmodule.c   | 72 +++++++++++++++++++++++------------------
 2 files changed, 79 insertions(+), 57 deletions(-)

diff --git a/Modules/_interpchannelsmodule.c b/Modules/_interpchannelsmodule.c
index c52cde6..8e6b21d 100644
--- a/Modules/_interpchannelsmodule.c
+++ b/Modules/_interpchannelsmodule.c
@@ -28,6 +28,7 @@
 This module has the following process-global state:
 
 _globals (static struct globals):
+    mutex (PyMutex)
     module_count (int)
     channels (struct _channels):
         numopen (int64_t)
@@ -1349,21 +1350,29 @@ typedef struct _channels {
 static void
 _channels_init(_channels *channels, PyThread_type_lock mutex)
 {
-    channels->mutex = mutex;
-    channels->head = NULL;
-    channels->numopen = 0;
-    channels->next_id = 0;
+    assert(mutex != NULL);
+    assert(channels->mutex == NULL);
+    *channels = (_channels){
+        .mutex = mutex,
+        .head = NULL,
+        .numopen = 0,
+        .next_id = 0,
+    };
 }
 
 static void
-_channels_fini(_channels *channels)
+_channels_fini(_channels *channels, PyThread_type_lock *p_mutex)
 {
+    PyThread_type_lock mutex = channels->mutex;
+    assert(mutex != NULL);
+
+    PyThread_acquire_lock(mutex, WAIT_LOCK);
     assert(channels->numopen == 0);
     assert(channels->head == NULL);
-    if (channels->mutex != NULL) {
-        PyThread_free_lock(channels->mutex);
-        channels->mutex = NULL;
-    }
+    *channels = (_channels){0};
+    PyThread_release_lock(mutex);
+
+    *p_mutex = mutex;
 }
 
 static int64_t
@@ -2812,6 +2821,7 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
    the data that we need to share between interpreters, so it cannot
    hold PyObject values. */
 static struct globals {
+    PyMutex mutex;
     int module_count;
     _channels channels;
 } _globals = {0};
@@ -2819,32 +2829,36 @@ static struct globals {
 static int
 _globals_init(void)
 {
-    // XXX This isn't thread-safe.
+    PyMutex_Lock(&_globals.mutex);
+    assert(_globals.module_count >= 0);
     _globals.module_count++;
-    if (_globals.module_count > 1) {
-        // Already initialized.
-        return 0;
-    }
-
-    assert(_globals.channels.mutex == NULL);
-    PyThread_type_lock mutex = PyThread_allocate_lock();
-    if (mutex == NULL) {
-        return ERR_CHANNELS_MUTEX_INIT;
+    if (_globals.module_count == 1) {
+        // Called for the first time.
+        PyThread_type_lock mutex = PyThread_allocate_lock();
+        if (mutex == NULL) {
+            _globals.module_count--;
+            PyMutex_Unlock(&_globals.mutex);
+            return ERR_CHANNELS_MUTEX_INIT;
+        }
+        _channels_init(&_globals.channels, mutex);
     }
-    _channels_init(&_globals.channels, mutex);
+    PyMutex_Unlock(&_globals.mutex);
     return 0;
 }
 
 static void
 _globals_fini(void)
 {
-    // XXX This isn't thread-safe.
+    PyMutex_Lock(&_globals.mutex);
+    assert(_globals.module_count > 0);
     _globals.module_count--;
-    if (_globals.module_count > 0) {
-        return;
+    if (_globals.module_count == 0) {
+        PyThread_type_lock mutex;
+        _channels_fini(&_globals.channels, &mutex);
+        assert(mutex != NULL);
+        PyThread_free_lock(mutex);
     }
-
-    _channels_fini(&_globals.channels);
+    PyMutex_Unlock(&_globals.mutex);
 }
 
 static _channels *
diff --git a/Modules/_interpqueuesmodule.c b/Modules/_interpqueuesmodule.c
index aa70134..297a176 100644
--- a/Modules/_interpqueuesmodule.c
+++ b/Modules/_interpqueuesmodule.c
@@ -845,28 +845,31 @@ typedef struct _queues {
 static void
 _queues_init(_queues *queues, PyThread_type_lock mutex)
 {
-    queues->mutex = mutex;
-    queues->head = NULL;
-    queues->count = 0;
-    queues->next_id = 1;
+    assert(mutex != NULL);
+    assert(queues->mutex == NULL);
+    *queues = (_queues){
+        .mutex = mutex,
+        .head = NULL,
+        .count = 0,
+        .next_id = 1,
+    };
 }
 
 static void
-_queues_fini(_queues *queues)
+_queues_fini(_queues *queues, PyThread_type_lock *p_mutex)
 {
+    PyThread_type_lock mutex = queues->mutex;
+    assert(mutex != NULL);
+
+    PyThread_acquire_lock(mutex, WAIT_LOCK);
     if (queues->count > 0) {
-        PyThread_acquire_lock(queues->mutex, WAIT_LOCK);
-        assert((queues->count == 0) != (queues->head != NULL));
-        _queueref *head = queues->head;
-        queues->head = NULL;
-        queues->count = 0;
-        PyThread_release_lock(queues->mutex);
-        _queuerefs_clear(head);
-    }
-    if (queues->mutex != NULL) {
-        PyThread_free_lock(queues->mutex);
-        queues->mutex = NULL;
+        assert(queues->head != NULL);
+        _queuerefs_clear(queues->head);
     }
+    *queues = (_queues){0};
+    PyThread_release_lock(mutex);
+
+    *p_mutex = mutex;
 }
 
 static int64_t
@@ -1398,6 +1401,7 @@ _queueobj_shared(PyThreadState *tstate, PyObject *queueobj,
    the data that we need to share between interpreters, so it cannot
    hold PyObject values. */
 static struct globals {
+    PyMutex mutex;
     int module_count;
     _queues queues;
 } _globals = {0};
@@ -1405,32 +1409,36 @@ static struct globals {
 static int
 _globals_init(void)
 {
-    // XXX This isn't thread-safe.
+    PyMutex_Lock(&_globals.mutex);
+    assert(_globals.module_count >= 0);
     _globals.module_count++;
-    if (_globals.module_count > 1) {
-        // Already initialized.
-        return 0;
-    }
-
-    assert(_globals.queues.mutex == NULL);
-    PyThread_type_lock mutex = PyThread_allocate_lock();
-    if (mutex == NULL) {
-        return ERR_QUEUES_ALLOC;
+    if (_globals.module_count == 1) {
+        // Called for the first time.
+        PyThread_type_lock mutex = PyThread_allocate_lock();
+        if (mutex == NULL) {
+            _globals.module_count--;
+            PyMutex_Unlock(&_globals.mutex);
+            return ERR_QUEUES_ALLOC;
+        }
+        _queues_init(&_globals.queues, mutex);
     }
-    _queues_init(&_globals.queues, mutex);
+    PyMutex_Unlock(&_globals.mutex);
     return 0;
 }
 
 static void
 _globals_fini(void)
 {
-    // XXX This isn't thread-safe.
+    PyMutex_Lock(&_globals.mutex);
+    assert(_globals.module_count > 0);
     _globals.module_count--;
-    if (_globals.module_count > 0) {
-        return;
+    if (_globals.module_count == 0) {
+        PyThread_type_lock mutex;
+        _queues_fini(&_globals.queues, &mutex);
+        assert(mutex != NULL);
+        PyThread_free_lock(mutex);
     }
-
-    _queues_fini(&_globals.queues);
+    PyMutex_Unlock(&_globals.mutex);
 }
 
 static _queues *
-- 
cgit v0.12