summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_typing.py11
-rw-r--r--Lib/typing.py26
2 files changed, 34 insertions, 3 deletions
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py
index c074e7a..08f7d02 100644
--- a/Lib/test/test_typing.py
+++ b/Lib/test/test_typing.py
@@ -4876,6 +4876,17 @@ class GetTypeHintTests(BaseTestCase):
'a': Annotated[Required[int], "a", "b", "c"]
})
+ def test_get_type_hints_collections_abc_callable(self):
+ # https://github.com/python/cpython/issues/91621
+ P = ParamSpec('P')
+ def f(x: collections.abc.Callable[[int], int]): ...
+ def g(x: collections.abc.Callable[..., int]): ...
+ def h(x: collections.abc.Callable[P, int]): ...
+
+ self.assertEqual(get_type_hints(f), {'x': collections.abc.Callable[[int], int]})
+ self.assertEqual(get_type_hints(g), {'x': collections.abc.Callable[..., int]})
+ self.assertEqual(get_type_hints(h), {'x': collections.abc.Callable[P, int]})
+
class GetUtilitiesTestCase(TestCase):
def test_get_origin(self):
diff --git a/Lib/typing.py b/Lib/typing.py
index 29a3f43..84f0fd1 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -203,6 +203,24 @@ def _is_param_expr(arg):
(tuple, list, ParamSpec, _ConcatenateGenericAlias))
+def _should_unflatten_callable_args(typ, args):
+ """Internal helper for munging collections.abc.Callable's __args__.
+
+ The canonical representation for a Callable's __args__ flattens the
+ argument types, see https://bugs.python.org/issue42195. For example:
+
+ collections.abc.Callable[[int, int], str].__args__ == (int, int, str)
+ collections.abc.Callable[ParamSpec, str].__args__ == (ParamSpec, str)
+
+ As a result, if we need to reconstruct the Callable from its __args__,
+ we need to unflatten it.
+ """
+ return (
+ typ.__origin__ is collections.abc.Callable
+ and not (len(args) == 2 and _is_param_expr(args[0]))
+ )
+
+
def _type_repr(obj):
"""Return the repr() of an object, special-casing types (internal helper).
@@ -351,7 +369,10 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
ForwardRef(arg) if isinstance(arg, str) else arg
for arg in t.__args__
)
- t = t.__origin__[args]
+ if _should_unflatten_callable_args(t, args):
+ t = t.__origin__[(args[:-1], args[-1])]
+ else:
+ t = t.__origin__[args]
ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
if ev_args == t.__args__:
return t
@@ -2361,8 +2382,7 @@ def get_args(tp):
return (tp.__origin__,) + tp.__metadata__
if isinstance(tp, (_GenericAlias, GenericAlias)):
res = tp.__args__
- if (tp.__origin__ is collections.abc.Callable
- and not (len(res) == 2 and _is_param_expr(res[0]))):
+ if _should_unflatten_callable_args(tp, res):
res = (list(res[:-1]), res[-1])
return res
if isinstance(tp, types.UnionType):