diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/mock_socket.py | 153 | ||||
-rw-r--r-- | Lib/test/test_smtpd.py | 49 | ||||
-rw-r--r-- | Lib/test/test_smtplib.py | 61 |
3 files changed, 188 insertions, 75 deletions
diff --git a/Lib/test/mock_socket.py b/Lib/test/mock_socket.py new file mode 100644 index 0000000..1512642 --- /dev/null +++ b/Lib/test/mock_socket.py @@ -0,0 +1,153 @@ +"""Mock socket module used by the smtpd and smtplib tests. +""" + +# imported for _GLOBAL_DEFAULT_TIMEOUT +import socket as socket_module + +# Mock socket module +_defaulttimeout = None +_reply_data = None + +# This is used to queue up data to be read through socket.makefile, typically +# *before* the socket object is even created. It is intended to handle a single +# line which the socket will feed on recv() or makefile(). +def reply_with(line): + global _reply_data + _reply_data = line + + +class MockFile: + """Mock file object returned by MockSocket.makefile(). + """ + def __init__(self, lines): + self.lines = lines + def readline(self): + return self.lines.pop(0) + b'\r\n' + def close(self): + pass + + +class MockSocket: + """Mock socket object used by smtpd and smtplib tests. + """ + def __init__(self): + global _reply_data + self.output = [] + self.lines = [] + if _reply_data: + self.lines.append(_reply_data) + self.conn = None + self.timeout = None + + def queue_recv(self, line): + self.lines.append(line) + + def recv(self, bufsize, flags=None): + data = self.lines.pop(0) + b'\r\n' + return data + + def fileno(self): + return 0 + + def settimeout(self, timeout): + if timeout is None: + self.timeout = _defaulttimeout + else: + self.timeout = timeout + + def gettimeout(self): + return self.timeout + + def setsockopt(self, level, optname, value): + pass + + def getsockopt(self, level, optname, buflen=None): + return 0 + + def bind(self, address): + pass + + def accept(self): + self.conn = MockSocket() + return self.conn, 'c' + + def getsockname(self): + return ('0.0.0.0', 0) + + def setblocking(self, flag): + pass + + def listen(self, backlog): + pass + + def makefile(self, mode='r', bufsize=-1): + handle = MockFile(self.lines) + return handle + + def sendall(self, buffer, flags=None): + self.last = data + self.output.append(data) + return len(data) + + def send(self, data, flags=None): + self.last = data + self.output.append(data) + return len(data) + + def getpeername(self): + return 'peer' + + def close(self): + pass + + +def socket(family=None, type=None, proto=None): + return MockSocket() + + +def create_connection(address, timeout=socket_module._GLOBAL_DEFAULT_TIMEOUT): + try: + int_port = int(address[1]) + except ValueError: + raise error + ms = MockSocket() + if timeout is socket_module._GLOBAL_DEFAULT_TIMEOUT: + timeout = getdefaulttimeout() + ms.settimeout(timeout) + return ms + + +def setdefaulttimeout(timeout): + global _defaulttimeout + _defaulttimeout = timeout + + +def getdefaulttimeout(): + return _defaulttimeout + + +def getfqdn(): + return "" + + +def gethostname(): + pass + + +def gethostbyname(name): + return "" + + +class gaierror(Exception): + pass + + +class error(Exception): + pass + + +# Constants +AF_INET = None +SOCK_STREAM = None +SOL_SOCKET = None +SO_REUSEADDR = None diff --git a/Lib/test/test_smtpd.py b/Lib/test/test_smtpd.py index 2b781bd..506ac99 100644 --- a/Lib/test/test_smtpd.py +++ b/Lib/test/test_smtpd.py @@ -1,53 +1,16 @@ -import asynchat from unittest import TestCase +from test import support, mock_socket import socket -from test import support -import asyncore import io import smtpd +import asyncore -# mock-ish socket to sit underneath asyncore -class DummySocket: - def __init__(self): - self.output = [] - self.queue = [] - self.conn = None - def queue_recv(self, line): - self.queue.append(line) - def recv(self, *args): - data = self.queue.pop(0) + b'\r\n' - return data - def fileno(self): - return 0 - def setsockopt(self, *args): - pass - def getsockopt(self, *args): - return 0 - def bind(self, *args): - pass - def accept(self): - self.conn = DummySocket() - return self.conn, 'c' - def listen(self, *args): - pass - def setblocking(self, *args): - pass - def send(self, data): - self.last = data - self.output.append(data) - return len(data) - def getpeername(self): - return 'peer' - def close(self): - pass class DummyServer(smtpd.SMTPServer): def __init__(self, *args): smtpd.SMTPServer.__init__(self, *args) self.messages = [] - def create_socket(self, family, type): - self.family_and_type = (socket.AF_INET, socket.SOCK_STREAM) - self.set_socket(DummySocket()) + def process_message(self, peer, mailfrom, rcpttos, data): self.messages.append((peer, mailfrom, rcpttos, data)) if data == 'return status': @@ -62,11 +25,15 @@ class BrokenDummyServer(DummyServer): class SMTPDChannelTest(TestCase): def setUp(self): + smtpd.socket = asyncore.socket = mock_socket self.debug = smtpd.DEBUGSTREAM = io.StringIO() self.server = DummyServer('a', 'b') conn, addr = self.server.accept() self.channel = smtpd.SMTPChannel(self.server, conn, addr) + def tearDown(self): + asyncore.socket = smtpd.socket = socket + def write_line(self, line): self.channel.socket.queue_recv(line) self.channel.handle_read() @@ -88,7 +55,7 @@ class SMTPDChannelTest(TestCase): b'502 Error: command "EHLO" not implemented\r\n') def test_HELO(self): - name = socket.getfqdn() + name = smtpd.socket.getfqdn() self.write_line(b'HELO test.example') self.assertEqual(self.channel.socket.last, '250 {}\r\n'.format(name).encode('ascii')) diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index d0b2b27..57faf6e 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -9,7 +9,7 @@ import time import select import unittest -from test import support +from test import support, mock_socket try: import threading @@ -48,27 +48,17 @@ def server(evt, buf, serv): serv.close() evt.set() -@unittest.skipUnless(threading, 'Threading required for this test.') class GeneralTests(unittest.TestCase): def setUp(self): - self._threads = support.threading_setup() - self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.settimeout(15) - self.port = support.bind_port(self.sock) - servargs = (self.evt, b"220 Hola mundo\n", self.sock) - self.thread = threading.Thread(target=server, args=servargs) - self.thread.start() - self.evt.wait() - self.evt.clear() + smtplib.socket = mock_socket + self.port = 25 def tearDown(self): - self.evt.wait() - self.thread.join() - support.threading_cleanup(*self._threads) + smtplib.socket = socket def testBasic1(self): + mock_socket.reply_with(b"220 Hola mundo") # connects smtp = smtplib.SMTP(HOST, self.port) smtp.close() @@ -85,12 +75,13 @@ class GeneralTests(unittest.TestCase): smtp.close() def testTimeoutDefault(self): - self.assertTrue(socket.getdefaulttimeout() is None) - socket.setdefaulttimeout(30) + self.assertTrue(mock_socket.getdefaulttimeout() is None) + mock_socket.setdefaulttimeout(30) + self.assertEqual(mock_socket.getdefaulttimeout(), 30) try: smtp = smtplib.SMTP(HOST, self.port) finally: - socket.setdefaulttimeout(None) + mock_socket.setdefaulttimeout(None) self.assertEqual(smtp.sock.gettimeout(), 30) smtp.close() @@ -155,6 +146,8 @@ MSG_END = '------------ END MESSAGE ------------\n' class DebuggingServerTests(unittest.TestCase): def setUp(self): + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn # temporarily replace sys.stdout to capture DebuggingServer output self.old_stdout = sys.stdout self.output = io.StringIO() @@ -176,6 +169,7 @@ class DebuggingServerTests(unittest.TestCase): self.serv_evt.clear() def tearDown(self): + socket.getfqdn = self.real_getfqdn # indicate that the client is finished self.client_evt.set() # wait for the server thread to terminate @@ -251,6 +245,12 @@ class DebuggingServerTests(unittest.TestCase): class NonConnectingTests(unittest.TestCase): + def setUp(self): + smtplib.socket = mock_socket + + def tearDown(self): + smtplib.socket = socket + def testNotConnected(self): # Test various operations on an unconnected SMTP object that # should raise exceptions (at present the attempt in SMTP.send @@ -263,9 +263,9 @@ class NonConnectingTests(unittest.TestCase): def testNonnumericPort(self): # check that non-numeric port raises socket.error - self.assertRaises(socket.error, smtplib.SMTP, + self.assertRaises(mock_socket.error, smtplib.SMTP, "localhost", "bogus") - self.assertRaises(socket.error, smtplib.SMTP, + self.assertRaises(mock_socket.error, smtplib.SMTP, "localhost:bogus") @@ -274,25 +274,15 @@ class NonConnectingTests(unittest.TestCase): class BadHELOServerTests(unittest.TestCase): def setUp(self): + smtplib.socket = mock_socket + mock_socket.reply_with(b"199 no hello for you!") self.old_stdout = sys.stdout self.output = io.StringIO() sys.stdout = self.output - - self._threads = support.threading_setup() - self.evt = threading.Event() - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.settimeout(15) - self.port = support.bind_port(self.sock) - servargs = (self.evt, b"199 no hello for you!\n", self.sock) - self.thread = threading.Thread(target=server, args=servargs) - self.thread.start() - self.evt.wait() - self.evt.clear() + self.port = 25 def tearDown(self): - self.evt.wait() - self.thread.join() - support.threading_cleanup(*self._threads) + smtplib.socket = socket sys.stdout = self.old_stdout def testFailingHELO(self): @@ -405,6 +395,8 @@ class SimSMTPServer(smtpd.SMTPServer): class SMTPSimTests(unittest.TestCase): def setUp(self): + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn self._threads = support.threading_setup() self.serv_evt = threading.Event() self.client_evt = threading.Event() @@ -421,6 +413,7 @@ class SMTPSimTests(unittest.TestCase): self.serv_evt.clear() def tearDown(self): + socket.getfqdn = self.real_getfqdn # indicate that the client is finished self.client_evt.set() # wait for the server thread to terminate |