diff options
Diffstat (limited to 'Lib/test/test_socket.py')
-rw-r--r-- | Lib/test/test_socket.py | 282 |
1 files changed, 279 insertions, 3 deletions
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index b1ce058..0712a4f 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -19,6 +19,8 @@ import signal import math import pickle import struct +import random +import string try: import multiprocessing except ImportError: @@ -1344,10 +1346,13 @@ class GeneralModuleTests(unittest.TestCase): def test_listen_backlog(self): for backlog in 0, -1: - srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: + srv.bind((HOST, 0)) + srv.listen(backlog) + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: srv.bind((HOST, 0)) - srv.listen(backlog) - srv.close() + srv.listen() @support.cpython_only def test_listen_backlog_overflow(self): @@ -5075,6 +5080,275 @@ class TestSocketSharing(SocketTCPTest): source.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendfileUsingSendTest(ThreadedTCPSocketTest): + """ + Test the send() implementation of socket.sendfile(). + """ + + FILESIZE = (10 * 1024 * 1024) # 10MB + BUFSIZE = 8192 + FILEDATA = b"" + TIMEOUT = 2 + + @classmethod + def setUpClass(cls): + def chunks(total, step): + assert total >= step + while total > step: + yield step + total -= step + if total: + yield total + + chunk = b"".join([random.choice(string.ascii_letters).encode() + for i in range(cls.BUFSIZE)]) + with open(support.TESTFN, 'wb') as f: + for csize in chunks(cls.FILESIZE, cls.BUFSIZE): + f.write(chunk) + with open(support.TESTFN, 'rb') as f: + cls.FILEDATA = f.read() + assert len(cls.FILEDATA) == cls.FILESIZE + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + + def accept_conn(self): + self.serv.settimeout(self.TIMEOUT) + conn, addr = self.serv.accept() + conn.settimeout(self.TIMEOUT) + self.addCleanup(conn.close) + return conn + + def recv_data(self, conn): + received = [] + while True: + chunk = conn.recv(self.BUFSIZE) + if not chunk: + break + received.append(chunk) + return b''.join(received) + + def meth_from_sock(self, sock): + # Depending on the mixin class being run return either send() + # or sendfile() method implementation. + return getattr(sock, "_sendfile_use_send") + + # regular file + + def _testRegularFile(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + + def testRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # non regular file + + def _testNonRegularFile(self): + address = self.serv.getsockname() + file = io.BytesIO(self.FILEDATA) + with socket.create_connection(address) as sock, file as file: + sent = sock.sendfile(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + self.assertRaises(socket._GiveupOnSendfile, + sock._sendfile_use_sendfile, file) + + def testNonRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # empty file + + def _testEmptyFileSend(self): + address = self.serv.getsockname() + filename = support.TESTFN + "2" + with open(filename, 'wb'): + self.addCleanup(support.unlink, filename) + file = open(filename, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, 0) + self.assertEqual(file.tell(), 0) + + def testEmptyFileSend(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(data, b"") + + # offset + + def _testOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file, offset=5000) + self.assertEqual(sent, self.FILESIZE - 5000) + self.assertEqual(file.tell(), self.FILESIZE) + + def testOffset(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE - 5000) + self.assertEqual(data, self.FILEDATA[5000:]) + + # count + + def _testCount(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 5000007 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCount(self): + count = 5000007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count small + + def _testCountSmall(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 1 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCountSmall(self): + count = 1 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count + offset + + def _testCountWithOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 100007 + meth = self.meth_from_sock(sock) + sent = meth(file, offset=2007, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count + 2007) + + def testCountWithOffset(self): + count = 100007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[2007:count+2007]) + + # non blocking sockets are not supposed to work + + def _testNonBlocking(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sock.setblocking(False) + meth = self.meth_from_sock(sock) + self.assertRaises(ValueError, meth, file) + self.assertRaises(ValueError, sock.sendfile, file) + + def testNonBlocking(self): + conn = self.accept_conn() + if conn.recv(8192): + self.fail('was not supposed to receive any data') + + # timeout (non-triggered) + + def _testWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + + def testWithTimeout(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # timeout (triggered) + + def _testWithTimeoutTriggeredSend(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=0.01) as sock, \ + file as file: + meth = self.meth_from_sock(sock) + self.assertRaises(socket.timeout, meth, file) + + def testWithTimeoutTriggeredSend(self): + conn = self.accept_conn() + conn.recv(88192) + + # errors + + def _test_errors(self): + pass + + def test_errors(self): + with open(support.TESTFN, 'rb') as file: + with socket.socket(type=socket.SOCK_DGRAM) as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "SOCK_STREAM", meth, file) + with open(support.TESTFN, 'rt') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "binary mode", meth, file) + with open(support.TESTFN, 'rb') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count='2') + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count=0.1) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=0) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=-1) + + +@unittest.skipUnless(thread, 'Threading required for this test.') +@unittest.skipUnless(hasattr(os, "sendfile"), + 'os.sendfile() required for this test.') +class SendfileUsingSendfileTest(SendfileUsingSendTest): + """ + Test the sendfile() implementation of socket.sendfile(). + """ + def meth_from_sock(self, sock): + return getattr(sock, "_sendfile_use_sendfile") + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] @@ -5127,6 +5401,8 @@ def test_main(): InterruptedRecvTimeoutTest, InterruptedSendTimeoutTest, TestSocketSharing, + SendfileUsingSendTest, + SendfileUsingSendfileTest, ]) thread_info = support.threading_setup() |