diff options
author | Ethan Furman <ethan@stoneleaf.us> | 2013-07-25 20:50:45 (GMT) |
---|---|---|
committer | Ethan Furman <ethan@stoneleaf.us> | 2013-07-25 20:50:45 (GMT) |
commit | b41803e3ef65624ce39a4bcf4caff6ff1184699f (patch) | |
tree | 8c3380a6dbce75d34e6058087881f9b9e7dd5c33 | |
parent | 4d35e75ca069b51ffdac7b34dad4ffb77e72a598 (diff) | |
download | cpython-b41803e3ef65624ce39a4bcf4caff6ff1184699f.zip cpython-b41803e3ef65624ce39a4bcf4caff6ff1184699f.tar.gz cpython-b41803e3ef65624ce39a4bcf4caff6ff1184699f.tar.bz2 |
Close #18545: now only executes member_type if no _value_ is assigned in __new__.
-rw-r--r-- | Lib/enum.py | 8 | ||||
-rw-r--r-- | Lib/test/test_enum.py | 16 |
2 files changed, 20 insertions, 4 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index 0def138..33af042 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -152,12 +152,12 @@ class EnumMeta(type): args = (args, ) # wrap it one more time if not use_args: enum_member = __new__(enum_class) - original_value = value + if not hasattr(enum_member, '_value_'): + enum_member._value_ = value else: enum_member = __new__(enum_class, *args) - original_value = member_type(*args) - if not hasattr(enum_member, '_value_'): - enum_member._value_ = original_value + if not hasattr(enum_member, '_value_'): + enum_member._value_ = member_type(*args) value = enum_member._value_ enum_member._member_type_ = member_type enum_member._name_ = member_name diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index d0b4a1c..91c4b69 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -934,6 +934,22 @@ class TestEnum(unittest.TestCase): self.assertEqual(ColorInAList.red.value, [1]) self.assertEqual(ColorInAList([1]), ColorInAList.red) + def test_conflicting_types_resolved_in_new(self): + class LabelledIntEnum(int, Enum): + def __new__(cls, *args): + value, label = args + obj = int.__new__(cls, value) + obj.label = label + obj._value_ = value + return obj + + class LabelledList(LabelledIntEnum): + unprocessed = (1, "Unprocessed") + payment_complete = (2, "Payment Complete") + + self.assertEqual(list(LabelledList), [LabelledList.unprocessed, LabelledList.payment_complete]) + self.assertEqual(LabelledList.unprocessed, 1) + self.assertEqual(LabelledList(1), LabelledList.unprocessed) class TestUnique(unittest.TestCase): |