diff options
author | Nikita Sobolev <mail@sobolevn.me> | 2023-01-21 21:44:41 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-21 21:44:41 (GMT) |
commit | c1c5882359a2899b74c1685a0d4e61d6e232161f (patch) | |
tree | 25735ca47f51c618a5acad9de4917588f56e2498 | |
parent | f63f525e161204970418ebc132efc542daaa24ed (diff) | |
download | cpython-c1c5882359a2899b74c1685a0d4e61d6e232161f.zip cpython-c1c5882359a2899b74c1685a0d4e61d6e232161f.tar.gz cpython-c1c5882359a2899b74c1685a0d4e61d6e232161f.tar.bz2 |
gh-100518: Add tests for `ast.NodeTransformer` (#100521)
-rw-r--r-- | Lib/test/support/ast_helper.py | 43 | ||||
-rw-r--r-- | Lib/test/test_ast.py | 128 | ||||
-rw-r--r-- | Lib/test/test_unparse.py | 42 |
3 files changed, 171 insertions, 42 deletions
diff --git a/Lib/test/support/ast_helper.py b/Lib/test/support/ast_helper.py new file mode 100644 index 0000000..8a0415b --- /dev/null +++ b/Lib/test/support/ast_helper.py @@ -0,0 +1,43 @@ +import ast + +class ASTTestMixin: + """Test mixing to have basic assertions for AST nodes.""" + + def assertASTEqual(self, ast1, ast2): + # Ensure the comparisons start at an AST node + self.assertIsInstance(ast1, ast.AST) + self.assertIsInstance(ast2, ast.AST) + + # An AST comparison routine modeled after ast.dump(), but + # instead of string building, it traverses the two trees + # in lock-step. + def traverse_compare(a, b, missing=object()): + if type(a) is not type(b): + self.fail(f"{type(a)!r} is not {type(b)!r}") + if isinstance(a, ast.AST): + for field in a._fields: + value1 = getattr(a, field, missing) + value2 = getattr(b, field, missing) + # Singletons are equal by definition, so further + # testing can be skipped. + if value1 is not value2: + traverse_compare(value1, value2) + elif isinstance(a, list): + try: + for node1, node2 in zip(a, b, strict=True): + traverse_compare(node1, node2) + except ValueError: + # Attempt a "pretty" error ala assertSequenceEqual() + len1 = len(a) + len2 = len(b) + if len1 > len2: + what = "First" + diff = len1 - len2 + else: + what = "Second" + diff = len2 - len1 + msg = f"{what} list contains {diff} additional elements." + raise self.failureException(msg) from None + elif a != b: + self.fail(f"{a!r} != {b!r}") + traverse_compare(ast1, ast2) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 53a6418..c728d2b 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -11,6 +11,7 @@ import weakref from textwrap import dedent from test import support +from test.support.ast_helper import ASTTestMixin def to_tuple(t): if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis: @@ -2290,9 +2291,10 @@ class EndPositionTests(unittest.TestCase): self.assertIsNone(ast.get_source_segment(s, x)) self.assertIsNone(ast.get_source_segment(s, y)) -class NodeVisitorTests(unittest.TestCase): +class BaseNodeVisitorCases: + # Both `NodeVisitor` and `NodeTranformer` must raise these warnings: def test_old_constant_nodes(self): - class Visitor(ast.NodeVisitor): + class Visitor(self.visitor_class): def visit_Num(self, node): log.append((node.lineno, 'Num', node.n)) def visit_Str(self, node): @@ -2340,6 +2342,128 @@ class NodeVisitorTests(unittest.TestCase): ]) +class NodeVisitorTests(BaseNodeVisitorCases, unittest.TestCase): + visitor_class = ast.NodeVisitor + + +class NodeTransformerTests(ASTTestMixin, BaseNodeVisitorCases, unittest.TestCase): + visitor_class = ast.NodeTransformer + + def assertASTTransformation(self, tranformer_class, + initial_code, expected_code): + initial_ast = ast.parse(dedent(initial_code)) + expected_ast = ast.parse(dedent(expected_code)) + + tranformer = tranformer_class() + result_ast = ast.fix_missing_locations(tranformer.visit(initial_ast)) + + self.assertASTEqual(result_ast, expected_ast) + + def test_node_remove_single(self): + code = 'def func(arg) -> SomeType: ...' + expected = 'def func(arg): ...' + + # Since `FunctionDef.returns` is defined as a single value, we test + # the `if isinstance(old_value, AST):` branch here. + class SomeTypeRemover(ast.NodeTransformer): + def visit_Name(self, node: ast.Name): + self.generic_visit(node) + if node.id == 'SomeType': + return None + return node + + self.assertASTTransformation(SomeTypeRemover, code, expected) + + def test_node_remove_from_list(self): + code = """ + def func(arg): + print(arg) + yield arg + """ + expected = """ + def func(arg): + print(arg) + """ + + # Since `FunctionDef.body` is defined as a list, we test + # the `if isinstance(old_value, list):` branch here. + class YieldRemover(ast.NodeTransformer): + def visit_Expr(self, node: ast.Expr): + self.generic_visit(node) + if isinstance(node.value, ast.Yield): + return None # Remove `yield` from a function + return node + + self.assertASTTransformation(YieldRemover, code, expected) + + def test_node_return_list(self): + code = """ + class DSL(Base, kw1=True): ... + """ + expected = """ + class DSL(Base, kw1=True, kw2=True, kw3=False): ... + """ + + class ExtendKeywords(ast.NodeTransformer): + def visit_keyword(self, node: ast.keyword): + self.generic_visit(node) + if node.arg == 'kw1': + return [ + node, + ast.keyword('kw2', ast.Constant(True)), + ast.keyword('kw3', ast.Constant(False)), + ] + return node + + self.assertASTTransformation(ExtendKeywords, code, expected) + + def test_node_mutate(self): + code = """ + def func(arg): + print(arg) + """ + expected = """ + def func(arg): + log(arg) + """ + + class PrintToLog(ast.NodeTransformer): + def visit_Call(self, node: ast.Call): + self.generic_visit(node) + if isinstance(node.func, ast.Name) and node.func.id == 'print': + node.func.id = 'log' + return node + + self.assertASTTransformation(PrintToLog, code, expected) + + def test_node_replace(self): + code = """ + def func(arg): + print(arg) + """ + expected = """ + def func(arg): + logger.log(arg, debug=True) + """ + + class PrintToLog(ast.NodeTransformer): + def visit_Call(self, node: ast.Call): + self.generic_visit(node) + if isinstance(node.func, ast.Name) and node.func.id == 'print': + return ast.Call( + func=ast.Attribute( + ast.Name('logger', ctx=ast.Load()), + attr='log', + ctx=ast.Load(), + ), + args=node.args, + keywords=[ast.keyword('debug', ast.Constant(True))], + ) + return node + + self.assertASTTransformation(PrintToLog, code, expected) + + @support.cpython_only class ModuleStateTests(unittest.TestCase): # bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state. diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index f1f1dd5..88c7c3a 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -6,6 +6,7 @@ import pathlib import random import tokenize import ast +from test.support.ast_helper import ASTTestMixin def read_pyfile(filename): @@ -128,46 +129,7 @@ docstring_prefixes = ( "async def foo():\n ", ) -class ASTTestCase(unittest.TestCase): - def assertASTEqual(self, ast1, ast2): - # Ensure the comparisons start at an AST node - self.assertIsInstance(ast1, ast.AST) - self.assertIsInstance(ast2, ast.AST) - - # An AST comparison routine modeled after ast.dump(), but - # instead of string building, it traverses the two trees - # in lock-step. - def traverse_compare(a, b, missing=object()): - if type(a) is not type(b): - self.fail(f"{type(a)!r} is not {type(b)!r}") - if isinstance(a, ast.AST): - for field in a._fields: - value1 = getattr(a, field, missing) - value2 = getattr(b, field, missing) - # Singletons are equal by definition, so further - # testing can be skipped. - if value1 is not value2: - traverse_compare(value1, value2) - elif isinstance(a, list): - try: - for node1, node2 in zip(a, b, strict=True): - traverse_compare(node1, node2) - except ValueError: - # Attempt a "pretty" error ala assertSequenceEqual() - len1 = len(a) - len2 = len(b) - if len1 > len2: - what = "First" - diff = len1 - len2 - else: - what = "Second" - diff = len2 - len1 - msg = f"{what} list contains {diff} additional elements." - raise self.failureException(msg) from None - elif a != b: - self.fail(f"{a!r} != {b!r}") - traverse_compare(ast1, ast2) - +class ASTTestCase(ASTTestMixin, unittest.TestCase): def check_ast_roundtrip(self, code1, **kwargs): with self.subTest(code1=code1, ast_parse_kwargs=kwargs): ast1 = ast.parse(code1, **kwargs) |