diff options
author | Brandt Bucher <brandt@python.org> | 2022-02-09 23:15:36 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-09 23:15:36 (GMT) |
commit | 78ae4cc6dc949e8bc39fab25fea5efe983dc0ad1 (patch) | |
tree | e81f0366f8d8524947e6a5037fc946b98473779c /Lib | |
parent | 5a3f97291eea96037cceee097ebc00bba44bc9ed (diff) | |
download | cpython-78ae4cc6dc949e8bc39fab25fea5efe983dc0ad1.zip cpython-78ae4cc6dc949e8bc39fab25fea5efe983dc0ad1.tar.gz cpython-78ae4cc6dc949e8bc39fab25fea5efe983dc0ad1.tar.bz2 |
bpo-46528: Attempt SWAPs at compile-time (GH-30970)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_peepholer.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py index 2df5883..6f24b29 100644 --- a/Lib/test/test_peepholer.py +++ b/Lib/test/test_peepholer.py @@ -1,10 +1,25 @@ import dis from itertools import combinations, product +import textwrap import unittest from test.support.bytecode_helper import BytecodeTestCase +def compile_pattern_with_fast_locals(pattern): + source = textwrap.dedent( + f""" + def f(x): + match x: + case {pattern}: + pass + """ + ) + namespace = {} + exec(source, namespace) + return namespace["f"].__code__ + + def count_instr_recursively(f, opname): count = 0 for instr in dis.get_instructions(f): @@ -580,6 +595,58 @@ class TestTranforms(BytecodeTestCase): 'not all arguments converted during string formatting'): eval("'%s, %s' % (x, *y)", {'x': 1, 'y': [2, 3]}) + def test_static_swaps_unpack_two(self): + def f(a, b): + a, b = a, b + b, a = a, b + self.assertNotInBytecode(f, "SWAP") + + def test_static_swaps_unpack_three(self): + def f(a, b, c): + a, b, c = a, b, c + a, c, b = a, b, c + b, a, c = a, b, c + b, c, a = a, b, c + c, a, b = a, b, c + c, b, a = a, b, c + self.assertNotInBytecode(f, "SWAP") + + def test_static_swaps_match_mapping(self): + for a, b, c in product("_a", "_b", "_c"): + pattern = f"{{'a': {a}, 'b': {b}, 'c': {c}}}" + with self.subTest(pattern): + code = compile_pattern_with_fast_locals(pattern) + self.assertNotInBytecode(code, "SWAP") + + def test_static_swaps_match_class(self): + forms = [ + "C({}, {}, {})", + "C({}, {}, c={})", + "C({}, b={}, c={})", + "C(a={}, b={}, c={})" + ] + for a, b, c in product("_a", "_b", "_c"): + for form in forms: + pattern = form.format(a, b, c) + with self.subTest(pattern): + code = compile_pattern_with_fast_locals(pattern) + self.assertNotInBytecode(code, "SWAP") + + def test_static_swaps_match_sequence(self): + swaps = {"*_, b, c", "a, *_, c", "a, b, *_"} + forms = ["{}, {}, {}", "{}, {}, *{}", "{}, *{}, {}", "*{}, {}, {}"] + for a, b, c in product("_a", "_b", "_c"): + for form in forms: + pattern = form.format(a, b, c) + with self.subTest(pattern): + code = compile_pattern_with_fast_locals(pattern) + if pattern in swaps: + # If this fails... great! Remove this pattern from swaps + # to prevent regressing on any improvement: + self.assertInBytecode(code, "SWAP") + else: + self.assertNotInBytecode(code, "SWAP") + class TestBuglets(unittest.TestCase): |