summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_iter.py53
-rw-r--r--Misc/NEWS2
-rw-r--r--Python/ceval.c70
3 files changed, 93 insertions, 32 deletions
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py
index ddc58a7..63e488e 100644
--- a/Lib/test/test_iter.py
+++ b/Lib/test/test_iter.py
@@ -594,4 +594,57 @@ class TestCase(unittest.TestCase):
except OSError:
pass
+ # Test iterators on RHS of unpacking assignments.
+ def test_unpack_iter(self):
+ a, b = 1, 2
+ self.assertEqual((a, b), (1, 2))
+
+ a, b, c = IteratingSequenceClass(3)
+ self.assertEqual((a, b, c), (0, 1, 2))
+
+ try: # too many values
+ a, b = IteratingSequenceClass(3)
+ except ValueError:
+ pass
+ else:
+ self.fail("should have raised ValueError")
+
+ try: # not enough values
+ a, b, c = IteratingSequenceClass(2)
+ except ValueError:
+ pass
+ else:
+ self.fail("should have raised ValueError")
+
+ try: # not iterable
+ a, b, c = len
+ except TypeError:
+ pass
+ else:
+ self.fail("should have raised TypeError")
+
+ a, b, c = {1: 42, 2: 42, 3: 42}.itervalues()
+ self.assertEqual((a, b, c), (42, 42, 42))
+
+ f = open(TESTFN, "w")
+ lines = ("a\n", "bb\n", "ccc\n")
+ try:
+ for line in lines:
+ f.write(line)
+ finally:
+ f.close()
+ f = open(TESTFN, "r")
+ try:
+ a, b, c = f
+ self.assertEqual((a, b, c), lines)
+ finally:
+ f.close()
+ try:
+ unlink(TESTFN)
+ except OSError:
+ pass
+
+ (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
+ self.assertEqual((a, b, c), (0, 1, 42))
+
run_unittest(TestCase)
diff --git a/Misc/NEWS b/Misc/NEWS
index bfbcc5f..7ca09dd 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -108,6 +108,8 @@ Core
extend() method of lists
'x in y' and 'x not in y' (PySequence_Contains() in C API)
operator.countOf() (PySequence_Count() in C API)
+ right-hand side of assignment statements with multiple targets, such as
+ x, y, z = some_iterable_object_returning_exactly_3_values
- Accessing module attributes is significantly faster (for example,
random.random or os.path or yourPythonModule.yourAttribute).
diff --git a/Python/ceval.c b/Python/ceval.c
index a71d48f..fcb8fc3 100644
--- a/Python/ceval.c
+++ b/Python/ceval.c
@@ -476,7 +476,7 @@ enum why_code {
};
static enum why_code do_raise(PyObject *, PyObject *, PyObject *);
-static int unpack_sequence(PyObject *, int, PyObject **);
+static int unpack_iterable(PyObject *, int, PyObject **);
PyObject *
@@ -1488,18 +1488,11 @@ eval_frame(PyFrameObject *f)
}
}
}
- else if (PySequence_Check(v)) {
- if (unpack_sequence(v, oparg,
- stack_pointer + oparg))
- stack_pointer += oparg;
- else
- why = WHY_EXCEPTION;
- }
- else {
- PyErr_SetString(PyExc_TypeError,
- "unpack non-sequence");
+ else if (unpack_iterable(v, oparg,
+ stack_pointer + oparg))
+ stack_pointer += oparg;
+ else
why = WHY_EXCEPTION;
- }
Py_DECREF(v);
break;
@@ -2694,37 +2687,50 @@ do_raise(PyObject *type, PyObject *value, PyObject *tb)
return WHY_EXCEPTION;
}
+/* Iterate v argcnt times and store the results on the stack (via decreasing
+ sp). Return 1 for success, 0 if error. */
+
static int
-unpack_sequence(PyObject *v, int argcnt, PyObject **sp)
+unpack_iterable(PyObject *v, int argcnt, PyObject **sp)
{
- int i;
+ int i = 0;
+ PyObject *it; /* iter(v) */
PyObject *w;
- for (i = 0; i < argcnt; i++) {
- if (! (w = PySequence_GetItem(v, i))) {
- if (PyErr_ExceptionMatches(PyExc_IndexError))
- PyErr_SetString(PyExc_ValueError,
- "unpack sequence of wrong size");
- goto finally;
+ assert(v != NULL);
+
+ it = PyObject_GetIter(v);
+ if (it == NULL)
+ goto Error;
+
+ for (; i < argcnt; i++) {
+ w = PyIter_Next(it);
+ if (w == NULL) {
+ /* Iterator done, via error or exhaustion. */
+ if (!PyErr_Occurred()) {
+ PyErr_Format(PyExc_ValueError,
+ "need more than %d value%s to unpack",
+ i, i == 1 ? "" : "s");
+ }
+ goto Error;
}
*--sp = w;
}
- /* we better get an IndexError now */
- if (PySequence_GetItem(v, i) == NULL) {
- if (PyErr_ExceptionMatches(PyExc_IndexError)) {
- PyErr_Clear();
- return 1;
- }
- /* some other exception occurred. fall through to finally */
+
+ /* We better have exhausted the iterator now. */
+ w = PyIter_Next(it);
+ if (w == NULL) {
+ if (PyErr_Occurred())
+ goto Error;
+ Py_DECREF(it);
+ return 1;
}
- else
- PyErr_SetString(PyExc_ValueError,
- "unpack sequence of wrong size");
+ PyErr_SetString(PyExc_ValueError, "too many values to unpack");
/* fall through */
-finally:
+Error:
for (; i > 0; i--, sp++)
Py_DECREF(*sp);
-
+ Py_XDECREF(it);
return 0;
}