summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Include/internal/pycore_pythread.h15
-rw-r--r--Lib/test/test__xxinterpchannels.py217
-rw-r--r--Modules/_threadmodule.c52
-rw-r--r--Modules/_xxinterpchannelsmodule.c385
-rw-r--r--Python/thread.c50
5 files changed, 571 insertions, 148 deletions
diff --git a/Include/internal/pycore_pythread.h b/Include/internal/pycore_pythread.h
index 8ce5a79..ffd7398 100644
--- a/Include/internal/pycore_pythread.h
+++ b/Include/internal/pycore_pythread.h
@@ -86,6 +86,21 @@ extern int _PyThread_at_fork_reinit(PyThread_type_lock *lock);
#endif /* HAVE_FORK */
+// unset: -1 seconds, in nanoseconds
+#define PyThread_UNSET_TIMEOUT ((_PyTime_t)(-1 * 1000 * 1000 * 1000))
+
+/* Helper to acquire an interruptible lock with a timeout. If the lock acquire
+ * is interrupted, signal handlers are run, and if they raise an exception,
+ * PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE
+ * are returned, depending on whether the lock can be acquired within the
+ * timeout.
+ */
+// Exported for the _xxinterpchannels module.
+PyAPI_FUNC(PyLockStatus) PyThread_acquire_lock_timed_with_retries(
+ PyThread_type_lock,
+ PY_TIMEOUT_T microseconds);
+
+
#ifdef __cplusplus
}
#endif
diff --git a/Lib/test/test__xxinterpchannels.py b/Lib/test/test__xxinterpchannels.py
index ff01a33..90a1224 100644
--- a/Lib/test/test__xxinterpchannels.py
+++ b/Lib/test/test__xxinterpchannels.py
@@ -564,7 +564,62 @@ class ChannelTests(TestBase):
with self.assertRaises(channels.ChannelClosedError):
channels.list_interpreters(cid, send=False)
- ####################
+ def test_allowed_types(self):
+ cid = channels.create()
+ objects = [
+ None,
+ 'spam',
+ b'spam',
+ 42,
+ ]
+ for obj in objects:
+ with self.subTest(obj):
+ channels.send(cid, obj, blocking=False)
+ got = channels.recv(cid)
+
+ self.assertEqual(got, obj)
+ self.assertIs(type(got), type(obj))
+ # XXX Check the following?
+ #self.assertIsNot(got, obj)
+ # XXX What about between interpreters?
+
+ def test_run_string_arg_unresolved(self):
+ cid = channels.create()
+ interp = interpreters.create()
+
+ out = _run_output(interp, dedent("""
+ import _xxinterpchannels as _channels
+ print(cid.end)
+ _channels.send(cid, b'spam', blocking=False)
+ """),
+ dict(cid=cid.send))
+ obj = channels.recv(cid)
+
+ self.assertEqual(obj, b'spam')
+ self.assertEqual(out.strip(), 'send')
+
+ # XXX For now there is no high-level channel into which the
+ # sent channel ID can be converted...
+ # Note: this test caused crashes on some buildbots (bpo-33615).
+ @unittest.skip('disabled until high-level channels exist')
+ def test_run_string_arg_resolved(self):
+ cid = channels.create()
+ cid = channels._channel_id(cid, _resolve=True)
+ interp = interpreters.create()
+
+ out = _run_output(interp, dedent("""
+ import _xxinterpchannels as _channels
+ print(chan.id.end)
+ _channels.send(chan.id, b'spam', blocking=False)
+ """),
+ dict(chan=cid.send))
+ obj = channels.recv(cid)
+
+ self.assertEqual(obj, b'spam')
+ self.assertEqual(out.strip(), 'send')
+
+ #-------------------
+ # send/recv
def test_send_recv_main(self):
cid = channels.create()
@@ -705,6 +760,9 @@ class ChannelTests(TestBase):
channels.recv(cid2)
del cid2
+ #-------------------
+ # send_buffer
+
def test_send_buffer(self):
buf = bytearray(b'spamspamspam')
cid = channels.create()
@@ -720,60 +778,131 @@ class ChannelTests(TestBase):
obj[4:8] = b'ham.'
self.assertEqual(obj, buf)
- def test_allowed_types(self):
+ #-------------------
+ # send with waiting
+
+ def build_send_waiter(self, obj, *, buffer=False):
+ # We want a long enough sleep that send() actually has to wait.
+
+ if buffer:
+ send = channels.send_buffer
+ else:
+ send = channels.send
+
cid = channels.create()
- objects = [
- None,
- 'spam',
- b'spam',
- 42,
- ]
- for obj in objects:
- with self.subTest(obj):
- channels.send(cid, obj, blocking=False)
- got = channels.recv(cid)
+ try:
+ started = time.monotonic()
+ send(cid, obj, blocking=False)
+ stopped = time.monotonic()
+ channels.recv(cid)
+ finally:
+ channels.destroy(cid)
+ delay = stopped - started # seconds
+ delay *= 3
- self.assertEqual(got, obj)
- self.assertIs(type(got), type(obj))
- # XXX Check the following?
- #self.assertIsNot(got, obj)
- # XXX What about between interpreters?
+ def wait():
+ time.sleep(delay)
+ return wait
- def test_run_string_arg_unresolved(self):
+ def test_send_blocking_waiting(self):
+ received = None
+ obj = b'spam'
+ wait = self.build_send_waiter(obj)
cid = channels.create()
- interp = interpreters.create()
+ def f():
+ nonlocal received
+ wait()
+ received = recv_wait(cid)
+ t = threading.Thread(target=f)
+ t.start()
+ channels.send(cid, obj, blocking=True)
+ t.join()
- out = _run_output(interp, dedent("""
- import _xxinterpchannels as _channels
- print(cid.end)
- _channels.send(cid, b'spam', blocking=False)
- """),
- dict(cid=cid.send))
- obj = channels.recv(cid)
+ self.assertEqual(received, obj)
- self.assertEqual(obj, b'spam')
- self.assertEqual(out.strip(), 'send')
+ def test_send_buffer_blocking_waiting(self):
+ received = None
+ obj = bytearray(b'spam')
+ wait = self.build_send_waiter(obj, buffer=True)
+ cid = channels.create()
+ def f():
+ nonlocal received
+ wait()
+ received = recv_wait(cid)
+ t = threading.Thread(target=f)
+ t.start()
+ channels.send_buffer(cid, obj, blocking=True)
+ t.join()
- # XXX For now there is no high-level channel into which the
- # sent channel ID can be converted...
- # Note: this test caused crashes on some buildbots (bpo-33615).
- @unittest.skip('disabled until high-level channels exist')
- def test_run_string_arg_resolved(self):
+ self.assertEqual(received, obj)
+
+ def test_send_blocking_no_wait(self):
+ received = None
+ obj = b'spam'
cid = channels.create()
- cid = channels._channel_id(cid, _resolve=True)
- interp = interpreters.create()
+ def f():
+ nonlocal received
+ received = recv_wait(cid)
+ t = threading.Thread(target=f)
+ t.start()
+ channels.send(cid, obj, blocking=True)
+ t.join()
- out = _run_output(interp, dedent("""
- import _xxinterpchannels as _channels
- print(chan.id.end)
- _channels.send(chan.id, b'spam', blocking=False)
- """),
- dict(chan=cid.send))
- obj = channels.recv(cid)
+ self.assertEqual(received, obj)
- self.assertEqual(obj, b'spam')
- self.assertEqual(out.strip(), 'send')
+ def test_send_buffer_blocking_no_wait(self):
+ received = None
+ obj = bytearray(b'spam')
+ cid = channels.create()
+ def f():
+ nonlocal received
+ received = recv_wait(cid)
+ t = threading.Thread(target=f)
+ t.start()
+ channels.send_buffer(cid, obj, blocking=True)
+ t.join()
+
+ self.assertEqual(received, obj)
+
+ def test_send_closed_while_waiting(self):
+ obj = b'spam'
+ wait = self.build_send_waiter(obj)
+ cid = channels.create()
+ def f():
+ wait()
+ channels.close(cid, force=True)
+ t = threading.Thread(target=f)
+ t.start()
+ with self.assertRaises(channels.ChannelClosedError):
+ channels.send(cid, obj, blocking=True)
+ t.join()
+
+ def test_send_buffer_closed_while_waiting(self):
+ try:
+ self._has_run_once
+ except AttributeError:
+ # At the moment, this test leaks a few references.
+ # It looks like the leak originates with the addition
+ # of _channels.send_buffer() (gh-110246), whereas the
+ # tests were added afterward. We want this test even
+ # if the refleak isn't fixed yet, so we skip here.
+ raise unittest.SkipTest('temporarily skipped due to refleaks')
+ else:
+ self._has_run_once = True
+
+ obj = bytearray(b'spam')
+ wait = self.build_send_waiter(obj, buffer=True)
+ cid = channels.create()
+ def f():
+ wait()
+ channels.close(cid, force=True)
+ t = threading.Thread(target=f)
+ t.start()
+ with self.assertRaises(channels.ChannelClosedError):
+ channels.send_buffer(cid, obj, blocking=True)
+ t.join()
+ #-------------------
# close
def test_close_single_user(self):
diff --git a/Modules/_threadmodule.c b/Modules/_threadmodule.c
index 86bd560..7620511 100644
--- a/Modules/_threadmodule.c
+++ b/Modules/_threadmodule.c
@@ -3,7 +3,6 @@
/* Interface to Sjoerd's portable C thread library */
#include "Python.h"
-#include "pycore_ceval.h" // _PyEval_MakePendingCalls()
#include "pycore_dict.h" // _PyDict_Pop()
#include "pycore_interp.h" // _PyInterpreterState.threads.count
#include "pycore_moduleobject.h" // _PyModule_GetState()
@@ -76,57 +75,10 @@ lock_dealloc(lockobject *self)
Py_DECREF(tp);
}
-/* Helper to acquire an interruptible lock with a timeout. If the lock acquire
- * is interrupted, signal handlers are run, and if they raise an exception,
- * PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE
- * are returned, depending on whether the lock can be acquired within the
- * timeout.
- */
-static PyLockStatus
+static inline PyLockStatus
acquire_timed(PyThread_type_lock lock, _PyTime_t timeout)
{
- PyThreadState *tstate = _PyThreadState_GET();
- _PyTime_t endtime = 0;
- if (timeout > 0) {
- endtime = _PyDeadline_Init(timeout);
- }
-
- PyLockStatus r;
- do {
- _PyTime_t microseconds;
- microseconds = _PyTime_AsMicroseconds(timeout, _PyTime_ROUND_CEILING);
-
- /* first a simple non-blocking try without releasing the GIL */
- r = PyThread_acquire_lock_timed(lock, 0, 0);
- if (r == PY_LOCK_FAILURE && microseconds != 0) {
- Py_BEGIN_ALLOW_THREADS
- r = PyThread_acquire_lock_timed(lock, microseconds, 1);
- Py_END_ALLOW_THREADS
- }
-
- if (r == PY_LOCK_INTR) {
- /* Run signal handlers if we were interrupted. Propagate
- * exceptions from signal handlers, such as KeyboardInterrupt, by
- * passing up PY_LOCK_INTR. */
- if (_PyEval_MakePendingCalls(tstate) < 0) {
- return PY_LOCK_INTR;
- }
-
- /* If we're using a timeout, recompute the timeout after processing
- * signals, since those can take time. */
- if (timeout > 0) {
- timeout = _PyDeadline_Get(endtime);
-
- /* Check for negative values, since those mean block forever.
- */
- if (timeout < 0) {
- r = PY_LOCK_FAILURE;
- }
- }
- }
- } while (r == PY_LOCK_INTR); /* Retry if we were interrupted. */
-
- return r;
+ return PyThread_acquire_lock_timed_with_retries(lock, timeout);
}
static int
diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c
index 34efe9d..be53cbf 100644
--- a/Modules/_xxinterpchannelsmodule.c
+++ b/Modules/_xxinterpchannelsmodule.c
@@ -10,6 +10,13 @@
#include "pycore_pybuffer.h" // _PyBuffer_ReleaseInInterpreterAndRawFree()
#include "pycore_interp.h" // _PyInterpreterState_LookUpID()
+#ifdef MS_WINDOWS
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h> // SwitchToThread()
+#elif defined(HAVE_SCHED_H)
+#include <sched.h> // sched_yield()
+#endif
+
/*
This module has the following process-global state:
@@ -234,15 +241,25 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
return cls;
}
-static void
+static int
wait_for_lock(PyThread_type_lock mutex)
{
- Py_BEGIN_ALLOW_THREADS
- // XXX Handle eintr, etc.
- PyThread_acquire_lock(mutex, WAIT_LOCK);
- Py_END_ALLOW_THREADS
-
+ PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT;
+ PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout);
+ if (res == PY_LOCK_INTR) {
+ /* KeyboardInterrupt, etc. */
+ assert(PyErr_Occurred());
+ return -1;
+ }
+ else if (res == PY_LOCK_FAILURE) {
+ assert(!PyErr_Occurred());
+ assert(timeout > 0);
+ PyErr_SetString(PyExc_TimeoutError, "timed out");
+ return -1;
+ }
+ assert(res == PY_LOCK_ACQUIRED);
PyThread_release_lock(mutex);
+ return 0;
}
@@ -489,6 +506,7 @@ _get_current_xibufferview_type(void)
#define ERR_CHANNEL_MUTEX_INIT -7
#define ERR_CHANNELS_MUTEX_INIT -8
#define ERR_NO_NEXT_CHANNEL_ID -9
+#define ERR_CHANNEL_CLOSED_WAITING -10
static int
exceptions_init(PyObject *mod)
@@ -540,6 +558,10 @@ handle_channel_error(int err, PyObject *mod, int64_t cid)
PyErr_Format(state->ChannelClosedError,
"channel %" PRId64 " is closed", cid);
}
+ else if (err == ERR_CHANNEL_CLOSED_WAITING) {
+ PyErr_Format(state->ChannelClosedError,
+ "channel %" PRId64 " has closed", cid);
+ }
else if (err == ERR_CHANNEL_INTERP_CLOSED) {
PyErr_Format(state->ChannelClosedError,
"channel %" PRId64 " is already closed", cid);
@@ -574,36 +596,145 @@ handle_channel_error(int err, PyObject *mod, int64_t cid)
/* the channel queue */
+typedef uintptr_t _channelitem_id_t;
+
+typedef struct wait_info {
+ PyThread_type_lock mutex;
+ enum {
+ WAITING_NO_STATUS = 0,
+ WAITING_ACQUIRED = 1,
+ WAITING_RELEASING = 2,
+ WAITING_RELEASED = 3,
+ } status;
+ int received;
+ _channelitem_id_t itemid;
+} _waiting_t;
+
+static int
+_waiting_init(_waiting_t *waiting)
+{
+ PyThread_type_lock mutex = PyThread_allocate_lock();
+ if (mutex == NULL) {
+ PyErr_NoMemory();
+ return -1;
+ }
+
+ *waiting = (_waiting_t){
+ .mutex = mutex,
+ .status = WAITING_NO_STATUS,
+ };
+ return 0;
+}
+
+static void
+_waiting_clear(_waiting_t *waiting)
+{
+ assert(waiting->status != WAITING_ACQUIRED
+ && waiting->status != WAITING_RELEASING);
+ if (waiting->mutex != NULL) {
+ PyThread_free_lock(waiting->mutex);
+ waiting->mutex = NULL;
+ }
+}
+
+static _channelitem_id_t
+_waiting_get_itemid(_waiting_t *waiting)
+{
+ return waiting->itemid;
+}
+
+static void
+_waiting_acquire(_waiting_t *waiting)
+{
+ assert(waiting->status == WAITING_NO_STATUS);
+ PyThread_acquire_lock(waiting->mutex, NOWAIT_LOCK);
+ waiting->status = WAITING_ACQUIRED;
+}
+
+static void
+_waiting_release(_waiting_t *waiting, int received)
+{
+ assert(waiting->mutex != NULL);
+ assert(waiting->status == WAITING_ACQUIRED);
+ assert(!waiting->received);
+
+ waiting->status = WAITING_RELEASING;
+ PyThread_release_lock(waiting->mutex);
+ if (waiting->received != received) {
+ assert(received == 1);
+ waiting->received = received;
+ }
+ waiting->status = WAITING_RELEASED;
+}
+
+static void
+_waiting_finish_releasing(_waiting_t *waiting)
+{
+ while (waiting->status == WAITING_RELEASING) {
+#ifdef MS_WINDOWS
+ SwitchToThread();
+#elif defined(HAVE_SCHED_H)
+ sched_yield();
+#endif
+ }
+}
+
struct _channelitem;
typedef struct _channelitem {
_PyCrossInterpreterData *data;
- PyThread_type_lock recv_mutex;
+ _waiting_t *waiting;
struct _channelitem *next;
} _channelitem;
-static _channelitem *
-_channelitem_new(void)
+static inline _channelitem_id_t
+_channelitem_ID(_channelitem *item)
{
- _channelitem *item = GLOBAL_MALLOC(_channelitem);
- if (item == NULL) {
- PyErr_NoMemory();
- return NULL;
+ return (_channelitem_id_t)item;
+}
+
+static void
+_channelitem_init(_channelitem *item,
+ _PyCrossInterpreterData *data, _waiting_t *waiting)
+{
+ *item = (_channelitem){
+ .data = data,
+ .waiting = waiting,
+ };
+ if (waiting != NULL) {
+ waiting->itemid = _channelitem_ID(item);
}
- item->data = NULL;
- item->next = NULL;
- return item;
}
static void
_channelitem_clear(_channelitem *item)
{
+ item->next = NULL;
+
if (item->data != NULL) {
// It was allocated in _channel_send().
(void)_release_xid_data(item->data, XID_IGNORE_EXC & XID_FREE);
item->data = NULL;
}
- item->next = NULL;
+
+ if (item->waiting != NULL) {
+ if (item->waiting->status == WAITING_ACQUIRED) {
+ _waiting_release(item->waiting, 0);
+ }
+ item->waiting = NULL;
+ }
+}
+
+static _channelitem *
+_channelitem_new(_PyCrossInterpreterData *data, _waiting_t *waiting)
+{
+ _channelitem *item = GLOBAL_MALLOC(_channelitem);
+ if (item == NULL) {
+ PyErr_NoMemory();
+ return NULL;
+ }
+ _channelitem_init(item, data, waiting);
+ return item;
}
static void
@@ -623,14 +754,17 @@ _channelitem_free_all(_channelitem *item)
}
}
-static _PyCrossInterpreterData *
-_channelitem_popped(_channelitem *item, PyThread_type_lock *recv_mutex)
+static void
+_channelitem_popped(_channelitem *item,
+ _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
{
- _PyCrossInterpreterData *data = item->data;
+ assert(item->waiting == NULL || item->waiting->status == WAITING_ACQUIRED);
+ *p_data = item->data;
+ *p_waiting = item->waiting;
+ // We clear them here, so they won't be released in _channelitem_clear().
item->data = NULL;
- *recv_mutex = item->recv_mutex;
+ item->waiting = NULL;
_channelitem_free(item);
- return data;
}
typedef struct _channelqueue {
@@ -670,15 +804,13 @@ _channelqueue_free(_channelqueue *queue)
}
static int
-_channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
- PyThread_type_lock recv_mutex)
+_channelqueue_put(_channelqueue *queue,
+ _PyCrossInterpreterData *data, _waiting_t *waiting)
{
- _channelitem *item = _channelitem_new();
+ _channelitem *item = _channelitem_new(data, waiting);
if (item == NULL) {
return -1;
}
- item->data = data;
- item->recv_mutex = recv_mutex;
queue->count += 1;
if (queue->first == NULL) {
@@ -688,15 +820,21 @@ _channelqueue_put(_channelqueue *queue, _PyCrossInterpreterData *data,
queue->last->next = item;
}
queue->last = item;
+
+ if (waiting != NULL) {
+ _waiting_acquire(waiting);
+ }
+
return 0;
}
-static _PyCrossInterpreterData *
-_channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
+static int
+_channelqueue_get(_channelqueue *queue,
+ _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
{
_channelitem *item = queue->first;
if (item == NULL) {
- return NULL;
+ return ERR_CHANNEL_EMPTY;
}
queue->first = item->next;
if (queue->last == item) {
@@ -704,7 +842,73 @@ _channelqueue_get(_channelqueue *queue, PyThread_type_lock *recv_mutex)
}
queue->count -= 1;
- return _channelitem_popped(item, recv_mutex);
+ _channelitem_popped(item, p_data, p_waiting);
+ return 0;
+}
+
+static int
+_channelqueue_find(_channelqueue *queue, _channelitem_id_t itemid,
+ _channelitem **p_item, _channelitem **p_prev)
+{
+ _channelitem *prev = NULL;
+ _channelitem *item = NULL;
+ if (queue->first != NULL) {
+ if (_channelitem_ID(queue->first) == itemid) {
+ item = queue->first;
+ }
+ else {
+ prev = queue->first;
+ while (prev->next != NULL) {
+ if (_channelitem_ID(prev->next) == itemid) {
+ item = prev->next;
+ break;
+ }
+ prev = prev->next;
+ }
+ if (item == NULL) {
+ prev = NULL;
+ }
+ }
+ }
+ if (p_item != NULL) {
+ *p_item = item;
+ }
+ if (p_prev != NULL) {
+ *p_prev = prev;
+ }
+ return (item != NULL);
+}
+
+static void
+_channelqueue_remove(_channelqueue *queue, _channelitem_id_t itemid,
+ _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
+{
+ _channelitem *prev = NULL;
+ _channelitem *item = NULL;
+ int found = _channelqueue_find(queue, itemid, &item, &prev);
+ if (!found) {
+ return;
+ }
+
+ assert(item->waiting != NULL);
+ assert(!item->waiting->received);
+ if (prev == NULL) {
+ assert(queue->first == item);
+ queue->first = item->next;
+ }
+ else {
+ assert(queue->first != item);
+ assert(prev->next == item);
+ prev->next = item->next;
+ }
+ item->next = NULL;
+
+ if (queue->last == item) {
+ queue->last = prev;
+ }
+ queue->count -= 1;
+
+ _channelitem_popped(item, p_data, p_waiting);
}
static void
@@ -1021,7 +1225,7 @@ _channel_free(_PyChannelState *chan)
static int
_channel_add(_PyChannelState *chan, int64_t interp,
- _PyCrossInterpreterData *data, PyThread_type_lock recv_mutex)
+ _PyCrossInterpreterData *data, _waiting_t *waiting)
{
int res = -1;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@@ -1035,9 +1239,10 @@ _channel_add(_PyChannelState *chan, int64_t interp,
goto done;
}
- if (_channelqueue_put(chan->queue, data, recv_mutex) != 0) {
+ if (_channelqueue_put(chan->queue, data, waiting) != 0) {
goto done;
}
+ // Any errors past this point must cause a _waiting_release() call.
res = 0;
done:
@@ -1047,7 +1252,7 @@ done:
static int
_channel_next(_PyChannelState *chan, int64_t interp,
- _PyCrossInterpreterData **res)
+ _PyCrossInterpreterData **p_data, _waiting_t **p_waiting)
{
int err = 0;
PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
@@ -1061,16 +1266,12 @@ _channel_next(_PyChannelState *chan, int64_t interp,
goto done;
}
- PyThread_type_lock recv_mutex = NULL;
- _PyCrossInterpreterData *data = _channelqueue_get(chan->queue, &recv_mutex);
- if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) {
+ int empty = _channelqueue_get(chan->queue, p_data, p_waiting);
+ assert(empty == 0 || empty == ERR_CHANNEL_EMPTY);
+ assert(!PyErr_Occurred());
+ if (empty && chan->closing != NULL) {
chan->open = 0;
}
- *res = data;
-
- if (recv_mutex != NULL) {
- PyThread_release_lock(recv_mutex);
- }
done:
PyThread_release_lock(chan->mutex);
@@ -1080,6 +1281,26 @@ done:
return err;
}
+static void
+_channel_remove(_PyChannelState *chan, _channelitem_id_t itemid)
+{
+ _PyCrossInterpreterData *data = NULL;
+ _waiting_t *waiting = NULL;
+
+ PyThread_acquire_lock(chan->mutex, WAIT_LOCK);
+ _channelqueue_remove(chan->queue, itemid, &data, &waiting);
+ PyThread_release_lock(chan->mutex);
+
+ (void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE);
+ if (waiting != NULL) {
+ _waiting_release(waiting, 0);
+ }
+
+ if (chan->queue->count == 0) {
+ _channel_finish_closing(chan);
+ }
+}
+
static int
_channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end)
{
@@ -1592,7 +1813,7 @@ _channel_destroy(_channels *channels, int64_t id)
static int
_channel_send(_channels *channels, int64_t id, PyObject *obj,
- PyThread_type_lock recv_mutex)
+ _waiting_t *waiting)
{
PyInterpreterState *interp = _get_current_interp();
if (interp == NULL) {
@@ -1627,8 +1848,8 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj,
}
// Add the data to the channel.
- int res = _channel_add(chan, PyInterpreterState_GetID(interp), data,
- recv_mutex);
+ int res = _channel_add(chan, PyInterpreterState_GetID(interp),
+ data, waiting);
PyThread_release_lock(mutex);
if (res != 0) {
// We may chain an exception here:
@@ -1640,31 +1861,74 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj,
return 0;
}
+static void
+_channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting)
+{
+ // Look up the channel.
+ PyThread_type_lock mutex = NULL;
+ _PyChannelState *chan = NULL;
+ int err = _channels_lookup(channels, cid, &mutex, &chan);
+ if (err != 0) {
+ // The channel was already closed, etc.
+ assert(waiting->status == WAITING_RELEASED);
+ return; // Ignore the error.
+ }
+ assert(chan != NULL);
+ // Past this point we are responsible for releasing the mutex.
+
+ _channelitem_id_t itemid = _waiting_get_itemid(waiting);
+ _channel_remove(chan, itemid);
+
+ PyThread_release_lock(mutex);
+}
+
static int
_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
{
- PyThread_type_lock mutex = PyThread_allocate_lock();
- if (mutex == NULL) {
- PyErr_NoMemory();
+ // We use a stack variable here, so we must ensure that &waiting
+ // is not held by any channel item at the point this function exits.
+ _waiting_t waiting;
+ if (_waiting_init(&waiting) < 0) {
+ assert(PyErr_Occurred());
return -1;
}
- PyThread_acquire_lock(mutex, NOWAIT_LOCK);
/* Queue up the object. */
- int res = _channel_send(channels, cid, obj, mutex);
+ int res = _channel_send(channels, cid, obj, &waiting);
if (res < 0) {
- PyThread_release_lock(mutex);
+ assert(waiting.status == WAITING_NO_STATUS);
goto finally;
}
/* Wait until the object is received. */
- wait_for_lock(mutex);
+ if (wait_for_lock(waiting.mutex) < 0) {
+ assert(PyErr_Occurred());
+ _waiting_finish_releasing(&waiting);
+ /* The send() call is failing now, so make sure the item
+ won't be received. */
+ _channel_clear_sent(channels, cid, &waiting);
+ assert(waiting.status == WAITING_RELEASED);
+ if (!waiting.received) {
+ res = -1;
+ goto finally;
+ }
+ // XXX Emit a warning if not a TimeoutError?
+ PyErr_Clear();
+ }
+ else {
+ _waiting_finish_releasing(&waiting);
+ assert(waiting.status == WAITING_RELEASED);
+ if (!waiting.received) {
+ res = ERR_CHANNEL_CLOSED_WAITING;
+ goto finally;
+ }
+ }
/* success! */
res = 0;
finally:
- // XXX Delete the lock.
+ _waiting_clear(&waiting);
return res;
}
@@ -1695,7 +1959,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res)
// Pop off the next item from the channel.
_PyCrossInterpreterData *data = NULL;
- err = _channel_next(chan, PyInterpreterState_GetID(interp), &data);
+ _waiting_t *waiting = NULL;
+ err = _channel_next(chan, PyInterpreterState_GetID(interp), &data,
+ &waiting);
PyThread_release_lock(mutex);
if (err != 0) {
return err;
@@ -1711,6 +1977,9 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res)
assert(PyErr_Occurred());
// It was allocated in _channel_send(), so we free it.
(void)_release_xid_data(data, XID_IGNORE_EXC | XID_FREE);
+ if (waiting != NULL) {
+ _waiting_release(waiting, 0);
+ }
return -1;
}
// It was allocated in _channel_send(), so we free it.
@@ -1719,9 +1988,17 @@ _channel_recv(_channels *channels, int64_t id, PyObject **res)
// The source interpreter has been destroyed already.
assert(PyErr_Occurred());
Py_DECREF(obj);
+ if (waiting != NULL) {
+ _waiting_release(waiting, 0);
+ }
return -1;
}
+ // Notify the sender.
+ if (waiting != NULL) {
+ _waiting_release(waiting, 1);
+ }
+
*res = obj;
return 0;
}
diff --git a/Python/thread.c b/Python/thread.c
index bf207ce..7185dd4 100644
--- a/Python/thread.c
+++ b/Python/thread.c
@@ -6,6 +6,7 @@
Stuff shared by all thread_*.h files is collected here. */
#include "Python.h"
+#include "pycore_ceval.h" // _PyEval_MakePendingCalls()
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include "pycore_structseq.h" // _PyStructSequence_FiniBuiltin()
#include "pycore_pythread.h" // _POSIX_THREADS
@@ -92,6 +93,55 @@ PyThread_set_stacksize(size_t size)
}
+PyLockStatus
+PyThread_acquire_lock_timed_with_retries(PyThread_type_lock lock,
+ PY_TIMEOUT_T timeout)
+{
+ PyThreadState *tstate = _PyThreadState_GET();
+ _PyTime_t endtime = 0;
+ if (timeout > 0) {
+ endtime = _PyDeadline_Init(timeout);
+ }
+
+ PyLockStatus r;
+ do {
+ _PyTime_t microseconds;
+ microseconds = _PyTime_AsMicroseconds(timeout, _PyTime_ROUND_CEILING);
+
+ /* first a simple non-blocking try without releasing the GIL */
+ r = PyThread_acquire_lock_timed(lock, 0, 0);
+ if (r == PY_LOCK_FAILURE && microseconds != 0) {
+ Py_BEGIN_ALLOW_THREADS
+ r = PyThread_acquire_lock_timed(lock, microseconds, 1);
+ Py_END_ALLOW_THREADS
+ }
+
+ if (r == PY_LOCK_INTR) {
+ /* Run signal handlers if we were interrupted. Propagate
+ * exceptions from signal handlers, such as KeyboardInterrupt, by
+ * passing up PY_LOCK_INTR. */
+ if (_PyEval_MakePendingCalls(tstate) < 0) {
+ return PY_LOCK_INTR;
+ }
+
+ /* If we're using a timeout, recompute the timeout after processing
+ * signals, since those can take time. */
+ if (timeout > 0) {
+ timeout = _PyDeadline_Get(endtime);
+
+ /* Check for negative values, since those mean block forever.
+ */
+ if (timeout < 0) {
+ r = PY_LOCK_FAILURE;
+ }
+ }
+ }
+ } while (r == PY_LOCK_INTR); /* Retry if we were interrupted. */
+
+ return r;
+}
+
+
/* Thread Specific Storage (TSS) API
Cross-platform components of TSS API implementation.