diff options
author | Tim Peters <tim.peters@gmail.com> | 2002-08-12 15:08:20 (GMT) |
---|---|---|
committer | Tim Peters <tim.peters@gmail.com> | 2002-08-12 15:08:20 (GMT) |
commit | 738eda742cd17739d65594b557fba442f65ff18a (patch) | |
tree | 330012d5cd4f8eb265b114f185ba96d705d24271 /Objects | |
parent | 5d546674d18b3e96a473454167f1e07d313e89ca (diff) | |
download | cpython-738eda742cd17739d65594b557fba442f65ff18a.zip cpython-738eda742cd17739d65594b557fba442f65ff18a.tar.gz cpython-738eda742cd17739d65594b557fba442f65ff18a.tar.bz2 |
k_mul: Rearranged computation for better cache use. Ignored overflow
(it's possible, but should be harmless -- this requires more thought,
and allocating enough space in advance to prevent it requires exactly
as much thought, to know exactly how much that is -- the end result
certainly fits in the allocated space -- hmm, but that's really all
the thought it needs! borrows/carries out of the high digits really
are harmless).
Diffstat (limited to 'Objects')
-rw-r--r-- | Objects/longobject.c | 110 |
1 files changed, 50 insertions, 60 deletions
diff --git a/Objects/longobject.c b/Objects/longobject.c index 6dedd38..bf82d73 100644 --- a/Objects/longobject.c +++ b/Objects/longobject.c @@ -1598,20 +1598,17 @@ kmul_split(PyLongObject *n, int size, PyLongObject **high, PyLongObject **low) static PyLongObject * k_mul(PyLongObject *a, PyLongObject *b) { + int asize = ABS(a->ob_size); + int bsize = ABS(b->ob_size); PyLongObject *ah = NULL; PyLongObject *al = NULL; PyLongObject *bh = NULL; PyLongObject *bl = NULL; - PyLongObject *albl = NULL; - PyLongObject *ahbh = NULL; - PyLongObject *k = NULL; PyLongObject *ret = NULL; - PyLongObject *t1, *t2; + PyLongObject *t1, *t2, *t3; int shift; /* the number of digits we split off */ int i; -#ifdef Py_DEBUG - digit d; -#endif + /* (ah*X+al)(bh*X+bl) = ah*bh*X*X + (ah*bl + al*bh)*X + al*bl * Let k = (ah+al)*(bh+bl) = ah*bl + al*bh + ah*bh + al*bl * Then the original product is @@ -1623,59 +1620,75 @@ k_mul(PyLongObject *a, PyLongObject *b) /* We want to split based on the larger number; fiddle so that b * is largest. */ - if (ABS(a->ob_size) > ABS(b->ob_size)) { + if (asize > bsize) { t1 = a; a = b; b = t1; + + i = asize; + asize = bsize; + bsize = i; } /* Use gradeschool math when either number is too small. */ - if (ABS(a->ob_size) <= KARATSUBA_CUTOFF) { + if (asize <= KARATSUBA_CUTOFF) { /* 0 is inevitable if one kmul arg has more than twice * the digits of another, so it's worth special-casing. */ - if (a->ob_size == 0) + if (asize == 0) return _PyLong_New(0); else return x_mul(a, b); } - shift = ABS(b->ob_size) >> 1; + shift = bsize >> 1; if (kmul_split(a, shift, &ah, &al) < 0) goto fail; if (kmul_split(b, shift, &bh, &bl) < 0) goto fail; - if ((ahbh = k_mul(ah, bh)) == NULL) goto fail; - assert(ahbh->ob_size >= 0); - - /* Allocate result space, and copy ahbh into the high digits. */ - ret = _PyLong_New(ABS(a->ob_size) + ABS(b->ob_size)); + /* Allocate result space. */ + ret = _PyLong_New(asize + bsize); if (ret == NULL) goto fail; #ifdef Py_DEBUG /* Fill with trash, to catch reference to uninitialized digits. */ memset(ret->ob_digit, 0xDF, ret->ob_size * sizeof(digit)); #endif - assert(2*shift + ahbh->ob_size <= ret->ob_size); - memcpy(ret->ob_digit + 2*shift, ahbh->ob_digit, - ahbh->ob_size * sizeof(digit)); - /* Zero-out the digits higher than the ahbh copy. */ - i = ret->ob_size - 2*shift - ahbh->ob_size; + /* t1 <- ah*bh, and copy into high digits of result. */ + if ((t1 = k_mul(ah, bh)) == NULL) goto fail; + assert(t1->ob_size >= 0); + assert(2*shift + t1->ob_size <= ret->ob_size); + memcpy(ret->ob_digit + 2*shift, t1->ob_digit, + t1->ob_size * sizeof(digit)); + + /* Zero-out the digits higher than the ah*bh copy. */ + i = ret->ob_size - 2*shift - t1->ob_size; if (i) - memset(ret->ob_digit + 2*shift + ahbh->ob_size, 0, + memset(ret->ob_digit + 2*shift + t1->ob_size, 0, i * sizeof(digit)); - /* Compute al*bl, and copy into the low digits. */ - if ((albl = k_mul(al, bl)) == NULL) goto fail; - assert(albl->ob_size >= 0); - assert(albl->ob_size <= 2*shift); /* no overlap with high digits */ - memcpy(ret->ob_digit, albl->ob_digit, albl->ob_size * sizeof(digit)); + /* t2 <- al*bl, and copy into the low digits. */ + if ((t2 = k_mul(al, bl)) == NULL) { + Py_DECREF(t1); + goto fail; + } + assert(t2->ob_size >= 0); + assert(t2->ob_size <= 2*shift); /* no overlap with high digits */ + memcpy(ret->ob_digit, t2->ob_digit, t2->ob_size * sizeof(digit)); /* Zero out remaining digits. */ - i = 2*shift - albl->ob_size; /* number of uninitialized digits */ + i = 2*shift - t2->ob_size; /* number of uninitialized digits */ if (i) - memset(ret->ob_digit + albl->ob_size, 0, i * sizeof(digit)); + memset(ret->ob_digit + t2->ob_size, 0, i * sizeof(digit)); + + /* Subtract ah*bh (t1) and al*bl (t2) from "the middle" digits. */ + i = ret->ob_size - shift; /* # digits after shift */ + v_isub(ret->ob_digit + shift, i, t2->ob_digit, t2->ob_size); + Py_DECREF(t2); - /* k = (ah+al)(bh+bl) */ + v_isub(ret->ob_digit + shift, i, t1->ob_digit, t1->ob_size); + Py_DECREF(t1); + + /* t3 <- (ah+al)(bh+bl) */ if ((t1 = x_add(ah, al)) == NULL) goto fail; Py_DECREF(ah); Py_DECREF(al); @@ -1689,36 +1702,16 @@ k_mul(PyLongObject *a, PyLongObject *b) Py_DECREF(bl); bh = bl = NULL; - k = k_mul(t1, t2); + t3 = k_mul(t1, t2); + assert(t3->ob_size >= 0); Py_DECREF(t1); Py_DECREF(t2); - if (k == NULL) goto fail; - - /* Add k into the result, starting at the shift'th LSD. */ - i = ret->ob_size - shift; /* # digits after shift */ -#ifdef Py_DEBUG - d = -#endif - v_iadd(ret->ob_digit + shift, i, k->ob_digit, k->ob_size); - assert(d == 0); - Py_DECREF(k); + if (t3 == NULL) goto fail; - /* Subtract ahbh and albl from the result. Note that this can't - * become negative, since k = ahbh + albl + other stuff. - */ -#ifdef Py_DEBUG - d = -#endif - v_isub(ret->ob_digit + shift, i, ahbh->ob_digit, ahbh->ob_size); - assert(d == 0); - Py_DECREF(ahbh); - -#ifdef Py_DEBUG - d = -#endif - v_isub(ret->ob_digit + shift, i, albl->ob_digit, albl->ob_size); - assert(d == 0); - Py_DECREF(albl); + /* Add t3. */ + v_iadd(ret->ob_digit + shift, ret->ob_size - shift, + t3->ob_digit, t3->ob_size); + Py_DECREF(t3); return long_normalize(ret); @@ -1728,9 +1721,6 @@ k_mul(PyLongObject *a, PyLongObject *b) Py_XDECREF(al); Py_XDECREF(bh); Py_XDECREF(bl); - Py_XDECREF(ahbh); - Py_XDECREF(albl); - Py_XDECREF(k); return NULL; } |