From 09549f44076e84083cbb15dbd6da9d1a3fd6d7f1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 27 Aug 2007 20:40:10 +0000 Subject: Changes in anticipation of stricter str vs. bytes enforcement. --- Lib/base64.py | 12 ++--- Lib/binhex.py | 4 +- Lib/test/string_tests.py | 51 ++++++++++--------- Lib/test/test_bytes.py | 107 +++++++++++++++++----------------------- Lib/test/test_cgi.py | 2 +- Lib/test/test_codeccallbacks.py | 16 +++--- Lib/test/test_codecs.py | 14 +++--- Lib/test/test_complex.py | 6 +-- Lib/wave.py | 2 +- 9 files changed, 98 insertions(+), 116 deletions(-) diff --git a/Lib/base64.py b/Lib/base64.py index de3f184..cec6422 100755 --- a/Lib/base64.py +++ b/Lib/base64.py @@ -54,7 +54,7 @@ def b64encode(s, altchars=None): encoded = binascii.b2a_base64(s)[:-1] if altchars is not None: if not isinstance(altchars, bytes): - altchars = bytes(altchars) + altchars = bytes(altchars, "ascii") assert len(altchars) == 2, repr(altchars) return _translate(encoded, {'+': altchars[0:1], '/': altchars[1:2]}) return encoded @@ -75,7 +75,7 @@ def b64decode(s, altchars=None): s = bytes(s) if altchars is not None: if not isinstance(altchars, bytes): - altchars = bytes(altchars) + altchars = bytes(altchars, "ascii") assert len(altchars) == 2, repr(altchars) s = _translate(s, {chr(altchars[0]): b'+', chr(altchars[1]): b'/'}) return binascii.a2b_base64(s) @@ -239,7 +239,7 @@ def b32decode(s, casefold=False, map01=None): acc = 0 shift = 35 # Process the last, partial quanta - last = binascii.unhexlify(bytes('%010x' % acc)) + last = binascii.unhexlify(bytes('%010x' % acc, "ascii")) if padchars == 0: last = b'' # No characters elif padchars == 1: @@ -323,8 +323,7 @@ def decode(input, output): def encodestring(s): """Encode a string into multiple lines of base-64 data.""" - if not isinstance(s, bytes): - s = bytes(s) + assert isinstance(s, bytes), repr(s) pieces = [] for i in range(0, len(s), MAXBINSIZE): chunk = s[i : i + MAXBINSIZE] @@ -334,8 +333,7 @@ def encodestring(s): def decodestring(s): """Decode a string.""" - if not isinstance(s, bytes): - s = bytes(s) + assert isinstance(s, bytes), repr(s) return binascii.a2b_base64(s) diff --git a/Lib/binhex.py b/Lib/binhex.py index 8da8961..0dfa475 100644 --- a/Lib/binhex.py +++ b/Lib/binhex.py @@ -191,8 +191,8 @@ class BinHex: nl = len(name) if nl > 63: raise Error, 'Filename too long' - d = bytes([nl]) + bytes(name) + b'\0' - d2 = bytes(finfo.Type) + bytes(finfo.Creator) + d = bytes([nl]) + name.encode("latin-1") + b'\0' + d2 = bytes(finfo.Type, "ascii") + bytes(finfo.Creator, "ascii") # Force all structs to be packed with big-endian d3 = struct.pack('>h', finfo.Flags) diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index 56275d6..bafa23b 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -45,40 +45,44 @@ class BaseTest(unittest.TestCase): else: return obj - # check that object.method(*args) returns result - def checkequal(self, result, object, methodname, *args): + # check that obj.method(*args) returns result + def checkequal(self, result, obj, methodname, *args): result = self.fixtype(result) - object = self.fixtype(object) + obj = self.fixtype(obj) args = self.fixtype(args) - realresult = getattr(object, methodname)(*args) + realresult = getattr(obj, methodname)(*args) self.assertEqual( result, realresult ) # if the original is returned make sure that # this doesn't happen with subclasses - if object == realresult: - class subtype(self.__class__.type2test): - pass - object = subtype(object) - realresult = getattr(object, methodname)(*args) - self.assert_(object is not realresult) - - # check that object.method(*args) raises exc - def checkraises(self, exc, object, methodname, *args): - object = self.fixtype(object) + if obj is realresult: + try: + class subtype(self.__class__.type2test): + pass + except TypeError: + pass # Skip this if we can't subclass + else: + obj = subtype(obj) + realresult = getattr(obj, methodname)(*args) + self.assert_(obj is not realresult) + + # check that obj.method(*args) raises exc + def checkraises(self, exc, obj, methodname, *args): + obj = self.fixtype(obj) args = self.fixtype(args) self.assertRaises( exc, - getattr(object, methodname), + getattr(obj, methodname), *args ) - # call object.method(*args) without any checks - def checkcall(self, object, methodname, *args): - object = self.fixtype(object) + # call obj.method(*args) without any checks + def checkcall(self, obj, methodname, *args): + obj = self.fixtype(obj) args = self.fixtype(args) - getattr(object, methodname)(*args) + getattr(obj, methodname)(*args) def test_count(self): self.checkequal(3, 'aaa', 'count', 'a') @@ -118,14 +122,14 @@ class BaseTest(unittest.TestCase): i, m = divmod(i, base) entry.append(charset[m]) teststrings.add(''.join(entry)) - teststrings = list(teststrings) + teststrings = [self.fixtype(ts) for ts in teststrings] for i in teststrings: - i = self.fixtype(i) n = len(i) for j in teststrings: r1 = i.count(j) if j: - r2, rem = divmod(n - len(i.replace(j, '')), len(j)) + r2, rem = divmod(n - len(i.replace(j, self.fixtype(''))), + len(j)) else: r2, rem = len(i)+1, 0 if rem or r1 != r2: @@ -157,9 +161,8 @@ class BaseTest(unittest.TestCase): i, m = divmod(i, base) entry.append(charset[m]) teststrings.add(''.join(entry)) - teststrings = list(teststrings) + teststrings = [self.fixtype(ts) for ts in teststrings] for i in teststrings: - i = self.fixtype(i) for j in teststrings: loc = i.find(j) r1 = (loc != -1) diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 39d92ed..b7e6800 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -528,9 +528,9 @@ class BytesTest(unittest.TestCase): def test_count(self): b = b'mississippi' - self.assertEqual(b.count('i'), 4) - self.assertEqual(b.count('ss'), 2) - self.assertEqual(b.count('w'), 0) + self.assertEqual(b.count(b'i'), 4) + self.assertEqual(b.count(b'ss'), 2) + self.assertEqual(b.count(b'w'), 0) def test_append(self): b = b'hell' @@ -551,58 +551,58 @@ class BytesTest(unittest.TestCase): def test_startswith(self): b = b'hello' - self.assertFalse(bytes().startswith("anything")) - self.assertTrue(b.startswith("hello")) - self.assertTrue(b.startswith("hel")) - self.assertTrue(b.startswith("h")) - self.assertFalse(b.startswith("hellow")) - self.assertFalse(b.startswith("ha")) + self.assertFalse(bytes().startswith(b"anything")) + self.assertTrue(b.startswith(b"hello")) + self.assertTrue(b.startswith(b"hel")) + self.assertTrue(b.startswith(b"h")) + self.assertFalse(b.startswith(b"hellow")) + self.assertFalse(b.startswith(b"ha")) def test_endswith(self): b = b'hello' - self.assertFalse(bytes().endswith("anything")) - self.assertTrue(b.endswith("hello")) - self.assertTrue(b.endswith("llo")) - self.assertTrue(b.endswith("o")) - self.assertFalse(b.endswith("whello")) - self.assertFalse(b.endswith("no")) + self.assertFalse(bytes().endswith(b"anything")) + self.assertTrue(b.endswith(b"hello")) + self.assertTrue(b.endswith(b"llo")) + self.assertTrue(b.endswith(b"o")) + self.assertFalse(b.endswith(b"whello")) + self.assertFalse(b.endswith(b"no")) def test_find(self): b = b'mississippi' - self.assertEqual(b.find('ss'), 2) - self.assertEqual(b.find('ss', 3), 5) - self.assertEqual(b.find('ss', 1, 7), 2) - self.assertEqual(b.find('ss', 1, 3), -1) - self.assertEqual(b.find('w'), -1) - self.assertEqual(b.find('mississippian'), -1) + self.assertEqual(b.find(b'ss'), 2) + self.assertEqual(b.find(b'ss', 3), 5) + self.assertEqual(b.find(b'ss', 1, 7), 2) + self.assertEqual(b.find(b'ss', 1, 3), -1) + self.assertEqual(b.find(b'w'), -1) + self.assertEqual(b.find(b'mississippian'), -1) def test_rfind(self): b = b'mississippi' - self.assertEqual(b.rfind('ss'), 5) - self.assertEqual(b.rfind('ss', 3), 5) - self.assertEqual(b.rfind('ss', 0, 6), 2) - self.assertEqual(b.rfind('w'), -1) - self.assertEqual(b.rfind('mississippian'), -1) + self.assertEqual(b.rfind(b'ss'), 5) + self.assertEqual(b.rfind(b'ss', 3), 5) + self.assertEqual(b.rfind(b'ss', 0, 6), 2) + self.assertEqual(b.rfind(b'w'), -1) + self.assertEqual(b.rfind(b'mississippian'), -1) def test_index(self): b = b'world' - self.assertEqual(b.index('w'), 0) - self.assertEqual(b.index('orl'), 1) - self.assertRaises(ValueError, lambda: b.index('worm')) - self.assertRaises(ValueError, lambda: b.index('ldo')) + self.assertEqual(b.index(b'w'), 0) + self.assertEqual(b.index(b'orl'), 1) + self.assertRaises(ValueError, b.index, b'worm') + self.assertRaises(ValueError, b.index, b'ldo') def test_rindex(self): # XXX could be more rigorous b = b'world' - self.assertEqual(b.rindex('w'), 0) - self.assertEqual(b.rindex('orl'), 1) - self.assertRaises(ValueError, lambda: b.rindex('worm')) - self.assertRaises(ValueError, lambda: b.rindex('ldo')) + self.assertEqual(b.rindex(b'w'), 0) + self.assertEqual(b.rindex(b'orl'), 1) + self.assertRaises(ValueError, b.rindex, b'worm') + self.assertRaises(ValueError, b.rindex, b'ldo') def test_replace(self): b = b'mississippi' - self.assertEqual(b.replace('i', 'a'), b'massassappa') - self.assertEqual(b.replace('ss', 'x'), b'mixixippi') + self.assertEqual(b.replace(b'i', b'a'), b'massassappa') + self.assertEqual(b.replace(b'ss', b'x'), b'mixixippi') def test_translate(self): b = b'hello' @@ -614,19 +614,19 @@ class BytesTest(unittest.TestCase): def test_split(self): b = b'mississippi' - self.assertEqual(b.split('i'), [b'm', b'ss', b'ss', b'pp', b'']) - self.assertEqual(b.split('ss'), [b'mi', b'i', b'ippi']) - self.assertEqual(b.split('w'), [b]) + self.assertEqual(b.split(b'i'), [b'm', b'ss', b'ss', b'pp', b'']) + self.assertEqual(b.split(b'ss'), [b'mi', b'i', b'ippi']) + self.assertEqual(b.split(b'w'), [b]) # require an arg (no magic whitespace split) - self.assertRaises(TypeError, lambda: b.split()) + self.assertRaises(TypeError, b.split) def test_rsplit(self): b = b'mississippi' - self.assertEqual(b.rsplit('i'), [b'm', b'ss', b'ss', b'pp', b'']) - self.assertEqual(b.rsplit('ss'), [b'mi', b'i', b'ippi']) - self.assertEqual(b.rsplit('w'), [b]) + self.assertEqual(b.rsplit(b'i'), [b'm', b'ss', b'ss', b'pp', b'']) + self.assertEqual(b.rsplit(b'ss'), [b'mi', b'i', b'ippi']) + self.assertEqual(b.rsplit(b'w'), [b]) # require an arg (no magic whitespace split) - self.assertRaises(TypeError, lambda: b.rsplit()) + self.assertRaises(TypeError, b.rsplit) def test_partition(self): b = b'mississippi' @@ -695,30 +695,11 @@ class BytesAsStringTest(test.string_tests.BaseTest): return obj.encode("utf-8") return super().fixtype(obj) - def checkequal(self, result, object, methodname, *args): - object = bytes(object, "utf-8") - realresult = getattr(bytes, methodname)(object, *args) - self.assertEqual( - self.fixtype(result), - realresult - ) - - def checkraises(self, exc, object, methodname, *args): - object = bytes(object, "utf-8") - self.assertRaises( - exc, - getattr(bytes, methodname), - object, - *args - ) - # Currently the bytes containment testing uses a single integer # value. This may not be the final design, but until then the # bytes section with in a bytes containment not valid def test_contains(self): pass - def test_find(self): - pass def test_expandtabs(self): pass def test_upper(self): diff --git a/Lib/test/test_cgi.py b/Lib/test/test_cgi.py index 652b36e..9b3df21 100644 --- a/Lib/test/test_cgi.py +++ b/Lib/test/test_cgi.py @@ -232,7 +232,7 @@ class CgiTests(unittest.TestCase): return a f = TestReadlineFile(tempfile.TemporaryFile()) - f.write('x' * 256 * 1024) + f.write(b'x' * 256 * 1024) f.seek(0) env = {'REQUEST_METHOD':'PUT'} fs = cgi.FieldStorage(fp=f, environ=env) diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py index 9b731d5..6d7e98e 100644 --- a/Lib/test/test_codeccallbacks.py +++ b/Lib/test/test_codeccallbacks.py @@ -675,22 +675,22 @@ class CodecCallbackTest(unittest.TestCase): # enhance coverage of: # Objects/unicodeobject.c::unicode_decode_call_errorhandler() # and callers - self.assertRaises(LookupError, "\xff".decode, "ascii", "test.unknown") + self.assertRaises(LookupError, b"\xff".decode, "ascii", "test.unknown") def baddecodereturn1(exc): return 42 codecs.register_error("test.baddecodereturn1", baddecodereturn1) - self.assertRaises(TypeError, "\xff".decode, "ascii", "test.baddecodereturn1") - self.assertRaises(TypeError, "\\".decode, "unicode-escape", "test.baddecodereturn1") - self.assertRaises(TypeError, "\\x0".decode, "unicode-escape", "test.baddecodereturn1") - self.assertRaises(TypeError, "\\x0y".decode, "unicode-escape", "test.baddecodereturn1") - self.assertRaises(TypeError, "\\Uffffeeee".decode, "unicode-escape", "test.baddecodereturn1") - self.assertRaises(TypeError, "\\uyyyy".decode, "raw-unicode-escape", "test.baddecodereturn1") + self.assertRaises(TypeError, b"\xff".decode, "ascii", "test.baddecodereturn1") + self.assertRaises(TypeError, b"\\".decode, "unicode-escape", "test.baddecodereturn1") + self.assertRaises(TypeError, b"\\x0".decode, "unicode-escape", "test.baddecodereturn1") + self.assertRaises(TypeError, b"\\x0y".decode, "unicode-escape", "test.baddecodereturn1") + self.assertRaises(TypeError, b"\\Uffffeeee".decode, "unicode-escape", "test.baddecodereturn1") + self.assertRaises(TypeError, b"\\uyyyy".decode, "raw-unicode-escape", "test.baddecodereturn1") def baddecodereturn2(exc): return ("?", None) codecs.register_error("test.baddecodereturn2", baddecodereturn2) - self.assertRaises(TypeError, "\xff".decode, "ascii", "test.baddecodereturn2") + self.assertRaises(TypeError, b"\xff".decode, "ascii", "test.baddecodereturn2") handler = PosReturn() codecs.register_error("test.posreturn", handler.handle) diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index f2ee524..cb78048 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -559,7 +559,7 @@ class ReadBufferTest(unittest.TestCase): def test_array(self): import array self.assertEqual( - codecs.readbuffer_encode(array.array("b", bytes("spam"))), + codecs.readbuffer_encode(array.array("b", b"spam")), (b"spam", 4) ) @@ -573,10 +573,10 @@ class ReadBufferTest(unittest.TestCase): class CharBufferTest(unittest.TestCase): def test_string(self): - self.assertEqual(codecs.charbuffer_encode("spam"), (b"spam", 4)) + self.assertEqual(codecs.charbuffer_encode(b"spam"), (b"spam", 4)) def test_empty(self): - self.assertEqual(codecs.charbuffer_encode(""), (b"", 0)) + self.assertEqual(codecs.charbuffer_encode(b""), (b"", 0)) def test_bad_args(self): self.assertRaises(TypeError, codecs.charbuffer_encode) @@ -999,19 +999,19 @@ class IDNACodecTest(unittest.TestCase): def test_incremental_decode(self): self.assertEquals( - "".join(codecs.iterdecode((bytes(chr(c)) for c in b"python.org"), "idna")), + "".join(codecs.iterdecode((bytes([c]) for c in b"python.org"), "idna")), "python.org" ) self.assertEquals( - "".join(codecs.iterdecode((bytes(chr(c)) for c in b"python.org."), "idna")), + "".join(codecs.iterdecode((bytes([c]) for c in b"python.org."), "idna")), "python.org." ) self.assertEquals( - "".join(codecs.iterdecode((bytes(chr(c)) for c in b"xn--pythn-mua.org."), "idna")), + "".join(codecs.iterdecode((bytes([c]) for c in b"xn--pythn-mua.org."), "idna")), "pyth\xf6n.org." ) self.assertEquals( - "".join(codecs.iterdecode((bytes(chr(c)) for c in b"xn--pythn-mua.org."), "idna")), + "".join(codecs.iterdecode((bytes([c]) for c in b"xn--pythn-mua.org."), "idna")), "pyth\xf6n.org." ) diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py index 116ce54..b4082d9 100644 --- a/Lib/test/test_complex.py +++ b/Lib/test/test_complex.py @@ -318,11 +318,11 @@ class ComplexTest(unittest.TestCase): fo = None try: - fo = open(test_support.TESTFN, "wb") + fo = open(test_support.TESTFN, "w") print(a, b, file=fo) fo.close() - fo = open(test_support.TESTFN, "rb") - self.assertEqual(fo.read(), ("%s %s\n" % (a, b)).encode("ascii")) + fo = open(test_support.TESTFN, "r") + self.assertEqual(fo.read(), ("%s %s\n" % (a, b))) finally: if (fo is not None) and (not fo.closed): fo.close() diff --git a/Lib/wave.py b/Lib/wave.py index e0025ec..6627bcf 100644 --- a/Lib/wave.py +++ b/Lib/wave.py @@ -459,7 +459,7 @@ class Wave_write: self._write_header(datasize) def _write_header(self, initlength): - self._file.write('RIFF') + self._file.write(b'RIFF') if not self._nframes: self._nframes = initlength / (self._nchannels * self._sampwidth) self._datalength = self._nframes * self._nchannels * self._sampwidth -- cgit v0.12