diff options
Diffstat (limited to 'Lib/typing.py')
-rw-r--r-- | Lib/typing.py | 35 |
1 files changed, 27 insertions, 8 deletions
diff --git a/Lib/typing.py b/Lib/typing.py index 231492c..a0b68f5 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -399,7 +399,8 @@ def _tp_cache(func=None, /, *, typed=False): return decorator -def _eval_type(t, globalns, localns, recursive_guard=frozenset()): + +def _eval_type(t, globalns, localns, type_params, *, recursive_guard=frozenset()): """Evaluate all forward references in the given type t. For use of globalns and localns see the docstring for get_type_hints(). @@ -407,7 +408,7 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()): ForwardRef. """ if isinstance(t, ForwardRef): - return t._evaluate(globalns, localns, recursive_guard) + return t._evaluate(globalns, localns, type_params, recursive_guard=recursive_guard) if isinstance(t, (_GenericAlias, GenericAlias, types.UnionType)): if isinstance(t, GenericAlias): args = tuple( @@ -421,7 +422,13 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()): t = t.__origin__[args] if is_unpacked: t = Unpack[t] - ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__) + + ev_args = tuple( + _eval_type( + a, globalns, localns, type_params, recursive_guard=recursive_guard + ) + for a in t.__args__ + ) if ev_args == t.__args__: return t if isinstance(t, GenericAlias): @@ -974,7 +981,7 @@ class ForwardRef(_Final, _root=True): self.__forward_is_class__ = is_class self.__forward_module__ = module - def _evaluate(self, globalns, localns, recursive_guard): + def _evaluate(self, globalns, localns, type_params, *, recursive_guard): if self.__forward_arg__ in recursive_guard: return self if not self.__forward_evaluated__ or localns is not globalns: @@ -988,14 +995,25 @@ class ForwardRef(_Final, _root=True): globalns = getattr( sys.modules.get(self.__forward_module__, None), '__dict__', globalns ) + if type_params: + # "Inject" type parameters into the local namespace + # (unless they are shadowed by assignments *in* the local namespace), + # as a way of emulating annotation scopes when calling `eval()` + locals_to_pass = {param.__name__: param for param in type_params} | localns + else: + locals_to_pass = localns type_ = _type_check( - eval(self.__forward_code__, globalns, localns), + eval(self.__forward_code__, globalns, locals_to_pass), "Forward references must evaluate to types.", is_argument=self.__forward_is_argument__, allow_special_forms=self.__forward_is_class__, ) self.__forward_value__ = _eval_type( - type_, globalns, localns, recursive_guard | {self.__forward_arg__} + type_, + globalns, + localns, + type_params, + recursive_guard=(recursive_guard | {self.__forward_arg__}), ) self.__forward_evaluated__ = True return self.__forward_value__ @@ -2334,7 +2352,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): value = type(None) if isinstance(value, str): value = ForwardRef(value, is_argument=False, is_class=True) - value = _eval_type(value, base_globals, base_locals) + value = _eval_type(value, base_globals, base_locals, base.__type_params__) hints[name] = value return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()} @@ -2360,6 +2378,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): raise TypeError('{!r} is not a module, class, method, ' 'or function.'.format(obj)) hints = dict(hints) + type_params = getattr(obj, "__type_params__", ()) for name, value in hints.items(): if value is None: value = type(None) @@ -2371,7 +2390,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False): is_argument=not isinstance(obj, types.ModuleType), is_class=False, ) - hints[name] = _eval_type(value, globalns, localns) + hints[name] = _eval_type(value, globalns, localns, type_params) return hints if include_extras else {k: _strip_annotations(t) for k, t in hints.items()} |