summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/enum.py24
-rw-r--r--Lib/test/test_enum.py18
2 files changed, 31 insertions, 11 deletions
diff --git a/Lib/enum.py b/Lib/enum.py
index 38d95c5..787945a 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -1,5 +1,3 @@
-"""Python Enumerations"""
-
import sys
from collections import OrderedDict
from types import MappingProxyType
@@ -154,11 +152,13 @@ class EnumMeta(type):
args = (args, ) # wrap it one more time
if not use_args:
enum_member = __new__(enum_class)
- enum_member._value = value
+ original_value = value
else:
enum_member = __new__(enum_class, *args)
- if not hasattr(enum_member, '_value'):
- enum_member._value = member_type(*args)
+ original_value = member_type(*args)
+ if not hasattr(enum_member, '_value'):
+ enum_member._value = original_value
+ value = enum_member._value
enum_member._member_type = member_type
enum_member._name = member_name
enum_member.__init__(*args)
@@ -416,12 +416,14 @@ class Enum(metaclass=EnumMeta):
return value
# by-value search for a matching enum member
# see if it's in the reverse mapping (for hashable values)
- if value in cls._value2member_map:
- return cls._value2member_map[value]
- # not there, now do long search -- O(n) behavior
- for member in cls._member_map.values():
- if member.value == value:
- return member
+ try:
+ if value in cls._value2member_map:
+ return cls._value2member_map[value]
+ except TypeError:
+ # not there, now do long search -- O(n) behavior
+ for member in cls._member_map.values():
+ if member.value == value:
+ return member
raise ValueError("%s is not a valid %s" % (value, cls.__name__))
def __repr__(self):
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index 2b87c56..c947182 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -694,6 +694,7 @@ class TestEnum(unittest.TestCase):
x = ('the-x', 1)
y = ('the-y', 2)
+
self.assertIs(NEI.__new__, Enum.__new__)
self.assertEqual(repr(NEI.x + NEI.y), "NamedInt('(the-x + the-y)', 3)")
globals()['NamedInt'] = NamedInt
@@ -785,6 +786,7 @@ class TestEnum(unittest.TestCase):
[AutoNumber.first, AutoNumber.second, AutoNumber.third],
)
self.assertEqual(int(AutoNumber.second), 2)
+ self.assertEqual(AutoNumber.third.value, 3)
self.assertIs(AutoNumber(1), AutoNumber.first)
def test_inherited_new_from_enhanced_enum(self):
@@ -916,6 +918,22 @@ class TestEnum(unittest.TestCase):
self.assertEqual(round(Planet.EARTH.surface_gravity, 2), 9.80)
self.assertEqual(Planet.EARTH.value, (5.976e+24, 6.37814e6))
+ def test_nonhash_value(self):
+ class AutoNumberInAList(Enum):
+ def __new__(cls):
+ value = [len(cls.__members__) + 1]
+ obj = object.__new__(cls)
+ obj._value = value
+ return obj
+ class ColorInAList(AutoNumberInAList):
+ red = ()
+ green = ()
+ blue = ()
+ self.assertEqual(list(ColorInAList), [ColorInAList.red, ColorInAList.green, ColorInAList.blue])
+ self.assertEqual(ColorInAList.red.value, [1])
+ self.assertEqual(ColorInAList([1]), ColorInAList.red)
+
+
class TestUnique(unittest.TestCase):