diff options
Diffstat (limited to 'Lib/test/test_socket.py')
| -rw-r--r-- | Lib/test/test_socket.py | 426 | 
1 files changed, 292 insertions, 134 deletions
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index b412386..cf45b73 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -20,6 +20,8 @@ import signal  import math  import pickle  import struct +import random +import string  try:      import multiprocessing  except ImportError: @@ -76,7 +78,7 @@ class SocketTCPTest(unittest.TestCase):      def setUp(self):          self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)          self.port = support.bind_port(self.serv) -        self.serv.listen(1) +        self.serv.listen()      def tearDown(self):          self.serv.close() @@ -445,7 +447,7 @@ class SocketListeningTestMixin(SocketTestBase):      def setUp(self):          super().setUp() -        self.serv.listen(1) +        self.serv.listen()  class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase, @@ -716,11 +718,11 @@ class GeneralModuleTests(unittest.TestCase):          with self.assertRaises(TypeError) as cm:              s.sendto('\u2620', sockname)          self.assertEqual(str(cm.exception), -                         "'str' does not support the buffer interface") +                         "a bytes-like object is required, not 'str'")          with self.assertRaises(TypeError) as cm:              s.sendto(5j, sockname)          self.assertEqual(str(cm.exception), -                         "'complex' does not support the buffer interface") +                         "a bytes-like object is required, not 'complex'")          with self.assertRaises(TypeError) as cm:              s.sendto(b'foo', None)          self.assertIn('not NoneType',str(cm.exception)) @@ -728,11 +730,11 @@ class GeneralModuleTests(unittest.TestCase):          with self.assertRaises(TypeError) as cm:              s.sendto('\u2620', 0, sockname)          self.assertEqual(str(cm.exception), -                         "'str' does not support the buffer interface") +                         "a bytes-like object is required, not 'str'")          with self.assertRaises(TypeError) as cm:              s.sendto(5j, 0, sockname)          self.assertEqual(str(cm.exception), -                         "'complex' does not support the buffer interface") +                         "a bytes-like object is required, not 'complex'")          with self.assertRaises(TypeError) as cm:              s.sendto(b'foo', 0, None)          self.assertIn('not NoneType', str(cm.exception)) @@ -1378,10 +1380,13 @@ class GeneralModuleTests(unittest.TestCase):      def test_listen_backlog(self):          for backlog in 0, -1: -            srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: +                srv.bind((HOST, 0)) +                srv.listen(backlog) + +        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv:              srv.bind((HOST, 0)) -            srv.listen(backlog) -            srv.close() +            srv.listen()      @support.cpython_only      def test_listen_backlog_overflow(self): @@ -3585,7 +3590,7 @@ class InterruptedTimeoutBase(unittest.TestCase):      def setUp(self):          super().setUp()          orig_alrm_handler = signal.signal(signal.SIGALRM, -                                          lambda signum, frame: None) +                                          lambda signum, frame: 1 / 0)          self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler)          self.addCleanup(self.setAlarm, 0) @@ -3622,13 +3627,11 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase):          self.serv.settimeout(self.timeout)      def checkInterruptedRecv(self, func, *args, **kwargs): -        # Check that func(*args, **kwargs) raises OSError with an +        # Check that func(*args, **kwargs) raises          # errno of EINTR when interrupted by a signal.          self.setAlarm(self.alarm_time) -        with self.assertRaises(OSError) as cm: +        with self.assertRaises(ZeroDivisionError) as cm:              func(*args, **kwargs) -        self.assertNotIsInstance(cm.exception, socket.timeout) -        self.assertEqual(cm.exception.errno, errno.EINTR)      def testInterruptedRecvTimeout(self):          self.checkInterruptedRecv(self.serv.recv, 1024) @@ -3684,12 +3687,10 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase,          # Check that func(*args, **kwargs), run in a loop, raises          # OSError with an errno of EINTR when interrupted by a          # signal. -        with self.assertRaises(OSError) as cm: +        with self.assertRaises(ZeroDivisionError) as cm:              while True:                  self.setAlarm(self.alarm_time)                  func(*args, **kwargs) -        self.assertNotIsInstance(cm.exception, socket.timeout) -        self.assertEqual(cm.exception.errno, errno.EINTR)      # Issue #12958: The following tests have problems on OS X prior to 10.7      @support.requires_mac_ver(10, 7) @@ -3731,8 +3732,6 @@ class TCPCloserTest(ThreadedTCPSocketTest):          self.cli.connect((HOST, self.port))          time.sleep(1.0) -@unittest.skipUnless(hasattr(socket, 'socketpair'), -                     'test needs socket.socketpair()')  @unittest.skipUnless(thread, 'Threading required for this test.')  class BasicSocketPairTest(SocketPairTest): @@ -3813,7 +3812,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest):          self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM |                                                    socket.SOCK_NONBLOCK)          self.port = support.bind_port(self.serv) -        self.serv.listen(1) +        self.serv.listen()          # actual testing          start = time.time()          try: @@ -4059,117 +4058,6 @@ class FileObjectClassTestCase(SocketConnectedTest):          pass -class FileObjectInterruptedTestCase(unittest.TestCase): -    """Test that the file object correctly handles EINTR internally.""" - -    class MockSocket(object): -        def __init__(self, recv_funcs=()): -            # A generator that returns callables that we'll call for each -            # call to recv(). -            self._recv_step = iter(recv_funcs) - -        def recv_into(self, buffer): -            data = next(self._recv_step)() -            assert len(buffer) >= len(data) -            buffer[:len(data)] = data -            return len(data) - -        def _decref_socketios(self): -            pass - -        def _textiowrap_for_test(self, buffering=-1): -            raw = socket.SocketIO(self, "r") -            if buffering < 0: -                buffering = io.DEFAULT_BUFFER_SIZE -            if buffering == 0: -                return raw -            buffer = io.BufferedReader(raw, buffering) -            text = io.TextIOWrapper(buffer, None, None) -            text.mode = "rb" -            return text - -    @staticmethod -    def _raise_eintr(): -        raise OSError(errno.EINTR, "interrupted") - -    def _textiowrap_mock_socket(self, mock, buffering=-1): -        raw = socket.SocketIO(mock, "r") -        if buffering < 0: -            buffering = io.DEFAULT_BUFFER_SIZE -        if buffering == 0: -            return raw -        buffer = io.BufferedReader(raw, buffering) -        text = io.TextIOWrapper(buffer, None, None) -        text.mode = "rb" -        return text - -    def _test_readline(self, size=-1, buffering=-1): -        mock_sock = self.MockSocket(recv_funcs=[ -                lambda : b"This is the first line\nAnd the sec", -                self._raise_eintr, -                lambda : b"ond line is here\n", -                lambda : b"", -                lambda : b"",  # XXX(gps): io library does an extra EOF read -            ]) -        fo = mock_sock._textiowrap_for_test(buffering=buffering) -        self.assertEqual(fo.readline(size), "This is the first line\n") -        self.assertEqual(fo.readline(size), "And the second line is here\n") - -    def _test_read(self, size=-1, buffering=-1): -        mock_sock = self.MockSocket(recv_funcs=[ -                lambda : b"This is the first line\nAnd the sec", -                self._raise_eintr, -                lambda : b"ond line is here\n", -                lambda : b"", -                lambda : b"",  # XXX(gps): io library does an extra EOF read -            ]) -        expecting = (b"This is the first line\n" -                     b"And the second line is here\n") -        fo = mock_sock._textiowrap_for_test(buffering=buffering) -        if buffering == 0: -            data = b'' -        else: -            data = '' -            expecting = expecting.decode('utf-8') -        while len(data) != len(expecting): -            part = fo.read(size) -            if not part: -                break -            data += part -        self.assertEqual(data, expecting) - -    def test_default(self): -        self._test_readline() -        self._test_readline(size=100) -        self._test_read() -        self._test_read(size=100) - -    def test_with_1k_buffer(self): -        self._test_readline(buffering=1024) -        self._test_readline(size=100, buffering=1024) -        self._test_read(buffering=1024) -        self._test_read(size=100, buffering=1024) - -    def _test_readline_no_buffer(self, size=-1): -        mock_sock = self.MockSocket(recv_funcs=[ -                lambda : b"a", -                lambda : b"\n", -                lambda : b"B", -                self._raise_eintr, -                lambda : b"b", -                lambda : b"", -            ]) -        fo = mock_sock._textiowrap_for_test(buffering=0) -        self.assertEqual(fo.readline(size), b"a\n") -        self.assertEqual(fo.readline(size), b"Bb") - -    def test_no_buffer(self): -        self._test_readline_no_buffer() -        self._test_readline_no_buffer(size=4) -        self._test_read(buffering=0) -        self._test_read(size=100, buffering=0) - -  class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase):      """Repeat the tests from FileObjectClassTestCase with bufsize==0. @@ -4588,7 +4476,7 @@ class TestLinuxAbstractNamespace(unittest.TestCase):          address = b"\x00python-test-hello\x00\xff"          with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1:              s1.bind(address) -            s1.listen(1) +            s1.listen()              with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2:                  s2.connect(s1.getsockname())                  with s1.accept()[0] as s3: @@ -4820,7 +4708,7 @@ class TIPCThreadableTest(unittest.TestCase, ThreadableTest):          srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE,                  TIPC_LOWER, TIPC_UPPER)          self.srv.bind(srvaddr) -        self.srv.listen(5) +        self.srv.listen()          self.serverExplicitReady()          self.conn, self.connaddr = self.srv.accept()          self.addCleanup(self.conn.close) @@ -5109,6 +4997,275 @@ class TestSocketSharing(SocketTCPTest):                      source.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendfileUsingSendTest(ThreadedTCPSocketTest): +    """ +    Test the send() implementation of socket.sendfile(). +    """ + +    FILESIZE = (10 * 1024 * 1024)  # 10MB +    BUFSIZE = 8192 +    FILEDATA = b"" +    TIMEOUT = 2 + +    @classmethod +    def setUpClass(cls): +        def chunks(total, step): +            assert total >= step +            while total > step: +                yield step +                total -= step +            if total: +                yield total + +        chunk = b"".join([random.choice(string.ascii_letters).encode() +                          for i in range(cls.BUFSIZE)]) +        with open(support.TESTFN, 'wb') as f: +            for csize in chunks(cls.FILESIZE, cls.BUFSIZE): +                f.write(chunk) +        with open(support.TESTFN, 'rb') as f: +            cls.FILEDATA = f.read() +            assert len(cls.FILEDATA) == cls.FILESIZE + +    @classmethod +    def tearDownClass(cls): +        support.unlink(support.TESTFN) + +    def accept_conn(self): +        self.serv.settimeout(self.TIMEOUT) +        conn, addr = self.serv.accept() +        conn.settimeout(self.TIMEOUT) +        self.addCleanup(conn.close) +        return conn + +    def recv_data(self, conn): +        received = [] +        while True: +            chunk = conn.recv(self.BUFSIZE) +            if not chunk: +                break +            received.append(chunk) +        return b''.join(received) + +    def meth_from_sock(self, sock): +        # Depending on the mixin class being run return either send() +        # or sendfile() method implementation. +        return getattr(sock, "_sendfile_use_send") + +    # regular file + +    def _testRegularFile(self): +        address = self.serv.getsockname() +        file = open(support.TESTFN, 'rb') +        with socket.create_connection(address) as sock, file as file: +            meth = self.meth_from_sock(sock) +            sent = meth(file) +            self.assertEqual(sent, self.FILESIZE) +            self.assertEqual(file.tell(), self.FILESIZE) + +    def testRegularFile(self): +        conn = self.accept_conn() +        data = self.recv_data(conn) +        self.assertEqual(len(data), self.FILESIZE) +        self.assertEqual(data, self.FILEDATA) + +    # non regular file + +    def _testNonRegularFile(self): +        address = self.serv.getsockname() +        file = io.BytesIO(self.FILEDATA) +        with socket.create_connection(address) as sock, file as file: +            sent = sock.sendfile(file) +            self.assertEqual(sent, self.FILESIZE) +            self.assertEqual(file.tell(), self.FILESIZE) +            self.assertRaises(socket._GiveupOnSendfile, +                              sock._sendfile_use_sendfile, file) + +    def testNonRegularFile(self): +        conn = self.accept_conn() +        data = self.recv_data(conn) +        self.assertEqual(len(data), self.FILESIZE) +        self.assertEqual(data, self.FILEDATA) + +    # empty file + +    def _testEmptyFileSend(self): +        address = self.serv.getsockname() +        filename = support.TESTFN + "2" +        with open(filename, 'wb'): +            self.addCleanup(support.unlink, filename) +        file = open(filename, 'rb') +        with socket.create_connection(address) as sock, file as file: +            meth = self.meth_from_sock(sock) +            sent = meth(file) +            self.assertEqual(sent, 0) +            self.assertEqual(file.tell(), 0) + +    def testEmptyFileSend(self): +        conn = self.accept_conn() +        data = self.recv_data(conn) +        self.assertEqual(data, b"") + +    # offset + +    def _testOffset(self): +        address = self.serv.getsockname() +        file = open(support.TESTFN, 'rb') +        with socket.create_connection(address) as sock, file as file: +            meth = self.meth_from_sock(sock) +            sent = meth(file, offset=5000) +            self.assertEqual(sent, self.FILESIZE - 5000) +            self.assertEqual(file.tell(), self.FILESIZE) + +    def testOffset(self): +        conn = self.accept_conn() +        data = self.recv_data(conn) +        self.assertEqual(len(data), self.FILESIZE - 5000) +        self.assertEqual(data, self.FILEDATA[5000:]) + +    # count + +    def _testCount(self): +        address = self.serv.getsockname() +        file = open(support.TESTFN, 'rb') +        with socket.create_connection(address, timeout=2) as sock, file as file: +            count = 5000007 +            meth = self.meth_from_sock(sock) +            sent = meth(file, count=count) +            self.assertEqual(sent, count) +            self.assertEqual(file.tell(), count) + +    def testCount(self): +        count = 5000007 +        conn = self.accept_conn() +        data = self.recv_data(conn) +        self.assertEqual(len(data), count) +        self.assertEqual(data, self.FILEDATA[:count]) + +    # count small + +    def _testCountSmall(self): +        address = self.serv.getsockname() +        file = open(support.TESTFN, 'rb') +        with socket.create_connection(address, timeout=2) as sock, file as file: +            count = 1 +            meth = self.meth_from_sock(sock) +            sent = meth(file, count=count) +            self.assertEqual(sent, count) +            self.assertEqual(file.tell(), count) + +    def testCountSmall(self): +        count = 1 +        conn = self.accept_conn() +        data = self.recv_data(conn) +        self.assertEqual(len(data), count) +        self.assertEqual(data, self.FILEDATA[:count]) + +    # count + offset + +    def _testCountWithOffset(self): +        address = self.serv.getsockname() +        file = open(support.TESTFN, 'rb') +        with socket.create_connection(address, timeout=2) as sock, file as file: +            count = 100007 +            meth = self.meth_from_sock(sock) +            sent = meth(file, offset=2007, count=count) +            self.assertEqual(sent, count) +            self.assertEqual(file.tell(), count + 2007) + +    def testCountWithOffset(self): +        count = 100007 +        conn = self.accept_conn() +        data = self.recv_data(conn) +        self.assertEqual(len(data), count) +        self.assertEqual(data, self.FILEDATA[2007:count+2007]) + +    # non blocking sockets are not supposed to work + +    def _testNonBlocking(self): +        address = self.serv.getsockname() +        file = open(support.TESTFN, 'rb') +        with socket.create_connection(address) as sock, file as file: +            sock.setblocking(False) +            meth = self.meth_from_sock(sock) +            self.assertRaises(ValueError, meth, file) +            self.assertRaises(ValueError, sock.sendfile, file) + +    def testNonBlocking(self): +        conn = self.accept_conn() +        if conn.recv(8192): +            self.fail('was not supposed to receive any data') + +    # timeout (non-triggered) + +    def _testWithTimeout(self): +        address = self.serv.getsockname() +        file = open(support.TESTFN, 'rb') +        with socket.create_connection(address, timeout=2) as sock, file as file: +            meth = self.meth_from_sock(sock) +            sent = meth(file) +            self.assertEqual(sent, self.FILESIZE) + +    def testWithTimeout(self): +        conn = self.accept_conn() +        data = self.recv_data(conn) +        self.assertEqual(len(data), self.FILESIZE) +        self.assertEqual(data, self.FILEDATA) + +    # timeout (triggered) + +    def _testWithTimeoutTriggeredSend(self): +        address = self.serv.getsockname() +        file = open(support.TESTFN, 'rb') +        with socket.create_connection(address, timeout=0.01) as sock, \ +                file as file: +            meth = self.meth_from_sock(sock) +            self.assertRaises(socket.timeout, meth, file) + +    def testWithTimeoutTriggeredSend(self): +        conn = self.accept_conn() +        conn.recv(88192) + +    # errors + +    def _test_errors(self): +        pass + +    def test_errors(self): +        with open(support.TESTFN, 'rb') as file: +            with socket.socket(type=socket.SOCK_DGRAM) as s: +                meth = self.meth_from_sock(s) +                self.assertRaisesRegex( +                    ValueError, "SOCK_STREAM", meth, file) +        with open(support.TESTFN, 'rt') as file: +            with socket.socket() as s: +                meth = self.meth_from_sock(s) +                self.assertRaisesRegex( +                    ValueError, "binary mode", meth, file) +        with open(support.TESTFN, 'rb') as file: +            with socket.socket() as s: +                meth = self.meth_from_sock(s) +                self.assertRaisesRegex(TypeError, "positive integer", +                                       meth, file, count='2') +                self.assertRaisesRegex(TypeError, "positive integer", +                                       meth, file, count=0.1) +                self.assertRaisesRegex(ValueError, "positive integer", +                                       meth, file, count=0) +                self.assertRaisesRegex(ValueError, "positive integer", +                                       meth, file, count=-1) + + +@unittest.skipUnless(thread, 'Threading required for this test.') +@unittest.skipUnless(hasattr(os, "sendfile"), +                     'os.sendfile() required for this test.') +class SendfileUsingSendfileTest(SendfileUsingSendTest): +    """ +    Test the sendfile() implementation of socket.sendfile(). +    """ +    def meth_from_sock(self, sock): +        return getattr(sock, "_sendfile_use_sendfile") + +  def test_main():      tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,               TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] @@ -5116,7 +5273,6 @@ def test_main():      tests.extend([          NonBlockingTCPTests,          FileObjectClassTestCase, -        FileObjectInterruptedTestCase,          UnbufferedFileObjectClassTestCase,          LineBufferedFileObjectClassTestCase,          SmallBufferedFileObjectClassTestCase, @@ -5161,6 +5317,8 @@ def test_main():          InterruptedRecvTimeoutTest,          InterruptedSendTimeoutTest,          TestSocketSharing, +        SendfileUsingSendTest, +        SendfileUsingSendfileTest,      ])      thread_info = support.threading_setup()  | 
