From a53d7cb6729dc3f254b70afcf19eaf71a2eed540 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Tue, 17 Oct 2023 16:32:00 -0600 Subject: gh-84570: Send-Wait Fixes for _xxinterpchannels (gh-111006) There were a few things I did in gh-110565 that need to be fixed. I also forgot to add tests in that PR. (Note that this PR exposes a refleak introduced by gh-110246. I'll take care of that separately.) --- Include/internal/pycore_pythread.h | 15 ++ Lib/test/test__xxinterpchannels.py | 217 ++++++++++++++++----- Modules/_threadmodule.c | 52 +---- Modules/_xxinterpchannelsmodule.c | 385 +++++++++++++++++++++++++++++++------ Python/thread.c | 50 +++++ 5 files changed, 571 insertions(+), 148 deletions(-) diff --git a/Include/internal/pycore_pythread.h b/Include/internal/pycore_pythread.h index 8ce5a79..ffd7398 100644 --- a/Include/internal/pycore_pythread.h +++ b/Include/internal/pycore_pythread.h @@ -86,6 +86,21 @@ extern int _PyThread_at_fork_reinit(PyThread_type_lock *lock); #endif /* HAVE_FORK */ +// unset: -1 seconds, in nanoseconds +#define PyThread_UNSET_TIMEOUT ((_PyTime_t)(-1 * 1000 * 1000 * 1000)) + +/* Helper to acquire an interruptible lock with a timeout. If the lock acquire + * is interrupted, signal handlers are run, and if they raise an exception, + * PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE + * are returned, depending on whether the lock can be acquired within the + * timeout. + */ +// Exported for the _xxinterpchannels module. +PyAPI_FUNC(PyLockStatus) PyThread_acquire_lock_timed_with_retries( + PyThread_type_lock, + PY_TIMEOUT_T microseconds); + + #ifdef __cplusplus } #endif diff --git a/Lib/test/test__xxinterpchannels.py b/Lib/test/test__xxinterpchannels.py index ff01a33..90a1224 100644 --- a/Lib/test/test__xxinterpchannels.py +++ b/Lib/test/test__xxinterpchannels.py @@ -564,7 +564,62 @@ class ChannelTests(TestBase): with self.assertRaises(channels.ChannelClosedError): channels.list_interpreters(cid, send=False) - #################### + def test_allowed_types(self): + cid = channels.create() + objects = [ + None, + 'spam', + b'spam', + 42, + ] + for obj in objects: + with self.subTest(obj): + channels.send(cid, obj, blocking=False) + got = channels.recv(cid) + + self.assertEqual(got, obj) + self.assertIs(type(got), type(obj)) + # XXX Check the following? + #self.assertIsNot(got, obj) + # XXX What about between interpreters? + + def test_run_string_arg_unresolved(self): + cid = channels.create() + interp = interpreters.create() + + out = _run_output(interp, dedent(""" + import _xxinterpchannels as _channels + print(cid.end) + _channels.send(cid, b'spam', blocking=False) + """), + dict(cid=cid.send)) + obj = channels.recv(cid) + + self.assertEqual(obj, b'spam') + self.assertEqual(out.strip(), 'send') + + # XXX For now there is no high-level channel into which the + # sent channel ID can be converted... + # Note: this test caused crashes on some buildbots (bpo-33615). + @unittest.skip('disabled until high-level channels exist') + def test_run_string_arg_resolved(self): + cid = channels.create() + cid = channels._channel_id(cid, _resolve=True) + interp = interpreters.create() + + out = _run_output(interp, dedent(""" + import _xxinterpchannels as _channels + print(chan.id.end) + _channels.send(chan.id, b'spam', blocking=False) + """), + dict(chan=cid.send)) + obj = channels.recv(cid) + + self.assertEqual(obj, b'spam') + self.assertEqual(out.strip(), 'send') + + #------------------- + # send/recv def test_send_recv_main(self): cid = channels.create() @@ -705,6 +760,9 @@ class ChannelTests(TestBase): channels.recv(cid2) del cid2 + #------------------- + # send_buffer + def test_send_buffer(self): buf = bytearray(b'spamspamspam') cid = channels.create() @@ -720,60 +778,131 @@ class ChannelTests(TestBase): obj[4:8] = b'ham.' self.assertEqual(obj, buf) - def test_allowed_types(self): + #------------------- + # send with waiting + + def build_send_waiter(self, obj, *, buffer=False): + # We want a long enough sleep that send() actually has to wait. + + if buffer: + send = channels.send_buffer + else: + send = channels.send + cid = channels.create() - objects = [ - None, - 'spam', - b'spam', - 42, - ] - for obj in objects: - with self.subTest(obj): - channels.send(cid, obj, blocking=False) - got = channels.recv(cid) + try: + started = time.monotonic() + send(cid, obj, blocking=False) + stopped = time.monotonic() + channels.recv(cid) + finally: + channels.destroy(cid) + delay = stopped - started # seconds + delay *= 3 - self.assertEqual(got, obj) - self.assertIs(type(got), type(obj)) - # XXX Check the following? - #self.assertIsNot(got, obj) - # XXX What about between interpreters? + def wait(): + time.sleep(delay) + return wait - def test_run_string_arg_unresolved(self): + def test_send_blocking_waiting(self): + received = None + obj = b'spam' + wait = self.build_send_waiter(obj) cid = channels.create() - interp = interpreters.create() + def f(): + nonlocal received + wait() + received = recv_wait(cid) + t = threading.Thread(target=f) + t.start() + channels.send(cid, obj, blocking=True) + t.join() - out = _run_output(interp, dedent(""" - import _xxinterpchannels as _channels - print(cid.end) - _channels.send(cid, b'spam', blocking=False) - """), - dict(cid=cid.send)) - obj = channels.recv(cid) + self.assertEqual(received, obj) - self.assertEqual(obj, b'spam') - self.assertEqual(out.strip(), 'send') + def test_send_buffer_blocking_waiting(self): + received = None + obj = bytearray(b'spam') + wait = self.build_send_waiter(obj, buffer=True) + cid = channels.create() + def f(): + nonlocal received + wait() + received = recv_wait(cid) + t = threading.Thread(target=f) + t.start() + channels.send_buffer(cid, obj, blocking=True) + t.join() - # XXX For now there is no high-level channel into which the - # sent channel ID can be converted... - # Note: this test caused crashes on some buildbots (bpo-33615). - @unittest.skip('disabled until high-level channels exist') - def test_run_string_arg_resolved(self): + self.assertEqual(received, obj) + + def test_send_blocking_no_wait(self): + received = None + obj = b'spam' cid = channels.create() - cid = channels._channel_id(cid, _resolve=True) - interp = interpreters.create() + def f(): + nonlocal received + received = recv_wait(cid) + t = threading.Thread(target=f) + t.start() + channels.send(cid, obj, blocking=True) + t.join() - out = _run_output(interp, dedent(""" - import _xxinterpchannels as _channels - print(chan.id.end) - _channels.send(chan.id, b'spam', blocking=False) - """), - dict(chan=cid.send)) - obj = channels.recv(cid) + self.assertEqual(received, obj) - self.assertEqual(obj, b'spam') - self.assertEqual(out.strip(), 'send') + def test_send_buffer_blocking_no_wait(self): + received = None + obj = bytearray(b'spam') + cid = channels.create() + def f(): + nonlocal received + received = recv_wait(cid) + t = threading.Thread(target=f) + t.start() + channels.send_buffer(cid, obj, blocking=True) + t.join() + + self.assertEqual(received, obj) + + def test_send_closed_while_waiting(self): + obj = b'spam' + wait = self.build_send_waiter(obj) + cid = channels.create() + def f(): + wait() + channels.close(cid, force=True) + t = threading.Thread(target=f) + t.start() + with self.assertRaises(channels.ChannelClosedError): + channels.send(cid, obj, blocking=True) + t.join() + + def test_send_buffer_closed_while_waiting(self): + try: + self._has_run_once + except AttributeError: + # At the moment, this test leaks a few references. + # It looks like the leak originates with the addition + # of _channels.send_buffer() (gh-110246), whereas the + # tests were added afterward. We want this test even + # if the refleak isn't fixed yet, so we skip here. + raise unittest.SkipTest('temporarily skipped due to refleaks') + else: + self._has_run_once = True + + obj = bytearray(b'spam') + wait = self.build_send_waiter(obj, buffer=True) + cid = channels.create() + def f(): + wait() + channels.close(cid, force=True) + t = threading.Thread(target=f) + t.start() + with self.assertRaises(channels.ChannelClosedError): + channels.send_buffer(cid, obj, blocking=True) + t.join() + #------------------- # close def test_close_single_user(self): diff --git a/Modules/_threadmodule.c b/Modules/_threadmodule.c index 86bd560..7620511 100644 --- a/Modules/_threadmodule.c +++ b/Modules/_threadmodule.c @@ -3,7 +3,6 @@ /* Interface to Sjoerd's portable C thread library */ #include "Python.h" -#include "pycore_ceval.h" // _PyEval_MakePendingCalls() #include "pycore_dict.h" // _PyDict_Pop() #include "pycore_interp.h" // _PyInterpreterState.threads.count #include "pycore_moduleobject.h" // _PyModule_GetState() @@ -76,57 +75,10 @@ lock_dealloc(lockobject *self) Py_DECREF(tp); } -/* Helper to acquire an interruptible lock with a timeout. If the lock acquire - * is interrupted, signal handlers are run, and if they raise an exception, - * PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE - * are returned, depending on whether the lock can be acquired within the - * timeout. - */ -static PyLockStatus +static inline PyLockStatus acquire_timed(PyThread_type_lock lock, _PyTime_t timeout) { - PyThreadState *tstate = _PyThreadState_GET(); - _PyTime_t endtime = 0; - if (timeout > 0) { - endtime = _PyDeadline_Init(timeout); - } - - PyLockStatus r; - do { - _PyTime_t microseconds; - microseconds = _PyTime_AsMicroseconds(timeout, _PyTime_ROUND_CEILING); - - /* first a simple non-blocking try without releasing the GIL */ - r = PyThread_acquire_lock_timed(lock, 0, 0); - if (r == PY_LOCK_FAILURE && microseconds != 0) { - Py_BEGIN_ALLOW_THREADS - r = PyThread_acquire_lock_timed(lock, microseconds, 1); - Py_END_ALLOW_THREADS - } - - if (r == PY_LOCK_INTR) { - /* Run signal handlers if we were interrupted. Propagate - * exceptions from signal handlers, such as KeyboardInterrupt, by - * passing up PY_LOCK_INTR. */ - if (_PyEval_MakePendingCalls(tstate) < 0) { - return PY_LOCK_INTR; - } - - /* If we're using a timeout, recompute the timeout after processing - * signals, since those can take time. */ - if (timeout > 0) { - timeout = _PyDeadline_Get(endtime); - - /* Check for negative values, since those mean block forever. - */ - if (timeout < 0) { - r = PY_LOCK_FAILURE; - } - } - } - } while (r == PY_LOCK_INTR); /* Retry if we were interrupted. */ - - return r; + return PyThread_acquire_lock_timed_with_retries(lock, timeout); } static int diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c index 34efe9d..be53cbf 100644 --- a/Modules/_xxinterpchannelsmodule.c +++ b/Modules/_xxinterpchannelsmodule.c @@ -10,6 +10,13 @@ #include "pycore_pybuffer.h" // _PyBuffer_ReleaseInInterpreterAndRawFree() #include "pycore_interp.h" // _PyInterpreterState_LookUpID() +#ifdef MS_WINDOWS +#define WIN32_LEAN_AND_MEAN +#include // SwitchToThread() +#elif defined(HAVE_SCHED_H) +#include // sched_yield() +#endif + /* This module has the following process-global state: @@ -234,15 +241,25 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared, return cls; } -static void +static int wait_for_lock(PyThread_type_lock mutex) { - Py_BEGIN_ALLOW_THREADS - // XXX Handle eintr, etc. - PyThread_acquire_lock(mutex, WAIT_LOCK); - Py_END_ALLOW_THREADS - + PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT; + PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout); + if (res == PY_LOCK_INTR) { + /* KeyboardInterrupt, etc. */ + assert(PyErr_Occurred()); + return -1; + } + else if (res == PY_LOCK_FAILURE) { + assert(!PyErr_Occurred()); + assert(timeout > 0); + PyErr_SetString(PyExc_TimeoutError, "timed out"); + return -1; + } + assert(res == PY_LOCK_ACQUIRED); PyThread_release_lock(mutex); + return 0; } @@ -489,6 +506,7 @@ _get_current_xibufferview_type(void) #define ERR_CHANNEL_MUTEX_INIT -7 #define ERR_CHANNELS_MUTEX_INIT -8 #define ERR_NO_NEXT_CHANNEL_ID -9 +#define ERR_CHANNEL_CLOSED_WAITING -10 static int exceptions_init(PyObject *mod) @@ -540,6 +558,10 @@ handle_channel_error(int err, PyObject *mod, int64_t cid) PyErr_Format(state->ChannelClosedError, "channel %" PRId64 " is closed", cid); } + else if (err == ERR_CHANNEL_CLOSED_WAITING) { + PyErr_Format(state->ChannelClosedError, + "channel %" PRId64 " has closed", cid); + } else if (err == ERR_CHANNEL_INTERP_CLOSED) { PyErr_Format(state->ChannelClosedError, "channel %" PRId64 " is already closed", cid); @@ -574,36 +596,145 @@ handle_channel_error(int err, PyObject *mod, int64_t cid) /* the channel queue */ +typedef uintptr_t _channelitem_id_t; + +typedef struct wait_info { + PyThread_type_lock mutex; + enum { + WAITING_NO_STATUS = 0, + WAITING_ACQUIRED = 1, + WAITING_RELEASING = 2, + WAITING_RELEASED = 3, + } status; + int received; + _channelitem_id_t itemid; +} _waiting_t; + +static int +_waiting_init(_waiting_t *waiting) +{ + PyThread_type_lock mutex = PyThread_allocate_lock(); + if (mutex == NULL) { + PyErr_NoMemory(); + return -1; + } + + *waiting = (_waiting_t){ + .mutex = mutex, + .status = WAITING_NO_STATUS, + }; + return 0; +} + +static void +_waiting_clear(_waiting_t *waiting) +{ + assert(waiting->status != WAITING_ACQUIRED + && waiting->status != WAITING_RELEASING); + if (waiting->mutex != NULL) { + PyThread_free_lock(waiting->mutex); + waiting->mutex = NULL; + } +} + +static _channelitem_id_t +_waiting_get_itemid(_waiting_t *waiting) +{ + return waiting->itemid; +} + +static void +_waiting_acquire(_waiting_t *waiting) +{ + assert(waiting->status == WAITING_NO_STATUS); + PyThread_acquire_lock(waiting->mutex, NOWAIT_LOCK); + waiting->status = WAITING_ACQUIRED; +} + +static void +_waiting_release(_waiting_t *waiting, int received) +{ + assert(waiting->mutex != NULL); + assert(waiting->status == WAITING_ACQUIRED); + assert(!waiting->received); + + waiting->status = WAITING_RELEASING; + PyThread_release_lock(waiting->mutex); + if (waiting->received != received) { + assert(received == 1); + waiting->received = received; + } + waiting->status = WAITING_RELEASED; +} + +static void +_waiting_finish_releasing(_waiting_t *waiting) +{ + while (waiting->status == WAITING_RELEASING) { +#ifdef MS_WINDOWS + SwitchToThread(); +#elif defined(HAVE_SCHED_H) + sched_yield(); +#endif + } +} + struct _channelitem; typedef struct _channelitem { _PyCrossInterpreterData *data; - PyThread_type_lock recv_mutex; + _waiting_t *waiting; struct _channelitem *next; } _channelitem; -static _channelitem * -_channelitem_new(void) +static inline _channelitem_id_t +_channelitem_ID(_channelitem *item) { - _channelitem *item = GLOBAL_MALLOC(_channelitem); - if (item == NULL) { - PyErr_NoMemory(); - return NULL; + return (_channelitem_id_t)item; +} + +static void +_channelitem_init(_channelitem *item, + _PyCrossInterpreterData *data, _waiting_t *waiting) +{ + *item = (_channelitem){ + .data = data, + .waiting = waiting, + }; + if (waiting != NULL) { + waiting->itemid = _channelitem_ID(item); } - item->data = NULL; - item->next = NULL; - return item; } static void _channelitem_clear(_channelitem *item) { + item->next = NULL; + if (item->data != NULL) { // It was allocated in _channel_send(). (void)_release_xid_data(item->data, XID_IGNORE_EXC & XID_FREE); item->data = NULL; } - item->next = NULL; + + if (item->waiting != NULL) { + if (item->waiting->status == WAITING_ACQUIRED) { + _waiting_release(item->waiting, 0); + } + item->waiting = NULL; + } +} + +static _channelitem * +_channelitem_new(_PyCrossInterpreterData *data, _waiting_t *waiting) +{ + _channelitem *item = GLOBAL_MALLOC(_channelitem); + if (item == NULL) { + PyErr_NoMemory(); + return NULL; + } + _channelitem_init(item, data, waiting); + return item; } static void @@ -623,14 +754,17 @@ _channelitem_free_all(_channelitem *item) } } -static _PyCrossInterpreterData * -_channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex) +static void +_channelitem_popped(_channelitem *item, + _PyCrossInterpreterData **p_data, _waiting_t **p_waiting) { - _PyCrossInterpreterData *data = item->data; + assert(item->waiting == NULL || item->waiting->status == WAITING_ACQUIRED); + *p_data = item->data; + *p_waiting = item->waiting; + // We clear them here, so they won't be released in _channelitem_clear(). item->data = NULL; - *recv_mutex = item->recv_mutex; + item->waiting = NULL; _channelitem_free(item); - return data; } typedef struct _channelqueue { @@ -670,15 +804,13 @@ _channelqueue_free(_channelqueue *queue) } static int -_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data, - PyThread_type_lock recv_mutex) +_channelqueue_put(_channelqueue *queue, + _PyCrossInterpreterData *data, _waiting_t *waiting) { - _channelitem *item = _channelitem_new(); + _channelitem *item = _channelitem_new(data, waiting); if (item == NULL) { return -1; } - item->data = data; - item->recv_mutex = recv_mutex; queue->count += 1; if (queue->first == NULL) { @@ -688,15 +820,21 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data, queue->last->next = item; } queue->last = item; + + if (waiting != NULL) { + _waiting_acquire(waiting); + } + return 0; } -static _PyCrossInterpreterData * -_channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex) +static int +_channelqueue_get(_channelqueue *queue, + _PyCrossInterpreterData **p_data, _waiting_t **p_waiting) { _channelitem *item = queue->first; if (item == NULL) { - return NULL; + return ERR_CHANNEL_EMPTY; } queue->first = item->next; if (queue->last == item) { @@ -704,7 +842,73 @@ _channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex) } queue->count -= 1; - return _channelitem_popped(item, recv_mutex); + _channelitem_popped(item, p_data, p_waiting); + return 0; +} + +static int +_channelqueue_find(_channelqueue *queue, _channelitem_id_t itemid, + _channelitem **p_item, _channelitem **p_prev) +{ + _channelitem *prev = NULL; + _channelitem *item = NULL; + if (queue->first != NULL) { + if (_channelitem_ID(queue->first) == itemid) { + item = queue->first; + } + else { + prev = queue->first; + while (prev->next != NULL) { + if (_channelitem_ID(prev->next) == itemid) { + item = prev->next; + break; + } + prev = prev->next; + } + if (item == NULL) { + prev = NULL; + } + } + } + if (p_item != NULL) { + *p_item = item; + } + if (p_prev != NULL) { + *p_prev = prev; + } + return (item != NULL); +} + +static void +_channelqueue_remove(_channelqueue *queue, _channelitem_id_t itemid, + _PyCrossInterpreterData **p_data, _waiting_t **p_waiting) +{ + _channelitem *prev = NULL; + _channelitem *item = NULL; + int found = _channelqueue_find(queue, itemid, &item, &prev); + if (!found) { + return; + } + + assert(item->waiting != NULL); + assert(!item->waiting->received); + if (prev == NULL) { + assert(queue->first == item); + queue->first = item->next; + } + else { + assert(queue->first != item); + assert(prev->next == item); + prev->next = item->next; + } + item->next = NULL; + + if (queue->last == item) { + queue->last = prev; + } + queue->count -= 1; + + _channelitem_popped(item, p_data, p_waiting); } static void @@ -1021,7 +1225,7 @@ _channel_free(_PyChannelState *chan) static int _channel_add(_PyChannelState *chan, int64_t interp, - _PyCrossInterpreterData *data, PyThread_type_lock recv_mutex) + _PyCrossInterpreterData *data, _waiting_t *waiting) { int res = -1; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -1035,9 +1239,10 @@ _channel_add(_PyChannelState *chan, int64_t interp, goto done; } - if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) { + if (_channelqueue_put(chan->queue, data, waiting) != 0) { goto done; } + // Any errors past this point must cause a _waiting_release() call. res = 0; done: @@ -1047,7 +1252,7 @@ done: static int _channel_next(_PyChannelState *chan, int64_t interp, - _PyCrossInterpreterData **res) + _PyCrossInterpreterData **p_data, _waiting_t **p_waiting) { int err = 0; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); @@ -1061,16 +1266,12 @@ _channel_next(_PyChannelState *chan, int64_t interp, goto done; } - PyThread_type_lock recv_mutex = NULL; - _PyCrossInterpreterData *data = _channelqueue_get(chan->queue, &recv_mutex); - if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) { + int empty = _channelqueue_get(chan->queue, p_data, p_waiting); + assert(empty == 0 || empty == ERR_CHANNEL_EMPTY); + assert(!PyErr_Occurred()); + if (empty && chan->closing != NULL) { chan->open = 0; } - *res = data; - - if (recv_mutex != NULL) { - PyThread_release_lock(recv_mutex); - } done: PyThread_release_lock(chan->mutex); @@ -1080,6 +1281,26 @@ done: return err; } +static void +_channel_remove(_PyChannelState *chan, _channelitem_id_t itemid) +{ + _PyCrossInterpreterData *data = NULL; + _waiting_t *waiting = NULL; + + PyThread_acquire_lock(chan->mutex, WAIT_LOCK); + _channelqueue_remove(chan->queue, itemid, &data, &waiting); + PyThread_release_lock(chan->mutex); + + (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE); + if (waiting != NULL) { + _waiting_release(waiting, 0); + } + + if (chan->queue->count == 0) { + _channel_finish_closing(chan); + } +} + static int _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end) { @@ -1592,7 +1813,7 @@ _channel_destroy(_channels *channels, int64_t id) static int _channel_send(_channels *channels, int64_t id, PyObject *obj, - PyThread_type_lock recv_mutex) + _waiting_t *waiting) { PyInterpreterState *interp = _get_current_interp(); if (interp == NULL) { @@ -1627,8 +1848,8 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj, } // Add the data to the channel. - int res = _channel_add(chan, PyInterpreterState_GetID(interp), data, - recv_mutex); + int res = _channel_add(chan, PyInterpreterState_GetID(interp), + data, waiting); PyThread_release_lock(mutex); if (res != 0) { // We may chain an exception here: @@ -1640,31 +1861,74 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj, return 0; } +static void +_channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting) +{ + // Look up the channel. + PyThread_type_lock mutex = NULL; + _PyChannelState *chan = NULL; + int err = _channels_lookup(channels, cid, &mutex, &chan); + if (err != 0) { + // The channel was already closed, etc. + assert(waiting->status == WAITING_RELEASED); + return; // Ignore the error. + } + assert(chan != NULL); + // Past this point we are responsible for releasing the mutex. + + _channelitem_id_t itemid = _waiting_get_itemid(waiting); + _channel_remove(chan, itemid); + + PyThread_release_lock(mutex); +} + static int _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj) { - PyThread_type_lock mutex = PyThread_allocate_lock(); - if (mutex == NULL) { - PyErr_NoMemory(); + // We use a stack variable here, so we must ensure that &waiting + // is not held by any channel item at the point this function exits. + _waiting_t waiting; + if (_waiting_init(&waiting) < 0) { + assert(PyErr_Occurred()); return -1; } - PyThread_acquire_lock(mutex, NOWAIT_LOCK); /* Queue up the object. */ - int res = _channel_send(channels, cid, obj, mutex); + int res = _channel_send(channels, cid, obj, &waiting); if (res < 0) { - PyThread_release_lock(mutex); + assert(waiting.status == WAITING_NO_STATUS); goto finally; } /* Wait until the object is received. */ - wait_for_lock(mutex); + if (wait_for_lock(waiting.mutex) < 0) { + assert(PyErr_Occurred()); + _waiting_finish_releasing(&waiting); + /* The send() call is failing now, so make sure the item + won't be received. */ + _channel_clear_sent(channels, cid, &waiting); + assert(waiting.status == WAITING_RELEASED); + if (!waiting.received) { + res = -1; + goto finally; + } + // XXX Emit a warning if not a TimeoutError? + PyErr_Clear(); + } + else { + _waiting_finish_releasing(&waiting); + assert(waiting.status == WAITING_RELEASED); + if (!waiting.received) { + res = ERR_CHANNEL_CLOSED_WAITING; + goto finally; + } + } /* success! */ res = 0; finally: - // XXX Delete the lock. + _waiting_clear(&waiting); return res; } @@ -1695,7 +1959,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res) // Pop off the next item from the channel. _PyCrossInterpreterData *data = NULL; - err = _channel_next(chan, PyInterpreterState_GetID(interp), &data); + _waiting_t *waiting = NULL; + err = _channel_next(chan, PyInterpreterState_GetID(interp), &data, + &waiting); PyThread_release_lock(mutex); if (err != 0) { return err; @@ -1711,6 +1977,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res) assert(PyErr_Occurred()); // It was allocated in _channel_send(), so we free it. (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE); + if (waiting != NULL) { + _waiting_release(waiting, 0); + } return -1; } // It was allocated in _channel_send(), so we free it. @@ -1719,9 +1988,17 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res) // The source interpreter has been destroyed already. assert(PyErr_Occurred()); Py_DECREF(obj); + if (waiting != NULL) { + _waiting_release(waiting, 0); + } return -1; } + // Notify the sender. + if (waiting != NULL) { + _waiting_release(waiting, 1); + } + *res = obj; return 0; } diff --git a/Python/thread.c b/Python/thread.c index bf207ce..7185dd4 100644 --- a/Python/thread.c +++ b/Python/thread.c @@ -6,6 +6,7 @@ Stuff shared by all thread_*.h files is collected here. */ #include "Python.h" +#include "pycore_ceval.h" // _PyEval_MakePendingCalls() #include "pycore_pystate.h" // _PyInterpreterState_GET() #include "pycore_structseq.h" // _PyStructSequence_FiniBuiltin() #include "pycore_pythread.h" // _POSIX_THREADS @@ -92,6 +93,55 @@ PyThread_set_stacksize(size_t size) } +PyLockStatus +PyThread_acquire_lock_timed_with_retries(PyThread_type_lock lock, + PY_TIMEOUT_T timeout) +{ + PyThreadState *tstate = _PyThreadState_GET(); + _PyTime_t endtime = 0; + if (timeout > 0) { + endtime = _PyDeadline_Init(timeout); + } + + PyLockStatus r; + do { + _PyTime_t microseconds; + microseconds = _PyTime_AsMicroseconds(timeout, _PyTime_ROUND_CEILING); + + /* first a simple non-blocking try without releasing the GIL */ + r = PyThread_acquire_lock_timed(lock, 0, 0); + if (r == PY_LOCK_FAILURE && microseconds != 0) { + Py_BEGIN_ALLOW_THREADS + r = PyThread_acquire_lock_timed(lock, microseconds, 1); + Py_END_ALLOW_THREADS + } + + if (r == PY_LOCK_INTR) { + /* Run signal handlers if we were interrupted. Propagate + * exceptions from signal handlers, such as KeyboardInterrupt, by + * passing up PY_LOCK_INTR. */ + if (_PyEval_MakePendingCalls(tstate) < 0) { + return PY_LOCK_INTR; + } + + /* If we're using a timeout, recompute the timeout after processing + * signals, since those can take time. */ + if (timeout > 0) { + timeout = _PyDeadline_Get(endtime); + + /* Check for negative values, since those mean block forever. + */ + if (timeout < 0) { + r = PY_LOCK_FAILURE; + } + } + } + } while (r == PY_LOCK_INTR); /* Retry if we were interrupted. */ + + return r; +} + + /* Thread Specific Storage (TSS) API Cross-platform components of TSS API implementation. -- cgit v0.12