summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_builtin.py42
-rw-r--r--Objects/floatobject.c67
-rw-r--r--Objects/longobject.c36
-rw-r--r--Python/bltinmodule.c88
4 files changed, 173 insertions, 60 deletions
diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py
index 37ea8ba..f77cf78 100644
--- a/Lib/test/test_builtin.py
+++ b/Lib/test/test_builtin.py
@@ -1440,6 +1440,7 @@ class BuiltinTest(unittest.TestCase):
def test_round(self):
self.assertEqual(round(0.0), 0.0)
+ self.assertEqual(type(round(0.0)), int)
self.assertEqual(round(1.0), 1.0)
self.assertEqual(round(10.0), 10.0)
self.assertEqual(round(1000000000.0), 1000000000.0)
@@ -1468,6 +1469,25 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(round(-999999999.9), -1000000000.0)
self.assertEqual(round(-8.0, -1), -10.0)
+ self.assertEqual(type(round(-8.0, -1)), float)
+
+ self.assertEqual(type(round(-8.0, 0)), float)
+ self.assertEqual(type(round(-8.0, 1)), float)
+
+ # Check even / odd rounding behaviour
+ self.assertEqual(round(5.5), 6)
+ self.assertEqual(round(6.5), 6)
+ self.assertEqual(round(-5.5), -6)
+ self.assertEqual(round(-6.5), -6)
+
+ # Check behavior on ints
+ self.assertEqual(round(0), 0)
+ self.assertEqual(round(8), 8)
+ self.assertEqual(round(-8), -8)
+ self.assertEqual(type(round(0)), int)
+ self.assertEqual(type(round(-8, -1)), float)
+ self.assertEqual(type(round(-8, 0)), float)
+ self.assertEqual(type(round(-8, 1)), float)
# test new kwargs
self.assertEqual(round(number=-8.0, ndigits=-1), -10.0)
@@ -1487,6 +1507,11 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(TypeError, round, 1, 2, 3)
self.assertRaises(TypeError, round, TestNoRound())
+ t = TestNoRound()
+ t.__round__ = lambda *args: args
+ self.assertRaises(TypeError, round, t)
+ self.assertRaises(TypeError, round, t, 0)
+
def test_setattr(self):
setattr(sys, 'spam', 1)
self.assertEqual(sys.spam, 1)
@@ -1529,6 +1554,18 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(ValueError, sum, BadSeq())
def test_trunc(self):
+
+ self.assertEqual(trunc(1), 1)
+ self.assertEqual(trunc(-1), -1)
+ self.assertEqual(type(trunc(1)), int)
+ self.assertEqual(type(trunc(1.5)), int)
+ self.assertEqual(trunc(1.5), 1)
+ self.assertEqual(trunc(-1.5), -1)
+ self.assertEqual(trunc(1.999999), 1)
+ self.assertEqual(trunc(-1.999999), -1)
+ self.assertEqual(trunc(-0.999999), -0)
+ self.assertEqual(trunc(-100.999), -100)
+
class TestTrunc:
def __trunc__(self):
return 23
@@ -1542,6 +1579,11 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(TypeError, trunc, 1, 2)
self.assertRaises(TypeError, trunc, TestNoTrunc())
+ t = TestNoTrunc()
+ t.__trunc__ = lambda *args: args
+ self.assertRaises(TypeError, trunc, t)
+ self.assertRaises(TypeError, trunc, t, 0)
+
def test_tuple(self):
self.assertEqual(tuple(()), ())
t0_3 = (0, 1, 2, 3)
diff --git a/Objects/floatobject.c b/Objects/floatobject.c
index 908258c..09efa12 100644
--- a/Objects/floatobject.c
+++ b/Objects/floatobject.c
@@ -743,14 +743,7 @@ float_bool(PyFloatObject *v)
}
static PyObject *
-float_long(PyObject *v)
-{
- double x = PyFloat_AsDouble(v);
- return PyLong_FromDouble(x);
-}
-
-static PyObject *
-float_int(PyObject *v)
+float_trunc(PyObject *v)
{
double x = PyFloat_AsDouble(v);
double wholepart; /* integral portion of x, rounded toward 0 */
@@ -776,6 +769,55 @@ float_int(PyObject *v)
}
static PyObject *
+float_round(PyObject *v, PyObject *args)
+{
+#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */
+ double x;
+ double f;
+ double flr, cil;
+ double rounded;
+ int i;
+ int ndigits = UNDEF_NDIGITS;
+
+ if (!PyArg_ParseTuple(args, "|i", &ndigits))
+ return NULL;
+
+ x = PyFloat_AsDouble(v);
+
+ if (ndigits != UNDEF_NDIGITS) {
+ f = 1.0;
+ i = abs(ndigits);
+ while (--i >= 0)
+ f = f*10.0;
+ if (ndigits < 0)
+ x /= f;
+ else
+ x *= f;
+ }
+
+ flr = floor(x);
+ cil = ceil(x);
+
+ if (x-flr > 0.5)
+ rounded = cil;
+ else if (x-flr == 0.5)
+ rounded = fmod(flr, 2) == 0 ? flr : cil;
+ else
+ rounded = flr;
+
+ if (ndigits != UNDEF_NDIGITS) {
+ if (ndigits < 0)
+ rounded *= f;
+ else
+ rounded /= f;
+ return PyFloat_FromDouble(rounded);
+ }
+
+ return PyLong_FromDouble(rounded);
+#undef UNDEF_NDIGITS
+}
+
+static PyObject *
float_float(PyObject *v)
{
if (PyFloat_CheckExact(v))
@@ -976,6 +1018,11 @@ float_getzero(PyObject *v, void *closure)
static PyMethodDef float_methods[] = {
{"conjugate", (PyCFunction)float_float, METH_NOARGS,
"Returns self, the complex conjugate of any float."},
+ {"__trunc__", (PyCFunction)float_trunc, METH_NOARGS,
+ "Returns the Integral closest to x between 0 and x."},
+ {"__round__", (PyCFunction)float_round, METH_VARARGS,
+ "Returns the Integral closest to x, rounding half toward even.\n"
+ "When an argument is passed, works like built-in round(x, ndigits)."},
{"__getnewargs__", (PyCFunction)float_getnewargs, METH_NOARGS},
{"__getformat__", (PyCFunction)float_getformat,
METH_O|METH_CLASS, float_getformat_doc},
@@ -1020,8 +1067,8 @@ static PyNumberMethods float_as_number = {
0, /*nb_xor*/
0, /*nb_or*/
(coercion)0, /*nb_coerce*/
- float_int, /*nb_int*/
- float_long, /*nb_long*/
+ float_trunc, /*nb_int*/
+ float_trunc, /*nb_long*/
float_float, /*nb_float*/
0, /* nb_oct */
0, /* nb_hex */
diff --git a/Objects/longobject.c b/Objects/longobject.c
index 518e607..ddf359d 100644
--- a/Objects/longobject.c
+++ b/Objects/longobject.c
@@ -3592,9 +3592,45 @@ long_getN(PyLongObject *v, void *context) {
return PyLong_FromLong((intptr_t)context);
}
+static PyObject *
+long_round(PyObject *self, PyObject *args)
+{
+#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */
+ int ndigits = UNDEF_NDIGITS;
+ double x;
+ PyObject *res;
+
+ if (!PyArg_ParseTuple(args, "|i", &ndigits))
+ return NULL;
+
+ if (ndigits == UNDEF_NDIGITS)
+ return long_long(self);
+
+ /* If called with two args, defer to float.__round__(). */
+ x = PyLong_AsDouble(self);
+ if (x == -1.0 && PyErr_Occurred())
+ return NULL;
+ self = PyFloat_FromDouble(x);
+ if (self == NULL)
+ return NULL;
+ res = PyObject_CallMethod(self, "__round__", "i", ndigits);
+ Py_DECREF(self);
+ return res;
+#undef UNDEF_NDIGITS
+}
+
static PyMethodDef long_methods[] = {
{"conjugate", (PyCFunction)long_long, METH_NOARGS,
"Returns self, the complex conjugate of any int."},
+ {"__trunc__", (PyCFunction)long_long, METH_NOARGS,
+ "Truncating an Integral returns itself."},
+ {"__floor__", (PyCFunction)long_long, METH_NOARGS,
+ "Flooring an Integral returns itself."},
+ {"__ceil__", (PyCFunction)long_long, METH_NOARGS,
+ "Ceiling of an Integral returns itself."},
+ {"__round__", (PyCFunction)long_round, METH_VARARGS,
+ "Rounding an Integral returns itself.\n"
+ "Rounding with an ndigits arguments defers to float.__round__."},
{"__getnewargs__", (PyCFunction)long_getnewargs, METH_NOARGS},
{NULL, NULL} /* sentinel */
};
diff --git a/Python/bltinmodule.c b/Python/bltinmodule.c
index b55dd51..9bbf64b 100644
--- a/Python/bltinmodule.c
+++ b/Python/bltinmodule.c
@@ -1373,63 +1373,44 @@ For most object types, eval(repr(object)) == object.");
static PyObject *
builtin_round(PyObject *self, PyObject *args, PyObject *kwds)
{
- double number;
- double f;
- int ndigits = 0;
- int i;
+#define UNDEF_NDIGITS (-0x7fffffff) /* Unlikely ndigits value */
+ static PyObject *round_str = NULL;
+ int ndigits = UNDEF_NDIGITS;
static char *kwlist[] = {"number", "ndigits", 0};
- PyObject* real;
+ PyObject *number, *round;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|i:round",
- kwlist, &real, &ndigits))
+ kwlist, &number, &ndigits))
return NULL;
- if (ndigits == 0) {
- PyObject *res;
- PyObject *d = PyObject_GetAttrString(real, "__round__");
- if (d == NULL && !PyFloat_Check(real)) {
- PyErr_SetString(PyExc_TypeError,
- "round() argument must have __round__ attribute or be a float");
+ if (round_str == NULL) {
+ round_str = PyUnicode_FromString("__round__");
+ if (round_str == NULL)
return NULL;
- }
- if (d == NULL) {
- PyErr_Clear();
- } else {
- res = PyObject_CallFunction(d, "");
- Py_DECREF(d);
- return res;
- }
- } else if (!PyFloat_Check(real)) {
- PyErr_SetString(PyExc_TypeError,
- "round() argument must have __round__ attribute or be a float");
+ }
+
+ round = _PyType_Lookup(Py_Type(number), round_str);
+ if (round == NULL) {
+ PyErr_Format(PyExc_TypeError,
+ "type %.100s doesn't define __round__ method",
+ Py_Type(number)->tp_name);
return NULL;
}
- number = PyFloat_AsDouble(real);
- f = 1.0;
- i = abs(ndigits);
- while (--i >= 0)
- f = f*10.0;
- if (ndigits < 0)
- number /= f;
- else
- number *= f;
- if (number >= 0.0)
- number = floor(number + 0.5);
+ if (ndigits == UNDEF_NDIGITS)
+ return PyObject_CallFunction(round, "O", number);
else
- number = ceil(number - 0.5);
- if (ndigits < 0)
- number *= f;
- else
- number /= f;
- return PyFloat_FromDouble(number);
+ return PyObject_CallFunction(round, "Oi", number, ndigits);
+#undef UNDEF_NDIGITS
}
PyDoc_STRVAR(round_doc,
"round(number[, ndigits]) -> floating point number\n\
\n\
Round a number to a given precision in decimal digits (default 0 digits).\n\
-This always returns a floating point number. Precision may be negative.");
+This returns an int when called with one argument, otherwise a float.\n\
+Precision may be negative.");
+
static PyObject *
builtin_sorted(PyObject *self, PyObject *args, PyObject *kwds)
@@ -1511,18 +1492,25 @@ Without arguments, equivalent to locals().\n\
With an argument, equivalent to object.__dict__.");
static PyObject *
-builtin_trunc(PyObject *self, PyObject *v)
+builtin_trunc(PyObject *self, PyObject *number)
{
- PyObject *res;
- PyObject *d = PyObject_GetAttrString(v, "__trunc__");
- if (d == NULL) {
- PyErr_SetString(PyExc_TypeError,
- "trunc() argument must have __trunc__ attribute");
+ static PyObject *trunc_str = NULL;
+ PyObject *trunc;
+
+ if (trunc_str == NULL) {
+ trunc_str = PyUnicode_FromString("__trunc__");
+ if (trunc_str == NULL)
+ return NULL;
+ }
+
+ trunc = _PyType_Lookup(Py_Type(number), trunc_str);
+ if (trunc == NULL) {
+ PyErr_Format(PyExc_TypeError,
+ "type %.100s doesn't define __trunc__ method",
+ Py_Type(number)->tp_name);
return NULL;
}
- res = PyObject_CallFunction(d, "");
- Py_DECREF(d);
- return res;
+ return PyObject_CallFunction(trunc, "O", number);
}
PyDoc_STRVAR(trunc_doc,