From 33ec88ac81f23668293d101b83367b086c795e5e Mon Sep 17 00:00:00 2001 From: Mark Shannon Date: Mon, 3 May 2021 00:38:22 +0100 Subject: bpo-43977: Make sure that tp_flags for pattern matching are inherited correctly. (GH-25813) --- Lib/test/test_collections.py | 6 ++++ Lib/test/test_patma.py | 41 ++++++++++++++++++++++ .../2021-05-02-11-59-00.bpo-43977.R0hSDo.rst | 1 + Modules/_abc.c | 11 ++++-- Objects/typeobject.c | 18 ++++++---- 5 files changed, 69 insertions(+), 8 deletions(-) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2021-05-02-11-59-00.bpo-43977.R0hSDo.rst diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 98690d2..2ba1a19 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -1967,6 +1967,12 @@ class TestCollectionABCs(ABCTestCase): self.assertEqual(len(mss), len(mss2)) self.assertEqual(list(mss), list(mss2)) + def test_illegal_patma_flags(self): + with self.assertRaises(TypeError): + class Both(Collection): + __abc_tpflags__ = (Sequence.__flags__ | Mapping.__flags__) + + ################################################################################ ### Counter diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 8a273be..084d087 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -2979,6 +2979,47 @@ class TestPatma(unittest.TestCase): self.assertEqual(f((False, range(10, 20), True)), alts[4]) +class TestInheritance(unittest.TestCase): + + def test_multiple_inheritance(self): + class C: + pass + class S1(collections.UserList, collections.abc.Mapping): + pass + class S2(C, collections.UserList, collections.abc.Mapping): + pass + class S3(list, C, collections.abc.Mapping): + pass + class S4(collections.UserList, dict, C): + pass + class M1(collections.UserDict, collections.abc.Sequence): + pass + class M2(C, collections.UserDict, collections.abc.Sequence): + pass + class M3(collections.UserDict, C, list): + pass + class M4(dict, collections.abc.Sequence, C): + pass + def f(x): + match x: + case []: + return "seq" + case {}: + return "map" + def g(x): + match x: + case {}: + return "map" + case []: + return "seq" + for Seq in (S1, S2, S3, S4): + self.assertEqual(f(Seq()), "seq") + self.assertEqual(g(Seq()), "seq") + for Map in (M1, M2, M3, M4): + self.assertEqual(f(Map()), "map") + self.assertEqual(g(Map()), "map") + + class PerfPatma(TestPatma): def assertEqual(*_, **__): diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-05-02-11-59-00.bpo-43977.R0hSDo.rst b/Misc/NEWS.d/next/Core and Builtins/2021-05-02-11-59-00.bpo-43977.R0hSDo.rst new file mode 100644 index 0000000..95aacaf --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2021-05-02-11-59-00.bpo-43977.R0hSDo.rst @@ -0,0 +1 @@ +Prevent classes being both a sequence and a mapping when pattern matching. diff --git a/Modules/_abc.c b/Modules/_abc.c index 39261dd..7720d40 100644 --- a/Modules/_abc.c +++ b/Modules/_abc.c @@ -467,6 +467,10 @@ _abc__abc_init(PyObject *module, PyObject *self) if (val == -1 && PyErr_Occurred()) { return NULL; } + if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) { + PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING"); + return NULL; + } ((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS); } if (_PyDict_DelItemId(cls->tp_dict, &PyId___abc_tpflags__) < 0) { @@ -527,9 +531,12 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass) /* Invalidate negative cache */ get_abc_state(module)->abc_invalidation_counter++; - if (PyType_Check(subclass) && PyType_Check(self) && - !PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE)) + /* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */ + if (PyType_Check(self) && + !PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) && + ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS) { + ((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS; ((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS); } Py_INCREF(subclass); diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 621bb0c..e511cf9 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -5713,12 +5713,6 @@ inherit_special(PyTypeObject *type, PyTypeObject *base) if (PyType_HasFeature(base, _Py_TPFLAGS_MATCH_SELF)) { type->tp_flags |= _Py_TPFLAGS_MATCH_SELF; } - if (PyType_HasFeature(base, Py_TPFLAGS_SEQUENCE)) { - type->tp_flags |= Py_TPFLAGS_SEQUENCE; - } - if (PyType_HasFeature(base, Py_TPFLAGS_MAPPING)) { - type->tp_flags |= Py_TPFLAGS_MAPPING; - } } static int @@ -5936,6 +5930,7 @@ inherit_slots(PyTypeObject *type, PyTypeObject *base) static int add_operators(PyTypeObject *); static int add_tp_new_wrapper(PyTypeObject *type); +#define COLLECTION_FLAGS (Py_TPFLAGS_SEQUENCE | Py_TPFLAGS_MAPPING) static int type_ready_checks(PyTypeObject *type) @@ -5962,6 +5957,10 @@ type_ready_checks(PyTypeObject *type) _PyObject_ASSERT((PyObject *)type, type->tp_as_async->am_send != NULL); } + /* Consistency checks for pattern matching + * Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING are mutually exclusive */ + _PyObject_ASSERT((PyObject *)type, (type->tp_flags & COLLECTION_FLAGS) != COLLECTION_FLAGS); + if (type->tp_name == NULL) { PyErr_Format(PyExc_SystemError, "Type does not define the tp_name field."); @@ -6156,6 +6155,12 @@ type_ready_inherit_as_structs(PyTypeObject *type, PyTypeObject *base) } } +static void +inherit_patma_flags(PyTypeObject *type, PyTypeObject *base) { + if ((type->tp_flags & COLLECTION_FLAGS) == 0) { + type->tp_flags |= base->tp_flags & COLLECTION_FLAGS; + } +} static int type_ready_inherit(PyTypeObject *type) @@ -6175,6 +6180,7 @@ type_ready_inherit(PyTypeObject *type) if (inherit_slots(type, (PyTypeObject *)b) < 0) { return -1; } + inherit_patma_flags(type, (PyTypeObject *)b); } } -- cgit v0.12