summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAlex Waygood <Alex.Waygood@Gmail.com>2023-08-08 12:12:49 (GMT)
committerGitHub <noreply@github.com>2023-08-08 12:12:49 (GMT)
commit7c5153de5a2bd2c886173a317f116885a925cfce (patch)
tree91d0b7fc4f83dae5fee488702bbb2911717394a8
parent5df8b0d5c71a168a94fb64ad9d8190377b6e73da (diff)
downloadcpython-7c5153de5a2bd2c886173a317f116885a925cfce.zip
cpython-7c5153de5a2bd2c886173a317f116885a925cfce.tar.gz
cpython-7c5153de5a2bd2c886173a317f116885a925cfce.tar.bz2
gh-106368: Argument clinic: add tests for more failure paths (#107731)
-rw-r--r--Lib/test/test_clinic.py108
-rwxr-xr-xTools/clinic/clinic.py5
2 files changed, 97 insertions, 16 deletions
diff --git a/Lib/test/test_clinic.py b/Lib/test/test_clinic.py
index d13d862..6c2411f 100644
--- a/Lib/test/test_clinic.py
+++ b/Lib/test/test_clinic.py
@@ -45,6 +45,7 @@ def _expect_failure(tc, parser, code, errmsg, *, filename=None, lineno=None):
tc.assertEqual(cm.exception.filename, filename)
if lineno is not None:
tc.assertEqual(cm.exception.lineno, lineno)
+ return cm.exception
class ClinicWholeFileTest(TestCase):
@@ -222,6 +223,15 @@ class ClinicWholeFileTest(TestCase):
last_line.startswith("/*[clinic end generated code: output=")
)
+ def test_directive_wrong_arg_number(self):
+ raw = dedent("""
+ /*[clinic input]
+ preserve foo bar baz eggs spam ham mushrooms
+ [clinic start generated code]*/
+ """)
+ err = "takes 1 positional argument but 8 were given"
+ self.expect_failure(raw, err)
+
def test_unknown_destination_command(self):
raw = """
/*[clinic input]
@@ -600,6 +610,31 @@ class ClinicWholeFileTest(TestCase):
self.expect_failure(block, err, lineno=2)
+class ParseFileUnitTest(TestCase):
+ def expect_parsing_failure(
+ self, *, filename, expected_error, verify=True, output=None
+ ):
+ errmsg = re.escape(dedent(expected_error).strip())
+ with self.assertRaisesRegex(clinic.ClinicError, errmsg):
+ clinic.parse_file(filename)
+
+ def test_parse_file_no_extension(self) -> None:
+ self.expect_parsing_failure(
+ filename="foo",
+ expected_error="Can't extract file type for file 'foo'"
+ )
+
+ def test_parse_file_strange_extension(self) -> None:
+ filenames_to_errors = {
+ "foo.rs": "Can't identify file type for file 'foo.rs'",
+ "foo.hs": "Can't identify file type for file 'foo.hs'",
+ "foo.js": "Can't identify file type for file 'foo.js'",
+ }
+ for filename, errmsg in filenames_to_errors.items():
+ with self.subTest(filename=filename):
+ self.expect_parsing_failure(filename=filename, expected_error=errmsg)
+
+
class ClinicGroupPermuterTest(TestCase):
def _test(self, l, m, r, output):
computed = clinic.permute_optional_groups(l, m, r)
@@ -794,8 +829,8 @@ class ClinicParserTest(TestCase):
return s[function_index]
def expect_failure(self, block, err, *, filename=None, lineno=None):
- _expect_failure(self, self.parse_function, block, err,
- filename=filename, lineno=lineno)
+ return _expect_failure(self, self.parse_function, block, err,
+ filename=filename, lineno=lineno)
def checkDocstring(self, fn, expected):
self.assertTrue(hasattr(fn, "docstring"))
@@ -877,6 +912,41 @@ class ClinicParserTest(TestCase):
"""
self.expect_failure(block, err, lineno=2)
+ def test_param_with_bizarre_default_fails_correctly(self):
+ template = """
+ module os
+ os.access
+ follow_symlinks: int = {default}
+ """
+ err = "Unsupported expression as default value"
+ for bad_default_value in (
+ "{1, 2, 3}",
+ "3 if bool() else 4",
+ "[x for x in range(42)]"
+ ):
+ with self.subTest(bad_default=bad_default_value):
+ block = template.format(default=bad_default_value)
+ self.expect_failure(block, err, lineno=2)
+
+ def test_unspecified_not_allowed_as_default_value(self):
+ block = """
+ module os
+ os.access
+ follow_symlinks: int(c_default='MAXSIZE') = unspecified
+ """
+ err = "'unspecified' is not a legal default value!"
+ exc = self.expect_failure(block, err, lineno=2)
+ self.assertNotIn('Malformed expression given as default value', str(exc))
+
+ def test_malformed_expression_as_default_value(self):
+ block = """
+ module os
+ os.access
+ follow_symlinks: int(c_default='MAXSIZE') = 1/0
+ """
+ err = "Malformed expression given as default value"
+ self.expect_failure(block, err, lineno=2)
+
def test_param_default_expr_binop(self):
err = (
"When you specify an expression ('a + b') as your default value, "
@@ -1041,6 +1111,28 @@ class ClinicParserTest(TestCase):
""")
self.assertEqual("os_stat_fn", function.c_basename)
+ def test_base_invalid_syntax(self):
+ block = """
+ module os
+ os.stat
+ invalid syntax: int = 42
+ """
+ err = dedent(r"""
+ Function 'stat' has an invalid parameter declaration:
+ \s+'invalid syntax: int = 42'
+ """).strip()
+ with self.assertRaisesRegex(clinic.ClinicError, err):
+ self.parse_function(block)
+
+ def test_param_default_invalid_syntax(self):
+ block = """
+ module os
+ os.stat
+ x: int = invalid syntax
+ """
+ err = r"Syntax error: 'x = invalid syntax\n'"
+ self.expect_failure(block, err, lineno=2)
+
def test_cloning_nonexistent_function_correctly_fails(self):
block = """
cloned = fooooooooooooooooo
@@ -1414,18 +1506,6 @@ class ClinicParserTest(TestCase):
with self.subTest(block=block):
self.expect_failure(block, err)
- def test_parameters_required_after_depr_star(self):
- dataset = (
- "module foo\nfoo.bar\n * [from 3.14]",
- "module foo\nfoo.bar\n * [from 3.14]\nDocstring here.",
- "module foo\nfoo.bar\n this: int\n * [from 3.14]",
- "module foo\nfoo.bar\n this: int\n * [from 3.14]\nDocstring.",
- )
- err = "Function 'foo.bar' specifies '* [from 3.14]' without any parameters afterwards."
- for block in dataset:
- with self.subTest(block=block):
- self.expect_failure(block, err)
-
def test_depr_star_invalid_format_1(self):
block = """
module foo
diff --git a/Tools/clinic/clinic.py b/Tools/clinic/clinic.py
index c6cf43a..0b336d9 100755
--- a/Tools/clinic/clinic.py
+++ b/Tools/clinic/clinic.py
@@ -5207,13 +5207,14 @@ class DSLParser:
# but at least make an attempt at ensuring it's a valid expression.
try:
value = eval(default)
- if value is unspecified:
- fail("'unspecified' is not a legal default value!")
except NameError:
pass # probably a named constant
except Exception as e:
fail("Malformed expression given as default value "
f"{default!r} caused {e!r}")
+ else:
+ if value is unspecified:
+ fail("'unspecified' is not a legal default value!")
if bad:
fail(f"Unsupported expression as default value: {default!r}")