summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>2024-03-01 18:01:27 (GMT)
committerGitHub <noreply@github.com>2024-03-01 18:01:27 (GMT)
commit90f75e1069f2d692480bcd305fc35b4fe7847e18 (patch)
tree7699f462cf551869091f80c8a2e5a7c26516607b /Lib
parent16be4a3b93ca315d6d95a5e5dd17c81d03ed578a (diff)
downloadcpython-90f75e1069f2d692480bcd305fc35b4fe7847e18.zip
cpython-90f75e1069f2d692480bcd305fc35b4fe7847e18.tar.gz
cpython-90f75e1069f2d692480bcd305fc35b4fe7847e18.tar.bz2
[3.12] gh-112281: Allow `Union` with unhashable `Annotated` metadata (GH-112283) (#116213)
Co-authored-by: Nikita Sobolev <mail@sobolevn.me> Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_types.py20
-rw-r--r--Lib/test/test_typing.py107
-rw-r--r--Lib/typing.py45
3 files changed, 154 insertions, 18 deletions
diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py
index b86392f4..5ffe408 100644
--- a/Lib/test/test_types.py
+++ b/Lib/test/test_types.py
@@ -709,6 +709,26 @@ class UnionTests(unittest.TestCase):
self.assertEqual(hash(int | str), hash(str | int))
self.assertEqual(hash(int | str), hash(typing.Union[int, str]))
+ def test_union_of_unhashable(self):
+ class UnhashableMeta(type):
+ __hash__ = None
+
+ class A(metaclass=UnhashableMeta): ...
+ class B(metaclass=UnhashableMeta): ...
+
+ self.assertEqual((A | B).__args__, (A, B))
+ union1 = A | B
+ with self.assertRaises(TypeError):
+ hash(union1)
+
+ union2 = int | B
+ with self.assertRaises(TypeError):
+ hash(union2)
+
+ union3 = A | int
+ with self.assertRaises(TypeError):
+ hash(union3)
+
def test_instancecheck_and_subclasscheck(self):
for x in (int | str, typing.Union[int, str]):
with self.subTest(x=x):
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py
index 7f9c10d..e0f7146 100644
--- a/Lib/test/test_typing.py
+++ b/Lib/test/test_typing.py
@@ -2,10 +2,11 @@ import contextlib
import collections
import collections.abc
from collections import defaultdict
-from functools import lru_cache, wraps
+from functools import lru_cache, wraps, reduce
import gc
import inspect
import itertools
+import operator
import pickle
import re
import sys
@@ -1770,6 +1771,26 @@ class UnionTests(BaseTestCase):
v = Union[u, Employee]
self.assertEqual(v, Union[int, float, Employee])
+ def test_union_of_unhashable(self):
+ class UnhashableMeta(type):
+ __hash__ = None
+
+ class A(metaclass=UnhashableMeta): ...
+ class B(metaclass=UnhashableMeta): ...
+
+ self.assertEqual(Union[A, B].__args__, (A, B))
+ union1 = Union[A, B]
+ with self.assertRaises(TypeError):
+ hash(union1)
+
+ union2 = Union[int, B]
+ with self.assertRaises(TypeError):
+ hash(union2)
+
+ union3 = Union[A, int]
+ with self.assertRaises(TypeError):
+ hash(union3)
+
def test_repr(self):
self.assertEqual(repr(Union), 'typing.Union')
u = Union[Employee, int]
@@ -5295,10 +5316,8 @@ class OverrideDecoratorTests(BaseTestCase):
self.assertFalse(hasattr(WithOverride.some, "__override__"))
def test_multiple_decorators(self):
- import functools
-
def with_wraps(f): # similar to `lru_cache` definition
- @functools.wraps(f)
+ @wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper
@@ -8183,6 +8202,76 @@ class AnnotatedTests(BaseTestCase):
self.assertEqual(A.__metadata__, (4, 5))
self.assertEqual(A.__origin__, int)
+ def test_deduplicate_from_union(self):
+ # Regular:
+ self.assertEqual(get_args(Annotated[int, 1] | int),
+ (Annotated[int, 1], int))
+ self.assertEqual(get_args(Union[Annotated[int, 1], int]),
+ (Annotated[int, 1], int))
+ self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int),
+ (Annotated[int, 1], Annotated[int, 2], int))
+ self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]),
+ (Annotated[int, 1], Annotated[int, 2], int))
+ self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int),
+ (Annotated[int, 1], Annotated[str, 1], int))
+ self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]),
+ (Annotated[int, 1], Annotated[str, 1], int))
+
+ # Duplicates:
+ self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int,
+ Annotated[int, 1] | int)
+ self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int],
+ Union[Annotated[int, 1], int])
+
+ # Unhashable metadata:
+ self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int),
+ (str, Annotated[int, {}], Annotated[int, set()], int))
+ self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]),
+ (str, Annotated[int, {}], Annotated[int, set()], int))
+ self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int),
+ (str, Annotated[int, {}], Annotated[str, {}], int))
+ self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]),
+ (str, Annotated[int, {}], Annotated[str, {}], int))
+
+ self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int),
+ (Annotated[int, 1], str, Annotated[str, {}], int))
+ self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]),
+ (Annotated[int, 1], str, Annotated[str, {}], int))
+
+ import dataclasses
+ @dataclasses.dataclass
+ class ValueRange:
+ lo: int
+ hi: int
+ v = ValueRange(1, 2)
+ self.assertEqual(get_args(Annotated[int, v] | None),
+ (Annotated[int, v], types.NoneType))
+ self.assertEqual(get_args(Union[Annotated[int, v], None]),
+ (Annotated[int, v], types.NoneType))
+ self.assertEqual(get_args(Optional[Annotated[int, v]]),
+ (Annotated[int, v], types.NoneType))
+
+ # Unhashable metadata duplicated:
+ self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
+ Annotated[int, {}] | int)
+ self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
+ int | Annotated[int, {}])
+ self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
+ Union[Annotated[int, {}], int])
+ self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
+ Union[int, Annotated[int, {}]])
+
+ def test_order_in_union(self):
+ expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int
+ for args in itertools.permutations(get_args(expr1)):
+ with self.subTest(args=args):
+ self.assertEqual(expr1, reduce(operator.or_, args))
+
+ expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int]
+ for args in itertools.permutations(get_args(expr2)):
+ with self.subTest(args=args):
+ self.assertEqual(expr2, Union[args])
+
def test_specialize(self):
L = Annotated[List[T], "my decoration"]
LI = Annotated[List[int], "my decoration"]
@@ -8203,6 +8292,16 @@ class AnnotatedTests(BaseTestCase):
{Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
{Annotated[int, 4, 5], Annotated[T, 4, 5]}
)
+ # Unhashable `metadata` raises `TypeError`:
+ a1 = Annotated[int, []]
+ with self.assertRaises(TypeError):
+ hash(a1)
+
+ class A:
+ __hash__ = None
+ a2 = Annotated[int, A()]
+ with self.assertRaises(TypeError):
+ hash(a2)
def test_instantiate(self):
class C:
diff --git a/Lib/typing.py b/Lib/typing.py
index 1e4c725..7581c16 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -314,19 +314,33 @@ def _unpack_args(args):
newargs.append(arg)
return newargs
-def _deduplicate(params):
+def _deduplicate(params, *, unhashable_fallback=False):
# Weed out strict duplicates, preserving the first of each occurrence.
- all_params = set(params)
- if len(all_params) < len(params):
- new_params = []
- for t in params:
- if t in all_params:
- new_params.append(t)
- all_params.remove(t)
- params = new_params
- assert not all_params, all_params
- return params
-
+ try:
+ return dict.fromkeys(params)
+ except TypeError:
+ if not unhashable_fallback:
+ raise
+ # Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
+ return _deduplicate_unhashable(params)
+
+def _deduplicate_unhashable(unhashable_params):
+ new_unhashable = []
+ for t in unhashable_params:
+ if t not in new_unhashable:
+ new_unhashable.append(t)
+ return new_unhashable
+
+def _compare_args_orderless(first_args, second_args):
+ first_unhashable = _deduplicate_unhashable(first_args)
+ second_unhashable = _deduplicate_unhashable(second_args)
+ t = list(second_unhashable)
+ try:
+ for elem in first_unhashable:
+ t.remove(elem)
+ except ValueError:
+ return False
+ return not t
def _remove_dups_flatten(parameters):
"""Internal helper for Union creation and substitution.
@@ -341,7 +355,7 @@ def _remove_dups_flatten(parameters):
else:
params.append(p)
- return tuple(_deduplicate(params))
+ return tuple(_deduplicate(params, unhashable_fallback=True))
def _flatten_literal_params(parameters):
@@ -1548,7 +1562,10 @@ class _UnionGenericAlias(_NotIterable, _GenericAlias, _root=True):
def __eq__(self, other):
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
return NotImplemented
- return set(self.__args__) == set(other.__args__)
+ try: # fast path
+ return set(self.__args__) == set(other.__args__)
+ except TypeError: # not hashable, slow path
+ return _compare_args_orderless(self.__args__, other.__args__)
def __hash__(self):
return hash(frozenset(self.__args__))