diff options
author | Gregory P. Smith <greg@krypto.org> | 2022-09-05 20:26:09 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-05 20:26:09 (GMT) |
commit | b5e331fdb38684808ffc540d53e8595bdc408b89 (patch) | |
tree | fff15beb4402c977a0a4dc51aaeab8976039650b /Lib | |
parent | 4f100fe9f1c691145e3fa959ef324646e303cdf3 (diff) | |
download | cpython-b5e331fdb38684808ffc540d53e8595bdc408b89.zip cpython-b5e331fdb38684808ffc540d53e8595bdc408b89.tar.gz cpython-b5e331fdb38684808ffc540d53e8595bdc408b89.tar.bz2 |
[3.8] gh-95778: CVE-2020-10735: Prevent DoS by very large int() (#96503)
* Correctly pre-check for int-to-str conversion
Converting a large enough `int` to a decimal string raises `ValueError` as expected. However, the raise comes _after_ the quadratic-time base-conversion algorithm has run to completion. For effective DOS prevention, we need some kind of check before entering the quadratic-time loop. Oops! =)
The quick fix: essentially we catch _most_ values that exceed the threshold up front. Those that slip through will still be on the small side (read: sufficiently fast), and will get caught by the existing check so that the limit remains exact.
The justification for the current check. The C code check is:
```c
max_str_digits / (3 * PyLong_SHIFT) <= (size_a - 11) / 10
```
In GitHub markdown math-speak, writing $M$ for `max_str_digits`, $L$ for `PyLong_SHIFT` and $s$ for `size_a`, that check is:
$$\left\lfloor\frac{M}{3L}\right\rfloor \le \left\lfloor\frac{s - 11}{10}\right\rfloor$$
From this it follows that
$$\frac{M}{3L} < \frac{s-1}{10}$$
hence that
$$\frac{L(s-1)}{M} > \frac{10}{3} > \log_2(10).$$
So
$$2^{L(s-1)} > 10^M.$$
But our input integer $a$ satisfies $|a| \ge 2^{L(s-1)}$, so $|a|$ is larger than $10^M$. This shows that we don't accidentally capture anything _below_ the intended limit in the check.
<!-- gh-issue-number: gh-95778 -->
* Issue: gh-95778
<!-- /gh-issue-number -->
Co-authored-by: Gregory P. Smith [Google LLC] <greg@krypto.org>
Co-authored-by: Christian Heimes <christian@python.org>
Co-authored-by: Mark Dickinson <dickinsm@gmail.com>
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/support/__init__.py | 10 | ||||
-rw-r--r-- | Lib/test/test_ast.py | 8 | ||||
-rw-r--r-- | Lib/test/test_cmd_line.py | 33 | ||||
-rw-r--r-- | Lib/test/test_compile.py | 13 | ||||
-rw-r--r-- | Lib/test/test_decimal.py | 18 | ||||
-rw-r--r-- | Lib/test/test_int.py | 196 | ||||
-rw-r--r-- | Lib/test/test_json/test_decode.py | 11 | ||||
-rw-r--r-- | Lib/test/test_sys.py | 10 | ||||
-rw-r--r-- | Lib/test/test_xmlrpc.py | 10 |
9 files changed, 304 insertions, 5 deletions
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index fb09e06..fa5a028 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -3401,3 +3401,13 @@ def skip_if_broken_multiprocessing_synchronize(): synchronize.Lock(ctx=None) except OSError as exc: raise unittest.SkipTest(f"broken multiprocessing SemLock: {exc!r}") + +@contextlib.contextmanager +def adjust_int_max_str_digits(max_digits): + """Temporarily change the integer string conversion length limit.""" + current = sys.get_int_max_str_digits() + try: + sys.set_int_max_str_digits(max_digits) + yield + finally: + sys.set_int_max_str_digits(current) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index c625e69..c67cce1 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -885,6 +885,14 @@ class ASTHelpers_Test(unittest.TestCase): self.assertRaises(ValueError, ast.literal_eval, '+True') self.assertRaises(ValueError, ast.literal_eval, '2+3') + def test_literal_eval_str_int_limit(self): + with support.adjust_int_max_str_digits(4000): + ast.literal_eval('3'*4000) # no error + with self.assertRaises(SyntaxError) as err_ctx: + ast.literal_eval('3'*4001) + self.assertIn('Exceeds the limit ', str(err_ctx.exception)) + self.assertIn(' Consider hexadecimal ', str(err_ctx.exception)) + def test_literal_eval_complex(self): # Issue #4907 self.assertEqual(ast.literal_eval('6j'), 6j) diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py index 871a9c7..9f09a50 100644 --- a/Lib/test/test_cmd_line.py +++ b/Lib/test/test_cmd_line.py @@ -788,6 +788,39 @@ class CmdLineTest(unittest.TestCase): self.assertTrue(proc.stderr.startswith(err_msg), proc.stderr) self.assertNotEqual(proc.returncode, 0) + def test_int_max_str_digits(self): + code = "import sys; print(sys.flags.int_max_str_digits, sys.get_int_max_str_digits())" + + assert_python_failure('-X', 'int_max_str_digits', '-c', code) + assert_python_failure('-X', 'int_max_str_digits=foo', '-c', code) + assert_python_failure('-X', 'int_max_str_digits=100', '-c', code) + + assert_python_failure('-c', code, PYTHONINTMAXSTRDIGITS='foo') + assert_python_failure('-c', code, PYTHONINTMAXSTRDIGITS='100') + + def res2int(res): + out = res.out.strip().decode("utf-8") + return tuple(int(i) for i in out.split()) + + res = assert_python_ok('-c', code) + self.assertEqual(res2int(res), (-1, sys.get_int_max_str_digits())) + res = assert_python_ok('-X', 'int_max_str_digits=0', '-c', code) + self.assertEqual(res2int(res), (0, 0)) + res = assert_python_ok('-X', 'int_max_str_digits=4000', '-c', code) + self.assertEqual(res2int(res), (4000, 4000)) + res = assert_python_ok('-X', 'int_max_str_digits=100000', '-c', code) + self.assertEqual(res2int(res), (100000, 100000)) + + res = assert_python_ok('-c', code, PYTHONINTMAXSTRDIGITS='0') + self.assertEqual(res2int(res), (0, 0)) + res = assert_python_ok('-c', code, PYTHONINTMAXSTRDIGITS='4000') + self.assertEqual(res2int(res), (4000, 4000)) + res = assert_python_ok( + '-X', 'int_max_str_digits=6000', '-c', code, + PYTHONINTMAXSTRDIGITS='4000' + ) + self.assertEqual(res2int(res), (6000, 6000)) + @unittest.skipIf(interpreter_requires_environment(), 'Cannot run -I tests when PYTHON env vars are required.') diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py index 566ca27..abb18c5 100644 --- a/Lib/test/test_compile.py +++ b/Lib/test/test_compile.py @@ -189,6 +189,19 @@ if 1: self.assertEqual(eval("0o777"), 511) self.assertEqual(eval("-0o0000010"), -8) + def test_int_literals_too_long(self): + n = 3000 + source = f"a = 1\nb = 2\nc = {'3'*n}\nd = 4" + with support.adjust_int_max_str_digits(n): + compile(source, "<long_int_pass>", "exec") # no errors. + with support.adjust_int_max_str_digits(n-1): + with self.assertRaises(SyntaxError) as err_ctx: + compile(source, "<long_int_fail>", "exec") + exc = err_ctx.exception + self.assertEqual(exc.lineno, 3) + self.assertIn('Exceeds the limit ', str(exc)) + self.assertIn(' Consider hexadecimal ', str(exc)) + def test_unary_minus(self): # Verify treatment of unary minus on negative numbers SF bug #660455 if sys.maxsize == 2147483647: diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 1f37b53..cfa9e17 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -2446,6 +2446,15 @@ class CUsabilityTest(UsabilityTest): class PyUsabilityTest(UsabilityTest): decimal = P + def setUp(self): + super().setUp() + self._previous_int_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(7000) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_int_limit) + super().tearDown() + class PythonAPItests(unittest.TestCase): def test_abc(self): @@ -4503,6 +4512,15 @@ class CCoverage(Coverage): class PyCoverage(Coverage): decimal = P + def setUp(self): + super().setUp() + self._previous_int_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(7000) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_int_limit) + super().tearDown() + class PyFunctionality(unittest.TestCase): """Extra functionality in decimal.py""" diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py index 6fdf52e..cbbddf5 100644 --- a/Lib/test/test_int.py +++ b/Lib/test/test_int.py @@ -1,4 +1,5 @@ import sys +import time import unittest from test import support @@ -571,5 +572,200 @@ class IntTestCases(unittest.TestCase): self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807) +class IntStrDigitLimitsTests(unittest.TestCase): + + int_class = int # Override this in subclasses to reuse the suite. + + def setUp(self): + super().setUp() + self._previous_limit = sys.get_int_max_str_digits() + sys.set_int_max_str_digits(2048) + + def tearDown(self): + sys.set_int_max_str_digits(self._previous_limit) + super().tearDown() + + def test_disabled_limit(self): + self.assertGreater(sys.get_int_max_str_digits(), 0) + self.assertLess(sys.get_int_max_str_digits(), 20_000) + with support.adjust_int_max_str_digits(0): + self.assertEqual(sys.get_int_max_str_digits(), 0) + i = self.int_class('1' * 20_000) + str(i) + self.assertGreater(sys.get_int_max_str_digits(), 0) + + def test_max_str_digits_edge_cases(self): + """Ignore the +/- sign and space padding.""" + int_class = self.int_class + maxdigits = sys.get_int_max_str_digits() + + int_class('1' * maxdigits) + int_class(' ' + '1' * maxdigits) + int_class('1' * maxdigits + ' ') + int_class('+' + '1' * maxdigits) + int_class('-' + '1' * maxdigits) + self.assertEqual(len(str(10 ** (maxdigits - 1))), maxdigits) + + def check(self, i, base=None): + with self.assertRaises(ValueError): + if base is None: + self.int_class(i) + else: + self.int_class(i, base) + + def test_max_str_digits(self): + maxdigits = sys.get_int_max_str_digits() + + self.check('1' * (maxdigits + 1)) + self.check(' ' + '1' * (maxdigits + 1)) + self.check('1' * (maxdigits + 1) + ' ') + self.check('+' + '1' * (maxdigits + 1)) + self.check('-' + '1' * (maxdigits + 1)) + self.check('1' * (maxdigits + 1)) + + i = 10 ** maxdigits + with self.assertRaises(ValueError): + str(i) + + def test_denial_of_service_prevented_int_to_str(self): + """Regression test: ensure we fail before performing O(N**2) work.""" + maxdigits = sys.get_int_max_str_digits() + assert maxdigits < 50_000, maxdigits # A test prerequisite. + get_time = time.process_time + if get_time() <= 0: # some platforms like WASM lack process_time() + get_time = time.monotonic + + huge_int = int(f'0x{"c"*65_000}', base=16) # 78268 decimal digits. + digits = 78_268 + with support.adjust_int_max_str_digits(digits): + start = get_time() + huge_decimal = str(huge_int) + seconds_to_convert = get_time() - start + self.assertEqual(len(huge_decimal), digits) + # Ensuring that we chose a slow enough conversion to measure. + # It takes 0.1 seconds on a Zen based cloud VM in an opt build. + if seconds_to_convert < 0.005: + raise unittest.SkipTest('"slow" conversion took only ' + f'{seconds_to_convert} seconds.') + + # We test with the limit almost at the size needed to check performance. + # The performant limit check is slightly fuzzy, give it a some room. + with support.adjust_int_max_str_digits(int(.995 * digits)): + with self.assertRaises(ValueError) as err: + start = get_time() + str(huge_int) + seconds_to_fail_huge = get_time() - start + self.assertIn('conversion', str(err.exception)) + self.assertLess(seconds_to_fail_huge, seconds_to_convert/8) + + # Now we test that a conversion that would take 30x as long also fails + # in a similarly fast fashion. + extra_huge_int = int(f'0x{"c"*500_000}', base=16) # 602060 digits. + with self.assertRaises(ValueError) as err: + start = get_time() + # If not limited, 8 seconds said Zen based cloud VM. + str(extra_huge_int) + seconds_to_fail_extra_huge = get_time() - start + self.assertIn('conversion', str(err.exception)) + self.assertLess(seconds_to_fail_extra_huge, seconds_to_convert/8) + + def test_denial_of_service_prevented_str_to_int(self): + """Regression test: ensure we fail before performing O(N**2) work.""" + maxdigits = sys.get_int_max_str_digits() + assert maxdigits < 100_000, maxdigits # A test prerequisite. + get_time = time.process_time + if get_time() <= 0: # some platforms like WASM lack process_time() + get_time = time.monotonic + + digits = 133700 + huge = '8'*digits + with support.adjust_int_max_str_digits(digits): + start = get_time() + int(huge) + seconds_to_convert = get_time() - start + # Ensuring that we chose a slow enough conversion to measure. + # It takes 0.1 seconds on a Zen based cloud VM in an opt build. + if seconds_to_convert < 0.005: + raise unittest.SkipTest('"slow" conversion took only ' + f'{seconds_to_convert} seconds.') + + with support.adjust_int_max_str_digits(digits - 1): + with self.assertRaises(ValueError) as err: + start = get_time() + int(huge) + seconds_to_fail_huge = get_time() - start + self.assertIn('conversion', str(err.exception)) + self.assertLess(seconds_to_fail_huge, seconds_to_convert/8) + + # Now we test that a conversion that would take 30x as long also fails + # in a similarly fast fashion. + extra_huge = '7'*1_200_000 + with self.assertRaises(ValueError) as err: + start = get_time() + # If not limited, 8 seconds in the Zen based cloud VM. + int(extra_huge) + seconds_to_fail_extra_huge = get_time() - start + self.assertIn('conversion', str(err.exception)) + self.assertLess(seconds_to_fail_extra_huge, seconds_to_convert/8) + + def test_power_of_two_bases_unlimited(self): + """The limit does not apply to power of 2 bases.""" + maxdigits = sys.get_int_max_str_digits() + + for base in (2, 4, 8, 16, 32): + with self.subTest(base=base): + self.int_class('1' * (maxdigits + 1), base) + assert maxdigits < 100_000 + self.int_class('1' * 100_000, base) + + def test_underscores_ignored(self): + maxdigits = sys.get_int_max_str_digits() + + triples = maxdigits // 3 + s = '111' * triples + s_ = '1_11' * triples + self.int_class(s) # succeeds + self.int_class(s_) # succeeds + self.check(f'{s}111') + self.check(f'{s_}_111') + + def test_sign_not_counted(self): + int_class = self.int_class + max_digits = sys.get_int_max_str_digits() + s = '5' * max_digits + i = int_class(s) + pos_i = int_class(f'+{s}') + assert i == pos_i + neg_i = int_class(f'-{s}') + assert -pos_i == neg_i + str(pos_i) + str(neg_i) + + def _other_base_helper(self, base): + int_class = self.int_class + max_digits = sys.get_int_max_str_digits() + s = '2' * max_digits + i = int_class(s, base) + if base > 10: + with self.assertRaises(ValueError): + str(i) + elif base < 10: + str(i) + with self.assertRaises(ValueError) as err: + int_class(f'{s}1', base) + + def test_int_from_other_bases(self): + base = 3 + with self.subTest(base=base): + self._other_base_helper(base) + base = 36 + with self.subTest(base=base): + self._other_base_helper(base) + + +class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests): + int_class = IntSubclass + + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_json/test_decode.py b/Lib/test/test_json/test_decode.py index 895c95b..124045b 100644 --- a/Lib/test/test_json/test_decode.py +++ b/Lib/test/test_json/test_decode.py @@ -2,6 +2,7 @@ import decimal from io import StringIO from collections import OrderedDict from test.test_json import PyTest, CTest +from test import support class TestDecode: @@ -95,9 +96,13 @@ class TestDecode: d = self.json.JSONDecoder() self.assertRaises(ValueError, d.raw_decode, 'a'*42, -50000) - def test_deprecated_encode(self): - with self.assertWarns(DeprecationWarning): - self.loads('{}', encoding='fake') + def test_limit_int(self): + maxdigits = 5000 + with support.adjust_int_max_str_digits(maxdigits): + self.loads('1' * maxdigits) + with self.assertRaises(ValueError): + self.loads('1' * (maxdigits + 1)) + class TestPyDecode(TestDecode, PyTest): pass class TestCDecode(TestDecode, CTest): pass diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py index 140c65a..581a7d6 100644 --- a/Lib/test/test_sys.py +++ b/Lib/test/test_sys.py @@ -447,11 +447,17 @@ class SysModuleTest(unittest.TestCase): self.assertIsInstance(sys.executable, str) self.assertEqual(len(sys.float_info), 11) self.assertEqual(sys.float_info.radix, 2) - self.assertEqual(len(sys.int_info), 2) + self.assertEqual(len(sys.int_info), 4) self.assertTrue(sys.int_info.bits_per_digit % 5 == 0) self.assertTrue(sys.int_info.sizeof_digit >= 1) + self.assertGreaterEqual(sys.int_info.default_max_str_digits, 500) + self.assertGreaterEqual(sys.int_info.str_digits_check_threshold, 100) + self.assertGreater(sys.int_info.default_max_str_digits, + sys.int_info.str_digits_check_threshold) self.assertEqual(type(sys.int_info.bits_per_digit), int) self.assertEqual(type(sys.int_info.sizeof_digit), int) + self.assertIsInstance(sys.int_info.default_max_str_digits, int) + self.assertIsInstance(sys.int_info.str_digits_check_threshold, int) self.assertIsInstance(sys.hexversion, int) self.assertEqual(len(sys.hash_info), 9) @@ -554,7 +560,7 @@ class SysModuleTest(unittest.TestCase): "inspect", "interactive", "optimize", "dont_write_bytecode", "no_user_site", "no_site", "ignore_environment", "verbose", "bytes_warning", "quiet", "hash_randomization", "isolated", - "dev_mode", "utf8_mode") + "dev_mode", "utf8_mode", "int_max_str_digits") for attr in attrs: self.assertTrue(hasattr(sys.flags, attr), attr) attr_type = bool if attr == "dev_mode" else int diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py index 52bacc1..aaa6707 100644 --- a/Lib/test/test_xmlrpc.py +++ b/Lib/test/test_xmlrpc.py @@ -283,6 +283,16 @@ class XMLRPCTestCase(unittest.TestCase): check('<bigdecimal>9876543210.0123456789</bigdecimal>', decimal.Decimal('9876543210.0123456789')) + def test_limit_int(self): + check = self.check_loads + maxdigits = 5000 + with support.adjust_int_max_str_digits(maxdigits): + s = '1' * (maxdigits + 1) + with self.assertRaises(ValueError): + check(f'<int>{s}</int>', None) + with self.assertRaises(ValueError): + check(f'<biginteger>{s}</biginteger>', None) + def test_get_host_info(self): # see bug #3613, this raised a TypeError transp = xmlrpc.client.Transport() |