summaryrefslogtreecommitdiffstats
path: root/Objects/longobject.c
diff options
context:
space:
mode:
authorMark Dickinson <mdickinson@enthought.com>2019-06-02 09:24:06 (GMT)
committerGitHub <noreply@github.com>2019-06-02 09:24:06 (GMT)
commitc52996785a45d4693857ea219e040777a14584f8 (patch)
treeee1d54ea597b45ef0d57407f8affaffe57f93be6 /Objects/longobject.c
parent5ae299ac78abb628803ab7dee0997364547f5cc8 (diff)
downloadcpython-c52996785a45d4693857ea219e040777a14584f8.zip
cpython-c52996785a45d4693857ea219e040777a14584f8.tar.gz
cpython-c52996785a45d4693857ea219e040777a14584f8.tar.bz2
bpo-36027: Extend three-argument pow to negative second argument (GH-13266)
Diffstat (limited to 'Objects/longobject.c')
-rw-r--r--Objects/longobject.c130
1 files changed, 118 insertions, 12 deletions
diff --git a/Objects/longobject.c b/Objects/longobject.c
index 5d2b595..49f1420 100644
--- a/Objects/longobject.c
+++ b/Objects/longobject.c
@@ -4174,6 +4174,98 @@ long_divmod(PyObject *a, PyObject *b)
return z;
}
+
+/* Compute an inverse to a modulo n, or raise ValueError if a is not
+ invertible modulo n. Assumes n is positive. The inverse returned
+ is whatever falls out of the extended Euclidean algorithm: it may
+ be either positive or negative, but will be smaller than n in
+ absolute value.
+
+ Pure Python equivalent for long_invmod:
+
+ def invmod(a, n):
+ b, c = 1, 0
+ while n:
+ q, r = divmod(a, n)
+ a, b, c, n = n, c, b - q*c, r
+
+ # at this point a is the gcd of the original inputs
+ if a == 1:
+ return b
+ raise ValueError("Not invertible")
+*/
+
+static PyLongObject *
+long_invmod(PyLongObject *a, PyLongObject *n)
+{
+ PyLongObject *b, *c;
+
+ /* Should only ever be called for positive n */
+ assert(Py_SIZE(n) > 0);
+
+ b = (PyLongObject *)PyLong_FromLong(1L);
+ if (b == NULL) {
+ return NULL;
+ }
+ c = (PyLongObject *)PyLong_FromLong(0L);
+ if (c == NULL) {
+ Py_DECREF(b);
+ return NULL;
+ }
+ Py_INCREF(a);
+ Py_INCREF(n);
+
+ /* references now owned: a, b, c, n */
+ while (Py_SIZE(n) != 0) {
+ PyLongObject *q, *r, *s, *t;
+
+ if (l_divmod(a, n, &q, &r) == -1) {
+ goto Error;
+ }
+ Py_DECREF(a);
+ a = n;
+ n = r;
+ t = (PyLongObject *)long_mul(q, c);
+ Py_DECREF(q);
+ if (t == NULL) {
+ goto Error;
+ }
+ s = (PyLongObject *)long_sub(b, t);
+ Py_DECREF(t);
+ if (s == NULL) {
+ goto Error;
+ }
+ Py_DECREF(b);
+ b = c;
+ c = s;
+ }
+ /* references now owned: a, b, c, n */
+
+ Py_DECREF(c);
+ Py_DECREF(n);
+ if (long_compare(a, _PyLong_One)) {
+ /* a != 1; we don't have an inverse. */
+ Py_DECREF(a);
+ Py_DECREF(b);
+ PyErr_SetString(PyExc_ValueError,
+ "base is not invertible for the given modulus");
+ return NULL;
+ }
+ else {
+ /* a == 1; b gives an inverse modulo n */
+ Py_DECREF(a);
+ return b;
+ }
+
+ Error:
+ Py_DECREF(a);
+ Py_DECREF(b);
+ Py_DECREF(c);
+ Py_DECREF(n);
+ return NULL;
+}
+
+
/* pow(v, w, x) */
static PyObject *
long_pow(PyObject *v, PyObject *w, PyObject *x)
@@ -4207,20 +4299,14 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
Py_RETURN_NOTIMPLEMENTED;
}
- if (Py_SIZE(b) < 0) { /* if exponent is negative */
- if (c) {
- PyErr_SetString(PyExc_ValueError, "pow() 2nd argument "
- "cannot be negative when 3rd argument specified");
- goto Error;
- }
- else {
- /* else return a float. This works because we know
+ if (Py_SIZE(b) < 0 && c == NULL) {
+ /* if exponent is negative and there's no modulus:
+ return a float. This works because we know
that this calls float_pow() which converts its
arguments to double. */
- Py_DECREF(a);
- Py_DECREF(b);
- return PyFloat_Type.tp_as_number->nb_power(v, w, x);
- }
+ Py_DECREF(a);
+ Py_DECREF(b);
+ return PyFloat_Type.tp_as_number->nb_power(v, w, x);
}
if (c) {
@@ -4255,6 +4341,26 @@ long_pow(PyObject *v, PyObject *w, PyObject *x)
goto Done;
}
+ /* if exponent is negative, negate the exponent and
+ replace the base with a modular inverse */
+ if (Py_SIZE(b) < 0) {
+ temp = (PyLongObject *)_PyLong_Copy(b);
+ if (temp == NULL)
+ goto Error;
+ Py_DECREF(b);
+ b = temp;
+ temp = NULL;
+ _PyLong_Negate(&b);
+ if (b == NULL)
+ goto Error;
+
+ temp = long_invmod(a, c);
+ if (temp == NULL)
+ goto Error;
+ Py_DECREF(a);
+ a = temp;
+ }
+
/* Reduce base by modulus in some cases:
1. If base < 0. Forcing the base non-negative makes things easier.
2. If base is obviously larger than the modulus. The "small