summaryrefslogtreecommitdiffstats
path: root/Modules/_xxinterpchannelsmodule.c
diff options
context:
space:
mode:
authorEric Snow <ericsnowcurrently@gmail.com>2023-10-17 22:32:00 (GMT)
committerGitHub <noreply@github.com>2023-10-17 22:32:00 (GMT)
commita53d7cb6729dc3f254b70afcf19eaf71a2eed540 (patch)
tree679f68637995e4bfefc98f0ec40edac6150693cd /Modules/_xxinterpchannelsmodule.c
parente37620edfd77b78b913b5eab55cd91327c3e7fd3 (diff)
downloadcpython-a53d7cb6729dc3f254b70afcf19eaf71a2eed540.zip
cpython-a53d7cb6729dc3f254b70afcf19eaf71a2eed540.tar.gz
cpython-a53d7cb6729dc3f254b70afcf19eaf71a2eed540.tar.bz2
gh-84570: Send-Wait Fixes for _xxinterpchannels (gh-111006)
There were a few things I did in gh-110565 that need to be fixed. I also forgot to add tests in that PR. (Note that this PR exposes a refleak introduced by gh-110246. I'll take care of that separately.)
Diffstat (limited to 'Modules/_xxinterpchannelsmodule.c')
-rw-r--r--Modules/_xxinterpchannelsmodule.c385
1 files changed, 331 insertions, 54 deletions
diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c
index 34efe9d..be53cbf 100644
--- a/Modules/_xxinterpchannelsmodule.c
+++ b/Modules/_xxinterpchannelsmodule.c
@@ -10,6 +10,13 @@
#include "pycore_pybuffer.h" // _PyBuffer_ReleaseInInterpreterAndRawFree()
#include "pycore_interp.h" // _PyInterpreterState_LookUpID()
+#ifdef MS_WINDOWS
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h> // SwitchToThread()
+#elif defined(HAVE_SCHED_H)
+#include <sched.h> // sched_yield()
+#endif
+
/*
This module has the following process-global state:
@@ -234,15 +241,25 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
return cls;
}
-static void
+static int
wait_for_lock(PyThread_type_lock mutex)
{
- Py_BEGIN_ALLOW_THREADS
- // XXX Handle eintr, etc.
- PyThread_acquire_lock(mutex, WAIT_LOCK);
- Py_END_ALLOW_THREADS
-
+ PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT;
+ PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout);
+ if (res == PY_LOCK_INTR) {
+ /* KeyboardInterrupt, etc. */
+ assert(PyErr_Occurred());
+ return -1;
+ }
+ else if (res == PY_LOCK_FAILURE) {
+ assert(!PyErr_Occurred());
+ assert(timeout > 0);
+ PyErr_SetString(PyExc_TimeoutError, "timed out");
+ return -1;
+ }
+ assert(res == PY_LOCK_ACQUIRED);
PyThread_release_lock(mutex);
+ return 0;
}
@@ -489,6 +506,7 @@ _get_current_xibufferview_type(void)
#define ERR_CHANNEL_MUTEX_INIT -7
#define ERR_CHANNELS_MUTEX_INIT -8
#define ERR_NO_NEXT_CHANNEL_ID -9
+#define ERR_CHANNEL_CLOSED_WAITING -10
static int
exceptions_init(PyObject *mod)
@@ -540,6 +558,10 @@ handle_channel_error(int err, PyObject *mod, int64_t cid)
PyErr_Format(state->ChannelClosedError,
"channel %" PRId64 " is closed", cid);
}
+ else if (err == ERR_CHANNEL_CLOSED_WAITING) {
+ PyErr_Format(state->ChannelClosedError,
+ "channel %" PRId64 " has closed", cid);
+ }
else if (err == ERR_CHANNEL_INTERP_CLOSED) {
PyErr_Format(state->ChannelClosedError,
"channel %" PRId64 " is already closed", cid);
@@ -574,36 +596,145 @@ handle_channel_error(int err, PyObject *mod, int64_t cid)
/* the channel queue */
+typedef uintptr_t _channelitem_id_t;
+
+typedef struct wait_info {
+ PyThread_type_lock mutex;
+ enum {
+ WAITING_NO_STATUS = 0,
+ WAITING_ACQUIRED = 1,
+ WAITING_RELEASING = 2,
+ WAITING_RELEASED = 3,
+ } status;
+ int received;
+ _channelitem_id_t itemid;
+} _waiting_t;
+
+static int
+_waiting_init(_waiting_t *waiting)
+{
+ PyThread_type_lock mutex = PyThread_allocate_lock();
+ if (mutex == NULL) {
+ PyErr_NoMemory();
+ return -1;
+ }
+
+ *waiting = (_waiting_t){
+ .mutex = mutex,
+ .status = WAITING_NO_STATUS,
+ };
+ return 0;
+}
+
+static void
+_waiting_clear(_waiting_t *waiting)
+{
+ assert(waiting->status != WAITING_ACQUIRED
+ && waiting->status != WAITING_RELEASING);
+ if (waiting->mutex != NULL) {
+ PyThread_free_lock(waiting->mutex);
+ waiting->mutex = NULL;
+ }
+}
+
+static _channelitem_id_t
+_waiting_get_itemid(_waiting_t *waiting)
+{
+ return waiting->itemid;
+}
+
+static void
+_waiting_acquire(_waiting_t *waiting)
+{
+ assert(waiting->status == WAITING_NO_STATUS);
+ PyThread_acquire_lock(waiting->mutex, NOWAIT_LOCK);
+ waiting->status = WAITING_ACQUIRED;
+}
+
+static void
+_waiting_release(_waiting_t *waiting, int received)
+{
+ assert(waiting->mutex != NULL);
+ assert(waiting->status == WAITING_ACQUIRED);
+ assert(!waiting->received);
+
+ waiting->status = WAITING_RELEASING;
+ PyThread_release_lock(waiting->mutex);
+ if (waiting->received != received) {
+ assert(received == 1);
+ waiting->received = received;
+ }
+ waiting->status = WAITING_RELEASED;
+}
+
+static void
+_waiting_finish_releasing(_waiting_t *waiting)
+{
+ while (waiting->status == WAITING_RELEASING) {
+#ifdef MS_WINDOWS
+ SwitchToThread();
+#elif defined(HAVE_SCHED_H)
+ sched_yield();
+#endif
+ }
+}
+
struct _channelitem;
typedef struct _channelitem {
_PyCrossInterpreterData *data;
- PyThread_type_lock recv_mutex;
+ _waiting_t *waiting;
struct _channelitem *next;
} _channelitem;
-static _channelitem *
-_channelitem_new(void)
+static inline _channelitem_id_t
+_channelitem_ID(_channelitem *item)
{
- _channelitem *item = GLOBAL_MALLOC(_channelitem);
- if (item == NULL) {
- PyErr_NoMemory();
- return NULL;
+ return (_channelitem_id_t)item;
+}
+
+static void
+_channelitem_init(_channelitem *item,
+ _PyCrossInterpreterData *data, _waiting_t *waiting)
+{
+ *item = (_channelitem){
+ .data = data,
+ .waiting = waiting,
+ };
+ if (waiting != NULL) {
+ waiting->itemid = _channelitem_ID(item);
}
- item->data = NULL;
- item->next = NULL;
- return item;
}
static void
_channelitem_clear(_channelitem *item)
{
+ item->next = NULL;
+
if (item->data != NULL) {
// It was allocated in _channel_send().
(void)_release_xid_data(item->data, XID_IGNORE_EXC & XID_FREE);
item->data = NULL;
}
- item->next = NULL;
+
+ if (item->waiting != NULL) {
+ if (item->waiting->status == WAITING_ACQUIRED) {
+ _waiting_release(item->waiting, 0);
+ }
+ item->waiting = NULL;
+ }
+}
+
+static _channelitem *
+_channelitem_new(_PyCrossInterpreterData *data, _waiting_t *waiting)
+{
+ _channelitem *item = GLOBAL_MALLOC(_channelitem);
+ if (item == NULL) {
+ PyErr_NoMemory();
+ return NULL;
+ }
+ _channelitem_init(item, data, waiting);
+ return item;
}
static void
@@ -623,14 +754,17 @@ _channelitem_free_all(_channelitem *item)
}
}
-static _PyCrossInterpreterData *
-_channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex)
+static void
+_channelitem_popped(_channelitem *item,
+ _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
{
- _PyCrossInterpreterData *data = item->data;
+ assert(item->waiting == NULL || item->waiting->status == WAITING_ACQUIRED);
+ *p_data = item->data;
+ *p_waiting = item->waiting;
+ // We clear them here, so they won't be released in _channelitem_clear().
item->data = NULL;
- *recv_mutex = item->recv_mutex;
+ item->waiting = NULL;
_channelitem_free(item);
- return data;
}
typedef struct _channelqueue {
@@ -670,15 +804,13 @@ _channelqueue_free(_channelqueue *queue)
}
static int
-_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
- PyThread_type_lock recv_mutex)
+_channelqueue_put(_channelqueue *queue,
+ _PyCrossInterpreterData *data, _waiting_t *waiting)
{
- _channelitem *item = _channelitem_new();
+ _channelitem *item = _channelitem_new(data, waiting);
if (item == NULL) {
return -1;
}
- item->data = data;
- item->recv_mutex = recv_mutex;
queue->count += 1;
if (queue->first == NULL) {
@@ -688,15 +820,21 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
queue->last->next = item;
}
queue->last = item;
+
+ if (waiting != NULL) {
+ _waiting_acquire(waiting);
+ }
+
return 0;
}
-static _PyCrossInterpreterData *
-_channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
+static int
+_channelqueue_get(_channelqueue *queue,
+ _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
{
_channelitem *item = queue->first;
if (item == NULL) {
- return NULL;
+ return ERR_CHANNEL_EMPTY;
}
queue->first = item->next;
if (queue->last == item) {
@@ -704,7 +842,73 @@ _channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
}
queue->count -= 1;
- return _channelitem_popped(item, recv_mutex);
+ _channelitem_popped(item, p_data, p_waiting);
+ return 0;
+}
+
+static int
+_channelqueue_find(_channelqueue *queue, _channelitem_id_t itemid,
+ _channelitem **p_item, _channelitem **p_prev)
+{
+ _channelitem *prev = NULL;
+ _channelitem *item = NULL;
+ if (queue->first != NULL) {
+ if (_channelitem_ID(queue->first) == itemid) {
+ item = queue->first;
+ }
+ else {
+ prev = queue->first;
+ while (prev->next != NULL) {
+ if (_channelitem_ID(prev->next) == itemid) {
+ item = prev->next;
+ break;
+ }
+ prev = prev->next;
+ }
+ if (item == NULL) {
+ prev = NULL;
+ }
+ }
+ }
+ if (p_item != NULL) {
+ *p_item = item;
+ }
+ if (p_prev != NULL) {
+ *p_prev = prev;
+ }
+ return (item != NULL);
+}
+
+static void
+_channelqueue_remove(_channelqueue *queue, _channelitem_id_t itemid,
+ _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
+{
+ _channelitem *prev = NULL;
+ _channelitem *item = NULL;
+ int found = _channelqueue_find(queue, itemid, &item, &prev);
+ if (!found) {
+ return;
+ }
+
+ assert(item->waiting != NULL);
+ assert(!item->waiting->received);
+ if (prev == NULL) {
+ assert(queue->first == item);
+ queue->first = item->next;
+ }
+ else {
+ assert(queue->first != item);
+ assert(prev->next == item);
+ prev->next = item->next;
+ }
+ item->next = NULL;
+
+ if (queue->last == item) {
+ queue->last = prev;
+ }
+ queue->count -= 1;
+
+ _channelitem_popped(item, p_data, p_waiting);
}
static void
@@ -1021,7 +1225,7 @@ _channel_free(_PyChannelState *chan)
static int
_channel_add(_PyChannelState *chan, int64_t interp,
- _PyCrossInterpreterData *data, PyThread_type_lock recv_mutex)
+ _PyCrossInterpreterData *data, _waiting_t *waiting)
{
int res = -1;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@@ -1035,9 +1239,10 @@ _channel_add(_PyChannelState *chan, int64_t interp,
goto done;
}
- if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) {
+ if (_channelqueue_put(chan->queue, data, waiting) != 0) {
goto done;
}
+ // Any errors past this point must cause a _waiting_release() call.
res = 0;
done:
@@ -1047,7 +1252,7 @@ done:
static int
_channel_next(_PyChannelState *chan, int64_t interp,
- _PyCrossInterpreterData **res)
+ _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
{
int err = 0;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@@ -1061,16 +1266,12 @@ _channel_next(_PyChannelState *chan, int64_t interp,
goto done;
}
- PyThread_type_lock recv_mutex = NULL;
- _PyCrossInterpreterData *data = _channelqueue_get(chan->queue, &recv_mutex);
- if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
+ int empty = _channelqueue_get(chan->queue, p_data, p_waiting);
+ assert(empty == 0 || empty == ERR_CHANNEL_EMPTY);
+ assert(!PyErr_Occurred());
+ if (empty && chan->closing != NULL) {
chan->open = 0;
}
- *res = data;
-
- if (recv_mutex != NULL) {
- PyThread_release_lock(recv_mutex);
- }
done:
PyThread_release_lock(chan->mutex);
@@ -1080,6 +1281,26 @@ done:
return err;
}
+static void
+_channel_remove(_PyChannelState *chan, _channelitem_id_t itemid)
+{
+ _PyCrossInterpreterData *data = NULL;
+ _waiting_t *waiting = NULL;
+
+ PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+ _channelqueue_remove(chan->queue, itemid, &data, &waiting);
+ PyThread_release_lock(chan->mutex);
+
+ (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE);
+ if (waiting != NULL) {
+ _waiting_release(waiting, 0);
+ }
+
+ if (chan->queue->count == 0) {
+ _channel_finish_closing(chan);
+ }
+}
+
static int
_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
{
@@ -1592,7 +1813,7 @@ _channel_destroy(_channels *channels, int64_t id)
static int
_channel_send(_channels *channels, int64_t id, PyObject *obj,
- PyThread_type_lock recv_mutex)
+ _waiting_t *waiting)
{
PyInterpreterState *interp = _get_current_interp();
if (interp == NULL) {
@@ -1627,8 +1848,8 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj,
}
// Add the data to the channel.
- int res = _channel_add(chan, PyInterpreterState_GetID(interp), data,
- recv_mutex);
+ int res = _channel_add(chan, PyInterpreterState_GetID(interp),
+ data, waiting);
PyThread_release_lock(mutex);
if (res != 0) {
// We may chain an exception here:
@@ -1640,31 +1861,74 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj,
return 0;
}
+static void
+_channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting)
+{
+ // Look up the channel.
+ PyThread_type_lock mutex = NULL;
+ _PyChannelState *chan = NULL;
+ int err = _channels_lookup(channels, cid, &mutex, &chan);
+ if (err != 0) {
+ // The channel was already closed, etc.
+ assert(waiting->status == WAITING_RELEASED);
+ return; // Ignore the error.
+ }
+ assert(chan != NULL);
+ // Past this point we are responsible for releasing the mutex.
+
+ _channelitem_id_t itemid = _waiting_get_itemid(waiting);
+ _channel_remove(chan, itemid);
+
+ PyThread_release_lock(mutex);
+}
+
static int
_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
{
- PyThread_type_lock mutex = PyThread_allocate_lock();
- if (mutex == NULL) {
- PyErr_NoMemory();
+ // We use a stack variable here, so we must ensure that &waiting
+ // is not held by any channel item at the point this function exits.
+ _waiting_t waiting;
+ if (_waiting_init(&waiting) < 0) {
+ assert(PyErr_Occurred());
return -1;
}
- PyThread_acquire_lock(mutex, NOWAIT_LOCK);
/* Queue up the object. */
- int res = _channel_send(channels, cid, obj, mutex);
+ int res = _channel_send(channels, cid, obj, &waiting);
if (res < 0) {
- PyThread_release_lock(mutex);
+ assert(waiting.status == WAITING_NO_STATUS);
goto finally;
}
/* Wait until the object is received. */
- wait_for_lock(mutex);
+ if (wait_for_lock(waiting.mutex) < 0) {
+ assert(PyErr_Occurred());
+ _waiting_finish_releasing(&waiting);
+ /* The send() call is failing now, so make sure the item
+ won't be received. */
+ _channel_clear_sent(channels, cid, &waiting);
+ assert(waiting.status == WAITING_RELEASED);
+ if (!waiting.received) {
+ res = -1;
+ goto finally;
+ }
+ // XXX Emit a warning if not a TimeoutError?
+ PyErr_Clear();
+ }
+ else {
+ _waiting_finish_releasing(&waiting);
+ assert(waiting.status == WAITING_RELEASED);
+ if (!waiting.received) {
+ res = ERR_CHANNEL_CLOSED_WAITING;
+ goto finally;
+ }
+ }
/* success! */
res = 0;
finally:
- // XXX Delete the lock.
+ _waiting_clear(&waiting);
return res;
}
@@ -1695,7 +1959,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res)
// Pop off the next item from the channel.
_PyCrossInterpreterData *data = NULL;
- err = _channel_next(chan, PyInterpreterState_GetID(interp), &data);
+ _waiting_t *waiting = NULL;
+ err = _channel_next(chan, PyInterpreterState_GetID(interp), &data,
+ &waiting);
PyThread_release_lock(mutex);
if (err != 0) {
return err;
@@ -1711,6 +1977,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res)
assert(PyErr_Occurred());
// It was allocated in _channel_send(), so we free it.
(void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE);
+ if (waiting != NULL) {
+ _waiting_release(waiting, 0);
+ }
return -1;
}
// It was allocated in _channel_send(), so we free it.
@@ -1719,9 +1988,17 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res)
// The source interpreter has been destroyed already.
assert(PyErr_Occurred());
Py_DECREF(obj);
+ if (waiting != NULL) {
+ _waiting_release(waiting, 0);
+ }
return -1;
}
+ // Notify the sender.
+ if (waiting != NULL) {
+ _waiting_release(waiting, 1);
+ }
+
*res = obj;
return 0;
}