summaryrefslogtreecommitdiffstats
path: root/Lib/xmlrpc/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/xmlrpc/client.py')
-rw-r--r--Lib/xmlrpc/client.py52
1 files changed, 36 insertions, 16 deletions
diff --git a/Lib/xmlrpc/client.py b/Lib/xmlrpc/client.py
index ff931e7..59f7546 100644
--- a/Lib/xmlrpc/client.py
+++ b/Lib/xmlrpc/client.py
@@ -386,8 +386,8 @@ class Binary:
if data is None:
data = b""
else:
- if not isinstance(data, bytes):
- raise TypeError("expected bytes, not %s" %
+ if not isinstance(data, (bytes, bytearray)):
+ raise TypeError("expected bytes or bytearray, not %s" %
data.__class__.__name__)
data = bytes(data) # Make a copy of the bytes!
self.data = data
@@ -559,6 +559,14 @@ class Marshaller:
write("</string></value>\n")
dispatch[str] = dump_unicode
+ def dump_bytes(self, value, write):
+ write("<value><base64>\n")
+ encoded = base64.encodebytes(value)
+ write(encoded.decode('ascii'))
+ write("</base64></value>\n")
+ dispatch[bytes] = dump_bytes
+ dispatch[bytearray] = dump_bytes
+
def dump_array(self, value, write):
i = id(value)
if i in self.memo:
@@ -629,7 +637,7 @@ class Unmarshaller:
# and again, if you don't understand what's going on in here,
# that's perfectly ok.
- def __init__(self, use_datetime=False):
+ def __init__(self, use_datetime=False, use_builtin_types=False):
self._type = None
self._stack = []
self._marks = []
@@ -637,7 +645,8 @@ class Unmarshaller:
self._methodname = None
self._encoding = "utf-8"
self.append = self._stack.append
- self._use_datetime = use_datetime
+ self._use_datetime = use_builtin_types or use_datetime
+ self._use_bytes = use_builtin_types
def close(self):
# return response tuple and target method
@@ -749,6 +758,8 @@ class Unmarshaller:
def end_base64(self, data):
value = Binary()
value.decode(data.encode("ascii"))
+ if self._use_bytes:
+ value = value.data
self.append(value)
self._value = 0
dispatch["base64"] = end_base64
@@ -860,21 +871,26 @@ FastMarshaller = FastParser = FastUnmarshaller = None
#
# return A (parser, unmarshaller) tuple.
-def getparser(use_datetime=False):
+def getparser(use_datetime=False, use_builtin_types=False):
"""getparser() -> parser, unmarshaller
Create an instance of the fastest available parser, and attach it
to an unmarshalling object. Return both objects.
"""
if FastParser and FastUnmarshaller:
- if use_datetime:
+ if use_builtin_types:
+ mkdatetime = _datetime_type
+ mkbytes = base64.decodebytes
+ elif use_datetime:
mkdatetime = _datetime_type
+ mkbytes = _binary
else:
mkdatetime = _datetime
- target = FastUnmarshaller(True, False, _binary, mkdatetime, Fault)
+ mkbytes = _binary
+ target = FastUnmarshaller(True, False, mkbytes, mkdatetime, Fault)
parser = FastParser(target)
else:
- target = Unmarshaller(use_datetime=use_datetime)
+ target = Unmarshaller(use_datetime=use_datetime, use_builtin_types=use_builtin_types)
if FastParser:
parser = FastParser(target)
else:
@@ -912,7 +928,7 @@ def dumps(params, methodname=None, methodresponse=None, encoding=None,
encoding: the packet encoding (default is UTF-8)
- All 8-bit strings in the data structure are assumed to use the
+ All byte strings in the data structure are assumed to use the
packet encoding. Unicode strings are automatically converted,
where necessary.
"""
@@ -971,7 +987,7 @@ def dumps(params, methodname=None, methodresponse=None, encoding=None,
# (None if not present).
# @see Fault
-def loads(data, use_datetime=False):
+def loads(data, use_datetime=False, use_builtin_types=False):
"""data -> unmarshalled data, method name
Convert an XML-RPC packet to unmarshalled data plus a method
@@ -980,7 +996,7 @@ def loads(data, use_datetime=False):
If the XML-RPC packet represents a fault condition, this function
raises a Fault exception.
"""
- p, u = getparser(use_datetime=use_datetime)
+ p, u = getparser(use_datetime=use_datetime, use_builtin_types=use_builtin_types)
p.feed(data)
p.close()
return u.close(), u.getmethodname()
@@ -1092,8 +1108,9 @@ class Transport:
# that they can decode such a request
encode_threshold = None #None = don't encode
- def __init__(self, use_datetime=False):
+ def __init__(self, use_datetime=False, use_builtin_types=False):
self._use_datetime = use_datetime
+ self._use_builtin_types = use_builtin_types
self._connection = (None, None)
self._extra_headers = []
@@ -1154,7 +1171,8 @@ class Transport:
def getparser(self):
# get parser and unmarshaller
- return getparser(use_datetime=self._use_datetime)
+ return getparser(use_datetime=self._use_datetime,
+ use_builtin_types=self._use_builtin_types)
##
# Get authorization info from host parameter
@@ -1361,7 +1379,7 @@ class ServerProxy:
"""
def __init__(self, uri, transport=None, encoding=None, verbose=False,
- allow_none=False, use_datetime=False):
+ allow_none=False, use_datetime=False, use_builtin_types=False):
# establish a "logical" server connection
# get the url
@@ -1375,9 +1393,11 @@ class ServerProxy:
if transport is None:
if type == "https":
- transport = SafeTransport(use_datetime=use_datetime)
+ handler = SafeTransport
else:
- transport = Transport(use_datetime=use_datetime)
+ handler = Transport
+ transport = handler(use_datetime=use_datetime,
+ use_builtin_types=use_builtin_types)
self.__transport = transport
self.__encoding = encoding or 'utf-8'