summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/ftplib.py22
-rw-r--r--Lib/test/test_ftplib.py115
2 files changed, 105 insertions, 32 deletions
diff --git a/Lib/ftplib.py b/Lib/ftplib.py
index 0a69b7a..e1c3d99 100644
--- a/Lib/ftplib.py
+++ b/Lib/ftplib.py
@@ -641,9 +641,21 @@ else:
ssl_version = ssl.PROTOCOL_SSLv23
def __init__(self, host='', user='', passwd='', acct='', keyfile=None,
- certfile=None, timeout=_GLOBAL_DEFAULT_TIMEOUT):
+ certfile=None, context=None,
+ timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None):
+ if context is not None and keyfile is not None:
+ raise ValueError("context and keyfile arguments are mutually "
+ "exclusive")
+ if context is not None and certfile is not None:
+ raise ValueError("context and certfile arguments are mutually "
+ "exclusive")
self.keyfile = keyfile
self.certfile = certfile
+ if context is None:
+ context = ssl._create_stdlib_context(self.ssl_version,
+ certfile=certfile,
+ keyfile=keyfile)
+ self.context = context
self._prot_p = False
FTP.__init__(self, host, user, passwd, acct, timeout)
@@ -660,8 +672,8 @@ else:
resp = self.voidcmd('AUTH TLS')
else:
resp = self.voidcmd('AUTH SSL')
- self.sock = ssl.wrap_socket(self.sock, self.keyfile, self.certfile,
- ssl_version=self.ssl_version)
+ self.sock = self.context.wrap_socket(self.sock,
+ server_hostname=self.host)
self.file = self.sock.makefile(mode='rb')
return resp
@@ -692,8 +704,8 @@ else:
def ntransfercmd(self, cmd, rest=None):
conn, size = FTP.ntransfercmd(self, cmd, rest)
if self._prot_p:
- conn = ssl.wrap_socket(conn, self.keyfile, self.certfile,
- ssl_version=self.ssl_version)
+ conn = self.context.wrap_socket(conn,
+ server_hostname=self.host)
return conn, size
def retrbinary(self, cmd, callback, blocksize=8192, rest=None):
diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py
index 4c229c0..cc1c19b 100644
--- a/Lib/test/test_ftplib.py
+++ b/Lib/test/test_ftplib.py
@@ -20,7 +20,7 @@ from test import test_support
from test.test_support import HOST, HOSTv6
threading = test_support.import_module('threading')
-
+TIMEOUT = 3
# the dummy data returned by server over the data channel when
# RETR, LIST and NLST commands are issued
RETR_DATA = 'abcde12345\r\n' * 1000
@@ -223,6 +223,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
self.active = False
self.active_lock = threading.Lock()
self.host, self.port = self.socket.getsockname()[:2]
+ self.handler_instance = None
def start(self):
assert not self.active
@@ -246,8 +247,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
def handle_accept(self):
conn, addr = self.accept()
- self.handler = self.handler(conn)
- self.close()
+ self.handler_instance = self.handler(conn)
def handle_connect(self):
self.close()
@@ -262,7 +262,8 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread):
if ssl is not None:
- CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem")
+ CERTFILE = os.path.join(os.path.dirname(__file__), "keycert3.pem")
+ CAFILE = os.path.join(os.path.dirname(__file__), "pycacert.pem")
class SSLConnection(object, asyncore.dispatcher):
"""An asyncore.dispatcher subclass supporting TLS/SSL."""
@@ -271,23 +272,25 @@ if ssl is not None:
_ssl_closing = False
def secure_connection(self):
- self.socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
- certfile=CERTFILE, server_side=True,
- do_handshake_on_connect=False,
- ssl_version=ssl.PROTOCOL_SSLv23)
+ socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False,
+ certfile=CERTFILE, server_side=True,
+ do_handshake_on_connect=False,
+ ssl_version=ssl.PROTOCOL_SSLv23)
+ self.del_channel()
+ self.set_socket(socket)
self._ssl_accepting = True
def _do_ssl_handshake(self):
try:
self.socket.do_handshake()
- except ssl.SSLError, err:
+ except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return
elif err.args[0] == ssl.SSL_ERROR_EOF:
return self.handle_close()
raise
- except socket.error, err:
+ except socket.error as err:
if err.args[0] == errno.ECONNABORTED:
return self.handle_close()
else:
@@ -297,18 +300,21 @@ if ssl is not None:
self._ssl_closing = True
try:
self.socket = self.socket.unwrap()
- except ssl.SSLError, err:
+ except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
return
- except socket.error, err:
+ except socket.error as err:
# Any "socket error" corresponds to a SSL_ERROR_SYSCALL return
# from OpenSSL's SSL_shutdown(), corresponding to a
# closed socket condition. See also:
# http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html
pass
self._ssl_closing = False
- super(SSLConnection, self).close()
+ if getattr(self, '_ccc', False) is False:
+ super(SSLConnection, self).close()
+ else:
+ pass
def handle_read_event(self):
if self._ssl_accepting:
@@ -329,7 +335,7 @@ if ssl is not None:
def send(self, data):
try:
return super(SSLConnection, self).send(data)
- except ssl.SSLError, err:
+ except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN,
ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
@@ -339,13 +345,13 @@ if ssl is not None:
def recv(self, buffer_size):
try:
return super(SSLConnection, self).recv(buffer_size)
- except ssl.SSLError, err:
+ except ssl.SSLError as err:
if err.args[0] in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE):
- return ''
+ return b''
if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN):
self.handle_close()
- return ''
+ return b''
raise
def handle_error(self):
@@ -355,6 +361,8 @@ if ssl is not None:
if (isinstance(self.socket, ssl.SSLSocket) and
self.socket._sslobj is not None):
self._do_ssl_shutdown()
+ else:
+ super(SSLConnection, self).close()
class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler):
@@ -462,12 +470,12 @@ class TestFTPClass(TestCase):
def test_rename(self):
self.client.rename('a', 'b')
- self.server.handler.next_response = '200'
+ self.server.handler_instance.next_response = '200'
self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b')
def test_delete(self):
self.client.delete('foo')
- self.server.handler.next_response = '199'
+ self.server.handler_instance.next_response = '199'
self.assertRaises(ftplib.error_reply, self.client.delete, 'foo')
def test_size(self):
@@ -515,7 +523,7 @@ class TestFTPClass(TestCase):
def test_storbinary(self):
f = StringIO.StringIO(RETR_DATA)
self.client.storbinary('stor', f)
- self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
+ self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg
flag = []
f.seek(0)
@@ -527,12 +535,12 @@ class TestFTPClass(TestCase):
for r in (30, '30'):
f.seek(0)
self.client.storbinary('stor', f, rest=r)
- self.assertEqual(self.server.handler.rest, str(r))
+ self.assertEqual(self.server.handler_instance.rest, str(r))
def test_storlines(self):
f = StringIO.StringIO(RETR_DATA.replace('\r\n', '\n'))
self.client.storlines('stor', f)
- self.assertEqual(self.server.handler.last_received_data, RETR_DATA)
+ self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA)
# test new callback arg
flag = []
f.seek(0)
@@ -551,14 +559,14 @@ class TestFTPClass(TestCase):
def test_makeport(self):
self.client.makeport()
# IPv4 is in use, just make sure send_eprt has not been used
- self.assertEqual(self.server.handler.last_received_cmd, 'port')
+ self.assertEqual(self.server.handler_instance.last_received_cmd, 'port')
def test_makepasv(self):
host, port = self.client.makepasv()
conn = socket.create_connection((host, port), 10)
conn.close()
# IPv4 is in use, just make sure send_epsv has not been used
- self.assertEqual(self.server.handler.last_received_cmd, 'pasv')
+ self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv')
def test_line_too_long(self):
self.assertRaises(ftplib.Error, self.client.sendcmd,
@@ -600,13 +608,13 @@ class TestIPv6Environment(TestCase):
def test_makeport(self):
self.client.makeport()
- self.assertEqual(self.server.handler.last_received_cmd, 'eprt')
+ self.assertEqual(self.server.handler_instance.last_received_cmd, 'eprt')
def test_makepasv(self):
host, port = self.client.makepasv()
conn = socket.create_connection((host, port), 10)
conn.close()
- self.assertEqual(self.server.handler.last_received_cmd, 'epsv')
+ self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv')
def test_transfer(self):
def retr():
@@ -642,7 +650,7 @@ class TestTLS_FTPClass(TestCase):
def setUp(self):
self.server = DummyTLS_FTPServer((HOST, 0))
self.server.start()
- self.client = ftplib.FTP_TLS(timeout=10)
+ self.client = ftplib.FTP_TLS(timeout=TIMEOUT)
self.client.connect(self.server.host, self.server.port)
def tearDown(self):
@@ -695,6 +703,59 @@ class TestTLS_FTPClass(TestCase):
finally:
self.client.ssl_version = ssl.PROTOCOL_TLSv1
+ def test_context(self):
+ self.client.quit()
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE,
+ context=ctx)
+ self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
+ context=ctx)
+ self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE,
+ keyfile=CERTFILE, context=ctx)
+
+ self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
+ self.client.connect(self.server.host, self.server.port)
+ self.assertNotIsInstance(self.client.sock, ssl.SSLSocket)
+ self.client.auth()
+ self.assertIs(self.client.sock.context, ctx)
+ self.assertIsInstance(self.client.sock, ssl.SSLSocket)
+
+ self.client.prot_p()
+ sock = self.client.transfercmd('list')
+ try:
+ self.assertIs(sock.context, ctx)
+ self.assertIsInstance(sock, ssl.SSLSocket)
+ finally:
+ sock.close()
+
+ def test_check_hostname(self):
+ self.client.quit()
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ ctx.verify_mode = ssl.CERT_REQUIRED
+ ctx.check_hostname = True
+ ctx.load_verify_locations(CAFILE)
+ self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT)
+
+ # 127.0.0.1 doesn't match SAN
+ self.client.connect(self.server.host, self.server.port)
+ with self.assertRaises(ssl.CertificateError):
+ self.client.auth()
+ # exception quits connection
+
+ self.client.connect(self.server.host, self.server.port)
+ self.client.prot_p()
+ with self.assertRaises(ssl.CertificateError):
+ self.client.transfercmd("list").close()
+ self.client.quit()
+
+ self.client.connect("localhost", self.server.port)
+ self.client.auth()
+ self.client.quit()
+
+ self.client.connect("localhost", self.server.port)
+ self.client.prot_p()
+ self.client.transfercmd("list").close()
+
class TestTimeouts(TestCase):