diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_typing.py | 11 | ||||
-rw-r--r-- | Lib/typing.py | 26 |
2 files changed, 34 insertions, 3 deletions
diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index c074e7a..08f7d02 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -4876,6 +4876,17 @@ class GetTypeHintTests(BaseTestCase): 'a': Annotated[Required[int], "a", "b", "c"] }) + def test_get_type_hints_collections_abc_callable(self): + # https://github.com/python/cpython/issues/91621 + P = ParamSpec('P') + def f(x: collections.abc.Callable[[int], int]): ... + def g(x: collections.abc.Callable[..., int]): ... + def h(x: collections.abc.Callable[P, int]): ... + + self.assertEqual(get_type_hints(f), {'x': collections.abc.Callable[[int], int]}) + self.assertEqual(get_type_hints(g), {'x': collections.abc.Callable[..., int]}) + self.assertEqual(get_type_hints(h), {'x': collections.abc.Callable[P, int]}) + class GetUtilitiesTestCase(TestCase): def test_get_origin(self): diff --git a/Lib/typing.py b/Lib/typing.py index 29a3f43..84f0fd1 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -203,6 +203,24 @@ def _is_param_expr(arg): (tuple, list, ParamSpec, _ConcatenateGenericAlias)) +def _should_unflatten_callable_args(typ, args): + """Internal helper for munging collections.abc.Callable's __args__. + + The canonical representation for a Callable's __args__ flattens the + argument types, see https://bugs.python.org/issue42195. For example: + + collections.abc.Callable[[int, int], str].__args__ == (int, int, str) + collections.abc.Callable[ParamSpec, str].__args__ == (ParamSpec, str) + + As a result, if we need to reconstruct the Callable from its __args__, + we need to unflatten it. + """ + return ( + typ.__origin__ is collections.abc.Callable + and not (len(args) == 2 and _is_param_expr(args[0])) + ) + + def _type_repr(obj): """Return the repr() of an object, special-casing types (internal helper). @@ -351,7 +369,10 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()): ForwardRef(arg) if isinstance(arg, str) else arg for arg in t.__args__ ) - t = t.__origin__[args] + if _should_unflatten_callable_args(t, args): + t = t.__origin__[(args[:-1], args[-1])] + else: + t = t.__origin__[args] ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__) if ev_args == t.__args__: return t @@ -2361,8 +2382,7 @@ def get_args(tp): return (tp.__origin__,) + tp.__metadata__ if isinstance(tp, (_GenericAlias, GenericAlias)): res = tp.__args__ - if (tp.__origin__ is collections.abc.Callable - and not (len(res) == 2 and _is_param_expr(res[0]))): + if _should_unflatten_callable_args(tp, res): res = (list(res[:-1]), res[-1]) return res if isinstance(tp, types.UnionType): |