From 4355a47903f3242222b5807c71ec3dda4a8c8d5c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 4 May 2007 05:00:04 +0000 Subject: Make all of test_bytes pass (except pickling, which is too badly busted). --- Lib/test/test_bytes.py | 76 ++++++++++++++++++++++++------------------------- Objects/bytesobject.c | 30 +++++++------------ Objects/unicodeobject.c | 6 ++++ 3 files changed, 55 insertions(+), 57 deletions(-) diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 7178c06..102eb46 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -102,35 +102,35 @@ class BytesTest(unittest.TestCase): self.failIf(b3 <= b2) def test_compare_to_str(self): - self.assertEqual(b"abc" == "abc", True) - self.assertEqual(b"ab" != "abc", True) - self.assertEqual(b"ab" <= "abc", True) - self.assertEqual(b"ab" < "abc", True) - self.assertEqual(b"abc" >= "ab", True) - self.assertEqual(b"abc" > "ab", True) - - self.assertEqual(b"abc" != "abc", False) - self.assertEqual(b"ab" == "abc", False) - self.assertEqual(b"ab" > "abc", False) - self.assertEqual(b"ab" >= "abc", False) - self.assertEqual(b"abc" < "ab", False) - self.assertEqual(b"abc" <= "ab", False) - - self.assertEqual("abc" == b"abc", True) - self.assertEqual("ab" != b"abc", True) - self.assertEqual("ab" <= b"abc", True) - self.assertEqual("ab" < b"abc", True) - self.assertEqual("abc" >= b"ab", True) - self.assertEqual("abc" > b"ab", True) - - self.assertEqual("abc" != b"abc", False) - self.assertEqual("ab" == b"abc", False) - self.assertEqual("ab" > b"abc", False) - self.assertEqual("ab" >= b"abc", False) - self.assertEqual("abc" < b"ab", False) - self.assertEqual("abc" <= b"ab", False) - - # But they should never compare equal to Unicode! + self.assertEqual(b"abc" == str8("abc"), True) + self.assertEqual(b"ab" != str8("abc"), True) + self.assertEqual(b"ab" <= str8("abc"), True) + self.assertEqual(b"ab" < str8("abc"), True) + self.assertEqual(b"abc" >= str8("ab"), True) + self.assertEqual(b"abc" > str8("ab"), True) + + self.assertEqual(b"abc" != str8("abc"), False) + self.assertEqual(b"ab" == str8("abc"), False) + self.assertEqual(b"ab" > str8("abc"), False) + self.assertEqual(b"ab" >= str8("abc"), False) + self.assertEqual(b"abc" < str8("ab"), False) + self.assertEqual(b"abc" <= str8("ab"), False) + + self.assertEqual(str8("abc") == b"abc", True) + self.assertEqual(str8("ab") != b"abc", True) + self.assertEqual(str8("ab") <= b"abc", True) + self.assertEqual(str8("ab") < b"abc", True) + self.assertEqual(str8("abc") >= b"ab", True) + self.assertEqual(str8("abc") > b"ab", True) + + self.assertEqual(str8("abc") != b"abc", False) + self.assertEqual(str8("ab") == b"abc", False) + self.assertEqual(str8("ab") > b"abc", False) + self.assertEqual(str8("ab") >= b"abc", False) + self.assertEqual(str8("abc") < b"ab", False) + self.assertEqual(str8("abc") <= b"ab", False) + + # Bytes should never compare equal to Unicode! # Test this for all expected byte orders and Unicode character sizes self.assertEqual(b"\0a\0b\0c" == "abc", False) self.assertEqual(b"\0\0\0a\0\0\0b\0\0\0c" == "abc", False) @@ -326,7 +326,7 @@ class BytesTest(unittest.TestCase): sample = "Hello world\n\u1234\u5678\u9abc\udef0" for enc in ("utf8", "utf16"): b = bytes(sample, enc) - self.assertEqual(b, bytes(map(ord, sample.encode(enc)))) + self.assertEqual(b, bytes(sample.encode(enc))) self.assertRaises(UnicodeEncodeError, bytes, sample, "latin1") b = bytes(sample, "latin1", "ignore") self.assertEqual(b, bytes(sample[:-4])) @@ -342,7 +342,7 @@ class BytesTest(unittest.TestCase): self.assertEqual(b.decode("utf8", "ignore"), "Hello world\n") def test_from_buffer(self): - sample = "Hello world\n\x80\x81\xfe\xff" + sample = str8("Hello world\n\x80\x81\xfe\xff") buf = buffer(sample) b = bytes(buf) self.assertEqual(b, bytes(map(ord, sample))) @@ -364,8 +364,8 @@ class BytesTest(unittest.TestCase): b1 = bytes("abc") b2 = bytes("def") self.assertEqual(b1 + b2, bytes("abcdef")) - self.assertEqual(b1 + "def", bytes("abcdef")) - self.assertEqual("def" + b1, bytes("defabc")) + self.assertEqual(b1 + str8("def"), bytes("abcdef")) + self.assertEqual(str8("def") + b1, bytes("defabc")) self.assertRaises(TypeError, lambda: b1 + "def") self.assertRaises(TypeError, lambda: "abc" + b2) @@ -388,7 +388,7 @@ class BytesTest(unittest.TestCase): self.assertEqual(b, bytes("abcdef")) self.assertEqual(b, b1) self.failUnless(b is b1) - b += "xyz" + b += str8("xyz") self.assertEqual(b, b"abcdefxyz") try: b += "" @@ -456,8 +456,8 @@ class BytesTest(unittest.TestCase): b = bytes([0x1a, 0x2b, 0x30]) self.assertEquals(bytes.fromhex('1a2B30'), b) self.assertEquals(bytes.fromhex(' 1A 2B 30 '), b) - self.assertEquals(bytes.fromhex(buffer('')), bytes()) - self.assertEquals(bytes.fromhex(buffer('0000')), bytes([0, 0])) + self.assertEquals(bytes.fromhex(buffer(b'')), bytes()) + self.assertEquals(bytes.fromhex(buffer(b'0000')), bytes([0, 0])) self.assertRaises(ValueError, bytes.fromhex, 'a') self.assertRaises(ValueError, bytes.fromhex, 'rt') self.assertRaises(ValueError, bytes.fromhex, '1a b cd') @@ -717,5 +717,5 @@ def test_main(): if __name__ == "__main__": - test_main() - ##unittest.main() + ##test_main() + unittest.main() diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c index 987a3c5..cb830e3 100644 --- a/Objects/bytesobject.c +++ b/Objects/bytesobject.c @@ -218,6 +218,7 @@ bytes_iconcat(PyBytesObject *self, PyObject *other) Py_ssize_t mysize; Py_ssize_t size; + /* XXX What if other == self? */ osize = _getbuffer(other, &optr); if (osize < 0) { PyErr_Format(PyExc_TypeError, @@ -698,34 +699,25 @@ bytes_init(PyBytesObject *self, PyObject *args, PyObject *kwds) if (PyUnicode_Check(arg)) { /* Encode via the codec registry */ - PyObject *encoded; - char *bytes; - Py_ssize_t size; + PyObject *encoded, *new; if (encoding == NULL) encoding = PyUnicode_GetDefaultEncoding(); encoded = PyCodec_Encode(arg, encoding, errors); if (encoded == NULL) return -1; - if (!PyString_Check(encoded)) { + if (!PyBytes_Check(encoded) && !PyString_Check(encoded)) { PyErr_Format(PyExc_TypeError, - "encoder did not return a string object (type=%.400s)", + "encoder did not return a str8 or bytes object (type=%.400s)", encoded->ob_type->tp_name); Py_DECREF(encoded); return -1; } - bytes = PyString_AS_STRING(encoded); - size = PyString_GET_SIZE(encoded); - if (size < self->ob_alloc) { - self->ob_size = size; - self->ob_bytes[self->ob_size] = '\0'; /* Trailing null byte */ - } - else if (PyBytes_Resize((PyObject *)self, size) < 0) { - Py_DECREF(encoded); - return -1; - } - memcpy(self->ob_bytes, bytes, size); - Py_DECREF(encoded); - return 0; + new = bytes_iconcat(self, encoded); + Py_DECREF(encoded); + if (new == NULL) + return -1; + Py_DECREF(new); + return 0; } /* If it's not unicode, there can't be encoding or errors */ @@ -2689,7 +2681,7 @@ bytes_fromhex(PyObject *cls, PyObject *args) return NULL; buf = PyBytes_AS_STRING(newbytes); - for (i = j = 0; ; i += 2) { + for (i = j = 0; i < len; i += 2) { /* skip over spaces in the input */ while (Py_CHARMASK(hex[i]) == ' ') i++; diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c index 26d6fc6..d4a17ce 100644 --- a/Objects/unicodeobject.c +++ b/Objects/unicodeobject.c @@ -5634,6 +5634,12 @@ unicode_encode(PyUnicodeObject *self, PyObject *args) if (v == NULL) goto onError; if (!PyBytes_Check(v)) { + if (PyString_Check(v)) { + /* Old codec, turn it into bytes */ + PyObject *b = PyBytes_FromObject(v); + Py_DECREF(v); + return b; + } PyErr_Format(PyExc_TypeError, "encoder did not return a bytes object " "(type=%.400s)", -- cgit v0.12