From 2fa33db12b8cb6ec1dd1b87df6911e311d98457b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 23 Aug 2007 22:07:24 +0000 Subject: Finish the work on __round__ and __trunc__. With Alex Martelli and Keir Mierle. --- Lib/test/test_builtin.py | 42 +++++++++++++++++++++++ Objects/floatobject.c | 67 ++++++++++++++++++++++++++++++------ Objects/longobject.c | 36 ++++++++++++++++++++ Python/bltinmodule.c | 88 +++++++++++++++++++++--------------------------- 4 files changed, 173 insertions(+), 60 deletions(-) diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index 37ea8ba..f77cf78 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -1440,6 +1440,7 @@ class BuiltinTest(unittest.TestCase): def test_round(self): self.assertEqual(round(0.0), 0.0) + self.assertEqual(type(round(0.0)), int) self.assertEqual(round(1.0), 1.0) self.assertEqual(round(10.0), 10.0) self.assertEqual(round(1000000000.0), 1000000000.0) @@ -1468,6 +1469,25 @@ class BuiltinTest(unittest.TestCase): self.assertEqual(round(-999999999.9), -1000000000.0) self.assertEqual(round(-8.0, -1), -10.0) + self.assertEqual(type(round(-8.0, -1)), float) + + self.assertEqual(type(round(-8.0, 0)), float) + self.assertEqual(type(round(-8.0, 1)), float) + + # Check even / odd rounding behaviour + self.assertEqual(round(5.5), 6) + self.assertEqual(round(6.5), 6) + self.assertEqual(round(-5.5), -6) + self.assertEqual(round(-6.5), -6) + + # Check behavior on ints + self.assertEqual(round(0), 0) + self.assertEqual(round(8), 8) + self.assertEqual(round(-8), -8) + self.assertEqual(type(round(0)), int) + self.assertEqual(type(round(-8, -1)), float) + self.assertEqual(type(round(-8, 0)), float) + self.assertEqual(type(round(-8, 1)), float) # test new kwargs self.assertEqual(round(number=-8.0, ndigits=-1), -10.0) @@ -1487,6 +1507,11 @@ class BuiltinTest(unittest.TestCase): self.assertRaises(TypeError, round, 1, 2, 3) self.assertRaises(TypeError, round, TestNoRound()) + t = TestNoRound() + t.__round__ = lambda *args: args + self.assertRaises(TypeError, round, t) + self.assertRaises(TypeError, round, t, 0) + def test_setattr(self): setattr(sys, 'spam', 1) self.assertEqual(sys.spam, 1) @@ -1529,6 +1554,18 @@ class BuiltinTest(unittest.TestCase): self.assertRaises(ValueError, sum, BadSeq()) def test_trunc(self): + + self.assertEqual(trunc(1), 1) + self.assertEqual(trunc(-1), -1) + self.assertEqual(type(trunc(1)), int) + self.assertEqual(type(trunc(1.5)), int) + self.assertEqual(trunc(1.5), 1) + self.assertEqual(trunc(-1.5), -1) + self.assertEqual(trunc(1.999999), 1) + self.assertEqual(trunc(-1.999999), -1) + self.assertEqual(trunc(-0.999999), -0) + self.assertEqual(trunc(-100.999), -100) + class TestTrunc: def __trunc__(self): return 23 @@ -1542,6 +1579,11 @@ class BuiltinTest(unittest.TestCase): self.assertRaises(TypeError, trunc, 1, 2) self.assertRaises(TypeError, trunc, TestNoTrunc()) + t = TestNoTrunc() + t.__trunc__ = lambda *args: args + self.assertRaises(TypeError, trunc, t) + self.assertRaises(TypeError, trunc, t, 0) + def test_tuple(self): self.assertEqual(tuple(()), ()) t0_3 = (0, 1, 2, 3) diff --git a/Objects/floatobject.c b/Objects/floatobject.c index 908258c..09efa12 100644 --- a/Objects/floatobject.c +++ b/Objects/floatobject.c @@ -743,14 +743,7 @@ float_bool(PyFloatObject *v) } static PyObject * -float_long(PyObject *v) -{ - double x = PyFloat_AsDouble(v); - return PyLong_FromDouble(x); -} - -static PyObject * -float_int(PyObject *v) +float_trunc(PyObject *v) { double x = PyFloat_AsDouble(v); double wholepart; /* integral portion of x, rounded toward 0 */ @@ -776,6 +769,55 @@ float_int(PyObject *v) } static PyObject * +float_round(PyObject *v, PyObject *args) +{ +#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */ + double x; + double f; + double flr, cil; + double rounded; + int i; + int ndigits = UNDEF_NDIGITS; + + if (!PyArg_ParseTuple(args, "|i", &ndigits)) + return NULL; + + x = PyFloat_AsDouble(v); + + if (ndigits != UNDEF_NDIGITS) { + f = 1.0; + i = abs(ndigits); + while (--i >= 0) + f = f*10.0; + if (ndigits < 0) + x /= f; + else + x *= f; + } + + flr = floor(x); + cil = ceil(x); + + if (x-flr > 0.5) + rounded = cil; + else if (x-flr == 0.5) + rounded = fmod(flr, 2) == 0 ? flr : cil; + else + rounded = flr; + + if (ndigits != UNDEF_NDIGITS) { + if (ndigits < 0) + rounded *= f; + else + rounded /= f; + return PyFloat_FromDouble(rounded); + } + + return PyLong_FromDouble(rounded); +#undef UNDEF_NDIGITS +} + +static PyObject * float_float(PyObject *v) { if (PyFloat_CheckExact(v)) @@ -976,6 +1018,11 @@ float_getzero(PyObject *v, void *closure) static PyMethodDef float_methods[] = { {"conjugate", (PyCFunction)float_float, METH_NOARGS, "Returns self, the complex conjugate of any float."}, + {"__trunc__", (PyCFunction)float_trunc, METH_NOARGS, + "Returns the Integral closest to x between 0 and x."}, + {"__round__", (PyCFunction)float_round, METH_VARARGS, + "Returns the Integral closest to x, rounding half toward even.\n" + "When an argument is passed, works like built-in round(x, ndigits)."}, {"__getnewargs__", (PyCFunction)float_getnewargs, METH_NOARGS}, {"__getformat__", (PyCFunction)float_getformat, METH_O|METH_CLASS, float_getformat_doc}, @@ -1020,8 +1067,8 @@ static PyNumberMethods float_as_number = { 0, /*nb_xor*/ 0, /*nb_or*/ (coercion)0, /*nb_coerce*/ - float_int, /*nb_int*/ - float_long, /*nb_long*/ + float_trunc, /*nb_int*/ + float_trunc, /*nb_long*/ float_float, /*nb_float*/ 0, /* nb_oct */ 0, /* nb_hex */ diff --git a/Objects/longobject.c b/Objects/longobject.c index 518e607..ddf359d 100644 --- a/Objects/longobject.c +++ b/Objects/longobject.c @@ -3592,9 +3592,45 @@ long_getN(PyLongObject *v, void *context) { return PyLong_FromLong((intptr_t)context); } +static PyObject * +long_round(PyObject *self, PyObject *args) +{ +#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */ + int ndigits = UNDEF_NDIGITS; + double x; + PyObject *res; + + if (!PyArg_ParseTuple(args, "|i", &ndigits)) + return NULL; + + if (ndigits == UNDEF_NDIGITS) + return long_long(self); + + /* If called with two args, defer to float.__round__(). */ + x = PyLong_AsDouble(self); + if (x == -1.0 && PyErr_Occurred()) + return NULL; + self = PyFloat_FromDouble(x); + if (self == NULL) + return NULL; + res = PyObject_CallMethod(self, "__round__", "i", ndigits); + Py_DECREF(self); + return res; +#undef UNDEF_NDIGITS +} + static PyMethodDef long_methods[] = { {"conjugate", (PyCFunction)long_long, METH_NOARGS, "Returns self, the complex conjugate of any int."}, + {"__trunc__", (PyCFunction)long_long, METH_NOARGS, + "Truncating an Integral returns itself."}, + {"__floor__", (PyCFunction)long_long, METH_NOARGS, + "Flooring an Integral returns itself."}, + {"__ceil__", (PyCFunction)long_long, METH_NOARGS, + "Ceiling of an Integral returns itself."}, + {"__round__", (PyCFunction)long_round, METH_VARARGS, + "Rounding an Integral returns itself.\n" + "Rounding with an ndigits arguments defers to float.__round__."}, {"__getnewargs__", (PyCFunction)long_getnewargs, METH_NOARGS}, {NULL, NULL} /* sentinel */ }; diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c index b55dd51..9bbf64b 100644 --- a/Python/bltinmodule.c +++ b/Python/bltinmodule.c @@ -1373,63 +1373,44 @@ For most object types, eval(repr(object)) == object."); static PyObject * builtin_round(PyObject *self, PyObject *args, PyObject *kwds) { - double number; - double f; - int ndigits = 0; - int i; +#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */ + static PyObject *round_str = NULL; + int ndigits = UNDEF_NDIGITS; static char *kwlist[] = {"number", "ndigits", 0}; - PyObject* real; + PyObject *number, *round; if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|i:round", - kwlist, &real, &ndigits)) + kwlist, &number, &ndigits)) return NULL; - if (ndigits == 0) { - PyObject *res; - PyObject *d = PyObject_GetAttrString(real, "__round__"); - if (d == NULL && !PyFloat_Check(real)) { - PyErr_SetString(PyExc_TypeError, - "round() argument must have __round__ attribute or be a float"); + if (round_str == NULL) { + round_str = PyUnicode_FromString("__round__"); + if (round_str == NULL) return NULL; - } - if (d == NULL) { - PyErr_Clear(); - } else { - res = PyObject_CallFunction(d, ""); - Py_DECREF(d); - return res; - } - } else if (!PyFloat_Check(real)) { - PyErr_SetString(PyExc_TypeError, - "round() argument must have __round__ attribute or be a float"); + } + + round = _PyType_Lookup(Py_Type(number), round_str); + if (round == NULL) { + PyErr_Format(PyExc_TypeError, + "type %.100s doesn't define __round__ method", + Py_Type(number)->tp_name); return NULL; } - number = PyFloat_AsDouble(real); - f = 1.0; - i = abs(ndigits); - while (--i >= 0) - f = f*10.0; - if (ndigits < 0) - number /= f; - else - number *= f; - if (number >= 0.0) - number = floor(number + 0.5); + if (ndigits == UNDEF_NDIGITS) + return PyObject_CallFunction(round, "O", number); else - number = ceil(number - 0.5); - if (ndigits < 0) - number *= f; - else - number /= f; - return PyFloat_FromDouble(number); + return PyObject_CallFunction(round, "Oi", number, ndigits); +#undef UNDEF_NDIGITS } PyDoc_STRVAR(round_doc, "round(number[, ndigits]) -> floating point number\n\ \n\ Round a number to a given precision in decimal digits (default 0 digits).\n\ -This always returns a floating point number. Precision may be negative."); +This returns an int when called with one argument, otherwise a float.\n\ +Precision may be negative."); + static PyObject * builtin_sorted(PyObject *self, PyObject *args, PyObject *kwds) @@ -1511,18 +1492,25 @@ Without arguments, equivalent to locals().\n\ With an argument, equivalent to object.__dict__."); static PyObject * -builtin_trunc(PyObject *self, PyObject *v) +builtin_trunc(PyObject *self, PyObject *number) { - PyObject *res; - PyObject *d = PyObject_GetAttrString(v, "__trunc__"); - if (d == NULL) { - PyErr_SetString(PyExc_TypeError, - "trunc() argument must have __trunc__ attribute"); + static PyObject *trunc_str = NULL; + PyObject *trunc; + + if (trunc_str == NULL) { + trunc_str = PyUnicode_FromString("__trunc__"); + if (trunc_str == NULL) + return NULL; + } + + trunc = _PyType_Lookup(Py_Type(number), trunc_str); + if (trunc == NULL) { + PyErr_Format(PyExc_TypeError, + "type %.100s doesn't define __trunc__ method", + Py_Type(number)->tp_name); return NULL; } - res = PyObject_CallFunction(d, ""); - Py_DECREF(d); - return res; + return PyObject_CallFunction(trunc, "O", number); } PyDoc_STRVAR(trunc_doc, -- cgit v0.12