summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRichard Jones <richard@commonground.com.au>2010-08-03 06:39:33 (GMT)
committerRichard Jones <richard@commonground.com.au>2010-08-03 06:39:33 (GMT)
commit64b02de01070d5a05d25b5d6be1cdf7889837ce4 (patch)
tree2c479c9e5a4ceda44c796494d6c6aac2c63aae85
parent0db85e5d46d0c0e57377c0bdbfdccb91eb36ec64 (diff)
downloadcpython-64b02de01070d5a05d25b5d6be1cdf7889837ce4.zip
cpython-64b02de01070d5a05d25b5d6be1cdf7889837ce4.tar.gz
cpython-64b02de01070d5a05d25b5d6be1cdf7889837ce4.tar.bz2
improvements to test_smtplib per issue2423
merged the socket mock introduced in test_smtpd
-rw-r--r--Lib/test/mock_socket.py153
-rw-r--r--Lib/test/test_smtpd.py49
-rw-r--r--Lib/test/test_smtplib.py61
3 files changed, 188 insertions, 75 deletions
diff --git a/Lib/test/mock_socket.py b/Lib/test/mock_socket.py
new file mode 100644
index 0000000..1512642
--- /dev/null
+++ b/Lib/test/mock_socket.py
@@ -0,0 +1,153 @@
+"""Mock socket module used by the smtpd and smtplib tests.
+"""
+
+# imported for _GLOBAL_DEFAULT_TIMEOUT
+import socket as socket_module
+
+# Mock socket module
+_defaulttimeout = None
+_reply_data = None
+
+# This is used to queue up data to be read through socket.makefile, typically
+# *before* the socket object is even created. It is intended to handle a single
+# line which the socket will feed on recv() or makefile().
+def reply_with(line):
+ global _reply_data
+ _reply_data = line
+
+
+class MockFile:
+ """Mock file object returned by MockSocket.makefile().
+ """
+ def __init__(self, lines):
+ self.lines = lines
+ def readline(self):
+ return self.lines.pop(0) + b'\r\n'
+ def close(self):
+ pass
+
+
+class MockSocket:
+ """Mock socket object used by smtpd and smtplib tests.
+ """
+ def __init__(self):
+ global _reply_data
+ self.output = []
+ self.lines = []
+ if _reply_data:
+ self.lines.append(_reply_data)
+ self.conn = None
+ self.timeout = None
+
+ def queue_recv(self, line):
+ self.lines.append(line)
+
+ def recv(self, bufsize, flags=None):
+ data = self.lines.pop(0) + b'\r\n'
+ return data
+
+ def fileno(self):
+ return 0
+
+ def settimeout(self, timeout):
+ if timeout is None:
+ self.timeout = _defaulttimeout
+ else:
+ self.timeout = timeout
+
+ def gettimeout(self):
+ return self.timeout
+
+ def setsockopt(self, level, optname, value):
+ pass
+
+ def getsockopt(self, level, optname, buflen=None):
+ return 0
+
+ def bind(self, address):
+ pass
+
+ def accept(self):
+ self.conn = MockSocket()
+ return self.conn, 'c'
+
+ def getsockname(self):
+ return ('0.0.0.0', 0)
+
+ def setblocking(self, flag):
+ pass
+
+ def listen(self, backlog):
+ pass
+
+ def makefile(self, mode='r', bufsize=-1):
+ handle = MockFile(self.lines)
+ return handle
+
+ def sendall(self, buffer, flags=None):
+ self.last = data
+ self.output.append(data)
+ return len(data)
+
+ def send(self, data, flags=None):
+ self.last = data
+ self.output.append(data)
+ return len(data)
+
+ def getpeername(self):
+ return 'peer'
+
+ def close(self):
+ pass
+
+
+def socket(family=None, type=None, proto=None):
+ return MockSocket()
+
+
+def create_connection(address, timeout=socket_module._GLOBAL_DEFAULT_TIMEOUT):
+ try:
+ int_port = int(address[1])
+ except ValueError:
+ raise error
+ ms = MockSocket()
+ if timeout is socket_module._GLOBAL_DEFAULT_TIMEOUT:
+ timeout = getdefaulttimeout()
+ ms.settimeout(timeout)
+ return ms
+
+
+def setdefaulttimeout(timeout):
+ global _defaulttimeout
+ _defaulttimeout = timeout
+
+
+def getdefaulttimeout():
+ return _defaulttimeout
+
+
+def getfqdn():
+ return ""
+
+
+def gethostname():
+ pass
+
+
+def gethostbyname(name):
+ return ""
+
+
+class gaierror(Exception):
+ pass
+
+
+class error(Exception):
+ pass
+
+
+# Constants
+AF_INET = None
+SOCK_STREAM = None
+SOL_SOCKET = None
+SO_REUSEADDR = None
diff --git a/Lib/test/test_smtpd.py b/Lib/test/test_smtpd.py
index 2b781bd..506ac99 100644
--- a/Lib/test/test_smtpd.py
+++ b/Lib/test/test_smtpd.py
@@ -1,53 +1,16 @@
-import asynchat
from unittest import TestCase
+from test import support, mock_socket
import socket
-from test import support
-import asyncore
import io
import smtpd
+import asyncore
-# mock-ish socket to sit underneath asyncore
-class DummySocket:
- def __init__(self):
- self.output = []
- self.queue = []
- self.conn = None
- def queue_recv(self, line):
- self.queue.append(line)
- def recv(self, *args):
- data = self.queue.pop(0) + b'\r\n'
- return data
- def fileno(self):
- return 0
- def setsockopt(self, *args):
- pass
- def getsockopt(self, *args):
- return 0
- def bind(self, *args):
- pass
- def accept(self):
- self.conn = DummySocket()
- return self.conn, 'c'
- def listen(self, *args):
- pass
- def setblocking(self, *args):
- pass
- def send(self, data):
- self.last = data
- self.output.append(data)
- return len(data)
- def getpeername(self):
- return 'peer'
- def close(self):
- pass
class DummyServer(smtpd.SMTPServer):
def __init__(self, *args):
smtpd.SMTPServer.__init__(self, *args)
self.messages = []
- def create_socket(self, family, type):
- self.family_and_type = (socket.AF_INET, socket.SOCK_STREAM)
- self.set_socket(DummySocket())
+
def process_message(self, peer, mailfrom, rcpttos, data):
self.messages.append((peer, mailfrom, rcpttos, data))
if data == 'return status':
@@ -62,11 +25,15 @@ class BrokenDummyServer(DummyServer):
class SMTPDChannelTest(TestCase):
def setUp(self):
+ smtpd.socket = asyncore.socket = mock_socket
self.debug = smtpd.DEBUGSTREAM = io.StringIO()
self.server = DummyServer('a', 'b')
conn, addr = self.server.accept()
self.channel = smtpd.SMTPChannel(self.server, conn, addr)
+ def tearDown(self):
+ asyncore.socket = smtpd.socket = socket
+
def write_line(self, line):
self.channel.socket.queue_recv(line)
self.channel.handle_read()
@@ -88,7 +55,7 @@ class SMTPDChannelTest(TestCase):
b'502 Error: command "EHLO" not implemented\r\n')
def test_HELO(self):
- name = socket.getfqdn()
+ name = smtpd.socket.getfqdn()
self.write_line(b'HELO test.example')
self.assertEqual(self.channel.socket.last,
'250 {}\r\n'.format(name).encode('ascii'))
diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py
index d0b2b27..57faf6e 100644
--- a/Lib/test/test_smtplib.py
+++ b/Lib/test/test_smtplib.py
@@ -9,7 +9,7 @@ import time
import select
import unittest
-from test import support
+from test import support, mock_socket
try:
import threading
@@ -48,27 +48,17 @@ def server(evt, buf, serv):
serv.close()
evt.set()
-@unittest.skipUnless(threading, 'Threading required for this test.')
class GeneralTests(unittest.TestCase):
def setUp(self):
- self._threads = support.threading_setup()
- self.evt = threading.Event()
- self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.sock.settimeout(15)
- self.port = support.bind_port(self.sock)
- servargs = (self.evt, b"220 Hola mundo\n", self.sock)
- self.thread = threading.Thread(target=server, args=servargs)
- self.thread.start()
- self.evt.wait()
- self.evt.clear()
+ smtplib.socket = mock_socket
+ self.port = 25
def tearDown(self):
- self.evt.wait()
- self.thread.join()
- support.threading_cleanup(*self._threads)
+ smtplib.socket = socket
def testBasic1(self):
+ mock_socket.reply_with(b"220 Hola mundo")
# connects
smtp = smtplib.SMTP(HOST, self.port)
smtp.close()
@@ -85,12 +75,13 @@ class GeneralTests(unittest.TestCase):
smtp.close()
def testTimeoutDefault(self):
- self.assertTrue(socket.getdefaulttimeout() is None)
- socket.setdefaulttimeout(30)
+ self.assertTrue(mock_socket.getdefaulttimeout() is None)
+ mock_socket.setdefaulttimeout(30)
+ self.assertEqual(mock_socket.getdefaulttimeout(), 30)
try:
smtp = smtplib.SMTP(HOST, self.port)
finally:
- socket.setdefaulttimeout(None)
+ mock_socket.setdefaulttimeout(None)
self.assertEqual(smtp.sock.gettimeout(), 30)
smtp.close()
@@ -155,6 +146,8 @@ MSG_END = '------------ END MESSAGE ------------\n'
class DebuggingServerTests(unittest.TestCase):
def setUp(self):
+ self.real_getfqdn = socket.getfqdn
+ socket.getfqdn = mock_socket.getfqdn
# temporarily replace sys.stdout to capture DebuggingServer output
self.old_stdout = sys.stdout
self.output = io.StringIO()
@@ -176,6 +169,7 @@ class DebuggingServerTests(unittest.TestCase):
self.serv_evt.clear()
def tearDown(self):
+ socket.getfqdn = self.real_getfqdn
# indicate that the client is finished
self.client_evt.set()
# wait for the server thread to terminate
@@ -251,6 +245,12 @@ class DebuggingServerTests(unittest.TestCase):
class NonConnectingTests(unittest.TestCase):
+ def setUp(self):
+ smtplib.socket = mock_socket
+
+ def tearDown(self):
+ smtplib.socket = socket
+
def testNotConnected(self):
# Test various operations on an unconnected SMTP object that
# should raise exceptions (at present the attempt in SMTP.send
@@ -263,9 +263,9 @@ class NonConnectingTests(unittest.TestCase):
def testNonnumericPort(self):
# check that non-numeric port raises socket.error
- self.assertRaises(socket.error, smtplib.SMTP,
+ self.assertRaises(mock_socket.error, smtplib.SMTP,
"localhost", "bogus")
- self.assertRaises(socket.error, smtplib.SMTP,
+ self.assertRaises(mock_socket.error, smtplib.SMTP,
"localhost:bogus")
@@ -274,25 +274,15 @@ class NonConnectingTests(unittest.TestCase):
class BadHELOServerTests(unittest.TestCase):
def setUp(self):
+ smtplib.socket = mock_socket
+ mock_socket.reply_with(b"199 no hello for you!")
self.old_stdout = sys.stdout
self.output = io.StringIO()
sys.stdout = self.output
-
- self._threads = support.threading_setup()
- self.evt = threading.Event()
- self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.sock.settimeout(15)
- self.port = support.bind_port(self.sock)
- servargs = (self.evt, b"199 no hello for you!\n", self.sock)
- self.thread = threading.Thread(target=server, args=servargs)
- self.thread.start()
- self.evt.wait()
- self.evt.clear()
+ self.port = 25
def tearDown(self):
- self.evt.wait()
- self.thread.join()
- support.threading_cleanup(*self._threads)
+ smtplib.socket = socket
sys.stdout = self.old_stdout
def testFailingHELO(self):
@@ -405,6 +395,8 @@ class SimSMTPServer(smtpd.SMTPServer):
class SMTPSimTests(unittest.TestCase):
def setUp(self):
+ self.real_getfqdn = socket.getfqdn
+ socket.getfqdn = mock_socket.getfqdn
self._threads = support.threading_setup()
self.serv_evt = threading.Event()
self.client_evt = threading.Event()
@@ -421,6 +413,7 @@ class SMTPSimTests(unittest.TestCase):
self.serv_evt.clear()
def tearDown(self):
+ socket.getfqdn = self.real_getfqdn
# indicate that the client is finished
self.client_evt.set()
# wait for the server thread to terminate