From 0fd27375cabd12e68a2f12cfeca11a2d5043429e Mon Sep 17 00:00:00 2001 From: Serhiy Storchaka Date: Sat, 17 Jul 2021 22:44:10 +0300 Subject: bpo-44654: Refactor and clean up the union type implementation (GH-27196) --- Include/genericaliasobject.h | 5 -- Include/internal/pycore_unionobject.h | 10 ++- Lib/test/test_types.py | 42 ++++++----- Lib/typing.py | 2 +- Objects/abstract.c | 6 +- Objects/genericaliasobject.c | 7 +- Objects/object.c | 4 +- Objects/typeobject.c | 2 +- Objects/unionobject.c | 135 ++++++++++++---------------------- 9 files changed, 88 insertions(+), 125 deletions(-) diff --git a/Include/genericaliasobject.h b/Include/genericaliasobject.h index 4ce9244..cf00297 100644 --- a/Include/genericaliasobject.h +++ b/Include/genericaliasobject.h @@ -5,11 +5,6 @@ extern "C" { #endif -#ifndef Py_LIMITED_API -PyAPI_FUNC(PyObject *) _Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *); -PyAPI_FUNC(PyObject *) _Py_make_parameters(PyObject *); -#endif - PyAPI_FUNC(PyObject *) Py_GenericAlias(PyObject *, PyObject *); PyAPI_DATA(PyTypeObject) Py_GenericAliasType; diff --git a/Include/internal/pycore_unionobject.h b/Include/internal/pycore_unionobject.h index 4d82b6f..236989f 100644 --- a/Include/internal/pycore_unionobject.h +++ b/Include/internal/pycore_unionobject.h @@ -8,9 +8,13 @@ extern "C" { # error "this header requires Py_BUILD_CORE define" #endif -PyAPI_FUNC(PyObject *) _Py_Union(PyObject *args); -PyAPI_DATA(PyTypeObject) _Py_UnionType; -PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject* self, PyObject* param); +PyAPI_DATA(PyTypeObject) _PyUnion_Type; +#define _PyUnion_Check(op) Py_IS_TYPE(op, &_PyUnion_Type) +PyAPI_FUNC(PyObject *) _Py_union_type_or(PyObject *, PyObject *); + +#define _PyGenericAlias_Check(op) PyObject_TypeCheck(op, &Py_GenericAliasType) +PyAPI_FUNC(PyObject *) _Py_subs_parameters(PyObject *, PyObject *, PyObject *, PyObject *); +PyAPI_FUNC(PyObject *) _Py_make_parameters(PyObject *); #ifdef __cplusplus } diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 2d0e33f..b2e1130 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -611,6 +611,18 @@ class TypesTests(unittest.TestCase): self.assertIsInstance(int.from_bytes, types.BuiltinMethodType) self.assertIsInstance(int.__new__, types.BuiltinMethodType) + def test_ellipsis_type(self): + self.assertIsInstance(Ellipsis, types.EllipsisType) + + def test_notimplemented_type(self): + self.assertIsInstance(NotImplemented, types.NotImplementedType) + + def test_none_type(self): + self.assertIsInstance(None, types.NoneType) + + +class UnionTests(unittest.TestCase): + def test_or_types_operator(self): self.assertEqual(int | str, typing.Union[int, str]) self.assertNotEqual(int | list, typing.Union[int, str]) @@ -657,18 +669,23 @@ class TypesTests(unittest.TestCase): with self.assertRaises(TypeError): Example() | int x = int | str - self.assertNotEqual(x, {}) + self.assertEqual(x, int | str) + self.assertEqual(x, str | int) + self.assertNotEqual(x, {}) # should not raise exception with self.assertRaises(TypeError): - (int | str) < typing.Union[str, int] + x < x with self.assertRaises(TypeError): - (int | str) < (int | bool) + x <= x + y = typing.Union[str, int] with self.assertRaises(TypeError): - (int | str) <= (int | str) + x < y + y = int | bool with self.assertRaises(TypeError): - # Check that we don't crash if typing.Union does not have a tuple in __args__ - x = typing.Union[str, int] - x.__args__ = [str, int] - (int | str ) == x + x < y + # Check that we don't crash if typing.Union does not have a tuple in __args__ + y = typing.Union[str, int] + y.__args__ = [str, int] + self.assertEqual(x, y) def test_hash(self): self.assertEqual(hash(int | str), hash(str | int)) @@ -873,15 +890,6 @@ class TypesTests(unittest.TestCase): self.assertLessEqual(sys.gettotalrefcount() - before, leeway, msg='Check for union reference leak.') - def test_ellipsis_type(self): - self.assertIsInstance(Ellipsis, types.EllipsisType) - - def test_notimplemented_type(self): - self.assertIsInstance(NotImplemented, types.NotImplementedType) - - def test_none_type(self): - self.assertIsInstance(None, types.NoneType) - class MappingProxyTests(unittest.TestCase): mappingproxy = types.MappingProxyType diff --git a/Lib/typing.py b/Lib/typing.py index cc7f41d..59f3ca3 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -1184,7 +1184,7 @@ class _UnionGenericAlias(_GenericAlias, _root=True): return Union[params] def __eq__(self, other): - if not isinstance(other, _UnionGenericAlias): + if not isinstance(other, (_UnionGenericAlias, types.Union)): return NotImplemented return set(self.__args__) == set(other.__args__) diff --git a/Objects/abstract.c b/Objects/abstract.c index fcfe2db..f14a923 100644 --- a/Objects/abstract.c +++ b/Objects/abstract.c @@ -6,7 +6,7 @@ #include "pycore_object.h" // _Py_CheckSlotResult() #include "pycore_pyerrors.h" // _PyErr_Occurred() #include "pycore_pystate.h" // _PyThreadState_GET() -#include "pycore_unionobject.h" // _Py_UnionType && _Py_Union() +#include "pycore_unionobject.h" // _PyUnion_Check() #include #include // offsetof() #include "longintrepr.h" @@ -2623,9 +2623,7 @@ recursive_issubclass(PyObject *derived, PyObject *cls) "issubclass() arg 1 must be a class")) return -1; - PyTypeObject *type = Py_TYPE(cls); - int is_union = (PyType_Check(type) && type == &_Py_UnionType); - if (!is_union && !check_class(cls, + if (!_PyUnion_Check(cls) && !check_class(cls, "issubclass() arg 2 must be a class," " a tuple of classes, or a union.")) { return -1; diff --git a/Objects/genericaliasobject.c b/Objects/genericaliasobject.c index d3d3871..dda53cb 100644 --- a/Objects/genericaliasobject.c +++ b/Objects/genericaliasobject.c @@ -2,7 +2,7 @@ #include "Python.h" #include "pycore_object.h" -#include "pycore_unionobject.h" // _Py_union_as_number +#include "pycore_unionobject.h" // _Py_union_type_or, _PyGenericAlias_Check #include "structmember.h" // PyMemberDef typedef struct { @@ -441,8 +441,7 @@ ga_getattro(PyObject *self, PyObject *name) static PyObject * ga_richcompare(PyObject *a, PyObject *b, int op) { - if (!PyObject_TypeCheck(a, &Py_GenericAliasType) || - !PyObject_TypeCheck(b, &Py_GenericAliasType) || + if (!_PyGenericAlias_Check(b) || (op != Py_EQ && op != Py_NE)) { Py_RETURN_NOTIMPLEMENTED; @@ -622,7 +621,7 @@ ga_new(PyTypeObject *type, PyObject *args, PyObject *kwds) } static PyNumberMethods ga_as_number = { - .nb_or = (binaryfunc)_Py_union_type_or, // Add __or__ function + .nb_or = _Py_union_type_or, // Add __or__ function }; // TODO: diff --git a/Objects/object.c b/Objects/object.c index 8a854c7..446c974 100644 --- a/Objects/object.c +++ b/Objects/object.c @@ -11,7 +11,7 @@ #include "pycore_pymem.h" // _PyMem_IsPtrFreed() #include "pycore_pystate.h" // _PyThreadState_GET() #include "pycore_symtable.h" // PySTEntry_Type -#include "pycore_unionobject.h" // _Py_UnionType +#include "pycore_unionobject.h" // _PyUnion_Type #include "frameobject.h" #include "interpreteridobject.h" @@ -1878,7 +1878,7 @@ _PyTypes_Init(void) INIT_TYPE(_PyWeakref_CallableProxyType); INIT_TYPE(_PyWeakref_ProxyType); INIT_TYPE(_PyWeakref_RefType); - INIT_TYPE(_Py_UnionType); + INIT_TYPE(_PyUnion_Type); return _PyStatus_OK(); #undef INIT_TYPE diff --git a/Objects/typeobject.c b/Objects/typeobject.c index 3331fee..badd706 100644 --- a/Objects/typeobject.c +++ b/Objects/typeobject.c @@ -9,7 +9,7 @@ #include "pycore_object.h" #include "pycore_pyerrors.h" #include "pycore_pystate.h" // _PyThreadState_GET() -#include "pycore_unionobject.h" // _Py_Union(), _Py_union_type_or +#include "pycore_unionobject.h" // _Py_union_type_or #include "frameobject.h" #include "opcode.h" // MAKE_CELL #include "structmember.h" // PyMemberDef diff --git a/Objects/unionobject.c b/Objects/unionobject.c index b3a6506..c744c87 100644 --- a/Objects/unionobject.c +++ b/Objects/unionobject.c @@ -5,6 +5,9 @@ #include "structmember.h" +static PyObject *make_union(PyObject *); + + typedef struct { PyObject_HEAD PyObject *args; @@ -46,11 +49,12 @@ union_hash(PyObject *self) } static int -is_generic_alias_in_args(PyObject *args) { +is_generic_alias_in_args(PyObject *args) +{ Py_ssize_t nargs = PyTuple_GET_SIZE(args); for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) { PyObject *arg = PyTuple_GET_ITEM(args, iarg); - if (PyObject_TypeCheck(arg, &Py_GenericAliasType)) { + if (_PyGenericAlias_Check(arg)) { return 0; } } @@ -108,22 +112,26 @@ union_subclasscheck(PyObject *self, PyObject *instance) } } } - Py_RETURN_FALSE; + Py_RETURN_FALSE; } static int -is_typing_module(PyObject *obj) { - PyObject *module = PyObject_GetAttrString(obj, "__module__"); - if (module == NULL) { +is_typing_module(PyObject *obj) +{ + _Py_IDENTIFIER(__module__); + PyObject *module; + if (_PyObject_LookupAttrId(obj, &PyId___module__, &module) < 0) { return -1; } - int is_typing = PyUnicode_Check(module) && _PyUnicode_EqualToASCIIString(module, "typing"); - Py_DECREF(module); + int is_typing = (module != NULL && + PyUnicode_Check(module) && + _PyUnicode_EqualToASCIIString(module, "typing")); + Py_XDECREF(module); return is_typing; } static int -is_typing_name(PyObject *obj, char *name) +is_typing_name(PyObject *obj, const char *name) { PyTypeObject *type = Py_TYPE(obj); if (strcmp(type->tp_name, name) != 0) { @@ -135,66 +143,22 @@ is_typing_name(PyObject *obj, char *name) static PyObject * union_richcompare(PyObject *a, PyObject *b, int op) { - PyObject *result = NULL; - if (op != Py_EQ && op != Py_NE) { - result = Py_NotImplemented; - Py_INCREF(result); - return result; + if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) { + Py_RETURN_NOTIMPLEMENTED; } - PyTypeObject *type = Py_TYPE(b); - - PyObject* a_set = PySet_New(((unionobject*)a)->args); + PyObject *a_set = PySet_New(((unionobject*)a)->args); if (a_set == NULL) { return NULL; } - PyObject* b_set = PySet_New(NULL); + PyObject *b_set = PySet_New(((unionobject*)b)->args); if (b_set == NULL) { - goto exit; - } - - // Populate b_set with the data from the right object - int is_typing_union = is_typing_name(b, "_UnionGenericAlias"); - if (is_typing_union < 0) { - goto exit; - } - if (is_typing_union) { - PyObject *b_args = PyObject_GetAttrString(b, "__args__"); - if (b_args == NULL) { - goto exit; - } - if (!PyTuple_CheckExact(b_args)) { - Py_DECREF(b_args); - PyErr_SetString(PyExc_TypeError, "__args__ argument of typing.Union object is not a tuple"); - goto exit; - } - Py_ssize_t b_arg_length = PyTuple_GET_SIZE(b_args); - for (Py_ssize_t i = 0; i < b_arg_length; i++) { - PyObject* arg = PyTuple_GET_ITEM(b_args, i); - if (PySet_Add(b_set, arg) == -1) { - Py_DECREF(b_args); - goto exit; - } - } - Py_DECREF(b_args); - } else if (type == &_Py_UnionType) { - PyObject* args = ((unionobject*) b)->args; - Py_ssize_t arg_length = PyTuple_GET_SIZE(args); - for (Py_ssize_t i = 0; i < arg_length; i++) { - PyObject* arg = PyTuple_GET_ITEM(args, i); - if (PySet_Add(b_set, arg) == -1) { - goto exit; - } - } - } else { Py_DECREF(a_set); - Py_DECREF(b_set); - Py_RETURN_NOTIMPLEMENTED; + return NULL; } - result = PyObject_RichCompare(a_set, b_set, op); -exit: - Py_XDECREF(a_set); - Py_XDECREF(b_set); + PyObject *result = PyObject_RichCompare(a_set, b_set, op); + Py_DECREF(b_set); + Py_DECREF(a_set); return result; } @@ -206,8 +170,7 @@ flatten_args(PyObject* args) // Get number of total args once it's flattened. for (Py_ssize_t i = 0; i < arg_length; i++) { PyObject *arg = PyTuple_GET_ITEM(args, i); - PyTypeObject* arg_type = Py_TYPE(arg); - if (arg_type == &_Py_UnionType) { + if (_PyUnion_Check(arg)) { total_args += PyTuple_GET_SIZE(((unionobject*) arg)->args); } else { total_args++; @@ -221,8 +184,7 @@ flatten_args(PyObject* args) Py_ssize_t pos = 0; for (Py_ssize_t i = 0; i < arg_length; i++) { PyObject *arg = PyTuple_GET_ITEM(args, i); - PyTypeObject* arg_type = Py_TYPE(arg); - if (arg_type == &_Py_UnionType) { + if (_PyUnion_Check(arg)) { PyObject* nested_args = ((unionobject*)arg)->args; Py_ssize_t nested_arg_length = PyTuple_GET_SIZE(nested_args); for (Py_ssize_t j = 0; j < nested_arg_length; j++) { @@ -240,6 +202,7 @@ flatten_args(PyObject* args) pos++; } } + assert(pos == total_args); return flattened_args; } @@ -253,6 +216,7 @@ dedup_and_flatten_args(PyObject* args) Py_ssize_t arg_length = PyTuple_GET_SIZE(args); PyObject *new_args = PyTuple_New(arg_length); if (new_args == NULL) { + Py_DECREF(args); return NULL; } // Add unique elements to an array. @@ -262,8 +226,8 @@ dedup_and_flatten_args(PyObject* args) PyObject* i_element = PyTuple_GET_ITEM(args, i); for (Py_ssize_t j = 0; j < added_items; j++) { PyObject* j_element = PyTuple_GET_ITEM(new_args, j); - int is_ga = PyObject_TypeCheck(i_element, &Py_GenericAliasType) && - PyObject_TypeCheck(j_element, &Py_GenericAliasType); + int is_ga = _PyGenericAlias_Check(i_element) && + _PyGenericAlias_Check(j_element); // RichCompare to also deduplicate GenericAlias types (slower) is_duplicate = is_ga ? PyObject_RichCompareBool(i_element, j_element, Py_EQ) : i_element == j_element; @@ -314,7 +278,7 @@ is_new_type(PyObject *obj) #define CHECK_RES(res) { \ int result = res; \ if (result) { \ - return result; \ + return result; \ } \ } @@ -322,28 +286,27 @@ is_new_type(PyObject *obj) static int is_unionable(PyObject *obj) { - if (obj == Py_None) { + if (obj == Py_None || + PyType_Check(obj) || + _PyGenericAlias_Check(obj) || + _PyUnion_Check(obj)) + { return 1; } - PyTypeObject *type = Py_TYPE(obj); CHECK_RES(is_typevar(obj)); CHECK_RES(is_new_type(obj)); CHECK_RES(is_special_form(obj)); - return ( - // The following checks never fail. - PyType_Check(obj) || - PyObject_TypeCheck(obj, &Py_GenericAliasType) || - type == &_Py_UnionType); + return 0; } PyObject * -_Py_union_type_or(PyObject* self, PyObject* param) +_Py_union_type_or(PyObject* self, PyObject* other) { - PyObject *tuple = PyTuple_Pack(2, self, param); + PyObject *tuple = PyTuple_Pack(2, self, other); if (tuple == NULL) { return NULL; } - PyObject *new_union = _Py_Union(tuple); + PyObject *new_union = make_union(tuple); Py_DECREF(tuple); return new_union; } @@ -471,7 +434,7 @@ union_getitem(PyObject *self, PyObject *item) return NULL; } - PyObject *res = _Py_Union(newargs); + PyObject *res = make_union(newargs); Py_DECREF(newargs); return res; @@ -504,7 +467,7 @@ static PyNumberMethods union_as_number = { .nb_or = _Py_union_type_or, // Add __or__ function }; -PyTypeObject _Py_UnionType = { +PyTypeObject _PyUnion_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) .tp_name = "types.Union", .tp_doc = "Represent a PEP 604 union type\n" @@ -527,8 +490,8 @@ PyTypeObject _Py_UnionType = { .tp_getset = union_properties, }; -PyObject * -_Py_Union(PyObject *args) +static PyObject * +make_union(PyObject *args) { assert(PyTuple_CheckExact(args)); @@ -538,16 +501,12 @@ _Py_Union(PyObject *args) Py_ssize_t nargs = PyTuple_GET_SIZE(args); for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) { PyObject *arg = PyTuple_GET_ITEM(args, iarg); - if (arg == NULL) { - return NULL; - } int is_arg_unionable = is_unionable(arg); if (is_arg_unionable < 0) { return NULL; } if (!is_arg_unionable) { - Py_INCREF(Py_NotImplemented); - return Py_NotImplemented; + Py_RETURN_NOTIMPLEMENTED; } } @@ -562,7 +521,7 @@ _Py_Union(PyObject *args) return result1; } - result = PyObject_GC_New(unionobject, &_Py_UnionType); + result = PyObject_GC_New(unionobject, &_PyUnion_Type); if (result == NULL) { Py_DECREF(args); return NULL; -- cgit v0.12