From a6c04bed1ed51e87bf9a24bc4b9ab9364821aef5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 3 Nov 2007 00:24:24 +0000 Subject: Patch 1171 by mfenniak -- allow subclassing of bytes. I suspect this has some problems when the subclass is evil, but that's for later. --- Lib/test/test_bytes.py | 83 ++++++++++++++++++++++++++++++++++++++++++++++++-- Objects/bytesobject.c | 15 ++++++--- 2 files changed, 92 insertions(+), 6 deletions(-) diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 112cb32..932fa44 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -3,6 +3,7 @@ import os import re import sys +import copy import pickle import tempfile import unittest @@ -782,11 +783,89 @@ class BytesAsStringTest(test.string_tests.BaseTest): pass +class BytesSubclass(bytes): + pass + +class BytesSubclassTest(unittest.TestCase): + + def test_basic(self): + self.assert_(issubclass(BytesSubclass, bytes)) + self.assert_(isinstance(BytesSubclass(), bytes)) + + a, b = b"abcd", b"efgh" + _a, _b = BytesSubclass(a), BytesSubclass(b) + + # test comparison operators with subclass instances + self.assert_(_a == _a) + self.assert_(_a != _b) + self.assert_(_a < _b) + self.assert_(_a <= _b) + self.assert_(_b >= _a) + self.assert_(_b > _a) + self.assert_(_a is not a) + + # test concat of subclass instances + self.assertEqual(a + b, _a + _b) + self.assertEqual(a + b, a + _b) + self.assertEqual(a + b, _a + b) + + # test repeat + self.assert_(a*5 == _a*5) + + def test_join(self): + # Make sure join returns a NEW object for single item sequences + # involving a subclass. + # Make sure that it is of the appropriate type. + s1 = BytesSubclass(b"abcd") + s2 = b"".join([s1]) + self.assert_(s1 is not s2) + self.assert_(type(s2) is bytes) + + # Test reverse, calling join on subclass + s3 = s1.join([b"abcd"]) + self.assert_(type(s3) is bytes) + + def test_pickle(self): + a = BytesSubclass(b"abcd") + a.x = 10 + a.y = BytesSubclass(b"efgh") + for proto in range(pickle.HIGHEST_PROTOCOL): + b = pickle.loads(pickle.dumps(a, proto)) + self.assertNotEqual(id(a), id(b)) + self.assertEqual(a, b) + self.assertEqual(a.x, b.x) + self.assertEqual(a.y, b.y) + self.assertEqual(type(a), type(b)) + self.assertEqual(type(a.y), type(b.y)) + + def test_copy(self): + a = BytesSubclass(b"abcd") + a.x = 10 + a.y = BytesSubclass(b"efgh") + for copy_method in (copy.copy, copy.deepcopy): + b = copy_method(a) + self.assertNotEqual(id(a), id(b)) + self.assertEqual(a, b) + self.assertEqual(a.x, b.x) + self.assertEqual(a.y, b.y) + self.assertEqual(type(a), type(b)) + self.assertEqual(type(a.y), type(b.y)) + + def test_init_override(self): + class subclass(bytes): + def __init__(self, newarg=1, *args, **kwargs): + bytes.__init__(self, *args, **kwargs) + x = subclass(4, source=b"abcd") + self.assertEqual(x, b"abcd") + x = subclass(newarg=4, source=b"abcd") + self.assertEqual(x, b"abcd") + + def test_main(): test.test_support.run_unittest(BytesTest) test.test_support.run_unittest(BytesAsStringTest) + test.test_support.run_unittest(BytesSubclassTest) test.test_support.run_unittest(BufferPEP3137Test) if __name__ == "__main__": - ##test_main() - unittest.main() + test_main() diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c index 2595ff2..3f2dbc2 100644 --- a/Objects/bytesobject.c +++ b/Objects/bytesobject.c @@ -2921,13 +2921,21 @@ PyDoc_STRVAR(reduce_doc, "Return state information for pickling."); static PyObject * bytes_reduce(PyBytesObject *self) { - PyObject *latin1; + PyObject *latin1, *dict; if (self->ob_bytes) latin1 = PyUnicode_DecodeLatin1(self->ob_bytes, Py_Size(self), NULL); else latin1 = PyUnicode_FromString(""); - return Py_BuildValue("(O(Ns))", Py_Type(self), latin1, "latin-1"); + + dict = PyObject_GetAttrString((PyObject *)self, "__dict__"); + if (dict == NULL) { + PyErr_Clear(); + dict = Py_None; + Py_INCREF(dict); + } + + return Py_BuildValue("(O(Ns)N)", Py_Type(self), latin1, "latin-1", dict); } static PySequenceMethods bytes_as_sequence = { @@ -3045,8 +3053,7 @@ PyTypeObject PyBytes_Type = { PyObject_GenericGetAttr, /* tp_getattro */ 0, /* tp_setattro */ &bytes_as_buffer, /* tp_as_buffer */ - /* bytes is 'final' or 'sealed' */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ bytes_doc, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ -- cgit v0.12