diff options
author | Berker Peksag <berker.peksag@gmail.com> | 2016-03-27 19:39:14 (GMT) |
---|---|---|
committer | Berker Peksag <berker.peksag@gmail.com> | 2016-03-27 19:39:14 (GMT) |
commit | fa0f62d6ab3d4acf949bd0160bca16f0f973c323 (patch) | |
tree | ab21e061c2c0ce17e9fb65f48d31c35642a95d0b | |
parent | f70fe6f6cfde596264ed6fdd626b8c2964443f3e (diff) | |
download | cpython-fa0f62d6ab3d4acf949bd0160bca16f0f973c323.zip cpython-fa0f62d6ab3d4acf949bd0160bca16f0f973c323.tar.gz cpython-fa0f62d6ab3d4acf949bd0160bca16f0f973c323.tar.bz2 |
Issue #23758: Improve num_params docs of create_{function,aggregate} functions
If you pass -1, the callable can take any number of arguments.
Added tests to verify the behavior.
Initial patch by Cédric Krier.
-rw-r--r-- | Doc/library/sqlite3.rst | 8 | ||||
-rw-r--r-- | Lib/sqlite3/test/userfunctions.py | 31 |
2 files changed, 36 insertions, 3 deletions
diff --git a/Doc/library/sqlite3.rst b/Doc/library/sqlite3.rst index 4890fc5..b037b45 100644 --- a/Doc/library/sqlite3.rst +++ b/Doc/library/sqlite3.rst @@ -324,8 +324,9 @@ Connection Objects Creates a user-defined function that you can later use from within SQL statements under the function name *name*. *num_params* is the number of - parameters the function accepts, and *func* is a Python callable that is called - as the SQL function. + parameters the function accepts (if *num_params* is -1, the function may + take any number of arguments), and *func* is a Python callable that is + called as the SQL function. The function can return any of the types supported by SQLite: bytes, str, int, float and None. @@ -340,7 +341,8 @@ Connection Objects Creates a user-defined aggregate function. The aggregate class must implement a ``step`` method, which accepts the number - of parameters *num_params*, and a ``finalize`` method which will return the + of parameters *num_params* (if *num_params* is -1, the function may take + any number of arguments), and a ``finalize`` method which will return the final result of the aggregate. The ``finalize`` method can return any of the types supported by SQLite: diff --git a/Lib/sqlite3/test/userfunctions.py b/Lib/sqlite3/test/userfunctions.py index 69e2ec2..8a4bbab 100644 --- a/Lib/sqlite3/test/userfunctions.py +++ b/Lib/sqlite3/test/userfunctions.py @@ -55,6 +55,9 @@ def func_isblob(v): def func_islonglong(v): return isinstance(v, int) and v >= 1<<31 +def func(*args): + return len(args) + class AggrNoStep: def __init__(self): pass @@ -111,6 +114,19 @@ class AggrCheckType: def finalize(self): return self.val +class AggrCheckTypes: + def __init__(self): + self.val = 0 + + def step(self, whichType, *vals): + theType = {"str": str, "int": int, "float": float, "None": type(None), + "blob": bytes} + for val in vals: + self.val += int(theType[whichType] is type(val)) + + def finalize(self): + return self.val + class AggrSum: def __init__(self): self.val = 0.0 @@ -140,6 +156,7 @@ class FunctionTests(unittest.TestCase): self.con.create_function("isnone", 1, func_isnone) self.con.create_function("isblob", 1, func_isblob) self.con.create_function("islonglong", 1, func_islonglong) + self.con.create_function("spam", -1, func) def tearDown(self): self.con.close() @@ -257,6 +274,13 @@ class FunctionTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(val, 1) + def CheckAnyArguments(self): + cur = self.con.cursor() + cur.execute("select spam(?, ?)", (1, 2)) + val = cur.fetchone()[0] + self.assertEqual(val, 2) + + class AggregateTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") @@ -279,6 +303,7 @@ class AggregateTests(unittest.TestCase): self.con.create_aggregate("excStep", 1, AggrExceptionInStep) self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize) self.con.create_aggregate("checkType", 2, AggrCheckType) + self.con.create_aggregate("checkTypes", -1, AggrCheckTypes) self.con.create_aggregate("mysum", 1, AggrSum) def tearDown(self): @@ -349,6 +374,12 @@ class AggregateTests(unittest.TestCase): val = cur.fetchone()[0] self.assertEqual(val, 1) + def CheckAggrCheckParamsInt(self): + cur = self.con.cursor() + cur.execute("select checkTypes('int', ?, ?)", (42, 24)) + val = cur.fetchone()[0] + self.assertEqual(val, 2) + def CheckAggrCheckParamFloat(self): cur = self.con.cursor() cur.execute("select checkType('float', ?)", (3.14,)) |