diff options
author | Eric Snow <ericsnowcurrently@gmail.com> | 2023-10-17 22:32:00 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-17 22:32:00 (GMT) |
commit | a53d7cb6729dc3f254b70afcf19eaf71a2eed540 (patch) | |
tree | 679f68637995e4bfefc98f0ec40edac6150693cd /Modules | |
parent | e37620edfd77b78b913b5eab55cd91327c3e7fd3 (diff) | |
download | cpython-a53d7cb6729dc3f254b70afcf19eaf71a2eed540.zip cpython-a53d7cb6729dc3f254b70afcf19eaf71a2eed540.tar.gz cpython-a53d7cb6729dc3f254b70afcf19eaf71a2eed540.tar.bz2 |
gh-84570: Send-Wait Fixes for _xxinterpchannels (gh-111006)
There were a few things I did in gh-110565 that need to be fixed. I also forgot to add tests in that PR.
(Note that this PR exposes a refleak introduced by gh-110246. I'll take care of that separately.)
Diffstat (limited to 'Modules')
-rw-r--r-- | Modules/_threadmodule.c | 52 | ||||
-rw-r--r-- | Modules/_xxinterpchannelsmodule.c | 385 |
2 files changed, 333 insertions, 104 deletions
diff --git a/Modules/_threadmodule.c b/Modules/_threadmodule.c index 86bd560..7620511 100644 --- a/Modules/_threadmodule.c +++ b/Modules/_threadmodule.c @@ -3,7 +3,6 @@ /* Interface to Sjoerd's portable C thread library */ #include "Python.h" -#include "pycore_ceval.h" // _PyEval_MakePendingCalls() #include "pycore_dict.h" // _PyDict_Pop() #include "pycore_interp.h" // _PyInterpreterState.threads.count #include "pycore_moduleobject.h" // _PyModule_GetState() @@ -76,57 +75,10 @@ lock_dealloc(lockobject *self) Py_DECREF(tp); } -/* Helper to acquire an interruptible lock with a timeout. If the lock acquire - * is interrupted, signal handlers are run, and if they raise an exception, - * PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE - * are returned, depending on whether the lock can be acquired within the - * timeout. - */ -static PyLockStatus +static inline PyLockStatus acquire_timed(PyThread_type_lock lock, _PyTime_t timeout) { - PyThreadState *tstate = _PyThreadState_GET(); - _PyTime_t endtime = 0; - if (timeout > 0) { - endtime = _PyDeadline_Init(timeout); - } - - PyLockStatus r; - do { - _PyTime_t microseconds; - microseconds = _PyTime_AsMicroseconds(timeout, _PyTime_ROUND_CEILING); - - /* first a simple non-blocking try without releasing the GIL */ - r = PyThread_acquire_lock_timed(lock, 0, 0); - if (r == PY_LOCK_FAILURE && microseconds != 0) { - Py_BEGIN_ALLOW_THREADS - r = PyThread_acquire_lock_timed(lock, microseconds, 1); - Py_END_ALLOW_THREADS - } - - if (r == PY_LOCK_INTR) { - /* Run signal handlers if we were interrupted. Propagate - * exceptions from signal handlers, such as KeyboardInterrupt, by - * passing up PY_LOCK_INTR. */ - if (_PyEval_MakePendingCalls(tstate) < 0) { - return PY_LOCK_INTR; - } - - /* If we're using a timeout, recompute the timeout after processing - * signals, since those can take time. */ - if (timeout > 0) { - timeout = _PyDeadline_Get(endtime); - - /* Check for negative values, since those mean block forever. - */ - if (timeout < 0) { - r = PY_LOCK_FAILURE; - } - } - } - } while (r == PY_LOCK_INTR); /* Retry if we were interrupted. */ - - return r; + return PyThread_acquire_lock_timed_with_retries(lock, timeout); } static int diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c index 34efe9d..be53cbf 100644 --- a/Modules/_xxinterpchannelsmodule.c +++ b/Modules/_xxinterpchannelsmodule.c @@ -10,6 +10,13 @@ #include "pycore_pybuffer.h" // _PyBuffer_ReleaseInInterpreterAndRawFree() #include "pycore_interp.h" // _PyInterpreterState_LookUpID() +#ifdef MS_WINDOWS +#define WIN32_LEAN_AND_MEAN +#include <windows.h> // SwitchToThread() +#elif defined(HAVE_SCHED_H) +#include <sched.h> // sched_yield() +#endif + /* This module has the following process-global state: @@ -234,15 +241,25 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared, return cls; } -static void +static int wait_for_lock(PyThread_type_lock mutex) { - Py_BEGIN_ALLOW_THREADS - // XXX Handle eintr, etc. - PyThread_acquire_lock(mutex, WAIT_LOCK); - Py_END_ALLOW_THREADS - + PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT; + PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout); + if (res == PY_LOCK_INTR) { + /* KeyboardInterrupt, etc. */ + assert(PyErr_Occurred()); + return -1; + } + else if (res == PY_LOCK_FAILURE) { + assert(!PyErr_Occurred()); + assert(timeout > 0); + PyErr_SetString(PyExc_TimeoutError, "timed out"); + return -1; + } + assert(res == PY_LOCK_ACQUIRED); PyThread_release_lock(mutex); + return 0; } @@ -489,6 +506,7 @@ _get_current_xibufferview_type(void) #define ERR_CHANNEL_MUTEX_INIT -7 #define ERR_CHANNELS_MUTEX_INIT -8 #define ERR_NO_NEXT_CHANNEL_ID -9 +#define ERR_CHANNEL_CLOSED_WAITING -10 static int exceptions_init(PyObject *mod) @@ -540,6 +558,10 @@ handle_channel_error(int err, PyObject *mod, int64_t cid) PyErr_Format(state->ChannelClosedError, "channel %" PRId64 " is closed", cid); } + else if (err == ERR_CHANNEL_CLOSED_WAITING) { + PyErr_Format(state->ChannelClosedError, + "channel %" PRId64 " has closed", cid); + } else if (err == ERR_CHANNEL_INTERP_CLOSED) { PyErr_Format(state->ChannelClosedError, "channel %" PRId64 " is already closed", cid); @@ -574,36 +596,145 @@ handle_channel_error(int err, PyObject *mod, int64_t cid) /* the channel queue */ +typedef uintptr_t _channelitem_id_t; + +typedef struct wait_info { + PyThread_type_lock mutex; + enum { + WAITING_NO_STATUS = 0, + WAITING_ACQUIRED = 1, + WAITING_RELEASING = 2, + WAITING_RELEASED = 3, + } status; + int received; + _channelitem_id_t itemid; +} _waiting_t; + +static int +_waiting_init(_waiting_t *waiting) +{ + PyThread_type_lock mutex = PyThread_allocate_lock(); + if (mutex == NULL) { + PyErr_NoMemory(); + return -1; + } + + *waiting = (_waiting_t){ + .mutex = mutex, + .status = WAITING_NO_STATUS, + }; + return 0; +} + +static void +_waiting_clear(_waiting_t *waiting) +{ + assert(waiting->status != WAITING_ACQUIRED + && waiting->status != WAITING_RELEASING); + if (waiting->mutex != NULL) { + PyThread_free_lock(waiting->mutex); + waiting->mutex = NULL; + } +} + +static _channelitem_id_t +_waiting_get_itemid(_waiting_t *waiting) +{ + return waiting->itemid; +} + +static void +_waiting_acquire(_waiting_t *waiting) +{ + assert(waiting->status == WAITING_NO_STATUS); + PyThread_acquire_lock(waiting->mutex, NOWAIT_LOCK); + waiting->status = WAITING_ACQUIRED; +} + +static void +_waiting_release(_waiting_t *waiting, int received) +{ + assert(waiting->mutex != NULL); + assert(waiting->status == WAITING_ACQUIRED); + assert(!waiting->received); + + waiting->status = WAITING_RELEASING; + PyThread_release_lock(waiting->mutex); + if (waiting->received != received) { + assert(received == 1); + waiting->received = received; + } + waiting->status = WAITING_RELEASED; +} + +static void +_waiting_finish_releasing(_waiting_t *waiting) +{ + while (waiting->status == WAITING_RELEASING) { +#ifdef MS_WINDOWS + SwitchToThread(); +#elif defined(HAVE_SCHED_H) + sched_yield(); +#endif + } +} + struct _channelitem; typedef struct _channelitem { _PyCrossInterpreterData *data; - PyThread_type_lock recv_mutex; + _waiting_t *waiting; struct _channelitem *next; } _channelitem; -static _channelitem * -_channelitem_new(void) +static inline _channelitem_id_t +_channelitem_ID(_channelitem *item) { - _channelitem *item = GLOBAL_MALLOC(_channelitem); - if (item == NULL) { - PyErr_NoMemory(); - return NULL; + return (_channelitem_id_t)item; +} + +static void +_channelitem_init(_channelitem *item, + _PyCrossInterpreterData *data, _waiting_t *waiting) +{ + *item = (_channelitem){ + .data = data, + .waiting = waiting, + }; + if (waiting != NULL) { + waiting->itemid = _channelitem_ID(item); } - item->data = NULL; - item->next = NULL; - return item; } static void _channelitem_clear(_channelitem *item) { + item->next = NULL; + if (item->data != NULL) { // It was allocated in _channel_send(). (void)_release_xid_data(item->data, XID_IGNORE_EXC & XID_FREE); item->data = NULL; } - item->next = NULL; + + if (item->waiting != NULL) { + if (item->waiting->status == WAITING_ACQUIRED) { + _waiting_release(item->waiting, 0); + } + item->waiting = NULL; + } +} + +static _channelitem * +_channelitem_new(_PyCrossInterpreterData *data, _waiting_t *waiting) +{ + _channelitem *item = GLOBAL_MALLOC(_channelitem); + if (item == NULL) { + PyErr_NoMemory(); + return NULL; + } + _channelitem_init(item, data, waiting); + return item; } static void @@ -623,14 +754,17 @@ _channelitem_free_all(_channelitem *item) } } -static _PyCrossInterpreterData * -_channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex) +static void +_channelitem_popped(_channelitem *item, + _PyCrossInterpreterData **p_data, _waiting_t **p_waiting) { - _PyCrossInterpreterData *data = item->data; + assert(item->waiting == NULL || item->waiting->status == WAITING_ACQUIRED); + *p_data = item->data; + *p_waiting = item->waiting; + // We clear them here, so they won't be released in _channelitem_clear(). item->data = NULL; - *recv_mutex = item->recv_mutex; + item->waiting = NULL; _channelitem_free(item); - return data; } typedef struct _channelqueue { @@ -670,15 +804,13 @@ _channelqueue_free(_channelqueue *queue) } static int -_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data, - PyThread_type_lock recv_mutex) +_channelqueue_put(_channelqueue *queue, + _PyCrossInterpreterData *data, _waiting_t *waiting) { - _channelitem *item = _channelitem_new(); + _channelitem *item = _channelitem_new(data, waiting); if (item == NULL) { return -1; } - item->data = data; - item->recv_mutex = recv_mutex; queue->count += 1; if (queue->first == NULL) { @@ -688,15 +820,21 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data, queue->last->next = item; } queue->last = item; + + if (waiting != NULL) { + _waiting_acquire(waiting); + } + return 0; } -static _PyCrossInterpreterData * -_channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex) +static int +_channelqueue_get(_channelqueue *queue, + _PyCrossInterpreterData **p_data, _waiting_t **p_waiting) { _channelitem *item = queue->first; if (item == NULL) { - return NULL; + return ERR_CHANNEL_EMPTY; } queue->first = item->next; if (queue->last == item) { @@ -704,7 +842,73 @@ _channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex) } queue->count -= 1; - return _channelitem_popped(item, recv_mutex); + _channelitem_popped(item, p_data, p_waiting); + return 0; +} + +static int +_channelqueue_find(_channelqueue *queue, _channelitem_id_t itemid, + _channelitem **p_item, _channelitem **p_prev) +{ + _channelitem *prev = NULL; + _channelitem *item = NULL; + if (queue->first != NULL) { + if (_channelitem_ID(queue->first) == itemid) { + item = queue->first; + } + else { + prev = queue->first; + while (prev->next != NULL) { + if (_channelitem_ID(prev->next) == itemid) { + item = prev->next; + break; + } + prev = prev->next; + } + if (item == NULL) { + prev = NULL; + } + } + } + if (p_item != NULL) { + *p_item = item; + } + if (p_prev != NULL) { + *p_prev = prev; + } + return (item != NULL); +} + +static void +_channelqueue_remove(_channelqueue *queue, _channelitem_id_t itemid, + _PyCrossInterpreterData **p_data, _waiting_t **p_waiting) +{ + _channelitem *prev = NULL; + _channelitem *item = NULL; + int found = _channelqueue_find(queue, itemid, &item, &prev); + if (!found) { + return; + } + + assert(item->waiting != NULL); + assert(!item->waiting->received); + if (prev == NULL) { + assert(queue->first == item); + queue->first = item->next; + } + else { + assert(queue->first != item); + assert(prev->next == item); + prev->next = item->next; + } + item->next = NULL; + + if (queue->last == item) { + queue->last = prev; + } + queue->count -= 1; + + _channelitem_popped(item, p_data, p_waiting); } static void @@ -1021,7 +1225,7 @@ _channel_free(_PyChannelState *chan) static int _channel_add(_PyChannelState *chan, int64_t interp, - _PyCrossInterpreterData *data, PyThread_type_lock recv_mutex) + _PyCrossInterpreterData *data, _waiting_t *waiting) { int res = -1; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -1035,9 +1239,10 @@ _channel_add(_PyChannelState *chan, int64_t interp, goto done; } - if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) { + if (_channelqueue_put(chan->queue, data, waiting) != 0) { goto done; } + // Any errors past this point must cause a _waiting_release() call. res = 0; done: @@ -1047,7 +1252,7 @@ done: static int _channel_next(_PyChannelState *chan, int64_t interp, - _PyCrossInterpreterData **res) + _PyCrossInterpreterData **p_data, _waiting_t **p_waiting) { int err = 0; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -1061,16 +1266,12 @@ _channel_next(_PyChannelState *chan, int64_t interp, goto done; } - PyThread_type_lock recv_mutex = NULL; - _PyCrossInterpreterData *data = _channelqueue_get(chan->queue, &recv_mutex); - if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) { + int empty = _channelqueue_get(chan->queue, p_data, p_waiting); + assert(empty == 0 || empty == ERR_CHANNEL_EMPTY); + assert(!PyErr_Occurred()); + if (empty && chan->closing != NULL) { chan->open = 0; } - *res = data; - - if (recv_mutex != NULL) { - PyThread_release_lock(recv_mutex); - } done: PyThread_release_lock(chan->mutex); @@ -1080,6 +1281,26 @@ done: return err; } +static void +_channel_remove(_PyChannelState *chan, _channelitem_id_t itemid) +{ + _PyCrossInterpreterData *data = NULL; + _waiting_t *waiting = NULL; + + PyThread_acquire_lock(chan->mutex, WAIT_LOCK); + _channelqueue_remove(chan->queue, itemid, &data, &waiting); + PyThread_release_lock(chan->mutex); + + (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE); + if (waiting != NULL) { + _waiting_release(waiting, 0); + } + + if (chan->queue->count == 0) { + _channel_finish_closing(chan); + } +} + static int _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end) { @@ -1592,7 +1813,7 @@ _channel_destroy(_channels *channels, int64_t id) static int _channel_send(_channels *channels, int64_t id, PyObject *obj, - PyThread_type_lock recv_mutex) + _waiting_t *waiting) { PyInterpreterState *interp = _get_current_interp(); if (interp == NULL) { @@ -1627,8 +1848,8 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj, } // Add the data to the channel. - int res = _channel_add(chan, PyInterpreterState_GetID(interp), data, - recv_mutex); + int res = _channel_add(chan, PyInterpreterState_GetID(interp), + data, waiting); PyThread_release_lock(mutex); if (res != 0) { // We may chain an exception here: @@ -1640,31 +1861,74 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj, return 0; } +static void +_channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting) +{ + // Look up the channel. + PyThread_type_lock mutex = NULL; + _PyChannelState *chan = NULL; + int err = _channels_lookup(channels, cid, &mutex, &chan); + if (err != 0) { + // The channel was already closed, etc. + assert(waiting->status == WAITING_RELEASED); + return; // Ignore the error. + } + assert(chan != NULL); + // Past this point we are responsible for releasing the mutex. + + _channelitem_id_t itemid = _waiting_get_itemid(waiting); + _channel_remove(chan, itemid); + + PyThread_release_lock(mutex); +} + static int _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj) { - PyThread_type_lock mutex = PyThread_allocate_lock(); - if (mutex == NULL) { - PyErr_NoMemory(); + // We use a stack variable here, so we must ensure that &waiting + // is not held by any channel item at the point this function exits. + _waiting_t waiting; + if (_waiting_init(&waiting) < 0) { + assert(PyErr_Occurred()); return -1; } - PyThread_acquire_lock(mutex, NOWAIT_LOCK); /* Queue up the object. */ - int res = _channel_send(channels, cid, obj, mutex); + int res = _channel_send(channels, cid, obj, &waiting); if (res < 0) { - PyThread_release_lock(mutex); + assert(waiting.status == WAITING_NO_STATUS); goto finally; } /* Wait until the object is received. */ - wait_for_lock(mutex); + if (wait_for_lock(waiting.mutex) < 0) { + assert(PyErr_Occurred()); + _waiting_finish_releasing(&waiting); + /* The send() call is failing now, so make sure the item + won't be received. */ + _channel_clear_sent(channels, cid, &waiting); + assert(waiting.status == WAITING_RELEASED); + if (!waiting.received) { + res = -1; + goto finally; + } + // XXX Emit a warning if not a TimeoutError? + PyErr_Clear(); + } + else { + _waiting_finish_releasing(&waiting); + assert(waiting.status == WAITING_RELEASED); + if (!waiting.received) { + res = ERR_CHANNEL_CLOSED_WAITING; + goto finally; + } + } /* success! */ res = 0; finally: - // XXX Delete the lock. + _waiting_clear(&waiting); return res; } @@ -1695,7 +1959,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res) // Pop off the next item from the channel. _PyCrossInterpreterData *data = NULL; - err = _channel_next(chan, PyInterpreterState_GetID(interp), &data); + _waiting_t *waiting = NULL; + err = _channel_next(chan, PyInterpreterState_GetID(interp), &data, + &waiting); PyThread_release_lock(mutex); if (err != 0) { return err; @@ -1711,6 +1977,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res) assert(PyErr_Occurred()); // It was allocated in _channel_send(), so we free it. (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE); + if (waiting != NULL) { + _waiting_release(waiting, 0); + } return -1; } // It was allocated in _channel_send(), so we free it. @@ -1719,9 +1988,17 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res) // The source interpreter has been destroyed already. assert(PyErr_Occurred()); Py_DECREF(obj); + if (waiting != NULL) { + _waiting_release(waiting, 0); + } return -1; } + // Notify the sender. + if (waiting != NULL) { + _waiting_release(waiting, 1); + } + *res = obj; return 0; } |