summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorBrandt Bucher <brandt@python.org>2022-02-09 23:15:36 (GMT)
committerGitHub <noreply@github.com>2022-02-09 23:15:36 (GMT)
commit78ae4cc6dc949e8bc39fab25fea5efe983dc0ad1 (patch)
treee81f0366f8d8524947e6a5037fc946b98473779c /Lib
parent5a3f97291eea96037cceee097ebc00bba44bc9ed (diff)
downloadcpython-78ae4cc6dc949e8bc39fab25fea5efe983dc0ad1.zip
cpython-78ae4cc6dc949e8bc39fab25fea5efe983dc0ad1.tar.gz
cpython-78ae4cc6dc949e8bc39fab25fea5efe983dc0ad1.tar.bz2
bpo-46528: Attempt SWAPs at compile-time (GH-30970)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/test_peepholer.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py
index 2df5883..6f24b29 100644
--- a/Lib/test/test_peepholer.py
+++ b/Lib/test/test_peepholer.py
@@ -1,10 +1,25 @@
import dis
from itertools import combinations, product
+import textwrap
import unittest
from test.support.bytecode_helper import BytecodeTestCase
+def compile_pattern_with_fast_locals(pattern):
+ source = textwrap.dedent(
+ f"""
+ def f(x):
+ match x:
+ case {pattern}:
+ pass
+ """
+ )
+ namespace = {}
+ exec(source, namespace)
+ return namespace["f"].__code__
+
+
def count_instr_recursively(f, opname):
count = 0
for instr in dis.get_instructions(f):
@@ -580,6 +595,58 @@ class TestTranforms(BytecodeTestCase):
'not all arguments converted during string formatting'):
eval("'%s, %s' % (x, *y)", {'x': 1, 'y': [2, 3]})
+ def test_static_swaps_unpack_two(self):
+ def f(a, b):
+ a, b = a, b
+ b, a = a, b
+ self.assertNotInBytecode(f, "SWAP")
+
+ def test_static_swaps_unpack_three(self):
+ def f(a, b, c):
+ a, b, c = a, b, c
+ a, c, b = a, b, c
+ b, a, c = a, b, c
+ b, c, a = a, b, c
+ c, a, b = a, b, c
+ c, b, a = a, b, c
+ self.assertNotInBytecode(f, "SWAP")
+
+ def test_static_swaps_match_mapping(self):
+ for a, b, c in product("_a", "_b", "_c"):
+ pattern = f"{{'a': {a}, 'b': {b}, 'c': {c}}}"
+ with self.subTest(pattern):
+ code = compile_pattern_with_fast_locals(pattern)
+ self.assertNotInBytecode(code, "SWAP")
+
+ def test_static_swaps_match_class(self):
+ forms = [
+ "C({}, {}, {})",
+ "C({}, {}, c={})",
+ "C({}, b={}, c={})",
+ "C(a={}, b={}, c={})"
+ ]
+ for a, b, c in product("_a", "_b", "_c"):
+ for form in forms:
+ pattern = form.format(a, b, c)
+ with self.subTest(pattern):
+ code = compile_pattern_with_fast_locals(pattern)
+ self.assertNotInBytecode(code, "SWAP")
+
+ def test_static_swaps_match_sequence(self):
+ swaps = {"*_, b, c", "a, *_, c", "a, b, *_"}
+ forms = ["{}, {}, {}", "{}, {}, *{}", "{}, *{}, {}", "*{}, {}, {}"]
+ for a, b, c in product("_a", "_b", "_c"):
+ for form in forms:
+ pattern = form.format(a, b, c)
+ with self.subTest(pattern):
+ code = compile_pattern_with_fast_locals(pattern)
+ if pattern in swaps:
+ # If this fails... great! Remove this pattern from swaps
+ # to prevent regressing on any improvement:
+ self.assertInBytecode(code, "SWAP")
+ else:
+ self.assertNotInBytecode(code, "SWAP")
+
class TestBuglets(unittest.TestCase):