From 503af8fe9a93ea6bc5bdfc76eb56b106a47c7292 Mon Sep 17 00:00:00 2001
From: Eric Snow <ericsnowcurrently@gmail.com>
Date: Mon, 12 Aug 2024 13:19:33 -0600
Subject: gh-117482: Make the Slot Wrapper Inheritance Tests Much More Thorough
 (gh-122867)

There were a still a number of gaps in the tests, including not looking
at all the builtin types and not checking wrappers in subinterpreters
that weren't in the main interpreter. This fixes all that.

I considered incorporating the names of the PyTypeObject fields
(a la gh-122866), but figured doing so doesn't add much value.
---
 Include/internal/pycore_typeobject.h |   6 ++
 Lib/test/support/__init__.py         | 142 ++++++++++++++++++++++++++++++++---
 Lib/test/test_embed.py               |  57 ++++++++------
 Lib/test/test_types.py               |  60 +++++++++------
 Modules/_testinternalcapi.c          |  16 ++++
 Objects/typeobject.c                 |  41 ++++++++++
 6 files changed, 268 insertions(+), 54 deletions(-)

diff --git a/Include/internal/pycore_typeobject.h b/Include/internal/pycore_typeobject.h
index df6bfef..8ba635c 100644
--- a/Include/internal/pycore_typeobject.h
+++ b/Include/internal/pycore_typeobject.h
@@ -183,6 +183,9 @@ PyAPI_FUNC(int) _PyStaticType_InitForExtension(
     PyInterpreterState *interp,
      PyTypeObject *self);
 
+// Export for _testinternalcapi extension.
+PyAPI_FUNC(PyObject *) _PyStaticType_GetBuiltins(void);
+
 
 /* Like PyType_GetModuleState, but skips verification
  * that type is a heap type with an associated module */
@@ -209,6 +212,9 @@ extern PyObject* _PyType_GetSubclasses(PyTypeObject *);
 extern int _PyType_HasSubclasses(PyTypeObject *);
 PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef2(PyTypeObject *, PyTypeObject *, PyModuleDef *);
 
+// Export for _testinternalcapi extension.
+PyAPI_FUNC(PyObject *) _PyType_GetSlotWrapperNames(void);
+
 // PyType_Ready() must be called if _PyType_IsReady() is false.
 // See also the Py_TPFLAGS_READY flag.
 static inline int
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
index f4dce79..e21a0be 100644
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -5,6 +5,7 @@ if __name__ != 'test.support':
 
 import contextlib
 import functools
+import inspect
 import _opcode
 import os
 import re
@@ -892,8 +893,16 @@ def calcvobjsize(fmt):
     return struct.calcsize(_vheader + fmt + _align)
 
 
-_TPFLAGS_HAVE_GC = 1<<14
+_TPFLAGS_STATIC_BUILTIN = 1<<1
+_TPFLAGS_DISALLOW_INSTANTIATION = 1<<7
+_TPFLAGS_IMMUTABLETYPE = 1<<8
 _TPFLAGS_HEAPTYPE = 1<<9
+_TPFLAGS_BASETYPE = 1<<10
+_TPFLAGS_READY = 1<<12
+_TPFLAGS_READYING = 1<<13
+_TPFLAGS_HAVE_GC = 1<<14
+_TPFLAGS_BASE_EXC_SUBCLASS = 1<<30
+_TPFLAGS_TYPE_SUBCLASS = 1<<31
 
 def check_sizeof(test, o, size):
     try:
@@ -2608,19 +2617,121 @@ def copy_python_src_ignore(path, names):
     return ignored
 
 
-def iter_builtin_types():
-    for obj in __builtins__.values():
-        if not isinstance(obj, type):
+# XXX Move this to the inspect module?
+def walk_class_hierarchy(top, *, topdown=True):
+    # This is based on the logic in os.walk().
+    assert isinstance(top, type), repr(top)
+    stack = [top]
+    while stack:
+        top = stack.pop()
+        if isinstance(top, tuple):
+            yield top
             continue
-        cls = obj
-        if cls.__module__ != 'builtins':
+
+        subs = type(top).__subclasses__(top)
+        if topdown:
+            # Yield before subclass traversal if going top down.
+            yield top, subs
+            # Traverse into subclasses.
+            for sub in reversed(subs):
+                stack.append(sub)
+        else:
+            # Yield after subclass traversal if going bottom up.
+            stack.append((top, subs))
+            # Traverse into subclasses.
+            for sub in reversed(subs):
+                stack.append(sub)
+
+
+def iter_builtin_types():
+    # First try the explicit route.
+    try:
+        import _testinternalcapi
+    except ImportError:
+        _testinternalcapi = None
+    if _testinternalcapi is not None:
+        yield from _testinternalcapi.get_static_builtin_types()
+        return
+
+    # Fall back to making a best-effort guess.
+    if hasattr(object, '__flags__'):
+        # Look for any type object with the Py_TPFLAGS_STATIC_BUILTIN flag set.
+        import datetime
+        seen = set()
+        for cls, subs in walk_class_hierarchy(object):
+            if cls in seen:
+                continue
+            seen.add(cls)
+            if not (cls.__flags__ & _TPFLAGS_STATIC_BUILTIN):
+                # Do not walk its subclasses.
+                subs[:] = []
+                continue
+            yield cls
+    else:
+        # Fall back to a naive approach.
+        seen = set()
+        for obj in __builtins__.values():
+            if not isinstance(obj, type):
+                continue
+            cls = obj
+            # XXX?
+            if cls.__module__ != 'builtins':
+                continue
+            if cls == ExceptionGroup:
+                # It's a heap type.
+                continue
+            if cls in seen:
+                continue
+            seen.add(cls)
+            yield cls
+
+
+# XXX Move this to the inspect module?
+def iter_name_in_mro(cls, name):
+    """Yield matching items found in base.__dict__ across the MRO.
+
+    The descriptor protocol is not invoked.
+
+    list(iter_name_in_mro(cls, name))[0] is roughly equivalent to
+    find_name_in_mro() in Objects/typeobject.c (AKA PyType_Lookup()).
+
+    inspect.getattr_static() is similar.
+    """
+    # This can fail if "cls" is weird.
+    for base in inspect._static_getmro(cls):
+        # This can fail if "base" is weird.
+        ns = inspect._get_dunder_dict_of_class(base)
+        try:
+            obj = ns[name]
+        except KeyError:
             continue
-        yield cls
+        yield obj, base
 
 
-def iter_slot_wrappers(cls):
-    assert cls.__module__ == 'builtins', cls
+# XXX Move this to the inspect module?
+def find_name_in_mro(cls, name, default=inspect._sentinel):
+    for res in iter_name_in_mro(cls, name):
+        # Return the first one.
+        return res
+    if default is not inspect._sentinel:
+        return default, None
+    raise AttributeError(name)
+
+
+# XXX The return value should always be exactly the same...
+def identify_type_slot_wrappers():
+    try:
+        import _testinternalcapi
+    except ImportError:
+        _testinternalcapi = None
+    if _testinternalcapi is not None:
+        names = {n: None for n in _testinternalcapi.identify_type_slot_wrappers()}
+        return list(names)
+    else:
+        raise NotImplementedError
+
 
+def iter_slot_wrappers(cls):
     def is_slot_wrapper(name, value):
         if not isinstance(value, types.WrapperDescriptorType):
             assert not repr(value).startswith('<slot wrapper '), (cls, name, value)
@@ -2630,6 +2741,19 @@ def iter_slot_wrappers(cls):
         assert name.startswith('__') and name.endswith('__'), (cls, name, value)
         return True
 
+    try:
+        attrs = identify_type_slot_wrappers()
+    except NotImplementedError:
+        attrs = None
+    if attrs is not None:
+        for attr in sorted(attrs):
+            obj, base = find_name_in_mro(cls, attr, None)
+            if obj is not None and is_slot_wrapper(attr, obj):
+                yield attr, base is cls
+        return
+
+    # Fall back to a naive best-effort approach.
+
     ns = vars(cls)
     unused = set(ns)
     for name in dir(cls):
diff --git a/Lib/test/test_embed.py b/Lib/test/test_embed.py
index 916a9a7..aab4333 100644
--- a/Lib/test/test_embed.py
+++ b/Lib/test/test_embed.py
@@ -420,45 +420,54 @@ class EmbeddingTests(EmbeddingTestsMixin, unittest.TestCase):
     def test_static_types_inherited_slots(self):
         script = textwrap.dedent("""
             import test.support
-
-            results = {}
-            def add(cls, slot, own):
-                value = getattr(cls, slot)
-                try:
-                    subresults = results[cls.__name__]
-                except KeyError:
-                    subresults = results[cls.__name__] = {}
-                subresults[slot] = [repr(value), own]
-
+            results = []
             for cls in test.support.iter_builtin_types():
-                for slot, own in test.support.iter_slot_wrappers(cls):
-                    add(cls, slot, own)
+                for attr, _ in test.support.iter_slot_wrappers(cls):
+                    wrapper = getattr(cls, attr)
+                    res = (cls, attr, wrapper)
+                    results.append(res)
+            results = ((repr(c), a, repr(w)) for c, a, w in results)
             """)
+        def collate_results(raw):
+            results = {}
+            for cls, attr, wrapper in raw:
+                key = cls, attr
+                assert key not in results, (results, key, wrapper)
+                results[key] = wrapper
+            return results
 
         ns = {}
         exec(script, ns, ns)
-        all_expected = ns['results']
+        main_results = collate_results(ns['results'])
         del ns
 
         script += textwrap.dedent("""
             import json
             import sys
-            text = json.dumps(results)
+            text = json.dumps(list(results))
             print(text, file=sys.stderr)
             """)
         out, err = self.run_embedded_interpreter(
                 "test_repeated_init_exec", script, script)
-        results = err.split('--- Loop #')[1:]
-        results = [res.rpartition(' ---\n')[-1] for res in results]
-
+        _results = err.split('--- Loop #')[1:]
+        (_embedded, _reinit,
+         ) = [json.loads(res.rpartition(' ---\n')[-1]) for res in _results]
+        embedded_results = collate_results(_embedded)
+        reinit_results = collate_results(_reinit)
+
+        for key, expected in main_results.items():
+            cls, attr = key
+            for src, results in [
+                ('embedded', embedded_results),
+                ('reinit', reinit_results),
+            ]:
+                with self.subTest(src, cls=cls, slotattr=attr):
+                    actual = results.pop(key)
+                    self.assertEqual(actual, expected)
         self.maxDiff = None
-        for i, text in enumerate(results, start=1):
-            result = json.loads(text)
-            for classname, expected in all_expected.items():
-                with self.subTest(loop=i, cls=classname):
-                    slots = result.pop(classname)
-                    self.assertEqual(slots, expected)
-            self.assertEqual(result, {})
+        self.assertEqual(embedded_results, {})
+        self.assertEqual(reinit_results, {})
+
         self.assertEqual(out, '')
 
     def test_getargs_reset_static_parser(self):
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index bcdd6d2..2ee4660 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -2396,35 +2396,53 @@ class SubinterpreterTests(unittest.TestCase):
     def test_static_types_inherited_slots(self):
         rch, sch = interpreters.channels.create()
 
-        slots = []
-        script = ''
-        for cls in iter_builtin_types():
-            for slot, own in iter_slot_wrappers(cls):
-                if cls is bool and slot in self.NUMERIC_METHODS:
+        script = textwrap.dedent("""
+            import test.support
+            results = []
+            for cls in test.support.iter_builtin_types():
+                for attr, _ in test.support.iter_slot_wrappers(cls):
+                    wrapper = getattr(cls, attr)
+                    res = (cls, attr, wrapper)
+                    results.append(res)
+            results = tuple((repr(c), a, repr(w)) for c, a, w in results)
+            sch.send_nowait(results)
+            """)
+        def collate_results(raw):
+            results = {}
+            for cls, attr, wrapper in raw:
+                # XXX This should not be necessary.
+                if cls == repr(bool) and attr in self.NUMERIC_METHODS:
                     continue
-                slots.append((cls, slot, own))
-                script += textwrap.dedent(f"""
-                    text = repr({cls.__name__}.{slot})
-                    sch.send_nowait(({cls.__name__!r}, {slot!r}, text))
-                    """)
+                key = cls, attr
+                assert key not in results, (results, key, wrapper)
+                results[key] = wrapper
+            return results
 
         exec(script)
-        all_expected = []
-        for cls, slot, _ in slots:
-            result = rch.recv()
-            assert result == (cls.__name__, slot, result[-1]), (cls, slot, result)
-            all_expected.append(result)
+        raw = rch.recv_nowait()
+        main_results = collate_results(raw)
 
         interp = interpreters.create()
         interp.exec('from test.support import interpreters')
         interp.prepare_main(sch=sch)
         interp.exec(script)
-
-        for i, (cls, slot, _) in enumerate(slots):
-            with self.subTest(cls=cls, slot=slot):
-                expected = all_expected[i]
-                result = rch.recv()
-                self.assertEqual(result, expected)
+        raw = rch.recv_nowait()
+        interp_results = collate_results(raw)
+
+        for key, expected in main_results.items():
+            cls, attr = key
+            with self.subTest(cls=cls, slotattr=attr):
+                actual = interp_results.pop(key)
+                # XXX This should not be necessary.
+                if cls == "<class 'collections.OrderedDict'>" and attr == '__len__':
+                    continue
+                self.assertEqual(actual, expected)
+        # XXX This should not be necessary.
+        interp_results = {k: v for k, v in interp_results.items() if k[1] != '__hash__'}
+        # XXX This should not be necessary.
+        interp_results.pop(("<class 'collections.OrderedDict'>", '__getitem__'), None)
+        self.maxDiff = None
+        self.assertEqual(interp_results, {})
 
 
 if __name__ == '__main__':
diff --git a/Modules/_testinternalcapi.c b/Modules/_testinternalcapi.c
index 6e6386b..00174ff 100644
--- a/Modules/_testinternalcapi.c
+++ b/Modules/_testinternalcapi.c
@@ -2035,6 +2035,20 @@ gh_119213_getargs_impl(PyObject *module, PyObject *spam)
 }
 
 
+static PyObject *
+get_static_builtin_types(PyObject *self, PyObject *Py_UNUSED(ignored))
+{
+    return _PyStaticType_GetBuiltins();
+}
+
+
+static PyObject *
+identify_type_slot_wrappers(PyObject *self, PyObject *Py_UNUSED(ignored))
+{
+    return _PyType_GetSlotWrapperNames();
+}
+
+
 static PyMethodDef module_functions[] = {
     {"get_configs", get_configs, METH_NOARGS},
     {"get_recursion_depth", get_recursion_depth, METH_NOARGS},
@@ -2129,6 +2143,8 @@ static PyMethodDef module_functions[] = {
     {"uop_symbols_test", _Py_uop_symbols_test, METH_NOARGS},
 #endif
     GH_119213_GETARGS_METHODDEF
+    {"get_static_builtin_types", get_static_builtin_types, METH_NOARGS},
+    {"identify_type_slot_wrappers", identify_type_slot_wrappers, METH_NOARGS},
     {NULL, NULL} /* sentinel */
 };
 
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index 00f0dc9..0d7009a 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -324,6 +324,29 @@ managed_static_type_get_def(PyTypeObject *self, int isbuiltin)
     return &_PyRuntime.types.managed_static.types[full_index].def;
 }
 
+
+PyObject *
+_PyStaticType_GetBuiltins(void)
+{
+    PyInterpreterState *interp = _PyInterpreterState_GET();
+    Py_ssize_t count = (Py_ssize_t)interp->types.builtins.num_initialized;
+    assert(count <= _Py_MAX_MANAGED_STATIC_BUILTIN_TYPES);
+
+    PyObject *results = PyList_New(count);
+    if (results == NULL) {
+        return NULL;
+    }
+    for (Py_ssize_t i = 0; i < count; i++) {
+        PyTypeObject *cls = interp->types.builtins.initialized[i].type;
+        assert(cls != NULL);
+        assert(interp->types.builtins.initialized[i].isbuiltin);
+        PyList_SET_ITEM(results, i, Py_NewRef((PyObject *)cls));
+    }
+
+    return results;
+}
+
+
 // Also see _PyStaticType_InitBuiltin() and _PyStaticType_FiniBuiltin().
 
 /* end static builtin helpers */
@@ -10927,6 +10950,24 @@ update_all_slots(PyTypeObject* type)
 }
 
 
+PyObject *
+_PyType_GetSlotWrapperNames(void)
+{
+    size_t len = Py_ARRAY_LENGTH(slotdefs) - 1;
+    PyObject *names = PyList_New(len);
+    if (names == NULL) {
+        return NULL;
+    }
+    assert(slotdefs[len].name == NULL);
+    for (size_t i = 0; i < len; i++) {
+        pytype_slotdef *slotdef = &slotdefs[i];
+        assert(slotdef->name != NULL);
+        PyList_SET_ITEM(names, i, Py_NewRef(slotdef->name_strobj));
+    }
+    return names;
+}
+
+
 /* Call __set_name__ on all attributes (including descriptors)
   in a newly generated type */
 static int
-- 
cgit v0.12