diff options
-rw-r--r-- | Lib/enum.py | 24 | ||||
-rw-r--r-- | Lib/test/test_enum.py | 18 |
2 files changed, 31 insertions, 11 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index 38d95c5..787945a 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -1,5 +1,3 @@ -"""Python Enumerations""" - import sys from collections import OrderedDict from types import MappingProxyType @@ -154,11 +152,13 @@ class EnumMeta(type): args = (args, ) # wrap it one more time if not use_args: enum_member = __new__(enum_class) - enum_member._value = value + original_value = value else: enum_member = __new__(enum_class, *args) - if not hasattr(enum_member, '_value'): - enum_member._value = member_type(*args) + original_value = member_type(*args) + if not hasattr(enum_member, '_value'): + enum_member._value = original_value + value = enum_member._value enum_member._member_type = member_type enum_member._name = member_name enum_member.__init__(*args) @@ -416,12 +416,14 @@ class Enum(metaclass=EnumMeta): return value # by-value search for a matching enum member # see if it's in the reverse mapping (for hashable values) - if value in cls._value2member_map: - return cls._value2member_map[value] - # not there, now do long search -- O(n) behavior - for member in cls._member_map.values(): - if member.value == value: - return member + try: + if value in cls._value2member_map: + return cls._value2member_map[value] + except TypeError: + # not there, now do long search -- O(n) behavior + for member in cls._member_map.values(): + if member.value == value: + return member raise ValueError("%s is not a valid %s" % (value, cls.__name__)) def __repr__(self): diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 2b87c56..c947182 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -694,6 +694,7 @@ class TestEnum(unittest.TestCase): x = ('the-x', 1) y = ('the-y', 2) + self.assertIs(NEI.__new__, Enum.__new__) self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)") globals()['NamedInt'] = NamedInt @@ -785,6 +786,7 @@ class TestEnum(unittest.TestCase): [AutoNumber.first, AutoNumber.second, AutoNumber.third], ) self.assertEqual(int(AutoNumber.second), 2) + self.assertEqual(AutoNumber.third.value, 3) self.assertIs(AutoNumber(1), AutoNumber.first) def test_inherited_new_from_enhanced_enum(self): @@ -916,6 +918,22 @@ class TestEnum(unittest.TestCase): self.assertEqual(round(Planet.EARTH.surface_gravity, 2), 9.80) self.assertEqual(Planet.EARTH.value, (5.976e+24, 6.37814e6)) + def test_nonhash_value(self): + class AutoNumberInAList(Enum): + def __new__(cls): + value = [len(cls.__members__) + 1] + obj = object.__new__(cls) + obj._value = value + return obj + class ColorInAList(AutoNumberInAList): + red = () + green = () + blue = () + self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue]) + self.assertEqual(ColorInAList.red.value, [1]) + self.assertEqual(ColorInAList([1]), ColorInAList.red) + + class TestUnique(unittest.TestCase): |