diff options
Diffstat (limited to 'Modules')
-rw-r--r-- | Modules/_xxsubinterpretersmodule.c | 349 |
1 files changed, 226 insertions, 123 deletions
diff --git a/Modules/_xxsubinterpretersmodule.c b/Modules/_xxsubinterpretersmodule.c index 3e064ca..d7d7fca 100644 --- a/Modules/_xxsubinterpretersmodule.c +++ b/Modules/_xxsubinterpretersmodule.c @@ -96,18 +96,20 @@ add_new_exception(PyObject *mod, const char *name, PyObject *base) add_new_exception(MOD, MODULE_NAME "." Py_STRINGIFY(NAME), BASE) static PyTypeObject * -add_new_type(PyObject *mod, PyTypeObject *cls, crossinterpdatafunc shared) +add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared) { - if (PyType_Ready(cls) != 0) { + PyTypeObject *cls = (PyTypeObject *)PyType_FromMetaclass( + NULL, mod, spec, NULL); + if (cls == NULL) { return NULL; } - if (PyModule_AddType(mod, cls) != 0) { - // XXX When this becomes a heap type, we need to decref here. + if (PyModule_AddType(mod, cls) < 0) { + Py_DECREF(cls); return NULL; } if (shared != NULL) { if (_PyCrossInterpreterData_RegisterClass(cls, shared)) { - // XXX When this becomes a heap type, we need to decref here. + Py_DECREF(cls); return NULL; } } @@ -135,12 +137,7 @@ _release_xid_data(_PyCrossInterpreterData *data, int ignoreexc) * shareable types are all very basic, with no GC. * That said, it becomes much messier once interpreters * no longer share a GIL, so this needs to be fixed before then. */ - // We do what _release_xidata() does in pystate.c. - if (data->free != NULL) { - data->free(data->data); - data->data = NULL; - } - Py_CLEAR(data->obj); + _PyCrossInterpreterData_Clear(NULL, data); if (ignoreexc) { // XXX Emit a warning? PyErr_Clear(); @@ -153,6 +150,69 @@ _release_xid_data(_PyCrossInterpreterData *data, int ignoreexc) } +/* module state *************************************************************/ + +typedef struct { + PyTypeObject *ChannelIDType; + + /* interpreter exceptions */ + PyObject *RunFailedError; + + /* channel exceptions */ + PyObject *ChannelError; + PyObject *ChannelNotFoundError; + PyObject *ChannelClosedError; + PyObject *ChannelEmptyError; + PyObject *ChannelNotEmptyError; +} module_state; + +static inline module_state * +get_module_state(PyObject *mod) +{ + assert(mod != NULL); + module_state *state = PyModule_GetState(mod); + assert(state != NULL); + return state; +} + +static int +traverse_module_state(module_state *state, visitproc visit, void *arg) +{ + /* heap types */ + Py_VISIT(state->ChannelIDType); + + /* interpreter exceptions */ + Py_VISIT(state->RunFailedError); + + /* channel exceptions */ + Py_VISIT(state->ChannelError); + Py_VISIT(state->ChannelNotFoundError); + Py_VISIT(state->ChannelClosedError); + Py_VISIT(state->ChannelEmptyError); + Py_VISIT(state->ChannelNotEmptyError); + return 0; +} + +static int +clear_module_state(module_state *state) +{ + /* heap types */ + (void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType); + Py_CLEAR(state->ChannelIDType); + + /* interpreter exceptions */ + Py_CLEAR(state->RunFailedError); + + /* channel exceptions */ + Py_CLEAR(state->ChannelError); + Py_CLEAR(state->ChannelNotFoundError); + Py_CLEAR(state->ChannelClosedError); + Py_CLEAR(state->ChannelEmptyError); + Py_CLEAR(state->ChannelNotEmptyError); + return 0; +} + + /* data-sharing-specific code ***********************************************/ struct _sharednsitem { @@ -420,82 +480,80 @@ _sharedexception_apply(_sharedexception *exc, PyObject *wrapperclass) #define ERR_CHANNELS_MUTEX_INIT -8 #define ERR_NO_NEXT_CHANNEL_ID -9 -static PyObject *ChannelError; -static PyObject *ChannelNotFoundError; -static PyObject *ChannelClosedError; -static PyObject *ChannelEmptyError; -static PyObject *ChannelNotEmptyError; - static int channel_exceptions_init(PyObject *mod) { - // XXX Move the exceptions into per-module memory? + module_state *state = get_module_state(mod); + if (state == NULL) { + return -1; + } #define ADD(NAME, BASE) \ do { \ - if (NAME == NULL) { \ - NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \ - if (NAME == NULL) { \ - return -1; \ - } \ + assert(state->NAME == NULL); \ + state->NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \ + if (state->NAME == NULL) { \ + return -1; \ } \ } while (0) // A channel-related operation failed. ADD(ChannelError, PyExc_RuntimeError); // An operation tried to use a channel that doesn't exist. - ADD(ChannelNotFoundError, ChannelError); + ADD(ChannelNotFoundError, state->ChannelError); // An operation tried to use a closed channel. - ADD(ChannelClosedError, ChannelError); + ADD(ChannelClosedError, state->ChannelError); // An operation tried to pop from an empty channel. - ADD(ChannelEmptyError, ChannelError); + ADD(ChannelEmptyError, state->ChannelError); // An operation tried to close a non-empty channel. - ADD(ChannelNotEmptyError, ChannelError); + ADD(ChannelNotEmptyError, state->ChannelError); #undef ADD return 0; } static int -handle_channel_error(int err, PyObject *Py_UNUSED(mod), int64_t cid) +handle_channel_error(int err, PyObject *mod, int64_t cid) { if (err == 0) { assert(!PyErr_Occurred()); return 0; } assert(err < 0); + module_state *state = get_module_state(mod); + assert(state != NULL); if (err == ERR_CHANNEL_NOT_FOUND) { - PyErr_Format(ChannelNotFoundError, + PyErr_Format(state->ChannelNotFoundError, "channel %" PRId64 " not found", cid); } else if (err == ERR_CHANNEL_CLOSED) { - PyErr_Format(ChannelClosedError, + PyErr_Format(state->ChannelClosedError, "channel %" PRId64 " is closed", cid); } else if (err == ERR_CHANNEL_INTERP_CLOSED) { - PyErr_Format(ChannelClosedError, + PyErr_Format(state->ChannelClosedError, "channel %" PRId64 " is already closed", cid); } else if (err == ERR_CHANNEL_EMPTY) { - PyErr_Format(ChannelEmptyError, + PyErr_Format(state->ChannelEmptyError, "channel %" PRId64 " is empty", cid); } else if (err == ERR_CHANNEL_NOT_EMPTY) { - PyErr_Format(ChannelNotEmptyError, + PyErr_Format(state->ChannelNotEmptyError, "channel %" PRId64 " may not be closed " "if not empty (try force=True)", cid); } else if (err == ERR_CHANNEL_MUTEX_INIT) { - PyErr_SetString(ChannelError, + PyErr_SetString(state->ChannelError, "can't initialize mutex for new channel"); } else if (err == ERR_CHANNELS_MUTEX_INIT) { - PyErr_SetString(ChannelError, + PyErr_SetString(state->ChannelError, "can't initialize mutex for channel management"); } else if (err == ERR_NO_NEXT_CHANNEL_ID) { - PyErr_SetString(ChannelError, + PyErr_SetString(state->ChannelError, "failed to get a channel ID"); } else { @@ -1604,8 +1662,6 @@ _channel_is_associated(_channels *channels, int64_t cid, int64_t interp, /* ChannelID class */ -static PyTypeObject ChannelIDType; - typedef struct channelid { PyObject_HEAD int64_t id; @@ -1624,7 +1680,9 @@ channel_id_converter(PyObject *arg, void *ptr) { int64_t cid; struct channel_id_converter_data *data = ptr; - if (PyObject_TypeCheck(arg, &ChannelIDType)) { + module_state *state = get_module_state(data->module); + assert(state != NULL); + if (PyObject_TypeCheck(arg, state->ChannelIDType)) { cid = ((channelid *)arg)->id; } else if (PyIndex_Check(arg)) { @@ -1731,11 +1789,20 @@ _channelid_new(PyObject *mod, PyTypeObject *cls, } static void -channelid_dealloc(PyObject *v) +channelid_dealloc(PyObject *self) { - int64_t cid = ((channelid *)v)->id; - _channels *channels = ((channelid *)v)->channels; - Py_TYPE(v)->tp_free(v); + int64_t cid = ((channelid *)self)->id; + _channels *channels = ((channelid *)self)->channels; + + PyTypeObject *tp = Py_TYPE(self); + tp->tp_free(self); + /* "Instances of heap-allocated types hold a reference to their type." + * See: https://docs.python.org/3.11/howto/isolating-extensions.html#garbage-collection-protocol + * See: https://docs.python.org/3.11/c-api/typeobj.html#c.PyTypeObject.tp_traverse + */ + // XXX Why don't we implement Py_TPFLAGS_HAVE_GC, e.g. Py_tp_traverse, + // like we do for _abc._abc_data? + Py_DECREF(tp); _channels_drop_id_object(channels, cid); } @@ -1774,11 +1841,6 @@ channelid_int(PyObject *self) return PyLong_FromLongLong(cid->id); } -static PyNumberMethods channelid_as_number = { - .nb_int = (unaryfunc)channelid_int, /* nb_int */ - .nb_index = (unaryfunc)channelid_int, /* nb_index */ -}; - static Py_hash_t channelid_hash(PyObject *self) { @@ -1804,15 +1866,19 @@ channelid_richcompare(PyObject *self, PyObject *other, int op) if (mod == NULL) { return NULL; } + module_state *state = get_module_state(mod); + if (state == NULL) { + goto done; + } - if (!PyObject_TypeCheck(self, &ChannelIDType)) { + if (!PyObject_TypeCheck(self, state->ChannelIDType)) { res = Py_NewRef(Py_NotImplemented); goto done; } channelid *cid = (channelid *)self; int equal; - if (PyObject_TypeCheck(other, &ChannelIDType)) { + if (PyObject_TypeCheck(other, state->ChannelIDType)) { channelid *othercid = (channelid *)other; equal = (cid->end == othercid->end) && (cid->id == othercid->id); } @@ -1892,10 +1958,14 @@ _channelid_from_xid(_PyCrossInterpreterData *data) if (mod == NULL) { return NULL; } + module_state *state = get_module_state(mod); + if (state == NULL) { + return NULL; + } // Note that we do not preserve the "resolve" flag. PyObject *cid = NULL; - int err = newchannelid(&ChannelIDType, xid->id, xid->end, + int err = newchannelid(state->ChannelIDType, xid->id, xid->end, _global_channels(), 0, 0, (channelid **)&cid); if (err != 0) { @@ -1926,20 +1996,20 @@ done: } static int -_channelid_shared(PyObject *obj, _PyCrossInterpreterData *data) -{ - struct _channelid_xid *xid = PyMem_NEW(struct _channelid_xid, 1); - if (xid == NULL) { +_channelid_shared(PyThreadState *tstate, PyObject *obj, + _PyCrossInterpreterData *data) +{ + if (_PyCrossInterpreterData_InitWithSize( + data, tstate->interp, sizeof(struct _channelid_xid), obj, + _channelid_from_xid + ) < 0) + { return -1; } + struct _channelid_xid *xid = (struct _channelid_xid *)data->data; xid->id = ((channelid *)obj)->id; xid->end = ((channelid *)obj)->end; xid->resolve = ((channelid *)obj)->resolve; - - data->data = xid; - data->obj = Py_NewRef(obj); - data->new_object = _channelid_from_xid; - data->free = PyMem_Free; return 0; } @@ -1992,61 +2062,45 @@ static PyGetSetDef channelid_getsets[] = { PyDoc_STRVAR(channelid_doc, "A channel ID identifies a channel and may be used as an int."); -static PyTypeObject ChannelIDType = { - PyVarObject_HEAD_INIT(&PyType_Type, 0) - "_xxsubinterpreters.ChannelID", /* tp_name */ - sizeof(channelid), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)channelid_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - (reprfunc)channelid_repr, /* tp_repr */ - &channelid_as_number, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - channelid_hash, /* tp_hash */ - 0, /* tp_call */ - (reprfunc)channelid_str, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - // Use Py_TPFLAGS_DISALLOW_INSTANTIATION so the type cannot be instantiated - // from Python code. We do this because there is a strong relationship - // between channel IDs and the channel lifecycle, so this limitation avoids - // related complications. Use the _channel_id() function instead. - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE - | Py_TPFLAGS_DISALLOW_INSTANTIATION, /* tp_flags */ - channelid_doc, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - channelid_richcompare, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* tp_methods */ - 0, /* tp_members */ - channelid_getsets, /* tp_getset */ +static PyType_Slot ChannelIDType_slots[] = { + {Py_tp_dealloc, (destructor)channelid_dealloc}, + {Py_tp_doc, (void *)channelid_doc}, + {Py_tp_repr, (reprfunc)channelid_repr}, + {Py_tp_str, (reprfunc)channelid_str}, + {Py_tp_hash, channelid_hash}, + {Py_tp_richcompare, channelid_richcompare}, + {Py_tp_getset, channelid_getsets}, + // number slots + {Py_nb_int, (unaryfunc)channelid_int}, + {Py_nb_index, (unaryfunc)channelid_int}, + {0, NULL}, }; +static PyType_Spec ChannelIDType_spec = { + .name = "_xxsubinterpreters.ChannelID", + .basicsize = sizeof(channelid), + .flags = (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | + Py_TPFLAGS_DISALLOW_INSTANTIATION | Py_TPFLAGS_IMMUTABLETYPE), + .slots = ChannelIDType_slots, +}; -/* interpreter-specific code ************************************************/ -static PyObject * RunFailedError = NULL; +/* interpreter-specific code ************************************************/ static int interp_exceptions_init(PyObject *mod) { - // XXX Move the exceptions into per-module memory? + module_state *state = get_module_state(mod); + if (state == NULL) { + return -1; + } #define ADD(NAME, BASE) \ do { \ - if (NAME == NULL) { \ - NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \ - if (NAME == NULL) { \ - return -1; \ - } \ + assert(state->NAME == NULL); \ + state->NAME = ADD_NEW_EXCEPTION(mod, NAME, BASE); \ + if (state->NAME == NULL) { \ + return -1; \ } \ } while (0) @@ -2167,9 +2221,10 @@ _run_script_in_interpreter(PyObject *mod, PyInterpreterState *interp, if (_ensure_not_running(interp) < 0) { return -1; } + module_state *state = get_module_state(mod); int needs_import = 0; - _sharedns *shared = _get_shared_ns(shareables, &ChannelIDType, + _sharedns *shared = _get_shared_ns(shareables, state->ChannelIDType, &needs_import); if (shared == NULL && PyErr_Occurred()) { return -1; @@ -2195,7 +2250,8 @@ _run_script_in_interpreter(PyObject *mod, PyInterpreterState *interp, // Propagate any exception out to the caller. if (exc != NULL) { - _sharedexception_apply(exc, RunFailedError); + assert(state != NULL); + _sharedexception_apply(exc, state->RunFailedError); _sharedexception_free(exc); } else if (result != 0) { @@ -2530,8 +2586,12 @@ channel_create(PyObject *self, PyObject *Py_UNUSED(ignored)) (void)handle_channel_error(cid, self, -1); return NULL; } + module_state *state = get_module_state(self); + if (state == NULL) { + return NULL; + } PyObject *id = NULL; - int err = newchannelid(&ChannelIDType, cid, 0, + int err = newchannelid(state->ChannelIDType, cid, 0, &_globals.channels, 0, 0, (channelid **)&id); if (handle_channel_error(err, self, cid)) { @@ -2594,10 +2654,16 @@ channel_list_all(PyObject *self, PyObject *Py_UNUSED(ignored)) if (ids == NULL) { goto finally; } + module_state *state = get_module_state(self); + if (state == NULL) { + Py_DECREF(ids); + ids = NULL; + goto finally; + } int64_t *cur = cids; for (int64_t i=0; i < count; cur++, i++) { PyObject *id = NULL; - int err = newchannelid(&ChannelIDType, *cur, 0, + int err = newchannelid(state->ChannelIDType, *cur, 0, &_globals.channels, 0, 0, (channelid **)&id); if (handle_channel_error(err, self, *cur)) { @@ -2850,7 +2916,11 @@ ends are closed. Closing an already closed end is a noop."); static PyObject * channel__channel_id(PyObject *self, PyObject *args, PyObject *kwds) { - PyTypeObject *cls = &ChannelIDType; + module_state *state = get_module_state(self); + if (state == NULL) { + return NULL; + } + PyTypeObject *cls = state->ChannelIDType; PyObject *mod = get_module_from_owned_type(cls); if (mod == NULL) { return NULL; @@ -2924,9 +2994,16 @@ module_exec(PyObject *mod) } /* Add other types */ - if (add_new_type(mod, &ChannelIDType, _channelid_shared) == NULL) { + module_state *state = get_module_state(mod); + + // ChannelID + state->ChannelIDType = add_new_type( + mod, &ChannelIDType_spec, _channelid_shared); + if (state->ChannelIDType == NULL) { goto error; } + + // PyInterpreterID if (PyModule_AddType(mod, &_PyInterpreterID_Type) < 0) { goto error; } @@ -2934,31 +3011,57 @@ module_exec(PyObject *mod) return 0; error: - (void)_PyCrossInterpreterData_UnregisterClass(&ChannelIDType); + (void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType); _globals_fini(); return -1; } +static struct PyModuleDef_Slot module_slots[] = { + {Py_mod_exec, module_exec}, + {0, NULL}, +}; + +static int +module_traverse(PyObject *mod, visitproc visit, void *arg) +{ + module_state *state = get_module_state(mod); + assert(state != NULL); + traverse_module_state(state, visit, arg); + return 0; +} + +static int +module_clear(PyObject *mod) +{ + module_state *state = get_module_state(mod); + assert(state != NULL); + clear_module_state(state); + return 0; +} + +static void +module_free(void *mod) +{ + module_state *state = get_module_state(mod); + assert(state != NULL); + clear_module_state(state); + _globals_fini(); +} + static struct PyModuleDef moduledef = { .m_base = PyModuleDef_HEAD_INIT, .m_name = MODULE_NAME, .m_doc = module_doc, - .m_size = -1, + .m_size = sizeof(module_state), .m_methods = module_functions, + .m_slots = module_slots, + .m_traverse = module_traverse, + .m_clear = module_clear, + .m_free = (freefunc)module_free, }; - PyMODINIT_FUNC PyInit__xxsubinterpreters(void) { - /* Create the module */ - PyObject *mod = PyModule_Create(&moduledef); - if (mod == NULL) { - return NULL; - } - if (module_exec(mod) < 0) { - Py_DECREF(mod); - return NULL; - } - return mod; + return PyModuleDef_Init(&moduledef); } |