summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/library/ast.rst14
-rw-r--r--Doc/whatsnew/3.14.rst7
-rw-r--r--Lib/ast.py71
-rw-r--r--Lib/test/test_ast.py121
-rw-r--r--Misc/NEWS.d/next/Library/2020-03-28-21-00-54.bpo-15987.aBL8XS.rst2
5 files changed, 210 insertions, 5 deletions
diff --git a/Doc/library/ast.rst b/Doc/library/ast.rst
index d4ccf28..9ee56b9 100644
--- a/Doc/library/ast.rst
+++ b/Doc/library/ast.rst
@@ -2472,6 +2472,20 @@ effects on the compilation of a program:
.. versionadded:: 3.8
+.. function:: compare(a, b, /, *, compare_attributes=False)
+
+ Recursively compares two ASTs.
+
+ *compare_attributes* affects whether AST attributes are considered
+ in the comparison. If *compare_attributes* is ``False`` (default), then
+ attributes are ignored. Otherwise they must all be equal. This
+ option is useful to check whether the ASTs are structurally equal but
+ differ in whitespace or similar details. Attributes include line numbers
+ and column offsets.
+
+ .. versionadded:: 3.14
+
+
.. _ast-cli:
Command-Line Usage
diff --git a/Doc/whatsnew/3.14.rst b/Doc/whatsnew/3.14.rst
index 27c985b..39172ac 100644
--- a/Doc/whatsnew/3.14.rst
+++ b/Doc/whatsnew/3.14.rst
@@ -86,6 +86,13 @@ New Modules
Improved Modules
================
+ast
+---
+
+Added :func:`ast.compare` for comparing two ASTs.
+(Contributed by Batuhan Taskaya and Jeremy Hylton in :issue:`15987`)
+
+
Optimizations
=============
diff --git a/Lib/ast.py b/Lib/ast.py
index d7e51ab..031bab4 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -401,6 +401,77 @@ def walk(node):
yield node
+def compare(
+ a,
+ b,
+ /,
+ *,
+ compare_attributes=False,
+):
+ """Recursively compares two ASTs.
+
+ compare_attributes affects whether AST attributes are considered
+ in the comparison. If compare_attributes is False (default), then
+ attributes are ignored. Otherwise they must all be equal. This
+ option is useful to check whether the ASTs are structurally equal but
+ might differ in whitespace or similar details.
+ """
+
+ def _compare(a, b):
+ # Compare two fields on an AST object, which may themselves be
+ # AST objects, lists of AST objects, or primitive ASDL types
+ # like identifiers and constants.
+ if isinstance(a, AST):
+ return compare(
+ a,
+ b,
+ compare_attributes=compare_attributes,
+ )
+ elif isinstance(a, list):
+ # If a field is repeated, then both objects will represent
+ # the value as a list.
+ if len(a) != len(b):
+ return False
+ for a_item, b_item in zip(a, b):
+ if not _compare(a_item, b_item):
+ return False
+ else:
+ return True
+ else:
+ return type(a) is type(b) and a == b
+
+ def _compare_fields(a, b):
+ if a._fields != b._fields:
+ return False
+ for field in a._fields:
+ a_field = getattr(a, field)
+ b_field = getattr(b, field)
+ if not _compare(a_field, b_field):
+ return False
+ else:
+ return True
+
+ def _compare_attributes(a, b):
+ if a._attributes != b._attributes:
+ return False
+ # Attributes are always ints.
+ for attr in a._attributes:
+ a_attr = getattr(a, attr)
+ b_attr = getattr(b, attr)
+ if a_attr != b_attr:
+ return False
+ else:
+ return True
+
+ if type(a) is not type(b):
+ return False
+ if not _compare_fields(a, b):
+ return False
+ if compare_attributes and not _compare_attributes(a, b):
+ return False
+ return True
+
+
class NodeVisitor(object):
"""
A node visitor base class that walks the abstract syntax tree and calls a
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 5422c86..8a4374c 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -38,6 +38,9 @@ def to_tuple(t):
result.append(to_tuple(getattr(t, f)))
return tuple(result)
+STDLIB = os.path.dirname(ast.__file__)
+STDLIB_FILES = [fn for fn in os.listdir(STDLIB) if fn.endswith(".py")]
+STDLIB_FILES.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
# These tests are compiled through "exec"
# There should be at least one test per statement
@@ -1066,6 +1069,114 @@ class AST_Tests(unittest.TestCase):
expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}"
self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions)
+ def test_compare_basics(self):
+ self.assertTrue(ast.compare(ast.parse("x = 10"), ast.parse("x = 10")))
+ self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("")))
+ self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("x")))
+ self.assertFalse(
+ ast.compare(ast.parse("x = 10;y = 20"), ast.parse("class C:pass"))
+ )
+
+ def test_compare_modified_ast(self):
+ # The ast API is a bit underspecified. The objects are mutable,
+ # and even _fields and _attributes are mutable. The compare() does
+ # some simple things to accommodate mutability.
+ a = ast.parse("m * x + b", mode="eval")
+ b = ast.parse("m * x + b", mode="eval")
+ self.assertTrue(ast.compare(a, b))
+
+ a._fields = a._fields + ("spam",)
+ a.spam = "Spam"
+ self.assertNotEqual(a._fields, b._fields)
+ self.assertFalse(ast.compare(a, b))
+ self.assertFalse(ast.compare(b, a))
+
+ b._fields = a._fields
+ b.spam = a.spam
+ self.assertTrue(ast.compare(a, b))
+ self.assertTrue(ast.compare(b, a))
+
+ b._attributes = b._attributes + ("eggs",)
+ b.eggs = "eggs"
+ self.assertNotEqual(a._attributes, b._attributes)
+ self.assertFalse(ast.compare(a, b, compare_attributes=True))
+ self.assertFalse(ast.compare(b, a, compare_attributes=True))
+
+ a._attributes = b._attributes
+ a.eggs = b.eggs
+ self.assertTrue(ast.compare(a, b, compare_attributes=True))
+ self.assertTrue(ast.compare(b, a, compare_attributes=True))
+
+ def test_compare_literals(self):
+ constants = (
+ -20,
+ 20,
+ 20.0,
+ 1,
+ 1.0,
+ True,
+ 0,
+ False,
+ frozenset(),
+ tuple(),
+ "ABCD",
+ "abcd",
+ "中文字",
+ 1e1000,
+ -1e1000,
+ )
+ for next_index, constant in enumerate(constants[:-1], 1):
+ next_constant = constants[next_index]
+ with self.subTest(literal=constant, next_literal=next_constant):
+ self.assertTrue(
+ ast.compare(ast.Constant(constant), ast.Constant(constant))
+ )
+ self.assertFalse(
+ ast.compare(
+ ast.Constant(constant), ast.Constant(next_constant)
+ )
+ )
+
+ same_looking_literal_cases = [
+ {1, 1.0, True, 1 + 0j},
+ {0, 0.0, False, 0 + 0j},
+ ]
+ for same_looking_literals in same_looking_literal_cases:
+ for literal in same_looking_literals:
+ for same_looking_literal in same_looking_literals - {literal}:
+ self.assertFalse(
+ ast.compare(
+ ast.Constant(literal),
+ ast.Constant(same_looking_literal),
+ )
+ )
+
+ def test_compare_fieldless(self):
+ self.assertTrue(ast.compare(ast.Add(), ast.Add()))
+ self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
+
+ def test_compare_modes(self):
+ for mode, sources in (
+ ("exec", exec_tests),
+ ("eval", eval_tests),
+ ("single", single_tests),
+ ):
+ for source in sources:
+ a = ast.parse(source, mode=mode)
+ b = ast.parse(source, mode=mode)
+ self.assertTrue(
+ ast.compare(a, b), f"{ast.dump(a)} != {ast.dump(b)}"
+ )
+
+ def test_compare_attributes_option(self):
+ def parse(a, b):
+ return ast.parse(a), ast.parse(b)
+
+ a, b = parse("2 + 2", "2+2")
+ self.assertTrue(ast.compare(a, b))
+ self.assertTrue(ast.compare(a, b, compare_attributes=False))
+ self.assertFalse(ast.compare(a, b, compare_attributes=True))
+
def test_positional_only_feature_version(self):
ast.parse('def foo(x, /): ...', feature_version=(3, 8))
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))
@@ -1222,6 +1333,7 @@ class AST_Tests(unittest.TestCase):
for node, attr, source in tests:
self.assert_none_check(node, attr, source)
+
class ASTHelpers_Test(unittest.TestCase):
maxDiff = None
@@ -2191,16 +2303,15 @@ class ASTValidatorTests(unittest.TestCase):
@support.requires_resource('cpu')
def test_stdlib_validates(self):
- stdlib = os.path.dirname(ast.__file__)
- tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
- tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
- for module in tests:
+ for module in STDLIB_FILES:
with self.subTest(module):
- fn = os.path.join(stdlib, module)
+ fn = os.path.join(STDLIB, module)
with open(fn, "r", encoding="utf-8") as fp:
source = fp.read()
mod = ast.parse(source, fn)
compile(mod, fn, "exec")
+ mod2 = ast.parse(source, fn)
+ self.assertTrue(ast.compare(mod, mod2))
constant_1 = ast.Constant(1)
pattern_1 = ast.MatchValue(constant_1)
diff --git a/Misc/NEWS.d/next/Library/2020-03-28-21-00-54.bpo-15987.aBL8XS.rst b/Misc/NEWS.d/next/Library/2020-03-28-21-00-54.bpo-15987.aBL8XS.rst
new file mode 100644
index 0000000..b906393
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2020-03-28-21-00-54.bpo-15987.aBL8XS.rst
@@ -0,0 +1,2 @@
+Implemented :func:`ast.compare` for comparing two ASTs. Patch by Batuhan
+Taskaya with some help from Jeremy Hylton.