diff options
author | Eric Snow <ericsnowcurrently@gmail.com> | 2022-12-02 18:36:57 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-02 18:36:57 (GMT) |
commit | 0547a981ae413248b21a6bb0cb62dda7d236fe45 (patch) | |
tree | 284f9abff17da763537e6087b30452a6f4aeadc9 | |
parent | ab02262cd0385a2fb5eb8a6ee3cedd4b4bb969f3 (diff) | |
download | cpython-0547a981ae413248b21a6bb0cb62dda7d236fe45.zip cpython-0547a981ae413248b21a6bb0cb62dda7d236fe45.tar.gz cpython-0547a981ae413248b21a6bb0cb62dda7d236fe45.tar.bz2 |
gh-99741: Clean Up the _xxsubinterpreters Module (gh-99940)
This cleanup up resolves a few subtle bugs and makes the implementation for multi-phase init much cleaner.
https://github.com/python/cpython/issues/99741
-rw-r--r-- | Include/cpython/pystate.h | 2 | ||||
-rw-r--r-- | Lib/test/test__xxsubinterpreters.py | 26 | ||||
-rw-r--r-- | Modules/_xxsubinterpretersmodule.c | 914 | ||||
-rw-r--r-- | Python/pystate.c | 20 |
4 files changed, 652 insertions, 310 deletions
diff --git a/Include/cpython/pystate.h b/Include/cpython/pystate.h index 7468a1c..0f56b1f 100644 --- a/Include/cpython/pystate.h +++ b/Include/cpython/pystate.h @@ -394,7 +394,7 @@ struct _xid { PyAPI_FUNC(int) _PyObject_GetCrossInterpreterData(PyObject *, _PyCrossInterpreterData *); PyAPI_FUNC(PyObject *) _PyCrossInterpreterData_NewObject(_PyCrossInterpreterData *); -PyAPI_FUNC(void) _PyCrossInterpreterData_Release(_PyCrossInterpreterData *); +PyAPI_FUNC(int) _PyCrossInterpreterData_Release(_PyCrossInterpreterData *); PyAPI_FUNC(int) _PyObject_CheckCrossInterpreterData(PyObject *); diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py index 66f29b9..f274b63 100644 --- a/Lib/test/test__xxsubinterpreters.py +++ b/Lib/test/test__xxsubinterpreters.py @@ -386,7 +386,6 @@ class ShareableTypeTests(unittest.TestCase): self._assert_values([ b'spam', 9999, - self.cid, ]) def test_bytes(self): @@ -1213,6 +1212,18 @@ class ChannelIDTests(TestBase): self.assertFalse(cid1 != cid2) self.assertTrue(cid1 != cid3) + def test_shareable(self): + chan = interpreters.channel_create() + + obj = interpreters.channel_create() + interpreters.channel_send(chan, obj) + got = interpreters.channel_recv(chan) + + self.assertEqual(got, obj) + self.assertIs(type(got), type(obj)) + # XXX Check the following in the channel tests? + #self.assertIsNot(got, obj) + class ChannelTests(TestBase): @@ -1545,6 +1556,19 @@ class ChannelTests(TestBase): self.assertEqual(obj5, b'eggs') self.assertIs(obj6, default) + def test_recv_sending_interp_destroyed(self): + cid = interpreters.channel_create() + interp = interpreters.create() + interpreters.run_string(interp, dedent(f""" + import _xxsubinterpreters as _interpreters + _interpreters.channel_send({cid}, b'spam') + """)) + interpreters.destroy(interp) + + with self.assertRaisesRegex(RuntimeError, + 'unrecognized interpreter ID'): + interpreters.channel_recv(cid) + def test_run_string_arg_unresolved(self): cid = interpreters.channel_create() interp = interpreters.create() diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index 2c9e0cd..3e064ca 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -6,11 +6,15 @@ #endif #include "Python.h" +// XXX This module should not rely on internal API. #include "pycore_frame.h" #include "pycore_pystate.h" // _PyThreadState_GET() #include "pycore_interpreteridobject.h" +#define MODULE_NAME "_xxsubinterpreters" + + static char * _copy_raw_string(PyObject *strobj) { @@ -28,13 +32,126 @@ _copy_raw_string(PyObject *strobj) } static PyInterpreterState * -_get_current(void) +_get_current_interp(void) { // PyInterpreterState_Get() aborts if lookup fails, so don't need // to check the result for NULL. return PyInterpreterState_Get(); } +static PyObject * +_get_current_module(void) +{ + // We ensured it was imported in _run_script(). + PyObject *name = PyUnicode_FromString(MODULE_NAME); + if (name == NULL) { + return NULL; + } + PyObject *mod = PyImport_GetModule(name); + Py_DECREF(name); + if (mod == NULL) { + return NULL; + } + assert(mod != Py_None); + return mod; +} + +static PyObject * +get_module_from_owned_type(PyTypeObject *cls) +{ + assert(cls != NULL); + return _get_current_module(); + // XXX Use the more efficient API now that we use heap types: + //return PyType_GetModule(cls); +} + +static struct PyModuleDef moduledef; + +static PyObject * +get_module_from_type(PyTypeObject *cls) +{ + assert(cls != NULL); + return _get_current_module(); + // XXX Use the more efficient API now that we use heap types: + //return PyType_GetModuleByDef(cls, &moduledef); +} + +static PyObject * +add_new_exception(PyObject *mod, const char *name, PyObject *base) +{ + assert(!PyObject_HasAttrString(mod, name)); + PyObject *exctype = PyErr_NewException(name, base, NULL); + if (exctype == NULL) { + return NULL; + } + int res = PyModule_AddType(mod, (PyTypeObject *)exctype); + if (res < 0) { + Py_DECREF(exctype); + return NULL; + } + return exctype; +} + +#define ADD_NEW_EXCEPTION(MOD, NAME, BASE) \ + add_new_exception(MOD, MODULE_NAME "." Py_STRINGIFY(NAME), BASE) + +static PyTypeObject * +add_new_type(PyObject *mod, PyTypeObject *cls, crossinterpdatafunc shared) +{ + if (PyType_Ready(cls) != 0) { + return NULL; + } + if (PyModule_AddType(mod, cls) != 0) { + // XXX When this becomes a heap type, we need to decref here. + return NULL; + } + if (shared != NULL) { + if (_PyCrossInterpreterData_RegisterClass(cls, shared)) { + // XXX When this becomes a heap type, we need to decref here. + return NULL; + } + } + return cls; +} + +static int +_release_xid_data(_PyCrossInterpreterData *data, int ignoreexc) +{ + PyObject *exctype, *excval, *exctb; + if (ignoreexc) { + PyErr_Fetch(&exctype, &excval, &exctb); + } + int res = _PyCrossInterpreterData_Release(data); + if (res < 0) { + // XXX Fix this! + /* The owning interpreter is already destroyed. + * Ideally, this shouldn't ever happen. When an interpreter is + * about to be destroyed, we should clear out all of its objects + * from every channel associated with that interpreter. + * For now we hack around that to resolve refleaks, by decref'ing + * the released object here, even if its the wrong interpreter. + * The owning interpreter has already been destroyed + * so we should be okay, especially since the currently + * shareable types are all very basic, with no GC. + * That said, it becomes much messier once interpreters + * no longer share a GIL, so this needs to be fixed before then. */ + // We do what _release_xidata() does in pystate.c. + if (data->free != NULL) { + data->free(data->data); + data->data = NULL; + } + Py_CLEAR(data->obj); + if (ignoreexc) { + // XXX Emit a warning? + PyErr_Clear(); + } + } + if (ignoreexc) { + PyErr_Restore(exctype, excval, exctb); + } + return res; +} + /* data-sharing-specific code ***********************************************/ @@ -66,7 +183,7 @@ _sharednsitem_clear(struct _sharednsitem *item) PyMem_Free(item->name); item->name = NULL; } - _PyCrossInterpreterData_Release(&item->data); + (void)_release_xid_data(&item->data, 1); } static int @@ -121,8 +238,10 @@ _sharedns_free(_sharedns *shared) } static _sharedns * -_get_shared_ns(PyObject *shareable) +_get_shared_ns(PyObject *shareable, PyTypeObject *channelidtype, + int *needs_import) { + *needs_import = 0; if (shareable == NULL || shareable == Py_None) { return NULL; } @@ -144,6 +263,9 @@ _get_shared_ns(PyObject *shareable) if (_sharednsitem_init(&shared->items[i], key, value) != 0) { break; } + if (Py_TYPE(value) == channelidtype) { + *needs_import = 1; + } } if (PyErr_Occurred()) { _sharedns_free(shared); @@ -287,6 +409,17 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass) #define CHANNEL_BOTH 0 #define CHANNEL_RECV -1 +/* channel errors */ + +#define ERR_CHANNEL_NOT_FOUND -2 +#define ERR_CHANNEL_CLOSED -3 +#define ERR_CHANNEL_INTERP_CLOSED -4 +#define ERR_CHANNEL_EMPTY -5 +#define ERR_CHANNEL_NOT_EMPTY -6 +#define ERR_CHANNEL_MUTEX_INIT -7 +#define ERR_CHANNELS_MUTEX_INIT -8 +#define ERR_NO_NEXT_CHANNEL_ID -9 + static PyObject *ChannelError; static PyObject *ChannelNotFoundError; static PyObject *ChannelClosedError; @@ -294,61 +427,81 @@ static PyObject *ChannelEmptyError; static PyObject *ChannelNotEmptyError; static int -channel_exceptions_init(PyObject *ns) +channel_exceptions_init(PyObject *mod) { // XXX Move the exceptions into per-module memory? +#define ADD(NAME, BASE) \ + do { \ + if (NAME == NULL) { \ + NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \ + if (NAME == NULL) { \ + return -1; \ + } \ + } \ + } while (0) + // A channel-related operation failed. - ChannelError = PyErr_NewException("_xxsubinterpreters.ChannelError", - PyExc_RuntimeError, NULL); - if (ChannelError == NULL) { - return -1; + ADD(ChannelError, PyExc_RuntimeError); + // An operation tried to use a channel that doesn't exist. + ADD(ChannelNotFoundError, ChannelError); + // An operation tried to use a closed channel. + ADD(ChannelClosedError, ChannelError); + // An operation tried to pop from an empty channel. + ADD(ChannelEmptyError, ChannelError); + // An operation tried to close a non-empty channel. + ADD(ChannelNotEmptyError, ChannelError); +#undef ADD + + return 0; +} + +static int +handle_channel_error(int err, PyObject *Py_UNUSED(mod), int64_t cid) +{ + if (err == 0) { + assert(!PyErr_Occurred()); + return 0; } - if (PyDict_SetItemString(ns, "ChannelError", ChannelError) != 0) { - return -1; + assert(err < 0); + if (err == ERR_CHANNEL_NOT_FOUND) { + PyErr_Format(ChannelNotFoundError, + "channel %" PRId64 " not found", cid); } - - // An operation tried to use a channel that doesn't exist. - ChannelNotFoundError = PyErr_NewException( - "_xxsubinterpreters.ChannelNotFoundError", ChannelError, NULL); - if (ChannelNotFoundError == NULL) { - return -1; + else if (err == ERR_CHANNEL_CLOSED) { + PyErr_Format(ChannelClosedError, + "channel %" PRId64 " is closed", cid); } - if (PyDict_SetItemString(ns, "ChannelNotFoundError", ChannelNotFoundError) != 0) { - return -1; + else if (err == ERR_CHANNEL_INTERP_CLOSED) { + PyErr_Format(ChannelClosedError, + "channel %" PRId64 " is already closed", cid); } - - // An operation tried to use a closed channel. - ChannelClosedError = PyErr_NewException( - "_xxsubinterpreters.ChannelClosedError", ChannelError, NULL); - if (ChannelClosedError == NULL) { - return -1; + else if (err == ERR_CHANNEL_EMPTY) { + PyErr_Format(ChannelEmptyError, + "channel %" PRId64 " is empty", cid); } - if (PyDict_SetItemString(ns, "ChannelClosedError", ChannelClosedError) != 0) { - return -1; + else if (err == ERR_CHANNEL_NOT_EMPTY) { + PyErr_Format(ChannelNotEmptyError, + "channel %" PRId64 " may not be closed " + "if not empty (try force=True)", + cid); } - - // An operation tried to pop from an empty channel. - ChannelEmptyError = PyErr_NewException( - "_xxsubinterpreters.ChannelEmptyError", ChannelError, NULL); - if (ChannelEmptyError == NULL) { - return -1; + else if (err == ERR_CHANNEL_MUTEX_INIT) { + PyErr_SetString(ChannelError, + "can't initialize mutex for new channel"); } - if (PyDict_SetItemString(ns, "ChannelEmptyError", ChannelEmptyError) != 0) { - return -1; + else if (err == ERR_CHANNELS_MUTEX_INIT) { + PyErr_SetString(ChannelError, + "can't initialize mutex for channel management"); } - - // An operation tried to close a non-empty channel. - ChannelNotEmptyError = PyErr_NewException( - "_xxsubinterpreters.ChannelNotEmptyError", ChannelError, NULL); - if (ChannelNotEmptyError == NULL) { - return -1; + else if (err == ERR_NO_NEXT_CHANNEL_ID) { + PyErr_SetString(ChannelError, + "failed to get a channel ID"); } - if (PyDict_SetItemString(ns, "ChannelNotEmptyError", ChannelNotEmptyError) != 0) { - return -1; + else { + assert(PyErr_Occurred()); } - - return 0; + return 1; } /* the channel queue */ @@ -377,7 +530,7 @@ static void _channelitem_clear(_channelitem *item) { if (item->data != NULL) { - _PyCrossInterpreterData_Release(item->data); + (void)_release_xid_data(item->data, 1); PyMem_Free(item->data); item->data = NULL; } @@ -621,8 +774,7 @@ _channelends_associate(_channelends *ends, int64_t interp, int send) interp, &prev); if (end != NULL) { if (!end->open) { - PyErr_SetString(ChannelClosedError, "channel already closed"); - return -1; + return ERR_CHANNEL_CLOSED; } // already associated return 0; @@ -721,19 +873,13 @@ typedef struct _channel { } _PyChannelState; static _PyChannelState * -_channel_new(void) +_channel_new(PyThread_type_lock mutex) { _PyChannelState *chan = PyMem_NEW(_PyChannelState, 1); if (chan == NULL) { return NULL; } - chan->mutex = PyThread_allocate_lock(); - if (chan->mutex == NULL) { - PyMem_Free(chan); - PyErr_SetString(ChannelError, - "can't initialize mutex for new channel"); - return NULL; - } + chan->mutex = mutex; chan->queue = _channelqueue_new(); if (chan->queue == NULL) { PyMem_Free(chan); @@ -771,10 +917,11 @@ _channel_add(_PyChannelState *chan, int64_t interp, PyThread_acquire_lock(chan->mutex, WAIT_LOCK); if (!chan->open) { - PyErr_SetString(ChannelClosedError, "channel closed"); + res = ERR_CHANNEL_CLOSED; goto done; } if (_channelends_associate(chan->ends, interp, 1) != 0) { + res = ERR_CHANNEL_INTERP_CLOSED; goto done; } @@ -788,31 +935,34 @@ done: return res; } -static _PyCrossInterpreterData * -_channel_next(_PyChannelState *chan, int64_t interp) +static int +_channel_next(_PyChannelState *chan, int64_t interp, + _PyCrossInterpreterData **res) { - _PyCrossInterpreterData *data = NULL; + int err = 0; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); if (!chan->open) { - PyErr_SetString(ChannelClosedError, "channel closed"); + err = ERR_CHANNEL_CLOSED; goto done; } if (_channelends_associate(chan->ends, interp, 0) != 0) { + err = ERR_CHANNEL_INTERP_CLOSED; goto done; } - data = _channelqueue_get(chan->queue); + _PyCrossInterpreterData *data = _channelqueue_get(chan->queue); if (data == NULL && !PyErr_Occurred() && chan->closing != NULL) { chan->open = 0; } + *res = data; done: PyThread_release_lock(chan->mutex); if (chan->queue->count == 0) { _channel_finish_closing(chan); } - return data; + return err; } static int @@ -822,7 +972,7 @@ _channel_close_interpreter(_PyChannelState *chan, int64_t interp, int end) int res = -1; if (!chan->open) { - PyErr_SetString(ChannelClosedError, "channel already closed"); + res = ERR_CHANNEL_CLOSED; goto done; } @@ -844,13 +994,12 @@ _channel_close_all(_PyChannelState *chan, int end, int force) PyThread_acquire_lock(chan->mutex, WAIT_LOCK); if (!chan->open) { - PyErr_SetString(ChannelClosedError, "channel already closed"); + res = ERR_CHANNEL_CLOSED; goto done; } if (!force && chan->queue->count > 0) { - PyErr_SetString(ChannelNotEmptyError, - "may not be closed if not empty (try force=True)"); + res = ERR_CHANNEL_NOT_EMPTY; goto done; } @@ -935,21 +1084,24 @@ typedef struct _channels { int64_t next_id; } _channels; -static int -_channels_init(_channels *channels) +static void +_channels_init(_channels *channels, PyThread_type_lock mutex) { - if (channels->mutex == NULL) { - channels->mutex = PyThread_allocate_lock(); - if (channels->mutex == NULL) { - PyErr_SetString(ChannelError, - "can't initialize mutex for channel management"); - return -1; - } - } + channels->mutex = mutex; channels->head = NULL; channels->numopen = 0; channels->next_id = 0; - return 0; +} + +static void +_channels_fini(_channels *channels) +{ + assert(channels->numopen == 0); + assert(channels->head == NULL); + if (channels->mutex != NULL) { + PyThread_free_lock(channels->mutex); + channels->mutex = NULL; + } } static int64_t @@ -958,17 +1110,17 @@ _channels_next_id(_channels *channels) // needs lock int64_t id = channels->next_id; if (id < 0) { /* overflow */ - PyErr_SetString(ChannelError, - "failed to get a channel ID"); return -1; } channels->next_id += 1; return id; } -static _PyChannelState * -_channels_lookup(_channels *channels, int64_t id, PyThread_type_lock *pmutex) +static int +_channels_lookup(_channels *channels, int64_t id, PyThread_type_lock *pmutex, + _PyChannelState **res) { + int err = -1; _PyChannelState *chan = NULL; PyThread_acquire_lock(channels->mutex, WAIT_LOCK); if (pmutex != NULL) { @@ -977,11 +1129,11 @@ _channels_lookup(_channels *channels, int64_t id, PyThread_type_lock *pmutex) _channelref *ref = _channelref_find(channels->head, id, NULL); if (ref == NULL) { - PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", id); + err = ERR_CHANNEL_NOT_FOUND; goto done; } if (ref->chan == NULL || !ref->chan->open) { - PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", id); + err = ERR_CHANNEL_CLOSED; goto done; } @@ -991,11 +1143,14 @@ _channels_lookup(_channels *channels, int64_t id, PyThread_type_lock *pmutex) } chan = ref->chan; + err = 0; + done: if (pmutex == NULL || *pmutex == NULL) { PyThread_release_lock(channels->mutex); } - return chan; + *res = chan; + return err; } static int64_t @@ -1007,6 +1162,7 @@ _channels_add(_channels *channels, _PyChannelState *chan) // Create a new ref. int64_t id = _channels_next_id(channels); if (id < 0) { + cid = ERR_NO_NEXT_CHANNEL_ID; goto done; } _channelref *ref = _channelref_new(id, chan); @@ -1041,31 +1197,32 @@ _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan, _channelref *ref = _channelref_find(channels->head, cid, NULL); if (ref == NULL) { - PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", cid); + res = ERR_CHANNEL_NOT_FOUND; goto done; } if (ref->chan == NULL) { - PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", cid); + res = ERR_CHANNEL_CLOSED; goto done; } else if (!force && end == CHANNEL_SEND && ref->chan->closing != NULL) { - PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", cid); + res = ERR_CHANNEL_CLOSED; goto done; } else { - if (_channel_close_all(ref->chan, end, force) != 0) { - if (end == CHANNEL_SEND && - PyErr_ExceptionMatches(ChannelNotEmptyError)) { + int err = _channel_close_all(ref->chan, end, force); + if (err != 0) { + if (end == CHANNEL_SEND && err == ERR_CHANNEL_NOT_EMPTY) { if (ref->chan->closing != NULL) { - PyErr_Format(ChannelClosedError, - "channel %" PRId64 " closed", cid); + res = ERR_CHANNEL_CLOSED; goto done; } // Mark the channel as closing and return. The channel // will be cleaned up in _channel_next(). PyErr_Clear(); - if (_channel_set_closing(ref, channels->mutex) != 0) { + int err = _channel_set_closing(ref, channels->mutex); + if (err != 0) { + res = err; goto done; } if (pchan != NULL) { @@ -1073,6 +1230,9 @@ _channels_close(_channels *channels, int64_t cid, _PyChannelState **pchan, } res = 0; } + else { + res = err; + } goto done; } if (pchan != NULL) { @@ -1121,7 +1281,7 @@ _channels_remove(_channels *channels, int64_t id, _PyChannelState **pchan) _channelref *prev = NULL; _channelref *ref = _channelref_find(channels->head, id, &prev); if (ref == NULL) { - PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", id); + res = ERR_CHANNEL_NOT_FOUND; goto done; } @@ -1141,7 +1301,7 @@ _channels_add_id_object(_channels *channels, int64_t id) _channelref *ref = _channelref_find(channels->head, id, NULL); if (ref == NULL) { - PyErr_Format(ChannelNotFoundError, "channel %" PRId64 " not found", id); + res = ERR_CHANNEL_NOT_FOUND; goto done; } ref->objcount += 1; @@ -1215,7 +1375,7 @@ _channel_set_closing(struct _channelref *ref, PyThread_type_lock mutex) { int res = -1; PyThread_acquire_lock(chan->mutex, WAIT_LOCK); if (chan->closing != NULL) { - PyErr_SetString(ChannelClosedError, "channel closed"); + res = ERR_CHANNEL_CLOSED; goto done; } chan->closing = PyMem_NEW(struct _channel_closing, 1); @@ -1258,14 +1418,18 @@ _channel_finish_closing(struct _channel *chan) { static int64_t _channel_create(_channels *channels) { - _PyChannelState *chan = _channel_new(); + PyThread_type_lock mutex = PyThread_allocate_lock(); + if (mutex == NULL) { + return ERR_CHANNEL_MUTEX_INIT; + } + _PyChannelState *chan = _channel_new(mutex); if (chan == NULL) { + PyThread_free_lock(mutex); return -1; } int64_t id = _channels_add(channels, chan); if (id < 0) { _channel_free(chan); - return -1; } return id; } @@ -1274,8 +1438,9 @@ static int _channel_destroy(_channels *channels, int64_t id) { _PyChannelState *chan = NULL; - if (_channels_remove(channels, id, &chan) != 0) { - return -1; + int err = _channels_remove(channels, id, &chan); + if (err != 0) { + return err; } if (chan != NULL) { _channel_free(chan); @@ -1286,23 +1451,24 @@ _channel_destroy(_channels *channels, int64_t id) static int _channel_send(_channels *channels, int64_t id, PyObject *obj) { - PyInterpreterState *interp = _get_current(); + PyInterpreterState *interp = _get_current_interp(); if (interp == NULL) { return -1; } // Look up the channel. PyThread_type_lock mutex = NULL; - _PyChannelState *chan = _channels_lookup(channels, id, &mutex); - if (chan == NULL) { - return -1; + _PyChannelState *chan = NULL; + int err = _channels_lookup(channels, id, &mutex, &chan); + if (err != 0) { + return err; } + assert(chan != NULL); // Past this point we are responsible for releasing the mutex. if (chan->closing != NULL) { - PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", id); PyThread_release_lock(mutex); - return -1; + return ERR_CHANNEL_CLOSED; } // Convert the object to cross-interpreter data. @@ -1321,61 +1487,87 @@ _channel_send(_channels *channels, int64_t id, PyObject *obj) int res = _channel_add(chan, PyInterpreterState_GetID(interp), data); PyThread_release_lock(mutex); if (res != 0) { - _PyCrossInterpreterData_Release(data); + // We may chain an exception here: + (void)_release_xid_data(data, 0); PyMem_Free(data); - return -1; + return res; } return 0; } -static PyObject * -_channel_recv(_channels *channels, int64_t id) +static int +_channel_recv(_channels *channels, int64_t id, PyObject **res) { - PyInterpreterState *interp = _get_current(); + int err; + *res = NULL; + + PyInterpreterState *interp = _get_current_interp(); if (interp == NULL) { - return NULL; + // XXX Is this always an error? + if (PyErr_Occurred()) { + return -1; + } + return 0; } // Look up the channel. PyThread_type_lock mutex = NULL; - _PyChannelState *chan = _channels_lookup(channels, id, &mutex); - if (chan == NULL) { - return NULL; + _PyChannelState *chan = NULL; + err = _channels_lookup(channels, id, &mutex, &chan); + if (err != 0) { + return err; } + assert(chan != NULL); // Past this point we are responsible for releasing the mutex. // Pop off the next item from the channel. - _PyCrossInterpreterData *data = _channel_next(chan, PyInterpreterState_GetID(interp)); + _PyCrossInterpreterData *data = NULL; + err = _channel_next(chan, PyInterpreterState_GetID(interp), &data); PyThread_release_lock(mutex); - if (data == NULL) { - return NULL; + if (err != 0) { + return err; + } + else if (data == NULL) { + assert(!PyErr_Occurred()); + return 0; } // Convert the data back to an object. PyObject *obj = _PyCrossInterpreterData_NewObject(data); - _PyCrossInterpreterData_Release(data); - PyMem_Free(data); if (obj == NULL) { - return NULL; + assert(PyErr_Occurred()); + (void)_release_xid_data(data, 1); + PyMem_Free(data); + return -1; + } + int release_res = _release_xid_data(data, 0); + PyMem_Free(data); + if (release_res < 0) { + // The source interpreter has been destroyed already. + assert(PyErr_Occurred()); + Py_DECREF(obj); + return -1; } - return obj; + *res = obj; + return 0; } static int _channel_drop(_channels *channels, int64_t id, int send, int recv) { - PyInterpreterState *interp = _get_current(); + PyInterpreterState *interp = _get_current_interp(); if (interp == NULL) { return -1; } // Look up the channel. PyThread_type_lock mutex = NULL; - _PyChannelState *chan = _channels_lookup(channels, id, &mutex); - if (chan == NULL) { - return -1; + _PyChannelState *chan = NULL; + int err = _channels_lookup(channels, id, &mutex, &chan); + if (err != 0) { + return err; } // Past this point we are responsible for releasing the mutex. @@ -1395,12 +1587,13 @@ static int _channel_is_associated(_channels *channels, int64_t cid, int64_t interp, int send) { - _PyChannelState *chan = _channels_lookup(channels, cid, NULL); - if (chan == NULL) { - return -1; - } else if (send && chan->closing != NULL) { - PyErr_Format(ChannelClosedError, "channel %" PRId64 " closed", cid); - return -1; + _PyChannelState *chan = NULL; + int err = _channels_lookup(channels, cid, NULL, &chan); + if (err != 0) { + return err; + } + else if (send && chan->closing != NULL) { + return ERR_CHANNEL_CLOSED; } _channelend *end = _channelend_find(send ? chan->ends->send : chan->ends->recv, @@ -1411,7 +1604,7 @@ _channel_is_associated(_channels *channels, int64_t cid, int64_t interp, /* ChannelID class */ -static PyTypeObject ChannelIDtype; +static PyTypeObject ChannelIDType; typedef struct channelid { PyObject_HEAD @@ -1421,11 +1614,17 @@ typedef struct channelid { _channels *channels; } channelid; +struct channel_id_converter_data { + PyObject *module; + int64_t cid; +}; + static int channel_id_converter(PyObject *arg, void *ptr) { int64_t cid; - if (PyObject_TypeCheck(arg, &ChannelIDtype)) { + struct channel_id_converter_data *data = ptr; + if (PyObject_TypeCheck(arg, &ChannelIDType)) { cid = ((channelid *)arg)->id; } else if (PyIndex_Check(arg)) { @@ -1445,51 +1644,62 @@ channel_id_converter(PyObject *arg, void *ptr) Py_TYPE(arg)->tp_name); return 0; } - *(int64_t *)ptr = cid; + data->cid = cid; return 1; } -static channelid * +static int newchannelid(PyTypeObject *cls, int64_t cid, int end, _channels *channels, - int force, int resolve) + int force, int resolve, channelid **res) { + *res = NULL; + channelid *self = PyObject_New(channelid, cls); if (self == NULL) { - return NULL; + return -1; } self->id = cid; self->end = end; self->resolve = resolve; self->channels = channels; - if (_channels_add_id_object(channels, cid) != 0) { - if (force && PyErr_ExceptionMatches(ChannelNotFoundError)) { - PyErr_Clear(); + int err = _channels_add_id_object(channels, cid); + if (err != 0) { + if (force && err == ERR_CHANNEL_NOT_FOUND) { + assert(!PyErr_Occurred()); } else { Py_DECREF((PyObject *)self); - return NULL; + return err; } } - return self; + *res = self; + return 0; } static _channels * _global_channels(void); static PyObject * -channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds) +_channelid_new(PyObject *mod, PyTypeObject *cls, + PyObject *args, PyObject *kwds) { static char *kwlist[] = {"id", "send", "recv", "force", "_resolve", NULL}; int64_t cid; + struct channel_id_converter_data cid_data = { + .module = mod, + }; int send = -1; int recv = -1; int force = 0; int resolve = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|$pppp:ChannelID.__new__", kwlist, - channel_id_converter, &cid, &send, &recv, &force, &resolve)) + channel_id_converter, &cid_data, + &send, &recv, &force, &resolve)) { return NULL; + } + cid = cid_data.cid; // Handle "send" and "recv". if (send == 0 && recv == 0) { @@ -1508,8 +1718,16 @@ channelid_new(PyTypeObject *cls, PyObject *args, PyObject *kwds) end = CHANNEL_RECV; } - return (PyObject *)newchannelid(cls, cid, end, _global_channels(), - force, resolve); + PyObject *id = NULL; + int err = newchannelid(cls, cid, end, _global_channels(), + force, resolve, + (channelid **)&id); + if (handle_channel_error(err, mod, cid)) { + assert(id == NULL); + return NULL; + } + assert(id != NULL); + return id; } static void @@ -1557,43 +1775,8 @@ channelid_int(PyObject *self) } static PyNumberMethods channelid_as_number = { - 0, /* nb_add */ - 0, /* nb_subtract */ - 0, /* nb_multiply */ - 0, /* nb_remainder */ - 0, /* nb_divmod */ - 0, /* nb_power */ - 0, /* nb_negative */ - 0, /* nb_positive */ - 0, /* nb_absolute */ - 0, /* nb_bool */ - 0, /* nb_invert */ - 0, /* nb_lshift */ - 0, /* nb_rshift */ - 0, /* nb_and */ - 0, /* nb_xor */ - 0, /* nb_or */ - (unaryfunc)channelid_int, /* nb_int */ - 0, /* nb_reserved */ - 0, /* nb_float */ - - 0, /* nb_inplace_add */ - 0, /* nb_inplace_subtract */ - 0, /* nb_inplace_multiply */ - 0, /* nb_inplace_remainder */ - 0, /* nb_inplace_power */ - 0, /* nb_inplace_lshift */ - 0, /* nb_inplace_rshift */ - 0, /* nb_inplace_and */ - 0, /* nb_inplace_xor */ - 0, /* nb_inplace_or */ - - 0, /* nb_floor_divide */ - 0, /* nb_true_divide */ - 0, /* nb_inplace_floor_divide */ - 0, /* nb_inplace_true_divide */ - - (unaryfunc)channelid_int, /* nb_index */ + .nb_int = (unaryfunc)channelid_int, /* nb_int */ + .nb_index = (unaryfunc)channelid_int, /* nb_index */ }; static Py_hash_t @@ -1612,17 +1795,24 @@ channelid_hash(PyObject *self) static PyObject * channelid_richcompare(PyObject *self, PyObject *other, int op) { + PyObject *res = NULL; if (op != Py_EQ && op != Py_NE) { Py_RETURN_NOTIMPLEMENTED; } - if (!PyObject_TypeCheck(self, &ChannelIDtype)) { - Py_RETURN_NOTIMPLEMENTED; + PyObject *mod = get_module_from_type(Py_TYPE(self)); + if (mod == NULL) { + return NULL; + } + + if (!PyObject_TypeCheck(self, &ChannelIDType)) { + res = Py_NewRef(Py_NotImplemented); + goto done; } channelid *cid = (channelid *)self; int equal; - if (PyObject_TypeCheck(other, &ChannelIDtype)) { + if (PyObject_TypeCheck(other, &ChannelIDType)) { channelid *othercid = (channelid *)other; equal = (cid->end == othercid->end) && (cid->id == othercid->id); } @@ -1631,27 +1821,34 @@ channelid_richcompare(PyObject *self, PyObject *other, int op) int overflow; long long othercid = PyLong_AsLongLongAndOverflow(other, &overflow); if (othercid == -1 && PyErr_Occurred()) { - return NULL; + goto done; } equal = !overflow && (othercid >= 0) && (cid->id == othercid); } else if (PyNumber_Check(other)) { PyObject *pyid = PyLong_FromLongLong(cid->id); if (pyid == NULL) { - return NULL; + goto done; } - PyObject *res = PyObject_RichCompare(pyid, other, op); + res = PyObject_RichCompare(pyid, other, op); Py_DECREF(pyid); - return res; + goto done; } else { - Py_RETURN_NOTIMPLEMENTED; + res = Py_NewRef(Py_NotImplemented); + goto done; } if ((op == Py_EQ && equal) || (op == Py_NE && !equal)) { - Py_RETURN_TRUE; + res = Py_NewRef(Py_True); } - Py_RETURN_FALSE; + else { + res = Py_NewRef(Py_False); + } + +done: + Py_DECREF(mod); + return res; } static PyObject * @@ -1690,24 +1887,42 @@ static PyObject * _channelid_from_xid(_PyCrossInterpreterData *data) { struct _channelid_xid *xid = (struct _channelid_xid *)data->data; + + PyObject *mod = _get_current_module(); + if (mod == NULL) { + return NULL; + } + // Note that we do not preserve the "resolve" flag. - PyObject *cid = (PyObject *)newchannelid(&ChannelIDtype, xid->id, xid->end, - _global_channels(), 0, 0); + PyObject *cid = NULL; + int err = newchannelid(&ChannelIDType, xid->id, xid->end, + _global_channels(), 0, 0, + (channelid **)&cid); + if (err != 0) { + assert(cid == NULL); + (void)handle_channel_error(err, mod, xid->id); + goto done; + } + assert(cid != NULL); if (xid->end == 0) { - return cid; + goto done; } if (!xid->resolve) { - return cid; + goto done; } /* Try returning a high-level channel end but fall back to the ID. */ PyObject *chan = _channel_from_cid(cid, xid->end); if (chan == NULL) { PyErr_Clear(); - return cid; + goto done; } Py_DECREF(cid); - return chan; + cid = chan; + +done: + Py_DECREF(mod); + return cid; } static int @@ -1734,8 +1949,22 @@ channelid_end(PyObject *self, void *end) int force = 1; channelid *cid = (channelid *)self; if (end != NULL) { - return (PyObject *)newchannelid(Py_TYPE(self), cid->id, *(int *)end, - cid->channels, force, cid->resolve); + PyObject *id = NULL; + int err = newchannelid(Py_TYPE(self), cid->id, *(int *)end, + cid->channels, force, cid->resolve, + (channelid **)&id); + if (err != 0) { + assert(id == NULL); + PyObject *mod = get_module_from_type(Py_TYPE(self)); + if (mod == NULL) { + return NULL; + } + (void)handle_channel_error(err, mod, cid->id); + Py_DECREF(mod); + return NULL; + } + assert(id != NULL); + return id; } if (cid->end == CHANNEL_SEND) { @@ -1763,7 +1992,7 @@ static PyGetSetDef channelid_getsets[] = { PyDoc_STRVAR(channelid_doc, "A channel ID identifies a channel and may be used as an int."); -static PyTypeObject ChannelIDtype = { +static PyTypeObject ChannelIDType = { PyVarObject_HEAD_INIT(&PyType_Type, 0) "_xxsubinterpreters.ChannelID", /* tp_name */ sizeof(channelid), /* tp_basicsize */ @@ -1807,21 +2036,23 @@ static PyTypeObject ChannelIDtype = { static PyObject * RunFailedError = NULL; static int -interp_exceptions_init(PyObject *ns) +interp_exceptions_init(PyObject *mod) { // XXX Move the exceptions into per-module memory? - if (RunFailedError == NULL) { - // An uncaught exception came out of interp_run_string(). - RunFailedError = PyErr_NewException("_xxsubinterpreters.RunFailedError", - PyExc_RuntimeError, NULL); - if (RunFailedError == NULL) { - return -1; - } - if (PyDict_SetItemString(ns, "RunFailedError", RunFailedError) != 0) { - return -1; - } - } +#define ADD(NAME, BASE) \ + do { \ + if (NAME == NULL) { \ + NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \ + if (NAME == NULL) { \ + return -1; \ + } \ + } \ + } while (0) + + // An uncaught exception came out of interp_run_string(). + ADD(RunFailedError, PyExc_RuntimeError); +#undef ADD return 0; } @@ -1860,12 +2091,24 @@ _ensure_not_running(PyInterpreterState *interp) static int _run_script(PyInterpreterState *interp, const char *codestr, - _sharedns *shared, _sharedexception **exc) + _sharedns *shared, int needs_import, + _sharedexception **exc) { PyObject *exctype = NULL; PyObject *excval = NULL; PyObject *tb = NULL; + if (needs_import) { + // It might not have been imported yet in the current interpreter. + // However, it will (almost) always have been imported already + // in the main interpreter. + PyObject *mod = PyImport_ImportModule(MODULE_NAME); + if (mod == NULL) { + goto error; + } + Py_DECREF(mod); + } + PyObject *main_mod = _PyInterpreterState_GetMainModule(interp); if (main_mod == NULL) { goto error; @@ -1918,14 +2161,16 @@ error: } static int -_run_script_in_interpreter(PyInterpreterState *interp, const char *codestr, - PyObject *shareables) +_run_script_in_interpreter(PyObject *mod, PyInterpreterState *interp, + const char *codestr, PyObject *shareables) { if (_ensure_not_running(interp) < 0) { return -1; } - _sharedns *shared = _get_shared_ns(shareables); + int needs_import = 0; + _sharedns *shared = _get_shared_ns(shareables, &ChannelIDType, + &needs_import); if (shared == NULL && PyErr_Occurred()) { return -1; } @@ -1941,7 +2186,7 @@ _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr, // Run the script. _sharedexception *exc = NULL; - int result = _run_script(interp, codestr, shared, &exc); + int result = _run_script(interp, codestr, shared, needs_import, &exc); // Switch back. if (save_tstate != NULL) { @@ -1972,18 +2217,41 @@ _run_script_in_interpreter(PyInterpreterState *interp, const char *codestr, the data that we need to share between interpreters, so it cannot hold PyObject values. */ static struct globals { + int module_count; _channels channels; -} _globals = {{0}}; +} _globals = {0}; static int -_init_globals(void) +_globals_init(void) { - if (_channels_init(&_globals.channels) != 0) { - return -1; + // XXX This isn't thread-safe. + _globals.module_count++; + if (_globals.module_count > 1) { + // Already initialized. + return 0; } + + assert(_globals.channels.mutex == NULL); + PyThread_type_lock mutex = PyThread_allocate_lock(); + if (mutex == NULL) { + return ERR_CHANNELS_MUTEX_INIT; + } + _channels_init(&_globals.channels, mutex); return 0; } +static void +_globals_fini(void) +{ + // XXX This isn't thread-safe. + _globals.module_count--; + if (_globals.module_count > 0) { + return; + } + + _channels_fini(&_globals.channels); +} + static _channels * _global_channels(void) { return &_globals.channels; @@ -2052,7 +2320,7 @@ interp_destroy(PyObject *self, PyObject *args, PyObject *kwds) } // Ensure we don't try to destroy the current interpreter. - PyInterpreterState *current = _get_current(); + PyInterpreterState *current = _get_current_interp(); if (current == NULL) { return NULL; } @@ -2129,7 +2397,7 @@ Return a list containing the ID of every existing interpreter."); static PyObject * interp_get_current(PyObject *self, PyObject *Py_UNUSED(ignored)) { - PyInterpreterState *interp =_get_current(); + PyInterpreterState *interp =_get_current_interp(); if (interp == NULL) { return NULL; } @@ -2187,7 +2455,7 @@ interp_run_string(PyObject *self, PyObject *args, PyObject *kwds) } // Run the code in the interpreter. - if (_run_script_in_interpreter(interp, codestr, shared) != 0) { + if (_run_script_in_interpreter(self, interp, codestr, shared) != 0) { return NULL; } Py_RETURN_NONE; @@ -2259,16 +2527,22 @@ channel_create(PyObject *self, PyObject *Py_UNUSED(ignored)) { int64_t cid = _channel_create(&_globals.channels); if (cid < 0) { + (void)handle_channel_error(cid, self, -1); return NULL; } - PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, cid, 0, - &_globals.channels, 0, 0); - if (id == NULL) { - if (_channel_destroy(&_globals.channels, cid) != 0) { + PyObject *id = NULL; + int err = newchannelid(&ChannelIDType, cid, 0, + &_globals.channels, 0, 0, + (channelid **)&id); + if (handle_channel_error(err, self, cid)) { + assert(id == NULL); + err = _channel_destroy(&_globals.channels, cid); + if (handle_channel_error(err, self, cid)) { // XXX issue a warning? } return NULL; } + assert(id != NULL); assert(((channelid *)id)->channels != NULL); return id; } @@ -2283,12 +2557,17 @@ channel_destroy(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"cid", NULL}; int64_t cid; + struct channel_id_converter_data cid_data = { + .module = self, + }; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&:channel_destroy", kwlist, - channel_id_converter, &cid)) { + channel_id_converter, &cid_data)) { return NULL; } + cid = cid_data.cid; - if (_channel_destroy(&_globals.channels, cid) != 0) { + int err = _channel_destroy(&_globals.channels, cid); + if (handle_channel_error(err, self, cid)) { return NULL; } Py_RETURN_NONE; @@ -2317,12 +2596,16 @@ channel_list_all(PyObject *self, PyObject *Py_UNUSED(ignored)) } int64_t *cur = cids; for (int64_t i=0; i < count; cur++, i++) { - PyObject *id = (PyObject *)newchannelid(&ChannelIDtype, *cur, 0, - &_globals.channels, 0, 0); - if (id == NULL) { + PyObject *id = NULL; + int err = newchannelid(&ChannelIDType, *cur, 0, + &_globals.channels, 0, 0, + (channelid **)&id); + if (handle_channel_error(err, self, *cur)) { + assert(id == NULL); Py_SETREF(ids, NULL); break; } + assert(id != NULL); PyList_SET_ITEM(ids, (Py_ssize_t)i, id); } @@ -2341,6 +2624,9 @@ channel_list_interpreters(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"cid", "send", NULL}; int64_t cid; /* Channel ID */ + struct channel_id_converter_data cid_data = { + .module = self, + }; int send = 0; /* Send or receive end? */ int64_t id; PyObject *ids, *id_obj; @@ -2348,9 +2634,10 @@ channel_list_interpreters(PyObject *self, PyObject *args, PyObject *kwds) if (!PyArg_ParseTupleAndKeywords( args, kwds, "O&$p:channel_list_interpreters", - kwlist, channel_id_converter, &cid, &send)) { + kwlist, channel_id_converter, &cid_data, &send)) { return NULL; } + cid = cid_data.cid; ids = PyList_New(0); if (ids == NULL) { @@ -2363,6 +2650,7 @@ channel_list_interpreters(PyObject *self, PyObject *args, PyObject *kwds) assert(id >= 0); int res = _channel_is_associated(&_globals.channels, cid, id, send); if (res < 0) { + (void)handle_channel_error(res, self, cid); goto except; } if (res) { @@ -2402,13 +2690,18 @@ channel_send(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"cid", "obj", NULL}; int64_t cid; + struct channel_id_converter_data cid_data = { + .module = self, + }; PyObject *obj; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&O:channel_send", kwlist, - channel_id_converter, &cid, &obj)) { + channel_id_converter, &cid_data, &obj)) { return NULL; } + cid = cid_data.cid; - if (_channel_send(&_globals.channels, cid, obj) != 0) { + int err = _channel_send(&_globals.channels, cid, obj); + if (handle_channel_error(err, self, cid)) { return NULL; } Py_RETURN_NONE; @@ -2424,26 +2717,32 @@ channel_recv(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"cid", "default", NULL}; int64_t cid; + struct channel_id_converter_data cid_data = { + .module = self, + }; PyObject *dflt = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|O:channel_recv", kwlist, - channel_id_converter, &cid, &dflt)) { + channel_id_converter, &cid_data, &dflt)) { return NULL; } - Py_XINCREF(dflt); + cid = cid_data.cid; - PyObject *obj = _channel_recv(&_globals.channels, cid); - if (obj != NULL) { - Py_XDECREF(dflt); - return obj; - } else if (PyErr_Occurred()) { - Py_XDECREF(dflt); - return NULL; - } else if (dflt != NULL) { - return dflt; - } else { - PyErr_Format(ChannelEmptyError, "channel %" PRId64 " is empty", cid); + PyObject *obj = NULL; + int err = _channel_recv(&_globals.channels, cid, &obj); + if (handle_channel_error(err, self, cid)) { return NULL; } + Py_XINCREF(dflt); + if (obj == NULL) { + // Use the default. + if (dflt == NULL) { + (void)handle_channel_error(ERR_CHANNEL_EMPTY, self, cid); + return NULL; + } + obj = Py_NewRef(dflt); + } + Py_XDECREF(dflt); + return obj; } PyDoc_STRVAR(channel_recv_doc, @@ -2459,16 +2758,22 @@ channel_close(PyObject *self, PyObject *args, PyObject *kwds) { static char *kwlist[] = {"cid", "send", "recv", "force", NULL}; int64_t cid; + struct channel_id_converter_data cid_data = { + .module = self, + }; int send = 0; int recv = 0; int force = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|$ppp:channel_close", kwlist, - channel_id_converter, &cid, &send, &recv, &force)) { + channel_id_converter, &cid_data, + &send, &recv, &force)) { return NULL; } + cid = cid_data.cid; - if (_channel_close(&_globals.channels, cid, send-recv, force) != 0) { + int err = _channel_close(&_globals.channels, cid, send-recv, force); + if (handle_channel_error(err, self, cid)) { return NULL; } Py_RETURN_NONE; @@ -2507,14 +2812,19 @@ channel_release(PyObject *self, PyObject *args, PyObject *kwds) // Note that only the current interpreter is affected. static char *kwlist[] = {"cid", "send", "recv", "force", NULL}; int64_t cid; + struct channel_id_converter_data cid_data = { + .module = self, + }; int send = 0; int recv = 0; int force = 0; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|$ppp:channel_release", kwlist, - channel_id_converter, &cid, &send, &recv, &force)) { + channel_id_converter, &cid_data, + &send, &recv, &force)) { return NULL; } + cid = cid_data.cid; if (send == 0 && recv == 0) { send = 1; recv = 1; @@ -2523,7 +2833,8 @@ channel_release(PyObject *self, PyObject *args, PyObject *kwds) // XXX Handle force is True. // XXX Fix implicit release. - if (_channel_drop(&_globals.channels, cid, send, recv) != 0) { + int err = _channel_drop(&_globals.channels, cid, send, recv); + if (handle_channel_error(err, self, cid)) { return NULL; } Py_RETURN_NONE; @@ -2539,7 +2850,14 @@ ends are closed. Closing an already closed end is a noop."); static PyObject * channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds) { - return channelid_new(&ChannelIDtype, args, kwds); + PyTypeObject *cls = &ChannelIDType; + PyObject *mod = get_module_from_owned_type(cls); + if (mod == NULL) { + return NULL; + } + PyObject *cid = _channelid_new(mod, cls, args, kwds); + Py_DECREF(mod); + return cid; } static PyMethodDef module_functions[] = { @@ -2590,59 +2908,57 @@ PyDoc_STRVAR(module_doc, "This module provides primitive operations to manage Python interpreters.\n\ The 'interpreters' module provides a more convenient interface."); -static struct PyModuleDef interpretersmodule = { - PyModuleDef_HEAD_INIT, - "_xxsubinterpreters", /* m_name */ - module_doc, /* m_doc */ - -1, /* m_size */ - module_functions, /* m_methods */ - NULL, /* m_slots */ - NULL, /* m_traverse */ - NULL, /* m_clear */ - NULL /* m_free */ -}; - - -PyMODINIT_FUNC -PyInit__xxsubinterpreters(void) +static int +module_exec(PyObject *mod) { - if (_init_globals() != 0) { - return NULL; - } - - /* Initialize types */ - if (PyType_Ready(&ChannelIDtype) != 0) { - return NULL; - } - - /* Create the module */ - PyObject *module = PyModule_Create(&interpretersmodule); - if (module == NULL) { - return NULL; + if (_globals_init() != 0) { + return -1; } /* Add exception types */ - PyObject *ns = PyModule_GetDict(module); // borrowed - if (interp_exceptions_init(ns) != 0) { - return NULL; + if (interp_exceptions_init(mod) != 0) { + goto error; } - if (channel_exceptions_init(ns) != 0) { - return NULL; + if (channel_exceptions_init(mod) != 0) { + goto error; } /* Add other types */ - if (PyDict_SetItemString(ns, "ChannelID", - Py_NewRef(&ChannelIDtype)) != 0) { - return NULL; + if (add_new_type(mod, &ChannelIDType, _channelid_shared) == NULL) { + goto error; } - if (PyDict_SetItemString(ns, "InterpreterID", - Py_NewRef(&_PyInterpreterID_Type)) != 0) { - return NULL; + if (PyModule_AddType(mod, &_PyInterpreterID_Type) < 0) { + goto error; } - if (_PyCrossInterpreterData_RegisterClass(&ChannelIDtype, _channelid_shared)) { + return 0; + +error: + (void)_PyCrossInterpreterData_UnregisterClass(&ChannelIDType); + _globals_fini(); + return -1; +} + +static struct PyModuleDef moduledef = { + .m_base = PyModuleDef_HEAD_INIT, + .m_name = MODULE_NAME, + .m_doc = module_doc, + .m_size = -1, + .m_methods = module_functions, +}; + + +PyMODINIT_FUNC +PyInit__xxsubinterpreters(void) +{ + /* Create the module */ + PyObject *mod = PyModule_Create(&moduledef); + if (mod == NULL) { return NULL; } - - return module; + if (module_exec(mod) < 0) { + Py_DECREF(mod); + return NULL; + } + return mod; } diff --git a/Python/pystate.c b/Python/pystate.c index 2554cc2..0fdcdf1 100644 --- a/Python/pystate.c +++ b/Python/pystate.c @@ -1865,7 +1865,7 @@ _PyObject_GetCrossInterpreterData(PyObject *obj, _PyCrossInterpreterData *data) // Fill in the blanks and validate the result. data->interp = interp->id; if (_check_xidata(tstate, data) != 0) { - _PyCrossInterpreterData_Release(data); + (void)_PyCrossInterpreterData_Release(data); return -1; } @@ -1878,8 +1878,8 @@ _release_xidata(void *arg) _PyCrossInterpreterData *data = (_PyCrossInterpreterData *)arg; if (data->free != NULL) { data->free(data->data); - data->data = NULL; } + data->data = NULL; Py_CLEAR(data->obj); } @@ -1910,27 +1910,29 @@ _call_in_interpreter(struct _gilstate_runtime_state *gilstate, } } -void +int _PyCrossInterpreterData_Release(_PyCrossInterpreterData *data) { - if (data->data == NULL && data->obj == NULL) { + if (data->free == NULL && data->obj == NULL) { // Nothing to release! - return; + data->data = NULL; + return 0; } // Switch to the original interpreter. PyInterpreterState *interp = _PyInterpreterState_LookUpID(data->interp); if (interp == NULL) { // The interpreter was already destroyed. - if (data->free != NULL) { - // XXX Someone leaked some memory... - } - return; + // This function shouldn't have been called. + // XXX Someone leaked some memory... + assert(PyErr_Occurred()); + return -1; } // "Release" the data and/or the object. struct _gilstate_runtime_state *gilstate = &_PyRuntime.gilstate; _call_in_interpreter(gilstate, interp, _release_xidata, data); + return 0; } PyObject * |