summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEric Snow <ericsnowcurrently@gmail.com>2023-10-17 23:05:49 (GMT)
committerGitHub <noreply@github.com>2023-10-17 23:05:49 (GMT)
commitc58c63fdf615a1c2bfc995dd0b938d82e32b6cde (patch)
treec55bd273d26fdcaf30013a44a7bf973940771a23
parent7029c1a1c5b864056aa00298b1d0e0269f073f99 (diff)
downloadcpython-c58c63fdf615a1c2bfc995dd0b938d82e32b6cde.zip
cpython-c58c63fdf615a1c2bfc995dd0b938d82e32b6cde.tar.gz
cpython-c58c63fdf615a1c2bfc995dd0b938d82e32b6cde.tar.bz2
gh-84570: Add Timeouts to SendChannel.send() and RecvChannel.recv() (gh-110567)
-rw-r--r--Include/internal/pycore_pythread.h6
-rw-r--r--Lib/test/support/interpreters.py20
-rw-r--r--Lib/test/test__xxinterpchannels.py128
-rw-r--r--Lib/test/test_interpreters.py5
-rw-r--r--Modules/_queuemodule.c2
-rw-r--r--Modules/_threadmodule.c11
-rw-r--r--Modules/_xxinterpchannelsmodule.c43
-rw-r--r--Python/thread.c34
8 files changed, 202 insertions, 47 deletions
diff --git a/Include/internal/pycore_pythread.h b/Include/internal/pycore_pythread.h
index ffd7398..d31ffc7 100644
--- a/Include/internal/pycore_pythread.h
+++ b/Include/internal/pycore_pythread.h
@@ -89,6 +89,12 @@ extern int _PyThread_at_fork_reinit(PyThread_type_lock *lock);
// unset: -1 seconds, in nanoseconds
#define PyThread_UNSET_TIMEOUT ((_PyTime_t)(-1 * 1000 * 1000 * 1000))
+// Exported for the _xxinterpchannels module.
+PyAPI_FUNC(int) PyThread_ParseTimeoutArg(
+ PyObject *arg,
+ int blocking,
+ PY_TIMEOUT_T *timeout);
+
/* Helper to acquire an interruptible lock with a timeout. If the lock acquire
* is interrupted, signal handlers are run, and if they raise an exception,
* PY_LOCK_INTR is returned. Otherwise, PY_LOCK_ACQUIRED or PY_LOCK_FAILURE
diff --git a/Lib/test/support/interpreters.py b/Lib/test/support/interpreters.py
index 9ba6862..f8f42c0 100644
--- a/Lib/test/support/interpreters.py
+++ b/Lib/test/support/interpreters.py
@@ -170,15 +170,25 @@ class RecvChannel(_ChannelEnd):
_end = 'recv'
- def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds
+ def recv(self, timeout=None, *,
+ _sentinel=object(),
+ _delay=10 / 1000, # 10 milliseconds
+ ):
"""Return the next object from the channel.
This blocks until an object has been sent, if none have been
sent already.
"""
+ if timeout is not None:
+ timeout = int(timeout)
+ if timeout < 0:
+ raise ValueError(f'timeout value must be non-negative')
+ end = time.time() + timeout
obj = _channels.recv(self._id, _sentinel)
while obj is _sentinel:
time.sleep(_delay)
+ if timeout is not None and time.time() >= end:
+ raise TimeoutError
obj = _channels.recv(self._id, _sentinel)
return obj
@@ -203,12 +213,12 @@ class SendChannel(_ChannelEnd):
_end = 'send'
- def send(self, obj):
+ def send(self, obj, timeout=None):
"""Send the object (i.e. its data) to the channel's receiving end.
This blocks until the object is received.
"""
- _channels.send(self._id, obj, blocking=True)
+ _channels.send(self._id, obj, timeout=timeout, blocking=True)
def send_nowait(self, obj):
"""Send the object to the channel's receiving end.
@@ -221,12 +231,12 @@ class SendChannel(_ChannelEnd):
# See bpo-32604 and gh-19829.
return _channels.send(self._id, obj, blocking=False)
- def send_buffer(self, obj):
+ def send_buffer(self, obj, timeout=None):
"""Send the object's buffer to the channel's receiving end.
This blocks until the object is received.
"""
- _channels.send_buffer(self._id, obj, blocking=True)
+ _channels.send_buffer(self._id, obj, timeout=timeout, blocking=True)
def send_buffer_nowait(self, obj):
"""Send the object's buffer to the channel's receiving end.
diff --git a/Lib/test/test__xxinterpchannels.py b/Lib/test/test__xxinterpchannels.py
index 90a1224..1c1ef3f 100644
--- a/Lib/test/test__xxinterpchannels.py
+++ b/Lib/test/test__xxinterpchannels.py
@@ -864,22 +864,97 @@ class ChannelTests(TestBase):
self.assertEqual(received, obj)
+ def test_send_timeout(self):
+ obj = b'spam'
+
+ with self.subTest('non-blocking with timeout'):
+ cid = channels.create()
+ with self.assertRaises(ValueError):
+ channels.send(cid, obj, blocking=False, timeout=0.1)
+
+ with self.subTest('timeout hit'):
+ cid = channels.create()
+ with self.assertRaises(TimeoutError):
+ channels.send(cid, obj, blocking=True, timeout=0.1)
+ with self.assertRaises(channels.ChannelEmptyError):
+ received = channels.recv(cid)
+ print(repr(received))
+
+ with self.subTest('timeout not hit'):
+ cid = channels.create()
+ def f():
+ recv_wait(cid)
+ t = threading.Thread(target=f)
+ t.start()
+ channels.send(cid, obj, blocking=True, timeout=10)
+ t.join()
+
+ def test_send_buffer_timeout(self):
+ try:
+ self._has_run_once_timeout
+ except AttributeError:
+ # At the moment, this test leaks a few references.
+ # It looks like the leak originates with the addition
+ # of _channels.send_buffer() (gh-110246), whereas the
+ # tests were added afterward. We want this test even
+ # if the refleak isn't fixed yet, so we skip here.
+ raise unittest.SkipTest('temporarily skipped due to refleaks')
+ else:
+ self._has_run_once_timeout = True
+
+ obj = bytearray(b'spam')
+
+ with self.subTest('non-blocking with timeout'):
+ cid = channels.create()
+ with self.assertRaises(ValueError):
+ channels.send_buffer(cid, obj, blocking=False, timeout=0.1)
+
+ with self.subTest('timeout hit'):
+ cid = channels.create()
+ with self.assertRaises(TimeoutError):
+ channels.send_buffer(cid, obj, blocking=True, timeout=0.1)
+ with self.assertRaises(channels.ChannelEmptyError):
+ received = channels.recv(cid)
+ print(repr(received))
+
+ with self.subTest('timeout not hit'):
+ cid = channels.create()
+ def f():
+ recv_wait(cid)
+ t = threading.Thread(target=f)
+ t.start()
+ channels.send_buffer(cid, obj, blocking=True, timeout=10)
+ t.join()
+
def test_send_closed_while_waiting(self):
obj = b'spam'
wait = self.build_send_waiter(obj)
- cid = channels.create()
- def f():
- wait()
- channels.close(cid, force=True)
- t = threading.Thread(target=f)
- t.start()
- with self.assertRaises(channels.ChannelClosedError):
- channels.send(cid, obj, blocking=True)
- t.join()
+
+ with self.subTest('without timeout'):
+ cid = channels.create()
+ def f():
+ wait()
+ channels.close(cid, force=True)
+ t = threading.Thread(target=f)
+ t.start()
+ with self.assertRaises(channels.ChannelClosedError):
+ channels.send(cid, obj, blocking=True)
+ t.join()
+
+ with self.subTest('with timeout'):
+ cid = channels.create()
+ def f():
+ wait()
+ channels.close(cid, force=True)
+ t = threading.Thread(target=f)
+ t.start()
+ with self.assertRaises(channels.ChannelClosedError):
+ channels.send(cid, obj, blocking=True, timeout=30)
+ t.join()
def test_send_buffer_closed_while_waiting(self):
try:
- self._has_run_once
+ self._has_run_once_closed
except AttributeError:
# At the moment, this test leaks a few references.
# It looks like the leak originates with the addition
@@ -888,19 +963,32 @@ class ChannelTests(TestBase):
# if the refleak isn't fixed yet, so we skip here.
raise unittest.SkipTest('temporarily skipped due to refleaks')
else:
- self._has_run_once = True
+ self._has_run_once_closed = True
obj = bytearray(b'spam')
wait = self.build_send_waiter(obj, buffer=True)
- cid = channels.create()
- def f():
- wait()
- channels.close(cid, force=True)
- t = threading.Thread(target=f)
- t.start()
- with self.assertRaises(channels.ChannelClosedError):
- channels.send_buffer(cid, obj, blocking=True)
- t.join()
+
+ with self.subTest('without timeout'):
+ cid = channels.create()
+ def f():
+ wait()
+ channels.close(cid, force=True)
+ t = threading.Thread(target=f)
+ t.start()
+ with self.assertRaises(channels.ChannelClosedError):
+ channels.send_buffer(cid, obj, blocking=True)
+ t.join()
+
+ with self.subTest('with timeout'):
+ cid = channels.create()
+ def f():
+ wait()
+ channels.close(cid, force=True)
+ t = threading.Thread(target=f)
+ t.start()
+ with self.assertRaises(channels.ChannelClosedError):
+ channels.send_buffer(cid, obj, blocking=True, timeout=30)
+ t.join()
#-------------------
# close
diff --git a/Lib/test/test_interpreters.py b/Lib/test/test_interpreters.py
index 0910b51..d2d52ec 100644
--- a/Lib/test/test_interpreters.py
+++ b/Lib/test/test_interpreters.py
@@ -1022,6 +1022,11 @@ class TestSendRecv(TestBase):
self.assertEqual(obj2, b'eggs')
self.assertNotEqual(id(obj2), int(out))
+ def test_recv_timeout(self):
+ r, _ = interpreters.create_channel()
+ with self.assertRaises(TimeoutError):
+ r.recv(timeout=1)
+
def test_recv_channel_does_not_exist(self):
ch = interpreters.RecvChannel(1_000_000)
with self.assertRaises(interpreters.ChannelNotFoundError):
diff --git a/Modules/_queuemodule.c b/Modules/_queuemodule.c
index b4bafb3..81a06cd 100644
--- a/Modules/_queuemodule.c
+++ b/Modules/_queuemodule.c
@@ -214,6 +214,8 @@ _queue_SimpleQueue_get_impl(simplequeueobject *self, PyTypeObject *cls,
PY_TIMEOUT_T microseconds;
PyThreadState *tstate = PyThreadState_Get();
+ // XXX Use PyThread_ParseTimeoutArg().
+
if (block == 0) {
/* Non-blocking */
microseconds = 0;
diff --git a/Modules/_threadmodule.c b/Modules/_threadmodule.c
index 7620511..4d45304 100644
--- a/Modules/_threadmodule.c
+++ b/Modules/_threadmodule.c
@@ -88,14 +88,15 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds,
char *kwlist[] = {"blocking", "timeout", NULL};
int blocking = 1;
PyObject *timeout_obj = NULL;
- const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1);
-
- *timeout = unset_timeout ;
-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|pO:acquire", kwlist,
&blocking, &timeout_obj))
return -1;
+ // XXX Use PyThread_ParseTimeoutArg().
+
+ const _PyTime_t unset_timeout = _PyTime_FromSeconds(-1);
+ *timeout = unset_timeout;
+
if (timeout_obj
&& _PyTime_FromSecondsObject(timeout,
timeout_obj, _PyTime_ROUND_TIMEOUT) < 0)
@@ -108,7 +109,7 @@ lock_acquire_parse_args(PyObject *args, PyObject *kwds,
}
if (*timeout < 0 && *timeout != unset_timeout) {
PyErr_SetString(PyExc_ValueError,
- "timeout value must be positive");
+ "timeout value must be a non-negative number");
return -1;
}
if (!blocking)
diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c
index be53cbf..2e2878d 100644
--- a/Modules/_xxinterpchannelsmodule.c
+++ b/Modules/_xxinterpchannelsmodule.c
@@ -242,9 +242,8 @@ add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
}
static int
-wait_for_lock(PyThread_type_lock mutex)
+wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout)
{
- PY_TIMEOUT_T timeout = PyThread_UNSET_TIMEOUT;
PyLockStatus res = PyThread_acquire_lock_timed_with_retries(mutex, timeout);
if (res == PY_LOCK_INTR) {
/* KeyboardInterrupt, etc. */
@@ -1883,7 +1882,8 @@ _channel_clear_sent(_channels *channels, int64_t cid, _waiting_t *waiting)
}
static int
-_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
+_channel_send_wait(_channels *channels, int64_t cid, PyObject *obj,
+ PY_TIMEOUT_T timeout)
{
// We use a stack variable here, so we must ensure that &waiting
// is not held by any channel item at the point this function exits.
@@ -1901,7 +1901,7 @@ _channel_send_wait(_channels *channels, int64_t cid, PyObject *obj)
}
/* Wait until the object is received. */
- if (wait_for_lock(waiting.mutex) < 0) {
+ if (wait_for_lock(waiting.mutex, timeout) < 0) {
assert(PyErr_Occurred());
_waiting_finish_releasing(&waiting);
/* The send() call is failing now, so make sure the item
@@ -2816,25 +2816,29 @@ receive end.");
static PyObject *
channel_send(PyObject *self, PyObject *args, PyObject *kwds)
{
- // XXX Add a timeout arg.
- static char *kwlist[] = {"cid", "obj", "blocking", NULL};
- int64_t cid;
+ static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL};
struct channel_id_converter_data cid_data = {
.module = self,
};
PyObject *obj;
int blocking = 1;
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$p:channel_send", kwlist,
+ PyObject *timeout_obj = NULL;
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O|$pO:channel_send", kwlist,
channel_id_converter, &cid_data, &obj,
- &blocking)) {
+ &blocking, &timeout_obj)) {
+ return NULL;
+ }
+
+ int64_t cid = cid_data.cid;
+ PY_TIMEOUT_T timeout;
+ if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
return NULL;
}
- cid = cid_data.cid;
/* Queue up the object. */
int err = 0;
if (blocking) {
- err = _channel_send_wait(&_globals.channels, cid, obj);
+ err = _channel_send_wait(&_globals.channels, cid, obj, timeout);
}
else {
err = _channel_send(&_globals.channels, cid, obj, NULL);
@@ -2855,20 +2859,25 @@ By default this waits for the object to be received.");
static PyObject *
channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
{
- static char *kwlist[] = {"cid", "obj", "blocking", NULL};
- int64_t cid;
+ static char *kwlist[] = {"cid", "obj", "blocking", "timeout", NULL};
struct channel_id_converter_data cid_data = {
.module = self,
};
PyObject *obj;
int blocking = 1;
+ PyObject *timeout_obj = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds,
- "O&O|$p:channel_send_buffer", kwlist,
+ "O&O|$pO:channel_send_buffer", kwlist,
channel_id_converter, &cid_data, &obj,
- &blocking)) {
+ &blocking, &timeout_obj)) {
+ return NULL;
+ }
+
+ int64_t cid = cid_data.cid;
+ PY_TIMEOUT_T timeout;
+ if (PyThread_ParseTimeoutArg(timeout_obj, blocking, &timeout) < 0) {
return NULL;
}
- cid = cid_data.cid;
PyObject *tempobj = PyMemoryView_FromObject(obj);
if (tempobj == NULL) {
@@ -2878,7 +2887,7 @@ channel_send_buffer(PyObject *self, PyObject *args, PyObject *kwds)
/* Queue up the object. */
int err = 0;
if (blocking) {
- err = _channel_send_wait(&_globals.channels, cid, tempobj);
+ err = _channel_send_wait(&_globals.channels, cid, tempobj, timeout);
}
else {
err = _channel_send(&_globals.channels, cid, tempobj, NULL);
diff --git a/Python/thread.c b/Python/thread.c
index 7185dd4..fefae83 100644
--- a/Python/thread.c
+++ b/Python/thread.c
@@ -93,6 +93,40 @@ PyThread_set_stacksize(size_t size)
}
+int
+PyThread_ParseTimeoutArg(PyObject *arg, int blocking, PY_TIMEOUT_T *timeout_p)
+{
+ assert(_PyTime_FromSeconds(-1) == PyThread_UNSET_TIMEOUT);
+ if (arg == NULL || arg == Py_None) {
+ *timeout_p = blocking ? PyThread_UNSET_TIMEOUT : 0;
+ return 0;
+ }
+ if (!blocking) {
+ PyErr_SetString(PyExc_ValueError,
+ "can't specify a timeout for a non-blocking call");
+ return -1;
+ }
+
+ _PyTime_t timeout;
+ if (_PyTime_FromSecondsObject(&timeout, arg, _PyTime_ROUND_TIMEOUT) < 0) {
+ return -1;
+ }
+ if (timeout < 0) {
+ PyErr_SetString(PyExc_ValueError,
+ "timeout value must be a non-negative number");
+ return -1;
+ }
+
+ if (_PyTime_AsMicroseconds(timeout,
+ _PyTime_ROUND_TIMEOUT) > PY_TIMEOUT_MAX) {
+ PyErr_SetString(PyExc_OverflowError,
+ "timeout value is too large");
+ return -1;
+ }
+ *timeout_p = timeout;
+ return 0;
+}
+
PyLockStatus
PyThread_acquire_lock_timed_with_retries(PyThread_type_lock lock,
PY_TIMEOUT_T timeout)