diff options
author | Carl Meyer <carl@oddbird.net> | 2022-10-21 13:41:51 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-21 13:41:51 (GMT) |
commit | 82ccbf69a842db25d8117f1c41b47aa5b4ed96ab (patch) | |
tree | 3a8bdfc3eb837106664433c2e385276fccf64c06 /Objects | |
parent | 8367ca136ed7616cb1f71bd9f1ec98dbcfd35d98 (diff) | |
download | cpython-82ccbf69a842db25d8117f1c41b47aa5b4ed96ab.zip cpython-82ccbf69a842db25d8117f1c41b47aa5b4ed96ab.tar.gz cpython-82ccbf69a842db25d8117f1c41b47aa5b4ed96ab.tar.bz2 |
gh-91051: allow setting a callback hook on PyType_Modified (GH-97875)
Diffstat (limited to 'Objects')
-rw-r--r-- | Objects/typeobject.c | 100 |
1 files changed, 97 insertions, 3 deletions
diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 196a6ae..7f8f2c7 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -372,6 +372,83 @@ _PyTypes_Fini(PyInterpreterState *interp) static PyObject * lookup_subclasses(PyTypeObject *); +int +PyType_AddWatcher(PyType_WatchCallback callback) +{ + PyInterpreterState *interp = _PyInterpreterState_GET(); + + for (int i = 0; i < TYPE_MAX_WATCHERS; i++) { + if (!interp->type_watchers[i]) { + interp->type_watchers[i] = callback; + return i; + } + } + + PyErr_SetString(PyExc_RuntimeError, "no more type watcher IDs available"); + return -1; +} + +static inline int +validate_watcher_id(PyInterpreterState *interp, int watcher_id) +{ + if (watcher_id < 0 || watcher_id >= TYPE_MAX_WATCHERS) { + PyErr_Format(PyExc_ValueError, "Invalid type watcher ID %d", watcher_id); + return -1; + } + if (!interp->type_watchers[watcher_id]) { + PyErr_Format(PyExc_ValueError, "No type watcher set for ID %d", watcher_id); + return -1; + } + return 0; +} + +int +PyType_ClearWatcher(int watcher_id) +{ + PyInterpreterState *interp = _PyInterpreterState_GET(); + if (validate_watcher_id(interp, watcher_id) < 0) { + return -1; + } + interp->type_watchers[watcher_id] = NULL; + return 0; +} + +static int assign_version_tag(PyTypeObject *type); + +int +PyType_Watch(int watcher_id, PyObject* obj) +{ + if (!PyType_Check(obj)) { + PyErr_SetString(PyExc_ValueError, "Cannot watch non-type"); + return -1; + } + PyTypeObject *type = (PyTypeObject *)obj; + PyInterpreterState *interp = _PyInterpreterState_GET(); + if (validate_watcher_id(interp, watcher_id) < 0) { + return -1; + } + // ensure we will get a callback on the next modification + assign_version_tag(type); + type->tp_watched |= (1 << watcher_id); + return 0; +} + +int +PyType_Unwatch(int watcher_id, PyObject* obj) +{ + if (!PyType_Check(obj)) { + PyErr_SetString(PyExc_ValueError, "Cannot watch non-type"); + return -1; + } + PyTypeObject *type = (PyTypeObject *)obj; + PyInterpreterState *interp = _PyInterpreterState_GET(); + if (validate_watcher_id(interp, watcher_id)) { + return -1; + } + type->tp_watched &= ~(1 << watcher_id); + return 0; +} + void PyType_Modified(PyTypeObject *type) { @@ -409,6 +486,23 @@ PyType_Modified(PyTypeObject *type) } } + if (type->tp_watched) { + PyInterpreterState *interp = _PyInterpreterState_GET(); + int bits = type->tp_watched; + int i = 0; + while(bits && i < TYPE_MAX_WATCHERS) { + if (bits & 1) { + PyType_WatchCallback cb = interp->type_watchers[i]; + if (cb && (cb(type) < 0)) { + PyErr_WriteUnraisable((PyObject *)type); + } + } + i += 1; + bits >>= 1; + } + } + + type->tp_flags &= ~Py_TPFLAGS_VALID_VERSION_TAG; type->tp_version_tag = 0; /* 0 is not a valid version tag */ } @@ -467,7 +561,7 @@ type_mro_modified(PyTypeObject *type, PyObject *bases) { } static int -assign_version_tag(struct type_cache *cache, PyTypeObject *type) +assign_version_tag(PyTypeObject *type) { /* Ensure that the tp_version_tag is valid and set Py_TPFLAGS_VALID_VERSION_TAG. To respect the invariant, this @@ -492,7 +586,7 @@ assign_version_tag(struct type_cache *cache, PyTypeObject *type) Py_ssize_t n = PyTuple_GET_SIZE(bases); for (Py_ssize_t i = 0; i < n; i++) { PyObject *b = PyTuple_GET_ITEM(bases, i); - if (!assign_version_tag(cache, _PyType_CAST(b))) + if (!assign_version_tag(_PyType_CAST(b))) return 0; } type->tp_flags |= Py_TPFLAGS_VALID_VERSION_TAG; @@ -4111,7 +4205,7 @@ _PyType_Lookup(PyTypeObject *type, PyObject *name) return NULL; } - if (MCACHE_CACHEABLE_NAME(name) && assign_version_tag(cache, type)) { + if (MCACHE_CACHEABLE_NAME(name) && assign_version_tag(type)) { h = MCACHE_HASH_METHOD(type, name); struct type_cache_entry *entry = &cache->hashtable[h]; entry->version = type->tp_version_tag; |