summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_ast/test_ast.py104
-rw-r--r--Misc/NEWS.d/next/Core and Builtins/2024-08-27-13-16-40.gh-issue-123344.56Or78.rst1
-rw-r--r--Python/ast_opt.c3
3 files changed, 99 insertions, 9 deletions
diff --git a/Lib/test/test_ast/test_ast.py b/Lib/test/test_ast/test_ast.py
index c37b24a..77596ec 100644
--- a/Lib/test/test_ast/test_ast.py
+++ b/Lib/test/test_ast/test_ast.py
@@ -3062,8 +3062,8 @@ class ASTOptimiziationTests(unittest.TestCase):
def wrap_expr(self, expr):
return ast.Module(body=[ast.Expr(value=expr)])
- def wrap_for(self, for_statement):
- return ast.Module(body=[for_statement])
+ def wrap_statement(self, statement):
+ return ast.Module(body=[statement])
def assert_ast(self, code, non_optimized_target, optimized_target):
non_optimized_tree = ast.parse(code, optimize=-1)
@@ -3090,16 +3090,16 @@ class ASTOptimiziationTests(unittest.TestCase):
f"{ast.dump(optimized_tree)}",
)
+ def create_binop(self, operand, left=ast.Constant(1), right=ast.Constant(1)):
+ return ast.BinOp(left=left, op=self.binop[operand], right=right)
+
def test_folding_binop(self):
code = "1 %s 1"
operators = self.binop.keys()
- def create_binop(operand, left=ast.Constant(1), right=ast.Constant(1)):
- return ast.BinOp(left=left, op=self.binop[operand], right=right)
-
for op in operators:
result_code = code % op
- non_optimized_target = self.wrap_expr(create_binop(op))
+ non_optimized_target = self.wrap_expr(self.create_binop(op))
optimized_target = self.wrap_expr(ast.Constant(value=eval(result_code)))
with self.subTest(
@@ -3111,7 +3111,7 @@ class ASTOptimiziationTests(unittest.TestCase):
# Multiplication of constant tuples must be folded
code = "(1,) * 3"
- non_optimized_target = self.wrap_expr(create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
+ non_optimized_target = self.wrap_expr(self.create_binop("*", ast.Tuple(elts=[ast.Constant(value=1)]), ast.Constant(value=3)))
optimized_target = self.wrap_expr(ast.Constant(eval(code)))
self.assert_ast(code, non_optimized_target, optimized_target)
@@ -3222,12 +3222,12 @@ class ASTOptimiziationTests(unittest.TestCase):
]
for left, right, ast_cls, optimized_iter in braces:
- non_optimized_target = self.wrap_for(ast.For(
+ non_optimized_target = self.wrap_statement(ast.For(
target=ast.Name(id="_", ctx=ast.Store()),
iter=ast_cls(elts=[ast.Constant(1)]),
body=[ast.Pass()]
))
- optimized_target = self.wrap_for(ast.For(
+ optimized_target = self.wrap_statement(ast.For(
target=ast.Name(id="_", ctx=ast.Store()),
iter=ast.Constant(value=optimized_iter),
body=[ast.Pass()]
@@ -3245,6 +3245,92 @@ class ASTOptimiziationTests(unittest.TestCase):
self.assert_ast(code, non_optimized_target, optimized_target)
+ def test_folding_type_param_in_function_def(self):
+ code = "def foo[%s = 1 + 1](): pass"
+
+ unoptimized_binop = self.create_binop("+")
+ unoptimized_type_params = [
+ ("T", "T", ast.TypeVar),
+ ("**P", "P", ast.ParamSpec),
+ ("*Ts", "Ts", ast.TypeVarTuple),
+ ]
+
+ for type, name, type_param in unoptimized_type_params:
+ result_code = code % type
+ optimized_target = self.wrap_statement(
+ ast.FunctionDef(
+ name='foo',
+ args=ast.arguments(),
+ body=[ast.Pass()],
+ type_params=[type_param(name=name, default_value=ast.Constant(2))]
+ )
+ )
+ non_optimized_target = self.wrap_statement(
+ ast.FunctionDef(
+ name='foo',
+ args=ast.arguments(),
+ body=[ast.Pass()],
+ type_params=[type_param(name=name, default_value=unoptimized_binop)]
+ )
+ )
+ self.assert_ast(result_code, non_optimized_target, optimized_target)
+
+ def test_folding_type_param_in_class_def(self):
+ code = "class foo[%s = 1 + 1]: pass"
+
+ unoptimized_binop = self.create_binop("+")
+ unoptimized_type_params = [
+ ("T", "T", ast.TypeVar),
+ ("**P", "P", ast.ParamSpec),
+ ("*Ts", "Ts", ast.TypeVarTuple),
+ ]
+
+ for type, name, type_param in unoptimized_type_params:
+ result_code = code % type
+ optimized_target = self.wrap_statement(
+ ast.ClassDef(
+ name='foo',
+ body=[ast.Pass()],
+ type_params=[type_param(name=name, default_value=ast.Constant(2))]
+ )
+ )
+ non_optimized_target = self.wrap_statement(
+ ast.ClassDef(
+ name='foo',
+ body=[ast.Pass()],
+ type_params=[type_param(name=name, default_value=unoptimized_binop)]
+ )
+ )
+ self.assert_ast(result_code, non_optimized_target, optimized_target)
+
+ def test_folding_type_param_in_type_alias(self):
+ code = "type foo[%s = 1 + 1] = 1"
+
+ unoptimized_binop = self.create_binop("+")
+ unoptimized_type_params = [
+ ("T", "T", ast.TypeVar),
+ ("**P", "P", ast.ParamSpec),
+ ("*Ts", "Ts", ast.TypeVarTuple),
+ ]
+
+ for type, name, type_param in unoptimized_type_params:
+ result_code = code % type
+ optimized_target = self.wrap_statement(
+ ast.TypeAlias(
+ name=ast.Name(id='foo', ctx=ast.Store()),
+ type_params=[type_param(name=name, default_value=ast.Constant(2))],
+ value=ast.Constant(value=1),
+ )
+ )
+ non_optimized_target = self.wrap_statement(
+ ast.TypeAlias(
+ name=ast.Name(id='foo', ctx=ast.Store()),
+ type_params=[type_param(name=name, default_value=unoptimized_binop)],
+ value=ast.Constant(value=1),
+ )
+ )
+ self.assert_ast(result_code, non_optimized_target, optimized_target)
+
if __name__ == "__main__":
unittest.main()
diff --git a/Misc/NEWS.d/next/Core and Builtins/2024-08-27-13-16-40.gh-issue-123344.56Or78.rst b/Misc/NEWS.d/next/Core and Builtins/2024-08-27-13-16-40.gh-issue-123344.56Or78.rst
new file mode 100644
index 0000000..b8b373d
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2024-08-27-13-16-40.gh-issue-123344.56Or78.rst
@@ -0,0 +1 @@
+Add AST optimizations for type parameter defaults.
diff --git a/Python/ast_opt.c b/Python/ast_opt.c
index d7a26e6..503715e 100644
--- a/Python/ast_opt.c
+++ b/Python/ast_opt.c
@@ -1087,10 +1087,13 @@ astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *stat
switch (node_->kind) {
case TypeVar_kind:
CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
+ CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.default_value);
break;
case ParamSpec_kind:
+ CALL_OPT(astfold_expr, expr_ty, node_->v.ParamSpec.default_value);
break;
case TypeVarTuple_kind:
+ CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVarTuple.default_value);
break;
}
return 1;