diff options
Diffstat (limited to 'Modules')
-rw-r--r-- | Modules/_pickle.c | 107 | ||||
-rw-r--r-- | Modules/pyexpat.c | 10 |
2 files changed, 70 insertions, 47 deletions
diff --git a/Modules/_pickle.c b/Modules/_pickle.c index fb69f14..d531dee 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -1649,13 +1649,40 @@ getattribute(PyObject *obj, PyObject *name, int allow_qualname) return attr; } +static int +_checkmodule(PyObject *module_name, PyObject *module, + PyObject *global, PyObject *dotted_path) +{ + if (module == Py_None) { + return -1; + } + if (PyUnicode_Check(module_name) && + _PyUnicode_EqualToASCIIString(module_name, "__main__")) { + return -1; + } + + PyObject *candidate = get_deep_attribute(module, dotted_path, NULL); + if (candidate == NULL) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + } + return -1; + } + if (candidate != global) { + Py_DECREF(candidate); + return -1; + } + Py_DECREF(candidate); + return 0; +} + static PyObject * whichmodule(PyObject *global, PyObject *dotted_path) { PyObject *module_name; - PyObject *modules_dict; - PyObject *module; + PyObject *module = NULL; Py_ssize_t i; + PyObject *modules; _Py_IDENTIFIER(__module__); _Py_IDENTIFIER(modules); _Py_IDENTIFIER(__main__); @@ -1678,35 +1705,48 @@ whichmodule(PyObject *global, PyObject *dotted_path) assert(module_name == NULL); /* Fallback on walking sys.modules */ - modules_dict = _PySys_GetObjectId(&PyId_modules); - if (modules_dict == NULL) { + modules = _PySys_GetObjectId(&PyId_modules); + if (modules == NULL) { PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules"); return NULL; } - - i = 0; - while (PyDict_Next(modules_dict, &i, &module_name, &module)) { - PyObject *candidate; - if (PyUnicode_Check(module_name) && - _PyUnicode_EqualToASCIIString(module_name, "__main__")) - continue; - if (module == Py_None) - continue; - - candidate = get_deep_attribute(module, dotted_path, NULL); - if (candidate == NULL) { - if (!PyErr_ExceptionMatches(PyExc_AttributeError)) + if (PyDict_CheckExact(modules)) { + i = 0; + while (PyDict_Next(modules, &i, &module_name, &module)) { + if (_checkmodule(module_name, module, global, dotted_path) == 0) { + Py_INCREF(module_name); + return module_name; + } + if (PyErr_Occurred()) { return NULL; - PyErr_Clear(); - continue; + } } - - if (candidate == global) { - Py_INCREF(module_name); - Py_DECREF(candidate); - return module_name; + } + else { + PyObject *iterator = PyObject_GetIter(modules); + if (iterator == NULL) { + return NULL; } - Py_DECREF(candidate); + while ((module_name = PyIter_Next(iterator))) { + module = PyObject_GetItem(modules, module_name); + if (module == NULL) { + Py_DECREF(module_name); + Py_DECREF(iterator); + return NULL; + } + if (_checkmodule(module_name, module, global, dotted_path) == 0) { + Py_DECREF(module); + Py_DECREF(iterator); + return module_name; + } + Py_DECREF(module); + Py_DECREF(module_name); + if (PyErr_Occurred()) { + Py_DECREF(iterator); + return NULL; + } + } + Py_DECREF(iterator); } /* If no module is found, use __main__. */ @@ -6424,9 +6464,7 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, /*[clinic end generated code: output=becc08d7f9ed41e3 input=e2e6a865de093ef4]*/ { PyObject *global; - PyObject *modules_dict; PyObject *module; - _Py_IDENTIFIER(modules); /* Try to map the old names used in Python 2.x to the new ones used in Python 3.x. We do this only with old pickle protocols and when the @@ -6483,25 +6521,16 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, } } - modules_dict = _PySys_GetObjectId(&PyId_modules); - if (modules_dict == NULL) { - PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules"); - return NULL; - } - - module = PyDict_GetItemWithError(modules_dict, module_name); + module = PyImport_GetModule(module_name); if (module == NULL) { if (PyErr_Occurred()) return NULL; module = PyImport_Import(module_name); if (module == NULL) return NULL; - global = getattribute(module, global_name, self->proto >= 4); - Py_DECREF(module); - } - else { - global = getattribute(module, global_name, self->proto >= 4); } + global = getattribute(module, global_name, self->proto >= 4); + Py_DECREF(module); return global; } diff --git a/Modules/pyexpat.c b/Modules/pyexpat.c index d9cfa3e..c8a01d4 100644 --- a/Modules/pyexpat.c +++ b/Modules/pyexpat.c @@ -1643,7 +1643,6 @@ MODULE_INITFUNC(void) PyObject *errors_module; PyObject *modelmod_name; PyObject *model_module; - PyObject *sys_modules; PyObject *tmpnum, *tmpstr; PyObject *codes_dict; PyObject *rev_codes_dict; @@ -1693,11 +1692,6 @@ MODULE_INITFUNC(void) */ PyModule_AddStringConstant(m, "native_encoding", "UTF-8"); - sys_modules = PySys_GetObject("modules"); - if (sys_modules == NULL) { - Py_DECREF(m); - return NULL; - } d = PyModule_GetDict(m); if (d == NULL) { Py_DECREF(m); @@ -1707,7 +1701,7 @@ MODULE_INITFUNC(void) if (errors_module == NULL) { errors_module = PyModule_New(MODULE_NAME ".errors"); if (errors_module != NULL) { - PyDict_SetItem(sys_modules, errmod_name, errors_module); + _PyImport_SetModule(errmod_name, errors_module); /* gives away the reference to errors_module */ PyModule_AddObject(m, "errors", errors_module); } @@ -1717,7 +1711,7 @@ MODULE_INITFUNC(void) if (model_module == NULL) { model_module = PyModule_New(MODULE_NAME ".model"); if (model_module != NULL) { - PyDict_SetItem(sys_modules, modelmod_name, model_module); + _PyImport_SetModule(modelmod_name, model_module); /* gives away the reference to model_module */ PyModule_AddObject(m, "model", model_module); } |