summaryrefslogtreecommitdiffstats
path: root/Lib/urllib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/urllib')
-rw-r--r--Lib/urllib/parse.py290
-rw-r--r--Lib/urllib/request.py87
2 files changed, 299 insertions, 78 deletions
diff --git a/Lib/urllib/parse.py b/Lib/urllib/parse.py
index b3494fa..9a3e42e 100644
--- a/Lib/urllib/parse.py
+++ b/Lib/urllib/parse.py
@@ -5,6 +5,9 @@ urlparse module is based upon the following RFC specifications.
RFC 3986 (STD66): "Uniform Resource Identifiers" by T. Berners-Lee, R. Fielding
and L. Masinter, January 2005.
+RFC 2732 : "Format for Literal IPv6 Addresses in URL's by R.Hinden, B.Carpenter
+and L.Masinter, December 1999.
+
RFC 2396: "Uniform Resource Identifiers (URI)": Generic Syntax by T.
Berners-Lee, R. Fielding, and L. Masinter, August 1998.
@@ -28,8 +31,8 @@ import sys
import collections
__all__ = ["urlparse", "urlunparse", "urljoin", "urldefrag",
- "urlsplit", "urlunsplit", "parse_qs", "parse_qsl",
- "quote", "quote_plus", "quote_from_bytes",
+ "urlsplit", "urlunsplit", "urlencode", "parse_qs",
+ "parse_qsl", "quote", "quote_plus", "quote_from_bytes",
"unquote", "unquote_plus", "unquote_to_bytes"]
# A classification of schemes ('' means apply by default)
@@ -57,6 +60,7 @@ scheme_chars = ('abcdefghijklmnopqrstuvwxyz'
'0123456789'
'+-.')
+# XXX: Consider replacing with functools.lru_cache
MAX_CACHE_SIZE = 20
_parse_cache = {}
@@ -66,64 +70,210 @@ def clear_cache():
_safe_quoters.clear()
-class ResultMixin(object):
- """Shared methods for the parsed result objects."""
+# Helpers for bytes handling
+# For 3.2, we deliberately require applications that
+# handle improperly quoted URLs to do their own
+# decoding and encoding. If valid use cases are
+# presented, we may relax this by using latin-1
+# decoding internally for 3.3
+_implicit_encoding = 'ascii'
+_implicit_errors = 'strict'
+
+def _noop(obj):
+ return obj
+
+def _encode_result(obj, encoding=_implicit_encoding,
+ errors=_implicit_errors):
+ return obj.encode(encoding, errors)
+
+def _decode_args(args, encoding=_implicit_encoding,
+ errors=_implicit_errors):
+ return tuple(x.decode(encoding, errors) if x else '' for x in args)
+
+def _coerce_args(*args):
+ # Invokes decode if necessary to create str args
+ # and returns the coerced inputs along with
+ # an appropriate result coercion function
+ # - noop for str inputs
+ # - encoding function otherwise
+ str_input = isinstance(args[0], str)
+ for arg in args[1:]:
+ # We special-case the empty string to support the
+ # "scheme=''" default argument to some functions
+ if arg and isinstance(arg, str) != str_input:
+ raise TypeError("Cannot mix str and non-str arguments")
+ if str_input:
+ return args + (_noop,)
+ return _decode_args(args) + (_encode_result,)
+
+# Result objects are more helpful than simple tuples
+class _ResultMixinStr(object):
+ """Standard approach to encoding parsed results from str to bytes"""
+ __slots__ = ()
+
+ def encode(self, encoding='ascii', errors='strict'):
+ return self._encoded_counterpart(*(x.encode(encoding, errors) for x in self))
+
+
+class _ResultMixinBytes(object):
+ """Standard approach to decoding parsed results from bytes to str"""
+ __slots__ = ()
+
+ def decode(self, encoding='ascii', errors='strict'):
+ return self._decoded_counterpart(*(x.decode(encoding, errors) for x in self))
+
+
+class _NetlocResultMixinBase(object):
+ """Shared methods for the parsed result objects containing a netloc element"""
+ __slots__ = ()
@property
def username(self):
- netloc = self.netloc
- if "@" in netloc:
- userinfo = netloc.rsplit("@", 1)[0]
- if ":" in userinfo:
- userinfo = userinfo.split(":", 1)[0]
- return userinfo
- return None
+ return self._userinfo[0]
@property
def password(self):
- netloc = self.netloc
- if "@" in netloc:
- userinfo = netloc.rsplit("@", 1)[0]
- if ":" in userinfo:
- return userinfo.split(":", 1)[1]
- return None
+ return self._userinfo[1]
@property
def hostname(self):
- netloc = self.netloc
- if "@" in netloc:
- netloc = netloc.rsplit("@", 1)[1]
- if ":" in netloc:
- netloc = netloc.split(":", 1)[0]
- return netloc.lower() or None
+ hostname = self._hostinfo[0]
+ if not hostname:
+ hostname = None
+ elif hostname is not None:
+ hostname = hostname.lower()
+ return hostname
@property
def port(self):
+ port = self._hostinfo[1]
+ if port is not None:
+ port = int(port, 10)
+ return port
+
+
+class _NetlocResultMixinStr(_NetlocResultMixinBase, _ResultMixinStr):
+ __slots__ = ()
+
+ @property
+ def _userinfo(self):
+ netloc = self.netloc
+ userinfo, have_info, hostinfo = netloc.rpartition('@')
+ if have_info:
+ username, have_password, password = userinfo.partition(':')
+ if not have_password:
+ password = None
+ else:
+ username = password = None
+ return username, password
+
+ @property
+ def _hostinfo(self):
+ netloc = self.netloc
+ _, _, hostinfo = netloc.rpartition('@')
+ _, have_open_br, bracketed = hostinfo.partition('[')
+ if have_open_br:
+ hostname, _, port = bracketed.partition(']')
+ _, have_port, port = port.partition(':')
+ else:
+ hostname, have_port, port = hostinfo.partition(':')
+ if not have_port:
+ port = None
+ return hostname, port
+
+
+class _NetlocResultMixinBytes(_NetlocResultMixinBase, _ResultMixinBytes):
+ __slots__ = ()
+
+ @property
+ def _userinfo(self):
netloc = self.netloc
- if "@" in netloc:
- netloc = netloc.rsplit("@", 1)[1]
- if ":" in netloc:
- port = netloc.split(":", 1)[1]
- return int(port, 10)
- return None
+ userinfo, have_info, hostinfo = netloc.rpartition(b'@')
+ if have_info:
+ username, have_password, password = userinfo.partition(b':')
+ if not have_password:
+ password = None
+ else:
+ username = password = None
+ return username, password
+
+ @property
+ def _hostinfo(self):
+ netloc = self.netloc
+ _, _, hostinfo = netloc.rpartition(b'@')
+ _, have_open_br, bracketed = hostinfo.partition(b'[')
+ if have_open_br:
+ hostname, _, port = bracketed.partition(b']')
+ _, have_port, port = port.partition(b':')
+ else:
+ hostname, have_port, port = hostinfo.partition(b':')
+ if not have_port:
+ port = None
+ return hostname, port
+
from collections import namedtuple
-class SplitResult(namedtuple('SplitResult', 'scheme netloc path query fragment'), ResultMixin):
+_DefragResultBase = namedtuple('DefragResult', 'url fragment')
+_SplitResultBase = namedtuple('SplitResult', 'scheme netloc path query fragment')
+_ParseResultBase = namedtuple('ParseResult', 'scheme netloc path params query fragment')
+
+# For backwards compatibility, alias _NetlocResultMixinStr
+# ResultBase is no longer part of the documented API, but it is
+# retained since deprecating it isn't worth the hassle
+ResultBase = _NetlocResultMixinStr
+# Structured result objects for string data
+class DefragResult(_DefragResultBase, _ResultMixinStr):
__slots__ = ()
+ def geturl(self):
+ if self.fragment:
+ return self.url + '#' + self.fragment
+ else:
+ return self.url
+class SplitResult(_SplitResultBase, _NetlocResultMixinStr):
+ __slots__ = ()
def geturl(self):
return urlunsplit(self)
+class ParseResult(_ParseResultBase, _NetlocResultMixinStr):
+ __slots__ = ()
+ def geturl(self):
+ return urlunparse(self)
-class ParseResult(namedtuple('ParseResult', 'scheme netloc path params query fragment'), ResultMixin):
+# Structured result objects for bytes data
+class DefragResultBytes(_DefragResultBase, _ResultMixinBytes):
+ __slots__ = ()
+ def geturl(self):
+ if self.fragment:
+ return self.url + b'#' + self.fragment
+ else:
+ return self.url
+class SplitResultBytes(_SplitResultBase, _NetlocResultMixinBytes):
__slots__ = ()
+ def geturl(self):
+ return urlunsplit(self)
+class ParseResultBytes(_ParseResultBase, _NetlocResultMixinBytes):
+ __slots__ = ()
def geturl(self):
return urlunparse(self)
+# Set up the encode/decode result pairs
+def _fix_result_transcoding():
+ _result_pairs = (
+ (DefragResult, DefragResultBytes),
+ (SplitResult, SplitResultBytes),
+ (ParseResult, ParseResultBytes),
+ )
+ for _decoded, _encoded in _result_pairs:
+ _decoded._encoded_counterpart = _encoded
+ _encoded._decoded_counterpart = _decoded
+
+_fix_result_transcoding()
+del _fix_result_transcoding
def urlparse(url, scheme='', allow_fragments=True):
"""Parse a URL into 6 components:
@@ -131,13 +281,15 @@ def urlparse(url, scheme='', allow_fragments=True):
Return a 6-tuple: (scheme, netloc, path, params, query, fragment).
Note that we don't break the components up in smaller bits
(e.g. netloc is a single string) and we don't expand % escapes."""
+ url, scheme, _coerce_result = _coerce_args(url, scheme)
tuple = urlsplit(url, scheme, allow_fragments)
scheme, netloc, url, query, fragment = tuple
if scheme in uses_params and ';' in url:
url, params = _splitparams(url)
else:
params = ''
- return ParseResult(scheme, netloc, url, params, query, fragment)
+ result = ParseResult(scheme, netloc, url, params, query, fragment)
+ return _coerce_result(result)
def _splitparams(url):
if '/' in url:
@@ -162,11 +314,12 @@ def urlsplit(url, scheme='', allow_fragments=True):
Return a 5-tuple: (scheme, netloc, path, query, fragment).
Note that we don't break the components up in smaller bits
(e.g. netloc is a single string) and we don't expand % escapes."""
+ url, scheme, _coerce_result = _coerce_args(url, scheme)
allow_fragments = bool(allow_fragments)
key = url, scheme, allow_fragments, type(url), type(scheme)
cached = _parse_cache.get(key, None)
if cached:
- return cached
+ return _coerce_result(cached)
if len(_parse_cache) >= MAX_CACHE_SIZE: # avoid runaway growth
clear_cache()
netloc = query = fragment = ''
@@ -177,13 +330,16 @@ def urlsplit(url, scheme='', allow_fragments=True):
url = url[i+1:]
if url[:2] == '//':
netloc, url = _splitnetloc(url, 2)
+ if (('[' in netloc and ']' not in netloc) or
+ (']' in netloc and '[' not in netloc)):
+ raise ValueError("Invalid IPv6 URL")
if allow_fragments and '#' in url:
url, fragment = url.split('#', 1)
if '?' in url:
url, query = url.split('?', 1)
v = SplitResult(scheme, netloc, url, query, fragment)
_parse_cache[key] = v
- return v
+ return _coerce_result(v)
if url.endswith(':') or not url[i+1].isdigit():
for c in url[:i]:
if c not in scheme_chars:
@@ -192,23 +348,27 @@ def urlsplit(url, scheme='', allow_fragments=True):
scheme, url = url[:i].lower(), url[i+1:]
if url[:2] == '//':
netloc, url = _splitnetloc(url, 2)
+ if (('[' in netloc and ']' not in netloc) or
+ (']' in netloc and '[' not in netloc)):
+ raise ValueError("Invalid IPv6 URL")
if allow_fragments and scheme in uses_fragment and '#' in url:
url, fragment = url.split('#', 1)
if scheme in uses_query and '?' in url:
url, query = url.split('?', 1)
v = SplitResult(scheme, netloc, url, query, fragment)
_parse_cache[key] = v
- return v
+ return _coerce_result(v)
def urlunparse(components):
"""Put a parsed URL back together again. This may result in a
slightly different, but equivalent URL, if the URL that was parsed
originally had redundant delimiters, e.g. a ? with an empty query
(the draft states that these are equivalent)."""
- scheme, netloc, url, params, query, fragment = components
+ scheme, netloc, url, params, query, fragment, _coerce_result = (
+ _coerce_args(*components))
if params:
url = "%s;%s" % (url, params)
- return urlunsplit((scheme, netloc, url, query, fragment))
+ return _coerce_result(urlunsplit((scheme, netloc, url, query, fragment)))
def urlunsplit(components):
"""Combine the elements of a tuple as returned by urlsplit() into a
@@ -216,7 +376,8 @@ def urlunsplit(components):
This may result in a slightly different, but equivalent URL, if the URL that
was parsed originally had unnecessary delimiters (for example, a ? with an
empty query; the RFC states that these are equivalent)."""
- scheme, netloc, url, query, fragment = components
+ scheme, netloc, url, query, fragment, _coerce_result = (
+ _coerce_args(*components))
if netloc or (scheme and scheme in uses_netloc and url[:2] != '//'):
if url and url[:1] != '/': url = '/' + url
url = '//' + (netloc or '') + url
@@ -226,7 +387,7 @@ def urlunsplit(components):
url = url + '?' + query
if fragment:
url = url + '#' + fragment
- return url
+ return _coerce_result(url)
def urljoin(base, url, allow_fragments=True):
"""Join a base URL and a possibly relative URL to form an absolute
@@ -235,27 +396,28 @@ def urljoin(base, url, allow_fragments=True):
return url
if not url:
return base
+ base, url, _coerce_result = _coerce_args(base, url)
bscheme, bnetloc, bpath, bparams, bquery, bfragment = \
urlparse(base, '', allow_fragments)
scheme, netloc, path, params, query, fragment = \
urlparse(url, bscheme, allow_fragments)
if scheme != bscheme or scheme not in uses_relative:
- return url
+ return _coerce_result(url)
if scheme in uses_netloc:
if netloc:
- return urlunparse((scheme, netloc, path,
- params, query, fragment))
+ return _coerce_result(urlunparse((scheme, netloc, path,
+ params, query, fragment)))
netloc = bnetloc
if path[:1] == '/':
- return urlunparse((scheme, netloc, path,
- params, query, fragment))
+ return _coerce_result(urlunparse((scheme, netloc, path,
+ params, query, fragment)))
if not path and not params:
path = bpath
params = bparams
if not query:
query = bquery
- return urlunparse((scheme, netloc, path,
- params, query, fragment))
+ return _coerce_result(urlunparse((scheme, netloc, path,
+ params, query, fragment)))
segments = bpath.split('/')[:-1] + path.split('/')
# XXX The stuff below is bogus in various ways...
if segments[-1] == '.':
@@ -277,8 +439,8 @@ def urljoin(base, url, allow_fragments=True):
segments[-1] = ''
elif len(segments) >= 2 and segments[-1] == '..':
segments[-2:] = ['']
- return urlunparse((scheme, netloc, '/'.join(segments),
- params, query, fragment))
+ return _coerce_result(urlunparse((scheme, netloc, '/'.join(segments),
+ params, query, fragment)))
def urldefrag(url):
"""Removes any existing fragment from URL.
@@ -287,12 +449,14 @@ def urldefrag(url):
the URL contained no fragments, the second element is the
empty string.
"""
+ url, _coerce_result = _coerce_args(url)
if '#' in url:
s, n, p, a, q, frag = urlparse(url)
defrag = urlunparse((s, n, p, a, q, ''))
- return defrag, frag
else:
- return url, ''
+ frag = ''
+ defrag = url
+ return _coerce_result(DefragResult(defrag, frag))
def unquote_to_bytes(string):
"""unquote_to_bytes('abc%20def') -> b'abc def'."""
@@ -359,7 +523,8 @@ def unquote(string, encoding='utf-8', errors='replace'):
string += pct_sequence.decode(encoding, errors)
return string
-def parse_qs(qs, keep_blank_values=False, strict_parsing=False):
+def parse_qs(qs, keep_blank_values=False, strict_parsing=False,
+ encoding='utf-8', errors='replace'):
"""Parse a query given as a string argument.
Arguments:
@@ -376,16 +541,22 @@ def parse_qs(qs, keep_blank_values=False, strict_parsing=False):
strict_parsing: flag indicating what to do with parsing errors.
If false (the default), errors are silently ignored.
If true, errors raise a ValueError exception.
+
+ encoding and errors: specify how to decode percent-encoded sequences
+ into Unicode characters, as accepted by the bytes.decode() method.
"""
dict = {}
- for name, value in parse_qsl(qs, keep_blank_values, strict_parsing):
+ pairs = parse_qsl(qs, keep_blank_values, strict_parsing,
+ encoding=encoding, errors=errors)
+ for name, value in pairs:
if name in dict:
dict[name].append(value)
else:
dict[name] = [value]
return dict
-def parse_qsl(qs, keep_blank_values=False, strict_parsing=False):
+def parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
+ encoding='utf-8', errors='replace'):
"""Parse a query given as a string argument.
Arguments:
@@ -402,8 +573,12 @@ def parse_qsl(qs, keep_blank_values=False, strict_parsing=False):
false (the default), errors are silently ignored. If true,
errors raise a ValueError exception.
+ encoding and errors: specify how to decode percent-encoded sequences
+ into Unicode characters, as accepted by the bytes.decode() method.
+
Returns a list, as G-d intended.
"""
+ qs, _coerce_result = _coerce_args(qs)
pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')]
r = []
for name_value in pairs:
@@ -419,10 +594,13 @@ def parse_qsl(qs, keep_blank_values=False, strict_parsing=False):
else:
continue
if len(nv[1]) or keep_blank_values:
- name = unquote(nv[0].replace('+', ' '))
- value = unquote(nv[1].replace('+', ' '))
+ name = nv[0].replace('+', ' ')
+ name = unquote(name, encoding=encoding, errors=errors)
+ name = _coerce_result(name)
+ value = nv[1].replace('+', ' ')
+ value = unquote(value, encoding=encoding, errors=errors)
+ value = _coerce_result(value)
r.append((name, value))
-
return r
def unquote_plus(string, encoding='utf-8', errors='replace'):
diff --git a/Lib/urllib/request.py b/Lib/urllib/request.py
index 220dfe4..ebbebe9 100644
--- a/Lib/urllib/request.py
+++ b/Lib/urllib/request.py
@@ -94,6 +94,7 @@ import re
import socket
import sys
import time
+import collections
from urllib.error import URLError, HTTPError, ContentTooShortError
from urllib.parse import (
@@ -114,11 +115,27 @@ else:
__version__ = sys.version[:3]
_opener = None
-def urlopen(url, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT):
+def urlopen(url, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
+ *, cafile=None, capath=None):
global _opener
- if _opener is None:
- _opener = build_opener()
- return _opener.open(url, data, timeout)
+ if cafile or capath:
+ if not _have_ssl:
+ raise ValueError('SSL support not available')
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.options |= ssl.OP_NO_SSLv2
+ if cafile or capath:
+ context.verify_mode = ssl.CERT_REQUIRED
+ context.load_verify_locations(cafile, capath)
+ check_hostname = True
+ else:
+ check_hostname = False
+ https_handler = HTTPSHandler(context=context, check_hostname=check_hostname)
+ opener = build_opener(https_handler)
+ elif _opener is None:
+ _opener = opener = build_opener()
+ else:
+ opener = _opener
+ return opener.open(url, data, timeout)
def install_opener(opener):
global _opener
@@ -1042,13 +1059,24 @@ class AbstractHTTPHandler(BaseHandler):
if request.data is not None: # POST
data = request.data
+ if isinstance(data, str):
+ raise TypeError("POST data should be bytes"
+ " or an iterable of bytes. It cannot be str.")
if not request.has_header('Content-type'):
request.add_unredirected_header(
'Content-type',
'application/x-www-form-urlencoded')
if not request.has_header('Content-length'):
- request.add_unredirected_header(
- 'Content-length', '%d' % len(data))
+ try:
+ mv = memoryview(data)
+ except TypeError:
+ if isinstance(data, collections.Iterable):
+ raise ValueError("Content-Length should be specified "
+ "for iterable data of type %r %r" % (type(data),
+ data))
+ else:
+ request.add_unredirected_header(
+ 'Content-length', '%d' % (len(mv) * mv.itemsize))
sel_host = host
if request.has_proxy():
@@ -1063,7 +1091,7 @@ class AbstractHTTPHandler(BaseHandler):
return request
- def do_open(self, http_class, req):
+ def do_open(self, http_class, req, **http_conn_args):
"""Return an HTTPResponse object for the request, using http_class.
http_class must implement the HTTPConnection API from http.client.
@@ -1072,7 +1100,8 @@ class AbstractHTTPHandler(BaseHandler):
if not host:
raise URLError('no host given')
- h = http_class(host, timeout=req.timeout) # will parse host:port
+ # will parse host:port
+ h = http_class(host, timeout=req.timeout, **http_conn_args)
headers = dict(req.unredirected_hdrs)
headers.update(dict((k, v) for k, v in req.headers.items()
@@ -1098,7 +1127,7 @@ class AbstractHTTPHandler(BaseHandler):
# Proxy-Authorization should not be sent to origin
# server.
del headers[proxy_auth_hdr]
- h._set_tunnel(req._tunnel_host, headers=tunnel_headers)
+ h.set_tunnel(req._tunnel_host, headers=tunnel_headers)
try:
h.request(req.get_method(), req.selector, req.data, headers)
@@ -1124,10 +1153,18 @@ class HTTPHandler(AbstractHTTPHandler):
http_request = AbstractHTTPHandler.do_request_
if hasattr(http.client, 'HTTPSConnection'):
+ import ssl
+
class HTTPSHandler(AbstractHTTPHandler):
+ def __init__(self, debuglevel=0, context=None, check_hostname=None):
+ AbstractHTTPHandler.__init__(self, debuglevel)
+ self._context = context
+ self._check_hostname = check_hostname
+
def https_open(self, req):
- return self.do_open(http.client.HTTPSConnection, req)
+ return self.do_open(http.client.HTTPSConnection, req,
+ context=self._context, check_hostname=self._check_hostname)
https_request = AbstractHTTPHandler.do_request_
@@ -1213,8 +1250,8 @@ class FileHandler(BaseHandler):
url = req.selector
if url[:2] == '//' and url[2:3] != '/' and (req.host and
req.host != 'localhost'):
- req.type = 'ftp'
- return self.parent.open(req)
+ if not req.host is self.get_names():
+ raise URLError("file:// scheme is supported only on localhost")
else:
return self.open_local_file(req)
@@ -1375,9 +1412,7 @@ class CacheFTPHandler(FTPHandler):
MAXFTPCACHE = 10 # Trim the ftp cache beyond this size
# Helper for non-unix systems
-if os.name == 'mac':
- from macurl2path import url2pathname, pathname2url
-elif os.name == 'nt':
+if os.name == 'nt':
from nturl2path import url2pathname, pathname2url
else:
def url2pathname(pathname):
@@ -1516,7 +1551,7 @@ class URLopener:
try:
fp = self.open_local_file(url1)
hdrs = fp.info()
- del fp
+ fp.close()
return url2pathname(splithost(url1)[1]), hdrs
except IOError as msg:
pass
@@ -1560,8 +1595,6 @@ class URLopener:
tfp.close()
finally:
fp.close()
- del fp
- del tfp
# raise exception if actual size does not match content-length header
if size >= 0 and read < size:
@@ -1635,6 +1668,12 @@ class URLopener:
headers["Authorization"] = "Basic %s" % auth
if realhost:
headers["Host"] = realhost
+
+ # Add Connection:close as we don't support persistent connections yet.
+ # This helps in closing the socket and avoiding ResourceWarning
+
+ headers["Connection"] = "close"
+
for header, value in self.addheaders:
headers[header] = value
@@ -1701,7 +1740,7 @@ class URLopener:
if not isinstance(url, str):
raise URLError('file error', 'proxy support for file protocol currently not implemented')
if url[:2] == '//' and url[2:3] != '/' and url[2:12].lower() != 'localhost/':
- return self.open_ftp(url)
+ raise ValueError("file:// scheme is supported only on localhost")
else:
return self.open_local_file(url)
@@ -2124,7 +2163,7 @@ class ftpwrapper:
# Try to retrieve as a file
try:
cmd = 'RETR ' + file
- conn = self.ftp.ntransfercmd(cmd)
+ conn, retrlen = self.ftp.ntransfercmd(cmd)
except ftplib.error_perm as reason:
if str(reason)[:3] != '550':
raise URLError('ftp error', reason).with_traceback(
@@ -2145,10 +2184,14 @@ class ftpwrapper:
cmd = 'LIST ' + file
else:
cmd = 'LIST'
- conn = self.ftp.ntransfercmd(cmd)
+ conn, retrlen = self.ftp.ntransfercmd(cmd)
self.busy = 1
+
+ ftpobj = addclosehook(conn.makefile('rb'), self.endtransfer)
+ conn.close()
# Pass back both a suitably decorated object and a retrieval length
- return (addclosehook(conn[0].makefile('rb'), self.endtransfer), conn[1])
+ return (ftpobj, retrlen)
+
def endtransfer(self):
if not self.busy:
return