summaryrefslogtreecommitdiffstats
path: root/Objects/moduleobject.c
diff options
context:
space:
mode:
Diffstat (limited to 'Objects/moduleobject.c')
-rw-r--r--Objects/moduleobject.c103
1 files changed, 95 insertions, 8 deletions
diff --git a/Objects/moduleobject.c b/Objects/moduleobject.c
index 46995b9..73ad971 100644
--- a/Objects/moduleobject.c
+++ b/Objects/moduleobject.c
@@ -5,6 +5,7 @@
#include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_fileutils.h" // _Py_wgetcwd
#include "pycore_interp.h" // PyInterpreterState.importlib
+#include "pycore_long.h" // _PyLong_GetOne()
#include "pycore_modsupport.h" // _PyModule_CreateInitialized()
#include "pycore_moduleobject.h" // _PyModule_GetDef()
#include "pycore_object.h" // _PyType_AllocNoTrack
@@ -1133,7 +1134,7 @@ static PyMethodDef module_methods[] = {
};
static PyObject *
-module_get_annotations(PyModuleObject *m, void *Py_UNUSED(ignored))
+module_get_dict(PyModuleObject *m)
{
PyObject *dict = PyObject_GetAttr((PyObject *)m, &_Py_ID(__dict__));
if (dict == NULL) {
@@ -1144,10 +1145,97 @@ module_get_annotations(PyModuleObject *m, void *Py_UNUSED(ignored))
Py_DECREF(dict);
return NULL;
}
+ return dict;
+}
+
+static PyObject *
+module_get_annotate(PyModuleObject *m, void *Py_UNUSED(ignored))
+{
+ PyObject *dict = module_get_dict(m);
+ if (dict == NULL) {
+ return NULL;
+ }
+
+ PyObject *annotate;
+ if (PyDict_GetItemRef(dict, &_Py_ID(__annotate__), &annotate) == 0) {
+ annotate = Py_None;
+ if (PyDict_SetItem(dict, &_Py_ID(__annotate__), annotate) == -1) {
+ Py_CLEAR(annotate);
+ }
+ }
+ Py_DECREF(dict);
+ return annotate;
+}
+
+static int
+module_set_annotate(PyModuleObject *m, PyObject *value, void *Py_UNUSED(ignored))
+{
+ if (value == NULL) {
+ PyErr_SetString(PyExc_TypeError, "cannot delete __annotate__ attribute");
+ return -1;
+ }
+ PyObject *dict = module_get_dict(m);
+ if (dict == NULL) {
+ return -1;
+ }
+
+ if (!Py_IsNone(value) && !PyCallable_Check(value)) {
+ PyErr_SetString(PyExc_TypeError, "__annotate__ must be callable or None");
+ Py_DECREF(dict);
+ return -1;
+ }
+
+ if (PyDict_SetItem(dict, &_Py_ID(__annotate__), value) == -1) {
+ Py_DECREF(dict);
+ return -1;
+ }
+ if (!Py_IsNone(value)) {
+ if (PyDict_Pop(dict, &_Py_ID(__annotations__), NULL) == -1) {
+ Py_DECREF(dict);
+ return -1;
+ }
+ }
+ Py_DECREF(dict);
+ return 0;
+}
+
+static PyObject *
+module_get_annotations(PyModuleObject *m, void *Py_UNUSED(ignored))
+{
+ PyObject *dict = module_get_dict(m);
+ if (dict == NULL) {
+ return NULL;
+ }
PyObject *annotations;
if (PyDict_GetItemRef(dict, &_Py_ID(__annotations__), &annotations) == 0) {
- annotations = PyDict_New();
+ PyObject *annotate;
+ int annotate_result = PyDict_GetItemRef(dict, &_Py_ID(__annotate__), &annotate);
+ if (annotate_result < 0) {
+ Py_DECREF(dict);
+ return NULL;
+ }
+ if (annotate_result == 1 && PyCallable_Check(annotate)) {
+ PyObject *one = _PyLong_GetOne();
+ annotations = _PyObject_CallOneArg(annotate, one);
+ if (annotations == NULL) {
+ Py_DECREF(annotate);
+ Py_DECREF(dict);
+ return NULL;
+ }
+ if (!PyDict_Check(annotations)) {
+ PyErr_Format(PyExc_TypeError, "__annotate__ returned non-dict of type '%.100s'",
+ Py_TYPE(annotations)->tp_name);
+ Py_DECREF(annotate);
+ Py_DECREF(annotations);
+ Py_DECREF(dict);
+ return NULL;
+ }
+ }
+ else {
+ annotations = PyDict_New();
+ }
+ Py_XDECREF(annotate);
if (annotations) {
int result = PyDict_SetItem(
dict, &_Py_ID(__annotations__), annotations);
@@ -1164,14 +1252,10 @@ static int
module_set_annotations(PyModuleObject *m, PyObject *value, void *Py_UNUSED(ignored))
{
int ret = -1;
- PyObject *dict = PyObject_GetAttr((PyObject *)m, &_Py_ID(__dict__));
+ PyObject *dict = module_get_dict(m);
if (dict == NULL) {
return -1;
}
- if (!PyDict_Check(dict)) {
- PyErr_Format(PyExc_TypeError, "<module>.__dict__ is not a dictionary");
- goto exit;
- }
if (value != NULL) {
/* set */
@@ -1188,8 +1272,10 @@ module_set_annotations(PyModuleObject *m, PyObject *value, void *Py_UNUSED(ignor
ret = 0;
}
}
+ if (ret == 0 && PyDict_Pop(dict, &_Py_ID(__annotate__), NULL) < 0) {
+ ret = -1;
+ }
-exit:
Py_DECREF(dict);
return ret;
}
@@ -1197,6 +1283,7 @@ exit:
static PyGetSetDef module_getsets[] = {
{"__annotations__", (getter)module_get_annotations, (setter)module_set_annotations},
+ {"__annotate__", (getter)module_get_annotate, (setter)module_set_annotate},
{NULL}
};