summaryrefslogtreecommitdiffstats
path: root/Lib/enum.py
diff options
context:
space:
mode:
authorEthan Furman <ethan@stoneleaf.us>2024-10-22 18:04:00 (GMT)
committerGitHub <noreply@github.com>2024-10-22 18:04:00 (GMT)
commitaaed91cabcedc16c089c4b1c9abb1114659a83d3 (patch)
tree3ca58dbb680453ade52a9f4af49245bfeb4ec759 /Lib/enum.py
parent079875e39589eb0628b5883f7ffa387e7476ec06 (diff)
downloadcpython-aaed91cabcedc16c089c4b1c9abb1114659a83d3.zip
cpython-aaed91cabcedc16c089c4b1c9abb1114659a83d3.tar.gz
cpython-aaed91cabcedc16c089c4b1c9abb1114659a83d3.tar.bz2
gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values (GH-125735)
Diffstat (limited to 'Lib/enum.py')
-rw-r--r--Lib/enum.py26
1 files changed, 20 insertions, 6 deletions
diff --git a/Lib/enum.py b/Lib/enum.py
index 17d7273..4f99122 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -327,6 +327,8 @@ class _proto_member:
# to the map, and by-value lookups for this value will be
# linear.
enum_class._value2member_map_.setdefault(value, enum_member)
+ if value not in enum_class._hashable_values_:
+ enum_class._hashable_values_.append(value)
except TypeError:
# keep track of the value in a list so containment checks are quick
enum_class._unhashable_values_.append(value)
@@ -538,7 +540,8 @@ class EnumType(type):
classdict['_member_names_'] = []
classdict['_member_map_'] = {}
classdict['_value2member_map_'] = {}
- classdict['_unhashable_values_'] = []
+ classdict['_hashable_values_'] = [] # for comparing with non-hashable types
+ classdict['_unhashable_values_'] = [] # e.g. frozenset() with set()
classdict['_unhashable_values_map_'] = {}
classdict['_member_type_'] = member_type
# now set the __repr__ for the value
@@ -748,7 +751,10 @@ class EnumType(type):
try:
return value in cls._value2member_map_
except TypeError:
- return value in cls._unhashable_values_
+ return (
+ value in cls._unhashable_values_ # both structures are lists
+ or value in cls._hashable_values_
+ )
def __delattr__(cls, attr):
# nicer error message when someone tries to delete an attribute
@@ -1166,8 +1172,11 @@ class Enum(metaclass=EnumType):
pass
except TypeError:
# not there, now do long search -- O(n) behavior
- for name, values in cls._unhashable_values_map_.items():
- if value in values:
+ for name, unhashable_values in cls._unhashable_values_map_.items():
+ if value in unhashable_values:
+ return cls[name]
+ for name, member in cls._member_map_.items():
+ if value == member._value_:
return cls[name]
# still not found -- verify that members exist, in-case somebody got here mistakenly
# (such as via super when trying to override __new__)
@@ -1233,6 +1242,7 @@ class Enum(metaclass=EnumType):
# to the map, and by-value lookups for this value will be
# linear.
cls._value2member_map_.setdefault(value, self)
+ cls._hashable_values_.append(value)
except TypeError:
# keep track of the value in a list so containment checks are quick
cls._unhashable_values_.append(value)
@@ -1763,6 +1773,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
body['_member_names_'] = member_names = []
body['_member_map_'] = member_map = {}
body['_value2member_map_'] = value2member_map = {}
+ body['_hashable_values_'] = hashable_values = []
body['_unhashable_values_'] = unhashable_values = []
body['_unhashable_values_map_'] = {}
body['_member_type_'] = member_type = etype._member_type_
@@ -1826,7 +1837,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
contained = value2member_map.get(member._value_)
except TypeError:
contained = None
- if member._value_ in unhashable_values:
+ if member._value_ in unhashable_values or member.value in hashable_values:
for m in enum_class:
if m._value_ == member._value_:
contained = m
@@ -1846,6 +1857,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
else:
enum_class._add_member_(name, member)
value2member_map[value] = member
+ hashable_values.append(value)
if _is_single_bit(value):
# not a multi-bit alias, record in _member_names_ and _flag_mask_
member_names.append(name)
@@ -1882,7 +1894,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
contained = value2member_map.get(member._value_)
except TypeError:
contained = None
- if member._value_ in unhashable_values:
+ if member._value_ in unhashable_values or member._value_ in hashable_values:
for m in enum_class:
if m._value_ == member._value_:
contained = m
@@ -1908,6 +1920,8 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
# to the map, and by-value lookups for this value will be
# linear.
enum_class._value2member_map_.setdefault(value, member)
+ if value not in hashable_values:
+ hashable_values.append(value)
except TypeError:
# keep track of the value in a list so containment checks are quick
enum_class._unhashable_values_.append(value)