diff options
author | Giampaolo RodolĂ <g.rodola@gmail.com> | 2010-05-10 14:53:29 (GMT) |
---|---|---|
committer | Giampaolo RodolĂ <g.rodola@gmail.com> | 2010-05-10 14:53:29 (GMT) |
commit | bd576b75b7bed253b7bf4af5a967e3ee4dc1af8a (patch) | |
tree | fe5924f78b230670672d6097bc3be6724fed9b27 /Lib | |
parent | f95a1b3c53bdd678b64aa608d4375660033460c3 (diff) | |
download | cpython-bd576b75b7bed253b7bf4af5a967e3ee4dc1af8a.zip cpython-bd576b75b7bed253b7bf4af5a967e3ee4dc1af8a.tar.gz cpython-bd576b75b7bed253b7bf4af5a967e3ee4dc1af8a.tar.bz2 |
Fix issue #4972: adds ftplib.FTP context manager protocol
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/ftplib.py | 14 | ||||
-rw-r--r-- | Lib/test/test_ftplib.py | 71 |
2 files changed, 74 insertions, 11 deletions
diff --git a/Lib/ftplib.py b/Lib/ftplib.py index a0c3578..c25ae2a 100644 --- a/Lib/ftplib.py +++ b/Lib/ftplib.py @@ -120,6 +120,20 @@ class FTP: if user: self.login(user, passwd, acct) + def __enter__(self): + return self + + # Context management protocol: try to quit() if active + def __exit__(self, *args): + if self.sock is not None: + try: + self.quit() + except (socket.error, EOFError): + pass + finally: + if self.sock is not None: + self.close() + def connect(self, host='', port=0, timeout=-999): '''Connect to host. Arguments are: - host: hostname to connect to (string, default previous host) diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index b949b69..eb33526 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -10,6 +10,7 @@ import socket import io import errno import os +import time try: import ssl except ImportError: @@ -137,6 +138,9 @@ class DummyFTPHandler(asynchat.async_chat): # sends back the received string (used by the test suite) self.push(arg) + def cmd_noop(self, arg): + self.push('200 noop ok') + def cmd_user(self, arg): self.push('331 username ok') @@ -218,6 +222,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): self.active = False self.active_lock = threading.Lock() self.host, self.port = self.socket.getsockname()[:2] + self.handler_instance = None def start(self): assert not self.active @@ -241,8 +246,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): def handle_accept(self): conn, addr = self.accept() - self.handler = self.handler(conn) - self.close() + self.handler_instance = self.handler(conn) def handle_connect(self): self.close() @@ -459,12 +463,12 @@ class TestFTPClass(TestCase): def test_rename(self): self.client.rename('a', 'b') - self.server.handler.next_response = '200' + self.server.handler_instance.next_response = '200' self.assertRaises(ftplib.error_reply, self.client.rename, 'a', 'b') def test_delete(self): self.client.delete('foo') - self.server.handler.next_response = '199' + self.server.handler_instance.next_response = '199' self.assertRaises(ftplib.error_reply, self.client.delete, 'foo') def test_size(self): @@ -512,7 +516,7 @@ class TestFTPClass(TestCase): def test_storbinary(self): f = io.BytesIO(RETR_DATA.encode('ascii')) self.client.storbinary('stor', f) - self.assertEqual(self.server.handler.last_received_data, RETR_DATA) + self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA) # test new callback arg flag = [] f.seek(0) @@ -524,12 +528,12 @@ class TestFTPClass(TestCase): for r in (30, '30'): f.seek(0) self.client.storbinary('stor', f, rest=r) - self.assertEqual(self.server.handler.rest, str(r)) + self.assertEqual(self.server.handler_instance.rest, str(r)) def test_storlines(self): f = io.BytesIO(RETR_DATA.replace('\r\n', '\n').encode('ascii')) self.client.storlines('stor', f) - self.assertEqual(self.server.handler.last_received_data, RETR_DATA) + self.assertEqual(self.server.handler_instance.last_received_data, RETR_DATA) # test new callback arg flag = [] f.seek(0) @@ -548,14 +552,59 @@ class TestFTPClass(TestCase): def test_makeport(self): self.client.makeport() # IPv4 is in use, just make sure send_eprt has not been used - self.assertEqual(self.server.handler.last_received_cmd, 'port') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'port') def test_makepasv(self): host, port = self.client.makepasv() conn = socket.create_connection((host, port), 2) conn.close() # IPv4 is in use, just make sure send_epsv has not been used - self.assertEqual(self.server.handler.last_received_cmd, 'pasv') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'pasv') + + def test_with_statement(self): + self.client.quit() + + def is_client_connected(): + if self.client.sock is None: + return False + try: + self.client.sendcmd('noop') + except (socket.error, EOFError): + return False + return True + + # base test + with ftplib.FTP(timeout=2) as self.client: + self.client.connect(self.server.host, self.server.port) + self.client.sendcmd('noop') + self.assertTrue(is_client_connected()) + self.assertEqual(self.server.handler_instance.last_received_cmd, 'quit') + self.assertFalse(is_client_connected()) + + # QUIT sent inside the with block + with ftplib.FTP(timeout=2) as self.client: + self.client.connect(self.server.host, self.server.port) + self.client.sendcmd('noop') + self.client.quit() + self.assertEqual(self.server.handler_instance.last_received_cmd, 'quit') + self.assertFalse(is_client_connected()) + + # force a wrong response code to be sent on QUIT: error_perm + # is expected and the connection is supposed to be closed + try: + with ftplib.FTP(timeout=2) as self.client: + self.client.connect(self.server.host, self.server.port) + self.client.sendcmd('noop') + self.server.handler_instance.next_response = '550 error on quit' + except ftplib.error_perm as err: + self.assertEqual(str(err), '550 error on quit') + else: + self.fail('Exception not raised') + # needed to give the threaded server some time to set the attribute + # which otherwise would still be == 'noop' + time.sleep(0.1) + self.assertEqual(self.server.handler_instance.last_received_cmd, 'quit') + self.assertFalse(is_client_connected()) class TestIPv6Environment(TestCase): @@ -575,13 +624,13 @@ class TestIPv6Environment(TestCase): def test_makeport(self): self.client.makeport() - self.assertEqual(self.server.handler.last_received_cmd, 'eprt') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'eprt') def test_makepasv(self): host, port = self.client.makepasv() conn = socket.create_connection((host, port), 2) conn.close() - self.assertEqual(self.server.handler.last_received_cmd, 'epsv') + self.assertEqual(self.server.handler_instance.last_received_cmd, 'epsv') def test_transfer(self): def retr(): |