summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNikita Sobolev <mail@sobolevn.me>2023-01-21 21:44:41 (GMT)
committerGitHub <noreply@github.com>2023-01-21 21:44:41 (GMT)
commitc1c5882359a2899b74c1685a0d4e61d6e232161f (patch)
tree25735ca47f51c618a5acad9de4917588f56e2498
parentf63f525e161204970418ebc132efc542daaa24ed (diff)
downloadcpython-c1c5882359a2899b74c1685a0d4e61d6e232161f.zip
cpython-c1c5882359a2899b74c1685a0d4e61d6e232161f.tar.gz
cpython-c1c5882359a2899b74c1685a0d4e61d6e232161f.tar.bz2
gh-100518: Add tests for `ast.NodeTransformer` (#100521)
-rw-r--r--Lib/test/support/ast_helper.py43
-rw-r--r--Lib/test/test_ast.py128
-rw-r--r--Lib/test/test_unparse.py42
3 files changed, 171 insertions, 42 deletions
diff --git a/Lib/test/support/ast_helper.py b/Lib/test/support/ast_helper.py
new file mode 100644
index 0000000..8a0415b
--- /dev/null
+++ b/Lib/test/support/ast_helper.py
@@ -0,0 +1,43 @@
+import ast
+
+class ASTTestMixin:
+ """Test mixing to have basic assertions for AST nodes."""
+
+ def assertASTEqual(self, ast1, ast2):
+ # Ensure the comparisons start at an AST node
+ self.assertIsInstance(ast1, ast.AST)
+ self.assertIsInstance(ast2, ast.AST)
+
+ # An AST comparison routine modeled after ast.dump(), but
+ # instead of string building, it traverses the two trees
+ # in lock-step.
+ def traverse_compare(a, b, missing=object()):
+ if type(a) is not type(b):
+ self.fail(f"{type(a)!r} is not {type(b)!r}")
+ if isinstance(a, ast.AST):
+ for field in a._fields:
+ value1 = getattr(a, field, missing)
+ value2 = getattr(b, field, missing)
+ # Singletons are equal by definition, so further
+ # testing can be skipped.
+ if value1 is not value2:
+ traverse_compare(value1, value2)
+ elif isinstance(a, list):
+ try:
+ for node1, node2 in zip(a, b, strict=True):
+ traverse_compare(node1, node2)
+ except ValueError:
+ # Attempt a "pretty" error ala assertSequenceEqual()
+ len1 = len(a)
+ len2 = len(b)
+ if len1 > len2:
+ what = "First"
+ diff = len1 - len2
+ else:
+ what = "Second"
+ diff = len2 - len1
+ msg = f"{what} list contains {diff} additional elements."
+ raise self.failureException(msg) from None
+ elif a != b:
+ self.fail(f"{a!r} != {b!r}")
+ traverse_compare(ast1, ast2)
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 53a6418..c728d2b 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -11,6 +11,7 @@ import weakref
from textwrap import dedent
from test import support
+from test.support.ast_helper import ASTTestMixin
def to_tuple(t):
if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
@@ -2290,9 +2291,10 @@ class EndPositionTests(unittest.TestCase):
self.assertIsNone(ast.get_source_segment(s, x))
self.assertIsNone(ast.get_source_segment(s, y))
-class NodeVisitorTests(unittest.TestCase):
+class BaseNodeVisitorCases:
+ # Both `NodeVisitor` and `NodeTranformer` must raise these warnings:
def test_old_constant_nodes(self):
- class Visitor(ast.NodeVisitor):
+ class Visitor(self.visitor_class):
def visit_Num(self, node):
log.append((node.lineno, 'Num', node.n))
def visit_Str(self, node):
@@ -2340,6 +2342,128 @@ class NodeVisitorTests(unittest.TestCase):
])
+class NodeVisitorTests(BaseNodeVisitorCases, unittest.TestCase):
+ visitor_class = ast.NodeVisitor
+
+
+class NodeTransformerTests(ASTTestMixin, BaseNodeVisitorCases, unittest.TestCase):
+ visitor_class = ast.NodeTransformer
+
+ def assertASTTransformation(self, tranformer_class,
+ initial_code, expected_code):
+ initial_ast = ast.parse(dedent(initial_code))
+ expected_ast = ast.parse(dedent(expected_code))
+
+ tranformer = tranformer_class()
+ result_ast = ast.fix_missing_locations(tranformer.visit(initial_ast))
+
+ self.assertASTEqual(result_ast, expected_ast)
+
+ def test_node_remove_single(self):
+ code = 'def func(arg) -> SomeType: ...'
+ expected = 'def func(arg): ...'
+
+ # Since `FunctionDef.returns` is defined as a single value, we test
+ # the `if isinstance(old_value, AST):` branch here.
+ class SomeTypeRemover(ast.NodeTransformer):
+ def visit_Name(self, node: ast.Name):
+ self.generic_visit(node)
+ if node.id == 'SomeType':
+ return None
+ return node
+
+ self.assertASTTransformation(SomeTypeRemover, code, expected)
+
+ def test_node_remove_from_list(self):
+ code = """
+ def func(arg):
+ print(arg)
+ yield arg
+ """
+ expected = """
+ def func(arg):
+ print(arg)
+ """
+
+ # Since `FunctionDef.body` is defined as a list, we test
+ # the `if isinstance(old_value, list):` branch here.
+ class YieldRemover(ast.NodeTransformer):
+ def visit_Expr(self, node: ast.Expr):
+ self.generic_visit(node)
+ if isinstance(node.value, ast.Yield):
+ return None # Remove `yield` from a function
+ return node
+
+ self.assertASTTransformation(YieldRemover, code, expected)
+
+ def test_node_return_list(self):
+ code = """
+ class DSL(Base, kw1=True): ...
+ """
+ expected = """
+ class DSL(Base, kw1=True, kw2=True, kw3=False): ...
+ """
+
+ class ExtendKeywords(ast.NodeTransformer):
+ def visit_keyword(self, node: ast.keyword):
+ self.generic_visit(node)
+ if node.arg == 'kw1':
+ return [
+ node,
+ ast.keyword('kw2', ast.Constant(True)),
+ ast.keyword('kw3', ast.Constant(False)),
+ ]
+ return node
+
+ self.assertASTTransformation(ExtendKeywords, code, expected)
+
+ def test_node_mutate(self):
+ code = """
+ def func(arg):
+ print(arg)
+ """
+ expected = """
+ def func(arg):
+ log(arg)
+ """
+
+ class PrintToLog(ast.NodeTransformer):
+ def visit_Call(self, node: ast.Call):
+ self.generic_visit(node)
+ if isinstance(node.func, ast.Name) and node.func.id == 'print':
+ node.func.id = 'log'
+ return node
+
+ self.assertASTTransformation(PrintToLog, code, expected)
+
+ def test_node_replace(self):
+ code = """
+ def func(arg):
+ print(arg)
+ """
+ expected = """
+ def func(arg):
+ logger.log(arg, debug=True)
+ """
+
+ class PrintToLog(ast.NodeTransformer):
+ def visit_Call(self, node: ast.Call):
+ self.generic_visit(node)
+ if isinstance(node.func, ast.Name) and node.func.id == 'print':
+ return ast.Call(
+ func=ast.Attribute(
+ ast.Name('logger', ctx=ast.Load()),
+ attr='log',
+ ctx=ast.Load(),
+ ),
+ args=node.args,
+ keywords=[ast.keyword('debug', ast.Constant(True))],
+ )
+ return node
+
+ self.assertASTTransformation(PrintToLog, code, expected)
+
+
@support.cpython_only
class ModuleStateTests(unittest.TestCase):
# bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.
diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py
index f1f1dd5..88c7c3a 100644
--- a/Lib/test/test_unparse.py
+++ b/Lib/test/test_unparse.py
@@ -6,6 +6,7 @@ import pathlib
import random
import tokenize
import ast
+from test.support.ast_helper import ASTTestMixin
def read_pyfile(filename):
@@ -128,46 +129,7 @@ docstring_prefixes = (
"async def foo():\n ",
)
-class ASTTestCase(unittest.TestCase):
- def assertASTEqual(self, ast1, ast2):
- # Ensure the comparisons start at an AST node
- self.assertIsInstance(ast1, ast.AST)
- self.assertIsInstance(ast2, ast.AST)
-
- # An AST comparison routine modeled after ast.dump(), but
- # instead of string building, it traverses the two trees
- # in lock-step.
- def traverse_compare(a, b, missing=object()):
- if type(a) is not type(b):
- self.fail(f"{type(a)!r} is not {type(b)!r}")
- if isinstance(a, ast.AST):
- for field in a._fields:
- value1 = getattr(a, field, missing)
- value2 = getattr(b, field, missing)
- # Singletons are equal by definition, so further
- # testing can be skipped.
- if value1 is not value2:
- traverse_compare(value1, value2)
- elif isinstance(a, list):
- try:
- for node1, node2 in zip(a, b, strict=True):
- traverse_compare(node1, node2)
- except ValueError:
- # Attempt a "pretty" error ala assertSequenceEqual()
- len1 = len(a)
- len2 = len(b)
- if len1 > len2:
- what = "First"
- diff = len1 - len2
- else:
- what = "Second"
- diff = len2 - len1
- msg = f"{what} list contains {diff} additional elements."
- raise self.failureException(msg) from None
- elif a != b:
- self.fail(f"{a!r} != {b!r}")
- traverse_compare(ast1, ast2)
-
+class ASTTestCase(ASTTestMixin, unittest.TestCase):
def check_ast_roundtrip(self, code1, **kwargs):
with self.subTest(code1=code1, ast_parse_kwargs=kwargs):
ast1 = ast.parse(code1, **kwargs)