summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_socket.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_socket.py')
-rw-r--r--Lib/test/test_socket.py67
1 files changed, 62 insertions, 5 deletions
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 01b9b5b..356b801 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -11,6 +11,7 @@ import Queue
import sys
import array
from weakref import proxy
+import signal
PORT = 50007
HOST = 'localhost'
@@ -21,7 +22,8 @@ class SocketTCPTest(unittest.TestCase):
def setUp(self):
self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.serv.bind((HOST, PORT))
+ global PORT
+ PORT = test_support.bind_port(self.serv, HOST, PORT)
self.serv.listen(1)
def tearDown(self):
@@ -33,7 +35,8 @@ class SocketUDPTest(unittest.TestCase):
def setUp(self):
self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.serv.bind((HOST, PORT))
+ global PORT
+ PORT = test_support.bind_port(self.serv, HOST, PORT)
def tearDown(self):
self.serv.close()
@@ -447,7 +450,12 @@ class GeneralModuleTests(unittest.TestCase):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("0.0.0.0", PORT+1))
name = sock.getsockname()
- self.assertEqual(name, ("0.0.0.0", PORT+1))
+ # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate
+ # it reasonable to get the host's addr in addition to 0.0.0.0.
+ # At least for eCos. This is required for the S/390 to pass.
+ my_ip_addr = socket.gethostbyname(socket.gethostname())
+ self.assert_(name[0] in ("0.0.0.0", my_ip_addr), '%s invalid' % name[0])
+ self.assertEqual(name[1], PORT+1)
def testGetSockOpt(self):
# Testing getsockopt()
@@ -575,6 +583,21 @@ class BasicUDPTest(ThreadedUDPSocketTest):
def _testRecvFrom(self):
self.cli.sendto(MSG, 0, (HOST, PORT))
+class TCPCloserTest(ThreadedTCPSocketTest):
+
+ def testClose(self):
+ conn, addr = self.serv.accept()
+ conn.close()
+
+ sd = self.cli
+ read, write, err = select.select([sd], [], [], 1.0)
+ self.assertEqual(read, [sd])
+ self.assertEqual(sd.recv(1), '')
+
+ def _testClose(self):
+ self.cli.connect((HOST, PORT))
+ time.sleep(1.0)
+
class BasicSocketPairTest(SocketPairTest):
def __init__(self, methodName='runTest'):
@@ -795,6 +818,37 @@ class TCPTimeoutTest(SocketTCPTest):
if not ok:
self.fail("accept() returned success when we did not expect it")
+ def testInterruptedTimeout(self):
+ # XXX I don't know how to do this test on MSWindows or any other
+ # plaform that doesn't support signal.alarm() or os.kill(), though
+ # the bug should have existed on all platforms.
+ if not hasattr(signal, "alarm"):
+ return # can only test on *nix
+ self.serv.settimeout(5.0) # must be longer than alarm
+ class Alarm(Exception):
+ pass
+ def alarm_handler(signal, frame):
+ raise Alarm
+ old_alarm = signal.signal(signal.SIGALRM, alarm_handler)
+ try:
+ signal.alarm(2) # POSIX allows alarm to be up to 1 second early
+ try:
+ foo = self.serv.accept()
+ except socket.timeout:
+ self.fail("caught timeout instead of Alarm")
+ except Alarm:
+ pass
+ except:
+ self.fail("caught other exception instead of Alarm")
+ else:
+ self.fail("nothing caught")
+ signal.alarm(0) # shut off alarm
+ except Alarm:
+ self.fail("got Alarm in wrong place")
+ finally:
+ # no alarm can be pending. Safe to restore old handler.
+ signal.signal(signal.SIGALRM, old_alarm)
+
class UDPTimeoutTest(SocketTCPTest):
def testUDPTimeout(self):
@@ -883,8 +937,8 @@ class BufferIOTest(SocketConnectedTest):
self.serv_conn.send(buf)
def test_main():
- tests = [GeneralModuleTests, BasicTCPTest, TCPTimeoutTest, TestExceptions,
- BufferIOTest]
+ tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
+ TestExceptions, BufferIOTest]
if sys.platform != 'mac':
tests.extend([ BasicUDPTest, UDPTimeoutTest ])
@@ -899,7 +953,10 @@ def test_main():
tests.append(BasicSocketPairTest)
if sys.platform == 'linux2':
tests.append(TestLinuxAbstractNamespace)
+
+ thread_info = test_support.threading_setup()
test_support.run_unittest(*tests)
+ test_support.threading_cleanup(*thread_info)
if __name__ == "__main__":
test_main()