diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2015-03-21 07:40:26 (GMT) |
---|---|---|
committer | Serhiy Storchaka <storchaka@gmail.com> | 2015-03-21 07:40:26 (GMT) |
commit | 52027c301a0b3675bb5db23d33eede3f6b19395f (patch) | |
tree | 1d43b46d6863740f5b776465f459c493bd0bca52 /Lib | |
parent | 63998a3520ac4c2217946baf99574d9e6a6a959d (diff) | |
download | cpython-52027c301a0b3675bb5db23d33eede3f6b19395f.zip cpython-52027c301a0b3675bb5db23d33eede3f6b19395f.tar.gz cpython-52027c301a0b3675bb5db23d33eede3f6b19395f.tar.bz2 |
Issue #22351: The nntplib.NNTP constructor no longer leaves the connection
and socket open until the garbage collector cleans them up. Patch by
Martin Panter.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/nntplib.py | 36 | ||||
-rw-r--r-- | Lib/test/test_nntplib.py | 102 |
2 files changed, 123 insertions, 15 deletions
diff --git a/Lib/nntplib.py b/Lib/nntplib.py index bcf7d1b..3413610 100644 --- a/Lib/nntplib.py +++ b/Lib/nntplib.py @@ -1041,11 +1041,18 @@ class NNTP(_NNTPBase): self.host = host self.port = port self.sock = socket.create_connection((host, port), timeout) - file = self.sock.makefile("rwb") - _NNTPBase.__init__(self, file, host, - readermode, timeout) - if user or usenetrc: - self.login(user, password, usenetrc) + file = None + try: + file = self.sock.makefile("rwb") + _NNTPBase.__init__(self, file, host, + readermode, timeout) + if user or usenetrc: + self.login(user, password, usenetrc) + except: + if file: + file.close() + self.sock.close() + raise def _close(self): try: @@ -1065,12 +1072,19 @@ if _have_ssl: in default port and the `ssl_context` argument for SSL connections. """ self.sock = socket.create_connection((host, port), timeout) - self.sock = _encrypt_on(self.sock, ssl_context, host) - file = self.sock.makefile("rwb") - _NNTPBase.__init__(self, file, host, - readermode=readermode, timeout=timeout) - if user or usenetrc: - self.login(user, password, usenetrc) + file = None + try: + self.sock = _encrypt_on(self.sock, ssl_context, host) + file = self.sock.makefile("rwb") + _NNTPBase.__init__(self, file, host, + readermode=readermode, timeout=timeout) + if user or usenetrc: + self.login(user, password, usenetrc) + except: + if file: + file.close() + self.sock.close() + raise def _close(self): try: diff --git a/Lib/test/test_nntplib.py b/Lib/test/test_nntplib.py index fb216c1..9e88ddb 100644 --- a/Lib/test/test_nntplib.py +++ b/Lib/test/test_nntplib.py @@ -8,6 +8,7 @@ import contextlib from test import support from nntplib import NNTP, GroupInfo import nntplib +from unittest.mock import patch try: import ssl except ImportError: @@ -370,6 +371,14 @@ class _NNTPServerIO(io.RawIOBase): return n +def make_mock_file(handler): + sio = _NNTPServerIO(handler) + # Using BufferedRWPair instead of BufferedRandom ensures the file + # isn't seekable. + file = io.BufferedRWPair(sio, sio) + return (sio, file) + + class MockedNNTPTestsMixin: # Override in derived classes handler_class = None @@ -384,10 +393,7 @@ class MockedNNTPTestsMixin: def make_server(self, *args, **kwargs): self.handler = self.handler_class() - self.sio = _NNTPServerIO(self.handler) - # Using BufferedRWPair instead of BufferedRandom ensures the file - # isn't seekable. - file = io.BufferedRWPair(self.sio, self.sio) + self.sio, file = make_mock_file(self.handler) self.server = nntplib._NNTPBase(file, 'test.server', *args, **kwargs) return self.server @@ -1425,5 +1431,93 @@ class PublicAPITests(unittest.TestCase): target_api.append('NNTP_SSL') self.assertEqual(set(nntplib.__all__), set(target_api)) +class MockSocketTests(unittest.TestCase): + """Tests involving a mock socket object + + Used where the _NNTPServerIO file object is not enough.""" + + nntp_class = nntplib.NNTP + + def check_constructor_error_conditions( + self, handler_class, + expected_error_type, expected_error_msg, + login=None, password=None): + + class mock_socket_module: + def create_connection(address, timeout): + return MockSocket() + + class MockSocket: + def close(self): + nonlocal socket_closed + socket_closed = True + + def makefile(socket, mode): + handler = handler_class() + _, file = make_mock_file(handler) + files.append(file) + return file + + socket_closed = False + files = [] + with patch('nntplib.socket', mock_socket_module), \ + self.assertRaisesRegex(expected_error_type, expected_error_msg): + self.nntp_class('dummy', user=login, password=password) + self.assertTrue(socket_closed) + for f in files: + self.assertTrue(f.closed) + + def test_bad_welcome(self): + #Test a bad welcome message + class Handler(NNTPv1Handler): + welcome = 'Bad Welcome' + self.check_constructor_error_conditions( + Handler, nntplib.NNTPProtocolError, Handler.welcome) + + def test_service_temporarily_unavailable(self): + #Test service temporarily unavailable + class Handler(NNTPv1Handler): + welcome = '400 Service temporarily unavilable' + self.check_constructor_error_conditions( + Handler, nntplib.NNTPTemporaryError, Handler.welcome) + + def test_service_permanently_unavailable(self): + #Test service permanently unavailable + class Handler(NNTPv1Handler): + welcome = '502 Service permanently unavilable' + self.check_constructor_error_conditions( + Handler, nntplib.NNTPPermanentError, Handler.welcome) + + def test_bad_capabilities(self): + #Test a bad capabilities response + class Handler(NNTPv1Handler): + def handle_CAPABILITIES(self): + self.push_lit(capabilities_response) + capabilities_response = '201 bad capability' + self.check_constructor_error_conditions( + Handler, nntplib.NNTPReplyError, capabilities_response) + + def test_login_aborted(self): + #Test a bad authinfo response + login = 't@e.com' + password = 'python' + class Handler(NNTPv1Handler): + def handle_AUTHINFO(self, *args): + self.push_lit(authinfo_response) + authinfo_response = '503 Mechanism not recognized' + self.check_constructor_error_conditions( + Handler, nntplib.NNTPPermanentError, authinfo_response, + login, password) + +@unittest.skipUnless(ssl, 'requires SSL support') +class MockSslTests(MockSocketTests): + class nntp_class(nntplib.NNTP_SSL): + def __init__(self, *pos, **kw): + class bypass_context: + """Bypass encryption and actual SSL module""" + def wrap_socket(sock, **args): + return sock + return super().__init__(*pos, ssl_context=bypass_context, **kw) + if __name__ == "__main__": unittest.main() |