summaryrefslogtreecommitdiffstats
path: root/Lib/ast.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ast.py')
-rw-r--r--Lib/ast.py71
1 files changed, 71 insertions, 0 deletions
diff --git a/Lib/ast.py b/Lib/ast.py
index d7e51ab..031bab4 100644
--- a/Lib/ast.py
+++ b/Lib/ast.py
@@ -401,6 +401,77 @@ def walk(node):
yield node
+def compare(
+ a,
+ b,
+ /,
+ *,
+ compare_attributes=False,
+):
+ """Recursively compares two ASTs.
+
+ compare_attributes affects whether AST attributes are considered
+ in the comparison. If compare_attributes is False (default), then
+ attributes are ignored. Otherwise they must all be equal. This
+ option is useful to check whether the ASTs are structurally equal but
+ might differ in whitespace or similar details.
+ """
+
+ def _compare(a, b):
+ # Compare two fields on an AST object, which may themselves be
+ # AST objects, lists of AST objects, or primitive ASDL types
+ # like identifiers and constants.
+ if isinstance(a, AST):
+ return compare(
+ a,
+ b,
+ compare_attributes=compare_attributes,
+ )
+ elif isinstance(a, list):
+ # If a field is repeated, then both objects will represent
+ # the value as a list.
+ if len(a) != len(b):
+ return False
+ for a_item, b_item in zip(a, b):
+ if not _compare(a_item, b_item):
+ return False
+ else:
+ return True
+ else:
+ return type(a) is type(b) and a == b
+
+ def _compare_fields(a, b):
+ if a._fields != b._fields:
+ return False
+ for field in a._fields:
+ a_field = getattr(a, field)
+ b_field = getattr(b, field)
+ if not _compare(a_field, b_field):
+ return False
+ else:
+ return True
+
+ def _compare_attributes(a, b):
+ if a._attributes != b._attributes:
+ return False
+ # Attributes are always ints.
+ for attr in a._attributes:
+ a_attr = getattr(a, attr)
+ b_attr = getattr(b, attr)
+ if a_attr != b_attr:
+ return False
+ else:
+ return True
+
+ if type(a) is not type(b):
+ return False
+ if not _compare_fields(a, b):
+ return False
+ if compare_attributes and not _compare_attributes(a, b):
+ return False
+ return True
+
+
class NodeVisitor(object):
"""
A node visitor base class that walks the abstract syntax tree and calls a