summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_sqlite3/util.py
blob: 5599823838beea40c9217dc6c49574e91dbfe7e0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import contextlib
import functools
import io
import re
import sqlite3
import test.support


# Helper for temporary memory databases
def memory_database(*args, **kwargs):
    cx = sqlite3.connect(":memory:", *args, **kwargs)
    return contextlib.closing(cx)


# Temporarily limit a database connection parameter
@contextlib.contextmanager
def cx_limit(cx, category=sqlite3.SQLITE_LIMIT_SQL_LENGTH, limit=128):
    try:
        _prev = cx.setlimit(category, limit)
        yield limit
    finally:
        cx.setlimit(category, _prev)


def with_tracebacks(exc, regex="", name=""):
    """Convenience decorator for testing callback tracebacks."""
    def decorator(func):
        _regex = re.compile(regex) if regex else None
        @functools.wraps(func)
        def wrapper(self, *args, **kwargs):
            with test.support.catch_unraisable_exception() as cm:
                # First, run the test with traceback enabled.
                with check_tracebacks(self, cm, exc, _regex, name):
                    func(self, *args, **kwargs)

            # Then run the test with traceback disabled.
            func(self, *args, **kwargs)
        return wrapper
    return decorator


@contextlib.contextmanager
def check_tracebacks(self, cm, exc, regex, obj_name):
    """Convenience context manager for testing callback tracebacks."""
    sqlite3.enable_callback_tracebacks(True)
    try:
        buf = io.StringIO()
        with contextlib.redirect_stderr(buf):
            yield

        self.assertEqual(cm.unraisable.exc_type, exc)
        if regex:
            msg = str(cm.unraisable.exc_value)
            self.assertIsNotNone(regex.search(msg))
        if obj_name:
            self.assertEqual(cm.unraisable.object.__name__, obj_name)
    finally:
        sqlite3.enable_callback_tracebacks(False)


class MemoryDatabaseMixin:

    def setUp(self):
        self.con = sqlite3.connect(":memory:")
        self.cur = self.con.cursor()

    def tearDown(self):
        self.cur.close()
        self.con.close()

    @property
    def cx(self):
        return self.con

    @property
    def cu(self):
        return self.cur