summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTim Peters <tim.peters@gmail.com>2001-05-05 05:36:48 (GMT)
committerTim Peters <tim.peters@gmail.com>2001-05-05 05:36:48 (GMT)
commit2cfe36828342e16cd274b968736a01aed5c49557 (patch)
tree8ee8bf38509e6abf06a98d561973d3a3eccee01d
parent432b42aa4c31fd473690ffeee446dcd493f8a8aa (diff)
downloadcpython-2cfe36828342e16cd274b968736a01aed5c49557.zip
cpython-2cfe36828342e16cd274b968736a01aed5c49557.tar.gz
cpython-2cfe36828342e16cd274b968736a01aed5c49557.tar.bz2
Make unicode.join() work nice with iterators. This also required a change
to string.join(), so that when the latter figures out in midstream that it really needs unicode.join() instead, unicode.join() can actually get all the sequence elements (i.e., there's no guarantee that the sequence passed to string.join() can be iterated over *again* by unicode.join(), so string.join() must not pass on the original sequence object anymore).
-rw-r--r--Lib/test/test_iter.py41
-rw-r--r--Misc/NEWS2
-rw-r--r--Objects/stringobject.c9
-rw-r--r--Objects/unicodeobject.c26
4 files changed, 65 insertions, 13 deletions
diff --git a/Lib/test/test_iter.py b/Lib/test/test_iter.py
index bfe032f..073ffb4 100644
--- a/Lib/test/test_iter.py
+++ b/Lib/test/test_iter.py
@@ -431,4 +431,45 @@ class TestCase(unittest.TestCase):
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(reduce(add, d), "".join(d.keys()))
+ def test_unicode_join_endcase(self):
+
+ # This class inserts a Unicode object into its argument's natural
+ # iteration, in the 3rd position.
+ class OhPhooey:
+ def __init__(self, seq):
+ self.it = iter(seq)
+ self.i = 0
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ i = self.i
+ self.i = i+1
+ if i == 2:
+ return u"fooled you!"
+ return self.it.next()
+
+ f = open(TESTFN, "w")
+ try:
+ f.write("a\n" + "b\n" + "c\n")
+ finally:
+ f.close()
+
+ f = open(TESTFN, "r")
+ # Nasty: string.join(s) can't know whether unicode.join() is needed
+ # until it's seen all of s's elements. But in this case, f's
+ # iterator cannot be restarted. So what we're testing here is
+ # whether string.join() can manage to remember everything it's seen
+ # and pass that on to unicode.join().
+ try:
+ got = " - ".join(OhPhooey(f))
+ self.assertEqual(got, u"a\n - b\n - fooled you! - c\n")
+ finally:
+ f.close()
+ try:
+ unlink(TESTFN)
+ except OSError:
+ pass
+
run_unittest(TestCase)
diff --git a/Misc/NEWS b/Misc/NEWS
index 0d7857f..d556afa 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -25,7 +25,7 @@ Core
reduce()
string.join()
tuple()
- XXX TODO unicode.join()
+ unicode.join()
XXX TODO zip()
XXX TODO 'x in y'
diff --git a/Objects/stringobject.c b/Objects/stringobject.c
index b905679..87d7c195 100644
--- a/Objects/stringobject.c
+++ b/Objects/stringobject.c
@@ -861,8 +861,15 @@ string_join(PyStringObject *self, PyObject *args)
item = PySequence_Fast_GET_ITEM(seq, i);
if (!PyString_Check(item)){
if (PyUnicode_Check(item)) {
+ /* Defer to Unicode join.
+ * CAUTION: There's no gurantee that the
+ * original sequence can be iterated over
+ * again, so we must pass seq here.
+ */
+ PyObject *result;
+ result = PyUnicode_Join((PyObject *)self, seq);
Py_DECREF(seq);
- return PyUnicode_Join((PyObject *)self, orig);
+ return result;
}
PyErr_Format(PyExc_TypeError,
"sequence item %i: expected string,"
diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c
index e52d628..5da4d2f 100644
--- a/Objects/unicodeobject.c
+++ b/Objects/unicodeobject.c
@@ -2724,10 +2724,11 @@ PyObject *PyUnicode_Join(PyObject *separator,
int seqlen = 0;
int sz = 100;
int i;
+ PyObject *it;
- seqlen = PySequence_Size(seq);
- if (seqlen < 0 && PyErr_Occurred())
- return NULL;
+ it = PyObject_GetIter(seq);
+ if (it == NULL)
+ return NULL;
if (separator == NULL) {
Py_UNICODE blank = ' ';
@@ -2737,7 +2738,7 @@ PyObject *PyUnicode_Join(PyObject *separator,
else {
separator = PyUnicode_FromObject(separator);
if (separator == NULL)
- return NULL;
+ goto onError;
sep = PyUnicode_AS_UNICODE(separator);
seplen = PyUnicode_GET_SIZE(separator);
}
@@ -2748,13 +2749,14 @@ PyObject *PyUnicode_Join(PyObject *separator,
p = PyUnicode_AS_UNICODE(res);
reslen = 0;
- for (i = 0; i < seqlen; i++) {
+ for (i = 0; ; ++i) {
int itemlen;
- PyObject *item;
-
- item = PySequence_GetItem(seq, i);
- if (item == NULL)
- goto onError;
+ PyObject *item = PyIter_Next(it);
+ if (item == NULL) {
+ if (PyErr_Occurred())
+ goto onError;
+ break;
+ }
if (!PyUnicode_Check(item)) {
PyObject *v;
v = PyUnicode_FromObject(item);
@@ -2784,11 +2786,13 @@ PyObject *PyUnicode_Join(PyObject *separator,
goto onError;
Py_XDECREF(separator);
+ Py_DECREF(it);
return (PyObject *)res;
onError:
Py_XDECREF(separator);
- Py_DECREF(res);
+ Py_XDECREF(res);
+ Py_DECREF(it);
return NULL;
}