summaryrefslogtreecommitdiffstats
path: root/Lib/test
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test')
-rw-r--r--Lib/test/test_ast.py121
1 files changed, 116 insertions, 5 deletions
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index 5422c86..8a4374c 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -38,6 +38,9 @@ def to_tuple(t):
result.append(to_tuple(getattr(t, f)))
return tuple(result)
+STDLIB = os.path.dirname(ast.__file__)
+STDLIB_FILES = [fn for fn in os.listdir(STDLIB) if fn.endswith(".py")]
+STDLIB_FILES.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
# These tests are compiled through "exec"
# There should be at least one test per statement
@@ -1066,6 +1069,114 @@ class AST_Tests(unittest.TestCase):
expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}"
self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions)
+ def test_compare_basics(self):
+ self.assertTrue(ast.compare(ast.parse("x = 10"), ast.parse("x = 10")))
+ self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("")))
+ self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("x")))
+ self.assertFalse(
+ ast.compare(ast.parse("x = 10;y = 20"), ast.parse("class C:pass"))
+ )
+
+ def test_compare_modified_ast(self):
+ # The ast API is a bit underspecified. The objects are mutable,
+ # and even _fields and _attributes are mutable. The compare() does
+ # some simple things to accommodate mutability.
+ a = ast.parse("m * x + b", mode="eval")
+ b = ast.parse("m * x + b", mode="eval")
+ self.assertTrue(ast.compare(a, b))
+
+ a._fields = a._fields + ("spam",)
+ a.spam = "Spam"
+ self.assertNotEqual(a._fields, b._fields)
+ self.assertFalse(ast.compare(a, b))
+ self.assertFalse(ast.compare(b, a))
+
+ b._fields = a._fields
+ b.spam = a.spam
+ self.assertTrue(ast.compare(a, b))
+ self.assertTrue(ast.compare(b, a))
+
+ b._attributes = b._attributes + ("eggs",)
+ b.eggs = "eggs"
+ self.assertNotEqual(a._attributes, b._attributes)
+ self.assertFalse(ast.compare(a, b, compare_attributes=True))
+ self.assertFalse(ast.compare(b, a, compare_attributes=True))
+
+ a._attributes = b._attributes
+ a.eggs = b.eggs
+ self.assertTrue(ast.compare(a, b, compare_attributes=True))
+ self.assertTrue(ast.compare(b, a, compare_attributes=True))
+
+ def test_compare_literals(self):
+ constants = (
+ -20,
+ 20,
+ 20.0,
+ 1,
+ 1.0,
+ True,
+ 0,
+ False,
+ frozenset(),
+ tuple(),
+ "ABCD",
+ "abcd",
+ "中文字",
+ 1e1000,
+ -1e1000,
+ )
+ for next_index, constant in enumerate(constants[:-1], 1):
+ next_constant = constants[next_index]
+ with self.subTest(literal=constant, next_literal=next_constant):
+ self.assertTrue(
+ ast.compare(ast.Constant(constant), ast.Constant(constant))
+ )
+ self.assertFalse(
+ ast.compare(
+ ast.Constant(constant), ast.Constant(next_constant)
+ )
+ )
+
+ same_looking_literal_cases = [
+ {1, 1.0, True, 1 + 0j},
+ {0, 0.0, False, 0 + 0j},
+ ]
+ for same_looking_literals in same_looking_literal_cases:
+ for literal in same_looking_literals:
+ for same_looking_literal in same_looking_literals - {literal}:
+ self.assertFalse(
+ ast.compare(
+ ast.Constant(literal),
+ ast.Constant(same_looking_literal),
+ )
+ )
+
+ def test_compare_fieldless(self):
+ self.assertTrue(ast.compare(ast.Add(), ast.Add()))
+ self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
+
+ def test_compare_modes(self):
+ for mode, sources in (
+ ("exec", exec_tests),
+ ("eval", eval_tests),
+ ("single", single_tests),
+ ):
+ for source in sources:
+ a = ast.parse(source, mode=mode)
+ b = ast.parse(source, mode=mode)
+ self.assertTrue(
+ ast.compare(a, b), f"{ast.dump(a)} != {ast.dump(b)}"
+ )
+
+ def test_compare_attributes_option(self):
+ def parse(a, b):
+ return ast.parse(a), ast.parse(b)
+
+ a, b = parse("2 + 2", "2+2")
+ self.assertTrue(ast.compare(a, b))
+ self.assertTrue(ast.compare(a, b, compare_attributes=False))
+ self.assertFalse(ast.compare(a, b, compare_attributes=True))
+
def test_positional_only_feature_version(self):
ast.parse('def foo(x, /): ...', feature_version=(3, 8))
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))
@@ -1222,6 +1333,7 @@ class AST_Tests(unittest.TestCase):
for node, attr, source in tests:
self.assert_none_check(node, attr, source)
+
class ASTHelpers_Test(unittest.TestCase):
maxDiff = None
@@ -2191,16 +2303,15 @@ class ASTValidatorTests(unittest.TestCase):
@support.requires_resource('cpu')
def test_stdlib_validates(self):
- stdlib = os.path.dirname(ast.__file__)
- tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
- tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
- for module in tests:
+ for module in STDLIB_FILES:
with self.subTest(module):
- fn = os.path.join(stdlib, module)
+ fn = os.path.join(STDLIB, module)
with open(fn, "r", encoding="utf-8") as fp:
source = fp.read()
mod = ast.parse(source, fn)
compile(mod, fn, "exec")
+ mod2 = ast.parse(source, fn)
+ self.assertTrue(ast.compare(mod, mod2))
constant_1 = ast.Constant(1)
pattern_1 = ast.MatchValue(constant_1)