diff options
author | Ethan Furman <ethan@stoneleaf.us> | 2024-10-22 18:04:00 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-22 18:04:00 (GMT) |
commit | aaed91cabcedc16c089c4b1c9abb1114659a83d3 (patch) | |
tree | 3ca58dbb680453ade52a9f4af49245bfeb4ec759 /Lib/enum.py | |
parent | 079875e39589eb0628b5883f7ffa387e7476ec06 (diff) | |
download | cpython-aaed91cabcedc16c089c4b1c9abb1114659a83d3.zip cpython-aaed91cabcedc16c089c4b1c9abb1114659a83d3.tar.gz cpython-aaed91cabcedc16c089c4b1c9abb1114659a83d3.tar.bz2 |
gh-125710: [Enum] fix hashable<->nonhashable comparisons for member values (GH-125735)
Diffstat (limited to 'Lib/enum.py')
-rw-r--r-- | Lib/enum.py | 26 |
1 files changed, 20 insertions, 6 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index 17d7273..4f99122 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -327,6 +327,8 @@ class _proto_member: # to the map, and by-value lookups for this value will be # linear. enum_class._value2member_map_.setdefault(value, enum_member) + if value not in enum_class._hashable_values_: + enum_class._hashable_values_.append(value) except TypeError: # keep track of the value in a list so containment checks are quick enum_class._unhashable_values_.append(value) @@ -538,7 +540,8 @@ class EnumType(type): classdict['_member_names_'] = [] classdict['_member_map_'] = {} classdict['_value2member_map_'] = {} - classdict['_unhashable_values_'] = [] + classdict['_hashable_values_'] = [] # for comparing with non-hashable types + classdict['_unhashable_values_'] = [] # e.g. frozenset() with set() classdict['_unhashable_values_map_'] = {} classdict['_member_type_'] = member_type # now set the __repr__ for the value @@ -748,7 +751,10 @@ class EnumType(type): try: return value in cls._value2member_map_ except TypeError: - return value in cls._unhashable_values_ + return ( + value in cls._unhashable_values_ # both structures are lists + or value in cls._hashable_values_ + ) def __delattr__(cls, attr): # nicer error message when someone tries to delete an attribute @@ -1166,8 +1172,11 @@ class Enum(metaclass=EnumType): pass except TypeError: # not there, now do long search -- O(n) behavior - for name, values in cls._unhashable_values_map_.items(): - if value in values: + for name, unhashable_values in cls._unhashable_values_map_.items(): + if value in unhashable_values: + return cls[name] + for name, member in cls._member_map_.items(): + if value == member._value_: return cls[name] # still not found -- verify that members exist, in-case somebody got here mistakenly # (such as via super when trying to override __new__) @@ -1233,6 +1242,7 @@ class Enum(metaclass=EnumType): # to the map, and by-value lookups for this value will be # linear. cls._value2member_map_.setdefault(value, self) + cls._hashable_values_.append(value) except TypeError: # keep track of the value in a list so containment checks are quick cls._unhashable_values_.append(value) @@ -1763,6 +1773,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): body['_member_names_'] = member_names = [] body['_member_map_'] = member_map = {} body['_value2member_map_'] = value2member_map = {} + body['_hashable_values_'] = hashable_values = [] body['_unhashable_values_'] = unhashable_values = [] body['_unhashable_values_map_'] = {} body['_member_type_'] = member_type = etype._member_type_ @@ -1826,7 +1837,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): contained = value2member_map.get(member._value_) except TypeError: contained = None - if member._value_ in unhashable_values: + if member._value_ in unhashable_values or member.value in hashable_values: for m in enum_class: if m._value_ == member._value_: contained = m @@ -1846,6 +1857,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): else: enum_class._add_member_(name, member) value2member_map[value] = member + hashable_values.append(value) if _is_single_bit(value): # not a multi-bit alias, record in _member_names_ and _flag_mask_ member_names.append(name) @@ -1882,7 +1894,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): contained = value2member_map.get(member._value_) except TypeError: contained = None - if member._value_ in unhashable_values: + if member._value_ in unhashable_values or member._value_ in hashable_values: for m in enum_class: if m._value_ == member._value_: contained = m @@ -1908,6 +1920,8 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): # to the map, and by-value lookups for this value will be # linear. enum_class._value2member_map_.setdefault(value, member) + if value not in hashable_values: + hashable_values.append(value) except TypeError: # keep track of the value in a list so containment checks are quick enum_class._unhashable_values_.append(value) |