diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ftplib.py | 22 | ||||
-rw-r--r-- | Lib/test/test_ftplib.py | 115 |
2 files changed, 105 insertions, 32 deletions
diff --git a/Lib/ftplib.py b/Lib/ftplib.py index 0a69b7a..e1c3d99 100644 --- a/Lib/ftplib.py +++ b/Lib/ftplib.py @@ -641,9 +641,21 @@ else: ssl_version = ssl.PROTOCOL_SSLv23 def __init__(self, host='', user='', passwd='', acct='', keyfile=None, - certfile=None, timeout=_GLOBAL_DEFAULT_TIMEOUT): + certfile=None, context=None, + timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None): + if context is not None and keyfile is not None: + raise ValueError("context and keyfile arguments are mutually " + "exclusive") + if context is not None and certfile is not None: + raise ValueError("context and certfile arguments are mutually " + "exclusive") self.keyfile = keyfile self.certfile = certfile + if context is None: + context = ssl._create_stdlib_context(self.ssl_version, + certfile=certfile, + keyfile=keyfile) + self.context = context self._prot_p = False FTP.__init__(self, host, user, passwd, acct, timeout) @@ -660,8 +672,8 @@ else: resp = self.voidcmd('AUTH TLS') else: resp = self.voidcmd('AUTH SSL') - self.sock = ssl.wrap_socket(self.sock, self.keyfile, self.certfile, - ssl_version=self.ssl_version) + self.sock = self.context.wrap_socket(self.sock, + server_hostname=self.host) self.file = self.sock.makefile(mode='rb') return resp @@ -692,8 +704,8 @@ else: def ntransfercmd(self, cmd, rest=None): conn, size = FTP.ntransfercmd(self, cmd, rest) if self._prot_p: - conn = ssl.wrap_socket(conn, self.keyfile, self.certfile, - ssl_version=self.ssl_version) + conn = self.context.wrap_socket(conn, + server_hostname=self.host) return conn, size def retrbinary(self, cmd, callback, blocksize=8192, rest=None): diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index 4c229c0..cc1c19b 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -20,7 +20,7 @@ from test import test_support from test.test_support import HOST, HOSTv6 threading = test_support.import_module('threading') - +TIMEOUT = 3 # the dummy data returned by server over the data channel when # RETR, LIST and NLST commands are issued RETR_DATA = 'abcde12345\r\n' * 1000 @@ -223,6 +223,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): self.active = False self.active_lock = threading.Lock() self.host, self.port = self.socket.getsockname()[:2] + self.handler_instance = None def start(self): assert not self.active @@ -246,8 +247,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): def handle_accept(self): conn, addr = self.accept() - self.handler = self.handler(conn) - self.close() + self.handler_instance = self.handler(conn) def handle_connect(self): self.close() @@ -262,7 +262,8 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): if ssl is not None: - CERTFILE = os.path.join(os.path.dirname(__file__), "keycert.pem") + CERTFILE = os.path.join(os.path.dirname(__file__), "keycert3.pem") + CAFILE = os.path.join(os.path.dirname(__file__), "pycacert.pem") class SSLConnection(object, asyncore.dispatcher): """An asyncore.dispatcher subclass supporting TLS/SSL.""" @@ -271,23 +272,25 @@ if ssl is not None: _ssl_closing = False def secure_connection(self): - self.socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False, - certfile=CERTFILE, server_side=True, - do_handshake_on_connect=False, - ssl_version=ssl.PROTOCOL_SSLv23) + socket = ssl.wrap_socket(self.socket, suppress_ragged_eofs=False, + certfile=CERTFILE, server_side=True, + do_handshake_on_connect=False, + ssl_version=ssl.PROTOCOL_SSLv23) + self.del_channel() + self.set_socket(socket) self._ssl_accepting = True def _do_ssl_handshake(self): try: self.socket.do_handshake() - except ssl.SSLError, err: + except ssl.SSLError as err: if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): return elif err.args[0] == ssl.SSL_ERROR_EOF: return self.handle_close() raise - except socket.error, err: + except socket.error as err: if err.args[0] == errno.ECONNABORTED: return self.handle_close() else: @@ -297,18 +300,21 @@ if ssl is not None: self._ssl_closing = True try: self.socket = self.socket.unwrap() - except ssl.SSLError, err: + except ssl.SSLError as err: if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): return - except socket.error, err: + except socket.error as err: # Any "socket error" corresponds to a SSL_ERROR_SYSCALL return # from OpenSSL's SSL_shutdown(), corresponding to a # closed socket condition. See also: # http://www.mail-archive.com/openssl-users@openssl.org/msg60710.html pass self._ssl_closing = False - super(SSLConnection, self).close() + if getattr(self, '_ccc', False) is False: + super(SSLConnection, self).close() + else: + pass def handle_read_event(self): if self._ssl_accepting: @@ -329,7 +335,7 @@ if ssl is not None: def send(self, data): try: return super(SSLConnection, self).send(data) - except ssl.SSLError, err: + except ssl.SSLError as err: if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN, ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): @@ -339,13 +345,13 @@ if ssl is not None: def recv(self, buffer_size): try: return super(SSLConnection, self).recv(buffer_size) - except ssl.SSLError, err: + except ssl.SSLError as err: if err.args[0] in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): - return '' + return b'' if err.args[0] in (ssl.SSL_ERROR_EOF, ssl.SSL_ERROR_ZERO_RETURN): self.handle_close() - return '' + return b'' raise def handle_error(self): @@ -355,6 +361,8 @@ if ssl is not None: if (isinstance(self.socket, ssl.SSLSocket) and self.socket._sslobj is not None): self._do_ssl_shutdown() + else: + super(SSLConnection, self).close() class DummyTLS_DTPHandler(SSLConnection, DummyDTPHandler): @@ -462,12 +470,12 @@ class TestFTPClass(TestCase): def test_rename(self): self.client.rename('a', 'b') - self.server.handler.next_response = '200' + self.server.handler_instance.next_response = '200' self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b') def test_delete(self): self.client.delete('foo') - self.server.handler.next_response = '199' + self.server.handler_instance.next_response = '199' self.assertRaises(ftplib.error_reply, self.client.delete, 'foo') def test_size(self): @@ -515,7 +523,7 @@ class TestFTPClass(TestCase): def test_storbinary(self): f = StringIO.StringIO(RETR_DATA) self.client.storbinary('stor', f) - self.assertEqual(self.server.handler.last_received_data, RETR_DATA) + self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA) # test new callback arg flag = [] f.seek(0) @@ -527,12 +535,12 @@ class TestFTPClass(TestCase): for r in (30, '30'): f.seek(0) self.client.storbinary('stor', f, rest=r) - self.assertEqual(self.server.handler.rest, str(r)) + self.assertEqual(self.server.handler_instance.rest, str(r)) def test_storlines(self): f = StringIO.StringIO(RETR_DATA.replace('\r\n', '\n')) self.client.storlines('stor', f) - self.assertEqual(self.server.handler.last_received_data, RETR_DATA) + self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA) # test new callback arg flag = [] f.seek(0) @@ -551,14 +559,14 @@ class TestFTPClass(TestCase): def test_makeport(self): self.client.makeport() # IPv4 is in use, just make sure send_eprt has not been used - self.assertEqual(self.server.handler.last_received_cmd, 'port') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'port') def test_makepasv(self): host, port = self.client.makepasv() conn = socket.create_connection((host, port), 10) conn.close() # IPv4 is in use, just make sure send_epsv has not been used - self.assertEqual(self.server.handler.last_received_cmd, 'pasv') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv') def test_line_too_long(self): self.assertRaises(ftplib.Error, self.client.sendcmd, @@ -600,13 +608,13 @@ class TestIPv6Environment(TestCase): def test_makeport(self): self.client.makeport() - self.assertEqual(self.server.handler.last_received_cmd, 'eprt') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'eprt') def test_makepasv(self): host, port = self.client.makepasv() conn = socket.create_connection((host, port), 10) conn.close() - self.assertEqual(self.server.handler.last_received_cmd, 'epsv') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv') def test_transfer(self): def retr(): @@ -642,7 +650,7 @@ class TestTLS_FTPClass(TestCase): def setUp(self): self.server = DummyTLS_FTPServer((HOST, 0)) self.server.start() - self.client = ftplib.FTP_TLS(timeout=10) + self.client = ftplib.FTP_TLS(timeout=TIMEOUT) self.client.connect(self.server.host, self.server.port) def tearDown(self): @@ -695,6 +703,59 @@ class TestTLS_FTPClass(TestCase): finally: self.client.ssl_version = ssl.PROTOCOL_TLSv1 + def test_context(self): + self.client.quit() + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + self.assertRaises(ValueError, ftplib.FTP_TLS, keyfile=CERTFILE, + context=ctx) + self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, + context=ctx) + self.assertRaises(ValueError, ftplib.FTP_TLS, certfile=CERTFILE, + keyfile=CERTFILE, context=ctx) + + self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT) + self.client.connect(self.server.host, self.server.port) + self.assertNotIsInstance(self.client.sock, ssl.SSLSocket) + self.client.auth() + self.assertIs(self.client.sock.context, ctx) + self.assertIsInstance(self.client.sock, ssl.SSLSocket) + + self.client.prot_p() + sock = self.client.transfercmd('list') + try: + self.assertIs(sock.context, ctx) + self.assertIsInstance(sock, ssl.SSLSocket) + finally: + sock.close() + + def test_check_hostname(self): + self.client.quit() + ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.check_hostname = True + ctx.load_verify_locations(CAFILE) + self.client = ftplib.FTP_TLS(context=ctx, timeout=TIMEOUT) + + # 127.0.0.1 doesn't match SAN + self.client.connect(self.server.host, self.server.port) + with self.assertRaises(ssl.CertificateError): + self.client.auth() + # exception quits connection + + self.client.connect(self.server.host, self.server.port) + self.client.prot_p() + with self.assertRaises(ssl.CertificateError): + self.client.transfercmd("list").close() + self.client.quit() + + self.client.connect("localhost", self.server.port) + self.client.auth() + self.client.quit() + + self.client.connect("localhost", self.server.port) + self.client.prot_p() + self.client.transfercmd("list").close() + class TestTimeouts(TestCase): |