diff options
Diffstat (limited to 'Lib/base64.py')
| -rwxr-xr-x | Lib/base64.py | 64 |
1 files changed, 33 insertions, 31 deletions
diff --git a/Lib/base64.py b/Lib/base64.py index 895d813..b6e82b6 100755 --- a/Lib/base64.py +++ b/Lib/base64.py @@ -29,14 +29,16 @@ __all__ = [ bytes_types = (bytes, bytearray) # Types acceptable as binary data - -def _translate(s, altchars): - if not isinstance(s, bytes_types): - raise TypeError("expected bytes, not %s" % s.__class__.__name__) - translation = bytearray(range(256)) - for k, v in altchars.items(): - translation[ord(k)] = v[0] - return s.translate(translation) +def _bytes_from_decode_data(s): + if isinstance(s, str): + try: + return s.encode('ascii') + except UnicodeEncodeError: + raise ValueError('string argument should contain only ASCII characters') + elif isinstance(s, bytes_types): + return s + else: + raise TypeError("argument should be bytes or ASCII string, not %s" % s.__class__.__name__) @@ -61,7 +63,7 @@ def b64encode(s, altchars=None): raise TypeError("expected bytes, not %s" % altchars.__class__.__name__) assert len(altchars) == 2, repr(altchars) - return _translate(encoded, {'+': altchars[0:1], '/': altchars[1:2]}) + return encoded.translate(bytes.maketrans(b'+/', altchars)) return encoded @@ -79,14 +81,11 @@ def b64decode(s, altchars=None, validate=False): discarded prior to the padding check. If validate is True, non-base64-alphabet characters in the input result in a binascii.Error. """ - if not isinstance(s, bytes_types): - raise TypeError("expected bytes, not %s" % s.__class__.__name__) + s = _bytes_from_decode_data(s) if altchars is not None: - if not isinstance(altchars, bytes_types): - raise TypeError("expected bytes, not %s" - % altchars.__class__.__name__) + altchars = _bytes_from_decode_data(altchars) assert len(altchars) == 2, repr(altchars) - s = _translate(s, {chr(altchars[0]): b'+', chr(altchars[1]): b'/'}) + s = s.translate(bytes.maketrans(altchars, b'+/')) if validate and not re.match(b'^[A-Za-z0-9+/]*={0,2}$', s): raise binascii.Error('Non-base64 digit found') return binascii.a2b_base64(s) @@ -109,6 +108,10 @@ def standard_b64decode(s): """ return b64decode(s) + +_urlsafe_encode_translation = bytes.maketrans(b'+/', b'-_') +_urlsafe_decode_translation = bytes.maketrans(b'-_', b'+/') + def urlsafe_b64encode(s): """Encode a byte string using a url-safe Base64 alphabet. @@ -116,7 +119,7 @@ def urlsafe_b64encode(s): returned. The alphabet uses '-' instead of '+' and '_' instead of '/'. """ - return b64encode(s, b'-_') + return b64encode(s).translate(_urlsafe_encode_translation) def urlsafe_b64decode(s): """Decode a byte string encoded with the standard Base64 alphabet. @@ -128,7 +131,9 @@ def urlsafe_b64decode(s): The alphabet uses '-' instead of '+' and '_' instead of '/'. """ - return b64decode(s, b'-_') + s = _bytes_from_decode_data(s) + s = s.translate(_urlsafe_decode_translation) + return b64decode(s) @@ -161,7 +166,7 @@ def b32encode(s): if leftover: s = s + bytes(5 - leftover) # Don't use += ! quanta += 1 - encoded = bytes() + encoded = bytearray() for i in range(quanta): # c1 and c2 are 16 bits wide, c3 is 8 bits wide. The intent of this # code is to process the 40 bits in units of 5 bits. So we take the 1 @@ -182,14 +187,14 @@ def b32encode(s): ]) # Adjust for any leftover partial quanta if leftover == 1: - return encoded[:-6] + b'======' + encoded[-6:] = b'======' elif leftover == 2: - return encoded[:-4] + b'====' + encoded[-4:] = b'====' elif leftover == 3: - return encoded[:-3] + b'===' + encoded[-3:] = b'===' elif leftover == 4: - return encoded[:-1] + b'=' - return encoded + encoded[-1:] = b'=' + return bytes(encoded) def b32decode(s, casefold=False, map01=None): @@ -211,8 +216,7 @@ def b32decode(s, casefold=False, map01=None): the input is incorrectly padded or if there are non-alphabet characters present in the input. """ - if not isinstance(s, bytes_types): - raise TypeError("expected bytes, not %s" % s.__class__.__name__) + s = _bytes_from_decode_data(s) quanta, leftover = divmod(len(s), 8) if leftover: raise binascii.Error('Incorrect padding') @@ -220,10 +224,9 @@ def b32decode(s, casefold=False, map01=None): # False, or the character to map the digit 1 (one) to. It should be # either L (el) or I (eye). if map01 is not None: - if not isinstance(map01, bytes_types): - raise TypeError("expected bytes, not %s" % map01.__class__.__name__) + map01 = _bytes_from_decode_data(map01) assert len(map01) == 1, repr(map01) - s = _translate(s, {b'0': b'O', b'1': map01}) + s = s.translate(bytes.maketrans(b'01', b'O' + map01)) if casefold: s = s.upper() # Strip off pad characters from the right. We need to count the pad @@ -242,7 +245,7 @@ def b32decode(s, casefold=False, map01=None): for c in s: val = _b32rev.get(c) if val is None: - raise TypeError('Non-base32 digit found') + raise binascii.Error('Non-base32 digit found') acc += _b32rev[c] << shift shift -= 5 if shift < 0: @@ -292,8 +295,7 @@ def b16decode(s, casefold=False): s were incorrectly padded or if there are non-alphabet characters present in the string. """ - if not isinstance(s, bytes_types): - raise TypeError("expected bytes, not %s" % s.__class__.__name__) + s = _bytes_from_decode_data(s) if casefold: s = s.upper() if re.search(b'[^0-9A-F]', s): |
