From e747562345238cefd68ef6581feb17707a3a06ff Mon Sep 17 00:00:00 2001 From: "Miss Islington (bot)" <31488909+miss-islington@users.noreply.github.com> Date: Sun, 17 Jul 2022 19:18:41 -0700 Subject: gh-94601: [Enum] fix inheritance for __str__ and friends (GH-94942) (cherry picked from commit c961d14f85a0e3e53d5ad1182206ef34030f10b8) Co-authored-by: Ethan Furman --- Lib/enum.py | 26 +++++++++++++++++++++----- Lib/test/test_enum.py | 26 +++++++++++++++++++++----- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/Lib/enum.py b/Lib/enum.py index b19d40c..f5c29ed 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -247,7 +247,10 @@ class _proto_member: if not enum_class._use_args_: enum_member = enum_class._new_member_(enum_class) if not hasattr(enum_member, '_value_'): - enum_member._value_ = value + try: + enum_member._value_ = enum_class._member_type_(*args) + except Exception as exc: + enum_member._value_ = value else: enum_member = enum_class._new_member_(enum_class, *args) if not hasattr(enum_member, '_value_'): @@ -562,7 +565,13 @@ class EnumType(type): classdict['__str__'] = enum_class.__str__ for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): if name not in classdict: - setattr(enum_class, name, getattr(first_enum, name)) + # check for mixin overrides before replacing + enum_method = getattr(first_enum, name) + found_method = getattr(enum_class, name) + object_method = getattr(object, name) + data_type_method = getattr(member_type, name) + if found_method in (data_type_method, object_method): + setattr(enum_class, name, enum_method) # # for Flag, add __or__, __and__, __xor__, and __invert__ if Flag is not None and issubclass(enum_class, Flag): @@ -950,16 +959,18 @@ class EnumType(type): @classmethod def _find_data_type_(mcls, class_name, bases): data_types = set() + base_chain = set() for chain in bases: candidate = None for base in chain.__mro__: + base_chain.add(base) if base is object: continue elif issubclass(base, Enum): if base._member_type_ is not object: data_types.add(base._member_type_) break - elif '__new__' in base.__dict__: + elif '__new__' in base.__dict__ or '__init__' in base.__dict__: if issubclass(base, Enum): continue data_types.add(candidate or base) @@ -1671,7 +1682,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True) for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): if name not in body: - setattr(enum_class, name, getattr(etype, name)) + # check for mixin overrides before replacing + enum_method = getattr(etype, name) + found_method = getattr(enum_class, name) + object_method = getattr(object, name) + data_type_method = getattr(member_type, name) + if found_method in (data_type_method, object_method): + setattr(enum_class, name, enum_method) gnv_last_values = [] if issubclass(enum_class, Flag): # Flag / IntFlag @@ -2002,7 +2019,6 @@ def _old_convert_(etype, name, module, filter, source=None, *, boundary=None): members.sort(key=lambda t: t[0]) cls = etype(name, members, module=module, boundary=boundary or KEEP) cls.__reduce_ex__ = _reduce_ex_by_global_name - cls.__repr__ = global_enum_repr return cls _stdlib_enums = IntEnum, StrEnum, IntFlag diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 74f31be..80834f2 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -2658,12 +2658,15 @@ class TestSpecial(unittest.TestCase): @dataclass class Foo: __qualname__ = 'Foo' - a: int = 0 + a: int class Entries(Foo, Enum): - ENTRY1 = Foo(1) + ENTRY1 = 1 + self.assertTrue(isinstance(Entries.ENTRY1, Foo)) + self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_) + self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value) self.assertEqual(repr(Entries.ENTRY1), '') - def test_repr_with_non_data_type_mixin(self): + def test_repr_with_init_data_type_mixin(self): # non-data_type is a mixin that doesn't define __new__ class Foo: def __init__(self, a): @@ -2671,10 +2674,23 @@ class TestSpecial(unittest.TestCase): def __repr__(self): return f'Foo(a={self.a!r})' class Entries(Foo, Enum): - ENTRY1 = Foo(1) - + ENTRY1 = 1 + # self.assertEqual(repr(Entries.ENTRY1), '') + def test_repr_and_str_with_non_data_type_mixin(self): + # non-data_type is a mixin that doesn't define __new__ + class Foo: + def __repr__(self): + return 'Foo' + def __str__(self): + return 'ooF' + class Entries(Foo, Enum): + ENTRY1 = 1 + # + self.assertEqual(repr(Entries.ENTRY1), 'Foo') + self.assertEqual(str(Entries.ENTRY1), 'ooF') + def test_value_backup_assign(self): # check that enum will add missing values when custom __new__ does not class Some(Enum): -- cgit v0.12