summaryrefslogtreecommitdiffstats
path: root/Objects/longobject.c
diff options
context:
space:
mode:
Diffstat (limited to 'Objects/longobject.c')
-rw-r--r--Objects/longobject.c100
1 files changed, 79 insertions, 21 deletions
diff --git a/Objects/longobject.c b/Objects/longobject.c
index f246bd2..2f6d103 100644
--- a/Objects/longobject.c
+++ b/Objects/longobject.c
@@ -12,7 +12,8 @@
* both operands contain more than KARATSUBA_CUTOFF digits (this
* being an internal Python long digit, in base BASE).
*/
-#define KARATSUBA_CUTOFF 35
+#define KARATSUBA_CUTOFF 70
+#define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF)
#define ABS(x) ((x) < 0 ? -(x) : (x))
@@ -1717,26 +1718,72 @@ x_mul(PyLongObject *a, PyLongObject *b)
return NULL;
memset(z->ob_digit, 0, z->ob_size * sizeof(digit));
- for (i = 0; i < size_a; ++i) {
- twodigits carry = 0;
- twodigits f = a->ob_digit[i];
- int j;
- digit *pz = z->ob_digit + i;
+ if (a == b) {
+ /* Efficient squaring per HAC, Algorithm 14.16:
+ * http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
+ * Gives slightly less than a 2x speedup when a == b,
+ * via exploiting that each entry in the multiplication
+ * pyramid appears twice (except for the size_a squares).
+ */
+ for (i = 0; i < size_a; ++i) {
+ twodigits carry;
+ twodigits f = a->ob_digit[i];
+ digit *pz = z->ob_digit + (i << 1);
+ digit *pa = a->ob_digit + i + 1;
+ digit *paend = a->ob_digit + size_a;
- SIGCHECK({
- Py_DECREF(z);
- return NULL;
- })
- for (j = 0; j < size_b; ++j) {
- carry += *pz + b->ob_digit[j] * f;
- *pz++ = (digit) (carry & MASK);
+ SIGCHECK({
+ Py_DECREF(z);
+ return NULL;
+ })
+
+ carry = *pz + f * f;
+ *pz++ = (digit)(carry & MASK);
carry >>= SHIFT;
+ assert(carry <= MASK);
+
+ /* Now f is added in twice in each column of the
+ * pyramid it appears. Same as adding f<<1 once.
+ */
+ f <<= 1;
+ while (pa < paend) {
+ carry += *pz + *pa++ * f;
+ *pz++ = (digit)(carry & MASK);
+ carry >>= SHIFT;
+ assert(carry <= (MASK << 1));
+ }
+ if (carry) {
+ carry += *pz;
+ *pz++ = (digit)(carry & MASK);
+ carry >>= SHIFT;
+ }
+ if (carry)
+ *pz += (digit)(carry & MASK);
+ assert((carry >> SHIFT) == 0);
}
- for (; carry != 0; ++j) {
- assert(i+j < z->ob_size);
- carry += *pz;
- *pz++ = (digit) (carry & MASK);
- carry >>= SHIFT;
+ }
+ else { /* a is not the same as b -- gradeschool long mult */
+ for (i = 0; i < size_a; ++i) {
+ twodigits carry = 0;
+ twodigits f = a->ob_digit[i];
+ digit *pz = z->ob_digit + i;
+ digit *pb = b->ob_digit;
+ digit *pbend = b->ob_digit + size_b;
+
+ SIGCHECK({
+ Py_DECREF(z);
+ return NULL;
+ })
+
+ while (pb < pbend) {
+ carry += *pz + *pb++ * f;
+ *pz++ = (digit)(carry & MASK);
+ carry >>= SHIFT;
+ assert(carry <= MASK);
+ }
+ if (carry)
+ *pz += (digit)(carry & MASK);
+ assert((carry >> SHIFT) == 0);
}
}
return long_normalize(z);
@@ -1816,7 +1863,8 @@ k_mul(PyLongObject *a, PyLongObject *b)
}
/* Use gradeschool math when either number is too small. */
- if (asize <= KARATSUBA_CUTOFF) {
+ i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF;
+ if (asize <= i) {
if (asize == 0)
return _PyLong_New(0);
else
@@ -1837,7 +1885,13 @@ k_mul(PyLongObject *a, PyLongObject *b)
if (kmul_split(a, shift, &ah, &al) < 0) goto fail;
assert(ah->ob_size > 0); /* the split isn't degenerate */
- if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
+ if (a == b) {
+ bh = ah;
+ bl = al;
+ Py_INCREF(bh);
+ Py_INCREF(bl);
+ }
+ else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail;
/* The plan:
* 1. Allocate result space (asize + bsize digits: that's always
@@ -1906,7 +1960,11 @@ k_mul(PyLongObject *a, PyLongObject *b)
Py_DECREF(al);
ah = al = NULL;
- if ((t2 = x_add(bh, bl)) == NULL) {
+ if (a == b) {
+ t2 = t1;
+ Py_INCREF(t2);
+ }
+ else if ((t2 = x_add(bh, bl)) == NULL) {
Py_DECREF(t1);
goto fail;
}