diff options
author | Eric Snow <ericsnowcurrently@gmail.com> | 2023-10-02 20:47:41 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-02 20:47:41 (GMT) |
commit | a8f5dab58daca9f01ec3c6f8c85e53329251b05d (patch) | |
tree | 7a9baf60b05d29b529799c40d3cdfa38e89d25bb | |
parent | 014aacda6239f0e33b3ad5ece343df66701804b2 (diff) | |
download | cpython-a8f5dab58daca9f01ec3c6f8c85e53329251b05d.zip cpython-a8f5dab58daca9f01ec3c6f8c85e53329251b05d.tar.gz cpython-a8f5dab58daca9f01ec3c6f8c85e53329251b05d.tar.bz2 |
gh-76785: Module-level Fixes for test.support.interpreters (gh-110236)
* add RecvChannel.close() and SendChannel.close()
* make RecvChannel and SendChannel shareable
* expose ChannelEmptyError and ChannelNotEmptyError
-rw-r--r-- | Lib/test/support/interpreters.py | 30 | ||||
-rw-r--r-- | Lib/test/test_interpreters.py | 16 | ||||
-rw-r--r-- | Modules/_xxinterpchannelsmodule.c | 185 |
3 files changed, 206 insertions, 25 deletions
diff --git a/Lib/test/support/interpreters.py b/Lib/test/support/interpreters.py index eeff3ab..d2beba3 100644 --- a/Lib/test/support/interpreters.py +++ b/Lib/test/support/interpreters.py @@ -7,7 +7,8 @@ import _xxinterpchannels as _channels # aliases: from _xxsubinterpreters import is_shareable from _xxinterpchannels import ( - ChannelError, ChannelNotFoundError, ChannelEmptyError, + ChannelError, ChannelNotFoundError, ChannelClosedError, + ChannelEmptyError, ChannelNotEmptyError, ) @@ -117,10 +118,16 @@ def list_all_channels(): class _ChannelEnd: """The base class for RecvChannel and SendChannel.""" - def __init__(self, id): - if not isinstance(id, (int, _channels.ChannelID)): - raise TypeError(f'id must be an int, got {id!r}') - self._id = id + _end = None + + def __init__(self, cid): + if self._end == 'send': + cid = _channels._channel_id(cid, send=True, force=True) + elif self._end == 'recv': + cid = _channels._channel_id(cid, recv=True, force=True) + else: + raise NotImplementedError(self._end) + self._id = cid def __repr__(self): return f'{type(self).__name__}(id={int(self._id)})' @@ -147,6 +154,8 @@ _NOT_SET = object() class RecvChannel(_ChannelEnd): """The receiving end of a cross-interpreter channel.""" + _end = 'recv' + def recv(self, *, _sentinel=object(), _delay=10 / 1000): # 10 milliseconds """Return the next object from the channel. @@ -171,10 +180,15 @@ class RecvChannel(_ChannelEnd): else: return _channels.recv(self._id, default) + def close(self): + _channels.close(self._id, recv=True) + class SendChannel(_ChannelEnd): """The sending end of a cross-interpreter channel.""" + _end = 'send' + def send(self, obj): """Send the object (i.e. its data) to the channel's receiving end. @@ -196,3 +210,9 @@ class SendChannel(_ChannelEnd): # None. This should be fixed when channel_send_wait() is added. # See bpo-32604 and gh-19829. return _channels.send(self._id, obj) + + def close(self): + _channels.close(self._id, send=True) + + +_channels._register_end_types(SendChannel, RecvChannel) diff --git a/Lib/test/test_interpreters.py b/Lib/test/test_interpreters.py index e62859a..ffdd8a1 100644 --- a/Lib/test/test_interpreters.py +++ b/Lib/test/test_interpreters.py @@ -822,6 +822,22 @@ class TestChannels(TestBase): after = set(interpreters.list_all_channels()) self.assertEqual(after, created) + def test_shareable(self): + rch, sch = interpreters.create_channel() + + self.assertTrue( + interpreters.is_shareable(rch)) + self.assertTrue( + interpreters.is_shareable(sch)) + + sch.send_nowait(rch) + sch.send_nowait(sch) + rch2 = rch.recv() + sch2 = rch.recv() + + self.assertEqual(rch2, rch) + self.assertEqual(sch2, sch) + class TestRecvChannelAttrs(TestBase): diff --git a/Modules/_xxinterpchannelsmodule.c b/Modules/_xxinterpchannelsmodule.c index 6096f88..d5be76f 100644 --- a/Modules/_xxinterpchannelsmodule.c +++ b/Modules/_xxinterpchannelsmodule.c @@ -198,6 +198,9 @@ _release_xid_data(_PyCrossInterpreterData *data, int flags) /* module state *************************************************************/ typedef struct { + PyTypeObject *send_channel_type; + PyTypeObject *recv_channel_type; + /* heap types */ PyTypeObject *ChannelIDType; @@ -218,6 +221,21 @@ get_module_state(PyObject *mod) return state; } +static module_state * +_get_current_module_state(void) +{ + PyObject *mod = _get_current_module(); + if (mod == NULL) { + // XXX import it? + PyErr_SetString(PyExc_RuntimeError, + MODULE_NAME " module not imported yet"); + return NULL; + } + module_state *state = get_module_state(mod); + Py_DECREF(mod); + return state; +} + static int traverse_module_state(module_state *state, visitproc visit, void *arg) { @@ -237,6 +255,9 @@ traverse_module_state(module_state *state, visitproc visit, void *arg) static int clear_module_state(module_state *state) { + Py_CLEAR(state->send_channel_type); + Py_CLEAR(state->recv_channel_type); + /* heap types */ if (state->ChannelIDType != NULL) { (void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType); @@ -1529,17 +1550,20 @@ typedef struct channelid { struct channel_id_converter_data { PyObject *module; int64_t cid; + int end; }; static int channel_id_converter(PyObject *arg, void *ptr) { int64_t cid; + int end = 0; struct channel_id_converter_data *data = ptr; module_state *state = get_module_state(data->module); assert(state != NULL); if (PyObject_TypeCheck(arg, state->ChannelIDType)) { cid = ((channelid *)arg)->id; + end = ((channelid *)arg)->end; } else if (PyIndex_Check(arg)) { cid = PyLong_AsLongLong(arg); @@ -1559,6 +1583,7 @@ channel_id_converter(PyObject *arg, void *ptr) return 0; } data->cid = cid; + data->end = end; return 1; } @@ -1600,6 +1625,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls, { static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL}; int64_t cid; + int end; struct channel_id_converter_data cid_data = { .module = mod, }; @@ -1614,6 +1640,7 @@ _channelid_new(PyObject *mod, PyTypeObject *cls, return NULL; } cid = cid_data.cid; + end = cid_data.end; // Handle "send" and "recv". if (send == 0 && recv == 0) { @@ -1621,14 +1648,17 @@ _channelid_new(PyObject *mod, PyTypeObject *cls, "'send' and 'recv' cannot both be False"); return NULL; } - - int end = 0; - if (send == 1) { + else if (send == 1) { if (recv == 0 || recv == -1) { end = CHANNEL_SEND; } + else { + assert(recv == 1); + end = 0; + } } else if (recv == 1) { + assert(send == 0 || send == -1); end = CHANNEL_RECV; } @@ -1773,21 +1803,12 @@ done: return res; } +static PyTypeObject * _get_current_channel_end_type(int end); + static PyObject * _channel_from_cid(PyObject *cid, int end) { - PyObject *highlevel = PyImport_ImportModule("interpreters"); - if (highlevel == NULL) { - PyErr_Clear(); - highlevel = PyImport_ImportModule("test.support.interpreters"); - if (highlevel == NULL) { - return NULL; - } - } - const char *clsname = (end == CHANNEL_RECV) ? "RecvChannel" : - "SendChannel"; - PyObject *cls = PyObject_GetAttrString(highlevel, clsname); - Py_DECREF(highlevel); + PyObject *cls = (PyObject *)_get_current_channel_end_type(end); if (cls == NULL) { return NULL; } @@ -1943,6 +1964,103 @@ static PyType_Spec ChannelIDType_spec = { }; +/* SendChannel and RecvChannel classes */ + +// XXX Use a new __xid__ protocol instead? + +static PyTypeObject * +_get_current_channel_end_type(int end) +{ + module_state *state = _get_current_module_state(); + if (state == NULL) { + return NULL; + } + PyTypeObject *cls; + if (end == CHANNEL_SEND) { + cls = state->send_channel_type; + } + else { + assert(end == CHANNEL_RECV); + cls = state->recv_channel_type; + } + if (cls == NULL) { + PyObject *highlevel = PyImport_ImportModule("interpreters"); + if (highlevel == NULL) { + PyErr_Clear(); + highlevel = PyImport_ImportModule("test.support.interpreters"); + if (highlevel == NULL) { + return NULL; + } + } + if (end == CHANNEL_SEND) { + cls = state->send_channel_type; + } + else { + cls = state->recv_channel_type; + } + assert(cls != NULL); + } + return cls; +} + +static PyObject * +_channel_end_from_xid(_PyCrossInterpreterData *data) +{ + channelid *cid = (channelid *)_channelid_from_xid(data); + if (cid == NULL) { + return NULL; + } + PyTypeObject *cls = _get_current_channel_end_type(cid->end); + if (cls == NULL) { + return NULL; + } + PyObject *obj = PyObject_CallOneArg((PyObject *)cls, (PyObject *)cid); + Py_DECREF(cid); + return obj; +} + +static int +_channel_end_shared(PyThreadState *tstate, PyObject *obj, + _PyCrossInterpreterData *data) +{ + PyObject *cidobj = PyObject_GetAttrString(obj, "_id"); + if (cidobj == NULL) { + return -1; + } + if (_channelid_shared(tstate, cidobj, data) < 0) { + return -1; + } + data->new_object = _channel_end_from_xid; + return 0; +} + +static int +set_channel_end_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv) +{ + module_state *state = get_module_state(mod); + if (state == NULL) { + return -1; + } + + if (state->send_channel_type != NULL + || state->recv_channel_type != NULL) + { + PyErr_SetString(PyExc_TypeError, "already registered"); + return -1; + } + state->send_channel_type = (PyTypeObject *)Py_NewRef(send); + state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv); + + if (_PyCrossInterpreterData_RegisterClass(send, _channel_end_shared)) { + return -1; + } + if (_PyCrossInterpreterData_RegisterClass(recv, _channel_end_shared)) { + return -1; + } + + return 0; +} + /* module level code ********************************************************/ /* globals is the process-global state for the module. It holds all @@ -2346,13 +2464,38 @@ channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds) return NULL; } PyTypeObject *cls = state->ChannelIDType; - PyObject *mod = get_module_from_owned_type(cls); - if (mod == NULL) { + assert(get_module_from_owned_type(cls) == self); + + return _channelid_new(self, cls, args, kwds); +} + +static PyObject * +channel__register_end_types(PyObject *self, PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"send", "recv", NULL}; + PyObject *send; + PyObject *recv; + if (!PyArg_ParseTupleAndKeywords(args, kwds, + "OO:_register_end_types", kwlist, + &send, &recv)) { return NULL; } - PyObject *cid = _channelid_new(mod, cls, args, kwds); - Py_DECREF(mod); - return cid; + if (!PyType_Check(send)) { + PyErr_SetString(PyExc_TypeError, "expected a type for 'send'"); + return NULL; + } + if (!PyType_Check(recv)) { + PyErr_SetString(PyExc_TypeError, "expected a type for 'recv'"); + return NULL; + } + PyTypeObject *cls_send = (PyTypeObject *)send; + PyTypeObject *cls_recv = (PyTypeObject *)recv; + + if (set_channel_end_types(self, cls_send, cls_recv) < 0) { + return NULL; + } + + Py_RETURN_NONE; } static PyMethodDef module_functions[] = { @@ -2374,6 +2517,8 @@ static PyMethodDef module_functions[] = { METH_VARARGS | METH_KEYWORDS, channel_release_doc}, {"_channel_id", _PyCFunction_CAST(channel__channel_id), METH_VARARGS | METH_KEYWORDS, NULL}, + {"_register_end_types", _PyCFunction_CAST(channel__register_end_types), + METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL} /* sentinel */ }; |