summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJelle Zijlstra <jelle.zijlstra@gmail.com>2024-09-25 21:14:03 (GMT)
committerGitHub <noreply@github.com>2024-09-25 21:14:03 (GMT)
commit17a544b257ee3aeaa525350717ee56fd409d9c08 (patch)
treef95619d160ac984403339ca67b25fabfc4eb97da
parent9d8f2d8e08336695fdade5846da4bbcc3eb5f152 (diff)
downloadcpython-17a544b257ee3aeaa525350717ee56fd409d9c08.zip
cpython-17a544b257ee3aeaa525350717ee56fd409d9c08.tar.gz
cpython-17a544b257ee3aeaa525350717ee56fd409d9c08.tar.bz2
gh-119180: Avoid going through AST and eval() when possible in annotationlib (#124337)
Often, ForwardRefs represent a single simple name. In that case, we can avoid going through the overhead of creating AST nodes and code objects and calling eval(): we can simply look up the name directly in the relevant namespaces. Co-authored-by: Victor Stinner <vstinner@python.org>
-rw-r--r--Lib/annotationlib.py79
-rw-r--r--Lib/test/test_annotationlib.py37
2 files changed, 88 insertions, 28 deletions
diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py
index 0a67742..be3bc27 100644
--- a/Lib/annotationlib.py
+++ b/Lib/annotationlib.py
@@ -1,8 +1,10 @@
"""Helpers for introspecting and wrapping annotations."""
import ast
+import builtins
import enum
import functools
+import keyword
import sys
import types
@@ -154,8 +156,19 @@ class ForwardRef:
globals[param_name] = param
locals.pop(param_name, None)
- code = self.__forward_code__
- value = eval(code, globals=globals, locals=locals)
+ arg = self.__forward_arg__
+ if arg.isidentifier() and not keyword.iskeyword(arg):
+ if arg in locals:
+ value = locals[arg]
+ elif arg in globals:
+ value = globals[arg]
+ elif hasattr(builtins, arg):
+ return getattr(builtins, arg)
+ else:
+ raise NameError(arg)
+ else:
+ code = self.__forward_code__
+ value = eval(code, globals=globals, locals=locals)
self.__forward_evaluated__ = True
self.__forward_value__ = value
return value
@@ -254,7 +267,9 @@ class _Stringifier:
__slots__ = _SLOTS
def __init__(self, node, globals=None, owner=None, is_class=False, cell=None):
- assert isinstance(node, ast.AST)
+ # Either an AST node or a simple str (for the common case where a ForwardRef
+ # represent a single name).
+ assert isinstance(node, (ast.AST, str))
self.__arg__ = None
self.__forward_evaluated__ = False
self.__forward_value__ = None
@@ -267,18 +282,26 @@ class _Stringifier:
self.__cell__ = cell
self.__owner__ = owner
- def __convert(self, other):
+ def __convert_to_ast(self, other):
if isinstance(other, _Stringifier):
+ if isinstance(other.__ast_node__, str):
+ return ast.Name(id=other.__ast_node__)
return other.__ast_node__
elif isinstance(other, slice):
return ast.Slice(
- lower=self.__convert(other.start) if other.start is not None else None,
- upper=self.__convert(other.stop) if other.stop is not None else None,
- step=self.__convert(other.step) if other.step is not None else None,
+ lower=self.__convert_to_ast(other.start) if other.start is not None else None,
+ upper=self.__convert_to_ast(other.stop) if other.stop is not None else None,
+ step=self.__convert_to_ast(other.step) if other.step is not None else None,
)
else:
return ast.Constant(value=other)
+ def __get_ast(self):
+ node = self.__ast_node__
+ if isinstance(node, str):
+ return ast.Name(id=node)
+ return node
+
def __make_new(self, node):
return _Stringifier(
node, self.__globals__, self.__owner__, self.__forward_is_class__
@@ -292,38 +315,37 @@ class _Stringifier:
def __getitem__(self, other):
# Special case, to avoid stringifying references to class-scoped variables
# as '__classdict__["x"]'.
- if (
- isinstance(self.__ast_node__, ast.Name)
- and self.__ast_node__.id == "__classdict__"
- ):
+ if self.__ast_node__ == "__classdict__":
raise KeyError
if isinstance(other, tuple):
- elts = [self.__convert(elt) for elt in other]
+ elts = [self.__convert_to_ast(elt) for elt in other]
other = ast.Tuple(elts)
else:
- other = self.__convert(other)
+ other = self.__convert_to_ast(other)
assert isinstance(other, ast.AST), repr(other)
- return self.__make_new(ast.Subscript(self.__ast_node__, other))
+ return self.__make_new(ast.Subscript(self.__get_ast(), other))
def __getattr__(self, attr):
- return self.__make_new(ast.Attribute(self.__ast_node__, attr))
+ return self.__make_new(ast.Attribute(self.__get_ast(), attr))
def __call__(self, *args, **kwargs):
return self.__make_new(
ast.Call(
- self.__ast_node__,
- [self.__convert(arg) for arg in args],
+ self.__get_ast(),
+ [self.__convert_to_ast(arg) for arg in args],
[
- ast.keyword(key, self.__convert(value))
+ ast.keyword(key, self.__convert_to_ast(value))
for key, value in kwargs.items()
],
)
)
def __iter__(self):
- yield self.__make_new(ast.Starred(self.__ast_node__))
+ yield self.__make_new(ast.Starred(self.__get_ast()))
def __repr__(self):
+ if isinstance(self.__ast_node__, str):
+ return self.__ast_node__
return ast.unparse(self.__ast_node__)
def __format__(self, format_spec):
@@ -332,7 +354,7 @@ class _Stringifier:
def _make_binop(op: ast.AST):
def binop(self, other):
return self.__make_new(
- ast.BinOp(self.__ast_node__, op, self.__convert(other))
+ ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
)
return binop
@@ -356,7 +378,7 @@ class _Stringifier:
def _make_rbinop(op: ast.AST):
def rbinop(self, other):
return self.__make_new(
- ast.BinOp(self.__convert(other), op, self.__ast_node__)
+ ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
)
return rbinop
@@ -381,9 +403,9 @@ class _Stringifier:
def compare(self, other):
return self.__make_new(
ast.Compare(
- left=self.__ast_node__,
+ left=self.__get_ast(),
ops=[op],
- comparators=[self.__convert(other)],
+ comparators=[self.__convert_to_ast(other)],
)
)
@@ -400,7 +422,7 @@ class _Stringifier:
def _make_unary_op(op):
def unary_op(self):
- return self.__make_new(ast.UnaryOp(op, self.__ast_node__))
+ return self.__make_new(ast.UnaryOp(op, self.__get_ast()))
return unary_op
@@ -422,7 +444,7 @@ class _StringifierDict(dict):
def __missing__(self, key):
fwdref = _Stringifier(
- ast.Name(id=key),
+ key,
globals=self.globals,
owner=self.owner,
is_class=self.is_class,
@@ -480,7 +502,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
name = freevars[i]
else:
name = "__cell__"
- fwdref = _Stringifier(ast.Name(id=name))
+ fwdref = _Stringifier(name)
new_closure.append(types.CellType(fwdref))
closure = tuple(new_closure)
else:
@@ -532,7 +554,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
else:
name = "__cell__"
fwdref = _Stringifier(
- ast.Name(id=name),
+ name,
cell=cell,
owner=owner,
globals=annotate.__globals__,
@@ -555,6 +577,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
result = func(Format.VALUE)
for obj in globals.stringifiers:
obj.__class__ = ForwardRef
+ if isinstance(obj.__ast_node__, str):
+ obj.__arg__ = obj.__ast_node__
+ obj.__ast_node__ = None
return result
elif format == Format.VALUE:
# Should be impossible because __annotate__ functions must not raise
diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py
index dd8ceb5..cc051ef 100644
--- a/Lib/test/test_annotationlib.py
+++ b/Lib/test/test_annotationlib.py
@@ -1,6 +1,7 @@
"""Tests for the annotations module."""
import annotationlib
+import builtins
import collections
import functools
import itertools
@@ -280,7 +281,14 @@ class TestForwardRefClass(unittest.TestCase):
def test_fwdref_with_module(self):
self.assertIs(ForwardRef("Format", module="annotationlib").evaluate(), Format)
- self.assertIs(ForwardRef("Counter", module="collections").evaluate(), collections.Counter)
+ self.assertIs(
+ ForwardRef("Counter", module="collections").evaluate(),
+ collections.Counter
+ )
+ self.assertEqual(
+ ForwardRef("Counter[int]", module="collections").evaluate(),
+ collections.Counter[int],
+ )
with self.assertRaises(NameError):
# If globals are passed explicitly, we don't look at the module dict
@@ -305,6 +313,33 @@ class TestForwardRefClass(unittest.TestCase):
self.assertIs(fr.evaluate(globals={"hello": str}), str)
self.assertIs(fr.evaluate(), str)
+ def test_fwdref_with_owner(self):
+ self.assertEqual(
+ ForwardRef("Counter[int]", owner=collections).evaluate(),
+ collections.Counter[int],
+ )
+
+ def test_name_lookup_without_eval(self):
+ # test the codepath where we look up simple names directly in the
+ # namespaces without going through eval()
+ self.assertIs(ForwardRef("int").evaluate(), int)
+ self.assertIs(ForwardRef("int").evaluate(locals={"int": str}), str)
+ self.assertIs(ForwardRef("int").evaluate(locals={"int": float}, globals={"int": str}), float)
+ self.assertIs(ForwardRef("int").evaluate(globals={"int": str}), str)
+ with support.swap_attr(builtins, "int", dict):
+ self.assertIs(ForwardRef("int").evaluate(), dict)
+
+ with self.assertRaises(NameError):
+ ForwardRef("doesntexist").evaluate()
+
+ def test_fwdref_invalid_syntax(self):
+ fr = ForwardRef("if")
+ with self.assertRaises(SyntaxError):
+ fr.evaluate()
+ fr = ForwardRef("1+")
+ with self.assertRaises(SyntaxError):
+ fr.evaluate()
+
class TestGetAnnotations(unittest.TestCase):
def test_builtin_type(self):