diff options
Diffstat (limited to 'Lib/sqlite3/test/userfunctions.py')
-rw-r--r-- | Lib/sqlite3/test/userfunctions.py | 107 |
1 files changed, 88 insertions, 19 deletions
diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index 78656e7..31bf289 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -55,6 +55,9 @@ class AggrNoStep: def __init__(self): pass + def finalize(self): + return 1 + class AggrNoFinalize: def __init__(self): pass @@ -144,9 +147,12 @@ class FunctionTests(unittest.TestCase): def CheckFuncRefCount(self): def getfunc(): def f(): - return val + return 1 return f - self.con.create_function("reftest", 0, getfunc()) + f = getfunc() + globals()["foo"] = f + # self.con.create_function("reftest", 0, getfunc()) + self.con.create_function("reftest", 0, f) cur = self.con.cursor() cur.execute("select reftest()") @@ -195,9 +201,12 @@ class FunctionTests(unittest.TestCase): def CheckFuncException(self): cur = self.con.cursor() - cur.execute("select raiseexception()") - val = cur.fetchone()[0] - self.failUnlessEqual(val, None) + try: + cur.execute("select raiseexception()") + cur.fetchone() + self.fail("should have raised OperationalError") + except sqlite.OperationalError, e: + self.failUnlessEqual(e.args[0], 'user-defined function raised exception') def CheckParamString(self): cur = self.con.cursor() @@ -267,31 +276,47 @@ class AggregateTests(unittest.TestCase): def CheckAggrNoStep(self): cur = self.con.cursor() - cur.execute("select nostep(t) from test") + try: + cur.execute("select nostep(t) from test") + self.fail("should have raised an AttributeError") + except AttributeError, e: + self.failUnlessEqual(e.args[0], "AggrNoStep instance has no attribute 'step'") def CheckAggrNoFinalize(self): cur = self.con.cursor() - cur.execute("select nofinalize(t) from test") - val = cur.fetchone()[0] - self.failUnlessEqual(val, None) + try: + cur.execute("select nofinalize(t) from test") + val = cur.fetchone()[0] + self.fail("should have raised an OperationalError") + except sqlite.OperationalError, e: + self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") def CheckAggrExceptionInInit(self): cur = self.con.cursor() - cur.execute("select excInit(t) from test") - val = cur.fetchone()[0] - self.failUnlessEqual(val, None) + try: + cur.execute("select excInit(t) from test") + val = cur.fetchone()[0] + self.fail("should have raised an OperationalError") + except sqlite.OperationalError, e: + self.failUnlessEqual(e.args[0], "user-defined aggregate's '__init__' method raised error") def CheckAggrExceptionInStep(self): cur = self.con.cursor() - cur.execute("select excStep(t) from test") - val = cur.fetchone()[0] - self.failUnlessEqual(val, 42) + try: + cur.execute("select excStep(t) from test") + val = cur.fetchone()[0] + self.fail("should have raised an OperationalError") + except sqlite.OperationalError, e: + self.failUnlessEqual(e.args[0], "user-defined aggregate's 'step' method raised error") def CheckAggrExceptionInFinalize(self): cur = self.con.cursor() - cur.execute("select excFinalize(t) from test") - val = cur.fetchone()[0] - self.failUnlessEqual(val, None) + try: + cur.execute("select excFinalize(t) from test") + val = cur.fetchone()[0] + self.fail("should have raised an OperationalError") + except sqlite.OperationalError, e: + self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error") def CheckAggrCheckParamStr(self): cur = self.con.cursor() @@ -331,10 +356,54 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.failUnlessEqual(val, 60) +def authorizer_cb(action, arg1, arg2, dbname, source): + if action != sqlite.SQLITE_SELECT: + return sqlite.SQLITE_DENY + if arg2 == 'c2' or arg1 == 't2': + return sqlite.SQLITE_DENY + return sqlite.SQLITE_OK + +class AuthorizerTests(unittest.TestCase): + def setUp(self): + self.con = sqlite.connect(":memory:") + self.con.executescript(""" + create table t1 (c1, c2); + create table t2 (c1, c2); + insert into t1 (c1, c2) values (1, 2); + insert into t2 (c1, c2) values (4, 5); + """) + + # For our security test: + self.con.execute("select c2 from t2") + + self.con.set_authorizer(authorizer_cb) + + def tearDown(self): + pass + + def CheckTableAccess(self): + try: + self.con.execute("select * from t2") + except sqlite.DatabaseError, e: + if not e.args[0].endswith("prohibited"): + self.fail("wrong exception text: %s" % e.args[0]) + return + self.fail("should have raised an exception due to missing privileges") + + def CheckColumnAccess(self): + try: + self.con.execute("select c2 from t1") + except sqlite.DatabaseError, e: + if not e.args[0].endswith("prohibited"): + self.fail("wrong exception text: %s" % e.args[0]) + return + self.fail("should have raised an exception due to missing privileges") + def suite(): function_suite = unittest.makeSuite(FunctionTests, "Check") aggregate_suite = unittest.makeSuite(AggregateTests, "Check") - return unittest.TestSuite((function_suite, aggregate_suite)) + authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check") + return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite)) def test(): runner = unittest.TextTestRunner() |