summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEric Snow <ericsnowcurrently@gmail.com>2023-10-02 20:47:41 (GMT)
committerGitHub <noreply@github.com>2023-10-02 20:47:41 (GMT)
commita8f5dab58daca9f01ec3c6f8c85e53329251b05d (patch)
tree7a9baf60b05d29b529799c40d3cdfa38e89d25bb
parent014aacda6239f0e33b3ad5ece343df66701804b2 (diff)
downloadcpython-a8f5dab58daca9f01ec3c6f8c85e53329251b05d.zip
cpython-a8f5dab58daca9f01ec3c6f8c85e53329251b05d.tar.gz
cpython-a8f5dab58daca9f01ec3c6f8c85e53329251b05d.tar.bz2
gh-76785: Module-level Fixes for test.support.interpreters (gh-110236)
* add RecvChannel.close() and SendChannel.close() * make RecvChannel and SendChannel shareable * expose ChannelEmptyError and ChannelNotEmptyError
-rw-r--r--Lib/test/support/interpreters.py30
-rw-r--r--Lib/test/test_interpreters.py16
-rw-r--r--Modules/_xxinterpchannelsmodule.c185
3 files changed, 206 insertions, 25 deletions
diff --git a/Lib/test/support/interpreters.py b/Lib/test/support/interpreters.py
index eeff3ab..d2beba3 100644
--- a/Lib/test/support/interpreters.py
+++ b/Lib/test/support/interpreters.py
@@ -7,7 +7,8 @@ import _xxinterpchannels as _channels
# aliases:
from _xxsubinterpreters import is_shareable
from _xxinterpchannels import (
- ChannelError, ChannelNotFoundError, ChannelEmptyError,
+ ChannelError, ChannelNotFoundError, ChannelClosedError,
+ ChannelEmptyError, ChannelNotEmptyError,
)
@@ -117,10 +118,16 @@ def list_all_channels():
class _ChannelEnd:
"""The base class for RecvChannel and SendChannel."""
- def __init__(self, id):
- if not isinstance(id, (int, _channels.ChannelID)):
- raise TypeError(f'id must be an int, got {id!r}')
- self._id = id
+ _end = None
+
+ def __init__(self, cid):
+ if self._end == 'send':
+ cid = _channels._channel_id(cid, send=True, force=True)
+ elif self._end == 'recv':
+ cid = _channels._channel_id(cid, recv=True, force=True)
+ else:
+ raise NotImplementedError(self._end)
+ self._id = cid
def __repr__(self):
return f'{type(self).__name__}(id={int(self._id)})'
@@ -147,6 +154,8 @@ _NOT_SET = object()
class RecvChannel(_ChannelEnd):
"""The receiving end of a cross-interpreter channel."""
+ _end = 'recv'
+
def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds
"""Return the next object from the channel.
@@ -171,10 +180,15 @@ class RecvChannel(_ChannelEnd):
else:
return _channels.recv(self._id, default)
+ def close(self):
+ _channels.close(self._id, recv=True)
+
class SendChannel(_ChannelEnd):
"""The sending end of a cross-interpreter channel."""
+ _end = 'send'
+
def send(self, obj):
"""Send the object (i.e. its data) to the channel's receiving end.
@@ -196,3 +210,9 @@ class SendChannel(_ChannelEnd):
# None. This should be fixed when channel_send_wait() is added.
# See bpo-32604 and gh-19829.
return _channels.send(self._id, obj)
+
+ def close(self):
+ _channels.close(self._id, send=True)
+
+
+_channels._register_end_types(SendChannel, RecvChannel)
diff --git a/Lib/test/test_interpreters.py b/Lib/test/test_interpreters.py
index e62859a..ffdd8a1 100644
--- a/Lib/test/test_interpreters.py
+++ b/Lib/test/test_interpreters.py
@@ -822,6 +822,22 @@ class TestChannels(TestBase):
after = set(interpreters.list_all_channels())
self.assertEqual(after, created)
+ def test_shareable(self):
+ rch, sch = interpreters.create_channel()
+
+ self.assertTrue(
+ interpreters.is_shareable(rch))
+ self.assertTrue(
+ interpreters.is_shareable(sch))
+
+ sch.send_nowait(rch)
+ sch.send_nowait(sch)
+ rch2 = rch.recv()
+ sch2 = rch.recv()
+
+ self.assertEqual(rch2, rch)
+ self.assertEqual(sch2, sch)
+
class TestRecvChannelAttrs(TestBase):
diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c
index 6096f88..d5be76f 100644
--- a/Modules/_xxinterpchannelsmodule.c
+++ b/Modules/_xxinterpchannelsmodule.c
@@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags)
/* module state *************************************************************/
typedef struct {
+ PyTypeObject *send_channel_type;
+ PyTypeObject *recv_channel_type;
+
/* heap types */
PyTypeObject *ChannelIDType;
@@ -218,6 +221,21 @@ get_module_state(PyObject *mod)
return state;
}
+static module_state *
+_get_current_module_state(void)
+{
+ PyObject *mod = _get_current_module();
+ if (mod == NULL) {
+ // XXX import it?
+ PyErr_SetString(PyExc_RuntimeError,
+ MODULE_NAME " module not imported yet");
+ return NULL;
+ }
+ module_state *state = get_module_state(mod);
+ Py_DECREF(mod);
+ return state;
+}
+
static int
traverse_module_state(module_state *state, visitproc visit, void *arg)
{
@@ -237,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
static int
clear_module_state(module_state *state)
{
+ Py_CLEAR(state->send_channel_type);
+ Py_CLEAR(state->recv_channel_type);
+
/* heap types */
if (state->ChannelIDType != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
@@ -1529,17 +1550,20 @@ typedef struct channelid {
struct channel_id_converter_data {
PyObject *module;
int64_t cid;
+ int end;
};
static int
channel_id_converter(PyObject *arg, void *ptr)
{
int64_t cid;
+ int end = 0;
struct channel_id_converter_data *data = ptr;
module_state *state = get_module_state(data->module);
assert(state != NULL);
if (PyObject_TypeCheck(arg, state->ChannelIDType)) {
cid = ((channelid *)arg)->id;
+ end = ((channelid *)arg)->end;
}
else if (PyIndex_Check(arg)) {
cid = PyLong_AsLongLong(arg);
@@ -1559,6 +1583,7 @@ channel_id_converter(PyObject *arg, void *ptr)
return 0;
}
data->cid = cid;
+ data->end = end;
return 1;
}
@@ -1600,6 +1625,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
{
static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL};
int64_t cid;
+ int end;
struct channel_id_converter_data cid_data = {
.module = mod,
};
@@ -1614,6 +1640,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
return NULL;
}
cid = cid_data.cid;
+ end = cid_data.end;
// Handle "send" and "recv".
if (send == 0 && recv == 0) {
@@ -1621,14 +1648,17 @@ _channelid_new(PyObject *mod, PyTypeObject *cls,
"'send' and 'recv' cannot both be False");
return NULL;
}
-
- int end = 0;
- if (send == 1) {
+ else if (send == 1) {
if (recv == 0 || recv == -1) {
end = CHANNEL_SEND;
}
+ else {
+ assert(recv == 1);
+ end = 0;
+ }
}
else if (recv == 1) {
+ assert(send == 0 || send == -1);
end = CHANNEL_RECV;
}
@@ -1773,21 +1803,12 @@ done:
return res;
}
+static PyTypeObject * _get_current_channel_end_type(int end);
+
static PyObject *
_channel_from_cid(PyObject *cid, int end)
{
- PyObject *highlevel = PyImport_ImportModule("interpreters");
- if (highlevel == NULL) {
- PyErr_Clear();
- highlevel = PyImport_ImportModule("test.support.interpreters");
- if (highlevel == NULL) {
- return NULL;
- }
- }
- const char *clsname = (end == CHANNEL_RECV) ? "RecvChannel" :
- "SendChannel";
- PyObject *cls = PyObject_GetAttrString(highlevel, clsname);
- Py_DECREF(highlevel);
+ PyObject *cls = (PyObject *)_get_current_channel_end_type(end);
if (cls == NULL) {
return NULL;
}
@@ -1943,6 +1964,103 @@ static PyType_Spec ChannelIDType_spec = {
};
+/* SendChannel and RecvChannel classes */
+
+// XXX Use a new __xid__ protocol instead?
+
+static PyTypeObject *
+_get_current_channel_end_type(int end)
+{
+ module_state *state = _get_current_module_state();
+ if (state == NULL) {
+ return NULL;
+ }
+ PyTypeObject *cls;
+ if (end == CHANNEL_SEND) {
+ cls = state->send_channel_type;
+ }
+ else {
+ assert(end == CHANNEL_RECV);
+ cls = state->recv_channel_type;
+ }
+ if (cls == NULL) {
+ PyObject *highlevel = PyImport_ImportModule("interpreters");
+ if (highlevel == NULL) {
+ PyErr_Clear();
+ highlevel = PyImport_ImportModule("test.support.interpreters");
+ if (highlevel == NULL) {
+ return NULL;
+ }
+ }
+ if (end == CHANNEL_SEND) {
+ cls = state->send_channel_type;
+ }
+ else {
+ cls = state->recv_channel_type;
+ }
+ assert(cls != NULL);
+ }
+ return cls;
+}
+
+static PyObject *
+_channel_end_from_xid(_PyCrossInterpreterData *data)
+{
+ channelid *cid = (channelid *)_channelid_from_xid(data);
+ if (cid == NULL) {
+ return NULL;
+ }
+ PyTypeObject *cls = _get_current_channel_end_type(cid->end);
+ if (cls == NULL) {
+ return NULL;
+ }
+ PyObject *obj = PyObject_CallOneArg((PyObject *)cls, (PyObject *)cid);
+ Py_DECREF(cid);
+ return obj;
+}
+
+static int
+_channel_end_shared(PyThreadState *tstate, PyObject *obj,
+ _PyCrossInterpreterData *data)
+{
+ PyObject *cidobj = PyObject_GetAttrString(obj, "_id");
+ if (cidobj == NULL) {
+ return -1;
+ }
+ if (_channelid_shared(tstate, cidobj, data) < 0) {
+ return -1;
+ }
+ data->new_object = _channel_end_from_xid;
+ return 0;
+}
+
+static int
+set_channel_end_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
+{
+ module_state *state = get_module_state(mod);
+ if (state == NULL) {
+ return -1;
+ }
+
+ if (state->send_channel_type != NULL
+ || state->recv_channel_type != NULL)
+ {
+ PyErr_SetString(PyExc_TypeError, "already registered");
+ return -1;
+ }
+ state->send_channel_type = (PyTypeObject *)Py_NewRef(send);
+ state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv);
+
+ if (_PyCrossInterpreterData_RegisterClass(send, _channel_end_shared)) {
+ return -1;
+ }
+ if (_PyCrossInterpreterData_RegisterClass(recv, _channel_end_shared)) {
+ return -1;
+ }
+
+ return 0;
+}
+
/* module level code ********************************************************/
/* globals is the process-global state for the module. It holds all
@@ -2346,13 +2464,38 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
return NULL;
}
PyTypeObject *cls = state->ChannelIDType;
- PyObject *mod = get_module_from_owned_type(cls);
- if (mod == NULL) {
+ assert(get_module_from_owned_type(cls) == self);
+
+ return _channelid_new(self, cls, args, kwds);
+}
+
+static PyObject *
+channel__register_end_types(PyObject *self, PyObject *args, PyObject *kwds)
+{
+ static char *kwlist[] = {"send", "recv", NULL};
+ PyObject *send;
+ PyObject *recv;
+ if (!PyArg_ParseTupleAndKeywords(args, kwds,
+ "OO:_register_end_types", kwlist,
+ &send, &recv)) {
return NULL;
}
- PyObject *cid = _channelid_new(mod, cls, args, kwds);
- Py_DECREF(mod);
- return cid;
+ if (!PyType_Check(send)) {
+ PyErr_SetString(PyExc_TypeError, "expected a type for 'send'");
+ return NULL;
+ }
+ if (!PyType_Check(recv)) {
+ PyErr_SetString(PyExc_TypeError, "expected a type for 'recv'");
+ return NULL;
+ }
+ PyTypeObject *cls_send = (PyTypeObject *)send;
+ PyTypeObject *cls_recv = (PyTypeObject *)recv;
+
+ if (set_channel_end_types(self, cls_send, cls_recv) < 0) {
+ return NULL;
+ }
+
+ Py_RETURN_NONE;
}
static PyMethodDef module_functions[] = {
@@ -2374,6 +2517,8 @@ static PyMethodDef module_functions[] = {
METH_VARARGS | METH_KEYWORDS, channel_release_doc},
{"_channel_id", _PyCFunction_CAST(channel__channel_id),
METH_VARARGS | METH_KEYWORDS, NULL},
+ {"_register_end_types", _PyCFunction_CAST(channel__register_end_types),
+ METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL} /* sentinel */
};