diff options
Diffstat (limited to 'Objects/longobject.c')
-rw-r--r-- | Objects/longobject.c | 100 |
1 files changed, 79 insertions, 21 deletions
diff --git a/Objects/longobject.c b/Objects/longobject.c index f246bd2..2f6d103 100644 --- a/Objects/longobject.c +++ b/Objects/longobject.c @@ -12,7 +12,8 @@ * both operands contain more than KARATSUBA_CUTOFF digits (this * being an internal Python long digit, in base BASE). */ -#define KARATSUBA_CUTOFF 35 +#define KARATSUBA_CUTOFF 70 +#define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF) #define ABS(x) ((x) < 0 ? -(x) : (x)) @@ -1717,26 +1718,72 @@ x_mul(PyLongObject *a, PyLongObject *b) return NULL; memset(z->ob_digit, 0, z->ob_size * sizeof(digit)); - for (i = 0; i < size_a; ++i) { - twodigits carry = 0; - twodigits f = a->ob_digit[i]; - int j; - digit *pz = z->ob_digit + i; + if (a == b) { + /* Efficient squaring per HAC, Algorithm 14.16: + * http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf + * Gives slightly less than a 2x speedup when a == b, + * via exploiting that each entry in the multiplication + * pyramid appears twice (except for the size_a squares). + */ + for (i = 0; i < size_a; ++i) { + twodigits carry; + twodigits f = a->ob_digit[i]; + digit *pz = z->ob_digit + (i << 1); + digit *pa = a->ob_digit + i + 1; + digit *paend = a->ob_digit + size_a; - SIGCHECK({ - Py_DECREF(z); - return NULL; - }) - for (j = 0; j < size_b; ++j) { - carry += *pz + b->ob_digit[j] * f; - *pz++ = (digit) (carry & MASK); + SIGCHECK({ + Py_DECREF(z); + return NULL; + }) + + carry = *pz + f * f; + *pz++ = (digit)(carry & MASK); carry >>= SHIFT; + assert(carry <= MASK); + + /* Now f is added in twice in each column of the + * pyramid it appears. Same as adding f<<1 once. + */ + f <<= 1; + while (pa < paend) { + carry += *pz + *pa++ * f; + *pz++ = (digit)(carry & MASK); + carry >>= SHIFT; + assert(carry <= (MASK << 1)); + } + if (carry) { + carry += *pz; + *pz++ = (digit)(carry & MASK); + carry >>= SHIFT; + } + if (carry) + *pz += (digit)(carry & MASK); + assert((carry >> SHIFT) == 0); } - for (; carry != 0; ++j) { - assert(i+j < z->ob_size); - carry += *pz; - *pz++ = (digit) (carry & MASK); - carry >>= SHIFT; + } + else { /* a is not the same as b -- gradeschool long mult */ + for (i = 0; i < size_a; ++i) { + twodigits carry = 0; + twodigits f = a->ob_digit[i]; + digit *pz = z->ob_digit + i; + digit *pb = b->ob_digit; + digit *pbend = b->ob_digit + size_b; + + SIGCHECK({ + Py_DECREF(z); + return NULL; + }) + + while (pb < pbend) { + carry += *pz + *pb++ * f; + *pz++ = (digit)(carry & MASK); + carry >>= SHIFT; + assert(carry <= MASK); + } + if (carry) + *pz += (digit)(carry & MASK); + assert((carry >> SHIFT) == 0); } } return long_normalize(z); @@ -1816,7 +1863,8 @@ k_mul(PyLongObject *a, PyLongObject *b) } /* Use gradeschool math when either number is too small. */ - if (asize <= KARATSUBA_CUTOFF) { + i = a == b ? KARATSUBA_SQUARE_CUTOFF : KARATSUBA_CUTOFF; + if (asize <= i) { if (asize == 0) return _PyLong_New(0); else @@ -1837,7 +1885,13 @@ k_mul(PyLongObject *a, PyLongObject *b) if (kmul_split(a, shift, &ah, &al) < 0) goto fail; assert(ah->ob_size > 0); /* the split isn't degenerate */ - if (kmul_split(b, shift, &bh, &bl) < 0) goto fail; + if (a == b) { + bh = ah; + bl = al; + Py_INCREF(bh); + Py_INCREF(bl); + } + else if (kmul_split(b, shift, &bh, &bl) < 0) goto fail; /* The plan: * 1. Allocate result space (asize + bsize digits: that's always @@ -1906,7 +1960,11 @@ k_mul(PyLongObject *a, PyLongObject *b) Py_DECREF(al); ah = al = NULL; - if ((t2 = x_add(bh, bl)) == NULL) { + if (a == b) { + t2 = t1; + Py_INCREF(t2); + } + else if ((t2 = x_add(bh, bl)) == NULL) { Py_DECREF(t1); goto fail; } |