summaryrefslogtreecommitdiffstats
path: root/Lib/sqlite3/test/userfunctions.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/sqlite3/test/userfunctions.py')
-rw-r--r--Lib/sqlite3/test/userfunctions.py330
1 files changed, 330 insertions, 0 deletions
diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py
new file mode 100644
index 0000000..ff7db9c
--- /dev/null
+++ b/Lib/sqlite3/test/userfunctions.py
@@ -0,0 +1,330 @@
+#-*- coding: ISO-8859-1 -*-
+# pysqlite2/test/userfunctions.py: tests for user-defined functions and
+# aggregates.
+#
+# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de>
+#
+# This file is part of pysqlite.
+#
+# This software is provided 'as-is', without any express or implied
+# warranty. In no event will the authors be held liable for any damages
+# arising from the use of this software.
+#
+# Permission is granted to anyone to use this software for any purpose,
+# including commercial applications, and to alter it and redistribute it
+# freely, subject to the following restrictions:
+#
+# 1. The origin of this software must not be misrepresented; you must not
+# claim that you wrote the original software. If you use this software
+# in a product, an acknowledgment in the product documentation would be
+# appreciated but is not required.
+# 2. Altered source versions must be plainly marked as such, and must not be
+# misrepresented as being the original software.
+# 3. This notice may not be removed or altered from any source distribution.
+
+import unittest
+import sqlite3 as sqlite
+
+def func_returntext():
+ return "foo"
+def func_returnunicode():
+ return u"bar"
+def func_returnint():
+ return 42
+def func_returnfloat():
+ return 3.14
+def func_returnnull():
+ return None
+def func_returnblob():
+ return buffer("blob")
+def func_raiseexception():
+ 5/0
+
+def func_isstring(v):
+ return type(v) is unicode
+def func_isint(v):
+ return type(v) is int
+def func_isfloat(v):
+ return type(v) is float
+def func_isnone(v):
+ return type(v) is type(None)
+def func_isblob(v):
+ return type(v) is buffer
+
+class AggrNoStep:
+ def __init__(self):
+ pass
+
+class AggrNoFinalize:
+ def __init__(self):
+ pass
+
+ def step(self, x):
+ pass
+
+class AggrExceptionInInit:
+ def __init__(self):
+ 5/0
+
+ def step(self, x):
+ pass
+
+ def finalize(self):
+ pass
+
+class AggrExceptionInStep:
+ def __init__(self):
+ pass
+
+ def step(self, x):
+ 5/0
+
+ def finalize(self):
+ return 42
+
+class AggrExceptionInFinalize:
+ def __init__(self):
+ pass
+
+ def step(self, x):
+ pass
+
+ def finalize(self):
+ 5/0
+
+class AggrCheckType:
+ def __init__(self):
+ self.val = None
+
+ def step(self, whichType, val):
+ theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}
+ self.val = int(theType[whichType] is type(val))
+
+ def finalize(self):
+ return self.val
+
+class AggrSum:
+ def __init__(self):
+ self.val = 0.0
+
+ def step(self, val):
+ self.val += val
+
+ def finalize(self):
+ return self.val
+
+class FunctionTests(unittest.TestCase):
+ def setUp(self):
+ self.con = sqlite.connect(":memory:")
+
+ self.con.create_function("returntext", 0, func_returntext)
+ self.con.create_function("returnunicode", 0, func_returnunicode)
+ self.con.create_function("returnint", 0, func_returnint)
+ self.con.create_function("returnfloat", 0, func_returnfloat)
+ self.con.create_function("returnnull", 0, func_returnnull)
+ self.con.create_function("returnblob", 0, func_returnblob)
+ self.con.create_function("raiseexception", 0, func_raiseexception)
+
+ self.con.create_function("isstring", 1, func_isstring)
+ self.con.create_function("isint", 1, func_isint)
+ self.con.create_function("isfloat", 1, func_isfloat)
+ self.con.create_function("isnone", 1, func_isnone)
+ self.con.create_function("isblob", 1, func_isblob)
+
+ def tearDown(self):
+ self.con.close()
+
+ def CheckFuncRefCount(self):
+ def getfunc():
+ def f():
+ return val
+ return f
+ self.con.create_function("reftest", 0, getfunc())
+ cur = self.con.cursor()
+ cur.execute("select reftest()")
+
+ def CheckFuncReturnText(self):
+ cur = self.con.cursor()
+ cur.execute("select returntext()")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(type(val), unicode)
+ self.failUnlessEqual(val, "foo")
+
+ def CheckFuncReturnUnicode(self):
+ cur = self.con.cursor()
+ cur.execute("select returnunicode()")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(type(val), unicode)
+ self.failUnlessEqual(val, u"bar")
+
+ def CheckFuncReturnInt(self):
+ cur = self.con.cursor()
+ cur.execute("select returnint()")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(type(val), int)
+ self.failUnlessEqual(val, 42)
+
+ def CheckFuncReturnFloat(self):
+ cur = self.con.cursor()
+ cur.execute("select returnfloat()")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(type(val), float)
+ if val < 3.139 or val > 3.141:
+ self.fail("wrong value")
+
+ def CheckFuncReturnNull(self):
+ cur = self.con.cursor()
+ cur.execute("select returnnull()")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(type(val), type(None))
+ self.failUnlessEqual(val, None)
+
+ def CheckFuncReturnBlob(self):
+ cur = self.con.cursor()
+ cur.execute("select returnblob()")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(type(val), buffer)
+ self.failUnlessEqual(val, buffer("blob"))
+
+ def CheckFuncException(self):
+ cur = self.con.cursor()
+ cur.execute("select raiseexception()")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, None)
+
+ def CheckParamString(self):
+ cur = self.con.cursor()
+ cur.execute("select isstring(?)", ("foo",))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+ def CheckParamInt(self):
+ cur = self.con.cursor()
+ cur.execute("select isint(?)", (42,))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+ def CheckParamFloat(self):
+ cur = self.con.cursor()
+ cur.execute("select isfloat(?)", (3.14,))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+ def CheckParamNone(self):
+ cur = self.con.cursor()
+ cur.execute("select isnone(?)", (None,))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+ def CheckParamBlob(self):
+ cur = self.con.cursor()
+ cur.execute("select isblob(?)", (buffer("blob"),))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+class AggregateTests(unittest.TestCase):
+ def setUp(self):
+ self.con = sqlite.connect(":memory:")
+ cur = self.con.cursor()
+ cur.execute("""
+ create table test(
+ t text,
+ i integer,
+ f float,
+ n,
+ b blob
+ )
+ """)
+ cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
+ ("foo", 5, 3.14, None, buffer("blob"),))
+
+ self.con.create_aggregate("nostep", 1, AggrNoStep)
+ self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
+ self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
+ self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
+ self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
+ self.con.create_aggregate("checkType", 2, AggrCheckType)
+ self.con.create_aggregate("mysum", 1, AggrSum)
+
+ def tearDown(self):
+ #self.cur.close()
+ #self.con.close()
+ pass
+
+ def CheckAggrNoStep(self):
+ cur = self.con.cursor()
+ cur.execute("select nostep(t) from test")
+
+ def CheckAggrNoFinalize(self):
+ cur = self.con.cursor()
+ cur.execute("select nofinalize(t) from test")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, None)
+
+ def CheckAggrExceptionInInit(self):
+ cur = self.con.cursor()
+ cur.execute("select excInit(t) from test")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, None)
+
+ def CheckAggrExceptionInStep(self):
+ cur = self.con.cursor()
+ cur.execute("select excStep(t) from test")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 42)
+
+ def CheckAggrExceptionInFinalize(self):
+ cur = self.con.cursor()
+ cur.execute("select excFinalize(t) from test")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, None)
+
+ def CheckAggrCheckParamStr(self):
+ cur = self.con.cursor()
+ cur.execute("select checkType('str', ?)", ("foo",))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+ def CheckAggrCheckParamInt(self):
+ cur = self.con.cursor()
+ cur.execute("select checkType('int', ?)", (42,))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+ def CheckAggrCheckParamFloat(self):
+ cur = self.con.cursor()
+ cur.execute("select checkType('float', ?)", (3.14,))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+ def CheckAggrCheckParamNone(self):
+ cur = self.con.cursor()
+ cur.execute("select checkType('None', ?)", (None,))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+ def CheckAggrCheckParamBlob(self):
+ cur = self.con.cursor()
+ cur.execute("select checkType('blob', ?)", (buffer("blob"),))
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 1)
+
+ def CheckAggrCheckAggrSum(self):
+ cur = self.con.cursor()
+ cur.execute("delete from test")
+ cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
+ cur.execute("select mysum(i) from test")
+ val = cur.fetchone()[0]
+ self.failUnlessEqual(val, 60)
+
+def suite():
+ function_suite = unittest.makeSuite(FunctionTests, "Check")
+ aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
+ return unittest.TestSuite((function_suite, aggregate_suite))
+
+def test():
+ runner = unittest.TextTestRunner()
+ runner.run(suite())
+
+if __name__ == "__main__":
+ test()