summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEthan Furman <ethan@stoneleaf.us>2022-07-18 01:51:04 (GMT)
committerGitHub <noreply@github.com>2022-07-18 01:51:04 (GMT)
commitc961d14f85a0e3e53d5ad1182206ef34030f10b8 (patch)
tree9d96259939db36f0c1d3007fb1e8bb5f425e6165
parent07aeb7405ea42729b95ecae225f1d96a4aea5121 (diff)
downloadcpython-c961d14f85a0e3e53d5ad1182206ef34030f10b8.zip
cpython-c961d14f85a0e3e53d5ad1182206ef34030f10b8.tar.gz
cpython-c961d14f85a0e3e53d5ad1182206ef34030f10b8.tar.bz2
gh-94601: [Enum] fix inheritance for __str__ and friends (GH-94942)
-rw-r--r--Lib/enum.py26
-rw-r--r--Lib/test/test_enum.py26
2 files changed, 42 insertions, 10 deletions
diff --git a/Lib/enum.py b/Lib/enum.py
index a4f1f09..80945c1 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -247,7 +247,10 @@ class _proto_member:
if not enum_class._use_args_:
enum_member = enum_class._new_member_(enum_class)
if not hasattr(enum_member, '_value_'):
- enum_member._value_ = value
+ try:
+ enum_member._value_ = enum_class._member_type_(*args)
+ except Exception as exc:
+ enum_member._value_ = value
else:
enum_member = enum_class._new_member_(enum_class, *args)
if not hasattr(enum_member, '_value_'):
@@ -562,7 +565,13 @@ class EnumType(type):
classdict['__str__'] = enum_class.__str__
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
if name not in classdict:
- setattr(enum_class, name, getattr(first_enum, name))
+ # check for mixin overrides before replacing
+ enum_method = getattr(first_enum, name)
+ found_method = getattr(enum_class, name)
+ object_method = getattr(object, name)
+ data_type_method = getattr(member_type, name)
+ if found_method in (data_type_method, object_method):
+ setattr(enum_class, name, enum_method)
#
# for Flag, add __or__, __and__, __xor__, and __invert__
if Flag is not None and issubclass(enum_class, Flag):
@@ -937,16 +946,18 @@ class EnumType(type):
@classmethod
def _find_data_type_(mcls, class_name, bases):
data_types = set()
+ base_chain = set()
for chain in bases:
candidate = None
for base in chain.__mro__:
+ base_chain.add(base)
if base is object:
continue
elif issubclass(base, Enum):
if base._member_type_ is not object:
data_types.add(base._member_type_)
break
- elif '__new__' in base.__dict__:
+ elif '__new__' in base.__dict__ or '__init__' in base.__dict__:
if issubclass(base, Enum):
continue
data_types.add(candidate or base)
@@ -1658,7 +1669,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
enum_class = type(cls_name, (etype, ), body, boundary=boundary, _simple=True)
for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'):
if name not in body:
- setattr(enum_class, name, getattr(etype, name))
+ # check for mixin overrides before replacing
+ enum_method = getattr(etype, name)
+ found_method = getattr(enum_class, name)
+ object_method = getattr(object, name)
+ data_type_method = getattr(member_type, name)
+ if found_method in (data_type_method, object_method):
+ setattr(enum_class, name, enum_method)
gnv_last_values = []
if issubclass(enum_class, Flag):
# Flag / IntFlag
@@ -1989,7 +2006,6 @@ def _old_convert_(etype, name, module, filter, source=None, *, boundary=None):
members.sort(key=lambda t: t[0])
cls = etype(name, members, module=module, boundary=boundary or KEEP)
cls.__reduce_ex__ = _reduce_ex_by_global_name
- cls.__repr__ = global_enum_repr
return cls
_stdlib_enums = IntEnum, StrEnum, IntFlag
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index 87d7c72..69fba9a 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -2693,12 +2693,15 @@ class TestSpecial(unittest.TestCase):
@dataclass
class Foo:
__qualname__ = 'Foo'
- a: int = 0
+ a: int
class Entries(Foo, Enum):
- ENTRY1 = Foo(1)
+ ENTRY1 = 1
+ self.assertTrue(isinstance(Entries.ENTRY1, Foo))
+ self.assertTrue(Entries._member_type_ is Foo, Entries._member_type_)
+ self.assertTrue(Entries.ENTRY1.value == Foo(1), Entries.ENTRY1.value)
self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
- def test_repr_with_non_data_type_mixin(self):
+ def test_repr_with_init_data_type_mixin(self):
# non-data_type is a mixin that doesn't define __new__
class Foo:
def __init__(self, a):
@@ -2706,10 +2709,23 @@ class TestSpecial(unittest.TestCase):
def __repr__(self):
return f'Foo(a={self.a!r})'
class Entries(Foo, Enum):
- ENTRY1 = Foo(1)
-
+ ENTRY1 = 1
+ #
self.assertEqual(repr(Entries.ENTRY1), '<Entries.ENTRY1: Foo(a=1)>')
+ def test_repr_and_str_with_non_data_type_mixin(self):
+ # non-data_type is a mixin that doesn't define __new__
+ class Foo:
+ def __repr__(self):
+ return 'Foo'
+ def __str__(self):
+ return 'ooF'
+ class Entries(Foo, Enum):
+ ENTRY1 = 1
+ #
+ self.assertEqual(repr(Entries.ENTRY1), 'Foo')
+ self.assertEqual(str(Entries.ENTRY1), 'ooF')
+
def test_value_backup_assign(self):
# check that enum will add missing values when custom __new__ does not
class Some(Enum):