summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorGregory P. Smith <greg@krypto.org>2022-09-05 20:26:09 (GMT)
committerGitHub <noreply@github.com>2022-09-05 20:26:09 (GMT)
commitb5e331fdb38684808ffc540d53e8595bdc408b89 (patch)
treefff15beb4402c977a0a4dc51aaeab8976039650b /Lib
parent4f100fe9f1c691145e3fa959ef324646e303cdf3 (diff)
downloadcpython-b5e331fdb38684808ffc540d53e8595bdc408b89.zip
cpython-b5e331fdb38684808ffc540d53e8595bdc408b89.tar.gz
cpython-b5e331fdb38684808ffc540d53e8595bdc408b89.tar.bz2
[3.8] gh-95778: CVE-2020-10735: Prevent DoS by very large int() (#96503)
* Correctly pre-check for int-to-str conversion Converting a large enough `int` to a decimal string raises `ValueError` as expected. However, the raise comes _after_ the quadratic-time base-conversion algorithm has run to completion. For effective DOS prevention, we need some kind of check before entering the quadratic-time loop. Oops! =) The quick fix: essentially we catch _most_ values that exceed the threshold up front. Those that slip through will still be on the small side (read: sufficiently fast), and will get caught by the existing check so that the limit remains exact. The justification for the current check. The C code check is: ```c max_str_digits / (3 * PyLong_SHIFT) <= (size_a - 11) / 10 ``` In GitHub markdown math-speak, writing $M$ for `max_str_digits`, $L$ for `PyLong_SHIFT` and $s$ for `size_a`, that check is: $$\left\lfloor\frac{M}{3L}\right\rfloor \le \left\lfloor\frac{s - 11}{10}\right\rfloor$$ From this it follows that $$\frac{M}{3L} < \frac{s-1}{10}$$ hence that $$\frac{L(s-1)}{M} > \frac{10}{3} > \log_2(10).$$ So $$2^{L(s-1)} > 10^M.$$ But our input integer $a$ satisfies $|a| \ge 2^{L(s-1)}$, so $|a|$ is larger than $10^M$. This shows that we don't accidentally capture anything _below_ the intended limit in the check. <!-- gh-issue-number: gh-95778 --> * Issue: gh-95778 <!-- /gh-issue-number --> Co-authored-by: Gregory P. Smith [Google LLC] <greg@krypto.org> Co-authored-by: Christian Heimes <christian@python.org> Co-authored-by: Mark Dickinson <dickinsm@gmail.com>
Diffstat (limited to 'Lib')
-rw-r--r--Lib/test/support/__init__.py10
-rw-r--r--Lib/test/test_ast.py8
-rw-r--r--Lib/test/test_cmd_line.py33
-rw-r--r--Lib/test/test_compile.py13
-rw-r--r--Lib/test/test_decimal.py18
-rw-r--r--Lib/test/test_int.py196
-rw-r--r--Lib/test/test_json/test_decode.py11
-rw-r--r--Lib/test/test_sys.py10
-rw-r--r--Lib/test/test_xmlrpc.py10
9 files changed, 304 insertions, 5 deletions
diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
index fb09e06..fa5a028 100644
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -3401,3 +3401,13 @@ def skip_if_broken_multiprocessing_synchronize():
synchronize.Lock(ctx=None)
except OSError as exc:
raise unittest.SkipTest(f"broken multiprocessing SemLock: {exc!r}")
+
+@contextlib.contextmanager
+def adjust_int_max_str_digits(max_digits):
+ """Temporarily change the integer string conversion length limit."""
+ current = sys.get_int_max_str_digits()
+ try:
+ sys.set_int_max_str_digits(max_digits)
+ yield
+ finally:
+ sys.set_int_max_str_digits(current)
diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py
index c625e69..c67cce1 100644
--- a/Lib/test/test_ast.py
+++ b/Lib/test/test_ast.py
@@ -885,6 +885,14 @@ class ASTHelpers_Test(unittest.TestCase):
self.assertRaises(ValueError, ast.literal_eval, '+True')
self.assertRaises(ValueError, ast.literal_eval, '2+3')
+ def test_literal_eval_str_int_limit(self):
+ with support.adjust_int_max_str_digits(4000):
+ ast.literal_eval('3'*4000) # no error
+ with self.assertRaises(SyntaxError) as err_ctx:
+ ast.literal_eval('3'*4001)
+ self.assertIn('Exceeds the limit ', str(err_ctx.exception))
+ self.assertIn(' Consider hexadecimal ', str(err_ctx.exception))
+
def test_literal_eval_complex(self):
# Issue #4907
self.assertEqual(ast.literal_eval('6j'), 6j)
diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py
index 871a9c7..9f09a50 100644
--- a/Lib/test/test_cmd_line.py
+++ b/Lib/test/test_cmd_line.py
@@ -788,6 +788,39 @@ class CmdLineTest(unittest.TestCase):
self.assertTrue(proc.stderr.startswith(err_msg), proc.stderr)
self.assertNotEqual(proc.returncode, 0)
+ def test_int_max_str_digits(self):
+ code = "import sys; print(sys.flags.int_max_str_digits, sys.get_int_max_str_digits())"
+
+ assert_python_failure('-X', 'int_max_str_digits', '-c', code)
+ assert_python_failure('-X', 'int_max_str_digits=foo', '-c', code)
+ assert_python_failure('-X', 'int_max_str_digits=100', '-c', code)
+
+ assert_python_failure('-c', code, PYTHONINTMAXSTRDIGITS='foo')
+ assert_python_failure('-c', code, PYTHONINTMAXSTRDIGITS='100')
+
+ def res2int(res):
+ out = res.out.strip().decode("utf-8")
+ return tuple(int(i) for i in out.split())
+
+ res = assert_python_ok('-c', code)
+ self.assertEqual(res2int(res), (-1, sys.get_int_max_str_digits()))
+ res = assert_python_ok('-X', 'int_max_str_digits=0', '-c', code)
+ self.assertEqual(res2int(res), (0, 0))
+ res = assert_python_ok('-X', 'int_max_str_digits=4000', '-c', code)
+ self.assertEqual(res2int(res), (4000, 4000))
+ res = assert_python_ok('-X', 'int_max_str_digits=100000', '-c', code)
+ self.assertEqual(res2int(res), (100000, 100000))
+
+ res = assert_python_ok('-c', code, PYTHONINTMAXSTRDIGITS='0')
+ self.assertEqual(res2int(res), (0, 0))
+ res = assert_python_ok('-c', code, PYTHONINTMAXSTRDIGITS='4000')
+ self.assertEqual(res2int(res), (4000, 4000))
+ res = assert_python_ok(
+ '-X', 'int_max_str_digits=6000', '-c', code,
+ PYTHONINTMAXSTRDIGITS='4000'
+ )
+ self.assertEqual(res2int(res), (6000, 6000))
+
@unittest.skipIf(interpreter_requires_environment(),
'Cannot run -I tests when PYTHON env vars are required.')
diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py
index 566ca27..abb18c5 100644
--- a/Lib/test/test_compile.py
+++ b/Lib/test/test_compile.py
@@ -189,6 +189,19 @@ if 1:
self.assertEqual(eval("0o777"), 511)
self.assertEqual(eval("-0o0000010"), -8)
+ def test_int_literals_too_long(self):
+ n = 3000
+ source = f"a = 1\nb = 2\nc = {'3'*n}\nd = 4"
+ with support.adjust_int_max_str_digits(n):
+ compile(source, "<long_int_pass>", "exec") # no errors.
+ with support.adjust_int_max_str_digits(n-1):
+ with self.assertRaises(SyntaxError) as err_ctx:
+ compile(source, "<long_int_fail>", "exec")
+ exc = err_ctx.exception
+ self.assertEqual(exc.lineno, 3)
+ self.assertIn('Exceeds the limit ', str(exc))
+ self.assertIn(' Consider hexadecimal ', str(exc))
+
def test_unary_minus(self):
# Verify treatment of unary minus on negative numbers SF bug #660455
if sys.maxsize == 2147483647:
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 1f37b53..cfa9e17 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -2446,6 +2446,15 @@ class CUsabilityTest(UsabilityTest):
class PyUsabilityTest(UsabilityTest):
decimal = P
+ def setUp(self):
+ super().setUp()
+ self._previous_int_limit = sys.get_int_max_str_digits()
+ sys.set_int_max_str_digits(7000)
+
+ def tearDown(self):
+ sys.set_int_max_str_digits(self._previous_int_limit)
+ super().tearDown()
+
class PythonAPItests(unittest.TestCase):
def test_abc(self):
@@ -4503,6 +4512,15 @@ class CCoverage(Coverage):
class PyCoverage(Coverage):
decimal = P
+ def setUp(self):
+ super().setUp()
+ self._previous_int_limit = sys.get_int_max_str_digits()
+ sys.set_int_max_str_digits(7000)
+
+ def tearDown(self):
+ sys.set_int_max_str_digits(self._previous_int_limit)
+ super().tearDown()
+
class PyFunctionality(unittest.TestCase):
"""Extra functionality in decimal.py"""
diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py
index 6fdf52e..cbbddf5 100644
--- a/Lib/test/test_int.py
+++ b/Lib/test/test_int.py
@@ -1,4 +1,5 @@
import sys
+import time
import unittest
from test import support
@@ -571,5 +572,200 @@ class IntTestCases(unittest.TestCase):
self.assertEqual(int('1_2_3_4_5_6_7', 32), 1144132807)
+class IntStrDigitLimitsTests(unittest.TestCase):
+
+ int_class = int # Override this in subclasses to reuse the suite.
+
+ def setUp(self):
+ super().setUp()
+ self._previous_limit = sys.get_int_max_str_digits()
+ sys.set_int_max_str_digits(2048)
+
+ def tearDown(self):
+ sys.set_int_max_str_digits(self._previous_limit)
+ super().tearDown()
+
+ def test_disabled_limit(self):
+ self.assertGreater(sys.get_int_max_str_digits(), 0)
+ self.assertLess(sys.get_int_max_str_digits(), 20_000)
+ with support.adjust_int_max_str_digits(0):
+ self.assertEqual(sys.get_int_max_str_digits(), 0)
+ i = self.int_class('1' * 20_000)
+ str(i)
+ self.assertGreater(sys.get_int_max_str_digits(), 0)
+
+ def test_max_str_digits_edge_cases(self):
+ """Ignore the +/- sign and space padding."""
+ int_class = self.int_class
+ maxdigits = sys.get_int_max_str_digits()
+
+ int_class('1' * maxdigits)
+ int_class(' ' + '1' * maxdigits)
+ int_class('1' * maxdigits + ' ')
+ int_class('+' + '1' * maxdigits)
+ int_class('-' + '1' * maxdigits)
+ self.assertEqual(len(str(10 ** (maxdigits - 1))), maxdigits)
+
+ def check(self, i, base=None):
+ with self.assertRaises(ValueError):
+ if base is None:
+ self.int_class(i)
+ else:
+ self.int_class(i, base)
+
+ def test_max_str_digits(self):
+ maxdigits = sys.get_int_max_str_digits()
+
+ self.check('1' * (maxdigits + 1))
+ self.check(' ' + '1' * (maxdigits + 1))
+ self.check('1' * (maxdigits + 1) + ' ')
+ self.check('+' + '1' * (maxdigits + 1))
+ self.check('-' + '1' * (maxdigits + 1))
+ self.check('1' * (maxdigits + 1))
+
+ i = 10 ** maxdigits
+ with self.assertRaises(ValueError):
+ str(i)
+
+ def test_denial_of_service_prevented_int_to_str(self):
+ """Regression test: ensure we fail before performing O(N**2) work."""
+ maxdigits = sys.get_int_max_str_digits()
+ assert maxdigits < 50_000, maxdigits # A test prerequisite.
+ get_time = time.process_time
+ if get_time() <= 0: # some platforms like WASM lack process_time()
+ get_time = time.monotonic
+
+ huge_int = int(f'0x{"c"*65_000}', base=16) # 78268 decimal digits.
+ digits = 78_268
+ with support.adjust_int_max_str_digits(digits):
+ start = get_time()
+ huge_decimal = str(huge_int)
+ seconds_to_convert = get_time() - start
+ self.assertEqual(len(huge_decimal), digits)
+ # Ensuring that we chose a slow enough conversion to measure.
+ # It takes 0.1 seconds on a Zen based cloud VM in an opt build.
+ if seconds_to_convert < 0.005:
+ raise unittest.SkipTest('"slow" conversion took only '
+ f'{seconds_to_convert} seconds.')
+
+ # We test with the limit almost at the size needed to check performance.
+ # The performant limit check is slightly fuzzy, give it a some room.
+ with support.adjust_int_max_str_digits(int(.995 * digits)):
+ with self.assertRaises(ValueError) as err:
+ start = get_time()
+ str(huge_int)
+ seconds_to_fail_huge = get_time() - start
+ self.assertIn('conversion', str(err.exception))
+ self.assertLess(seconds_to_fail_huge, seconds_to_convert/8)
+
+ # Now we test that a conversion that would take 30x as long also fails
+ # in a similarly fast fashion.
+ extra_huge_int = int(f'0x{"c"*500_000}', base=16) # 602060 digits.
+ with self.assertRaises(ValueError) as err:
+ start = get_time()
+ # If not limited, 8 seconds said Zen based cloud VM.
+ str(extra_huge_int)
+ seconds_to_fail_extra_huge = get_time() - start
+ self.assertIn('conversion', str(err.exception))
+ self.assertLess(seconds_to_fail_extra_huge, seconds_to_convert/8)
+
+ def test_denial_of_service_prevented_str_to_int(self):
+ """Regression test: ensure we fail before performing O(N**2) work."""
+ maxdigits = sys.get_int_max_str_digits()
+ assert maxdigits < 100_000, maxdigits # A test prerequisite.
+ get_time = time.process_time
+ if get_time() <= 0: # some platforms like WASM lack process_time()
+ get_time = time.monotonic
+
+ digits = 133700
+ huge = '8'*digits
+ with support.adjust_int_max_str_digits(digits):
+ start = get_time()
+ int(huge)
+ seconds_to_convert = get_time() - start
+ # Ensuring that we chose a slow enough conversion to measure.
+ # It takes 0.1 seconds on a Zen based cloud VM in an opt build.
+ if seconds_to_convert < 0.005:
+ raise unittest.SkipTest('"slow" conversion took only '
+ f'{seconds_to_convert} seconds.')
+
+ with support.adjust_int_max_str_digits(digits - 1):
+ with self.assertRaises(ValueError) as err:
+ start = get_time()
+ int(huge)
+ seconds_to_fail_huge = get_time() - start
+ self.assertIn('conversion', str(err.exception))
+ self.assertLess(seconds_to_fail_huge, seconds_to_convert/8)
+
+ # Now we test that a conversion that would take 30x as long also fails
+ # in a similarly fast fashion.
+ extra_huge = '7'*1_200_000
+ with self.assertRaises(ValueError) as err:
+ start = get_time()
+ # If not limited, 8 seconds in the Zen based cloud VM.
+ int(extra_huge)
+ seconds_to_fail_extra_huge = get_time() - start
+ self.assertIn('conversion', str(err.exception))
+ self.assertLess(seconds_to_fail_extra_huge, seconds_to_convert/8)
+
+ def test_power_of_two_bases_unlimited(self):
+ """The limit does not apply to power of 2 bases."""
+ maxdigits = sys.get_int_max_str_digits()
+
+ for base in (2, 4, 8, 16, 32):
+ with self.subTest(base=base):
+ self.int_class('1' * (maxdigits + 1), base)
+ assert maxdigits < 100_000
+ self.int_class('1' * 100_000, base)
+
+ def test_underscores_ignored(self):
+ maxdigits = sys.get_int_max_str_digits()
+
+ triples = maxdigits // 3
+ s = '111' * triples
+ s_ = '1_11' * triples
+ self.int_class(s) # succeeds
+ self.int_class(s_) # succeeds
+ self.check(f'{s}111')
+ self.check(f'{s_}_111')
+
+ def test_sign_not_counted(self):
+ int_class = self.int_class
+ max_digits = sys.get_int_max_str_digits()
+ s = '5' * max_digits
+ i = int_class(s)
+ pos_i = int_class(f'+{s}')
+ assert i == pos_i
+ neg_i = int_class(f'-{s}')
+ assert -pos_i == neg_i
+ str(pos_i)
+ str(neg_i)
+
+ def _other_base_helper(self, base):
+ int_class = self.int_class
+ max_digits = sys.get_int_max_str_digits()
+ s = '2' * max_digits
+ i = int_class(s, base)
+ if base > 10:
+ with self.assertRaises(ValueError):
+ str(i)
+ elif base < 10:
+ str(i)
+ with self.assertRaises(ValueError) as err:
+ int_class(f'{s}1', base)
+
+ def test_int_from_other_bases(self):
+ base = 3
+ with self.subTest(base=base):
+ self._other_base_helper(base)
+ base = 36
+ with self.subTest(base=base):
+ self._other_base_helper(base)
+
+
+class IntSubclassStrDigitLimitsTests(IntStrDigitLimitsTests):
+ int_class = IntSubclass
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_json/test_decode.py b/Lib/test/test_json/test_decode.py
index 895c95b..124045b 100644
--- a/Lib/test/test_json/test_decode.py
+++ b/Lib/test/test_json/test_decode.py
@@ -2,6 +2,7 @@ import decimal
from io import StringIO
from collections import OrderedDict
from test.test_json import PyTest, CTest
+from test import support
class TestDecode:
@@ -95,9 +96,13 @@ class TestDecode:
d = self.json.JSONDecoder()
self.assertRaises(ValueError, d.raw_decode, 'a'*42, -50000)
- def test_deprecated_encode(self):
- with self.assertWarns(DeprecationWarning):
- self.loads('{}', encoding='fake')
+ def test_limit_int(self):
+ maxdigits = 5000
+ with support.adjust_int_max_str_digits(maxdigits):
+ self.loads('1' * maxdigits)
+ with self.assertRaises(ValueError):
+ self.loads('1' * (maxdigits + 1))
+
class TestPyDecode(TestDecode, PyTest): pass
class TestCDecode(TestDecode, CTest): pass
diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py
index 140c65a..581a7d6 100644
--- a/Lib/test/test_sys.py
+++ b/Lib/test/test_sys.py
@@ -447,11 +447,17 @@ class SysModuleTest(unittest.TestCase):
self.assertIsInstance(sys.executable, str)
self.assertEqual(len(sys.float_info), 11)
self.assertEqual(sys.float_info.radix, 2)
- self.assertEqual(len(sys.int_info), 2)
+ self.assertEqual(len(sys.int_info), 4)
self.assertTrue(sys.int_info.bits_per_digit % 5 == 0)
self.assertTrue(sys.int_info.sizeof_digit >= 1)
+ self.assertGreaterEqual(sys.int_info.default_max_str_digits, 500)
+ self.assertGreaterEqual(sys.int_info.str_digits_check_threshold, 100)
+ self.assertGreater(sys.int_info.default_max_str_digits,
+ sys.int_info.str_digits_check_threshold)
self.assertEqual(type(sys.int_info.bits_per_digit), int)
self.assertEqual(type(sys.int_info.sizeof_digit), int)
+ self.assertIsInstance(sys.int_info.default_max_str_digits, int)
+ self.assertIsInstance(sys.int_info.str_digits_check_threshold, int)
self.assertIsInstance(sys.hexversion, int)
self.assertEqual(len(sys.hash_info), 9)
@@ -554,7 +560,7 @@ class SysModuleTest(unittest.TestCase):
"inspect", "interactive", "optimize", "dont_write_bytecode",
"no_user_site", "no_site", "ignore_environment", "verbose",
"bytes_warning", "quiet", "hash_randomization", "isolated",
- "dev_mode", "utf8_mode")
+ "dev_mode", "utf8_mode", "int_max_str_digits")
for attr in attrs:
self.assertTrue(hasattr(sys.flags, attr), attr)
attr_type = bool if attr == "dev_mode" else int
diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py
index 52bacc1..aaa6707 100644
--- a/Lib/test/test_xmlrpc.py
+++ b/Lib/test/test_xmlrpc.py
@@ -283,6 +283,16 @@ class XMLRPCTestCase(unittest.TestCase):
check('<bigdecimal>9876543210.0123456789</bigdecimal>',
decimal.Decimal('9876543210.0123456789'))
+ def test_limit_int(self):
+ check = self.check_loads
+ maxdigits = 5000
+ with support.adjust_int_max_str_digits(maxdigits):
+ s = '1' * (maxdigits + 1)
+ with self.assertRaises(ValueError):
+ check(f'<int>{s}</int>', None)
+ with self.assertRaises(ValueError):
+ check(f'<biginteger>{s}</biginteger>', None)
+
def test_get_host_info(self):
# see bug #3613, this raised a TypeError
transp = xmlrpc.client.Transport()