summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBerker Peksag <berker.peksag@gmail.com>2016-03-27 19:39:14 (GMT)
committerBerker Peksag <berker.peksag@gmail.com>2016-03-27 19:39:14 (GMT)
commitfa0f62d6ab3d4acf949bd0160bca16f0f973c323 (patch)
treeab21e061c2c0ce17e9fb65f48d31c35642a95d0b
parentf70fe6f6cfde596264ed6fdd626b8c2964443f3e (diff)
downloadcpython-fa0f62d6ab3d4acf949bd0160bca16f0f973c323.zip
cpython-fa0f62d6ab3d4acf949bd0160bca16f0f973c323.tar.gz
cpython-fa0f62d6ab3d4acf949bd0160bca16f0f973c323.tar.bz2
Issue #23758: Improve num_params docs of create_{function,aggregate} functions
If you pass -1, the callable can take any number of arguments. Added tests to verify the behavior. Initial patch by Cédric Krier.
-rw-r--r--Doc/library/sqlite3.rst8
-rw-r--r--Lib/sqlite3/test/userfunctions.py31
2 files changed, 36 insertions, 3 deletions
diff --git a/Doc/library/sqlite3.rst b/Doc/library/sqlite3.rst
index 4890fc5..b037b45 100644
--- a/Doc/library/sqlite3.rst
+++ b/Doc/library/sqlite3.rst
@@ -324,8 +324,9 @@ Connection Objects
Creates a user-defined function that you can later use from within SQL
statements under the function name *name*. *num_params* is the number of
- parameters the function accepts, and *func* is a Python callable that is called
- as the SQL function.
+ parameters the function accepts (if *num_params* is -1, the function may
+ take any number of arguments), and *func* is a Python callable that is
+ called as the SQL function.
The function can return any of the types supported by SQLite: bytes, str, int,
float and None.
@@ -340,7 +341,8 @@ Connection Objects
Creates a user-defined aggregate function.
The aggregate class must implement a ``step`` method, which accepts the number
- of parameters *num_params*, and a ``finalize`` method which will return the
+ of parameters *num_params* (if *num_params* is -1, the function may take
+ any number of arguments), and a ``finalize`` method which will return the
final result of the aggregate.
The ``finalize`` method can return any of the types supported by SQLite:
diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py
index 69e2ec2..8a4bbab 100644
--- a/Lib/sqlite3/test/userfunctions.py
+++ b/Lib/sqlite3/test/userfunctions.py
@@ -55,6 +55,9 @@ def func_isblob(v):
def func_islonglong(v):
return isinstance(v, int) and v >= 1<<31
+def func(*args):
+ return len(args)
+
class AggrNoStep:
def __init__(self):
pass
@@ -111,6 +114,19 @@ class AggrCheckType:
def finalize(self):
return self.val
+class AggrCheckTypes:
+ def __init__(self):
+ self.val = 0
+
+ def step(self, whichType, *vals):
+ theType = {"str": str, "int": int, "float": float, "None": type(None),
+ "blob": bytes}
+ for val in vals:
+ self.val += int(theType[whichType] is type(val))
+
+ def finalize(self):
+ return self.val
+
class AggrSum:
def __init__(self):
self.val = 0.0
@@ -140,6 +156,7 @@ class FunctionTests(unittest.TestCase):
self.con.create_function("isnone", 1, func_isnone)
self.con.create_function("isblob", 1, func_isblob)
self.con.create_function("islonglong", 1, func_islonglong)
+ self.con.create_function("spam", -1, func)
def tearDown(self):
self.con.close()
@@ -257,6 +274,13 @@ class FunctionTests(unittest.TestCase):
val = cur.fetchone()[0]
self.assertEqual(val, 1)
+ def CheckAnyArguments(self):
+ cur = self.con.cursor()
+ cur.execute("select spam(?, ?)", (1, 2))
+ val = cur.fetchone()[0]
+ self.assertEqual(val, 2)
+
+
class AggregateTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
@@ -279,6 +303,7 @@ class AggregateTests(unittest.TestCase):
self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
self.con.create_aggregate("checkType", 2, AggrCheckType)
+ self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
self.con.create_aggregate("mysum", 1, AggrSum)
def tearDown(self):
@@ -349,6 +374,12 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0]
self.assertEqual(val, 1)
+ def CheckAggrCheckParamsInt(self):
+ cur = self.con.cursor()
+ cur.execute("select checkTypes('int', ?, ?)", (42, 24))
+ val = cur.fetchone()[0]
+ self.assertEqual(val, 2)
+
def CheckAggrCheckParamFloat(self):
cur = self.con.cursor()
cur.execute("select checkType('float', ?)", (3.14,))