summaryrefslogtreecommitdiffstats
path: root/Modules
diff options
context:
space:
mode:
authorMark Dickinson <mdickinson@enthought.com>2019-05-19 16:51:56 (GMT)
committerGitHub <noreply@github.com>2019-05-19 16:51:56 (GMT)
commit5c08ce9bf712acbb3f05a3a57baf51fcb534cdf0 (patch)
treef31a368cd3f31951303766f605111f9589612d44 /Modules
parent7c59362a15dfce538512ff1fce4e07d33a925cfb (diff)
downloadcpython-5c08ce9bf712acbb3f05a3a57baf51fcb534cdf0.zip
cpython-5c08ce9bf712acbb3f05a3a57baf51fcb534cdf0.tar.gz
cpython-5c08ce9bf712acbb3f05a3a57baf51fcb534cdf0.tar.bz2
bpo-36957: Speed up math.isqrt (#13405)
* Add math.isqrt function computing the integer square root. * Code cleanup: remove redundant comments, rename some variables. * Tighten up code a bit more; use Py_XDECREF to simplify error handling. * Update Modules/mathmodule.c Co-Authored-By: Serhiy Storchaka <storchaka@gmail.com> * Update Modules/mathmodule.c Use real argument clinic type instead of an alias Co-Authored-By: Serhiy Storchaka <storchaka@gmail.com> * Add proof sketch * Updates from review. * Correct and expand documentation. * Fix bad reference handling on error; make some variables block-local; other tidying. * Style and consistency fixes. * Add missing error check; don't try to DECREF a NULL a * Simplify some error returns. * Another two test cases: - clarify that floats are rejected even if they happen to be squares of small integers - TypeError beats ValueError for a negative float * Add fast path for small inputs. Needs tests. * Speed up isqrt for n >= 2**64 as well; add extra tests. * Reduce number of test-cases to avoid dominating the run-time of test_math. * Don't perform unnecessary extra iterations when computing c_bit_length. * Abstract common uint64_t code out into a separate function. * Cleanup. * Add a missing Py_DECREF in an error branch. More cleanup. * Update Modules/mathmodule.c Add missing `static` declaration to helper function. Co-Authored-By: Serhiy Storchaka <storchaka@gmail.com> * Add missing backtick.
Diffstat (limited to 'Modules')
-rw-r--r--Modules/mathmodule.c64
1 files changed, 56 insertions, 8 deletions
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 7a0044a..a153e98 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -1620,6 +1620,22 @@ completes the proof sketch.
*/
+
+/* Approximate square root of a large 64-bit integer.
+
+ Given `n` satisfying `2**62 <= n < 2**64`, return `a`
+ satisfying `(a - 1)**2 < n < (a + 1)**2`. */
+
+static uint64_t
+_approximate_isqrt(uint64_t n)
+{
+ uint32_t u = 1U + (n >> 62);
+ u = (u << 1) + (n >> 59) / u;
+ u = (u << 3) + (n >> 53) / u;
+ u = (u << 7) + (n >> 41) / u;
+ return (u << 15) + (n >> 17) / u;
+}
+
/*[clinic input]
math.isqrt
@@ -1633,8 +1649,9 @@ static PyObject *
math_isqrt(PyObject *module, PyObject *n)
/*[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]*/
{
- int a_too_large, s;
+ int a_too_large, c_bit_length;
size_t c, d;
+ uint64_t m, u;
PyObject *a = NULL, *b;
n = PyNumber_Index(n);
@@ -1653,24 +1670,55 @@ math_isqrt(PyObject *module, PyObject *n)
return PyLong_FromLong(0);
}
+ /* c = (n.bit_length() - 1) // 2 */
c = _PyLong_NumBits(n);
if (c == (size_t)(-1)) {
goto error;
}
c = (c - 1U) / 2U;
- /* s = c.bit_length() */
- s = 0;
- while ((c >> s) > 0) {
- ++s;
+ /* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a
+ fast, almost branch-free algorithm. In the final correction, we use `u*u
+ - 1 >= m` instead of the simpler `u*u > m` in order to get the correct
+ result in the corner case where `u=2**32`. */
+ if (c <= 31U) {
+ m = (uint64_t)PyLong_AsUnsignedLongLong(n);
+ Py_DECREF(n);
+ if (m == (uint64_t)(-1) && PyErr_Occurred()) {
+ return NULL;
+ }
+ u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c);
+ u -= u * u - 1U >= m;
+ return PyLong_FromUnsignedLongLong((unsigned long long)u);
}
- a = PyLong_FromLong(1);
+ /* Slow path: n >= 2**64. We perform the first five iterations in C integer
+ arithmetic, then switch to using Python long integers. */
+
+ /* From n >= 2**64 it follows that c.bit_length() >= 6. */
+ c_bit_length = 6;
+ while ((c >> c_bit_length) > 0U) {
+ ++c_bit_length;
+ }
+
+ /* Initialise d and a. */
+ d = c >> (c_bit_length - 5);
+ b = _PyLong_Rshift(n, 2U*c - 62U);
+ if (b == NULL) {
+ goto error;
+ }
+ m = (uint64_t)PyLong_AsUnsignedLongLong(b);
+ Py_DECREF(b);
+ if (m == (uint64_t)(-1) && PyErr_Occurred()) {
+ goto error;
+ }
+ u = _approximate_isqrt(m) >> (31U - d);
+ a = PyLong_FromUnsignedLongLong((unsigned long long)u);
if (a == NULL) {
goto error;
}
- d = 0;
- while (--s >= 0) {
+
+ for (int s = c_bit_length - 6; s >= 0; --s) {
PyObject *q;
size_t e = d;