diff options
author | Erlend Egeberg Aasland <erlend.aasland@innova.no> | 2021-11-29 15:22:32 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-29 15:22:32 (GMT) |
commit | c4a69a4ad035513ada1c0d41a46723606b538e13 (patch) | |
tree | 29ef43642bc01b65bfa9305d3e3c74952eff2d01 /Lib/test/test_sqlite3/test_userfunctions.py | |
parent | 6ac3c8a3140c17bd71ba98dfc5250c371101e77c (diff) | |
download | cpython-c4a69a4ad035513ada1c0d41a46723606b538e13.zip cpython-c4a69a4ad035513ada1c0d41a46723606b538e13.tar.gz cpython-c4a69a4ad035513ada1c0d41a46723606b538e13.tar.bz2 |
bpo-45828: Use unraisable exceptions within sqlite3 callbacks (FH-29591)
Diffstat (limited to 'Lib/test/test_sqlite3/test_userfunctions.py')
-rw-r--r-- | Lib/test/test_sqlite3/test_userfunctions.py | 59 |
1 files changed, 32 insertions, 27 deletions
diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py index 62a11a5..996437b 100644 --- a/Lib/test/test_sqlite3/test_userfunctions.py +++ b/Lib/test/test_sqlite3/test_userfunctions.py @@ -25,46 +25,52 @@ import contextlib import functools import gc import io +import re import sys import unittest import unittest.mock import sqlite3 as sqlite -from test.support import bigmemtest +from test.support import bigmemtest, catch_unraisable_exception from .test_dbapi import cx_limit -def with_tracebacks(strings, traceback=True): +def with_tracebacks(exc, regex="", name=""): """Convenience decorator for testing callback tracebacks.""" - if traceback: - strings.append('Traceback') - def decorator(func): + _regex = re.compile(regex) if regex else None @functools.wraps(func) def wrapper(self, *args, **kwargs): - # First, run the test with traceback enabled. - with check_tracebacks(self, strings): - func(self, *args, **kwargs) + with catch_unraisable_exception() as cm: + # First, run the test with traceback enabled. + with check_tracebacks(self, cm, exc, _regex, name): + func(self, *args, **kwargs) # Then run the test with traceback disabled. func(self, *args, **kwargs) return wrapper return decorator + @contextlib.contextmanager -def check_tracebacks(self, strings): +def check_tracebacks(self, cm, exc, regex, obj_name): """Convenience context manager for testing callback tracebacks.""" sqlite.enable_callback_tracebacks(True) try: buf = io.StringIO() with contextlib.redirect_stderr(buf): yield - tb = buf.getvalue() - for s in strings: - self.assertIn(s, tb) + + self.assertEqual(cm.unraisable.exc_type, exc) + if regex: + msg = str(cm.unraisable.exc_value) + self.assertIsNotNone(regex.search(msg)) + if obj_name: + self.assertEqual(cm.unraisable.object.__name__, obj_name) finally: sqlite.enable_callback_tracebacks(False) + def func_returntext(): return "foo" def func_returntextwithnull(): @@ -299,7 +305,7 @@ class FunctionTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(val, 1<<31) - @with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError']) + @with_tracebacks(ZeroDivisionError, name="func_raiseexception") def test_func_exception(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -307,14 +313,14 @@ class FunctionTests(unittest.TestCase): cur.fetchone() self.assertEqual(str(cm.exception), 'user-defined function raised exception') - @with_tracebacks(['func_memoryerror', 'MemoryError']) + @with_tracebacks(MemoryError, name="func_memoryerror") def test_func_memory_error(self): cur = self.con.cursor() with self.assertRaises(MemoryError): cur.execute("select memoryerror()") cur.fetchone() - @with_tracebacks(['func_overflowerror', 'OverflowError']) + @with_tracebacks(OverflowError, name="func_overflowerror") def test_func_overflow_error(self): cur = self.con.cursor() with self.assertRaises(sqlite.DataError): @@ -426,22 +432,21 @@ class FunctionTests(unittest.TestCase): del x,y gc.collect() + @with_tracebacks(OverflowError) def test_func_return_too_large_int(self): cur = self.con.cursor() for value in 2**63, -2**63-1, 2**64: self.con.create_function("largeint", 0, lambda value=value: value) - with check_tracebacks(self, ['OverflowError']): - with self.assertRaises(sqlite.DataError): - cur.execute("select largeint()") + with self.assertRaises(sqlite.DataError): + cur.execute("select largeint()") + @with_tracebacks(UnicodeEncodeError, "surrogates not allowed", "chr") def test_func_return_text_with_surrogates(self): cur = self.con.cursor() self.con.create_function("pychr", 1, chr) for value in 0xd8ff, 0xdcff: - with check_tracebacks(self, - ['UnicodeEncodeError', 'surrogates not allowed']): - with self.assertRaises(sqlite.OperationalError): - cur.execute("select pychr(?)", (value,)) + with self.assertRaises(sqlite.OperationalError): + cur.execute("select pychr(?)", (value,)) @unittest.skipUnless(sys.maxsize > 2**32, 'requires 64bit platform') @bigmemtest(size=2**31, memuse=3, dry_run=False) @@ -510,7 +515,7 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") - @with_tracebacks(['__init__', '5/0', 'ZeroDivisionError']) + @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit") def test_aggr_exception_in_init(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -518,7 +523,7 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") - @with_tracebacks(['step', '5/0', 'ZeroDivisionError']) + @with_tracebacks(ZeroDivisionError, name="AggrExceptionInStep") def test_aggr_exception_in_step(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -526,7 +531,7 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") - @with_tracebacks(['finalize', '5/0', 'ZeroDivisionError']) + @with_tracebacks(ZeroDivisionError, name="AggrExceptionInFinalize") def test_aggr_exception_in_finalize(self): cur = self.con.cursor() with self.assertRaises(sqlite.OperationalError) as cm: @@ -643,11 +648,11 @@ class AuthorizerRaiseExceptionTests(AuthorizerTests): raise ValueError return sqlite.SQLITE_OK - @with_tracebacks(['authorizer_cb', 'ValueError']) + @with_tracebacks(ValueError, name="authorizer_cb") def test_table_access(self): super().test_table_access() - @with_tracebacks(['authorizer_cb', 'ValueError']) + @with_tracebacks(ValueError, name="authorizer_cb") def test_column_access(self): super().test_table_access() |