summaryrefslogtreecommitdiffstats
path: root/Lib/test/support
diff options
context:
space:
mode:
authorNikita Sobolev <mail@sobolevn.me>2023-01-21 21:44:41 (GMT)
committerGitHub <noreply@github.com>2023-01-21 21:44:41 (GMT)
commitc1c5882359a2899b74c1685a0d4e61d6e232161f (patch)
tree25735ca47f51c618a5acad9de4917588f56e2498 /Lib/test/support
parentf63f525e161204970418ebc132efc542daaa24ed (diff)
downloadcpython-c1c5882359a2899b74c1685a0d4e61d6e232161f.zip
cpython-c1c5882359a2899b74c1685a0d4e61d6e232161f.tar.gz
cpython-c1c5882359a2899b74c1685a0d4e61d6e232161f.tar.bz2
gh-100518: Add tests for `ast.NodeTransformer` (#100521)
Diffstat (limited to 'Lib/test/support')
-rw-r--r--Lib/test/support/ast_helper.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/Lib/test/support/ast_helper.py b/Lib/test/support/ast_helper.py
new file mode 100644
index 0000000..8a0415b
--- /dev/null
+++ b/Lib/test/support/ast_helper.py
@@ -0,0 +1,43 @@
+import ast
+
+class ASTTestMixin:
+ """Test mixing to have basic assertions for AST nodes."""
+
+ def assertASTEqual(self, ast1, ast2):
+ # Ensure the comparisons start at an AST node
+ self.assertIsInstance(ast1, ast.AST)
+ self.assertIsInstance(ast2, ast.AST)
+
+ # An AST comparison routine modeled after ast.dump(), but
+ # instead of string building, it traverses the two trees
+ # in lock-step.
+ def traverse_compare(a, b, missing=object()):
+ if type(a) is not type(b):
+ self.fail(f"{type(a)!r} is not {type(b)!r}")
+ if isinstance(a, ast.AST):
+ for field in a._fields:
+ value1 = getattr(a, field, missing)
+ value2 = getattr(b, field, missing)
+ # Singletons are equal by definition, so further
+ # testing can be skipped.
+ if value1 is not value2:
+ traverse_compare(value1, value2)
+ elif isinstance(a, list):
+ try:
+ for node1, node2 in zip(a, b, strict=True):
+ traverse_compare(node1, node2)
+ except ValueError:
+ # Attempt a "pretty" error ala assertSequenceEqual()
+ len1 = len(a)
+ len2 = len(b)
+ if len1 > len2:
+ what = "First"
+ diff = len1 - len2
+ else:
+ what = "Second"
+ diff = len2 - len1
+ msg = f"{what} list contains {diff} additional elements."
+ raise self.failureException(msg) from None
+ elif a != b:
+ self.fail(f"{a!r} != {b!r}")
+ traverse_compare(ast1, ast2)