diff options
Diffstat (limited to 'Lib/test')
381 files changed, 20341 insertions, 4719 deletions
diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 204d894..e9120ab 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -19,7 +19,7 @@ import logging import struct import operator import test.support -import test.script_helper +import test.support.script_helper # Skip tests if _multiprocessing wasn't built. @@ -2624,7 +2624,7 @@ class _TestPicklingConnections(BaseTestCase): l = socket.socket() l.bind((test.support.HOST, 0)) - l.listen(1) + l.listen() conn.send(l.getsockname()) new_conn, addr = l.accept() conn.send(new_conn) @@ -3271,7 +3271,7 @@ class TestWait(unittest.TestCase): from multiprocessing.connection import wait l = socket.socket() l.bind((test.support.HOST, 0)) - l.listen(4) + l.listen() addr = l.getsockname() readers = [] procs = [] @@ -3483,11 +3483,11 @@ class TestNoForkBomb(unittest.TestCase): sm = multiprocessing.get_start_method() name = os.path.join(os.path.dirname(__file__), 'mp_fork_bomb.py') if sm != 'fork': - rc, out, err = test.script_helper.assert_python_failure(name, sm) + rc, out, err = test.support.script_helper.assert_python_failure(name, sm) self.assertEqual(out, b'') self.assertIn(b'RuntimeError', err) else: - rc, out, err = test.script_helper.assert_python_ok(name, sm) + rc, out, err = test.support.script_helper.assert_python_ok(name, sm) self.assertEqual(out.rstrip(), b'123') self.assertEqual(err, b'') diff --git a/Lib/test/badsyntax_async1.py b/Lib/test/badsyntax_async1.py new file mode 100644 index 0000000..fb85e29 --- /dev/null +++ b/Lib/test/badsyntax_async1.py @@ -0,0 +1,2 @@ +async def foo(a=await something()): + pass diff --git a/Lib/test/badsyntax_async2.py b/Lib/test/badsyntax_async2.py new file mode 100644 index 0000000..fb85e29 --- /dev/null +++ b/Lib/test/badsyntax_async2.py @@ -0,0 +1,2 @@ +async def foo(a=await something()): + pass diff --git a/Lib/test/badsyntax_async3.py b/Lib/test/badsyntax_async3.py new file mode 100644 index 0000000..dde1bc5 --- /dev/null +++ b/Lib/test/badsyntax_async3.py @@ -0,0 +1,2 @@ +async def foo(): + [i async for i in els] diff --git a/Lib/test/badsyntax_async4.py b/Lib/test/badsyntax_async4.py new file mode 100644 index 0000000..d033b28 --- /dev/null +++ b/Lib/test/badsyntax_async4.py @@ -0,0 +1,2 @@ +async def foo(): + await diff --git a/Lib/test/badsyntax_async5.py b/Lib/test/badsyntax_async5.py new file mode 100644 index 0000000..9d19af6 --- /dev/null +++ b/Lib/test/badsyntax_async5.py @@ -0,0 +1,2 @@ +def foo(): + await something() diff --git a/Lib/test/badsyntax_async6.py b/Lib/test/badsyntax_async6.py new file mode 100644 index 0000000..cb0a23d --- /dev/null +++ b/Lib/test/badsyntax_async6.py @@ -0,0 +1,2 @@ +async def foo(): + yield diff --git a/Lib/test/badsyntax_async7.py b/Lib/test/badsyntax_async7.py new file mode 100644 index 0000000..51e4bf9 --- /dev/null +++ b/Lib/test/badsyntax_async7.py @@ -0,0 +1,2 @@ +async def foo(): + yield from [] diff --git a/Lib/test/badsyntax_async8.py b/Lib/test/badsyntax_async8.py new file mode 100644 index 0000000..3c636f9 --- /dev/null +++ b/Lib/test/badsyntax_async8.py @@ -0,0 +1,2 @@ +async def foo(): + await await fut diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py index 3d50fc1..63c3ae8 100644 --- a/Lib/test/datetimetester.py +++ b/Lib/test/datetimetester.py @@ -3,6 +3,7 @@ See http://www.zope.org/Members/fdrake/DateTimeWiki/TestCases """ +import decimal import sys import pickle import random @@ -50,6 +51,17 @@ class TestModule(unittest.TestCase): self.assertEqual(datetime.MINYEAR, 1) self.assertEqual(datetime.MAXYEAR, 9999) + def test_name_cleanup(self): + if '_Fast' not in str(self): + return + datetime = datetime_module + names = set(name for name in dir(datetime) + if not name.startswith('__') and not name.endswith('__')) + allowed = set(['MAXYEAR', 'MINYEAR', 'date', 'datetime', + 'datetime_CAPI', 'time', 'timedelta', 'timezone', + 'tzinfo']) + self.assertEqual(names - allowed, set([])) + def test_divide_and_round(self): if '_Fast' in str(self): return @@ -1196,11 +1208,13 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase): #check that this standard extension works t.strftime("%f") - def test_format(self): dt = self.theclass(2007, 9, 10) self.assertEqual(dt.__format__(''), str(dt)) + with self.assertRaisesRegex(TypeError, '^must be str, not int$'): + dt.__format__(123) + # check that a derived class's __str__() gets called class A(self.theclass): def __str__(self): @@ -1352,8 +1366,6 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase): return isinstance(other, LargerThanAnything) def __eq__(self, other): return isinstance(other, LargerThanAnything) - def __ne__(self, other): - return not isinstance(other, LargerThanAnything) def __gt__(self, other): return not isinstance(other, LargerThanAnything) def __ge__(self, other): @@ -1456,9 +1468,10 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase): for month_byte in b'9', b'\0', b'\r', b'\xff': self.assertRaises(TypeError, self.theclass, base[:2] + month_byte + base[3:]) - # Good bytes, but bad tzinfo: - self.assertRaises(TypeError, self.theclass, - bytes([1] * len(base)), 'EST') + if issubclass(self.theclass, datetime): + # Good bytes, but bad tzinfo: + with self.assertRaisesRegex(TypeError, '^bad tzinfo state arg$'): + self.theclass(bytes([1] * len(base)), 'EST') for ord_byte in range(1, 13): # This shouldn't blow up because of the month byte alone. If @@ -1534,6 +1547,9 @@ class TestDateTime(TestDate): dt = self.theclass(2007, 9, 10, 4, 5, 1, 123) self.assertEqual(dt.__format__(''), str(dt)) + with self.assertRaisesRegex(TypeError, '^must be str, not int$'): + dt.__format__(123) + # check that a derived class's __str__() gets called class A(self.theclass): def __str__(self): @@ -1913,6 +1929,7 @@ class TestDateTime(TestDate): for insane in -1e200, 1e200: self.assertRaises(OverflowError, self.theclass.utcfromtimestamp, insane) + @unittest.skipIf(sys.platform == "win32", "Windows doesn't accept negative timestamps") def test_negative_float_fromtimestamp(self): # The result is tz-dependent; at least test that this doesn't @@ -2292,6 +2309,9 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase): t = self.theclass(1, 2, 3, 4) self.assertEqual(t.__format__(''), str(t)) + with self.assertRaisesRegex(TypeError, '^must be str, not int$'): + t.__format__(123) + # check that a derived class's __str__() gets called class A(self.theclass): def __str__(self): @@ -2355,13 +2375,14 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase): self.assertEqual(orig, derived) def test_bool(self): + # time is always True. cls = self.theclass self.assertTrue(cls(1)) self.assertTrue(cls(0, 1)) self.assertTrue(cls(0, 0, 1)) self.assertTrue(cls(0, 0, 0, 1)) - self.assertFalse(cls(0)) - self.assertFalse(cls()) + self.assertTrue(cls(0)) + self.assertTrue(cls()) def test_replace(self): cls = self.theclass @@ -2420,6 +2441,9 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase): for hour_byte in ' ', '9', chr(24), '\xff': self.assertRaises(TypeError, self.theclass, hour_byte + base[1:]) + # Good bytes, but bad tzinfo: + with self.assertRaisesRegex(TypeError, '^bad tzinfo state arg$'): + self.theclass(bytes([1] * len(base)), 'EST') # A mixin for classes with a tzinfo= argument. Subclasses must define # theclass as a class atribute, and theclass(1, 1, 1, tzinfo=whatever) @@ -2679,7 +2703,7 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase): self.assertRaises(TypeError, t.strftime, "%Z") # Issue #6697: - if '_Fast' in str(type(self)): + if '_Fast' in str(self): Badtzname.tz = '\ud800' self.assertRaises(ValueError, t.strftime, "%Z") @@ -2714,7 +2738,7 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase): self.assertEqual(derived.tzname(), 'cookie') def test_more_bool(self): - # Test cases with non-None tzinfo. + # time is always True. cls = self.theclass t = cls(0, tzinfo=FixedOffset(-300, "")) @@ -2724,23 +2748,11 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase): self.assertTrue(t) t = cls(5, tzinfo=FixedOffset(300, "")) - self.assertFalse(t) + self.assertTrue(t) t = cls(23, 59, tzinfo=FixedOffset(23*60 + 59, "")) - self.assertFalse(t) - - # Mostly ensuring this doesn't overflow internally. - t = cls(0, tzinfo=FixedOffset(23*60 + 59, "")) self.assertTrue(t) - # But this should yield a value error -- the utcoffset is bogus. - t = cls(0, tzinfo=FixedOffset(24*60, "")) - self.assertRaises(ValueError, lambda: bool(t)) - - # Likewise. - t = cls(0, tzinfo=FixedOffset(-24*60, "")) - self.assertRaises(ValueError, lambda: bool(t)) - def test_replace(self): cls = self.theclass z100 = FixedOffset(100, "+100") @@ -3853,8 +3865,59 @@ class Oddballs(unittest.TestCase): self.assertEqual(as_datetime, datetime_sc) self.assertEqual(datetime_sc, as_datetime) -def test_main(): - support.run_unittest(__name__) + def test_extra_attributes(self): + for x in [date.today(), + time(), + datetime.utcnow(), + timedelta(), + tzinfo(), + timezone(timedelta())]: + with self.assertRaises(AttributeError): + x.abc = 1 + + def test_check_arg_types(self): + class Number: + def __init__(self, value): + self.value = value + def __int__(self): + return self.value + + for xx in [decimal.Decimal(10), + decimal.Decimal('10.9'), + Number(10)]: + self.assertEqual(datetime(10, 10, 10, 10, 10, 10, 10), + datetime(xx, xx, xx, xx, xx, xx, xx)) + + with self.assertRaisesRegex(TypeError, '^an integer is required ' + '\(got type str\)$'): + datetime(10, 10, '10') + + f10 = Number(10.9) + with self.assertRaisesRegex(TypeError, '^__int__ returned non-int ' + '\(type float\)$'): + datetime(10, 10, f10) + + class Float(float): + pass + s10 = Float(10.9) + with self.assertRaisesRegex(TypeError, '^integer argument expected, ' + 'got float$'): + datetime(10, 10, s10) + + with self.assertRaises(TypeError): + datetime(10., 10, 10) + with self.assertRaises(TypeError): + datetime(10, 10., 10) + with self.assertRaises(TypeError): + datetime(10, 10, 10.) + with self.assertRaises(TypeError): + datetime(10, 10, 10, 10.) + with self.assertRaises(TypeError): + datetime(10, 10, 10, 10, 10.) + with self.assertRaises(TypeError): + datetime(10, 10, 10, 10, 10, 10.) + with self.assertRaises(TypeError): + datetime(10, 10, 10, 10, 10, 10, 10.) if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/eintrdata/eintr_tester.py b/Lib/test/eintrdata/eintr_tester.py new file mode 100644 index 0000000..e1e0d91 --- /dev/null +++ b/Lib/test/eintrdata/eintr_tester.py @@ -0,0 +1,469 @@ +""" +This test suite exercises some system calls subject to interruption with EINTR, +to check that it is actually handled transparently. +It is intended to be run by the main test suite within a child process, to +ensure there is no background thread running (so that signals are delivered to +the correct thread). +Signals are generated in-process using setitimer(ITIMER_REAL), which allows +sub-second periodicity (contrarily to signal()). +""" + +import contextlib +import io +import os +import select +import signal +import socket +import subprocess +import sys +import time +import unittest + +from test import support + +@contextlib.contextmanager +def kill_on_error(proc): + """Context manager killing the subprocess if a Python exception is raised.""" + with proc: + try: + yield proc + except: + proc.kill() + raise + + +@unittest.skipUnless(hasattr(signal, "setitimer"), "requires setitimer()") +class EINTRBaseTest(unittest.TestCase): + """ Base class for EINTR tests. """ + + # delay for initial signal delivery + signal_delay = 0.1 + # signal delivery periodicity + signal_period = 0.1 + # default sleep time for tests - should obviously have: + # sleep_time > signal_period + sleep_time = 0.2 + + @classmethod + def setUpClass(cls): + cls.orig_handler = signal.signal(signal.SIGALRM, lambda *args: None) + signal.setitimer(signal.ITIMER_REAL, cls.signal_delay, + cls.signal_period) + + @classmethod + def stop_alarm(cls): + signal.setitimer(signal.ITIMER_REAL, 0, 0) + + @classmethod + def tearDownClass(cls): + cls.stop_alarm() + signal.signal(signal.SIGALRM, cls.orig_handler) + + @classmethod + def _sleep(cls): + # default sleep time + time.sleep(cls.sleep_time) + + def subprocess(self, *args, **kw): + cmd_args = (sys.executable, '-c') + args + return subprocess.Popen(cmd_args, **kw) + + +@unittest.skipUnless(hasattr(signal, "setitimer"), "requires setitimer()") +class OSEINTRTest(EINTRBaseTest): + """ EINTR tests for the os module. """ + + def new_sleep_process(self): + code = 'import time; time.sleep(%r)' % self.sleep_time + return self.subprocess(code) + + def _test_wait_multiple(self, wait_func): + num = 3 + processes = [self.new_sleep_process() for _ in range(num)] + for _ in range(num): + wait_func() + + def test_wait(self): + self._test_wait_multiple(os.wait) + + @unittest.skipUnless(hasattr(os, 'wait3'), 'requires wait3()') + def test_wait3(self): + self._test_wait_multiple(lambda: os.wait3(0)) + + def _test_wait_single(self, wait_func): + proc = self.new_sleep_process() + wait_func(proc.pid) + + def test_waitpid(self): + self._test_wait_single(lambda pid: os.waitpid(pid, 0)) + + @unittest.skipUnless(hasattr(os, 'wait4'), 'requires wait4()') + def test_wait4(self): + self._test_wait_single(lambda pid: os.wait4(pid, 0)) + + def test_read(self): + rd, wr = os.pipe() + self.addCleanup(os.close, rd) + # wr closed explicitly by parent + + # the payload below are smaller than PIPE_BUF, hence the writes will be + # atomic + datas = [b"hello", b"world", b"spam"] + + code = '\n'.join(( + 'import os, sys, time', + '', + 'wr = int(sys.argv[1])', + 'datas = %r' % datas, + 'sleep_time = %r' % self.sleep_time, + '', + 'for data in datas:', + ' # let the parent block on read()', + ' time.sleep(sleep_time)', + ' os.write(wr, data)', + )) + + proc = self.subprocess(code, str(wr), pass_fds=[wr]) + with kill_on_error(proc): + os.close(wr) + for data in datas: + self.assertEqual(data, os.read(rd, len(data))) + self.assertEqual(proc.wait(), 0) + + def test_write(self): + rd, wr = os.pipe() + self.addCleanup(os.close, wr) + # rd closed explicitly by parent + + # we must write enough data for the write() to block + data = b"x" * support.PIPE_MAX_SIZE + + code = '\n'.join(( + 'import io, os, sys, time', + '', + 'rd = int(sys.argv[1])', + 'sleep_time = %r' % self.sleep_time, + 'data = b"x" * %s' % support.PIPE_MAX_SIZE, + 'data_len = len(data)', + '', + '# let the parent block on write()', + 'time.sleep(sleep_time)', + '', + 'read_data = io.BytesIO()', + 'while len(read_data.getvalue()) < data_len:', + ' chunk = os.read(rd, 2 * data_len)', + ' read_data.write(chunk)', + '', + 'value = read_data.getvalue()', + 'if value != data:', + ' raise Exception("read error: %s vs %s bytes"', + ' % (len(value), data_len))', + )) + + proc = self.subprocess(code, str(rd), pass_fds=[rd]) + with kill_on_error(proc): + os.close(rd) + written = 0 + while written < len(data): + written += os.write(wr, memoryview(data)[written:]) + self.assertEqual(proc.wait(), 0) + + +@unittest.skipUnless(hasattr(signal, "setitimer"), "requires setitimer()") +class SocketEINTRTest(EINTRBaseTest): + """ EINTR tests for the socket module. """ + + @unittest.skipUnless(hasattr(socket, 'socketpair'), 'needs socketpair()') + def _test_recv(self, recv_func): + rd, wr = socket.socketpair() + self.addCleanup(rd.close) + # wr closed explicitly by parent + + # single-byte payload guard us against partial recv + datas = [b"x", b"y", b"z"] + + code = '\n'.join(( + 'import os, socket, sys, time', + '', + 'fd = int(sys.argv[1])', + 'family = %s' % int(wr.family), + 'sock_type = %s' % int(wr.type), + 'datas = %r' % datas, + 'sleep_time = %r' % self.sleep_time, + '', + 'wr = socket.fromfd(fd, family, sock_type)', + 'os.close(fd)', + '', + 'with wr:', + ' for data in datas:', + ' # let the parent block on recv()', + ' time.sleep(sleep_time)', + ' wr.sendall(data)', + )) + + fd = wr.fileno() + proc = self.subprocess(code, str(fd), pass_fds=[fd]) + with kill_on_error(proc): + wr.close() + for data in datas: + self.assertEqual(data, recv_func(rd, len(data))) + self.assertEqual(proc.wait(), 0) + + def test_recv(self): + self._test_recv(socket.socket.recv) + + @unittest.skipUnless(hasattr(socket.socket, 'recvmsg'), 'needs recvmsg()') + def test_recvmsg(self): + self._test_recv(lambda sock, data: sock.recvmsg(data)[0]) + + def _test_send(self, send_func): + rd, wr = socket.socketpair() + self.addCleanup(wr.close) + # rd closed explicitly by parent + + # we must send enough data for the send() to block + data = b"xyz" * (support.SOCK_MAX_SIZE // 3) + + code = '\n'.join(( + 'import os, socket, sys, time', + '', + 'fd = int(sys.argv[1])', + 'family = %s' % int(rd.family), + 'sock_type = %s' % int(rd.type), + 'sleep_time = %r' % self.sleep_time, + 'data = b"xyz" * %s' % (support.SOCK_MAX_SIZE // 3), + 'data_len = len(data)', + '', + 'rd = socket.fromfd(fd, family, sock_type)', + 'os.close(fd)', + '', + 'with rd:', + ' # let the parent block on send()', + ' time.sleep(sleep_time)', + '', + ' received_data = bytearray(data_len)', + ' n = 0', + ' while n < data_len:', + ' n += rd.recv_into(memoryview(received_data)[n:])', + '', + 'if received_data != data:', + ' raise Exception("recv error: %s vs %s bytes"', + ' % (len(received_data), data_len))', + )) + + fd = rd.fileno() + proc = self.subprocess(code, str(fd), pass_fds=[fd]) + with kill_on_error(proc): + rd.close() + written = 0 + while written < len(data): + sent = send_func(wr, memoryview(data)[written:]) + # sendall() returns None + written += len(data) if sent is None else sent + self.assertEqual(proc.wait(), 0) + + def test_send(self): + self._test_send(socket.socket.send) + + def test_sendall(self): + self._test_send(socket.socket.sendall) + + @unittest.skipUnless(hasattr(socket.socket, 'sendmsg'), 'needs sendmsg()') + def test_sendmsg(self): + self._test_send(lambda sock, data: sock.sendmsg([data])) + + def test_accept(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.addCleanup(sock.close) + + sock.bind((support.HOST, 0)) + port = sock.getsockname()[1] + sock.listen() + + code = '\n'.join(( + 'import socket, time', + '', + 'host = %r' % support.HOST, + 'port = %s' % port, + 'sleep_time = %r' % self.sleep_time, + '', + '# let parent block on accept()', + 'time.sleep(sleep_time)', + 'with socket.create_connection((host, port)):', + ' time.sleep(sleep_time)', + )) + + proc = self.subprocess(code) + with kill_on_error(proc): + client_sock, _ = sock.accept() + client_sock.close() + self.assertEqual(proc.wait(), 0) + + # Issue #25122: There is a race condition in the FreeBSD kernel on + # handling signals in the FIFO device. Skip the test until the bug is + # fixed in the kernel. + # https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=203162 + @support.requires_freebsd_version(10, 3) + @unittest.skipUnless(hasattr(os, 'mkfifo'), 'needs mkfifo()') + def _test_open(self, do_open_close_reader, do_open_close_writer): + filename = support.TESTFN + + # Use a fifo: until the child opens it for reading, the parent will + # block when trying to open it for writing. + support.unlink(filename) + os.mkfifo(filename) + self.addCleanup(support.unlink, filename) + + code = '\n'.join(( + 'import os, time', + '', + 'path = %a' % filename, + 'sleep_time = %r' % self.sleep_time, + '', + '# let the parent block', + 'time.sleep(sleep_time)', + '', + do_open_close_reader, + )) + + proc = self.subprocess(code) + with kill_on_error(proc): + do_open_close_writer(filename) + self.assertEqual(proc.wait(), 0) + + def python_open(self, path): + fp = open(path, 'w') + fp.close() + + def test_open(self): + self._test_open("fp = open(path, 'r')\nfp.close()", + self.python_open) + + def os_open(self, path): + fd = os.open(path, os.O_WRONLY) + os.close(fd) + + def test_os_open(self): + self._test_open("fd = os.open(path, os.O_RDONLY)\nos.close(fd)", + self.os_open) + + +@unittest.skipUnless(hasattr(signal, "setitimer"), "requires setitimer()") +class TimeEINTRTest(EINTRBaseTest): + """ EINTR tests for the time module. """ + + def test_sleep(self): + t0 = time.monotonic() + time.sleep(self.sleep_time) + self.stop_alarm() + dt = time.monotonic() - t0 + self.assertGreaterEqual(dt, self.sleep_time) + + +@unittest.skipUnless(hasattr(signal, "setitimer"), "requires setitimer()") +class SignalEINTRTest(EINTRBaseTest): + """ EINTR tests for the signal module. """ + + @unittest.skipUnless(hasattr(signal, 'sigtimedwait'), + 'need signal.sigtimedwait()') + def test_sigtimedwait(self): + t0 = time.monotonic() + signal.sigtimedwait([signal.SIGUSR1], self.sleep_time) + dt = time.monotonic() - t0 + self.assertGreaterEqual(dt, self.sleep_time) + + @unittest.skipUnless(hasattr(signal, 'sigwaitinfo'), + 'need signal.sigwaitinfo()') + def test_sigwaitinfo(self): + signum = signal.SIGUSR1 + pid = os.getpid() + + old_handler = signal.signal(signum, lambda *args: None) + self.addCleanup(signal.signal, signum, old_handler) + + code = '\n'.join(( + 'import os, time', + 'pid = %s' % os.getpid(), + 'signum = %s' % int(signum), + 'sleep_time = %r' % self.sleep_time, + 'time.sleep(sleep_time)', + 'os.kill(pid, signum)', + )) + + t0 = time.monotonic() + proc = self.subprocess(code) + with kill_on_error(proc): + # parent + signal.sigwaitinfo([signum]) + dt = time.monotonic() - t0 + self.assertEqual(proc.wait(), 0) + + self.assertGreaterEqual(dt, self.sleep_time) + + +@unittest.skipUnless(hasattr(signal, "setitimer"), "requires setitimer()") +class SelectEINTRTest(EINTRBaseTest): + """ EINTR tests for the select module. """ + + def test_select(self): + t0 = time.monotonic() + select.select([], [], [], self.sleep_time) + dt = time.monotonic() - t0 + self.stop_alarm() + self.assertGreaterEqual(dt, self.sleep_time) + + @unittest.skipUnless(hasattr(select, 'poll'), 'need select.poll') + def test_poll(self): + poller = select.poll() + + t0 = time.monotonic() + poller.poll(self.sleep_time * 1e3) + dt = time.monotonic() - t0 + self.stop_alarm() + self.assertGreaterEqual(dt, self.sleep_time) + + @unittest.skipUnless(hasattr(select, 'epoll'), 'need select.epoll') + def test_epoll(self): + poller = select.epoll() + self.addCleanup(poller.close) + + t0 = time.monotonic() + poller.poll(self.sleep_time) + dt = time.monotonic() - t0 + self.stop_alarm() + self.assertGreaterEqual(dt, self.sleep_time) + + @unittest.skipUnless(hasattr(select, 'kqueue'), 'need select.kqueue') + def test_kqueue(self): + kqueue = select.kqueue() + self.addCleanup(kqueue.close) + + t0 = time.monotonic() + kqueue.control(None, 1, self.sleep_time) + dt = time.monotonic() - t0 + self.stop_alarm() + self.assertGreaterEqual(dt, self.sleep_time) + + @unittest.skipUnless(hasattr(select, 'devpoll'), 'need select.devpoll') + def test_devpoll(self): + poller = select.devpoll() + self.addCleanup(poller.close) + + t0 = time.monotonic() + poller.poll(self.sleep_time * 1e3) + dt = time.monotonic() - t0 + self.stop_alarm() + self.assertGreaterEqual(dt, self.sleep_time) + + +def test_main(): + support.run_unittest( + OSEINTRTest, + SocketEINTRTest, + TimeEINTRTest, + SignalEINTRTest, + SelectEINTRTest) + + +if __name__ == "__main__": + test_main() diff --git a/Lib/test/exception_hierarchy.txt b/Lib/test/exception_hierarchy.txt index 1c1f69f..0513765 100644 --- a/Lib/test/exception_hierarchy.txt +++ b/Lib/test/exception_hierarchy.txt @@ -4,6 +4,7 @@ BaseException +-- GeneratorExit +-- Exception +-- StopIteration + +-- StopAsyncIteration +-- ArithmeticError | +-- FloatingPointError | +-- OverflowError @@ -38,6 +39,7 @@ BaseException +-- ReferenceError +-- RuntimeError | +-- NotImplementedError + | +-- RecursionError +-- SyntaxError | +-- IndentationError | +-- TabError diff --git a/Lib/test/fork_wait.py b/Lib/test/fork_wait.py index 19b54ec..713039d 100644 --- a/Lib/test/fork_wait.py +++ b/Lib/test/fork_wait.py @@ -48,7 +48,12 @@ class ForkWait(unittest.TestCase): for i in range(NUM_THREADS): _thread.start_new(self.f, (i,)) - time.sleep(LONGSLEEP) + # busy-loop to wait for threads + deadline = time.monotonic() + 10.0 + while len(self.alive) < NUM_THREADS: + time.sleep(0.1) + if deadline < time.monotonic(): + break a = sorted(self.alive.keys()) self.assertEqual(a, list(range(NUM_THREADS))) diff --git a/Lib/test/imghdrdata/python.exr b/Lib/test/imghdrdata/python.exr Binary files differnew file mode 100644 index 0000000..773c81e --- /dev/null +++ b/Lib/test/imghdrdata/python.exr diff --git a/Lib/test/imghdrdata/python.webp b/Lib/test/imghdrdata/python.webp Binary files differnew file mode 100644 index 0000000..e824ec7 --- /dev/null +++ b/Lib/test/imghdrdata/python.webp diff --git a/Lib/test/imp_dummy.py b/Lib/test/imp_dummy.py new file mode 100644 index 0000000..2a4deb4 --- /dev/null +++ b/Lib/test/imp_dummy.py @@ -0,0 +1,3 @@ +# Fodder for test of issue24748 in test_imp + +dummy_name = True diff --git a/Lib/test/inspect_fodder.py b/Lib/test/inspect_fodder.py index 0c1d810..068d825 100644 --- a/Lib/test/inspect_fodder.py +++ b/Lib/test/inspect_fodder.py @@ -45,9 +45,16 @@ class StupidGit: self.ex = sys.exc_info() self.tr = inspect.trace() + def contradiction(self): + 'The automatic gainsaying.' + pass + # line 48 class MalodorousPervert(StupidGit): - pass + def abuse(self, a, b, c): + pass + def contradiction(self): + pass Tit = MalodorousPervert @@ -55,4 +62,10 @@ class ParrotDroppings: pass class FesteringGob(MalodorousPervert, ParrotDroppings): + def abuse(self, a, b, c): + pass + def contradiction(self): + pass + +async def lobbest(grenade): pass diff --git a/Lib/test/inspect_fodder2.py b/Lib/test/inspect_fodder2.py index bd7106f..c6987ea 100644 --- a/Lib/test/inspect_fodder2.py +++ b/Lib/test/inspect_fodder2.py @@ -109,3 +109,31 @@ def annotated(arg1: list): #line 109 def keyword_only_arg(*, arg): pass + +@wrap(lambda: None) +def func114(): + return 115 + +class ClassWithMethod: + def method(self): + pass + +from functools import wraps + +def decorator(func): + @wraps(func) + def fake(): + return 42 + return fake + +#line 129 +@decorator +def real(): + return 20 + +#line 134 +class cls135: + def func136(): + def func137(): + never_reached1 + never_reached2 diff --git a/Lib/test/list_tests.py b/Lib/test/list_tests.py index 42e118b..1adfc75 100644 --- a/Lib/test/list_tests.py +++ b/Lib/test/list_tests.py @@ -30,6 +30,12 @@ class CommonTest(seq_tests.CommonTest): self.assertNotEqual(id(a), id(b)) self.assertEqual(a, b) + def test_getitem_error(self): + msg = "list indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + a = [] + a['a'] = "python" + def test_repr(self): l0 = [] l2 = [0, 1, 2] @@ -50,7 +56,7 @@ class CommonTest(seq_tests.CommonTest): l0 = [] for i in range(sys.getrecursionlimit() + 100): l0 = [l0] - self.assertRaises(RuntimeError, repr, l0) + self.assertRaises(RecursionError, repr, l0) def test_print(self): d = self.type2test(range(200)) @@ -120,6 +126,10 @@ class CommonTest(seq_tests.CommonTest): a[-1] = 9 self.assertEqual(a, self.type2test([5,6,7,8,9])) + msg = "list indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + a['a'] = "python" + def test_delitem(self): a = self.type2test([0, 1]) del a[1] diff --git a/Lib/test/lock_tests.py b/Lib/test/lock_tests.py index 42a7d82..b325bce 100644 --- a/Lib/test/lock_tests.py +++ b/Lib/test/lock_tests.py @@ -86,7 +86,13 @@ class BaseLockTests(BaseTestCase): def test_repr(self): lock = self.locktype() - repr(lock) + self.assertRegex(repr(lock), "<unlocked .* object (.*)?at .*>") + del lock + + def test_locked_repr(self): + lock = self.locktype() + lock.acquire() + self.assertRegex(repr(lock), "<locked .* object (.*)?at .*>") del lock def test_acquire_destroy(self): diff --git a/Lib/test/mock_socket.py b/Lib/test/mock_socket.py index e36724f..b28c473 100644 --- a/Lib/test/mock_socket.py +++ b/Lib/test/mock_socket.py @@ -35,8 +35,9 @@ class MockFile: class MockSocket: """Mock socket object used by smtpd and smtplib tests. """ - def __init__(self): + def __init__(self, family=None): global _reply_data + self.family = family self.output = [] self.lines = [] if _reply_data: @@ -101,15 +102,14 @@ class MockSocket: return len(data) def getpeername(self): - return 'peer' + return ('peer-address', 'peer-port') def close(self): pass def socket(family=None, type=None, proto=None): - return MockSocket() - + return MockSocket(family) def create_connection(address, timeout=socket_module._GLOBAL_DEFAULT_TIMEOUT, source_address=None): @@ -144,13 +144,16 @@ def gethostname(): def gethostbyname(name): return "" +def getaddrinfo(*args, **kw): + return socket_module.getaddrinfo(*args, **kw) gaierror = socket_module.gaierror error = socket_module.error # Constants -AF_INET = None -SOCK_STREAM = None +AF_INET = socket_module.AF_INET +AF_INET6 = socket_module.AF_INET6 +SOCK_STREAM = socket_module.SOCK_STREAM SOL_SOCKET = None SO_REUSEADDR = None diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index c6f4f6c..2ef48e6 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -1540,16 +1540,62 @@ class AbstractPickleTests(unittest.TestCase): self.assertGreaterEqual(num_additems, 2) def test_simple_newobj(self): - x = object.__new__(SimpleNewObj) # avoid __init__ + x = SimpleNewObj.__new__(SimpleNewObj, 0xface) # avoid __init__ x.abc = 666 for proto in protocols: - s = self.dumps(x, proto) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), - 2 <= proto < 4) - self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), - proto >= 4) - y = self.loads(s) # will raise TypeError if __init__ called - self.assert_is_copy(x, y) + with self.subTest(proto=proto): + s = self.dumps(x, proto) + if proto < 1: + self.assertIn(b'\nL64206', s) # LONG + else: + self.assertIn(b'M\xce\xfa', s) # BININT2 + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto) + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) + y = self.loads(s) # will raise TypeError if __init__ called + self.assert_is_copy(x, y) + + def test_complex_newobj(self): + x = ComplexNewObj.__new__(ComplexNewObj, 0xface) # avoid __init__ + x.abc = 666 + for proto in protocols: + with self.subTest(proto=proto): + s = self.dumps(x, proto) + if proto < 1: + self.assertIn(b'\nL64206', s) # LONG + elif proto < 2: + self.assertIn(b'M\xce\xfa', s) # BININT2 + elif proto < 4: + self.assertIn(b'X\x04\x00\x00\x00FACE', s) # BINUNICODE + else: + self.assertIn(b'\x8c\x04FACE', s) # SHORT_BINUNICODE + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), + 2 <= proto) + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ_EX, s)) + y = self.loads(s) # will raise TypeError if __init__ called + self.assert_is_copy(x, y) + + def test_complex_newobj_ex(self): + x = ComplexNewObjEx.__new__(ComplexNewObjEx, 0xface) # avoid __init__ + x.abc = 666 + for proto in protocols: + with self.subTest(proto=proto): + if 2 <= proto < 4: + self.assertRaises(ValueError, self.dumps, x, proto) + continue + s = self.dumps(x, proto) + if proto < 1: + self.assertIn(b'\nL64206', s) # LONG + elif proto < 2: + self.assertIn(b'M\xce\xfa', s) # BININT2 + else: + assert proto >= 4 + self.assertIn(b'\x8c\x04FACE', s) # SHORT_BINUNICODE + self.assertFalse(opcode_in_pickle(pickle.NEWOBJ, s)) + self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s), + 4 <= proto) + y = self.loads(s) # will raise TypeError if __init__ called + self.assert_is_copy(x, y) def test_newobj_list_slots(self): x = SlotList([1, 2, 3]) @@ -1813,13 +1859,24 @@ class AbstractPickleTests(unittest.TestCase): class B: class C: pass - - for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): for obj in [Nested.A, Nested.A.B, Nested.A.B.C]: with self.subTest(proto=proto, obj=obj): unpickled = self.loads(self.dumps(obj, proto)) self.assertIs(obj, unpickled) + def test_recursive_nested_names(self): + global Recursive + class Recursive: + pass + Recursive.mod = sys.modules[Recursive.__module__] + Recursive.__qualname__ = 'Recursive.mod.Recursive' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + unpickled = self.loads(self.dumps(Recursive, proto)) + self.assertIs(unpickled, Recursive) + del Recursive.mod # break reference loop + def test_py_methods(self): global PyMethodsTest class PyMethodsTest: @@ -1858,7 +1915,7 @@ class AbstractPickleTests(unittest.TestCase): (PyMethodsTest.biscuits, PyMethodsTest), (PyMethodsTest.Nested.pie, PyMethodsTest.Nested) ) - for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): for method in py_methods: with self.subTest(proto=proto, method=method): unpickled = self.loads(self.dumps(method, proto)) @@ -1898,7 +1955,7 @@ class AbstractPickleTests(unittest.TestCase): (Subclass.Nested("sweet").count, ("e",)), (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), ) - for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): for method, args in c_methods: with self.subTest(proto=proto, method=method): unpickled = self.loads(self.dumps(method, proto)) @@ -1922,6 +1979,27 @@ class AbstractPickleTests(unittest.TestCase): self.assertIn(('c%s\n%s' % (mod, name)).encode(), pickled) self.assertIs(type(self.loads(pickled)), type(val)) + def test_local_lookup_error(self): + # Test that whichmodule() errors out cleanly when looking up + # an assumed globally-reachable object fails. + def f(): + pass + # Since the function is local, lookup will fail + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((AttributeError, pickle.PicklingError)): + pickletools.dis(self.dumps(f, proto)) + # Same without a __module__ attribute (exercises a different path + # in _pickle.c). + del f.__module__ + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((AttributeError, pickle.PicklingError)): + pickletools.dis(self.dumps(f, proto)) + # Yet a different path. + f.__name__ = f.__qualname__ + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises((AttributeError, pickle.PicklingError)): + pickletools.dis(self.dumps(f, proto)) + class BigmemPickleTests(unittest.TestCase): @@ -2156,12 +2234,20 @@ myclasses = [MyInt, MyFloat, class SlotList(MyList): __slots__ = ["foo"] -class SimpleNewObj(object): - def __init__(self, a, b, c): +class SimpleNewObj(int): + def __init__(self, *args, **kwargs): # raise an error, to make sure this isn't called raise TypeError("SimpleNewObj.__init__() didn't expect to get called") def __eq__(self, other): - return self.__dict__ == other.__dict__ + return int(self) == int(other) and self.__dict__ == other.__dict__ + +class ComplexNewObj(SimpleNewObj): + def __getnewargs__(self): + return ('%X' % self, 16) + +class ComplexNewObjEx(SimpleNewObj): + def __getnewargs_ex__(self): + return ('%X' % self,), {'base': 16} class BadGetattr: def __getattr__(self, key): diff --git a/Lib/test/pystone.py b/Lib/test/pystone.py index a41f1e5..1f67e66 100755 --- a/Lib/test/pystone.py +++ b/Lib/test/pystone.py @@ -41,7 +41,7 @@ Version History: LOOPS = 50000 -from time import clock +from time import time __version__ = "1.2" @@ -93,10 +93,10 @@ def Proc0(loops=LOOPS): global PtrGlb global PtrGlbNext - starttime = clock() + starttime = time() for i in range(loops): pass - nulltime = clock() - starttime + nulltime = time() - starttime PtrGlbNext = Record() PtrGlb = Record() @@ -108,7 +108,7 @@ def Proc0(loops=LOOPS): String1Loc = "DHRYSTONE PROGRAM, 1'ST STRING" Array2Glob[8][7] = 10 - starttime = clock() + starttime = time() for i in range(loops): Proc5() @@ -134,7 +134,7 @@ def Proc0(loops=LOOPS): IntLoc2 = 7 * (IntLoc3 - IntLoc2) - IntLoc1 IntLoc1 = Proc2(IntLoc1) - benchtime = clock() - starttime - nulltime + benchtime = time() - starttime - nulltime if benchtime == 0.0: loopsPerBenchtime = 0.0 else: diff --git a/Lib/test/re_tests.py b/Lib/test/re_tests.py index 7f8075e..8c158f8 100755 --- a/Lib/test/re_tests.py +++ b/Lib/test/re_tests.py @@ -87,7 +87,7 @@ tests = [ (r'[\a][\b][\f][\n][\r][\t][\v]', '\a\b\f\n\r\t\v', SUCCEED, 'found', '\a\b\f\n\r\t\v'), # NOTE: not an error under PCRE/PRE: (r'\u', '', SYNTAX_ERROR), # A Perl escape - (r'\c\e\g\h\i\j\k\m\o\p\q\y\z', 'ceghijkmopqyz', SUCCEED, 'found', 'ceghijkmopqyz'), + # (r'\c\e\g\h\i\j\k\m\o\p\q\y\z', 'ceghijkmopqyz', SUCCEED, 'found', 'ceghijkmopqyz'), (r'\xff', '\377', SUCCEED, 'found', chr(255)), # new \x semantics (r'\x00ffffffffffffff', '\377', FAIL, 'found', chr(255)), @@ -607,8 +607,8 @@ xyzabc # new \x semantics (r'\x00ff', '\377', FAIL), # (r'\x00ff', '\377', SUCCEED, 'found', chr(255)), - (r'\t\n\v\r\f\a\g', '\t\n\v\r\f\ag', SUCCEED, 'found', '\t\n\v\r\f\ag'), - ('\t\n\v\r\f\a\g', '\t\n\v\r\f\ag', SUCCEED, 'found', '\t\n\v\r\f\ag'), + (r'\t\n\v\r\f\a', '\t\n\v\r\f\a', SUCCEED, 'found', '\t\n\v\r\f\a'), + ('\t\n\v\r\f\a', '\t\n\v\r\f\a', SUCCEED, 'found', '\t\n\v\r\f\a'), (r'\t\n\v\r\f\a', '\t\n\v\r\f\a', SUCCEED, 'found', chr(9)+chr(10)+chr(11)+chr(13)+chr(12)+chr(7)), (r'[\t][\n][\v][\r][\f][\b]', '\t\n\v\r\f\b', SUCCEED, 'found', '\t\n\v\r\f\b'), diff --git a/Lib/test/regrtest.py b/Lib/test/regrtest.py index 6708876..a04c3f7 100755 --- a/Lib/test/regrtest.py +++ b/Lib/test/regrtest.py @@ -322,6 +322,8 @@ def _create_parser(): group.add_argument('-F', '--forever', action='store_true', help='run the specified tests in a loop, until an ' 'error happens') + group.add_argument('-P', '--pgo', dest='pgo', action='store_true', + help='enable Profile Guided Optimization training') parser.add_argument('args', nargs=argparse.REMAINDER, help=argparse.SUPPRESS) @@ -361,7 +363,7 @@ def _parse_args(args, **kwargs): findleaks=False, use_resources=None, trace=False, coverdir='coverage', runleaks=False, huntrleaks=False, verbose2=False, print_slow=False, random_seed=None, use_mp=None, verbose3=False, forever=False, - header=False, failfast=False, match_tests=None) + header=False, failfast=False, match_tests=None, pgo=False) for k, v in kwargs.items(): if not hasattr(ns, k): raise TypeError('%r is an invalid keyword argument ' @@ -435,14 +437,16 @@ def run_test_in_subprocess(testname, ns): from subprocess import Popen, PIPE base_cmd = ([sys.executable] + support.args_from_interpreter_flags() + ['-X', 'faulthandler', '-m', 'test.regrtest']) - + # required to spawn a new process with PGO flag on/off + if ns.pgo: + base_cmd = base_cmd + ['--pgo'] slaveargs = ( (testname, ns.verbose, ns.quiet), dict(huntrleaks=ns.huntrleaks, use_resources=ns.use_resources, output_on_failure=ns.verbose3, timeout=ns.timeout, failfast=ns.failfast, - match_tests=ns.match_tests)) + match_tests=ns.match_tests, pgo=ns.pgo)) # Running the child from the same working directory as regrtest's original # invocation ensures that TEMPDIR for the child is the same when # sysconfig.is_python_build() is true. See issue 15300. @@ -596,13 +600,14 @@ def main(tests=None, **kwargs): ns.args = [] # For a partial run, we do not need to clutter the output. - if ns.verbose or ns.header or not (ns.quiet or ns.single or tests or ns.args): + if (ns.verbose or ns.header or + not (ns.pgo or ns.quiet or ns.single or tests or ns.args)): # Print basic platform information print("==", platform.python_implementation(), *sys.version.split()) print("== ", platform.platform(aliased=True), - "%s-endian" % sys.byteorder) + "%s-endian" % sys.byteorder) print("== ", "hash algorithm:", sys.hash_info.algorithm, - "64bit" if sys.maxsize > 2**32 else "32bit") + "64bit" if sys.maxsize > 2**32 else "32bit") print("== ", os.getcwd()) print("Testing with flags:", sys.flags) @@ -722,13 +727,16 @@ def main(tests=None, **kwargs): continue accumulate_result(test, result) if not ns.quiet: - fmt = "[{1:{0}}{2}/{3}] {4}" if bad else "[{1:{0}}{2}] {4}" + if bad and not ns.pgo: + fmt = "[{1:{0}}{2}/{3}] {4}" + else: + fmt = "[{1:{0}}{2}] {4}" print(fmt.format( test_count_width, test_index, test_count, len(bad), test)) if stdout: print(stdout) - if stderr: + if stderr and not ns.pgo: print(stderr, file=sys.stderr) sys.stdout.flush() sys.stderr.flush() @@ -745,7 +753,10 @@ def main(tests=None, **kwargs): else: for test_index, test in enumerate(tests, 1): if not ns.quiet: - fmt = "[{1:{0}}{2}/{3}] {4}" if bad else "[{1:{0}}{2}] {4}" + if bad and not ns.pgo: + fmt = "[{1:{0}}{2}/{3}] {4}" + else: + fmt = "[{1:{0}}{2}] {4}" print(fmt.format( test_count_width, test_index, test_count, len(bad), test)) sys.stdout.flush() @@ -760,13 +771,11 @@ def main(tests=None, **kwargs): ns.huntrleaks, output_on_failure=ns.verbose3, timeout=ns.timeout, failfast=ns.failfast, - match_tests=ns.match_tests) + match_tests=ns.match_tests, pgo=ns.pgo) accumulate_result(test, result) except KeyboardInterrupt: interrupted = True break - except: - raise if ns.findleaks: gc.collect() if gc.garbage: @@ -781,14 +790,14 @@ def main(tests=None, **kwargs): if module not in save_modules and module.startswith("test."): support.unload(module) - if interrupted: + if interrupted and not ns.pgo: # print a newline after ^C print() print("Test suite interrupted by signal SIGINT.") omitted = set(selected) - set(good) - set(bad) - set(skipped) print(count(len(omitted), "test"), "omitted:") printlist(omitted) - if good and not ns.quiet: + if good and not ns.quiet and not ns.pgo: if not bad and not skipped and not interrupted and len(good) > 1: print("All", end=' ') print(count(len(good), "test"), "OK.") @@ -797,26 +806,27 @@ def main(tests=None, **kwargs): print("10 slowest tests:") for time, test in test_times[:10]: print("%s: %.1fs" % (test, time)) - if bad: + if bad and not ns.pgo: print(count(len(bad), "test"), "failed:") printlist(bad) - if environment_changed: + if environment_changed and not ns.pgo: print("{} altered the execution environment:".format( count(len(environment_changed), "test"))) printlist(environment_changed) - if skipped and not ns.quiet: + if skipped and not ns.quiet and not ns.pgo: print(count(len(skipped), "test"), "skipped:") printlist(skipped) if ns.verbose2 and bad: print("Re-running failed tests in verbose mode") for test in bad[:]: - print("Re-running test %r in verbose mode" % test) + if not ns.pgo: + print("Re-running test %r in verbose mode" % test) sys.stdout.flush() try: ns.verbose = True ok = runtest(test, True, ns.quiet, ns.huntrleaks, - timeout=ns.timeout) + timeout=ns.timeout, pgo=ns.pgo) except KeyboardInterrupt: # print a newline separate from the ^C print() @@ -915,7 +925,7 @@ def replace_stdout(): def runtest(test, verbose, quiet, huntrleaks=False, use_resources=None, output_on_failure=False, failfast=False, match_tests=None, - timeout=None): + timeout=None, *, pgo=False): """Run a single test. test -- the name of the test @@ -928,6 +938,8 @@ def runtest(test, verbose, quiet, timeout -- dump the traceback and exit if a test takes more than timeout seconds failfast, match_tests -- See regrtest command-line flags for these. + pgo -- if true, do not print unnecessary info when running the test + for Profile Guided Optimization build Returns the tuple result, test_time, where result is one of the constants: INTERRUPTED KeyboardInterrupt when run under -j @@ -937,7 +949,6 @@ def runtest(test, verbose, quiet, FAILED test failed PASSED test passed """ - if use_resources is not None: support.use_resources = use_resources use_timeout = (timeout is not None) @@ -967,8 +978,8 @@ def runtest(test, verbose, quiet, sys.stdout = stream sys.stderr = stream result = runtest_inner(test, verbose, quiet, huntrleaks, - display_failure=False) - if result[0] == FAILED: + display_failure=False, pgo=pgo) + if result[0] == FAILED and not pgo: output = stream.getvalue() orig_stderr.write(output) orig_stderr.flush() @@ -978,7 +989,7 @@ def runtest(test, verbose, quiet, else: support.verbose = verbose # Tell tests to be moderately quiet result = runtest_inner(test, verbose, quiet, huntrleaks, - display_failure=not verbose) + display_failure=not verbose, pgo=pgo) return result finally: if use_timeout: @@ -1010,10 +1021,11 @@ class saved_test_environment: changed = False - def __init__(self, testname, verbose=0, quiet=False): + def __init__(self, testname, verbose=0, quiet=False, *, pgo=False): self.testname = testname self.verbose = verbose self.quiet = quiet + self.pgo = pgo # To add things to save and restore, add a name XXX to the resources list # and add corresponding get_XXX/restore_XXX functions. get_XXX should @@ -1242,11 +1254,11 @@ class saved_test_environment: if current != original: self.changed = True restore(original) - if not self.quiet: + if not self.quiet and not self.pgo: print("Warning -- {} was modified by {}".format( name, self.testname), file=sys.stderr) - if self.verbose > 1: + if self.verbose > 1 and not self.pgo: print(" Before: {}\n After: {} ".format( original, current), file=sys.stderr) @@ -1254,7 +1266,7 @@ class saved_test_environment: def runtest_inner(test, verbose, quiet, - huntrleaks=False, display_failure=True): + huntrleaks=False, display_failure=True, pgo=False): support.unload(test) test_time = 0.0 @@ -1265,7 +1277,7 @@ def runtest_inner(test, verbose, quiet, else: # Always import it from the test package abstest = 'test.' + test - with saved_test_environment(test, verbose, quiet) as environment: + with saved_test_environment(test, verbose, quiet, pgo=pgo) as environment: start_time = time.time() the_module = importlib.import_module(abstest) # If the test has a test_main, that will run the appropriate @@ -1275,33 +1287,39 @@ def runtest_inner(test, verbose, quiet, def test_runner(): loader = unittest.TestLoader() tests = loader.loadTestsFromModule(the_module) + for error in loader.errors: + print(error, file=sys.stderr) + if loader.errors: + raise Exception("errors while loading tests") support.run_unittest(tests) test_runner() if huntrleaks: refleak = dash_R(the_module, test, test_runner, huntrleaks) test_time = time.time() - start_time except support.ResourceDenied as msg: - if not quiet: + if not quiet and not pgo: print(test, "skipped --", msg) sys.stdout.flush() return RESOURCE_DENIED, test_time except unittest.SkipTest as msg: - if not quiet: + if not quiet and not pgo: print(test, "skipped --", msg) sys.stdout.flush() return SKIPPED, test_time except KeyboardInterrupt: raise except support.TestFailed as msg: - if display_failure: - print("test", test, "failed --", msg, file=sys.stderr) - else: - print("test", test, "failed", file=sys.stderr) + if not pgo: + if display_failure: + print("test", test, "failed --", msg, file=sys.stderr) + else: + print("test", test, "failed", file=sys.stderr) sys.stderr.flush() return FAILED, test_time except: msg = traceback.format_exc() - print("test", test, "crashed --", msg, file=sys.stderr) + if not pgo: + print("test", test, "crashed --", msg, file=sys.stderr) sys.stderr.flush() return FAILED, test_time else: diff --git a/Lib/test/ssl_servers.py b/Lib/test/ssl_servers.py index 759b3f4..f9d30cf 100644 --- a/Lib/test/ssl_servers.py +++ b/Lib/test/ssl_servers.py @@ -150,7 +150,7 @@ class HTTPSServerThread(threading.Thread): def make_https_server(case, *, context=None, certfile=CERTFILE, host=HOST, handler_class=None): if context is None: - context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) # We assume the certfile contains both private key and certificate context.load_cert_chain(certfile) server = HTTPSServerThread(context, host, handler_class) @@ -182,6 +182,8 @@ if __name__ == "__main__": parser.add_argument('--curve-name', dest='curve_name', type=str, action='store', help='curve name for EC-based Diffie-Hellman') + parser.add_argument('--ciphers', dest='ciphers', type=str, + help='allowed cipher list') parser.add_argument('--dh', dest='dh_file', type=str, action='store', help='PEM file containing DH parameters') args = parser.parse_args() @@ -192,12 +194,14 @@ if __name__ == "__main__": else: handler_class = RootedHTTPRequestHandler handler_class.root = os.getcwd() - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) context.load_cert_chain(CERTFILE) if args.curve_name: context.set_ecdh_curve(args.curve_name) if args.dh_file: context.load_dh_params(args.dh_file) + if args.ciphers: + context.set_ciphers(args.ciphers) server = HTTPSServer(("", args.port), handler_class, context) if args.verbose: diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index 242a931..e086994 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -976,6 +976,9 @@ class MixinStrUnicodeUserStringTest: self.checkequal(True, 'helloworld', 'startswith', 'lowo', 3) self.checkequal(True, 'helloworld', 'startswith', 'lowo', 3, 7) self.checkequal(False, 'helloworld', 'startswith', 'lowo', 3, 6) + self.checkequal(True, '', 'startswith', '', 0, 1) + self.checkequal(True, '', 'startswith', '', 0, 0) + self.checkequal(False, '', 'startswith', '', 1, 0) # test negative indices self.checkequal(True, 'hello', 'startswith', 'he', 0, -1) @@ -1022,6 +1025,9 @@ class MixinStrUnicodeUserStringTest: self.checkequal(False, 'helloworld', 'endswith', 'lowo', 3, 8) self.checkequal(False, 'ab', 'endswith', 'ab', 0, 1) self.checkequal(False, 'ab', 'endswith', 'ab', 0, 0) + self.checkequal(True, '', 'endswith', '', 0, 1) + self.checkequal(True, '', 'endswith', '', 0, 0) + self.checkequal(False, '', 'endswith', '', 1, 0) # test negative indices self.checkequal(True, 'hello', 'endswith', 'lo', -2) @@ -1176,8 +1182,7 @@ class MixinStrUnicodeUserStringTest: self.checkraises(TypeError, 'abc', '__mod__') self.checkraises(TypeError, '%(foo)s', '__mod__', 42) self.checkraises(TypeError, '%s%s', '__mod__', (42,)) - with self.assertWarns(DeprecationWarning): - self.checkraises(TypeError, '%c', '__mod__', (None,)) + self.checkraises(TypeError, '%c', '__mod__', (None,)) self.checkraises(ValueError, '%(foo', '__mod__', {}) self.checkraises(TypeError, '%(foo)s %(bar)s', '__mod__', ('foo', 42)) self.checkraises(TypeError, '%d', '__mod__', "42") # not numeric diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py index 01ca2f8..8b180b5 100644 --- a/Lib/test/support/__init__.py +++ b/Lib/test/support/__init__.py @@ -88,7 +88,7 @@ __all__ = [ "skip_unless_symlink", "requires_gzip", "requires_bz2", "requires_lzma", "bigmemtest", "bigaddrspacetest", "cpython_only", "get_attribute", "requires_IEEE_754", "skip_unless_xattr", "requires_zlib", - "anticipate_failure", "load_package_tests", + "anticipate_failure", "load_package_tests", "detect_api_mismatch", # sys "is_jython", "check_impl_detail", # network @@ -376,36 +376,32 @@ def rmtree(path): pass def make_legacy_pyc(source): - """Move a PEP 3147 pyc/pyo file to its legacy pyc/pyo location. - - The choice of .pyc or .pyo extension is done based on the __debug__ flag - value. + """Move a PEP 3147/488 pyc file to its legacy pyc location. :param source: The file system path to the source file. The source file - does not need to exist, however the PEP 3147 pyc file must exist. + does not need to exist, however the PEP 3147/488 pyc file must exist. :return: The file system path to the legacy pyc file. """ pyc_file = importlib.util.cache_from_source(source) up_one = os.path.dirname(os.path.abspath(source)) - legacy_pyc = os.path.join(up_one, source + ('c' if __debug__ else 'o')) + legacy_pyc = os.path.join(up_one, source + 'c') os.rename(pyc_file, legacy_pyc) return legacy_pyc def forget(modname): """'Forget' a module was ever imported. - This removes the module from sys.modules and deletes any PEP 3147 or - legacy .pyc and .pyo files. + This removes the module from sys.modules and deletes any PEP 3147/488 or + legacy .pyc files. """ unload(modname) for dirname in sys.path: source = os.path.join(dirname, modname + '.py') # It doesn't matter if they exist or not, unlink all possible - # combinations of PEP 3147 and legacy pyc and pyo files. + # combinations of PEP 3147/488 and legacy pyc files. unlink(source + 'c') - unlink(source + 'o') - unlink(importlib.util.cache_from_source(source, debug_override=True)) - unlink(importlib.util.cache_from_source(source, debug_override=False)) + for opt in ('', 1, 2): + unlink(importlib.util.cache_from_source(source, optimization=opt)) # Check whether a gui is actually available def _is_gui_available(): @@ -1042,7 +1038,8 @@ def open_urlresource(url, *args, **kw): # Verify the requirement before downloading the file requires('urlfetch') - print('\tfetching %s ...' % url, file=get_original_stdout()) + if verbose: + print('\tfetching %s ...' % url, file=get_original_stdout()) opener = urllib.request.build_opener() if gzip: opener.addheaders.append(('Accept-Encoding', 'gzip')) @@ -2187,6 +2184,21 @@ def fs_is_case_insensitive(directory): return False +def detect_api_mismatch(ref_api, other_api, *, ignore=()): + """Returns the set of items in ref_api not in other_api, except for a + defined list of items to be ignored in this check. + + By default this skips private attributes beginning with '_' but + includes all magic methods, i.e. those starting and ending in '__'. + """ + missing_items = set(dir(ref_api)) - set(dir(other_api)) + if ignore: + missing_items -= set(ignore) + missing_items = set(m for m in missing_items + if not m.startswith('_') or m.endswith('__')) + return missing_items + + class SuppressCrashReport: """Try to prevent a crash report from popping up. @@ -2194,6 +2206,7 @@ class SuppressCrashReport: disable the creation of coredump file. """ old_value = None + old_modes = None def __enter__(self): """On Windows, disable Windows Error Reporting dialogs using @@ -2211,6 +2224,26 @@ class SuppressCrashReport: SEM_NOGPFAULTERRORBOX = 0x02 self.old_value = self._k32.SetErrorMode(SEM_NOGPFAULTERRORBOX) self._k32.SetErrorMode(self.old_value | SEM_NOGPFAULTERRORBOX) + + # Suppress assert dialogs in debug builds + # (see http://bugs.python.org/issue23314) + try: + import msvcrt + msvcrt.CrtSetReportMode + except (AttributeError, ImportError): + # no msvcrt or a release build + pass + else: + self.old_modes = {} + for report_type in [msvcrt.CRT_WARN, + msvcrt.CRT_ERROR, + msvcrt.CRT_ASSERT]: + old_mode = msvcrt.CrtSetReportMode(report_type, + msvcrt.CRTDBG_MODE_FILE) + old_file = msvcrt.CrtSetReportFile(report_type, + msvcrt.CRTDBG_FILE_STDERR) + self.old_modes[report_type] = old_mode, old_file + else: if resource is not None: try: @@ -2242,6 +2275,12 @@ class SuppressCrashReport: if sys.platform.startswith('win'): self._k32.SetErrorMode(self.old_value) + + if self.old_modes: + import msvcrt + for report_type, (old_mode, old_file) in self.old_modes.items(): + msvcrt.CrtSetReportMode(report_type, old_mode) + msvcrt.CrtSetReportFile(report_type, old_file) else: if resource is not None: try: diff --git a/Lib/test/script_helper.py b/Lib/test/support/script_helper.py index b29392f..584b0e8 100644 --- a/Lib/test/script_helper.py +++ b/Lib/test/support/script_helper.py @@ -14,13 +14,13 @@ import shutil import zipfile from importlib.util import source_from_cache -from test.support import make_legacy_pyc, strip_python_stderr, temp_dir +from test.support import make_legacy_pyc, strip_python_stderr # Cached result of the expensive test performed in the function below. __cached_interp_requires_environment = None -def _interpreter_requires_environment(): +def interpreter_requires_environment(): """ Returns True if our sys.executable interpreter requires environment variables in order to be able to run at all. @@ -57,7 +57,7 @@ _PythonRunResult = collections.namedtuple("_PythonRunResult", # Executing the interpreter in a subprocess def run_python_until_end(*args, **env_vars): - env_required = _interpreter_requires_environment() + env_required = interpreter_requires_environment() if '__isolated' in env_vars: isolated = env_vars.pop('__isolated') else: @@ -95,10 +95,30 @@ def run_python_until_end(*args, **env_vars): def _assert_python(expected_success, *args, **env_vars): res, cmd_line = run_python_until_end(*args, **env_vars) if (res.rc and expected_success) or (not res.rc and not expected_success): - raise AssertionError( - "Process return code is %d, command line was: %r, " - "stderr follows:\n%s" % (res.rc, cmd_line, - res.err.decode('ascii', 'ignore'))) + # Limit to 80 lines to ASCII characters + maxlen = 80 * 100 + out, err = res.out, res.err + if len(out) > maxlen: + out = b'(... truncated stdout ...)' + out[-maxlen:] + if len(err) > maxlen: + err = b'(... truncated stderr ...)' + err[-maxlen:] + out = out.decode('ascii', 'replace').rstrip() + err = err.decode('ascii', 'replace').rstrip() + raise AssertionError("Process return code is %d\n" + "command line: %r\n" + "\n" + "stdout:\n" + "---\n" + "%s\n" + "---\n" + "\n" + "stderr:\n" + "---\n" + "%s\n" + "---" + % (res.rc, cmd_line, + out, + err)) return res def assert_python_ok(*args, **env_vars): diff --git a/Lib/test/test___future__.py b/Lib/test/test___future__.py index 6f73c7f..559a187 100644 --- a/Lib/test/test___future__.py +++ b/Lib/test/test___future__.py @@ -1,5 +1,4 @@ import unittest -from test import support import __future__ GOOD_SERIALS = ("alpha", "beta", "candidate", "final") @@ -58,8 +57,5 @@ class FutureTest(unittest.TestCase): ".compiler_flag isn't int") -def test_main(): - support.run_unittest(FutureTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test__opcode.py b/Lib/test/test__opcode.py index 0152e9d..1075dec 100644 --- a/Lib/test/test__opcode.py +++ b/Lib/test/test__opcode.py @@ -1,5 +1,5 @@ import dis -from test.support import run_unittest, import_module +from test.support import import_module import unittest _opcode = import_module("_opcode") @@ -16,8 +16,5 @@ class OpcodeTests(unittest.TestCase): self.assertRaises(ValueError, _opcode.stack_effect, dis.opmap['BUILD_SLICE']) self.assertRaises(ValueError, _opcode.stack_effect, dis.opmap['POP_TOP'], 0) -def test_main(): - run_unittest(OpcodeTests) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test__osx_support.py b/Lib/test/test__osx_support.py index 5dcadf7..ac6325a 100644 --- a/Lib/test/test__osx_support.py +++ b/Lib/test/test__osx_support.py @@ -273,9 +273,5 @@ class Test_OSXSupport(unittest.TestCase): result = _osx_support.get_platform_osx(config_vars, ' ', ' ', ' ') self.assertEqual(('macosx', '10.6', 'fat'), result) -def test_main(): - if sys.platform == 'darwin': - test.support.run_unittest(Test_OSXSupport) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_argparse.py b/Lib/test/test_argparse.py index ecc5507..27bfad5 100644 --- a/Lib/test/test_argparse.py +++ b/Lib/test/test_argparse.py @@ -20,15 +20,6 @@ class StdIOBuffer(StringIO): class TestCase(unittest.TestCase): - def assertEqual(self, obj1, obj2): - if obj1 != obj2: - print('') - print(repr(obj1)) - print(repr(obj2)) - print(obj1) - print(obj2) - super(TestCase, self).assertEqual(obj1, obj2) - def setUp(self): # The tests assume that line wrapping occurs at 80 columns, but this # behaviour can be overridden by setting the COLUMNS environment @@ -78,9 +69,6 @@ class NS(object): def __eq__(self, other): return vars(self) == vars(other) - def __ne__(self, other): - return not (self == other) - class ArgumentParserError(Exception): @@ -765,6 +753,39 @@ class TestOptionalsActionCount(ParserTestCase): ] +class TestOptionalsAllowLongAbbreviation(ParserTestCase): + """Allow long options to be abbreviated unambiguously""" + + argument_signatures = [ + Sig('--foo'), + Sig('--foobaz'), + Sig('--fooble', action='store_true'), + ] + failures = ['--foob 5', '--foob'] + successes = [ + ('', NS(foo=None, foobaz=None, fooble=False)), + ('--foo 7', NS(foo='7', foobaz=None, fooble=False)), + ('--fooba a', NS(foo=None, foobaz='a', fooble=False)), + ('--foobl --foo g', NS(foo='g', foobaz=None, fooble=True)), + ] + + +class TestOptionalsDisallowLongAbbreviation(ParserTestCase): + """Do not allow abbreviations of long options at all""" + + parser_signature = Sig(allow_abbrev=False) + argument_signatures = [ + Sig('--foo'), + Sig('--foodle', action='store_true'), + Sig('--foonly'), + ] + failures = ['-foon 3', '--foon 3', '--food', '--food --foo 2'] + successes = [ + ('', NS(foo=None, foodle=False, foonly=None)), + ('--foo 3', NS(foo='3', foodle=False, foonly=None)), + ('--foonly 7 --foodle --foo 2', NS(foo='2', foodle=True, foonly='7')), + ] + # ================ # Positional tests # ================ @@ -1993,14 +2014,9 @@ class TestAddSubparsers(TestCase): ''')) def _test_subparser_help(self, args_str, expected_help): - try: + with self.assertRaises(ArgumentParserError) as cm: self.parser.parse_args(args_str.split()) - except ArgumentParserError: - err = sys.exc_info()[1] - if err.stdout != expected_help: - print(repr(expected_help)) - print(repr(err.stdout)) - self.assertEqual(err.stdout, expected_help) + self.assertEqual(expected_help, cm.exception.stdout) def test_subparser1_help(self): self._test_subparser_help('5.0 1 -h', textwrap.dedent('''\ @@ -2846,15 +2862,15 @@ class TestGetDefault(TestCase): def test_get_default(self): parser = ErrorRaisingArgumentParser() - self.assertEqual(None, parser.get_default("foo")) - self.assertEqual(None, parser.get_default("bar")) + self.assertIsNone(parser.get_default("foo")) + self.assertIsNone(parser.get_default("bar")) parser.add_argument("--foo") - self.assertEqual(None, parser.get_default("foo")) - self.assertEqual(None, parser.get_default("bar")) + self.assertIsNone(parser.get_default("foo")) + self.assertIsNone(parser.get_default("bar")) parser.add_argument("--bar", type=int, default=42) - self.assertEqual(None, parser.get_default("foo")) + self.assertIsNone(parser.get_default("foo")) self.assertEqual(42, parser.get_default("bar")) parser.set_defaults(foo="badger") @@ -2869,18 +2885,16 @@ class TestNamespaceContainsSimple(TestCase): def test_empty(self): ns = argparse.Namespace() - self.assertEqual('' in ns, False) - self.assertEqual('' not in ns, True) - self.assertEqual('x' in ns, False) + self.assertNotIn('', ns) + self.assertNotIn('x', ns) def test_non_empty(self): ns = argparse.Namespace(x=1, y=2) - self.assertEqual('x' in ns, True) - self.assertEqual('x' not in ns, False) - self.assertEqual('y' in ns, True) - self.assertEqual('' in ns, False) - self.assertEqual('xx' in ns, False) - self.assertEqual('z' in ns, False) + self.assertNotIn('', ns) + self.assertIn('x', ns) + self.assertIn('y', ns) + self.assertNotIn('xx', ns) + self.assertNotIn('z', ns) # ===================== # Help formatting tests @@ -2936,13 +2950,6 @@ class TestHelpFormattingMetaclass(type): def _test(self, tester, parser_text): expected_text = getattr(tester, self.func_suffix) expected_text = textwrap.dedent(expected_text) - if expected_text != parser_text: - print(repr(expected_text)) - print(repr(parser_text)) - for char1, char2 in zip(expected_text, parser_text): - if char1 != char2: - print('first diff: %r %r' % (char1, char2)) - break tester.assertEqual(expected_text, parser_text) def test_format(self, tester): @@ -4221,24 +4228,17 @@ class TestInvalidArgumentConstructors(TestCase): self.assertValueError('foo', action='baz') self.assertValueError('--foo', action=('store', 'append')) parser = argparse.ArgumentParser() - try: + with self.assertRaises(ValueError) as cm: parser.add_argument("--foo", action="store-true") - except ValueError: - e = sys.exc_info()[1] - expected = 'unknown action' - msg = 'expected %r, found %r' % (expected, e) - self.assertTrue(expected in str(e), msg) + self.assertIn('unknown action', str(cm.exception)) def test_multiple_dest(self): parser = argparse.ArgumentParser() parser.add_argument(dest='foo') - try: + with self.assertRaises(ValueError) as cm: parser.add_argument('bar', dest='baz') - except ValueError: - e = sys.exc_info()[1] - expected = 'dest supplied twice for positional argument' - msg = 'expected %r, found %r' % (expected, e) - self.assertTrue(expected in str(e), msg) + self.assertIn('dest supplied twice for positional argument', + str(cm.exception)) def test_no_argument_actions(self): for action in ['store_const', 'store_true', 'store_false', @@ -4395,18 +4395,10 @@ class TestConflictHandling(TestCase): class TestOptionalsHelpVersionActions(TestCase): """Test the help and version actions""" - def _get_error(self, func, *args, **kwargs): - try: - func(*args, **kwargs) - except ArgumentParserError: - return sys.exc_info()[1] - else: - self.assertRaises(ArgumentParserError, func, *args, **kwargs) - def assertPrintHelpExit(self, parser, args_str): - self.assertEqual( - parser.format_help(), - self._get_error(parser.parse_args, args_str.split()).stdout) + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(args_str.split()) + self.assertEqual(parser.format_help(), cm.exception.stdout) def assertArgumentParserError(self, parser, *args): self.assertRaises(ArgumentParserError, parser.parse_args, args) @@ -4421,8 +4413,9 @@ class TestOptionalsHelpVersionActions(TestCase): def test_version_format(self): parser = ErrorRaisingArgumentParser(prog='PPP') parser.add_argument('-v', '--version', action='version', version='%(prog)s 3.5') - msg = self._get_error(parser.parse_args, ['-v']).stdout - self.assertEqual('PPP 3.5\n', msg) + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(['-v']) + self.assertEqual('PPP 3.5\n', cm.exception.stdout) def test_version_no_help(self): parser = ErrorRaisingArgumentParser(add_help=False) @@ -4434,8 +4427,9 @@ class TestOptionalsHelpVersionActions(TestCase): def test_version_action(self): parser = ErrorRaisingArgumentParser(prog='XXX') parser.add_argument('-V', action='version', version='%(prog)s 3.7') - msg = self._get_error(parser.parse_args, ['-V']).stdout - self.assertEqual('XXX 3.7\n', msg) + with self.assertRaises(ArgumentParserError) as cm: + parser.parse_args(['-V']) + self.assertEqual('XXX 3.7\n', cm.exception.stdout) def test_no_help(self): parser = ErrorRaisingArgumentParser(add_help=False) @@ -4605,14 +4599,10 @@ class TestArgumentTypeError(TestCase): parser = ErrorRaisingArgumentParser(prog='PROG', add_help=False) parser.add_argument('x', type=spam) - try: + with self.assertRaises(ArgumentParserError) as cm: parser.parse_args(['XXX']) - except ArgumentParserError: - expected = 'usage: PROG x\nPROG: error: argument x: spam!\n' - msg = sys.exc_info()[1].stderr - self.assertEqual(expected, msg) - else: - self.fail() + self.assertEqual('usage: PROG x\nPROG: error: argument x: spam!\n', + cm.exception.stderr) # ========================= # MessageContentError tests diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py index 07c9bf9..10d9946 100644 --- a/Lib/test/test_array.py +++ b/Lib/test/test_array.py @@ -394,7 +394,9 @@ class BaseTest: self.assertEqual(a, b) def test_tofromstring(self): - nb_warnings = 4 + # Warnings not raised when arguments are incorrect as Argument Clinic + # handles that before the warning can be raised. + nb_warnings = 2 with warnings.catch_warnings(record=True) as r: warnings.filterwarnings("always", message=r"(to|from)string\(\) is deprecated", @@ -1039,6 +1041,11 @@ class BaseTest: a = array.array(self.typecode, "foo") a = array.array(self.typecode, array.array('u', 'foo')) + @support.cpython_only + def test_obsolete_write_lock(self): + from _testcapi import getbuffer_with_null_view + a = array.array('B', b"") + self.assertRaises(BufferError, getbuffer_with_null_view, a) class StringTest(BaseTest): diff --git a/Lib/test/test_asdl_parser.py b/Lib/test/test_asdl_parser.py new file mode 100644 index 0000000..7a6426a --- /dev/null +++ b/Lib/test/test_asdl_parser.py @@ -0,0 +1,122 @@ +"""Tests for the asdl parser in Parser/asdl.py""" + +import importlib.machinery +import os +from os.path import dirname +import sys +import sysconfig +import unittest + + +# This test is only relevant for from-source builds of Python. +if not sysconfig.is_python_build(): + raise unittest.SkipTest('test irrelevant for an installed Python') + +src_base = dirname(dirname(dirname(__file__))) +parser_dir = os.path.join(src_base, 'Parser') + + +class TestAsdlParser(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Loads the asdl module dynamically, since it's not in a real importable + # package. + # Parses Python.asdl into a ast.Module and run the check on it. + # There's no need to do this for each test method, hence setUpClass. + sys.path.insert(0, parser_dir) + loader = importlib.machinery.SourceFileLoader( + 'asdl', os.path.join(parser_dir, 'asdl.py')) + cls.asdl = loader.load_module() + cls.mod = cls.asdl.parse(os.path.join(parser_dir, 'Python.asdl')) + cls.assertTrue(cls.asdl.check(cls.mod), 'Module validation failed') + + @classmethod + def tearDownClass(cls): + del sys.path[0] + + def setUp(self): + # alias stuff from the class, for convenience + self.asdl = TestAsdlParser.asdl + self.mod = TestAsdlParser.mod + self.types = self.mod.types + + def test_module(self): + self.assertEqual(self.mod.name, 'Python') + self.assertIn('stmt', self.types) + self.assertIn('expr', self.types) + self.assertIn('mod', self.types) + + def test_definitions(self): + defs = self.mod.dfns + self.assertIsInstance(defs[0], self.asdl.Type) + self.assertIsInstance(defs[0].value, self.asdl.Sum) + + self.assertIsInstance(self.types['withitem'], self.asdl.Product) + self.assertIsInstance(self.types['alias'], self.asdl.Product) + + def test_product(self): + alias = self.types['alias'] + self.assertEqual( + str(alias), + 'Product([Field(identifier, name), Field(identifier, asname, opt=True)])') + + def test_attributes(self): + stmt = self.types['stmt'] + self.assertEqual(len(stmt.attributes), 2) + self.assertEqual(str(stmt.attributes[0]), 'Field(int, lineno)') + self.assertEqual(str(stmt.attributes[1]), 'Field(int, col_offset)') + + def test_constructor_fields(self): + ehandler = self.types['excepthandler'] + self.assertEqual(len(ehandler.types), 1) + self.assertEqual(len(ehandler.attributes), 2) + + cons = ehandler.types[0] + self.assertIsInstance(cons, self.asdl.Constructor) + self.assertEqual(len(cons.fields), 3) + + f0 = cons.fields[0] + self.assertEqual(f0.type, 'expr') + self.assertEqual(f0.name, 'type') + self.assertTrue(f0.opt) + + f1 = cons.fields[1] + self.assertEqual(f1.type, 'identifier') + self.assertEqual(f1.name, 'name') + self.assertTrue(f1.opt) + + f2 = cons.fields[2] + self.assertEqual(f2.type, 'stmt') + self.assertEqual(f2.name, 'body') + self.assertFalse(f2.opt) + self.assertTrue(f2.seq) + + def test_visitor(self): + class CustomVisitor(self.asdl.VisitorBase): + def __init__(self): + super().__init__() + self.names_with_seq = [] + + def visitModule(self, mod): + for dfn in mod.dfns: + self.visit(dfn) + + def visitType(self, type): + self.visit(type.value) + + def visitSum(self, sum): + for t in sum.types: + self.visit(t) + + def visitConstructor(self, cons): + for f in cons.fields: + if f.seq: + self.names_with_seq.append(cons.name) + + v = CustomVisitor() + v.visit(self.types['mod']) + self.assertEqual(v.names_with_seq, ['Module', 'Interactive', 'Suite']) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index a533f86..d3e6d35 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -78,9 +78,9 @@ exec_tests = [ # Pass, "pass", # Break - "break", + "for v in v:break", # Continue - "continue", + "for v in v:continue", # for statements with naked tuples (see http://bugs.python.org/issue6704) "for a,b in c: pass", "[(a,b) for a,b in c]", @@ -106,6 +106,15 @@ exec_tests = [ "{r for l in x if g}", # setcomp with naked tuple "{r for l,m in x}", + # AsyncFunctionDef + "async def f():\n await something()", + # AsyncFor + "async def f():\n async for e in i: 1\n else: 2", + # AsyncWith + "async def f():\n async with a as b: 1", + # PEP 448: Additional Unpacking Generalizations + "{**{1:2}, 2:3}", + "{*{1, 2}, 3}", ] # These are compiled through "single" @@ -225,9 +234,12 @@ class AST_Tests(unittest.TestCase): (single_tests, single_results, "single"), (eval_tests, eval_results, "eval")): for i, o in zip(input, output): - ast_tree = compile(i, "?", kind, ast.PyCF_ONLY_AST) - self.assertEqual(to_tuple(ast_tree), o) - self._assertTrueorder(ast_tree, (0, 0)) + with self.subTest(action="parsing", input=i): + ast_tree = compile(i, "?", kind, ast.PyCF_ONLY_AST) + self.assertEqual(to_tuple(ast_tree), o) + self._assertTrueorder(ast_tree, (0, 0)) + with self.subTest(action="compiling", input=i): + compile(ast_tree, "?", kind) def test_slice(self): slc = ast.parse("x[::]").body[0].value.slice @@ -427,17 +439,17 @@ class ASTHelpers_Test(unittest.TestCase): self.assertEqual(ast.dump(node), "Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load()), " "args=[Name(id='eggs', ctx=Load()), Str(s='and cheese')], " - "keywords=[], starargs=None, kwargs=None))])" + "keywords=[]))])" ) self.assertEqual(ast.dump(node, annotate_fields=False), "Module([Expr(Call(Name('spam', Load()), [Name('eggs', Load()), " - "Str('and cheese')], [], None, None))])" + "Str('and cheese')], []))])" ) self.assertEqual(ast.dump(node, include_attributes=True), "Module(body=[Expr(value=Call(func=Name(id='spam', ctx=Load(), " "lineno=1, col_offset=0), args=[Name(id='eggs', ctx=Load(), " "lineno=1, col_offset=5), Str(s='and cheese', lineno=1, " - "col_offset=11)], keywords=[], starargs=None, kwargs=None, " + "col_offset=11)], keywords=[], " "lineno=1, col_offset=0), lineno=1, col_offset=0)])" ) @@ -453,16 +465,16 @@ class ASTHelpers_Test(unittest.TestCase): def test_fix_missing_locations(self): src = ast.parse('write("spam")') src.body.append(ast.Expr(ast.Call(ast.Name('spam', ast.Load()), - [ast.Str('eggs')], [], None, None))) + [ast.Str('eggs')], []))) self.assertEqual(src, ast.fix_missing_locations(src)) self.assertEqual(ast.dump(src, include_attributes=True), "Module(body=[Expr(value=Call(func=Name(id='write', ctx=Load(), " "lineno=1, col_offset=0), args=[Str(s='spam', lineno=1, " - "col_offset=6)], keywords=[], starargs=None, kwargs=None, " + "col_offset=6)], keywords=[], " "lineno=1, col_offset=0), lineno=1, col_offset=0), " "Expr(value=Call(func=Name(id='spam', ctx=Load(), lineno=1, " "col_offset=0), args=[Str(s='eggs', lineno=1, col_offset=0)], " - "keywords=[], starargs=None, kwargs=None, lineno=1, " + "keywords=[], lineno=1, " "col_offset=0), lineno=1, col_offset=0)])" ) @@ -487,8 +499,7 @@ class ASTHelpers_Test(unittest.TestCase): node = ast.parse('foo()', mode='eval') d = dict(ast.iter_fields(node.body)) self.assertEqual(d.pop('func').id, 'foo') - self.assertEqual(d, {'keywords': [], 'kwargs': None, - 'args': [], 'starargs': None}) + self.assertEqual(d, {'keywords': [], 'args': []}) def test_iter_child_nodes(self): node = ast.parse("spam(23, 42, eggs='leek')", mode='eval') @@ -506,6 +517,9 @@ class ASTHelpers_Test(unittest.TestCase): self.assertEqual(ast.get_docstring(node.body[0]), 'line one\nline two') + node = ast.parse('async def foo():\n """spam\n ham"""') + self.assertEqual(ast.get_docstring(node.body[0]), 'spam\nham') + def test_literal_eval(self): self.assertEqual(ast.literal_eval('[1, 2, 3]'), [1, 2, 3]) self.assertEqual(ast.literal_eval('{"foo": 42}'), {"foo": 42}) @@ -604,8 +618,7 @@ class ASTValidatorTests(unittest.TestCase): self._check_arguments(fac, self.stmt) def test_classdef(self): - def cls(bases=None, keywords=None, starargs=None, kwargs=None, - body=None, decorator_list=None): + def cls(bases=None, keywords=None, body=None, decorator_list=None): if bases is None: bases = [] if keywords is None: @@ -614,16 +627,12 @@ class ASTValidatorTests(unittest.TestCase): body = [ast.Pass()] if decorator_list is None: decorator_list = [] - return ast.ClassDef("myclass", bases, keywords, starargs, - kwargs, body, decorator_list) + return ast.ClassDef("myclass", bases, keywords, + body, decorator_list) self.stmt(cls(bases=[ast.Name("x", ast.Store())]), "must have Load context") self.stmt(cls(keywords=[ast.keyword("x", ast.Name("x", ast.Store()))]), "must have Load context") - self.stmt(cls(starargs=ast.Name("x", ast.Store())), - "must have Load context") - self.stmt(cls(kwargs=ast.Name("x", ast.Store())), - "must have Load context") self.stmt(cls(body=[]), "empty body on ClassDef") self.stmt(cls(body=[None]), "None disallowed") self.stmt(cls(decorator_list=[ast.Name("x", ast.Store())]), @@ -777,8 +786,6 @@ class ASTValidatorTests(unittest.TestCase): def test_dict(self): d = ast.Dict([], [ast.Name("x", ast.Load())]) self.expr(d, "same number of keys as values") - d = ast.Dict([None], [ast.Name("x", ast.Load())]) - self.expr(d, "None disallowed") d = ast.Dict([ast.Name("x", ast.Load())], [None]) self.expr(d, "None disallowed") @@ -854,20 +861,12 @@ class ASTValidatorTests(unittest.TestCase): func = ast.Name("x", ast.Load()) args = [ast.Name("y", ast.Load())] keywords = [ast.keyword("w", ast.Name("z", ast.Load()))] - stararg = ast.Name("p", ast.Load()) - kwarg = ast.Name("q", ast.Load()) - call = ast.Call(ast.Name("x", ast.Store()), args, keywords, stararg, - kwarg) + call = ast.Call(ast.Name("x", ast.Store()), args, keywords) self.expr(call, "must have Load context") - call = ast.Call(func, [None], keywords, stararg, kwarg) + call = ast.Call(func, [None], keywords) self.expr(call, "None disallowed") bad_keywords = [ast.keyword("w", ast.Name("z", ast.Store()))] - call = ast.Call(func, args, bad_keywords, stararg, kwarg) - self.expr(call, "must have Load context") - call = ast.Call(func, args, keywords, ast.Name("z", ast.Store()), kwarg) - self.expr(call, "must have Load context") - call = ast.Call(func, args, keywords, stararg, - ast.Name("w", ast.Store())) + call = ast.Call(func, args, bad_keywords) self.expr(call, "must have Load context") def test_num(self): @@ -957,8 +956,8 @@ exec_results = [ ('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], ('arg', (1, 7), 'args', None), [], [], None, []), [('Pass', (1, 14))], [], None)]), ('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, [], [], ('arg', (1, 8), 'kwargs', None), []), [('Pass', (1, 17))], [], None)]), ('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [('arg', (1, 6), 'a', None), ('arg', (1, 9), 'b', None), ('arg', (1, 14), 'c', None), ('arg', (1, 22), 'd', None), ('arg', (1, 28), 'e', None)], ('arg', (1, 35), 'args', None), [('arg', (1, 41), 'f', None)], [('Num', (1, 43), 42)], ('arg', (1, 49), 'kwargs', None), [('Num', (1, 11), 1), ('NameConstant', (1, 16), None), ('List', (1, 24), [], ('Load',)), ('Dict', (1, 30), [], [])]), [('Pass', (1, 58))], [], None)]), -('Module', [('ClassDef', (1, 0), 'C', [], [], None, None, [('Pass', (1, 8))], [])]), -('Module', [('ClassDef', (1, 0), 'C', [('Name', (1, 8), 'object', ('Load',))], [], None, None, [('Pass', (1, 17))], [])]), +('Module', [('ClassDef', (1, 0), 'C', [], [], [('Pass', (1, 8))], [])]), +('Module', [('ClassDef', (1, 0), 'C', [('Name', (1, 8), 'object', ('Load',))], [], [('Pass', (1, 17))], [])]), ('Module', [('FunctionDef', (1, 0), 'f', ('arguments', [], None, [], [], None, []), [('Return', (1, 8), ('Num', (1, 15), 1))], [], None)]), ('Module', [('Delete', (1, 0), [('Name', (1, 4), 'v', ('Del',))])]), ('Module', [('Assign', (1, 0), [('Name', (1, 0), 'v', ('Store',))], ('Num', (1, 4), 1))]), @@ -968,7 +967,7 @@ exec_results = [ ('Module', [('If', (1, 0), ('Name', (1, 3), 'v', ('Load',)), [('Pass', (1, 5))], [])]), ('Module', [('With', (1, 0), [('withitem', ('Name', (1, 5), 'x', ('Load',)), ('Name', (1, 10), 'y', ('Store',)))], [('Pass', (1, 13))])]), ('Module', [('With', (1, 0), [('withitem', ('Name', (1, 5), 'x', ('Load',)), ('Name', (1, 10), 'y', ('Store',))), ('withitem', ('Name', (1, 13), 'z', ('Load',)), ('Name', (1, 18), 'q', ('Store',)))], [('Pass', (1, 21))])]), -('Module', [('Raise', (1, 0), ('Call', (1, 6), ('Name', (1, 6), 'Exception', ('Load',)), [('Str', (1, 16), 'string')], [], None, None), None)]), +('Module', [('Raise', (1, 0), ('Call', (1, 6), ('Name', (1, 6), 'Exception', ('Load',)), [('Str', (1, 16), 'string')], []), None)]), ('Module', [('Try', (1, 0), [('Pass', (2, 2))], [('ExceptHandler', (3, 0), ('Name', (3, 7), 'Exception', ('Load',)), None, [('Pass', (4, 2))])], [], [])]), ('Module', [('Try', (1, 0), [('Pass', (2, 2))], [], [], [('Pass', (4, 2))])]), ('Module', [('Assert', (1, 0), ('Name', (1, 7), 'v', ('Load',)), None)]), @@ -977,17 +976,22 @@ exec_results = [ ('Module', [('Global', (1, 0), ['v'])]), ('Module', [('Expr', (1, 0), ('Num', (1, 0), 1))]), ('Module', [('Pass', (1, 0))]), -('Module', [('Break', (1, 0))]), -('Module', [('Continue', (1, 0))]), +('Module', [('For', (1, 0), ('Name', (1, 4), 'v', ('Store',)), ('Name', (1, 9), 'v', ('Load',)), [('Break', (1, 11))], [])]), +('Module', [('For', (1, 0), ('Name', (1, 4), 'v', ('Store',)), ('Name', (1, 9), 'v', ('Load',)), [('Continue', (1, 11))], [])]), ('Module', [('For', (1, 0), ('Tuple', (1, 4), [('Name', (1, 4), 'a', ('Store',)), ('Name', (1, 6), 'b', ('Store',))], ('Store',)), ('Name', (1, 11), 'c', ('Load',)), [('Pass', (1, 14))], [])]), ('Module', [('Expr', (1, 0), ('ListComp', (1, 1), ('Tuple', (1, 2), [('Name', (1, 2), 'a', ('Load',)), ('Name', (1, 4), 'b', ('Load',))], ('Load',)), [('comprehension', ('Tuple', (1, 11), [('Name', (1, 11), 'a', ('Store',)), ('Name', (1, 13), 'b', ('Store',))], ('Store',)), ('Name', (1, 18), 'c', ('Load',)), [])]))]), ('Module', [('Expr', (1, 0), ('GeneratorExp', (1, 1), ('Tuple', (1, 2), [('Name', (1, 2), 'a', ('Load',)), ('Name', (1, 4), 'b', ('Load',))], ('Load',)), [('comprehension', ('Tuple', (1, 11), [('Name', (1, 11), 'a', ('Store',)), ('Name', (1, 13), 'b', ('Store',))], ('Store',)), ('Name', (1, 18), 'c', ('Load',)), [])]))]), ('Module', [('Expr', (1, 0), ('GeneratorExp', (1, 1), ('Tuple', (1, 2), [('Name', (1, 2), 'a', ('Load',)), ('Name', (1, 4), 'b', ('Load',))], ('Load',)), [('comprehension', ('Tuple', (1, 12), [('Name', (1, 12), 'a', ('Store',)), ('Name', (1, 14), 'b', ('Store',))], ('Store',)), ('Name', (1, 20), 'c', ('Load',)), [])]))]), ('Module', [('Expr', (1, 0), ('GeneratorExp', (2, 4), ('Tuple', (3, 4), [('Name', (3, 4), 'Aa', ('Load',)), ('Name', (5, 7), 'Bb', ('Load',))], ('Load',)), [('comprehension', ('Tuple', (8, 4), [('Name', (8, 4), 'Aa', ('Store',)), ('Name', (10, 4), 'Bb', ('Store',))], ('Store',)), ('Name', (10, 10), 'Cc', ('Load',)), [])]))]), -('Module', [('Expr', (1, 0), ('DictComp', (1, 1), ('Name', (1, 1), 'a', ('Load',)), ('Name', (1, 5), 'b', ('Load',)), [('comprehension', ('Name', (1, 11), 'w', ('Store',)), ('Name', (1, 16), 'x', ('Load',)), []), ('comprehension', ('Name', (1, 22), 'm', ('Store',)), ('Name', (1, 27), 'p', ('Load',)), [('Name', (1, 32), 'g', ('Load',))])]))]), -('Module', [('Expr', (1, 0), ('DictComp', (1, 1), ('Name', (1, 1), 'a', ('Load',)), ('Name', (1, 5), 'b', ('Load',)), [('comprehension', ('Tuple', (1, 11), [('Name', (1, 11), 'v', ('Store',)), ('Name', (1, 13), 'w', ('Store',))], ('Store',)), ('Name', (1, 18), 'x', ('Load',)), [])]))]), -('Module', [('Expr', (1, 0), ('SetComp', (1, 1), ('Name', (1, 1), 'r', ('Load',)), [('comprehension', ('Name', (1, 7), 'l', ('Store',)), ('Name', (1, 12), 'x', ('Load',)), [('Name', (1, 17), 'g', ('Load',))])]))]), -('Module', [('Expr', (1, 0), ('SetComp', (1, 1), ('Name', (1, 1), 'r', ('Load',)), [('comprehension', ('Tuple', (1, 7), [('Name', (1, 7), 'l', ('Store',)), ('Name', (1, 9), 'm', ('Store',))], ('Store',)), ('Name', (1, 14), 'x', ('Load',)), [])]))]), +('Module', [('Expr', (1, 0), ('DictComp', (1, 0), ('Name', (1, 1), 'a', ('Load',)), ('Name', (1, 5), 'b', ('Load',)), [('comprehension', ('Name', (1, 11), 'w', ('Store',)), ('Name', (1, 16), 'x', ('Load',)), []), ('comprehension', ('Name', (1, 22), 'm', ('Store',)), ('Name', (1, 27), 'p', ('Load',)), [('Name', (1, 32), 'g', ('Load',))])]))]), +('Module', [('Expr', (1, 0), ('DictComp', (1, 0), ('Name', (1, 1), 'a', ('Load',)), ('Name', (1, 5), 'b', ('Load',)), [('comprehension', ('Tuple', (1, 11), [('Name', (1, 11), 'v', ('Store',)), ('Name', (1, 13), 'w', ('Store',))], ('Store',)), ('Name', (1, 18), 'x', ('Load',)), [])]))]), +('Module', [('Expr', (1, 0), ('SetComp', (1, 0), ('Name', (1, 1), 'r', ('Load',)), [('comprehension', ('Name', (1, 7), 'l', ('Store',)), ('Name', (1, 12), 'x', ('Load',)), [('Name', (1, 17), 'g', ('Load',))])]))]), +('Module', [('Expr', (1, 0), ('SetComp', (1, 0), ('Name', (1, 1), 'r', ('Load',)), [('comprehension', ('Tuple', (1, 7), [('Name', (1, 7), 'l', ('Store',)), ('Name', (1, 9), 'm', ('Store',))], ('Store',)), ('Name', (1, 14), 'x', ('Load',)), [])]))]), +('Module', [('AsyncFunctionDef', (1, 6), 'f', ('arguments', [], None, [], [], None, []), [('Expr', (2, 1), ('Await', (2, 1), ('Call', (2, 7), ('Name', (2, 7), 'something', ('Load',)), [], [])))], [], None)]), +('Module', [('AsyncFunctionDef', (1, 6), 'f', ('arguments', [], None, [], [], None, []), [('AsyncFor', (2, 7), ('Name', (2, 11), 'e', ('Store',)), ('Name', (2, 16), 'i', ('Load',)), [('Expr', (2, 19), ('Num', (2, 19), 1))], [('Expr', (3, 7), ('Num', (3, 7), 2))])], [], None)]), +('Module', [('AsyncFunctionDef', (1, 6), 'f', ('arguments', [], None, [], [], None, []), [('AsyncWith', (2, 7), [('withitem', ('Name', (2, 12), 'a', ('Load',)), ('Name', (2, 17), 'b', ('Store',)))], [('Expr', (2, 20), ('Num', (2, 20), 1))])], [], None)]), +('Module', [('Expr', (1, 0), ('Dict', (1, 0), [None, ('Num', (1, 10), 2)], [('Dict', (1, 3), [('Num', (1, 4), 1)], [('Num', (1, 6), 2)]), ('Num', (1, 12), 3)]))]), +('Module', [('Expr', (1, 0), ('Set', (1, 0), [('Starred', (1, 1), ('Set', (1, 2), [('Num', (1, 3), 1), ('Num', (1, 6), 2)]), ('Load',)), ('Num', (1, 10), 3)]))]), ] single_results = [ ('Interactive', [('Expr', (1, 0), ('BinOp', (1, 0), ('Num', (1, 0), 1), ('Add',), ('Num', (1, 2), 2)))]), @@ -1005,7 +1009,7 @@ eval_results = [ ('Expression', ('ListComp', (1, 1), ('Name', (1, 1), 'a', ('Load',)), [('comprehension', ('Name', (1, 7), 'b', ('Store',)), ('Name', (1, 12), 'c', ('Load',)), [('Name', (1, 17), 'd', ('Load',))])])), ('Expression', ('GeneratorExp', (1, 1), ('Name', (1, 1), 'a', ('Load',)), [('comprehension', ('Name', (1, 7), 'b', ('Store',)), ('Name', (1, 12), 'c', ('Load',)), [('Name', (1, 17), 'd', ('Load',))])])), ('Expression', ('Compare', (1, 0), ('Num', (1, 0), 1), [('Lt',), ('Lt',)], [('Num', (1, 4), 2), ('Num', (1, 8), 3)])), -('Expression', ('Call', (1, 0), ('Name', (1, 0), 'f', ('Load',)), [('Num', (1, 2), 1), ('Num', (1, 4), 2)], [('keyword', 'c', ('Num', (1, 8), 3))], ('Name', (1, 11), 'd', ('Load',)), ('Name', (1, 15), 'e', ('Load',)))), +('Expression', ('Call', (1, 0), ('Name', (1, 0), 'f', ('Load',)), [('Num', (1, 2), 1), ('Num', (1, 4), 2), ('Starred', (1, 10), ('Name', (1, 11), 'd', ('Load',)), ('Load',))], [('keyword', 'c', ('Num', (1, 8), 3)), ('keyword', None, ('Name', (1, 15), 'e', ('Load',)))])), ('Expression', ('Num', (1, 0), 10)), ('Expression', ('Str', (1, 0), 'string')), ('Expression', ('Attribute', (1, 0), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',))), @@ -1016,6 +1020,6 @@ eval_results = [ ('Expression', ('Tuple', (1, 0), [('Num', (1, 0), 1), ('Num', (1, 2), 2), ('Num', (1, 4), 3)], ('Load',))), ('Expression', ('Tuple', (1, 1), [('Num', (1, 1), 1), ('Num', (1, 3), 2), ('Num', (1, 5), 3)], ('Load',))), ('Expression', ('Tuple', (1, 0), [], ('Load',))), -('Expression', ('Call', (1, 0), ('Attribute', (1, 0), ('Attribute', (1, 0), ('Attribute', (1, 0), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',)), 'c', ('Load',)), 'd', ('Load',)), [('Subscript', (1, 8), ('Attribute', (1, 8), ('Name', (1, 8), 'a', ('Load',)), 'b', ('Load',)), ('Slice', ('Num', (1, 12), 1), ('Num', (1, 14), 2), None), ('Load',))], [], None, None)), +('Expression', ('Call', (1, 0), ('Attribute', (1, 0), ('Attribute', (1, 0), ('Attribute', (1, 0), ('Name', (1, 0), 'a', ('Load',)), 'b', ('Load',)), 'c', ('Load',)), 'd', ('Load',)), [('Subscript', (1, 8), ('Attribute', (1, 8), ('Name', (1, 8), 'a', ('Load',)), 'b', ('Load',)), ('Slice', ('Num', (1, 12), 1), ('Num', (1, 14), 2), None), ('Load',))], [])), ] main() diff --git a/Lib/test/test_asynchat.py b/Lib/test/test_asynchat.py index 2dc9d0c..3a33fc8 100644 --- a/Lib/test/test_asynchat.py +++ b/Lib/test/test_asynchat.py @@ -12,6 +12,7 @@ import socket import sys import time import unittest +import warnings import unittest.mock try: import threading @@ -38,7 +39,7 @@ if threading: self.start_resend_event = None def run(self): - self.sock.listen(1) + self.sock.listen() self.event.set() conn, client = self.sock.accept() self.buffer = b"" @@ -298,7 +299,10 @@ class TestHelperFunctions(unittest.TestCase): class TestFifo(unittest.TestCase): def test_basic(self): - f = asynchat.fifo() + with self.assertWarns(DeprecationWarning) as cm: + f = asynchat.fifo() + self.assertEqual(str(cm.warning), + "fifo class will be removed in Python 3.6") f.push(7) f.push(b'a') self.assertEqual(len(f), 2) @@ -313,7 +317,10 @@ class TestFifo(unittest.TestCase): self.assertEqual(f.pop(), (0, None)) def test_given_list(self): - f = asynchat.fifo([b'x', 17, 3]) + with self.assertWarns(DeprecationWarning) as cm: + f = asynchat.fifo([b'x', 17, 3]) + self.assertEqual(str(cm.warning), + "fifo class will be removed in Python 3.6") self.assertEqual(len(f), 3) self.assertEqual(f.pop(), (1, b'x')) self.assertEqual(f.pop(), (1, 17)) diff --git a/Lib/test/test_asyncio/test_pep492.py b/Lib/test/test_asyncio/test_pep492.py new file mode 100644 index 0000000..41e1b8a --- /dev/null +++ b/Lib/test/test_asyncio/test_pep492.py @@ -0,0 +1,208 @@ +"""Tests support for new syntax introduced by PEP 492.""" + +import collections.abc +import types +import unittest + +from test import support +from unittest import mock + +import asyncio +from asyncio import test_utils + + +class BaseTest(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.BaseEventLoop() + self.loop._process_events = mock.Mock() + self.loop._selector = mock.Mock() + self.loop._selector.select.return_value = () + self.set_event_loop(self.loop) + + +class LockTests(BaseTest): + + def test_context_manager_async_with(self): + primitives = [ + asyncio.Lock(loop=self.loop), + asyncio.Condition(loop=self.loop), + asyncio.Semaphore(loop=self.loop), + asyncio.BoundedSemaphore(loop=self.loop), + ] + + async def test(lock): + await asyncio.sleep(0.01, loop=self.loop) + self.assertFalse(lock.locked()) + async with lock as _lock: + self.assertIs(_lock, None) + self.assertTrue(lock.locked()) + await asyncio.sleep(0.01, loop=self.loop) + self.assertTrue(lock.locked()) + self.assertFalse(lock.locked()) + + for primitive in primitives: + self.loop.run_until_complete(test(primitive)) + self.assertFalse(primitive.locked()) + + def test_context_manager_with_await(self): + primitives = [ + asyncio.Lock(loop=self.loop), + asyncio.Condition(loop=self.loop), + asyncio.Semaphore(loop=self.loop), + asyncio.BoundedSemaphore(loop=self.loop), + ] + + async def test(lock): + await asyncio.sleep(0.01, loop=self.loop) + self.assertFalse(lock.locked()) + with await lock as _lock: + self.assertIs(_lock, None) + self.assertTrue(lock.locked()) + await asyncio.sleep(0.01, loop=self.loop) + self.assertTrue(lock.locked()) + self.assertFalse(lock.locked()) + + for primitive in primitives: + self.loop.run_until_complete(test(primitive)) + self.assertFalse(primitive.locked()) + + +class StreamReaderTests(BaseTest): + + def test_readline(self): + DATA = b'line1\nline2\nline3' + + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(DATA) + stream.feed_eof() + + async def reader(): + data = [] + async for line in stream: + data.append(line) + return data + + data = self.loop.run_until_complete(reader()) + self.assertEqual(data, [b'line1\n', b'line2\n', b'line3']) + + +class CoroutineTests(BaseTest): + + def test_iscoroutine(self): + async def foo(): pass + + f = foo() + try: + self.assertTrue(asyncio.iscoroutine(f)) + finally: + f.close() # silence warning + + # Test that asyncio.iscoroutine() uses collections.abc.Coroutine + class FakeCoro: + def send(self, value): pass + def throw(self, typ, val=None, tb=None): pass + def close(self): pass + def __await__(self): yield + + self.assertTrue(asyncio.iscoroutine(FakeCoro())) + + def test_iscoroutinefunction(self): + async def foo(): pass + self.assertTrue(asyncio.iscoroutinefunction(foo)) + + def test_function_returning_awaitable(self): + class Awaitable: + def __await__(self): + return ('spam',) + + @asyncio.coroutine + def func(): + return Awaitable() + + coro = func() + self.assertEqual(coro.send(None), 'spam') + coro.close() + + def test_async_def_coroutines(self): + async def bar(): + return 'spam' + async def foo(): + return await bar() + + # production mode + data = self.loop.run_until_complete(foo()) + self.assertEqual(data, 'spam') + + # debug mode + self.loop.set_debug(True) + data = self.loop.run_until_complete(foo()) + self.assertEqual(data, 'spam') + + @mock.patch('asyncio.coroutines.logger') + def test_async_def_wrapped(self, m_log): + async def foo(): + pass + async def start(): + foo_coro = foo() + self.assertRegex( + repr(foo_coro), + r'<CoroWrapper .*\.foo\(\) running at .*pep492.*>') + + with support.check_warnings((r'.*foo.*was never', + RuntimeWarning)): + foo_coro = None + support.gc_collect() + self.assertTrue(m_log.error.called) + message = m_log.error.call_args[0][0] + self.assertRegex(message, + r'CoroWrapper.*foo.*was never') + + self.loop.set_debug(True) + self.loop.run_until_complete(start()) + + async def start(): + foo_coro = foo() + task = asyncio.ensure_future(foo_coro, loop=self.loop) + self.assertRegex(repr(task), r'Task.*foo.*running') + + self.loop.run_until_complete(start()) + + + def test_types_coroutine(self): + def gen(): + yield from () + return 'spam' + + @types.coroutine + def func(): + return gen() + + async def coro(): + wrapper = func() + self.assertIsInstance(wrapper, types._GeneratorWrapper) + return await wrapper + + data = self.loop.run_until_complete(coro()) + self.assertEqual(data, 'spam') + + def test_task_print_stack(self): + T = None + + async def foo(): + f = T.get_stack(limit=1) + try: + self.assertEqual(f[0].f_code.co_name, 'foo') + finally: + f = None + + async def runner(): + nonlocal T + T = asyncio.ensure_future(foo(), loop=self.loop) + await T + + self.loop.run_until_complete(runner()) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py index d44726d..3857916 100644 --- a/Lib/test/test_asyncore.py +++ b/Lib/test/test_asyncore.py @@ -7,7 +7,6 @@ import sys import time import errno import struct -import warnings from test import support from io import BytesIO @@ -65,7 +64,7 @@ class crashingdummy: # used when testing senders; just collects what it gets until newline is sent def capture_server(evt, buf, serv): try: - serv.listen(5) + serv.listen() conn, addr = serv.accept() except socket.timeout: pass @@ -298,23 +297,6 @@ class DispatcherTests(unittest.TestCase): 'warning: unhandled connect event'] self.assertEqual(lines, expected) - def test_issue_8594(self): - # XXX - this test is supposed to be removed in next major Python - # version - d = asyncore.dispatcher(socket.socket()) - # make sure the error message no longer refers to the socket - # object but the dispatcher instance instead - self.assertRaisesRegex(AttributeError, 'dispatcher instance', - getattr, d, 'foo') - # cheap inheritance with the underlying socket is supposed - # to still work but a DeprecationWarning is expected - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - family = d.family - self.assertEqual(family, socket.AF_INET) - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - def test_strerror(self): # refers to bug #8573 err = asyncore._strerror(errno.EPERM) @@ -331,9 +313,8 @@ class dispatcherwithsend_noread(asyncore.dispatcher_with_send): def handle_connect(self): pass -class DispatcherWithSendTests(unittest.TestCase): - usepoll = False +class DispatcherWithSendTests(unittest.TestCase): def setUp(self): pass @@ -383,10 +364,6 @@ class DispatcherWithSendTests(unittest.TestCase): self.fail("join() timed out") - -class DispatcherWithSendTests_UsePoll(DispatcherWithSendTests): - usepoll = True - @unittest.skipUnless(hasattr(asyncore, 'file_wrapper'), 'asyncore.file_wrapper required') class FileWrapperTest(unittest.TestCase): diff --git a/Lib/test/test_atexit.py b/Lib/test/test_atexit.py index 70d2f1c..172bd25 100644 --- a/Lib/test/test_atexit.py +++ b/Lib/test/test_atexit.py @@ -177,9 +177,5 @@ class SubinterpreterTest(unittest.TestCase): self.assertEqual(atexit._ncallbacks(), n) -def test_main(): - support.run_unittest(__name__) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_augassign.py b/Lib/test/test_augassign.py index 0e75c6b..5093e9d 100644 --- a/Lib/test/test_augassign.py +++ b/Lib/test/test_augassign.py @@ -1,6 +1,5 @@ # Augmented assignment test. -from test.support import run_unittest import unittest @@ -136,6 +135,14 @@ class AugAssignTest(unittest.TestCase): output.append("__imul__ called") return self + def __matmul__(self, val): + output.append("__matmul__ called") + def __rmatmul__(self, val): + output.append("__rmatmul__ called") + def __imatmul__(self, val): + output.append("__imatmul__ called") + return self + def __floordiv__(self, val): output.append("__floordiv__ called") return self @@ -225,6 +232,10 @@ class AugAssignTest(unittest.TestCase): 1 * x x *= 1 + x @ 1 + 1 @ x + x @= 1 + x / 1 1 / x x /= 1 @@ -271,6 +282,9 @@ __isub__ called __mul__ called __rmul__ called __imul__ called +__matmul__ called +__rmatmul__ called +__imatmul__ called __truediv__ called __rtruediv__ called __itruediv__ called @@ -300,8 +314,5 @@ __rlshift__ called __ilshift__ called '''.splitlines()) -def test_main(): - run_unittest(AugAssignTest) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py index b9738dd..a0f548d 100644 --- a/Lib/test/test_base64.py +++ b/Lib/test/test_base64.py @@ -3,10 +3,8 @@ from test import support import base64 import binascii import os -import sys -import subprocess -import struct from array import array +from test.support import script_helper class LegacyBase64TestCase(unittest.TestCase): @@ -622,15 +620,13 @@ class BaseXYTestCase(unittest.TestCase): self.assertTrue(issubclass(binascii.Error, ValueError)) - class TestMain(unittest.TestCase): def tearDown(self): if os.path.exists(support.TESTFN): os.unlink(support.TESTFN) - def get_output(self, *args, **options): - args = (sys.executable, '-m', 'base64') + args - return subprocess.check_output(args, **options) + def get_output(self, *args): + return script_helper.assert_python_ok('-m', 'base64', *args).out def test_encode_decode(self): output = self.get_output('-t') @@ -643,13 +639,14 @@ class TestMain(unittest.TestCase): def test_encode_file(self): with open(support.TESTFN, 'wb') as fp: fp.write(b'a\xffb\n') - output = self.get_output('-e', support.TESTFN) self.assertEqual(output.rstrip(), b'Yf9iCg==') - with open(support.TESTFN, 'rb') as fp: - output = self.get_output('-e', stdin=fp) - self.assertEqual(output.rstrip(), b'Yf9iCg==') + def test_encode_from_stdin(self): + with script_helper.spawn_python('-m', 'base64', '-e') as proc: + out, err = proc.communicate(b'a\xffb\n') + self.assertEqual(out.rstrip(), b'Yf9iCg==') + self.assertIsNone(err) def test_decode(self): with open(support.TESTFN, 'wb') as fp: @@ -657,10 +654,5 @@ class TestMain(unittest.TestCase): output = self.get_output('-d', support.TESTFN) self.assertEqual(output.rstrip(), b'a\xffb') - - -def test_main(): - support.run_unittest(__name__) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_binascii.py b/Lib/test/test_binascii.py index 389daa0..8367afe 100644 --- a/Lib/test/test_binascii.py +++ b/Lib/test/test_binascii.py @@ -1,6 +1,5 @@ """Test the binascii C module.""" -from test import support import unittest import binascii import array @@ -277,11 +276,5 @@ class MemoryviewBinASCIITest(BinASCIITest): type2test = memoryview -def test_main(): - support.run_unittest(BinASCIITest, - ArrayBinASCIITest, - BytearrayBinASCIITest, - MemoryviewBinASCIITest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_binop.py b/Lib/test/test_binop.py index 963aa01..823740c 100644 --- a/Lib/test/test_binop.py +++ b/Lib/test/test_binop.py @@ -389,9 +389,5 @@ class OperationOrderTests(unittest.TestCase): self.assertEqual(op_sequence(le, B, V), ['B.__le__', 'V.__ge__']) -def test_main(): - support.run_unittest(RatTestCase, OperationOrderTests) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_buffer.py b/Lib/test/test_buffer.py index aa15377..a653390 100644 --- a/Lib/test/test_buffer.py +++ b/Lib/test/test_buffer.py @@ -11,6 +11,7 @@ # memoryview tests is now in this module. # +import contextlib import unittest from test import support from itertools import permutations, product @@ -216,7 +217,7 @@ def iter_format(nitems, testobj='ndarray'): for t in iter_mode(nitems, testobj): yield t if testobj != 'ndarray': - raise StopIteration + return yield struct_items(nitems, testobj) @@ -1007,6 +1008,7 @@ class TestBufferProtocol(unittest.TestCase): # shape, strides, offset structure = ( ([], [], 0), + ([1,3,1], [], 0), ([12], [], 0), ([12], [-1], 11), ([6], [2], 0), @@ -1078,6 +1080,18 @@ class TestBufferProtocol(unittest.TestCase): self.assertRaises(BufferError, ndarray, ex, getbuf=PyBUF_ANY_CONTIGUOUS) nd = ndarray(ex, getbuf=PyBUF_SIMPLE) + # Issue #22445: New precise contiguity definition. + for shape in [1,12,1], [7,0,7]: + for order in 0, ND_FORTRAN: + ex = ndarray(items, shape=shape, flags=order|ND_WRITABLE) + self.assertTrue(is_contiguous(ex, 'F')) + self.assertTrue(is_contiguous(ex, 'C')) + + for flags in requests: + nd = ndarray(ex, getbuf=flags) + self.assertTrue(is_contiguous(nd, 'F')) + self.assertTrue(is_contiguous(nd, 'C')) + def test_ndarray_exceptions(self): nd = ndarray([9], [1]) ndm = ndarray([9], [1], flags=ND_VAREXPORT) @@ -2454,7 +2468,7 @@ class TestBufferProtocol(unittest.TestCase): def test_memoryview_sizeof(self): check = self.check_sizeof vsize = support.calcvobjsize - base_struct = 'Pnin 2P2n2i5P 3cP' + base_struct = 'Pnin 2P2n2i5P P' per_dim = '3n' items = list(range(8)) @@ -2545,8 +2559,7 @@ class TestBufferProtocol(unittest.TestCase): ex = ndarray(sitems, shape=[1], format=sfmt) msrc = memoryview(ex) for dfmt, _, _ in iter_format(1): - if (not is_memoryview_format(sfmt) or - not is_memoryview_format(dfmt)): + if not is_memoryview_format(dfmt): self.assertRaises(ValueError, msrc.cast, dfmt, [32//dsize]) else: @@ -2759,6 +2772,32 @@ class TestBufferProtocol(unittest.TestCase): ndim=ndim, shape=shape, strides=strides, lst=lst, cast=True) + if ctypes: + # format: "T{>l:x:>d:y:}" + class BEPoint(ctypes.BigEndianStructure): + _fields_ = [("x", ctypes.c_long), ("y", ctypes.c_double)] + point = BEPoint(100, 200.1) + m1 = memoryview(point) + m2 = m1.cast('B') + self.assertEqual(m2.obj, point) + self.assertEqual(m2.itemsize, 1) + self.assertEqual(m2.readonly, 0) + self.assertEqual(m2.ndim, 1) + self.assertEqual(m2.shape, (m2.nbytes,)) + self.assertEqual(m2.strides, (1,)) + self.assertEqual(m2.suboffsets, ()) + + x = ctypes.c_double(1.2) + m1 = memoryview(x) + m2 = m1.cast('c') + self.assertEqual(m2.obj, x) + self.assertEqual(m2.itemsize, 1) + self.assertEqual(m2.readonly, 0) + self.assertEqual(m2.ndim, 1) + self.assertEqual(m2.shape, (m2.nbytes,)) + self.assertEqual(m2.strides, (1,)) + self.assertEqual(m2.suboffsets, ()) + def test_memoryview_tolist(self): # Most tolist() tests are in self.verify() etc. @@ -2812,6 +2851,13 @@ class TestBufferProtocol(unittest.TestCase): m = memoryview(ex) self.assertRaises(TypeError, eval, "9.0 in m", locals()) + @contextlib.contextmanager + def assert_out_of_bounds_error(self, dim): + with self.assertRaises(IndexError) as cm: + yield + self.assertEqual(str(cm.exception), + "index out of bounds on dimension %d" % (dim,)) + def test_memoryview_index(self): # ndim = 0 @@ -2838,12 +2884,31 @@ class TestBufferProtocol(unittest.TestCase): self.assertRaises(IndexError, m.__getitem__, -8) self.assertRaises(IndexError, m.__getitem__, 8) - # Not implemented: multidimensional sub-views + # multi-dimensional ex = ndarray(list(range(12)), shape=[3,4], flags=ND_WRITABLE) m = memoryview(ex) - self.assertRaises(NotImplementedError, m.__getitem__, 0) - self.assertRaises(NotImplementedError, m.__setitem__, 0, 9) + self.assertEqual(m[0, 0], 0) + self.assertEqual(m[2, 0], 8) + self.assertEqual(m[2, 3], 11) + self.assertEqual(m[-1, -1], 11) + self.assertEqual(m[-3, -4], 0) + + # out of bounds + for index in (3, -4): + with self.assert_out_of_bounds_error(dim=1): + m[index, 0] + for index in (4, -5): + with self.assert_out_of_bounds_error(dim=2): + m[0, index] + self.assertRaises(IndexError, m.__getitem__, (2**64, 0)) + self.assertRaises(IndexError, m.__getitem__, (0, 2**64)) + + self.assertRaises(TypeError, m.__getitem__, (0, 0, 0)) + self.assertRaises(TypeError, m.__getitem__, (0.0, 0.0)) + + # Not implemented: multidimensional sub-views + self.assertRaises(NotImplementedError, m.__getitem__, ()) self.assertRaises(NotImplementedError, m.__getitem__, 0) def test_memoryview_assign(self): @@ -2932,10 +2997,27 @@ class TestBufferProtocol(unittest.TestCase): m = memoryview(ex) self.assertRaises(NotImplementedError, m.__setitem__, 0, 1) - # Not implemented: multidimensional sub-views + # multi-dimensional ex = ndarray(list(range(12)), shape=[3,4], flags=ND_WRITABLE) m = memoryview(ex) + m[0,1] = 42 + self.assertEqual(ex[0][1], 42) + m[-1,-1] = 43 + self.assertEqual(ex[2][3], 43) + # errors + for index in (3, -4): + with self.assert_out_of_bounds_error(dim=1): + m[index, 0] = 0 + for index in (4, -5): + with self.assert_out_of_bounds_error(dim=2): + m[0, index] = 0 + self.assertRaises(IndexError, m.__setitem__, (2**64, 0), 0) + self.assertRaises(IndexError, m.__setitem__, (0, 2**64), 0) + + self.assertRaises(TypeError, m.__setitem__, (0, 0, 0), 0) + self.assertRaises(TypeError, m.__setitem__, (0.0, 0.0), 0) + # Not implemented: multidimensional sub-views self.assertRaises(NotImplementedError, m.__setitem__, 0, [2, 3]) def test_memoryview_slice(self): @@ -2948,8 +3030,8 @@ class TestBufferProtocol(unittest.TestCase): self.assertRaises(ValueError, m.__setitem__, slice(0,2,0), bytearray([1,2])) - # invalid slice key - self.assertRaises(TypeError, m.__getitem__, ()) + # 0-dim slicing (identity function) + self.assertRaises(NotImplementedError, m.__getitem__, ()) # multidimensional slices ex = ndarray(list(range(12)), shape=[12], flags=ND_WRITABLE) diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index 14366c6..cdbb2cb 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -16,7 +16,7 @@ import unittest import warnings from operator import neg from test.support import TESTFN, unlink, run_unittest, check_warnings -from test.script_helper import assert_python_ok +from test.support.script_helper import assert_python_ok try: import pty, signal except ImportError: @@ -312,11 +312,11 @@ class BuiltinTest(unittest.TestCase): self.assertRaises(TypeError, compile) self.assertRaises(ValueError, compile, 'print(42)\n', '<string>', 'badmode') self.assertRaises(ValueError, compile, 'print(42)\n', '<string>', 'single', 0xff) - self.assertRaises(TypeError, compile, chr(0), 'f', 'exec') + self.assertRaises(ValueError, compile, chr(0), 'f', 'exec') self.assertRaises(TypeError, compile, 'pass', '?', 'exec', mode='eval', source='0', filename='tmp') compile('print("\xe5")\n', '', 'exec') - self.assertRaises(TypeError, compile, chr(0), 'f', 'exec') + self.assertRaises(ValueError, compile, chr(0), 'f', 'exec') self.assertRaises(ValueError, compile, str('a = 1'), 'f', 'bad') # test the optimize argument @@ -1094,7 +1094,7 @@ class BuiltinTest(unittest.TestCase): self.assertAlmostEqual(pow(-1, 0.5), 1j) self.assertAlmostEqual(pow(-1, 1/3), 0.5 + 0.8660254037844386j) - self.assertRaises(TypeError, pow, -1, -2, 3) + self.assertRaises(ValueError, pow, -1, -2, 3) self.assertRaises(ValueError, pow, 1, 2, 0) self.assertRaises(TypeError, pow) diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index b00573f..53a80f4 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -301,6 +301,14 @@ class BaseBytesTest: self.assertRaises(ValueError, self.type2test.fromhex, '\x00') self.assertRaises(ValueError, self.type2test.fromhex, '12 \x00 34') + def test_hex(self): + self.assertRaises(TypeError, self.type2test.hex) + self.assertRaises(TypeError, self.type2test.hex, 1) + self.assertEqual(self.type2test(b"").hex(), "") + self.assertEqual(bytearray([0x1a, 0x2b, 0x30]).hex(), '1a2b30') + self.assertEqual(self.type2test(b"\x1a\x2b\x30").hex(), '1a2b30') + self.assertEqual(memoryview(b"\x1a\x2b\x30").hex(), '1a2b30') + def test_join(self): self.assertEqual(self.type2test(b"").join([]), b"") self.assertEqual(self.type2test(b"").join([b""]), b"") @@ -461,6 +469,28 @@ class BaseBytesTest: self.assertEqual(b.rindex(i, 3, 9), 7) self.assertRaises(ValueError, b.rindex, w, 1, 3) + def test_mod(self): + b = b'hello, %b!' + orig = b + b = b % b'world' + self.assertEqual(b, b'hello, world!') + self.assertEqual(orig, b'hello, %b!') + self.assertFalse(b is orig) + b = b'%s / 100 = %d%%' + a = b % (b'seventy-nine', 79) + self.assertEqual(a, b'seventy-nine / 100 = 79%') + + def test_imod(self): + b = b'hello, %b!' + orig = b + b %= b'world' + self.assertEqual(b, b'hello, world!') + self.assertEqual(orig, b'hello, %b!') + self.assertFalse(b is orig) + b = b'%s / 100 = %d%%' + b %= (b'seventy-nine', 79) + self.assertEqual(b, b'seventy-nine / 100 = 79%') + def test_replace(self): b = self.type2test(b'mississippi') self.assertEqual(b.replace(b'i', b'a'), b'massassappa') @@ -722,6 +752,11 @@ class BaseBytesTest: class BytesTest(BaseBytesTest, unittest.TestCase): type2test = bytes + def test_getitem_error(self): + msg = "byte indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + b'python'['a'] + def test_buffer_is_readonly(self): fd = os.open(__file__, os.O_RDONLY) with open(fd, "rb", buffering=0) as f: @@ -776,6 +811,17 @@ class BytesTest(BaseBytesTest, unittest.TestCase): class ByteArrayTest(BaseBytesTest, unittest.TestCase): type2test = bytearray + def test_getitem_error(self): + msg = "bytearray indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + bytearray(b'python')['a'] + + def test_setitem_error(self): + msg = "bytearray indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + b = bytearray(b'python') + b['a'] = "python" + def test_nohash(self): self.assertRaises(TypeError, hash, bytearray()) @@ -990,6 +1036,28 @@ class ByteArrayTest(BaseBytesTest, unittest.TestCase): b[8:] = b self.assertEqual(b, bytearray(list(range(8)) + list(range(256)))) + def test_mod(self): + b = bytearray(b'hello, %b!') + orig = b + b = b % b'world' + self.assertEqual(b, b'hello, world!') + self.assertEqual(orig, bytearray(b'hello, %b!')) + self.assertFalse(b is orig) + b = bytearray(b'%s / 100 = %d%%') + a = b % (b'seventy-nine', 79) + self.assertEqual(a, bytearray(b'seventy-nine / 100 = 79%')) + + def test_imod(self): + b = bytearray(b'hello, %b!') + orig = b + b %= b'world' + self.assertEqual(b, b'hello, world!') + self.assertEqual(orig, bytearray(b'hello, %b!')) + self.assertFalse(b is orig) + b = bytearray(b'%s / 100 = %d%%') + b %= (b'seventy-nine', 79) + self.assertEqual(b, bytearray(b'seventy-nine / 100 = 79%')) + def test_iconcat(self): b = bytearray(b"abc") b1 = b @@ -1197,6 +1265,10 @@ class ByteArrayTest(BaseBytesTest, unittest.TestCase): self.assertRaises(BufferError, delslice) self.assertEqual(b, orig) + @test.support.cpython_only + def test_obsolete_write_lock(self): + from _testcapi import getbuffer_with_null_view + self.assertRaises(BufferError, getbuffer_with_null_view, bytearray()) class AssortedBytesTest(unittest.TestCase): # @@ -1307,20 +1379,35 @@ class AssortedBytesTest(unittest.TestCase): b = bytearray() self.assertFalse(b.replace(b'', b'') is b) + @unittest.skipUnless(sys.flags.bytes_warning, + "BytesWarning is needed for this test: use -bb option") def test_compare(self): - if sys.flags.bytes_warning: - def bytes_warning(): - return test.support.check_warnings(('', BytesWarning)) - with bytes_warning(): - b'' == '' - with bytes_warning(): - b'' != '' - with bytes_warning(): - bytearray(b'') == '' - with bytes_warning(): - bytearray(b'') != '' - else: - self.skipTest("BytesWarning is needed for this test: use -bb option") + def bytes_warning(): + return test.support.check_warnings(('', BytesWarning)) + with bytes_warning(): + b'' == '' + with bytes_warning(): + '' == b'' + with bytes_warning(): + b'' != '' + with bytes_warning(): + '' != b'' + with bytes_warning(): + bytearray(b'') == '' + with bytes_warning(): + '' == bytearray(b'') + with bytes_warning(): + bytearray(b'') != '' + with bytes_warning(): + '' != bytearray(b'') + with bytes_warning(): + b'\0' == 0 + with bytes_warning(): + 0 == b'\0' + with bytes_warning(): + b'\0' != 0 + with bytes_warning(): + 0 != b'\0' # Optimizations: # __iter__? (optimization) diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py index beef275..a1e4b8d 100644 --- a/Lib/test/test_bz2.py +++ b/Lib/test/test_bz2.py @@ -2,13 +2,15 @@ from test import support from test.support import bigmemtest, _4G import unittest -from io import BytesIO +from io import BytesIO, DEFAULT_BUFFER_SIZE import os import pickle +import glob import random import subprocess import sys from test.support import unlink +import _compression try: import threading @@ -51,6 +53,19 @@ class BaseTest(unittest.TestCase): EMPTY_DATA = b'BZh9\x17rE8P\x90\x00\x00\x00\x00' BAD_DATA = b'this is not a valid bzip2 file' + # Some tests need more than one block of uncompressed data. Since one block + # is at least 100 kB, we gather some data dynamically and compress it. + # Note that this assumes that compression works correctly, so we cannot + # simply use the bigger test data for all tests. + test_size = 0 + BIG_TEXT = bytearray(128*1024) + for fname in glob.glob(os.path.join(os.path.dirname(__file__), '*.py')): + with open(fname, 'rb') as fh: + test_size += fh.readinto(memoryview(BIG_TEXT)[test_size:]) + if test_size > 128*1024: + break + BIG_DATA = bz2.compress(BIG_TEXT, compresslevel=1) + def setUp(self): self.filename = support.TESTFN @@ -96,7 +111,7 @@ class BZ2FileTest(BaseTest): def testRead(self): self.createTempFile() with BZ2File(self.filename) as bz2f: - self.assertRaises(TypeError, bz2f.read, None) + self.assertRaises(TypeError, bz2f.read, float()) self.assertEqual(bz2f.read(), self.TEXT) def testReadBadFile(self): @@ -107,21 +122,21 @@ class BZ2FileTest(BaseTest): def testReadMultiStream(self): self.createTempFile(streams=5) with BZ2File(self.filename) as bz2f: - self.assertRaises(TypeError, bz2f.read, None) + self.assertRaises(TypeError, bz2f.read, float()) self.assertEqual(bz2f.read(), self.TEXT * 5) def testReadMonkeyMultiStream(self): # Test BZ2File.read() on a multi-stream archive where a stream # boundary coincides with the end of the raw read buffer. - buffer_size = bz2._BUFFER_SIZE - bz2._BUFFER_SIZE = len(self.DATA) + buffer_size = _compression.BUFFER_SIZE + _compression.BUFFER_SIZE = len(self.DATA) try: self.createTempFile(streams=5) with BZ2File(self.filename) as bz2f: - self.assertRaises(TypeError, bz2f.read, None) + self.assertRaises(TypeError, bz2f.read, float()) self.assertEqual(bz2f.read(), self.TEXT * 5) finally: - bz2._BUFFER_SIZE = buffer_size + _compression.BUFFER_SIZE = buffer_size def testReadTrailingJunk(self): self.createTempFile(suffix=self.BAD_DATA) @@ -136,7 +151,7 @@ class BZ2FileTest(BaseTest): def testRead0(self): self.createTempFile() with BZ2File(self.filename) as bz2f: - self.assertRaises(TypeError, bz2f.read, None) + self.assertRaises(TypeError, bz2f.read, float()) self.assertEqual(bz2f.read(0), b"") def testReadChunk10(self): @@ -545,13 +560,24 @@ class BZ2FileTest(BaseTest): with BZ2File(str_filename, "rb") as f: self.assertEqual(f.read(), self.DATA) + def testDecompressLimited(self): + """Decompressed data buffering should be limited""" + bomb = bz2.compress(bytes(int(2e6)), compresslevel=9) + self.assertLess(len(bomb), _compression.BUFFER_SIZE) + + decomp = BZ2File(BytesIO(bomb)) + self.assertEqual(bytes(1), decomp.read(1)) + max_decomp = 1 + DEFAULT_BUFFER_SIZE + self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, + "Excessive amount of data was decompressed") + # Tests for a BZ2File wrapping another file object: def testReadBytesIO(self): with BytesIO(self.DATA) as bio: with BZ2File(bio) as bz2f: - self.assertRaises(TypeError, bz2f.read, None) + self.assertRaises(TypeError, bz2f.read, float()) self.assertEqual(bz2f.read(), self.TEXT) self.assertFalse(bio.closed) @@ -705,6 +731,95 @@ class BZ2DecompressorTest(BaseTest): with self.assertRaises(TypeError): pickle.dumps(BZ2Decompressor(), proto) + def testDecompressorChunksMaxsize(self): + bzd = BZ2Decompressor() + max_length = 100 + out = [] + + # Feed some input + len_ = len(self.BIG_DATA) - 64 + out.append(bzd.decompress(self.BIG_DATA[:len_], + max_length=max_length)) + self.assertFalse(bzd.needs_input) + self.assertEqual(len(out[-1]), max_length) + + # Retrieve more data without providing more input + out.append(bzd.decompress(b'', max_length=max_length)) + self.assertFalse(bzd.needs_input) + self.assertEqual(len(out[-1]), max_length) + + # Retrieve more data while providing more input + out.append(bzd.decompress(self.BIG_DATA[len_:], + max_length=max_length)) + self.assertLessEqual(len(out[-1]), max_length) + + # Retrieve remaining uncompressed data + while not bzd.eof: + out.append(bzd.decompress(b'', max_length=max_length)) + self.assertLessEqual(len(out[-1]), max_length) + + out = b"".join(out) + self.assertEqual(out, self.BIG_TEXT) + self.assertEqual(bzd.unused_data, b"") + + def test_decompressor_inputbuf_1(self): + # Test reusing input buffer after moving existing + # contents to beginning + bzd = BZ2Decompressor() + out = [] + + # Create input buffer and fill it + self.assertEqual(bzd.decompress(self.DATA[:100], + max_length=0), b'') + + # Retrieve some results, freeing capacity at beginning + # of input buffer + out.append(bzd.decompress(b'', 2)) + + # Add more data that fits into input buffer after + # moving existing data to beginning + out.append(bzd.decompress(self.DATA[100:105], 15)) + + # Decompress rest of data + out.append(bzd.decompress(self.DATA[105:])) + self.assertEqual(b''.join(out), self.TEXT) + + def test_decompressor_inputbuf_2(self): + # Test reusing input buffer by appending data at the + # end right away + bzd = BZ2Decompressor() + out = [] + + # Create input buffer and empty it + self.assertEqual(bzd.decompress(self.DATA[:200], + max_length=0), b'') + out.append(bzd.decompress(b'')) + + # Fill buffer with new data + out.append(bzd.decompress(self.DATA[200:280], 2)) + + # Append some more data, not enough to require resize + out.append(bzd.decompress(self.DATA[280:300], 2)) + + # Decompress rest of data + out.append(bzd.decompress(self.DATA[300:])) + self.assertEqual(b''.join(out), self.TEXT) + + def test_decompressor_inputbuf_3(self): + # Test reusing input buffer after extending it + + bzd = BZ2Decompressor() + out = [] + + # Create almost full input buffer + out.append(bzd.decompress(self.DATA[:200], 5)) + + # Add even more data to it, requiring resize + out.append(bzd.decompress(self.DATA[200:300], 5)) + + # Decompress rest of data + out.append(bzd.decompress(self.DATA[300:])) + self.assertEqual(b''.join(out), self.TEXT) class CompressDecompressTest(BaseTest): def testCompress(self): diff --git a/Lib/test/test_calendar.py b/Lib/test/test_calendar.py index 9193857..80ed632 100644 --- a/Lib/test/test_calendar.py +++ b/Lib/test/test_calendar.py @@ -2,7 +2,7 @@ import calendar import unittest from test import support -from test.script_helper import assert_python_ok, assert_python_failure +from test.support.script_helper import assert_python_ok, assert_python_failure import time import locale import sys diff --git a/Lib/test/test_call.py b/Lib/test/test_call.py index c00ccba..e2b8e0f 100644 --- a/Lib/test/test_call.py +++ b/Lib/test/test_call.py @@ -1,5 +1,4 @@ import unittest -from test import support # The test cases here cover several paths through the function calling # code. They depend on the METH_XXX flag that is used to define a C @@ -123,9 +122,5 @@ class CFunctionCalls(unittest.TestCase): self.assertRaises(TypeError, [].count, x=2, y=2) -def test_main(): - support.run_unittest(CFunctionCalls) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_capi.py b/Lib/test/test_capi.py index 36c62376..eae3add 100644 --- a/Lib/test/test_capi.py +++ b/Lib/test/test_capi.py @@ -6,10 +6,12 @@ import pickle import random import subprocess import sys +import textwrap import time import unittest from test import support from test.support import MISSING_C_DOCSTRINGS +from test.support.script_helper import assert_python_failure try: import _posixsubprocess except ImportError: @@ -21,6 +23,9 @@ except ImportError: # Skip this test if the _testcapi module isn't available. _testcapi = support.import_module('_testcapi') +# Were we compiled --with-pydebug or with #define Py_DEBUG? +Py_DEBUG = hasattr(sys, 'gettotalrefcount') + def testfunction(self): """some doc""" @@ -118,7 +123,7 @@ class CAPITest(unittest.TestCase): self.assertEqual(_testcapi.no_docstring.__doc__, None) self.assertEqual(_testcapi.no_docstring.__text_signature__, None) - self.assertEqual(_testcapi.docstring_empty.__doc__, "") + self.assertEqual(_testcapi.docstring_empty.__doc__, None) self.assertEqual(_testcapi.docstring_empty.__text_signature__, None) self.assertEqual(_testcapi.docstring_no_signature.__doc__, @@ -145,11 +150,92 @@ class CAPITest(unittest.TestCase): "This docstring has a valid signature.") self.assertEqual(_testcapi.docstring_with_signature.__text_signature__, "($module, /, sig)") + self.assertEqual(_testcapi.docstring_with_signature_but_no_doc.__doc__, None) + self.assertEqual(_testcapi.docstring_with_signature_but_no_doc.__text_signature__, + "($module, /, sig)") + self.assertEqual(_testcapi.docstring_with_signature_and_extra_newlines.__doc__, "\nThis docstring has a valid signature and some extra newlines.") self.assertEqual(_testcapi.docstring_with_signature_and_extra_newlines.__text_signature__, "($module, /, parameter)") + def test_c_type_with_matrix_multiplication(self): + M = _testcapi.matmulType + m1 = M() + m2 = M() + self.assertEqual(m1 @ m2, ("matmul", m1, m2)) + self.assertEqual(m1 @ 42, ("matmul", m1, 42)) + self.assertEqual(42 @ m1, ("matmul", 42, m1)) + o = m1 + o @= m2 + self.assertEqual(o, ("imatmul", m1, m2)) + o = m1 + o @= 42 + self.assertEqual(o, ("imatmul", m1, 42)) + o = 42 + o @= m1 + self.assertEqual(o, ("matmul", 42, m1)) + + def test_return_null_without_error(self): + # Issue #23571: A function must not return NULL without setting an + # error + if Py_DEBUG: + code = textwrap.dedent(""" + import _testcapi + from test import support + + with support.SuppressCrashReport(): + _testcapi.return_null_without_error() + """) + rc, out, err = assert_python_failure('-c', code) + self.assertRegex(err.replace(b'\r', b''), + br'Fatal Python error: a function returned NULL ' + br'without setting an error\n' + br'SystemError: <built-in function ' + br'return_null_without_error> returned NULL ' + br'without setting an error\n' + br'\n' + br'Current thread.*:\n' + br' File .*", line 6 in <module>') + else: + with self.assertRaises(SystemError) as cm: + _testcapi.return_null_without_error() + self.assertRegex(str(cm.exception), + 'return_null_without_error.* ' + 'returned NULL without setting an error') + + def test_return_result_with_error(self): + # Issue #23571: A function must not return a result with an error set + if Py_DEBUG: + code = textwrap.dedent(""" + import _testcapi + from test import support + + with support.SuppressCrashReport(): + _testcapi.return_result_with_error() + """) + rc, out, err = assert_python_failure('-c', code) + self.assertRegex(err.replace(b'\r', b''), + br'Fatal Python error: a function returned a ' + br'result with an error set\n' + br'ValueError\n' + br'\n' + br'During handling of the above exception, ' + br'another exception occurred:\n' + br'\n' + br'SystemError: <built-in ' + br'function return_result_with_error> ' + br'returned a result with an error set\n' + br'\n' + br'Current thread.*:\n' + br' File .*, line 6 in <module>') + else: + with self.assertRaises(SystemError) as cm: + _testcapi.return_result_with_error() + self.assertRegex(str(cm.exception), + 'return_result_with_error.* ' + 'returned a result with an error set') + @unittest.skipUnless(threading, 'Threading required for this test.') class TestPendingCalls(unittest.TestCase): @@ -264,7 +350,7 @@ class EmbeddingTests(unittest.TestCase): exename += ext exepath = os.path.dirname(sys.executable) else: - exepath = os.path.join(basepath, "Modules") + exepath = os.path.join(basepath, "Programs") self.test_exe = exe = os.path.join(exepath, exename) if not os.path.exists(exe): self.skipTest("%r doesn't exist" % exe) @@ -283,12 +369,13 @@ class EmbeddingTests(unittest.TestCase): cmd.extend(args) p = subprocess.Popen(cmd, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + stderr=subprocess.PIPE, + universal_newlines=True) (out, err) = p.communicate() self.assertEqual(p.returncode, 0, "bad returncode %d, stderr is %r" % (p.returncode, err)) - return out.decode("latin1"), err.decode("latin1") + return out, err def test_subinterps(self): # This is just a "don't crash" test @@ -315,34 +402,38 @@ class EmbeddingTests(unittest.TestCase): print() print(out) print(err) + expected_errors = sys.__stdout__.errors expected_stdin_encoding = sys.__stdin__.encoding expected_pipe_encoding = self._get_default_pipe_encoding() - expected_output = os.linesep.join([ + expected_output = '\n'.join([ "--- Use defaults ---", "Expected encoding: default", "Expected errors: default", - "stdin: {0}:strict", - "stdout: {1}:strict", - "stderr: {1}:backslashreplace", + "stdin: {in_encoding}:{errors}", + "stdout: {out_encoding}:{errors}", + "stderr: {out_encoding}:backslashreplace", "--- Set errors only ---", "Expected encoding: default", - "Expected errors: surrogateescape", - "stdin: {0}:surrogateescape", - "stdout: {1}:surrogateescape", - "stderr: {1}:backslashreplace", + "Expected errors: ignore", + "stdin: {in_encoding}:ignore", + "stdout: {out_encoding}:ignore", + "stderr: {out_encoding}:backslashreplace", "--- Set encoding only ---", "Expected encoding: latin-1", "Expected errors: default", - "stdin: latin-1:strict", - "stdout: latin-1:strict", + "stdin: latin-1:{errors}", + "stdout: latin-1:{errors}", "stderr: latin-1:backslashreplace", "--- Set encoding and errors ---", "Expected encoding: latin-1", - "Expected errors: surrogateescape", - "stdin: latin-1:surrogateescape", - "stdout: latin-1:surrogateescape", - "stderr: latin-1:backslashreplace"]).format(expected_stdin_encoding, - expected_pipe_encoding) + "Expected errors: replace", + "stdin: latin-1:replace", + "stdout: latin-1:replace", + "stderr: latin-1:backslashreplace"]) + expected_output = expected_output.format( + in_encoding=expected_stdin_encoding, + out_encoding=expected_pipe_encoding, + errors=expected_errors) # This is useful if we ever trip over odd platform behaviour self.maxDiff = None self.assertEqual(out.strip(), expected_output) diff --git a/Lib/test/test_cgi.py b/Lib/test/test_cgi.py index 6b28106..ab9f6ab 100644 --- a/Lib/test/test_cgi.py +++ b/Lib/test/test_cgi.py @@ -1,4 +1,4 @@ -from test.support import run_unittest, check_warnings +from test.support import check_warnings import cgi import os import sys @@ -344,6 +344,16 @@ Larry self.assertEqual(fs.list[0].name, 'submit-name') self.assertEqual(fs.list[0].value, 'Larry') + def test_fieldstorage_as_context_manager(self): + fp = BytesIO(b'x' * 10) + env = {'REQUEST_METHOD': 'PUT'} + with cgi.FieldStorage(fp=fp, environ=env) as fs: + content = fs.file.read() + self.assertFalse(fs.file.closed) + self.assertTrue(fs.file.closed) + self.assertEqual(content, 'x' * 10) + with self.assertRaisesRegex(ValueError, 'I/O operation on closed file'): + fs.file.read() _qs_result = { 'key1': 'value1', @@ -519,9 +529,5 @@ Content-Transfer-Encoding: binary --AaB03x-- """ - -def test_main(): - run_unittest(CgiTests) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_cgitb.py b/Lib/test/test_cgitb.py index 2e072a9..a87a422 100644 --- a/Lib/test/test_cgitb.py +++ b/Lib/test/test_cgitb.py @@ -1,5 +1,5 @@ -from test.support import run_unittest -from test.script_helper import assert_python_failure, temp_dir +from test.support import temp_dir +from test.support.script_helper import assert_python_failure import unittest import sys import cgitb @@ -63,8 +63,5 @@ class TestCgitb(unittest.TestCase): self.assertNotIn('</p>', out) -def test_main(): - run_unittest(TestCgitb) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_charmapcodec.py b/Lib/test/test_charmapcodec.py index 6226587..4064aef 100644 --- a/Lib/test/test_charmapcodec.py +++ b/Lib/test/test_charmapcodec.py @@ -49,8 +49,5 @@ class CharmapCodecTest(unittest.TestCase): def test_maptoundefined(self): self.assertRaises(UnicodeError, str, b'abc\001', codecname) -def test_main(): - test.support.run_unittest(CharmapCodecTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_class.py b/Lib/test/test_class.py index e3883d6..4d554a3 100644 --- a/Lib/test/test_class.py +++ b/Lib/test/test_class.py @@ -2,7 +2,6 @@ import unittest -from test import support testmeths = [ @@ -13,6 +12,8 @@ testmeths = [ "rsub", "mul", "rmul", + "matmul", + "rmatmul", "truediv", "rtruediv", "floordiv", @@ -177,6 +178,14 @@ class ClassTests(unittest.TestCase): self.assertCallStack([("__rmul__", (testme, 1))]) callLst[:] = [] + testme @ 1 + self.assertCallStack([("__matmul__", (testme, 1))]) + + callLst[:] = [] + 1 @ testme + self.assertCallStack([("__rmatmul__", (testme, 1))]) + + callLst[:] = [] testme / 1 self.assertCallStack([("__truediv__", (testme, 1))]) @@ -491,10 +500,10 @@ class ClassTests(unittest.TestCase): try: a() # This should not segfault - except RuntimeError: + except RecursionError: pass else: - self.fail("Failed to raise RuntimeError") + self.fail("Failed to raise RecursionError") def testForExceptionsRaisedInInstanceGetattr2(self): # Tests for exceptions raised in instance_getattr2(). @@ -559,8 +568,5 @@ class ClassTests(unittest.TestCase): a = A(hash(A.f)^(-1)) hash(a.f) -def test_main(): - support.run_unittest(ClassTests) - -if __name__=='__main__': - test_main() +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_cmath.py b/Lib/test/test_cmath.py index 68bf16e..1f884e5 100644 --- a/Lib/test/test_cmath.py +++ b/Lib/test/test_cmath.py @@ -1,5 +1,6 @@ -from test.support import run_unittest, requires_IEEE_754, cpython_only +from test.support import requires_IEEE_754, cpython_only from test.test_math import parse_testfile, test_file +import test.test_math as test_math import unittest import cmath, math from cmath import phase, polar, rect, pi @@ -560,8 +561,46 @@ class CMathTests(unittest.TestCase): self.assertComplexIdentical(cmath.atanh(z), z) -def test_main(): - run_unittest(CMathTests) +class IsCloseTests(test_math.IsCloseTests): + isclose = cmath.isclose + + def test_reject_complex_tolerances(self): + with self.assertRaises(TypeError): + self.isclose(1j, 1j, rel_tol=1j) + + with self.assertRaises(TypeError): + self.isclose(1j, 1j, abs_tol=1j) + + with self.assertRaises(TypeError): + self.isclose(1j, 1j, rel_tol=1j, abs_tol=1j) + + def test_complex_values(self): + # test complex values that are close to within 12 decimal places + complex_examples = [(1.0+1.0j, 1.000000000001+1.0j), + (1.0+1.0j, 1.0+1.000000000001j), + (-1.0+1.0j, -1.000000000001+1.0j), + (1.0-1.0j, 1.0-0.999999999999j), + ] + + self.assertAllClose(complex_examples, rel_tol=1e-12) + self.assertAllNotClose(complex_examples, rel_tol=1e-13) + + def test_complex_near_zero(self): + # test values near zero that are near to within three decimal places + near_zero_examples = [(0.001j, 0), + (0.001, 0), + (0.001+0.001j, 0), + (-0.001+0.001j, 0), + (0.001-0.001j, 0), + (-0.001-0.001j, 0), + ] + + self.assertAllClose(near_zero_examples, abs_tol=1.5e-03) + self.assertAllNotClose(near_zero_examples, abs_tol=0.5e-03) + + self.assertIsClose(0.001-0.001j, 0.001+0.001j, abs_tol=2e-03) + self.assertIsNotClose(0.001-0.001j, 0.001+0.001j, abs_tol=1e-03) + if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py index cb9bbdd..0feb63f 100644 --- a/Lib/test/test_cmd_line.py +++ b/Lib/test/test_cmd_line.py @@ -8,8 +8,8 @@ import shutil import sys import subprocess import tempfile -from test import script_helper -from test.script_helper import (spawn_python, kill_python, assert_python_ok, +from test.support import script_helper +from test.support.script_helper import (spawn_python, kill_python, assert_python_ok, assert_python_failure) @@ -59,7 +59,7 @@ class CmdLineTest(unittest.TestCase): def test_xoptions(self): def get_xoptions(*args): - # use subprocess module directly because test.script_helper adds + # use subprocess module directly because test.support.script_helper adds # "-X faulthandler" to the command line args = (sys.executable, '-E') + args args += ('-c', 'import sys; print(sys._xoptions)') @@ -344,7 +344,8 @@ class CmdLineTest(unittest.TestCase): # Issue #5319: if stdout.flush() fails at shutdown, an error should # be printed out. code = """if 1: - import os, sys + import os, sys, test.support + test.support.SuppressCrashReport().__enter__() sys.stdout.write('x') os.close(sys.stdout.fileno())""" rc, out, err = assert_python_ok('-c', code) @@ -444,7 +445,7 @@ class CmdLineTest(unittest.TestCase): self.assertEqual(err.splitlines().count(b'Unknown option: -a'), 1) self.assertEqual(b'', out) - @unittest.skipIf(script_helper._interpreter_requires_environment(), + @unittest.skipIf(script_helper.interpreter_requires_environment(), 'Cannot run -I tests when PYTHON env vars are required.') def test_isolatedmode(self): self.verify_valid_flag('-I') diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py index 7350164..fda3e62 100644 --- a/Lib/test/test_cmd_line_script.py +++ b/Lib/test/test_cmd_line_script.py @@ -13,10 +13,9 @@ import subprocess import textwrap from test import support -from test.script_helper import ( +from test.support.script_helper import ( make_pkg, make_script, make_zip_pkg, make_zip_script, - assert_python_ok, assert_python_failure, temp_dir, - spawn_python, kill_python) + assert_python_ok, assert_python_failure, spawn_python, kill_python) verbose = support.verbose @@ -223,14 +222,14 @@ class CmdLineTest(unittest.TestCase): self.check_repl_stderr_flush(True) def test_basic_script(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'script') self._check_script(script_name, script_name, script_name, script_dir, None, importlib.machinery.SourceFileLoader) def test_script_compiled(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'script') py_compile.compile(script_name, doraise=True) os.remove(script_name) @@ -240,14 +239,14 @@ class CmdLineTest(unittest.TestCase): importlib.machinery.SourcelessFileLoader) def test_directory(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, '__main__') self._check_script(script_dir, script_name, script_dir, script_dir, '', importlib.machinery.SourceFileLoader) def test_directory_compiled(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, '__main__') py_compile.compile(script_name, doraise=True) os.remove(script_name) @@ -257,19 +256,19 @@ class CmdLineTest(unittest.TestCase): importlib.machinery.SourcelessFileLoader) def test_directory_error(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: msg = "can't find '__main__' module in %r" % script_dir self._check_import_error(script_dir, msg) def test_zipfile(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, '__main__') zip_name, run_name = make_zip_script(script_dir, 'test_zip', script_name) self._check_script(zip_name, run_name, zip_name, zip_name, '', zipimport.zipimporter) def test_zipfile_compiled(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, '__main__') compiled_name = py_compile.compile(script_name, doraise=True) zip_name, run_name = make_zip_script(script_dir, 'test_zip', compiled_name) @@ -277,14 +276,14 @@ class CmdLineTest(unittest.TestCase): zipimport.zipimporter) def test_zipfile_error(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'not_main') zip_name, run_name = make_zip_script(script_dir, 'test_zip', script_name) msg = "can't find '__main__' module in %r" % zip_name self._check_import_error(zip_name, msg) def test_module_in_package(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir) script_name = _make_test_script(pkg_dir, 'script') @@ -294,14 +293,14 @@ class CmdLineTest(unittest.TestCase): importlib.machinery.SourceFileLoader) def test_module_in_package_in_zipfile(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script') launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.script', zip_name) self._check_script(launch_name, run_name, run_name, zip_name, 'test_pkg', zipimport.zipimporter) def test_module_in_subpackage_in_zipfile(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script', depth=2) launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.test_pkg.script', zip_name) self._check_script(launch_name, run_name, run_name, @@ -309,7 +308,7 @@ class CmdLineTest(unittest.TestCase): zipimport.zipimporter) def test_package(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir) script_name = _make_test_script(pkg_dir, '__main__') @@ -319,7 +318,7 @@ class CmdLineTest(unittest.TestCase): importlib.machinery.SourceFileLoader) def test_package_compiled(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir) script_name = _make_test_script(pkg_dir, '__main__') @@ -332,7 +331,7 @@ class CmdLineTest(unittest.TestCase): importlib.machinery.SourcelessFileLoader) def test_package_error(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir) msg = ("'test_pkg' is a package and cannot " @@ -341,7 +340,7 @@ class CmdLineTest(unittest.TestCase): self._check_import_error(launch_name, msg) def test_package_recursion(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir) main_dir = os.path.join(pkg_dir, '__main__') @@ -355,7 +354,7 @@ class CmdLineTest(unittest.TestCase): def test_issue8202(self): # Make sure package __init__ modules see "-m" in sys.argv0 while # searching for the module to execute - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: with support.change_cwd(path=script_dir): pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir, "import sys; print('init_argv0==%r' % sys.argv[0])") @@ -372,7 +371,7 @@ class CmdLineTest(unittest.TestCase): def test_issue8202_dash_c_file_ignored(self): # Make sure a "-c" file in the current directory # does not alter the value of sys.path[0] - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: with support.change_cwd(path=script_dir): with open("-c", "w") as f: f.write("data") @@ -387,7 +386,7 @@ class CmdLineTest(unittest.TestCase): def test_issue8202_dash_m_file_ignored(self): # Make sure a "-m" file in the current directory # does not alter the value of sys.path[0] - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'other') with support.change_cwd(path=script_dir): with open("-m", "w") as f: @@ -402,7 +401,7 @@ class CmdLineTest(unittest.TestCase): # If a module is invoked with the -m command line flag # and results in an error that the return code to the # shell is '1' - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: with support.change_cwd(path=script_dir): pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir) @@ -422,7 +421,7 @@ class CmdLineTest(unittest.TestCase): except: raise NameError from None """) - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'script', script) exitcode, stdout, stderr = assert_python_failure(script_name) text = stderr.decode('ascii').split('\n') @@ -466,7 +465,7 @@ class CmdLineTest(unittest.TestCase): if error: sys.exit(error) """) - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'script', script) exitcode, stdout, stderr = assert_python_failure(script_name) text = stderr.decode('ascii') diff --git a/Lib/test/test_code_module.py b/Lib/test/test_code_module.py index 7a80a80..3394b39 100644 --- a/Lib/test/test_code_module.py +++ b/Lib/test/test_code_module.py @@ -1,6 +1,7 @@ "Test InteractiveConsole and InteractiveInterpreter from code module" import sys import unittest +from textwrap import dedent from contextlib import ExitStack from unittest import mock from test import support @@ -78,9 +79,40 @@ class TestInteractiveConsole(unittest.TestCase): self.console.interact(banner='') self.assertEqual(len(self.stderr.method_calls), 1) + def test_cause_tb(self): + self.infunc.side_effect = ["raise ValueError('') from AttributeError", + EOFError('Finished')] + self.console.interact() + output = ''.join(''.join(call[1]) for call in self.stderr.method_calls) + expected = dedent(""" + AttributeError + + The above exception was the direct cause of the following exception: + + Traceback (most recent call last): + File "<console>", line 1, in <module> + ValueError + """) + self.assertIn(expected, output) + + def test_context_tb(self): + self.infunc.side_effect = ["try: ham\nexcept: eggs\n", + EOFError('Finished')] + self.console.interact() + output = ''.join(''.join(call[1]) for call in self.stderr.method_calls) + expected = dedent(""" + Traceback (most recent call last): + File "<console>", line 1, in <module> + NameError: name 'ham' is not defined + + During handling of the above exception, another exception occurred: + + Traceback (most recent call last): + File "<console>", line 2, in <module> + NameError: name 'eggs' is not defined + """) + self.assertIn(expected, output) -def test_main(): - support.run_unittest(TestInteractiveConsole) if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py index 1327f11..ee1e28a 100644 --- a/Lib/test/test_codeccallbacks.py +++ b/Lib/test/test_codeccallbacks.py @@ -150,6 +150,22 @@ class CodecCallbackTest(unittest.TestCase): sout = b"a\xac\\u1234\xa4\\u8000\\U0010ffff" self.assertEqual(sin.encode("iso-8859-15", "backslashreplace"), sout) + def test_nameescape(self): + # Does the same as backslashescape, but prefers ``\N{...}`` escape + # sequences. + sin = "a\xac\u1234\u20ac\u8000\U0010ffff" + sout = (b'a\\N{NOT SIGN}\\N{ETHIOPIC SYLLABLE SEE}\\N{EURO SIGN}' + b'\\N{CJK UNIFIED IDEOGRAPH-8000}\\U0010ffff') + self.assertEqual(sin.encode("ascii", "namereplace"), sout) + + sout = (b'a\xac\\N{ETHIOPIC SYLLABLE SEE}\\N{EURO SIGN}' + b'\\N{CJK UNIFIED IDEOGRAPH-8000}\\U0010ffff') + self.assertEqual(sin.encode("latin-1", "namereplace"), sout) + + sout = (b'a\xac\\N{ETHIOPIC SYLLABLE SEE}\xa4' + b'\\N{CJK UNIFIED IDEOGRAPH-8000}\\U0010ffff') + self.assertEqual(sin.encode("iso-8859-15", "namereplace"), sout) + def test_decoding_callbacks(self): # This is a test for a decoding callback handler # that allows the decoding of the invalid sequence @@ -220,6 +236,11 @@ class CodecCallbackTest(unittest.TestCase): "\u0000\ufffd" ) + self.assertEqual( + b"\x00\x00\x00\x00\x00".decode("unicode-internal", "backslashreplace"), + "\u0000\\x00" + ) + codecs.register_error("test.hui", handler_unicodeinternal) self.assertEqual( @@ -287,7 +308,7 @@ class CodecCallbackTest(unittest.TestCase): def test_longstrings(self): # test long strings to check for memory overflow problems errors = [ "strict", "ignore", "replace", "xmlcharrefreplace", - "backslashreplace"] + "backslashreplace", "namereplace"] # register the handlers under different names, # to prevent the codec from recognizing the name for err in errors: @@ -550,17 +571,6 @@ class CodecCallbackTest(unittest.TestCase): codecs.backslashreplace_errors, UnicodeError("ouch") ) - # "backslashreplace" can only be used for encoding - self.assertRaises( - TypeError, - codecs.backslashreplace_errors, - UnicodeDecodeError("ascii", bytearray(b"\xff"), 0, 1, "ouch") - ) - self.assertRaises( - TypeError, - codecs.backslashreplace_errors, - UnicodeTranslateError("\u3042", 0, 1, "ouch") - ) # Use the correct exception tests = [ ("\u3042", "\\u3042"), @@ -585,6 +595,72 @@ class CodecCallbackTest(unittest.TestCase): 1, 1 + len(s), "ouch")), (r, 1 + len(s)) ) + self.assertEqual( + codecs.backslashreplace_errors( + UnicodeTranslateError("a" + s + "b", + 1, 1 + len(s), "ouch")), + (r, 1 + len(s)) + ) + tests = [ + (b"a", "\\x61"), + (b"\n", "\\x0a"), + (b"\x00", "\\x00"), + (b"\xff", "\\xff"), + ] + for b, r in tests: + with self.subTest(bytes=b): + self.assertEqual( + codecs.backslashreplace_errors( + UnicodeDecodeError("ascii", bytearray(b"a" + b + b"b"), + 1, 2, "ouch")), + (r, 2) + ) + + def test_badandgoodnamereplaceexceptions(self): + # "namereplace" complains about a non-exception passed in + self.assertRaises( + TypeError, + codecs.namereplace_errors, + 42 + ) + # "namereplace" complains about the wrong exception types + self.assertRaises( + TypeError, + codecs.namereplace_errors, + UnicodeError("ouch") + ) + # "namereplace" can only be used for encoding + self.assertRaises( + TypeError, + codecs.namereplace_errors, + UnicodeDecodeError("ascii", bytearray(b"\xff"), 0, 1, "ouch") + ) + self.assertRaises( + TypeError, + codecs.namereplace_errors, + UnicodeTranslateError("\u3042", 0, 1, "ouch") + ) + # Use the correct exception + tests = [ + ("\u3042", "\\N{HIRAGANA LETTER A}"), + ("\x00", "\\x00"), + ("\ufbf9", "\\N{ARABIC LIGATURE UIGHUR KIRGHIZ YEH WITH " + "HAMZA ABOVE WITH ALEF MAKSURA ISOLATED FORM}"), + ("\U000e007f", "\\N{CANCEL TAG}"), + ("\U0010ffff", "\\U0010ffff"), + # Lone surrogates + ("\ud800", "\\ud800"), + ("\udfff", "\\udfff"), + ("\ud800\udfff", "\\ud800\\udfff"), + ] + for s, r in tests: + with self.subTest(str=s): + self.assertEqual( + codecs.namereplace_errors( + UnicodeEncodeError("ascii", "a" + s + "b", + 1, 1 + len(s), "ouch")), + (r, 1 + len(s)) + ) def test_badandgoodsurrogateescapeexceptions(self): surrogateescape_errors = codecs.lookup_error('surrogateescape') @@ -663,20 +739,24 @@ class CodecCallbackTest(unittest.TestCase): surrogatepass_errors, UnicodeDecodeError(enc, "a".encode(enc), 0, 1, "ouch") ) + for s in ("\ud800", "\udfff", "\ud800\udfff"): + with self.subTest(str=s): + self.assertRaises( + UnicodeEncodeError, + surrogatepass_errors, + UnicodeEncodeError("ascii", s, 0, len(s), "ouch") + ) tests = [ - ("ascii", "\ud800", b'\xed\xa0\x80', 3), ("utf-8", "\ud800", b'\xed\xa0\x80', 3), ("utf-16le", "\ud800", b'\x00\xd8', 2), ("utf-16be", "\ud800", b'\xd8\x00', 2), ("utf-32le", "\ud800", b'\x00\xd8\x00\x00', 4), ("utf-32be", "\ud800", b'\x00\x00\xd8\x00', 4), - ("ascii", "\udfff", b'\xed\xbf\xbf', 3), ("utf-8", "\udfff", b'\xed\xbf\xbf', 3), ("utf-16le", "\udfff", b'\xff\xdf', 2), ("utf-16be", "\udfff", b'\xdf\xff', 2), ("utf-32le", "\udfff", b'\xff\xdf\x00\x00', 4), ("utf-32be", "\udfff", b'\x00\x00\xdf\xff', 4), - ("ascii", "\ud800\udfff", b'\xed\xa0\x80\xed\xbf\xbf', 3), ("utf-8", "\ud800\udfff", b'\xed\xa0\x80\xed\xbf\xbf', 3), ("utf-16le", "\ud800\udfff", b'\x00\xd8\xff\xdf', 2), ("utf-16be", "\ud800\udfff", b'\xd8\x00\xdf\xff', 2), @@ -694,7 +774,7 @@ class CodecCallbackTest(unittest.TestCase): self.assertEqual( surrogatepass_errors( UnicodeDecodeError(enc, bytearray(b"a" + b[:n] + b"b"), - 1, n, "ouch")), + 1, 1 + n, "ouch")), (s[:1], 1 + n) ) @@ -738,6 +818,10 @@ class CodecCallbackTest(unittest.TestCase): codecs.backslashreplace_errors, codecs.lookup_error("backslashreplace") ) + self.assertEqual( + codecs.namereplace_errors, + codecs.lookup_error("namereplace") + ) def test_unencodablereplacement(self): def unencrepl(exc): @@ -890,7 +974,8 @@ class CodecCallbackTest(unittest.TestCase): class D(dict): def __getitem__(self, key): raise ValueError - for err in ("strict", "replace", "xmlcharrefreplace", "backslashreplace", "test.posreturn"): + for err in ("strict", "replace", "xmlcharrefreplace", + "backslashreplace", "namereplace", "test.posreturn"): self.assertRaises(UnicodeError, codecs.charmap_encode, "\xff", err, {0xff: None}) self.assertRaises(ValueError, codecs.charmap_encode, "\xff", err, D()) self.assertRaises(TypeError, codecs.charmap_encode, "\xff", err, {0xff: 300}) @@ -905,7 +990,7 @@ class CodecCallbackTest(unittest.TestCase): def __getitem__(self, key): raise ValueError #self.assertRaises(ValueError, "\xff".translate, D()) - self.assertRaises(TypeError, "\xff".translate, {0xff: sys.maxunicode+1}) + self.assertRaises(ValueError, "\xff".translate, {0xff: sys.maxunicode+1}) self.assertRaises(TypeError, "\xff".translate, {0xff: ()}) def test_bug828737(self): @@ -967,6 +1052,7 @@ class CodecCallbackTest(unittest.TestCase): codecs.ignore_errors, codecs.replace_errors, codecs.backslashreplace_errors, + codecs.namereplace_errors, codecs.xmlcharrefreplace_errors, codecs.lookup_error('surrogateescape'), codecs.lookup_error('surrogatepass'), diff --git a/Lib/test/test_codecencodings_cn.py b/Lib/test/test_codecencodings_cn.py index 60e69eb..d0e3a15 100644 --- a/Lib/test/test_codecencodings_cn.py +++ b/Lib/test/test_codecencodings_cn.py @@ -83,8 +83,5 @@ class Test_HZ(multibytecodec_support.TestBase, unittest.TestCase): (b"ab~{\x79\x79\x41\x44~}cd", "replace", "ab\ufffd\ufffd\u804acd"), ) -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_codecencodings_hk.py b/Lib/test/test_codecencodings_hk.py index 25c05b6..bb9be11 100644 --- a/Lib/test/test_codecencodings_hk.py +++ b/Lib/test/test_codecencodings_hk.py @@ -19,8 +19,5 @@ class Test_Big5HKSCS(multibytecodec_support.TestBase, unittest.TestCase): (b"abc\x80\x80\xc1\xc4", "ignore", "abc\u8b10"), ) -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_codecencodings_iso2022.py b/Lib/test/test_codecencodings_iso2022.py index 8776864..8a3ca70 100644 --- a/Lib/test/test_codecencodings_iso2022.py +++ b/Lib/test/test_codecencodings_iso2022.py @@ -38,8 +38,5 @@ class Test_ISO2022_KR(multibytecodec_support.TestBase, unittest.TestCase): def test_chunkcoding(self): pass -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_codecencodings_jp.py b/Lib/test/test_codecencodings_jp.py index 4091948..44b63a0 100644 --- a/Lib/test/test_codecencodings_jp.py +++ b/Lib/test/test_codecencodings_jp.py @@ -123,8 +123,5 @@ class Test_SJISX0213(multibytecodec_support.TestBase, unittest.TestCase): b"\x85Gℜ\x85Q = ⟨ሴ⟩" ) -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_codecencodings_kr.py b/Lib/test/test_codecencodings_kr.py index cd7696a..b6a74fb 100644 --- a/Lib/test/test_codecencodings_kr.py +++ b/Lib/test/test_codecencodings_kr.py @@ -66,8 +66,5 @@ class Test_JOHAB(multibytecodec_support.TestBase, unittest.TestCase): (b"\x8CBxy", "replace", "\uFFFDBxy"), ) -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_codecencodings_tw.py b/Lib/test/test_codecencodings_tw.py index ea6e1c1..9174296 100644 --- a/Lib/test/test_codecencodings_tw.py +++ b/Lib/test/test_codecencodings_tw.py @@ -19,8 +19,5 @@ class Test_Big5(multibytecodec_support.TestBase, unittest.TestCase): (b"abc\x80\x80\xc1\xc4", "ignore", "abc\u8b10"), ) -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index a1079a1..cc1f11a 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -349,6 +349,8 @@ class ReadTest(MixInCheckStateHandling): self.assertRaises(UnicodeEncodeError, "\ud800".encode, self.encoding) self.assertEqual("[\uDC80]".encode(self.encoding, "backslashreplace"), "[\\udc80]".encode(self.encoding)) + self.assertEqual("[\uDC80]".encode(self.encoding, "namereplace"), + "[\\udc80]".encode(self.encoding)) self.assertEqual("[\uDC80]".encode(self.encoding, "xmlcharrefreplace"), "[�]".encode(self.encoding)) self.assertEqual("[\uDC80]".encode(self.encoding, "ignore"), @@ -376,6 +378,10 @@ class ReadTest(MixInCheckStateHandling): before + after) self.assertEqual(test_sequence.decode(self.encoding, "replace"), before + self.ill_formed_sequence_replace + after) + backslashreplace = ''.join('\\x%02x' % b + for b in self.ill_formed_sequence) + self.assertEqual(test_sequence.decode(self.encoding, "backslashreplace"), + before + backslashreplace + after) class UTF32Test(ReadTest, unittest.TestCase): encoding = "utf-32" @@ -808,6 +814,7 @@ class CP65001Test(ReadTest, unittest.TestCase): ('\udc80', 'ignore', b''), ('\udc80', 'replace', b'?'), ('\udc80', 'backslashreplace', b'\\udc80'), + ('\udc80', 'namereplace', b'\\udc80'), ('\udc80', 'surrogatepass', b'\xed\xb2\x80'), )) else: @@ -869,6 +876,8 @@ class CP65001Test(ReadTest, unittest.TestCase): self.assertRaises(UnicodeDecodeError, b"\xed\xa0\x80".decode, "cp65001") self.assertEqual("[\uDC80]".encode("cp65001", "backslashreplace"), b'[\\udc80]') + self.assertEqual("[\uDC80]".encode("cp65001", "namereplace"), + b'[\\udc80]') self.assertEqual("[\uDC80]".encode("cp65001", "xmlcharrefreplace"), b'[�]') self.assertEqual("[\uDC80]".encode("cp65001", "surrogateescape"), @@ -890,10 +899,6 @@ class CP65001Test(ReadTest, unittest.TestCase): "\U00010fff\uD800") self.assertTrue(codecs.lookup_error("surrogatepass")) - def test_readline(self): - self.skipTest("issue #20571: code page 65001 codec does not " - "support partial decoder yet") - class UTF7Test(ReadTest, unittest.TestCase): encoding = "utf-7" @@ -1139,6 +1144,7 @@ class UTF8SigTest(UTF8Test, unittest.TestCase): class EscapeDecodeTest(unittest.TestCase): def test_empty(self): self.assertEqual(codecs.escape_decode(b""), (b"", 0)) + self.assertEqual(codecs.escape_decode(bytearray()), (b"", 0)) def test_raw(self): decode = codecs.escape_decode @@ -1357,14 +1363,19 @@ class UnicodeInternalTest(unittest.TestCase): "unicode_internal") if sys.byteorder == "little": invalid = b"\x00\x00\x11\x00" + invalid_backslashreplace = r"\x00\x00\x11\x00" else: invalid = b"\x00\x11\x00\x00" + invalid_backslashreplace = r"\x00\x11\x00\x00" with support.check_warnings(): self.assertRaises(UnicodeDecodeError, invalid.decode, "unicode_internal") with support.check_warnings(): self.assertEqual(invalid.decode("unicode_internal", "replace"), '\ufffd') + with support.check_warnings(): + self.assertEqual(invalid.decode("unicode_internal", "backslashreplace"), + invalid_backslashreplace) @unittest.skipUnless(SIZEOF_WCHAR_T == 4, 'specific to 32-bit wchar_t') def test_decode_error_attributes(self): @@ -1670,6 +1681,12 @@ class CodecsModuleTest(unittest.TestCase): self.assertEqual(codecs.decode(b'abc'), 'abc') self.assertRaises(UnicodeDecodeError, codecs.decode, b'\xff', 'ascii') + # test keywords + self.assertEqual(codecs.decode(obj=b'\xe4\xf6\xfc', encoding='latin-1'), + '\xe4\xf6\xfc') + self.assertEqual(codecs.decode(b'[\xff]', 'ascii', errors='ignore'), + '[]') + def test_encode(self): self.assertEqual(codecs.encode('\xe4\xf6\xfc', 'latin-1'), b'\xe4\xf6\xfc') @@ -1678,6 +1695,12 @@ class CodecsModuleTest(unittest.TestCase): self.assertEqual(codecs.encode('abc'), b'abc') self.assertRaises(UnicodeEncodeError, codecs.encode, '\xffff', 'ascii') + # test keywords + self.assertEqual(codecs.encode(obj='\xe4\xf6\xfc', encoding='latin-1'), + b'\xe4\xf6\xfc') + self.assertEqual(codecs.encode('[\xff]', 'ascii', errors='ignore'), + b'[]') + def test_register(self): self.assertRaises(TypeError, codecs.register) self.assertRaises(TypeError, codecs.register, 42) @@ -1726,6 +1749,7 @@ class CodecsModuleTest(unittest.TestCase): "register_error", "lookup_error", "strict_errors", "replace_errors", "ignore_errors", "xmlcharrefreplace_errors", "backslashreplace_errors", + "namereplace_errors", "open", "EncodedFile", "iterencode", "iterdecode", "BOM", "BOM_BE", "BOM_LE", @@ -1856,7 +1880,9 @@ all_unicode_encodings = [ "iso8859_9", "johab", "koi8_r", + "koi8_t", "koi8_u", + "kz1048", "latin_1", "mac_cyrillic", "mac_greek", @@ -2087,6 +2113,16 @@ class CharmapTest(unittest.TestCase): ) self.assertEqual( + codecs.charmap_decode(b"\x00\x01\x02", "backslashreplace", "ab"), + ("ab\\x02", 3) + ) + + self.assertEqual( + codecs.charmap_decode(b"\x00\x01\x02", "backslashreplace", "ab\ufffe"), + ("ab\\x02", 3) + ) + + self.assertEqual( codecs.charmap_decode(b"\x00\x01\x02", "ignore", "ab"), ("ab", 3) ) @@ -2163,6 +2199,25 @@ class CharmapTest(unittest.TestCase): ) self.assertEqual( + codecs.charmap_decode(b"\x00\x01\x02", "backslashreplace", + {0: 'a', 1: 'b'}), + ("ab\\x02", 3) + ) + + self.assertEqual( + codecs.charmap_decode(b"\x00\x01\x02", "backslashreplace", + {0: 'a', 1: 'b', 2: None}), + ("ab\\x02", 3) + ) + + # Issue #14850 + self.assertEqual( + codecs.charmap_decode(b"\x00\x01\x02", "backslashreplace", + {0: 'a', 1: 'b', 2: '\ufffe'}), + ("ab\\x02", 3) + ) + + self.assertEqual( codecs.charmap_decode(b"\x00\x01\x02", "ignore", {0: 'a', 1: 'b'}), ("ab", 3) @@ -2239,6 +2294,18 @@ class CharmapTest(unittest.TestCase): ) self.assertEqual( + codecs.charmap_decode(b"\x00\x01\x02", "backslashreplace", + {0: a, 1: b}), + ("ab\\x02", 3) + ) + + self.assertEqual( + codecs.charmap_decode(b"\x00\x01\x02", "backslashreplace", + {0: a, 1: b, 2: 0xFFFE}), + ("ab\\x02", 3) + ) + + self.assertEqual( codecs.charmap_decode(b"\x00\x01\x02", "ignore", {0: a, 1: b}), ("ab", 3) @@ -2297,9 +2364,13 @@ class TypesTest(unittest.TestCase): self.assertRaises(UnicodeDecodeError, codecs.unicode_escape_decode, br"\U00110000") self.assertEqual(codecs.unicode_escape_decode(r"\U00110000", "replace"), ("\ufffd", 10)) + self.assertEqual(codecs.unicode_escape_decode(r"\U00110000", "backslashreplace"), + (r"\x5c\x55\x30\x30\x31\x31\x30\x30\x30\x30", 10)) self.assertRaises(UnicodeDecodeError, codecs.raw_unicode_escape_decode, br"\U00110000") self.assertEqual(codecs.raw_unicode_escape_decode(r"\U00110000", "replace"), ("\ufffd", 10)) + self.assertEqual(codecs.raw_unicode_escape_decode(r"\U00110000", "backslashreplace"), + (r"\x5c\x55\x30\x30\x31\x31\x30\x30\x30\x30", 10)) class UnicodeEscapeTest(unittest.TestCase): @@ -2884,15 +2955,15 @@ class CodePageTest(unittest.TestCase): self.assertRaisesRegex(UnicodeEncodeError, 'cp932', codecs.code_page_encode, 932, '\xff') self.assertRaisesRegex(UnicodeDecodeError, 'cp932', - codecs.code_page_decode, 932, b'\x81\x00') + codecs.code_page_decode, 932, b'\x81\x00', 'strict', True) self.assertRaisesRegex(UnicodeDecodeError, 'CP_UTF8', - codecs.code_page_decode, self.CP_UTF8, b'\xff') + codecs.code_page_decode, self.CP_UTF8, b'\xff', 'strict', True) def check_decode(self, cp, tests): for raw, errors, expected in tests: if expected is not None: try: - decoded = codecs.code_page_decode(cp, raw, errors) + decoded = codecs.code_page_decode(cp, raw, errors, True) except UnicodeDecodeError as err: self.fail('Unable to decode %a from "cp%s" with ' 'errors=%r: %s' % (raw, cp, errors, err)) @@ -2904,7 +2975,7 @@ class CodePageTest(unittest.TestCase): self.assertLessEqual(decoded[1], len(raw)) else: self.assertRaises(UnicodeDecodeError, - codecs.code_page_decode, cp, raw, errors) + codecs.code_page_decode, cp, raw, errors, True) def check_encode(self, cp, tests): for text, errors, expected in tests: @@ -2932,7 +3003,12 @@ class CodePageTest(unittest.TestCase): ('[\xff]', 'replace', b'[y]'), ('[\u20ac]', 'replace', b'[?]'), ('[\xff]', 'backslashreplace', b'[\\xff]'), + ('[\xff]', 'namereplace', + b'[\\N{LATIN SMALL LETTER Y WITH DIAERESIS}]'), ('[\xff]', 'xmlcharrefreplace', b'[ÿ]'), + ('\udcff', 'strict', None), + ('[\udcff]', 'surrogateescape', b'[\xff]'), + ('[\udcff]', 'surrogatepass', None), )) self.check_decode(932, ( (b'abc', 'strict', 'abc'), @@ -2941,10 +3017,13 @@ class CodePageTest(unittest.TestCase): (b'[\xff]', 'strict', None), (b'[\xff]', 'ignore', '[]'), (b'[\xff]', 'replace', '[\ufffd]'), + (b'[\xff]', 'backslashreplace', '[\\xff]'), (b'[\xff]', 'surrogateescape', '[\udcff]'), + (b'[\xff]', 'surrogatepass', None), (b'\x81\x00abc', 'strict', None), (b'\x81\x00abc', 'ignore', '\x00abc'), (b'\x81\x00abc', 'replace', '\ufffd\x00abc'), + (b'\x81\x00abc', 'backslashreplace', '\\x81\x00abc'), )) def test_cp1252(self): @@ -2952,9 +3031,12 @@ class CodePageTest(unittest.TestCase): ('abc', 'strict', b'abc'), ('\xe9\u20ac', 'strict', b'\xe9\x80'), ('\xff', 'strict', b'\xff'), + # test error handlers ('\u0141', 'strict', None), ('\u0141', 'ignore', b''), ('\u0141', 'replace', b'L'), + ('\udc98', 'surrogateescape', b'\x98'), + ('\udc98', 'surrogatepass', None), )) self.check_decode(1252, ( (b'abc', 'strict', 'abc'), diff --git a/Lib/test/test_codeop.py b/Lib/test/test_codeop.py index b65423b..509bf5d 100644 --- a/Lib/test/test_codeop.py +++ b/Lib/test/test_codeop.py @@ -3,7 +3,7 @@ Nick Mathewson """ import unittest -from test.support import run_unittest, is_jython +from test.support import is_jython from codeop import compile_command, PyCF_DONT_IMPLY_DEDENT import io @@ -296,9 +296,5 @@ class CodeopTests(unittest.TestCase): compile("a = 1\n", "def", 'single').co_filename) -def test_main(): - run_unittest(CodeopTests) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 66db90f..4124f91 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -1,7 +1,8 @@ """Unit tests for collections.py.""" import unittest, doctest, operator -from test.support import TESTFN, forget, unlink +from test.support import TESTFN, forget, unlink, import_fresh_module +import contextlib import inspect from test import support from collections import namedtuple, Counter, OrderedDict, _count_elements @@ -11,9 +12,12 @@ from random import randrange, shuffle import keyword import re import sys -from collections import UserDict +import types +from collections import UserDict, UserString, UserList from collections import ChainMap -from collections.abc import Hashable, Iterable, Iterator +from collections import deque +from collections.abc import Awaitable, Coroutine, AsyncIterator, AsyncIterable +from collections.abc import Hashable, Iterable, Iterator, Generator from collections.abc import Sized, Container, Callable from collections.abc import Set, MutableSet from collections.abc import Mapping, MutableMapping, KeysView, ItemsView @@ -21,6 +25,26 @@ from collections.abc import Sequence, MutableSequence from collections.abc import ByteString +class TestUserObjects(unittest.TestCase): + def _superset_test(self, a, b): + self.assertGreaterEqual( + set(dir(a)), + set(dir(b)), + '{a} should have all the methods of {b}'.format( + a=a.__name__, + b=b.__name__, + ), + ) + def test_str_protocol(self): + self._superset_test(UserString, str) + + def test_list_protocol(self): + self._superset_test(UserList, list) + + def test_dict_protocol(self): + self._superset_test(UserDict, dict) + + ################################################################################ ### ChainMap (helper class for configparser and the string module) ################################################################################ @@ -196,6 +220,14 @@ class TestNamedTuple(unittest.TestCase): Point = namedtuple('Point', 'x y') self.assertEqual(Point.__doc__, 'Point(x, y)') + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_doc_writable(self): + Point = namedtuple('Point', 'x y') + self.assertEqual(Point.x.__doc__, 'Alias for field number 0') + Point.x.__doc__ = 'docstring for Point.x' + self.assertEqual(Point.x.__doc__, 'docstring for Point.x') + def test_name_fixer(self): for spec, renamed in [ [('efg', 'g%hi'), ('efg', '_1')], # field with non-alpha char @@ -455,6 +487,121 @@ class ABCTestCase(unittest.TestCase): class TestOneTrickPonyABCs(ABCTestCase): + def test_Awaitable(self): + def gen(): + yield + + @types.coroutine + def coro(): + yield + + async def new_coro(): + pass + + class Bar: + def __await__(self): + yield + + class MinimalCoro(Coroutine): + def send(self, value): + return value + def throw(self, typ, val=None, tb=None): + super().throw(typ, val, tb) + def __await__(self): + yield + + non_samples = [None, int(), gen(), object()] + for x in non_samples: + self.assertNotIsInstance(x, Awaitable) + self.assertFalse(issubclass(type(x), Awaitable), repr(type(x))) + + samples = [Bar(), MinimalCoro()] + for x in samples: + self.assertIsInstance(x, Awaitable) + self.assertTrue(issubclass(type(x), Awaitable)) + + c = coro() + # Iterable coroutines (generators with CO_ITERABLE_COROUTINE + # flag don't have '__await__' method, hence can't be instances + # of Awaitable. Use inspect.isawaitable to detect them. + self.assertNotIsInstance(c, Awaitable) + + c = new_coro() + self.assertIsInstance(c, Awaitable) + c.close() # awoid RuntimeWarning that coro() was not awaited + + class CoroLike: pass + Coroutine.register(CoroLike) + self.assertTrue(isinstance(CoroLike(), Awaitable)) + self.assertTrue(issubclass(CoroLike, Awaitable)) + CoroLike = None + support.gc_collect() # Kill CoroLike to clean-up ABCMeta cache + + def test_Coroutine(self): + def gen(): + yield + + @types.coroutine + def coro(): + yield + + async def new_coro(): + pass + + class Bar: + def __await__(self): + yield + + class MinimalCoro(Coroutine): + def send(self, value): + return value + def throw(self, typ, val=None, tb=None): + super().throw(typ, val, tb) + def __await__(self): + yield + + non_samples = [None, int(), gen(), object(), Bar()] + for x in non_samples: + self.assertNotIsInstance(x, Coroutine) + self.assertFalse(issubclass(type(x), Coroutine), repr(type(x))) + + samples = [MinimalCoro()] + for x in samples: + self.assertIsInstance(x, Awaitable) + self.assertTrue(issubclass(type(x), Awaitable)) + + c = coro() + # Iterable coroutines (generators with CO_ITERABLE_COROUTINE + # flag don't have '__await__' method, hence can't be instances + # of Coroutine. Use inspect.isawaitable to detect them. + self.assertNotIsInstance(c, Coroutine) + + c = new_coro() + self.assertIsInstance(c, Coroutine) + c.close() # awoid RuntimeWarning that coro() was not awaited + + class CoroLike: + def send(self, value): + pass + def throw(self, typ, val=None, tb=None): + pass + def close(self): + pass + def __await__(self): + pass + self.assertTrue(isinstance(CoroLike(), Coroutine)) + self.assertTrue(issubclass(CoroLike, Coroutine)) + + class CoroLike: + def send(self, value): + pass + def close(self): + pass + def __await__(self): + pass + self.assertFalse(isinstance(CoroLike(), Coroutine)) + self.assertFalse(issubclass(CoroLike, Coroutine)) + def test_Hashable(self): # Check some non-hashables non_samples = [bytearray(), list(), set(), dict()] @@ -481,6 +628,40 @@ class TestOneTrickPonyABCs(ABCTestCase): self.validate_abstract_methods(Hashable, '__hash__') self.validate_isinstance(Hashable, '__hash__') + def test_AsyncIterable(self): + class AI: + async def __aiter__(self): + return self + self.assertTrue(isinstance(AI(), AsyncIterable)) + self.assertTrue(issubclass(AI, AsyncIterable)) + # Check some non-iterables + non_samples = [None, object, []] + for x in non_samples: + self.assertNotIsInstance(x, AsyncIterable) + self.assertFalse(issubclass(type(x), AsyncIterable), repr(type(x))) + self.validate_abstract_methods(AsyncIterable, '__aiter__') + self.validate_isinstance(AsyncIterable, '__aiter__') + + def test_AsyncIterator(self): + class AI: + async def __aiter__(self): + return self + async def __anext__(self): + raise StopAsyncIteration + self.assertTrue(isinstance(AI(), AsyncIterator)) + self.assertTrue(issubclass(AI, AsyncIterator)) + non_samples = [None, object, []] + # Check some non-iterables + for x in non_samples: + self.assertNotIsInstance(x, AsyncIterator) + self.assertFalse(issubclass(type(x), AsyncIterator), repr(type(x))) + # Similarly to regular iterators (see issue 10565) + class AnextOnly: + async def __anext__(self): + raise StopAsyncIteration + self.assertNotIsInstance(AnextOnly(), AsyncIterator) + self.validate_abstract_methods(AsyncIterator, '__anext__', '__aiter__') + def test_Iterable(self): # Check some non-iterables non_samples = [None, 42, 3.14, 1j] @@ -528,9 +709,80 @@ class TestOneTrickPonyABCs(ABCTestCase): class NextOnly: def __next__(self): yield 1 - raise StopIteration + return self.assertNotIsInstance(NextOnly(), Iterator) + def test_Generator(self): + class NonGen1: + def __iter__(self): return self + def __next__(self): return None + def close(self): pass + def throw(self, typ, val=None, tb=None): pass + + class NonGen2: + def __iter__(self): return self + def __next__(self): return None + def close(self): pass + def send(self, value): return value + + class NonGen3: + def close(self): pass + def send(self, value): return value + def throw(self, typ, val=None, tb=None): pass + + non_samples = [ + None, 42, 3.14, 1j, b"", "", (), [], {}, set(), + iter(()), iter([]), NonGen1(), NonGen2(), NonGen3()] + for x in non_samples: + self.assertNotIsInstance(x, Generator) + self.assertFalse(issubclass(type(x), Generator), repr(type(x))) + + class Gen: + def __iter__(self): return self + def __next__(self): return None + def close(self): pass + def send(self, value): return value + def throw(self, typ, val=None, tb=None): pass + + class MinimalGen(Generator): + def send(self, value): + return value + def throw(self, typ, val=None, tb=None): + super().throw(typ, val, tb) + + def gen(): + yield 1 + + samples = [gen(), (lambda: (yield))(), Gen(), MinimalGen()] + for x in samples: + self.assertIsInstance(x, Iterator) + self.assertIsInstance(x, Generator) + self.assertTrue(issubclass(type(x), Generator), repr(type(x))) + self.validate_abstract_methods(Generator, 'send', 'throw') + + # mixin tests + mgen = MinimalGen() + self.assertIs(mgen, iter(mgen)) + self.assertIs(mgen.send(None), next(mgen)) + self.assertEqual(2, mgen.send(2)) + self.assertIsNone(mgen.close()) + self.assertRaises(ValueError, mgen.throw, ValueError) + self.assertRaisesRegex(ValueError, "^huhu$", + mgen.throw, ValueError, ValueError("huhu")) + self.assertRaises(StopIteration, mgen.throw, StopIteration()) + + class FailOnClose(Generator): + def send(self, value): return value + def throw(self, *args): raise ValueError + + self.assertRaises(ValueError, FailOnClose().close) + + class IgnoreGeneratorExit(Generator): + def send(self, value): return value + def throw(self, *args): pass + + self.assertRaises(RuntimeError, IgnoreGeneratorExit().close) + def test_Sized(self): non_samples = [None, 42, 3.14, 1j, (lambda: (yield))(), @@ -657,6 +909,59 @@ class TestCollectionABCs(ABCTestCase): a, b = OneTwoThreeSet(), OneTwoThreeSet() self.assertTrue(hash(a) == hash(b)) + def test_isdisjoint_Set(self): + class MySet(Set): + def __init__(self, itr): + self.contents = itr + def __contains__(self, x): + return x in self.contents + def __iter__(self): + return iter(self.contents) + def __len__(self): + return len([x for x in self.contents]) + s1 = MySet((1, 2, 3)) + s2 = MySet((4, 5, 6)) + s3 = MySet((1, 5, 6)) + self.assertTrue(s1.isdisjoint(s2)) + self.assertFalse(s1.isdisjoint(s3)) + + def test_equality_Set(self): + class MySet(Set): + def __init__(self, itr): + self.contents = itr + def __contains__(self, x): + return x in self.contents + def __iter__(self): + return iter(self.contents) + def __len__(self): + return len([x for x in self.contents]) + s1 = MySet((1,)) + s2 = MySet((1, 2)) + s3 = MySet((3, 4)) + s4 = MySet((3, 4)) + self.assertTrue(s2 > s1) + self.assertTrue(s1 < s2) + self.assertFalse(s2 <= s1) + self.assertFalse(s2 <= s3) + self.assertFalse(s1 >= s2) + self.assertEqual(s3, s4) + self.assertNotEqual(s2, s3) + + def test_arithmetic_Set(self): + class MySet(Set): + def __init__(self, itr): + self.contents = itr + def __contains__(self, x): + return x in self.contents + def __iter__(self): + return iter(self.contents) + def __len__(self): + return len([x for x in self.contents]) + s1 = MySet((1, 2, 3)) + s2 = MySet((3, 4, 5)) + s3 = s1 & s2 + self.assertEqual(s3, MySet((3,))) + def test_MutableSet(self): self.assertIsInstance(set(), MutableSet) self.assertTrue(issubclass(set, MutableSet)) @@ -957,6 +1262,41 @@ class TestCollectionABCs(ABCTestCase): self.validate_abstract_methods(Sequence, '__contains__', '__iter__', '__len__', '__getitem__') + def test_Sequence_mixins(self): + class SequenceSubclass(Sequence): + def __init__(self, seq=()): + self.seq = seq + + def __getitem__(self, index): + return self.seq[index] + + def __len__(self): + return len(self.seq) + + # Compare Sequence.index() behavior to (list|str).index() behavior + def assert_index_same(seq1, seq2, index_args): + try: + expected = seq1.index(*index_args) + except ValueError: + with self.assertRaises(ValueError): + seq2.index(*index_args) + else: + actual = seq2.index(*index_args) + self.assertEqual( + actual, expected, '%r.index%s' % (seq1, index_args)) + + for ty in list, str: + nativeseq = ty('abracadabra') + indexes = [-10000, -9999] + list(range(-3, len(nativeseq) + 3)) + seqseq = SequenceSubclass(nativeseq) + for letter in set(nativeseq) | {'z'}: + assert_index_same(nativeseq, seqseq, (letter,)) + for start in range(-3, len(nativeseq) + 3): + assert_index_same(nativeseq, seqseq, (letter, start)) + for stop in range(-3, len(nativeseq) + 3): + assert_index_same( + nativeseq, seqseq, (letter, start, stop)) + def test_ByteString(self): for sample in [bytes, bytearray]: self.assertIsInstance(sample(), ByteString) @@ -971,7 +1311,7 @@ class TestCollectionABCs(ABCTestCase): for sample in [tuple, str, bytes]: self.assertNotIsInstance(sample(), MutableSequence) self.assertFalse(issubclass(sample, MutableSequence)) - for sample in [list, bytearray]: + for sample in [list, bytearray, deque]: self.assertIsInstance(sample(), MutableSequence) self.assertTrue(issubclass(sample, MutableSequence)) self.assertFalse(issubclass(str, MutableSequence)) @@ -1284,9 +1624,24 @@ class TestCounter(unittest.TestCase): ### OrderedDict ################################################################################ -class TestOrderedDict(unittest.TestCase): +py_coll = import_fresh_module('collections', blocked=['_collections']) +c_coll = import_fresh_module('collections', fresh=['_collections']) + + +@contextlib.contextmanager +def replaced_module(name, replacement): + original_module = sys.modules[name] + sys.modules[name] = replacement + try: + yield + finally: + sys.modules[name] = original_module + + +class OrderedDictTests: def test_init(self): + OrderedDict = self.module.OrderedDict with self.assertRaises(TypeError): OrderedDict([('a', 1), ('b', 2)], None) # too many args pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)] @@ -1310,6 +1665,7 @@ class TestOrderedDict(unittest.TestCase): [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)]) def test_update(self): + OrderedDict = self.module.OrderedDict with self.assertRaises(TypeError): OrderedDict().update([('a', 1), ('b', 2)], None) # too many args pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)] @@ -1350,11 +1706,26 @@ class TestOrderedDict(unittest.TestCase): self.assertRaises(TypeError, OrderedDict().update, (), ()) self.assertRaises(TypeError, OrderedDict.update) + self.assertRaises(TypeError, OrderedDict().update, 42) + self.assertRaises(TypeError, OrderedDict().update, (), ()) + self.assertRaises(TypeError, OrderedDict.update) + + def test_fromkeys(self): + OrderedDict = self.module.OrderedDict + od = OrderedDict.fromkeys('abc') + self.assertEqual(list(od.items()), [(c, None) for c in 'abc']) + od = OrderedDict.fromkeys('abc', value=None) + self.assertEqual(list(od.items()), [(c, None) for c in 'abc']) + od = OrderedDict.fromkeys('abc', value=0) + self.assertEqual(list(od.items()), [(c, 0) for c in 'abc']) + def test_abc(self): + OrderedDict = self.module.OrderedDict self.assertIsInstance(OrderedDict(), MutableMapping) self.assertTrue(issubclass(OrderedDict, MutableMapping)) def test_clear(self): + OrderedDict = self.module.OrderedDict pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] shuffle(pairs) od = OrderedDict(pairs) @@ -1363,6 +1734,7 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(len(od), 0) def test_delitem(self): + OrderedDict = self.module.OrderedDict pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] od = OrderedDict(pairs) del od['a'] @@ -1372,6 +1744,7 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(list(od.items()), pairs[:2] + pairs[3:]) def test_setitem(self): + OrderedDict = self.module.OrderedDict od = OrderedDict([('d', 1), ('b', 2), ('c', 3), ('a', 4), ('e', 5)]) od['c'] = 10 # existing element od['f'] = 20 # new element @@ -1379,6 +1752,7 @@ class TestOrderedDict(unittest.TestCase): [('d', 1), ('b', 2), ('c', 10), ('a', 4), ('e', 5), ('f', 20)]) def test_iterators(self): + OrderedDict = self.module.OrderedDict pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] shuffle(pairs) od = OrderedDict(pairs) @@ -1388,8 +1762,51 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(list(od.items()), pairs) self.assertEqual(list(reversed(od)), [t[0] for t in reversed(pairs)]) + self.assertEqual(list(reversed(od.keys())), + [t[0] for t in reversed(pairs)]) + self.assertEqual(list(reversed(od.values())), + [t[1] for t in reversed(pairs)]) + self.assertEqual(list(reversed(od.items())), list(reversed(pairs))) + + def test_detect_deletion_during_iteration(self): + OrderedDict = self.module.OrderedDict + od = OrderedDict.fromkeys('abc') + it = iter(od) + key = next(it) + del od[key] + with self.assertRaises(Exception): + # Note, the exact exception raised is not guaranteed + # The only guarantee that the next() will not succeed + next(it) + + def test_sorted_iterators(self): + OrderedDict = self.module.OrderedDict + with self.assertRaises(TypeError): + OrderedDict([('a', 1), ('b', 2)], None) + pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)] + od = OrderedDict(pairs) + self.assertEqual(sorted(od), [t[0] for t in pairs]) + self.assertEqual(sorted(od.keys()), [t[0] for t in pairs]) + self.assertEqual(sorted(od.values()), [t[1] for t in pairs]) + self.assertEqual(sorted(od.items()), pairs) + self.assertEqual(sorted(reversed(od)), + sorted([t[0] for t in reversed(pairs)])) + + def test_iterators_empty(self): + OrderedDict = self.module.OrderedDict + od = OrderedDict() + empty = [] + self.assertEqual(list(od), empty) + self.assertEqual(list(od.keys()), empty) + self.assertEqual(list(od.values()), empty) + self.assertEqual(list(od.items()), empty) + self.assertEqual(list(reversed(od)), empty) + self.assertEqual(list(reversed(od.keys())), empty) + self.assertEqual(list(reversed(od.values())), empty) + self.assertEqual(list(reversed(od.items())), empty) def test_popitem(self): + OrderedDict = self.module.OrderedDict pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] shuffle(pairs) od = OrderedDict(pairs) @@ -1399,7 +1816,19 @@ class TestOrderedDict(unittest.TestCase): od.popitem() self.assertEqual(len(od), 0) + def test_popitem_last(self): + OrderedDict = self.module.OrderedDict + pairs = [(i, i) for i in range(30)] + + obj = OrderedDict(pairs) + for i in range(8): + obj.popitem(True) + obj.popitem(True) + obj.popitem(last=True) + self.assertEqual(len(obj), 20) + def test_pop(self): + OrderedDict = self.module.OrderedDict pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] shuffle(pairs) od = OrderedDict(pairs) @@ -1420,10 +1849,12 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(m.pop('b', 5), 5) self.assertEqual(m.pop('a', 6), 1) self.assertEqual(m.pop('a', 6), 6) + self.assertEqual(m.pop('a', default=6), 6) with self.assertRaises(KeyError): m.pop('a') def test_equality(self): + OrderedDict = self.module.OrderedDict pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] shuffle(pairs) od1 = OrderedDict(pairs) @@ -1439,6 +1870,7 @@ class TestOrderedDict(unittest.TestCase): self.assertNotEqual(od1, OrderedDict(pairs[:-1])) def test_copying(self): + OrderedDict = self.module.OrderedDict # Check that ordered dicts are copyable, deepcopyable, picklable, # and have a repr/eval round-trip pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] @@ -1447,12 +1879,17 @@ class TestOrderedDict(unittest.TestCase): msg = "\ncopy: %s\nod: %s" % (dup, od) self.assertIsNot(dup, od, msg) self.assertEqual(dup, od) + self.assertEqual(list(dup.items()), list(od.items())) + self.assertEqual(len(dup), len(od)) + self.assertEqual(type(dup), type(od)) check(od.copy()) check(copy.copy(od)) check(copy.deepcopy(od)) - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - with self.subTest(proto=proto): - check(pickle.loads(pickle.dumps(od, proto))) + # pickle directly pulls the module, so we have to fake it + with replaced_module('collections', self.module): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + check(pickle.loads(pickle.dumps(od, proto))) check(eval(repr(od))) update_test = OrderedDict() update_test.update(od) @@ -1460,6 +1897,7 @@ class TestOrderedDict(unittest.TestCase): check(OrderedDict(od)) def test_yaml_linkage(self): + OrderedDict = self.module.OrderedDict # Verify that __reduce__ is setup in a way that supports PyYAML's dump() feature. # In yaml, lists are native but tuples are not. pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] @@ -1469,6 +1907,7 @@ class TestOrderedDict(unittest.TestCase): self.assertTrue(all(type(pair)==list for pair in od.__reduce__()[1])) def test_reduce_not_too_fat(self): + OrderedDict = self.module.OrderedDict # do not save instance dictionary if not needed pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] od = OrderedDict(pairs) @@ -1477,15 +1916,20 @@ class TestOrderedDict(unittest.TestCase): self.assertIsNotNone(od.__reduce__()[2]) def test_pickle_recursive(self): + OrderedDict = self.module.OrderedDict od = OrderedDict() od[1] = od - for proto in range(-1, pickle.HIGHEST_PROTOCOL + 1): - dup = pickle.loads(pickle.dumps(od, proto)) - self.assertIsNot(dup, od) - self.assertEqual(list(dup.keys()), [1]) - self.assertIs(dup[1], dup) + + # pickle directly pulls the module, so we have to fake it + with replaced_module('collections', self.module): + for proto in range(-1, pickle.HIGHEST_PROTOCOL + 1): + dup = pickle.loads(pickle.dumps(od, proto)) + self.assertIsNot(dup, od) + self.assertEqual(list(dup.keys()), [1]) + self.assertIs(dup[1], dup) def test_repr(self): + OrderedDict = self.module.OrderedDict od = OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]) self.assertEqual(repr(od), "OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)])") @@ -1493,6 +1937,7 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(repr(OrderedDict()), "OrderedDict()") def test_repr_recursive(self): + OrderedDict = self.module.OrderedDict # See issue #9826 od = OrderedDict.fromkeys('abc') od['x'] = od @@ -1500,6 +1945,7 @@ class TestOrderedDict(unittest.TestCase): "OrderedDict([('a', None), ('b', None), ('c', None), ('x', ...)])") def test_setdefault(self): + OrderedDict = self.module.OrderedDict pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] shuffle(pairs) od = OrderedDict(pairs) @@ -1510,6 +1956,7 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(od.setdefault('x', 10), 10) # make sure 'x' is added to the end self.assertEqual(list(od.items())[-1], ('x', 10)) + self.assertEqual(od.setdefault('g', default=9), 9) # make sure setdefault still works when __missing__ is defined class Missing(OrderedDict): @@ -1518,16 +1965,19 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(Missing().setdefault(5, 9), 9) def test_reinsert(self): + OrderedDict = self.module.OrderedDict # Given insert a, insert b, delete a, re-insert a, # verify that a is now later than b. od = OrderedDict() od['a'] = 1 od['b'] = 2 del od['a'] + self.assertEqual(list(od.items()), [('b', 2)]) od['a'] = 1 self.assertEqual(list(od.items()), [('b', 2), ('a', 1)]) def test_move_to_end(self): + OrderedDict = self.module.OrderedDict od = OrderedDict.fromkeys('abcde') self.assertEqual(list(od), list('abcde')) od.move_to_end('c') @@ -1538,16 +1988,22 @@ class TestOrderedDict(unittest.TestCase): self.assertEqual(list(od), list('cabde')) od.move_to_end('e') self.assertEqual(list(od), list('cabde')) + od.move_to_end('b', last=False) + self.assertEqual(list(od), list('bcade')) with self.assertRaises(KeyError): od.move_to_end('x') + with self.assertRaises(KeyError): + od.move_to_end('x', 0) def test_sizeof(self): + OrderedDict = self.module.OrderedDict # Wimpy test: Just verify the reported size is larger than a regular dict d = dict(a=1) od = OrderedDict(**d) self.assertGreater(sys.getsizeof(od), sys.getsizeof(d)) def test_override_update(self): + OrderedDict = self.module.OrderedDict # Verify that subclasses can override update() without breaking __init__() class MyOD(OrderedDict): def update(self, *args, **kwds): @@ -1555,18 +2011,171 @@ class TestOrderedDict(unittest.TestCase): items = [('a', 1), ('c', 3), ('b', 2)] self.assertEqual(list(MyOD(items).items()), items) -class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): - type2test = OrderedDict + +class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): + + module = py_coll + + +@unittest.skipUnless(c_coll, 'requires the C version of the collections module') +class CPythonOrderedDictTests(OrderedDictTests, unittest.TestCase): + + module = c_coll + + def test_delitem_hash_collision(self): + OrderedDict = self.module.OrderedDict + + class Key: + def __init__(self, hash): + self._hash = hash + self.value = str(id(self)) + def __hash__(self): + return self._hash + def __eq__(self, other): + try: + return self.value == other.value + except AttributeError: + return False + def __repr__(self): + return self.value + + def blocking_hash(hash): + # See the collision-handling in lookdict (in Objects/dictobject.c). + MINSIZE = 8 + i = (hash & MINSIZE-1) + return (i << 2) + i + hash + 1 + + COLLIDING = 1 + + key = Key(COLLIDING) + colliding = Key(COLLIDING) + blocking = Key(blocking_hash(COLLIDING)) + + od = OrderedDict() + od[key] = ... + od[blocking] = ... + od[colliding] = ... + od['after'] = ... + + del od[blocking] + del od[colliding] + self.assertEqual(list(od.items()), [(key, ...), ('after', ...)]) + + def test_key_change_during_iteration(self): + OrderedDict = self.module.OrderedDict + + od = OrderedDict.fromkeys('abcde') + self.assertEqual(list(od), list('abcde')) + with self.assertRaises(RuntimeError): + for i, k in enumerate(od): + od.move_to_end(k) + self.assertLess(i, 5) + with self.assertRaises(RuntimeError): + for k in od: + od['f'] = None + with self.assertRaises(RuntimeError): + for k in od: + del od['c'] + self.assertEqual(list(od), list('bdeaf')) + + def test_issue24347(self): + OrderedDict = self.module.OrderedDict + + class Key: + def __hash__(self): + return randrange(100000) + + od = OrderedDict() + for i in range(100): + key = Key() + od[key] = i + + # These should not crash. + with self.assertRaises(KeyError): + repr(od) + with self.assertRaises(KeyError): + od.copy() + + def test_issue24348(self): + OrderedDict = self.module.OrderedDict + + class Key: + def __hash__(self): + return 1 + + od = OrderedDict() + od[Key()] = 0 + # This should not crash. + od.popitem() + + def test_issue24667(self): + """ + dict resizes after a certain number of insertion operations, + whether or not there were deletions that freed up slots in the + hash table. During fast node lookup, OrderedDict must correctly + respond to all resizes, even if the current "size" is the same + as the old one. We verify that here by forcing a dict resize + on a sparse odict and then perform an operation that should + trigger an odict resize (e.g. popitem). One key aspect here is + that we will keep the size of the odict the same at each popitem + call. This verifies that we handled the dict resize properly. + """ + OrderedDict = self.module.OrderedDict + + od = OrderedDict() + for c0 in '0123456789ABCDEF': + for c1 in '0123456789ABCDEF': + if len(od) == 4: + # This should not raise a KeyError. + od.popitem(last=False) + key = c0 + c1 + od[key] = key + + +class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): + + @classmethod + def setUpClass(cls): + cls.type2test = py_coll.OrderedDict + + def test_popitem(self): + d = self._empty_mapping() + self.assertRaises(KeyError, d.popitem) + + +@unittest.skipUnless(c_coll, 'requires the C version of the collections module') +class CPythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): + + @classmethod + def setUpClass(cls): + cls.type2test = c_coll.OrderedDict + + def test_popitem(self): + d = self._empty_mapping() + self.assertRaises(KeyError, d.popitem) + + +class PurePythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + + @classmethod + def setUpClass(cls): + class MyOrderedDict(py_coll.OrderedDict): + pass + cls.type2test = MyOrderedDict def test_popitem(self): d = self._empty_mapping() self.assertRaises(KeyError, d.popitem) -class MyOrderedDict(OrderedDict): - pass -class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): - type2test = MyOrderedDict +@unittest.skipUnless(c_coll, 'requires the C version of the collections module') +class CPythonSubclassMappingTests(mapping_tests.BasicTestMappingProtocol): + + @classmethod + def setUpClass(cls): + class MyOrderedDict(c_coll.OrderedDict): + pass + cls.type2test = MyOrderedDict def test_popitem(self): d = self._empty_mapping() @@ -1583,7 +2192,11 @@ def test_main(verbose=None): NamedTupleDocs = doctest.DocTestSuite(module=collections) test_classes = [TestNamedTuple, NamedTupleDocs, TestOneTrickPonyABCs, TestCollectionABCs, TestCounter, TestChainMap, - TestOrderedDict, GeneralMappingTests, SubclassMappingTests] + PurePythonOrderedDictTests, CPythonOrderedDictTests, + PurePythonGeneralMappingTests, CPythonGeneralMappingTests, + PurePythonSubclassMappingTests, CPythonSubclassMappingTests, + TestUserObjects, + ] support.run_unittest(*test_classes) support.run_doctest(collections, verbose) diff --git a/Lib/test/test_compare.py b/Lib/test/test_compare.py index a663832..471c8da 100644 --- a/Lib/test/test_compare.py +++ b/Lib/test/test_compare.py @@ -1,5 +1,4 @@ import unittest -from test import support class Empty: def __repr__(self): @@ -121,8 +120,5 @@ class ComparisonTest(unittest.TestCase): self.assertEqual(Anything(), y) -def test_main(): - support.run_unittest(ComparisonTest) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py index cff3c9e..db821be 100644 --- a/Lib/test/test_compile.py +++ b/Lib/test/test_compile.py @@ -5,7 +5,8 @@ import sys import _ast import tempfile import types -from test import support, script_helper +from test import support +from test.support import script_helper class TestSpecifics(unittest.TestCase): @@ -427,7 +428,7 @@ if 1: def test_compile_ast(self): fname = __file__ - if fname.lower().endswith(('pyc', 'pyo')): + if fname.lower().endswith('pyc'): fname = fname[:-1] with open(fname, 'r') as f: fcontents = f.read() @@ -460,6 +461,17 @@ if 1: ast.body = [_ast.BoolOp()] self.assertRaises(TypeError, compile, ast, '<ast>', 'exec') + def test_dict_evaluation_order(self): + i = 0 + + def f(): + nonlocal i + i += 1 + return i + + d = {f(): f(), f(): f()} + self.assertEqual(d, {1: 2, 3: 4}) + @support.cpython_only def test_same_filename_used(self): s = """def f(): pass\ndef g(): pass""" @@ -522,7 +534,7 @@ if 1: broken = prefix + repeated * fail_depth details = "Compiling ({!r} + {!r} * {})".format( prefix, repeated, fail_depth) - with self.assertRaises(RuntimeError, msg=details): + with self.assertRaises(RecursionError, msg=details): self.compile_single(broken) check_limit("a", "()") diff --git a/Lib/test/test_compileall.py b/Lib/test/test_compileall.py index 2a42238..9479776 100644 --- a/Lib/test/test_compileall.py +++ b/Lib/test/test_compileall.py @@ -10,7 +10,15 @@ import time import unittest import io -from test import support, script_helper +from unittest import mock, skipUnless +try: + from concurrent.futures import ProcessPoolExecutor + _have_multiprocessing = True +except ImportError: + _have_multiprocessing = False + +from test import support +from test.support import script_helper class CompileallTests(unittest.TestCase): @@ -94,18 +102,45 @@ class CompileallTests(unittest.TestCase): def test_optimize(self): # make sure compiling with different optimization settings than the # interpreter's creates the correct file names - optimize = 1 if __debug__ else 0 + optimize, opt = (1, 1) if __debug__ else (0, '') compileall.compile_dir(self.directory, quiet=True, optimize=optimize) cached = importlib.util.cache_from_source(self.source_path, - debug_override=not optimize) + optimization=opt) self.assertTrue(os.path.isfile(cached)) cached2 = importlib.util.cache_from_source(self.source_path2, - debug_override=not optimize) + optimization=opt) self.assertTrue(os.path.isfile(cached2)) cached3 = importlib.util.cache_from_source(self.source_path3, - debug_override=not optimize) + optimization=opt) self.assertTrue(os.path.isfile(cached3)) + @mock.patch('compileall.ProcessPoolExecutor') + def test_compile_pool_called(self, pool_mock): + compileall.compile_dir(self.directory, quiet=True, workers=5) + self.assertTrue(pool_mock.called) + + def test_compile_workers_non_positive(self): + with self.assertRaisesRegex(ValueError, + "workers must be greater or equal to 0"): + compileall.compile_dir(self.directory, workers=-1) + + @mock.patch('compileall.ProcessPoolExecutor') + def test_compile_workers_cpu_count(self, pool_mock): + compileall.compile_dir(self.directory, quiet=True, workers=0) + self.assertEqual(pool_mock.call_args[1]['max_workers'], None) + + @mock.patch('compileall.ProcessPoolExecutor') + @mock.patch('compileall.compile_file') + def test_compile_one_worker(self, compile_file_mock, pool_mock): + compileall.compile_dir(self.directory, quiet=True) + self.assertFalse(pool_mock.called) + self.assertTrue(compile_file_mock.called) + + @mock.patch('compileall.ProcessPoolExecutor', new=None) + @mock.patch('compileall.compile_file') + def test_compile_missing_multiprocessing(self, compile_file_mock): + compileall.compile_dir(self.directory, quiet=True, workers=5) + self.assertTrue(compile_file_mock.called) class EncodingTest(unittest.TestCase): """Issue 6716: compileall should escape source code when printing errors @@ -203,11 +238,11 @@ class CommandLineTests(unittest.TestCase): self.assertNotIn(b'Listing ', quiet) # Ensure that the default behavior of compileall's CLI is to create - # PEP 3147 pyc/pyo files. + # PEP 3147/PEP 488 pyc files. for name, ext, switch in [ ('normal', 'pyc', []), - ('optimize', 'pyo', ['-O']), - ('doubleoptimize', 'pyo', ['-OO']), + ('optimize', 'opt-1.pyc', ['-O']), + ('doubleoptimize', 'opt-2.pyc', ['-OO']), ]: def f(self, ext=ext, switch=switch): script_helper.assert_python_ok(*(switch + @@ -224,13 +259,12 @@ class CommandLineTests(unittest.TestCase): def test_legacy_paths(self): # Ensure that with the proper switch, compileall leaves legacy - # pyc/pyo files, and no __pycache__ directory. + # pyc files, and no __pycache__ directory. self.assertRunOK('-b', '-q', self.pkgdir) # Verify the __pycache__ directory contents. self.assertFalse(os.path.exists(self.pkgdir_cachedir)) - opt = 'c' if __debug__ else 'o' - expected = sorted(['__init__.py', '__init__.py' + opt, 'bar.py', - 'bar.py' + opt]) + expected = sorted(['__init__.py', '__init__.pyc', 'bar.py', + 'bar.pyc']) self.assertEqual(sorted(os.listdir(self.pkgdir)), expected) def test_multiple_runs(self): @@ -273,12 +307,53 @@ class CommandLineTests(unittest.TestCase): self.assertCompiled(subinitfn) self.assertCompiled(hamfn) + def test_recursion_limit(self): + subpackage = os.path.join(self.pkgdir, 'spam') + subpackage2 = os.path.join(subpackage, 'ham') + subpackage3 = os.path.join(subpackage2, 'eggs') + for pkg in (subpackage, subpackage2, subpackage3): + script_helper.make_pkg(pkg) + + subinitfn = os.path.join(subpackage, '__init__.py') + hamfn = script_helper.make_script(subpackage, 'ham', '') + spamfn = script_helper.make_script(subpackage2, 'spam', '') + eggfn = script_helper.make_script(subpackage3, 'egg', '') + + self.assertRunOK('-q', '-r 0', self.pkgdir) + self.assertNotCompiled(subinitfn) + self.assertFalse( + os.path.exists(os.path.join(subpackage, '__pycache__'))) + + self.assertRunOK('-q', '-r 1', self.pkgdir) + self.assertCompiled(subinitfn) + self.assertCompiled(hamfn) + self.assertNotCompiled(spamfn) + + self.assertRunOK('-q', '-r 2', self.pkgdir) + self.assertCompiled(subinitfn) + self.assertCompiled(hamfn) + self.assertCompiled(spamfn) + self.assertNotCompiled(eggfn) + + self.assertRunOK('-q', '-r 5', self.pkgdir) + self.assertCompiled(subinitfn) + self.assertCompiled(hamfn) + self.assertCompiled(spamfn) + self.assertCompiled(eggfn) + def test_quiet(self): noisy = self.assertRunOK(self.pkgdir) quiet = self.assertRunOK('-q', self.pkgdir) self.assertNotEqual(b'', noisy) self.assertEqual(b'', quiet) + def test_silent(self): + script_helper.make_script(self.pkgdir, 'crunchyfrog', 'bad(syntax') + _, quiet, _ = self.assertRunNotOK('-q', self.pkgdir) + _, silent, _ = self.assertRunNotOK('-qq', self.pkgdir) + self.assertNotEqual(b'', quiet) + self.assertEqual(b'', silent) + def test_regexp(self): self.assertRunOK('-q', '-x', r'ba[^\\/]*$', self.pkgdir) self.assertNotCompiled(self.barfn) @@ -379,6 +454,29 @@ class CommandLineTests(unittest.TestCase): out = self.assertRunOK('badfilename') self.assertRegex(out, b"Can't list 'badfilename'") + @skipUnless(_have_multiprocessing, "requires multiprocessing") + def test_workers(self): + bar2fn = script_helper.make_script(self.directory, 'bar2', '') + files = [] + for suffix in range(5): + pkgdir = os.path.join(self.directory, 'foo{}'.format(suffix)) + os.mkdir(pkgdir) + fn = script_helper.make_script(pkgdir, '__init__', '') + files.append(script_helper.make_script(pkgdir, 'bar2', '')) + + self.assertRunOK(self.directory, '-j', '0') + self.assertCompiled(bar2fn) + for file in files: + self.assertCompiled(file) + + @mock.patch('compileall.compile_dir') + def test_workers_available_cores(self, compile_dir): + with mock.patch("sys.argv", + new=[sys.executable, self.directory, "-j0"]): + compileall.main() + self.assertTrue(compile_dir.called) + self.assertEqual(compile_dir.call_args[-1]['workers'], None) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_concurrent_futures.py b/Lib/test/test_concurrent_futures.py index c74b2ca..b99740b 100644 --- a/Lib/test/test_concurrent_futures.py +++ b/Lib/test/test_concurrent_futures.py @@ -9,8 +9,9 @@ test.support.import_module('multiprocessing.synchronize') # without thread support. test.support.import_module('threading') -from test.script_helper import assert_python_ok +from test.support.script_helper import assert_python_ok +import os import sys import threading import time @@ -425,6 +426,13 @@ class ExecutorTest: self.assertTrue(collected, "Stale reference not collected within timeout.") + def test_max_workers_negative(self): + for number in (0, -1): + with self.assertRaisesRegex(ValueError, + "max_workers must be greater " + "than 0"): + self.executor_type(max_workers=number) + class ThreadPoolExecutorTest(ThreadPoolMixin, ExecutorTest, unittest.TestCase): def test_map_submits_without_iteration(self): @@ -437,6 +445,11 @@ class ThreadPoolExecutorTest(ThreadPoolMixin, ExecutorTest, unittest.TestCase): self.executor.shutdown(wait=True) self.assertCountEqual(finished, range(10)) + def test_default_workers(self): + executor = self.executor_type() + self.assertEqual(executor._max_workers, + (os.cpu_count() or 1) * 5) + class ProcessPoolExecutorTest(ProcessPoolMixin, ExecutorTest, unittest.TestCase): def test_killed_child(self): @@ -451,6 +464,48 @@ class ProcessPoolExecutorTest(ProcessPoolMixin, ExecutorTest, unittest.TestCase) # Submitting other jobs fails as well. self.assertRaises(BrokenProcessPool, self.executor.submit, pow, 2, 8) + def test_map_chunksize(self): + def bad_map(): + list(self.executor.map(pow, range(40), range(40), chunksize=-1)) + + ref = list(map(pow, range(40), range(40))) + self.assertEqual( + list(self.executor.map(pow, range(40), range(40), chunksize=6)), + ref) + self.assertEqual( + list(self.executor.map(pow, range(40), range(40), chunksize=50)), + ref) + self.assertEqual( + list(self.executor.map(pow, range(40), range(40), chunksize=40)), + ref) + self.assertRaises(ValueError, bad_map) + + @classmethod + def _test_traceback(cls): + raise RuntimeError(123) # some comment + + def test_traceback(self): + # We want ensure that the traceback from the child process is + # contained in the traceback raised in the main process. + future = self.executor.submit(self._test_traceback) + with self.assertRaises(Exception) as cm: + future.result() + + exc = cm.exception + self.assertIs(type(exc), RuntimeError) + self.assertEqual(exc.args, (123,)) + cause = exc.__cause__ + self.assertIs(type(cause), futures.process._RemoteTraceback) + self.assertIn('raise RuntimeError(123) # some comment', cause.tb) + + with test.support.captured_stderr() as f1: + try: + raise exc + except RuntimeError: + sys.excepthook(*sys.exc_info()) + self.assertIn('raise RuntimeError(123) # some comment', + f1.getvalue()) + class FutureTests(unittest.TestCase): def test_done_callback_with_result(self): diff --git a/Lib/test/test_configparser.py b/Lib/test/test_configparser.py index 3b03500..71a8f3f 100644 --- a/Lib/test/test_configparser.py +++ b/Lib/test/test_configparser.py @@ -579,7 +579,7 @@ boolean {0[0]} NO return e else: self.fail("expected exception type %s.%s" - % (exc.__module__, exc.__name__)) + % (exc.__module__, exc.__qualname__)) def test_boolean(self): cf = self.fromstring( @@ -1585,6 +1585,34 @@ class CoverageOneHundredTestCase(unittest.TestCase): """) self.assertEqual(repr(parser['section']), '<Section: section>') + def test_inconsistent_converters_state(self): + parser = configparser.ConfigParser() + import decimal + parser.converters['decimal'] = decimal.Decimal + parser.read_string(""" + [s1] + one = 1 + [s2] + two = 2 + """) + self.assertIn('decimal', parser.converters) + self.assertEqual(parser.getdecimal('s1', 'one'), 1) + self.assertEqual(parser.getdecimal('s2', 'two'), 2) + self.assertEqual(parser['s1'].getdecimal('one'), 1) + self.assertEqual(parser['s2'].getdecimal('two'), 2) + del parser.getdecimal + with self.assertRaises(AttributeError): + parser.getdecimal('s1', 'one') + self.assertIn('decimal', parser.converters) + del parser.converters['decimal'] + self.assertNotIn('decimal', parser.converters) + with self.assertRaises(AttributeError): + parser.getdecimal('s1', 'one') + with self.assertRaises(AttributeError): + parser['s1'].getdecimal('one') + with self.assertRaises(AttributeError): + parser['s2'].getdecimal('two') + class ExceptionPicklingTestCase(unittest.TestCase): """Tests for issue #13760: ConfigParser exceptions are not picklable.""" @@ -1777,5 +1805,252 @@ class InlineCommentStrippingTestCase(unittest.TestCase): self.assertEqual(s['k3'], 'v3;#//still v3# and still v3') +class ExceptionContextTestCase(unittest.TestCase): + """ Test that implementation details doesn't leak + through raising exceptions. """ + + def test_get_basic_interpolation(self): + parser = configparser.ConfigParser() + parser.read_string(""" + [Paths] + home_dir: /Users + my_dir: %(home_dir1)s/lumberjack + my_pictures: %(my_dir)s/Pictures + """) + cm = self.assertRaises(configparser.InterpolationMissingOptionError) + with cm: + parser.get('Paths', 'my_dir') + self.assertIs(cm.exception.__suppress_context__, True) + + def test_get_extended_interpolation(self): + parser = configparser.ConfigParser( + interpolation=configparser.ExtendedInterpolation()) + parser.read_string(""" + [Paths] + home_dir: /Users + my_dir: ${home_dir1}/lumberjack + my_pictures: ${my_dir}/Pictures + """) + cm = self.assertRaises(configparser.InterpolationMissingOptionError) + with cm: + parser.get('Paths', 'my_dir') + self.assertIs(cm.exception.__suppress_context__, True) + + def test_missing_options(self): + parser = configparser.ConfigParser() + parser.read_string(""" + [Paths] + home_dir: /Users + """) + with self.assertRaises(configparser.NoSectionError) as cm: + parser.options('test') + self.assertIs(cm.exception.__suppress_context__, True) + + def test_missing_section(self): + config = configparser.ConfigParser() + with self.assertRaises(configparser.NoSectionError) as cm: + config.set('Section1', 'an_int', '15') + self.assertIs(cm.exception.__suppress_context__, True) + + def test_remove_option(self): + config = configparser.ConfigParser() + with self.assertRaises(configparser.NoSectionError) as cm: + config.remove_option('Section1', 'an_int') + self.assertIs(cm.exception.__suppress_context__, True) + + +class ConvertersTestCase(BasicTestCase, unittest.TestCase): + """Introduced in 3.5, issue #18159.""" + + config_class = configparser.ConfigParser + + def newconfig(self, defaults=None): + instance = super().newconfig(defaults=defaults) + instance.converters['list'] = lambda v: [e.strip() for e in v.split() + if e.strip()] + return instance + + def test_converters(self): + cfg = self.newconfig() + self.assertIn('boolean', cfg.converters) + self.assertIn('list', cfg.converters) + self.assertIsNone(cfg.converters['int']) + self.assertIsNone(cfg.converters['float']) + self.assertIsNone(cfg.converters['boolean']) + self.assertIsNotNone(cfg.converters['list']) + self.assertEqual(len(cfg.converters), 4) + with self.assertRaises(ValueError): + cfg.converters[''] = lambda v: v + with self.assertRaises(ValueError): + cfg.converters[None] = lambda v: v + cfg.read_string(""" + [s] + str = string + int = 1 + float = 0.5 + list = a b c d e f g + bool = yes + """) + s = cfg['s'] + self.assertEqual(s['str'], 'string') + self.assertEqual(s['int'], '1') + self.assertEqual(s['float'], '0.5') + self.assertEqual(s['list'], 'a b c d e f g') + self.assertEqual(s['bool'], 'yes') + self.assertEqual(cfg.get('s', 'str'), 'string') + self.assertEqual(cfg.get('s', 'int'), '1') + self.assertEqual(cfg.get('s', 'float'), '0.5') + self.assertEqual(cfg.get('s', 'list'), 'a b c d e f g') + self.assertEqual(cfg.get('s', 'bool'), 'yes') + self.assertEqual(cfg.get('s', 'str'), 'string') + self.assertEqual(cfg.getint('s', 'int'), 1) + self.assertEqual(cfg.getfloat('s', 'float'), 0.5) + self.assertEqual(cfg.getlist('s', 'list'), ['a', 'b', 'c', 'd', + 'e', 'f', 'g']) + self.assertEqual(cfg.getboolean('s', 'bool'), True) + self.assertEqual(s.get('str'), 'string') + self.assertEqual(s.getint('int'), 1) + self.assertEqual(s.getfloat('float'), 0.5) + self.assertEqual(s.getlist('list'), ['a', 'b', 'c', 'd', + 'e', 'f', 'g']) + self.assertEqual(s.getboolean('bool'), True) + with self.assertRaises(AttributeError): + cfg.getdecimal('s', 'float') + with self.assertRaises(AttributeError): + s.getdecimal('float') + import decimal + cfg.converters['decimal'] = decimal.Decimal + self.assertIn('decimal', cfg.converters) + self.assertIsNotNone(cfg.converters['decimal']) + self.assertEqual(len(cfg.converters), 5) + dec0_5 = decimal.Decimal('0.5') + self.assertEqual(cfg.getdecimal('s', 'float'), dec0_5) + self.assertEqual(s.getdecimal('float'), dec0_5) + del cfg.converters['decimal'] + self.assertNotIn('decimal', cfg.converters) + self.assertEqual(len(cfg.converters), 4) + with self.assertRaises(AttributeError): + cfg.getdecimal('s', 'float') + with self.assertRaises(AttributeError): + s.getdecimal('float') + with self.assertRaises(KeyError): + del cfg.converters['decimal'] + with self.assertRaises(KeyError): + del cfg.converters[''] + with self.assertRaises(KeyError): + del cfg.converters[None] + + +class BlatantOverrideConvertersTestCase(unittest.TestCase): + """What if somebody overrode a getboolean()? We want to make sure that in + this case the automatic converters do not kick in.""" + + config = """ + [one] + one = false + two = false + three = long story short + + [two] + one = false + two = false + three = four + """ + + def test_converters_at_init(self): + cfg = configparser.ConfigParser(converters={'len': len}) + cfg.read_string(self.config) + self._test_len(cfg) + self.assertIsNotNone(cfg.converters['len']) + + def test_inheritance(self): + class StrangeConfigParser(configparser.ConfigParser): + gettysburg = 'a historic borough in south central Pennsylvania' + + def getboolean(self, section, option, *, raw=False, vars=None, + fallback=configparser._UNSET): + if section == option: + return True + return super().getboolean(section, option, raw=raw, vars=vars, + fallback=fallback) + def getlen(self, section, option, *, raw=False, vars=None, + fallback=configparser._UNSET): + return self._get_conv(section, option, len, raw=raw, vars=vars, + fallback=fallback) + + cfg = StrangeConfigParser() + cfg.read_string(self.config) + self._test_len(cfg) + self.assertIsNone(cfg.converters['len']) + self.assertTrue(cfg.getboolean('one', 'one')) + self.assertTrue(cfg.getboolean('two', 'two')) + self.assertFalse(cfg.getboolean('one', 'two')) + self.assertFalse(cfg.getboolean('two', 'one')) + cfg.converters['boolean'] = cfg._convert_to_boolean + self.assertFalse(cfg.getboolean('one', 'one')) + self.assertFalse(cfg.getboolean('two', 'two')) + self.assertFalse(cfg.getboolean('one', 'two')) + self.assertFalse(cfg.getboolean('two', 'one')) + + def _test_len(self, cfg): + self.assertEqual(len(cfg.converters), 4) + self.assertIn('boolean', cfg.converters) + self.assertIn('len', cfg.converters) + self.assertNotIn('tysburg', cfg.converters) + self.assertIsNone(cfg.converters['int']) + self.assertIsNone(cfg.converters['float']) + self.assertIsNone(cfg.converters['boolean']) + self.assertEqual(cfg.getlen('one', 'one'), 5) + self.assertEqual(cfg.getlen('one', 'two'), 5) + self.assertEqual(cfg.getlen('one', 'three'), 16) + self.assertEqual(cfg.getlen('two', 'one'), 5) + self.assertEqual(cfg.getlen('two', 'two'), 5) + self.assertEqual(cfg.getlen('two', 'three'), 4) + self.assertEqual(cfg.getlen('two', 'four', fallback=0), 0) + with self.assertRaises(configparser.NoOptionError): + cfg.getlen('two', 'four') + self.assertEqual(cfg['one'].getlen('one'), 5) + self.assertEqual(cfg['one'].getlen('two'), 5) + self.assertEqual(cfg['one'].getlen('three'), 16) + self.assertEqual(cfg['two'].getlen('one'), 5) + self.assertEqual(cfg['two'].getlen('two'), 5) + self.assertEqual(cfg['two'].getlen('three'), 4) + self.assertEqual(cfg['two'].getlen('four', 0), 0) + self.assertEqual(cfg['two'].getlen('four'), None) + + def test_instance_assignment(self): + cfg = configparser.ConfigParser() + cfg.getboolean = lambda section, option: True + cfg.getlen = lambda section, option: len(cfg[section][option]) + cfg.read_string(self.config) + self.assertEqual(len(cfg.converters), 3) + self.assertIn('boolean', cfg.converters) + self.assertNotIn('len', cfg.converters) + self.assertIsNone(cfg.converters['int']) + self.assertIsNone(cfg.converters['float']) + self.assertIsNone(cfg.converters['boolean']) + self.assertTrue(cfg.getboolean('one', 'one')) + self.assertTrue(cfg.getboolean('two', 'two')) + self.assertTrue(cfg.getboolean('one', 'two')) + self.assertTrue(cfg.getboolean('two', 'one')) + cfg.converters['boolean'] = cfg._convert_to_boolean + self.assertFalse(cfg.getboolean('one', 'one')) + self.assertFalse(cfg.getboolean('two', 'two')) + self.assertFalse(cfg.getboolean('one', 'two')) + self.assertFalse(cfg.getboolean('two', 'one')) + self.assertEqual(cfg.getlen('one', 'one'), 5) + self.assertEqual(cfg.getlen('one', 'two'), 5) + self.assertEqual(cfg.getlen('one', 'three'), 16) + self.assertEqual(cfg.getlen('two', 'one'), 5) + self.assertEqual(cfg.getlen('two', 'two'), 5) + self.assertEqual(cfg.getlen('two', 'three'), 4) + # If a getter impl is assigned straight to the instance, it won't + # be available on the section proxies. + with self.assertRaises(AttributeError): + self.assertEqual(cfg['one'].getlen('one'), 5) + with self.assertRaises(AttributeError): + self.assertEqual(cfg['two'].getlen('one'), 5) + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_contains.py b/Lib/test/test_contains.py index a667a16..3c6bdef 100644 --- a/Lib/test/test_contains.py +++ b/Lib/test/test_contains.py @@ -1,5 +1,4 @@ from collections import deque -from test.support import run_unittest import unittest @@ -86,8 +85,5 @@ class TestContains(unittest.TestCase): self.assertTrue(container == container) -def test_main(): - run_unittest(TestContains) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 8f849ae..78741f5 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -83,6 +83,42 @@ class ContextManagerTestCase(unittest.TestCase): raise ZeroDivisionError(999) self.assertEqual(state, [1, 42, 999]) + def test_contextmanager_except_stopiter(self): + stop_exc = StopIteration('spam') + @contextmanager + def woohoo(): + yield + try: + with self.assertWarnsRegex(PendingDeprecationWarning, + "StopIteration"): + with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail('StopIteration was suppressed') + + def test_contextmanager_except_pep479(self): + code = """\ +from __future__ import generator_stop +from contextlib import contextmanager +@contextmanager +def woohoo(): + yield +""" + locals = {} + exec(code, locals, locals) + woohoo = locals['woohoo'] + + stop_exc = StopIteration('spam') + try: + with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail('StopIteration was suppressed') + def _create_contextmanager_attribs(self): def attribs(**kw): def decorate(func): @@ -726,60 +762,76 @@ class TestExitStack(unittest.TestCase): stack.push(cm) self.assertIs(stack._exit_callbacks[-1], cm) -class TestRedirectStdout(unittest.TestCase): + +class TestRedirectStream: + + redirect_stream = None + orig_stream = None @support.requires_docstrings def test_instance_docs(self): # Issue 19330: ensure context manager instances have good docstrings - cm_docstring = redirect_stdout.__doc__ - obj = redirect_stdout(None) + cm_docstring = self.redirect_stream.__doc__ + obj = self.redirect_stream(None) self.assertEqual(obj.__doc__, cm_docstring) def test_no_redirect_in_init(self): - orig_stdout = sys.stdout - redirect_stdout(None) - self.assertIs(sys.stdout, orig_stdout) + orig_stdout = getattr(sys, self.orig_stream) + self.redirect_stream(None) + self.assertIs(getattr(sys, self.orig_stream), orig_stdout) def test_redirect_to_string_io(self): f = io.StringIO() msg = "Consider an API like help(), which prints directly to stdout" - orig_stdout = sys.stdout - with redirect_stdout(f): - print(msg) - self.assertIs(sys.stdout, orig_stdout) + orig_stdout = getattr(sys, self.orig_stream) + with self.redirect_stream(f): + print(msg, file=getattr(sys, self.orig_stream)) + self.assertIs(getattr(sys, self.orig_stream), orig_stdout) s = f.getvalue().strip() self.assertEqual(s, msg) def test_enter_result_is_target(self): f = io.StringIO() - with redirect_stdout(f) as enter_result: + with self.redirect_stream(f) as enter_result: self.assertIs(enter_result, f) def test_cm_is_reusable(self): f = io.StringIO() - write_to_f = redirect_stdout(f) - orig_stdout = sys.stdout + write_to_f = self.redirect_stream(f) + orig_stdout = getattr(sys, self.orig_stream) with write_to_f: - print("Hello", end=" ") + print("Hello", end=" ", file=getattr(sys, self.orig_stream)) with write_to_f: - print("World!") - self.assertIs(sys.stdout, orig_stdout) + print("World!", file=getattr(sys, self.orig_stream)) + self.assertIs(getattr(sys, self.orig_stream), orig_stdout) s = f.getvalue() self.assertEqual(s, "Hello World!\n") def test_cm_is_reentrant(self): f = io.StringIO() - write_to_f = redirect_stdout(f) - orig_stdout = sys.stdout + write_to_f = self.redirect_stream(f) + orig_stdout = getattr(sys, self.orig_stream) with write_to_f: - print("Hello", end=" ") + print("Hello", end=" ", file=getattr(sys, self.orig_stream)) with write_to_f: - print("World!") - self.assertIs(sys.stdout, orig_stdout) + print("World!", file=getattr(sys, self.orig_stream)) + self.assertIs(getattr(sys, self.orig_stream), orig_stdout) s = f.getvalue() self.assertEqual(s, "Hello World!\n") +class TestRedirectStdout(TestRedirectStream, unittest.TestCase): + + redirect_stream = redirect_stdout + orig_stream = "stdout" + + +class TestRedirectStderr(TestRedirectStream, unittest.TestCase): + + redirect_stream = redirect_stderr + orig_stream = "stderr" + + class TestSuppress(unittest.TestCase): @support.requires_docstrings diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py index eb8d18c..b9eaddd 100644 --- a/Lib/test/test_copy.py +++ b/Lib/test/test_copy.py @@ -7,7 +7,6 @@ import abc from operator import le, lt, ge, gt, eq, ne import unittest -from test import support order_comparisons = le, lt, ge, gt equality_comparisons = eq, ne @@ -146,6 +145,40 @@ class TestCopy(unittest.TestCase): x = C(42) self.assertEqual(copy.copy(x), x) + def test_copy_inst_getnewargs(self): + class C(int): + def __new__(cls, foo): + self = int.__new__(cls) + self.foo = foo + return self + def __getnewargs__(self): + return self.foo, + def __eq__(self, other): + return self.foo == other.foo + x = C(42) + y = copy.copy(x) + self.assertIsInstance(y, C) + self.assertEqual(y, x) + self.assertIsNot(y, x) + self.assertEqual(y.foo, x.foo) + + def test_copy_inst_getnewargs_ex(self): + class C(int): + def __new__(cls, *, foo): + self = int.__new__(cls) + self.foo = foo + return self + def __getnewargs_ex__(self): + return (), {'foo': self.foo} + def __eq__(self, other): + return self.foo == other.foo + x = C(foo=42) + y = copy.copy(x) + self.assertIsInstance(y, C) + self.assertEqual(y, x) + self.assertIsNot(y, x) + self.assertEqual(y.foo, x.foo) + def test_copy_inst_getstate(self): class C: def __init__(self, foo): @@ -294,7 +327,7 @@ class TestCopy(unittest.TestCase): x.append(x) y = copy.deepcopy(x) for op in comparisons: - self.assertRaises(RuntimeError, op, y, x) + self.assertRaises(RecursionError, op, y, x) self.assertIsNot(y, x) self.assertIs(y[0], y) self.assertEqual(len(y), 1) @@ -321,7 +354,7 @@ class TestCopy(unittest.TestCase): x[0].append(x) y = copy.deepcopy(x) for op in comparisons: - self.assertRaises(RuntimeError, op, y, x) + self.assertRaises(RecursionError, op, y, x) self.assertIsNot(y, x) self.assertIsNot(y[0], x[0]) self.assertIs(y[0][0], y) @@ -340,7 +373,7 @@ class TestCopy(unittest.TestCase): for op in order_comparisons: self.assertRaises(TypeError, op, y, x) for op in equality_comparisons: - self.assertRaises(RuntimeError, op, y, x) + self.assertRaises(RecursionError, op, y, x) self.assertIsNot(y, x) self.assertIs(y['foo'], y) self.assertEqual(len(y), 1) @@ -405,6 +438,42 @@ class TestCopy(unittest.TestCase): self.assertIsNot(y, x) self.assertIsNot(y.foo, x.foo) + def test_deepcopy_inst_getnewargs(self): + class C(int): + def __new__(cls, foo): + self = int.__new__(cls) + self.foo = foo + return self + def __getnewargs__(self): + return self.foo, + def __eq__(self, other): + return self.foo == other.foo + x = C([42]) + y = copy.deepcopy(x) + self.assertIsInstance(y, C) + self.assertEqual(y, x) + self.assertIsNot(y, x) + self.assertEqual(y.foo, x.foo) + self.assertIsNot(y.foo, x.foo) + + def test_deepcopy_inst_getnewargs_ex(self): + class C(int): + def __new__(cls, *, foo): + self = int.__new__(cls) + self.foo = foo + return self + def __getnewargs_ex__(self): + return (), {'foo': self.foo} + def __eq__(self, other): + return self.foo == other.foo + x = C(foo=[42]) + y = copy.deepcopy(x) + self.assertIsInstance(y, C) + self.assertEqual(y, x) + self.assertIsNot(y, x) + self.assertEqual(y.foo, x.foo) + self.assertIsNot(y.foo, x.foo) + def test_deepcopy_inst_getstate(self): class C: def __init__(self, foo): @@ -752,8 +821,5 @@ class TestCopy(unittest.TestCase): def global_foo(x, y): return x+y -def test_main(): - support.run_unittest(TestCopy) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_copyreg.py b/Lib/test/test_copyreg.py index abe0748..52e887c 100644 --- a/Lib/test/test_copyreg.py +++ b/Lib/test/test_copyreg.py @@ -1,7 +1,6 @@ import copyreg import unittest -from test import support from test.pickletester import ExtensionSaver class C: @@ -113,9 +112,5 @@ class CopyRegTestCase(unittest.TestCase): self.assertEqual(result, expected) -def test_main(): - support.run_unittest(CopyRegTestCase) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_coroutines.py b/Lib/test/test_coroutines.py new file mode 100644 index 0000000..10de856 --- /dev/null +++ b/Lib/test/test_coroutines.py @@ -0,0 +1,1471 @@ +import contextlib +import inspect +import sys +import types +import unittest +import warnings +from test import support + + +class AsyncYieldFrom: + def __init__(self, obj): + self.obj = obj + + def __await__(self): + yield from self.obj + + +class AsyncYield: + def __init__(self, value): + self.value = value + + def __await__(self): + yield self.value + + +def run_async(coro): + assert coro.__class__ in {types.GeneratorType, types.CoroutineType} + + buffer = [] + result = None + while True: + try: + buffer.append(coro.send(None)) + except StopIteration as ex: + result = ex.args[0] if ex.args else None + break + return buffer, result + + +def run_async__await__(coro): + assert coro.__class__ is types.CoroutineType + aw = coro.__await__() + buffer = [] + result = None + i = 0 + while True: + try: + if i % 2: + buffer.append(next(aw)) + else: + buffer.append(aw.send(None)) + i += 1 + except StopIteration as ex: + result = ex.args[0] if ex.args else None + break + return buffer, result + + +@contextlib.contextmanager +def silence_coro_gc(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + yield + support.gc_collect() + + +class AsyncBadSyntaxTest(unittest.TestCase): + + def test_badsyntax_1(self): + with self.assertRaisesRegex(SyntaxError, "'await' outside"): + import test.badsyntax_async1 + + def test_badsyntax_2(self): + with self.assertRaisesRegex(SyntaxError, "'await' outside"): + import test.badsyntax_async2 + + def test_badsyntax_3(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async3 + + def test_badsyntax_4(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async4 + + def test_badsyntax_5(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async5 + + def test_badsyntax_6(self): + with self.assertRaisesRegex( + SyntaxError, "'yield' inside async function"): + + import test.badsyntax_async6 + + def test_badsyntax_7(self): + with self.assertRaisesRegex( + SyntaxError, "'yield from' inside async function"): + + import test.badsyntax_async7 + + def test_badsyntax_8(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async8 + + def test_badsyntax_9(self): + ns = {} + for comp in {'(await a for a in b)', + '[await a for a in b]', + '{await a for a in b}', + '{await a: c for a in b}'}: + + with self.assertRaisesRegex(SyntaxError, 'await.*in comprehen'): + exec('async def f():\n\t{}'.format(comp), ns, ns) + + def test_badsyntax_10(self): + # Tests for issue 24619 + + samples = [ + """async def foo(): + def bar(): pass + await = 1 + """, + + """async def foo(): + + def bar(): pass + await = 1 + """, + + """async def foo(): + def bar(): pass + if 1: + await = 1 + """, + + """def foo(): + async def bar(): pass + if 1: + await a + """, + + """def foo(): + async def bar(): pass + await a + """, + + """def foo(): + def baz(): pass + async def bar(): pass + await a + """, + + """def foo(): + def baz(): pass + # 456 + async def bar(): pass + # 123 + await a + """, + + """async def foo(): + def baz(): pass + # 456 + async def bar(): pass + # 123 + await = 2 + """, + + """def foo(): + + def baz(): pass + + async def bar(): pass + + await a + """, + + """async def foo(): + + def baz(): pass + + async def bar(): pass + + await = 2 + """, + + """async def foo(): + def async(): pass + """, + + """async def foo(): + def await(): pass + """, + + """async def foo(): + def bar(): + await + """, + + """async def foo(): + return lambda async: await + """, + + """async def foo(): + return lambda a: await + """, + + """await a()""", + + """async def foo(a=await b): + pass + """, + + """async def foo(a:await b): + pass + """, + + """def baz(): + async def foo(a=await b): + pass + """, + + """async def foo(async): + pass + """, + + """async def foo(): + def bar(): + def baz(): + async = 1 + """, + + """async def foo(): + def bar(): + def baz(): + pass + async = 1 + """, + + """def foo(): + async def bar(): + + async def baz(): + pass + + def baz(): + 42 + + async = 1 + """, + + """async def foo(): + def bar(): + def baz(): + pass\nawait foo() + """, + + """def foo(): + def bar(): + async def baz(): + pass\nawait foo() + """, + + """async def foo(await): + pass + """, + + """def foo(): + + async def bar(): pass + + await a + """, + + """def foo(): + async def bar(): + pass\nawait a + """] + + for code in samples: + with self.subTest(code=code), self.assertRaises(SyntaxError): + compile(code, "<test>", "exec") + + def test_goodsyntax_1(self): + # Tests for issue 24619 + + def foo(await): + async def foo(): pass + async def foo(): + pass + return await + 1 + self.assertEqual(foo(10), 11) + + def foo(await): + async def foo(): pass + async def foo(): pass + return await + 2 + self.assertEqual(foo(20), 22) + + def foo(await): + + async def foo(): pass + + async def foo(): pass + + return await + 2 + self.assertEqual(foo(20), 22) + + def foo(await): + """spam""" + async def foo(): \ + pass + # 123 + async def foo(): pass + # 456 + return await + 2 + self.assertEqual(foo(20), 22) + + def foo(await): + def foo(): pass + def foo(): pass + async def bar(): return await_ + await_ = await + try: + bar().send(None) + except StopIteration as ex: + return ex.args[0] + self.assertEqual(foo(42), 42) + + async def f(): + async def g(): pass + await z + await = 1 + self.assertTrue(inspect.iscoroutinefunction(f)) + + +class TokenizerRegrTest(unittest.TestCase): + + def test_oneline_defs(self): + buf = [] + for i in range(500): + buf.append('def i{i}(): return {i}'.format(i=i)) + buf = '\n'.join(buf) + + # Test that 500 consequent, one-line defs is OK + ns = {} + exec(buf, ns, ns) + self.assertEqual(ns['i499'](), 499) + + # Test that 500 consequent, one-line defs *and* + # one 'async def' following them is OK + buf += '\nasync def foo():\n return' + ns = {} + exec(buf, ns, ns) + self.assertEqual(ns['i499'](), 499) + self.assertTrue(inspect.iscoroutinefunction(ns['foo'])) + + +class CoroutineTest(unittest.TestCase): + + def test_gen_1(self): + def gen(): yield + self.assertFalse(hasattr(gen, '__await__')) + + def test_func_1(self): + async def foo(): + return 10 + + f = foo() + self.assertIsInstance(f, types.CoroutineType) + self.assertTrue(bool(foo.__code__.co_flags & inspect.CO_COROUTINE)) + self.assertFalse(bool(foo.__code__.co_flags & inspect.CO_GENERATOR)) + self.assertTrue(bool(f.cr_code.co_flags & inspect.CO_COROUTINE)) + self.assertFalse(bool(f.cr_code.co_flags & inspect.CO_GENERATOR)) + self.assertEqual(run_async(f), ([], 10)) + + self.assertEqual(run_async__await__(foo()), ([], 10)) + + def bar(): pass + self.assertFalse(bool(bar.__code__.co_flags & inspect.CO_COROUTINE)) + + def test_func_2(self): + async def foo(): + raise StopIteration + + with self.assertRaisesRegex( + RuntimeError, "coroutine raised StopIteration"): + + run_async(foo()) + + def test_func_3(self): + async def foo(): + raise StopIteration + + with silence_coro_gc(): + self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$') + + def test_func_4(self): + async def foo(): + raise StopIteration + + check = lambda: self.assertRaisesRegex( + TypeError, "'coroutine' object is not iterable") + + with check(): + list(foo()) + + with check(): + tuple(foo()) + + with check(): + sum(foo()) + + with check(): + iter(foo()) + + with silence_coro_gc(), check(): + for i in foo(): + pass + + with silence_coro_gc(), check(): + [i for i in foo()] + + def test_func_5(self): + @types.coroutine + def bar(): + yield 1 + + async def foo(): + await bar() + + check = lambda: self.assertRaisesRegex( + TypeError, "'coroutine' object is not iterable") + + with check(): + for el in foo(): pass + + # the following should pass without an error + for el in bar(): + self.assertEqual(el, 1) + self.assertEqual([el for el in bar()], [1]) + self.assertEqual(tuple(bar()), (1,)) + self.assertEqual(next(iter(bar())), 1) + + def test_func_6(self): + @types.coroutine + def bar(): + yield 1 + yield 2 + + async def foo(): + await bar() + + f = foo() + self.assertEqual(f.send(None), 1) + self.assertEqual(f.send(None), 2) + with self.assertRaises(StopIteration): + f.send(None) + + def test_func_7(self): + async def bar(): + return 10 + + def foo(): + yield from bar() + + with silence_coro_gc(), self.assertRaisesRegex( + TypeError, + "cannot 'yield from' a coroutine object in a non-coroutine generator"): + + list(foo()) + + def test_func_8(self): + @types.coroutine + def bar(): + return (yield from foo()) + + async def foo(): + return 'spam' + + self.assertEqual(run_async(bar()), ([], 'spam') ) + + def test_func_9(self): + async def foo(): pass + + with self.assertWarnsRegex( + RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"): + + foo() + support.gc_collect() + + def test_func_10(self): + N = 0 + + @types.coroutine + def gen(): + nonlocal N + try: + a = yield + yield (a ** 2) + except ZeroDivisionError: + N += 100 + raise + finally: + N += 1 + + async def foo(): + await gen() + + coro = foo() + aw = coro.__await__() + self.assertIs(aw, iter(aw)) + next(aw) + self.assertEqual(aw.send(10), 100) + + self.assertEqual(N, 0) + aw.close() + self.assertEqual(N, 1) + + coro = foo() + aw = coro.__await__() + next(aw) + with self.assertRaises(ZeroDivisionError): + aw.throw(ZeroDivisionError, None, None) + self.assertEqual(N, 102) + + def test_func_11(self): + async def func(): pass + coro = func() + # Test that PyCoro_Type and _PyCoroWrapper_Type types were properly + # initialized + self.assertIn('__await__', dir(coro)) + self.assertIn('__iter__', dir(coro.__await__())) + self.assertIn('coroutine_wrapper', repr(coro.__await__())) + coro.close() # avoid RuntimeWarning + + def test_func_12(self): + async def g(): + i = me.send(None) + await foo + me = g() + with self.assertRaisesRegex(ValueError, + "coroutine already executing"): + me.send(None) + + def test_func_13(self): + async def g(): + pass + with self.assertRaisesRegex( + TypeError, + "can't send non-None value to a just-started coroutine"): + + g().send('spam') + + def test_func_14(self): + @types.coroutine + def gen(): + yield + async def coro(): + try: + await gen() + except GeneratorExit: + await gen() + c = coro() + c.send(None) + with self.assertRaisesRegex(RuntimeError, + "coroutine ignored GeneratorExit"): + c.close() + + def test_cr_await(self): + @types.coroutine + def a(): + self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING) + self.assertIsNone(coro_b.cr_await) + yield + self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING) + self.assertIsNone(coro_b.cr_await) + + async def c(): + await a() + + async def b(): + self.assertIsNone(coro_b.cr_await) + await c() + self.assertIsNone(coro_b.cr_await) + + coro_b = b() + self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CREATED) + self.assertIsNone(coro_b.cr_await) + + coro_b.send(None) + self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_SUSPENDED) + self.assertEqual(coro_b.cr_await.cr_await.gi_code.co_name, 'a') + + with self.assertRaises(StopIteration): + coro_b.send(None) # complete coroutine + self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CLOSED) + self.assertIsNone(coro_b.cr_await) + + def test_corotype_1(self): + ct = types.CoroutineType + self.assertIn('into coroutine', ct.send.__doc__) + self.assertIn('inside coroutine', ct.close.__doc__) + self.assertIn('in coroutine', ct.throw.__doc__) + self.assertIn('of the coroutine', ct.__dict__['__name__'].__doc__) + self.assertIn('of the coroutine', ct.__dict__['__qualname__'].__doc__) + self.assertEqual(ct.__name__, 'coroutine') + + async def f(): pass + c = f() + self.assertIn('coroutine object', repr(c)) + c.close() + + def test_await_1(self): + + async def foo(): + await 1 + with self.assertRaisesRegex(TypeError, "object int can.t.*await"): + run_async(foo()) + + def test_await_2(self): + async def foo(): + await [] + with self.assertRaisesRegex(TypeError, "object list can.t.*await"): + run_async(foo()) + + def test_await_3(self): + async def foo(): + await AsyncYieldFrom([1, 2, 3]) + + self.assertEqual(run_async(foo()), ([1, 2, 3], None)) + self.assertEqual(run_async__await__(foo()), ([1, 2, 3], None)) + + def test_await_4(self): + async def bar(): + return 42 + + async def foo(): + return await bar() + + self.assertEqual(run_async(foo()), ([], 42)) + + def test_await_5(self): + class Awaitable: + def __await__(self): + return + + async def foo(): + return (await Awaitable()) + + with self.assertRaisesRegex( + TypeError, "__await__.*returned non-iterator of type"): + + run_async(foo()) + + def test_await_6(self): + class Awaitable: + def __await__(self): + return iter([52]) + + async def foo(): + return (await Awaitable()) + + self.assertEqual(run_async(foo()), ([52], None)) + + def test_await_7(self): + class Awaitable: + def __await__(self): + yield 42 + return 100 + + async def foo(): + return (await Awaitable()) + + self.assertEqual(run_async(foo()), ([42], 100)) + + def test_await_8(self): + class Awaitable: + pass + + async def foo(): return await Awaitable() + + with self.assertRaisesRegex( + TypeError, "object Awaitable can't be used in 'await' expression"): + + run_async(foo()) + + def test_await_9(self): + def wrap(): + return bar + + async def bar(): + return 42 + + async def foo(): + b = bar() + + db = {'b': lambda: wrap} + + class DB: + b = wrap + + return (await bar() + await wrap()() + await db['b']()()() + + await bar() * 1000 + await DB.b()()) + + async def foo2(): + return -await bar() + + self.assertEqual(run_async(foo()), ([], 42168)) + self.assertEqual(run_async(foo2()), ([], -42)) + + def test_await_10(self): + async def baz(): + return 42 + + async def bar(): + return baz() + + async def foo(): + return await (await bar()) + + self.assertEqual(run_async(foo()), ([], 42)) + + def test_await_11(self): + def ident(val): + return val + + async def bar(): + return 'spam' + + async def foo(): + return ident(val=await bar()) + + async def foo2(): + return await bar(), 'ham' + + self.assertEqual(run_async(foo2()), ([], ('spam', 'ham'))) + + def test_await_12(self): + async def coro(): + return 'spam' + + class Awaitable: + def __await__(self): + return coro() + + async def foo(): + return await Awaitable() + + with self.assertRaisesRegex( + TypeError, "__await__\(\) returned a coroutine"): + + run_async(foo()) + + def test_await_13(self): + class Awaitable: + def __await__(self): + return self + + async def foo(): + return await Awaitable() + + with self.assertRaisesRegex( + TypeError, "__await__.*returned non-iterator of type"): + + run_async(foo()) + + def test_await_14(self): + class Wrapper: + # Forces the interpreter to use CoroutineType.__await__ + def __init__(self, coro): + assert coro.__class__ is types.CoroutineType + self.coro = coro + def __await__(self): + return self.coro.__await__() + + class FutureLike: + def __await__(self): + return (yield) + + class Marker(Exception): + pass + + async def coro1(): + try: + return await FutureLike() + except ZeroDivisionError: + raise Marker + async def coro2(): + return await Wrapper(coro1()) + + c = coro2() + c.send(None) + with self.assertRaisesRegex(StopIteration, 'spam'): + c.send('spam') + + c = coro2() + c.send(None) + with self.assertRaises(Marker): + c.throw(ZeroDivisionError) + + def test_with_1(self): + class Manager: + def __init__(self, name): + self.name = name + + async def __aenter__(self): + await AsyncYieldFrom(['enter-1-' + self.name, + 'enter-2-' + self.name]) + return self + + async def __aexit__(self, *args): + await AsyncYieldFrom(['exit-1-' + self.name, + 'exit-2-' + self.name]) + + if self.name == 'B': + return True + + + async def foo(): + async with Manager("A") as a, Manager("B") as b: + await AsyncYieldFrom([('managers', a.name, b.name)]) + 1/0 + + f = foo() + result, _ = run_async(f) + + self.assertEqual( + result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B', + ('managers', 'A', 'B'), + 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A'] + ) + + async def foo(): + async with Manager("A") as a, Manager("C") as c: + await AsyncYieldFrom([('managers', a.name, c.name)]) + 1/0 + + with self.assertRaises(ZeroDivisionError): + run_async(foo()) + + def test_with_2(self): + class CM: + def __aenter__(self): + pass + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex(AttributeError, '__aexit__'): + run_async(foo()) + + def test_with_3(self): + class CM: + def __aexit__(self): + pass + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex(AttributeError, '__aenter__'): + run_async(foo()) + + def test_with_4(self): + class CM: + def __enter__(self): + pass + + def __exit__(self): + pass + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex(AttributeError, '__aexit__'): + run_async(foo()) + + def test_with_5(self): + # While this test doesn't make a lot of sense, + # it's a regression test for an early bug with opcodes + # generation + + class CM: + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + pass + + async def func(): + async with CM(): + assert (1, ) == 1 + + with self.assertRaises(AssertionError): + run_async(func()) + + def test_with_6(self): + class CM: + def __aenter__(self): + return 123 + + def __aexit__(self, *e): + return 456 + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex( + TypeError, "object int can't be used in 'await' expression"): + # it's important that __aexit__ wasn't called + run_async(foo()) + + def test_with_7(self): + class CM: + async def __aenter__(self): + return self + + def __aexit__(self, *e): + return 444 + + async def foo(): + async with CM(): + 1/0 + + try: + run_async(foo()) + except TypeError as exc: + self.assertRegex( + exc.args[0], "object int can't be used in 'await' expression") + self.assertTrue(exc.__context__ is not None) + self.assertTrue(isinstance(exc.__context__, ZeroDivisionError)) + else: + self.fail('invalid asynchronous context manager did not fail') + + + def test_with_8(self): + CNT = 0 + + class CM: + async def __aenter__(self): + return self + + def __aexit__(self, *e): + return 456 + + async def foo(): + nonlocal CNT + async with CM(): + CNT += 1 + + + with self.assertRaisesRegex( + TypeError, "object int can't be used in 'await' expression"): + + run_async(foo()) + + self.assertEqual(CNT, 1) + + + def test_with_9(self): + CNT = 0 + + class CM: + async def __aenter__(self): + return self + + async def __aexit__(self, *e): + 1/0 + + async def foo(): + nonlocal CNT + async with CM(): + CNT += 1 + + with self.assertRaises(ZeroDivisionError): + run_async(foo()) + + self.assertEqual(CNT, 1) + + def test_with_10(self): + CNT = 0 + + class CM: + async def __aenter__(self): + return self + + async def __aexit__(self, *e): + 1/0 + + async def foo(): + nonlocal CNT + async with CM(): + async with CM(): + raise RuntimeError + + try: + run_async(foo()) + except ZeroDivisionError as exc: + self.assertTrue(exc.__context__ is not None) + self.assertTrue(isinstance(exc.__context__, ZeroDivisionError)) + self.assertTrue(isinstance(exc.__context__.__context__, + RuntimeError)) + else: + self.fail('exception from __aexit__ did not propagate') + + def test_with_11(self): + CNT = 0 + + class CM: + async def __aenter__(self): + raise NotImplementedError + + async def __aexit__(self, *e): + 1/0 + + async def foo(): + nonlocal CNT + async with CM(): + raise RuntimeError + + try: + run_async(foo()) + except NotImplementedError as exc: + self.assertTrue(exc.__context__ is None) + else: + self.fail('exception from __aenter__ did not propagate') + + def test_with_12(self): + CNT = 0 + + class CM: + async def __aenter__(self): + return self + + async def __aexit__(self, *e): + return True + + async def foo(): + nonlocal CNT + async with CM() as cm: + self.assertIs(cm.__class__, CM) + raise RuntimeError + + run_async(foo()) + + def test_with_13(self): + CNT = 0 + + class CM: + async def __aenter__(self): + 1/0 + + async def __aexit__(self, *e): + return True + + async def foo(): + nonlocal CNT + CNT += 1 + async with CM(): + CNT += 1000 + CNT += 10000 + + with self.assertRaises(ZeroDivisionError): + run_async(foo()) + self.assertEqual(CNT, 1) + + def test_for_1(self): + aiter_calls = 0 + + class AsyncIter: + def __init__(self): + self.i = 0 + + async def __aiter__(self): + nonlocal aiter_calls + aiter_calls += 1 + return self + + async def __anext__(self): + self.i += 1 + + if not (self.i % 10): + await AsyncYield(self.i * 10) + + if self.i > 100: + raise StopAsyncIteration + + return self.i, self.i + + + buffer = [] + async def test1(): + async for i1, i2 in AsyncIter(): + buffer.append(i1 + i2) + + yielded, _ = run_async(test1()) + # Make sure that __aiter__ was called only once + self.assertEqual(aiter_calls, 1) + self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) + self.assertEqual(buffer, [i*2 for i in range(1, 101)]) + + + buffer = [] + async def test2(): + nonlocal buffer + async for i in AsyncIter(): + buffer.append(i[0]) + if i[0] == 20: + break + else: + buffer.append('what?') + buffer.append('end') + + yielded, _ = run_async(test2()) + # Make sure that __aiter__ was called only once + self.assertEqual(aiter_calls, 2) + self.assertEqual(yielded, [100, 200]) + self.assertEqual(buffer, [i for i in range(1, 21)] + ['end']) + + + buffer = [] + async def test3(): + nonlocal buffer + async for i in AsyncIter(): + if i[0] > 20: + continue + buffer.append(i[0]) + else: + buffer.append('what?') + buffer.append('end') + + yielded, _ = run_async(test3()) + # Make sure that __aiter__ was called only once + self.assertEqual(aiter_calls, 3) + self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) + self.assertEqual(buffer, [i for i in range(1, 21)] + + ['what?', 'end']) + + def test_for_2(self): + tup = (1, 2, 3) + refs_before = sys.getrefcount(tup) + + async def foo(): + async for i in tup: + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, "async for' requires an object.*__aiter__.*tuple"): + + run_async(foo()) + + self.assertEqual(sys.getrefcount(tup), refs_before) + + def test_for_3(self): + class I: + def __aiter__(self): + return self + + aiter = I() + refs_before = sys.getrefcount(aiter) + + async def foo(): + async for i in aiter: + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, + "async for' received an invalid object.*__aiter.*\: I"): + + run_async(foo()) + + self.assertEqual(sys.getrefcount(aiter), refs_before) + + def test_for_4(self): + class I: + async def __aiter__(self): + return self + + def __anext__(self): + return () + + aiter = I() + refs_before = sys.getrefcount(aiter) + + async def foo(): + async for i in aiter: + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, + "async for' received an invalid object.*__anext__.*tuple"): + + run_async(foo()) + + self.assertEqual(sys.getrefcount(aiter), refs_before) + + def test_for_5(self): + class I: + async def __aiter__(self): + return self + + def __anext__(self): + return 123 + + async def foo(): + async for i in I(): + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, + "async for' received an invalid object.*__anext.*int"): + + run_async(foo()) + + def test_for_6(self): + I = 0 + + class Manager: + async def __aenter__(self): + nonlocal I + I += 10000 + + async def __aexit__(self, *args): + nonlocal I + I += 100000 + + class Iterable: + def __init__(self): + self.i = 0 + + async def __aiter__(self): + return self + + async def __anext__(self): + if self.i > 10: + raise StopAsyncIteration + self.i += 1 + return self.i + + ############## + + manager = Manager() + iterable = Iterable() + mrefs_before = sys.getrefcount(manager) + irefs_before = sys.getrefcount(iterable) + + async def main(): + nonlocal I + + async with manager: + async for i in iterable: + I += 1 + I += 1000 + + run_async(main()) + self.assertEqual(I, 111011) + + self.assertEqual(sys.getrefcount(manager), mrefs_before) + self.assertEqual(sys.getrefcount(iterable), irefs_before) + + ############## + + async def main(): + nonlocal I + + async with Manager(): + async for i in Iterable(): + I += 1 + I += 1000 + + async with Manager(): + async for i in Iterable(): + I += 1 + I += 1000 + + run_async(main()) + self.assertEqual(I, 333033) + + ############## + + async def main(): + nonlocal I + + async with Manager(): + I += 100 + async for i in Iterable(): + I += 1 + else: + I += 10000000 + I += 1000 + + async with Manager(): + I += 100 + async for i in Iterable(): + I += 1 + else: + I += 10000000 + I += 1000 + + run_async(main()) + self.assertEqual(I, 20555255) + + def test_for_7(self): + CNT = 0 + class AI: + async def __aiter__(self): + 1/0 + async def foo(): + nonlocal CNT + async for i in AI(): + CNT += 1 + CNT += 10 + with self.assertRaises(ZeroDivisionError): + run_async(foo()) + self.assertEqual(CNT, 0) + + +class CoroAsyncIOCompatTest(unittest.TestCase): + + def test_asyncio_1(self): + import asyncio + + class MyException(Exception): + pass + + buffer = [] + + class CM: + async def __aenter__(self): + buffer.append(1) + await asyncio.sleep(0.01) + buffer.append(2) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await asyncio.sleep(0.01) + buffer.append(exc_type.__name__) + + async def f(): + async with CM() as c: + await asyncio.sleep(0.01) + raise MyException + buffer.append('unreachable') + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(f()) + except MyException: + pass + finally: + loop.close() + asyncio.set_event_loop(None) + + self.assertEqual(buffer, [1, 2, 'MyException']) + + +class SysSetCoroWrapperTest(unittest.TestCase): + + def test_set_wrapper_1(self): + async def foo(): + return 'spam' + + wrapped = None + def wrap(gen): + nonlocal wrapped + wrapped = gen + return gen + + self.assertIsNone(sys.get_coroutine_wrapper()) + + sys.set_coroutine_wrapper(wrap) + self.assertIs(sys.get_coroutine_wrapper(), wrap) + try: + f = foo() + self.assertTrue(wrapped) + + self.assertEqual(run_async(f), ([], 'spam')) + finally: + sys.set_coroutine_wrapper(None) + + self.assertIsNone(sys.get_coroutine_wrapper()) + + wrapped = None + with silence_coro_gc(): + foo() + self.assertFalse(wrapped) + + def test_set_wrapper_2(self): + self.assertIsNone(sys.get_coroutine_wrapper()) + with self.assertRaisesRegex(TypeError, "callable expected, got int"): + sys.set_coroutine_wrapper(1) + self.assertIsNone(sys.get_coroutine_wrapper()) + + def test_set_wrapper_3(self): + async def foo(): + return 'spam' + + def wrapper(coro): + async def wrap(coro): + return await coro + return wrap(coro) + + sys.set_coroutine_wrapper(wrapper) + try: + with silence_coro_gc(), self.assertRaisesRegex( + RuntimeError, + "coroutine wrapper.*\.wrapper at 0x.*attempted to " + "recursively wrap .* wrap .*"): + + foo() + finally: + sys.set_coroutine_wrapper(None) + + def test_set_wrapper_4(self): + @types.coroutine + def foo(): + return 'spam' + + wrapped = None + def wrap(gen): + nonlocal wrapped + wrapped = gen + return gen + + sys.set_coroutine_wrapper(wrap) + try: + foo() + self.assertIs( + wrapped, None, + "generator-based coroutine was wrapped via " + "sys.set_coroutine_wrapper") + finally: + sys.set_coroutine_wrapper(None) + + +class CAPITest(unittest.TestCase): + + def test_tp_await_1(self): + from _testcapi import awaitType as at + + async def foo(): + future = at(iter([1])) + return (await future) + + self.assertEqual(foo().send(None), 1) + + def test_tp_await_2(self): + # Test tp_await to __await__ mapping + from _testcapi import awaitType as at + future = at(iter([1])) + self.assertEqual(next(future.__await__()), 1) + + def test_tp_await_3(self): + from _testcapi import awaitType as at + + async def foo(): + future = at(1) + return (await future) + + with self.assertRaisesRegex( + TypeError, "__await__.*returned non-iterator of type 'int'"): + self.assertEqual(foo().send(None), 1) + + +if __name__=="__main__": + unittest.main() diff --git a/Lib/test/test_cprofile.py b/Lib/test/test_cprofile.py index ce5d27e..f18983f 100644 --- a/Lib/test/test_cprofile.py +++ b/Lib/test/test_cprofile.py @@ -11,7 +11,7 @@ from test.profilee import testfunc class CProfileTest(ProfileTest): profilerclass = cProfile.Profile profilermodule = cProfile - expected_max_output = "{built-in method max}" + expected_max_output = "{built-in method builtins.max}" def get_expected_output(self): return _ProfileOutput @@ -72,9 +72,9 @@ profilee.py:84(helper2_indirect) <- 2 0.000 0.140 profilee.py:88(helper2) <- 6 0.234 0.300 profilee.py:55(helper) 2 0.078 0.100 profilee.py:84(helper2_indirect) profilee.py:98(subhelper) <- 8 0.064 0.080 profilee.py:88(helper2) -{built-in method exc_info} <- 4 0.000 0.000 profilee.py:73(helper1) -{built-in method hasattr} <- 4 0.000 0.004 profilee.py:73(helper1) +{built-in method builtins.hasattr} <- 4 0.000 0.004 profilee.py:73(helper1) 8 0.000 0.008 profilee.py:88(helper2) +{built-in method sys.exc_info} <- 4 0.000 0.000 profilee.py:73(helper1) {method 'append' of 'list' objects} <- 4 0.000 0.000 profilee.py:73(helper1)""" _ProfileOutput['print_callees'] = """\ <string>:1(<module>) -> 1 0.270 1.000 profilee.py:25(testfunc) @@ -87,12 +87,12 @@ profilee.py:48(mul) -> profilee.py:55(helper) -> 4 0.116 0.120 profilee.py:73(helper1) 2 0.000 0.140 profilee.py:84(helper2_indirect) 6 0.234 0.300 profilee.py:88(helper2) -profilee.py:73(helper1) -> 4 0.000 0.000 {built-in method exc_info} +profilee.py:73(helper1) -> 4 0.000 0.004 {built-in method builtins.hasattr} profilee.py:84(helper2_indirect) -> 2 0.006 0.040 profilee.py:35(factorial) 2 0.078 0.100 profilee.py:88(helper2) profilee.py:88(helper2) -> 8 0.064 0.080 profilee.py:98(subhelper) profilee.py:98(subhelper) -> 16 0.016 0.016 profilee.py:110(__getattr__) -{built-in method hasattr} -> 12 0.012 0.012 profilee.py:110(__getattr__)""" +{built-in method builtins.hasattr} -> 12 0.012 0.012 profilee.py:110(__getattr__)""" if __name__ == "__main__": main() diff --git a/Lib/test/test_crashers.py b/Lib/test/test_crashers.py index 336ccbe..58dfd00 100644 --- a/Lib/test/test_crashers.py +++ b/Lib/test/test_crashers.py @@ -8,7 +8,7 @@ import unittest import glob import os.path import test.support -from test.script_helper import assert_python_failure +from test.support.script_helper import assert_python_failure CRASHER_DIR = os.path.join(os.path.dirname(__file__), "crashers") CRASHER_FILES = os.path.join(CRASHER_DIR, "*.py") @@ -30,9 +30,8 @@ class CrasherTest(unittest.TestCase): assert_python_failure(fname) -def test_main(): - test.support.run_unittest(CrasherTest) +def tearDownModule(): test.support.reap_children() if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py index 65449ae..8e9c2b4 100644 --- a/Lib/test/test_csv.py +++ b/Lib/test/test_csv.py @@ -186,6 +186,14 @@ class Test_Csv(unittest.TestCase): self._write_test(['a',1,'p,q'], 'a,1,p\\,q', escapechar='\\', quoting = csv.QUOTE_NONE) + def test_write_iterable(self): + self._write_test(iter(['a', 1, 'p,q']), 'a,1,"p,q"') + self._write_test(iter(['a', 1, None]), 'a,1,') + self._write_test(iter([]), '') + self._write_test(iter([None]), '""') + self._write_error_test(csv.Error, iter([None]), quoting=csv.QUOTE_NONE) + self._write_test(iter([None, None]), ',') + def test_writerows(self): class BrokenFile: def write(self, buf): @@ -578,6 +586,16 @@ class TestDictFields(unittest.TestCase): fileobj.readline() # header self.assertEqual(fileobj.read(), "10,,abc\r\n") + def test_write_multiple_dict_rows(self): + fileobj = StringIO() + writer = csv.DictWriter(fileobj, fieldnames=["f1", "f2", "f3"]) + writer.writeheader() + self.assertEqual(fileobj.getvalue(), "f1,f2,f3\r\n") + writer.writerows([{"f1": 1, "f2": "abc", "f3": "f"}, + {"f1": 2, "f2": 5, "f3": "xyz"}]) + self.assertEqual(fileobj.getvalue(), + "f1,f2,f3\r\n1,abc,f\r\n2,5,xyz\r\n") + def test_write_no_fields(self): fileobj = StringIO() self.assertRaises(TypeError, csv.DictWriter, fileobj) @@ -776,7 +794,7 @@ class TestDialectValidity(unittest.TestCase): with self.assertRaises(csv.Error) as cm: mydialect() self.assertEqual(str(cm.exception), - '"quotechar" must be an 1-character string') + '"quotechar" must be a 1-character string') mydialect.quotechar = 4 with self.assertRaises(csv.Error) as cm: @@ -799,13 +817,13 @@ class TestDialectValidity(unittest.TestCase): with self.assertRaises(csv.Error) as cm: mydialect() self.assertEqual(str(cm.exception), - '"delimiter" must be an 1-character string') + '"delimiter" must be a 1-character string') mydialect.delimiter = "" with self.assertRaises(csv.Error) as cm: mydialect() self.assertEqual(str(cm.exception), - '"delimiter" must be an 1-character string') + '"delimiter" must be a 1-character string') mydialect.delimiter = b"," with self.assertRaises(csv.Error) as cm: @@ -1066,11 +1084,5 @@ class TestUnicode(unittest.TestCase): self.assertEqual(fileobj.read(), expected) -def test_main(): - mod = sys.modules[__name__] - support.run_unittest( - *[getattr(mod, name) for name in dir(mod) if name.startswith('Test')] - ) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_curses.py b/Lib/test/test_curses.py index bd7d4fc..2747041 100644 --- a/Lib/test/test_curses.py +++ b/Lib/test/test_curses.py @@ -370,6 +370,13 @@ class TestCurses(unittest.TestCase): offset = human_readable_signature.find("[y, x,]") assert offset >= 0, "" + def test_update_lines_cols(self): + # this doesn't actually test that LINES and COLS are updated, + # because we can't automate changing them. See Issue #4254 for + # a manual test script. We can only test that the function + # can be called. + curses.update_lines_cols() + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_datetime.py b/Lib/test/test_datetime.py index d9ddb32..2d4eb52 100644 --- a/Lib/test/test_datetime.py +++ b/Lib/test/test_datetime.py @@ -1,20 +1,20 @@ import unittest import sys + from test.support import import_fresh_module, run_unittest TESTS = 'test.datetimetester' -# XXX: import_fresh_module() is supposed to leave sys.module cache untouched, -# XXX: but it does not, so we have to save and restore it ourselves. -save_sys_modules = sys.modules.copy() try: pure_tests = import_fresh_module(TESTS, fresh=['datetime', '_strptime'], blocked=['_datetime']) fast_tests = import_fresh_module(TESTS, fresh=['datetime', '_datetime', '_strptime']) finally: - sys.modules.clear() - sys.modules.update(save_sys_modules) + # XXX: import_fresh_module() is supposed to leave sys.module cache untouched, + # XXX: but it does not, so we have to cleanup ourselves. + for modname in ['datetime', '_datetime', '_strptime']: + sys.modules.pop(modname, None) test_modules = [pure_tests, fast_tests] test_suffixes = ["_Pure", "_Fast"] # XXX(gb) First run all the _Pure tests, then all the _Fast tests. You might diff --git a/Lib/test/test_dbm_dumb.py b/Lib/test/test_dbm_dumb.py index dc88ca6..ff63c88 100644 --- a/Lib/test/test_dbm_dumb.py +++ b/Lib/test/test_dbm_dumb.py @@ -217,6 +217,14 @@ class DumbDBMTestCase(unittest.TestCase): self.assertEqual(str(cm.exception), "DBM object has already been closed") + def test_create_new(self): + with dumbdbm.open(_fname, 'n') as f: + for k in self._dict: + f[k] = self._dict[k] + + with dumbdbm.open(_fname, 'n') as f: + self.assertEqual(f.keys(), []) + def test_eval(self): with open(_fname + '.dir', 'w') as stream: stream.write("str(print('Hacked!')), 0\n") diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index a178f6f..137aaa5 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -33,12 +33,13 @@ import unittest import numbers import locale from test.support import (run_unittest, run_doctest, is_resource_enabled, - requires_IEEE_754) + requires_IEEE_754, requires_docstrings) from test.support import (check_warnings, import_fresh_module, TestFailed, run_with_locale, cpython_only) import random import time import warnings +import inspect try: import threading except ImportError: @@ -4174,9 +4175,7 @@ class CheckAttributes(unittest.TestCase): self.assertEqual(C.__version__, P.__version__) self.assertEqual(C.__libmpdec_version__, P.__libmpdec_version__) - x = dir(C) - y = [s for s in dir(P) if '__' in s or not s.startswith('_')] - self.assertEqual(set(x) - set(y), set()) + self.assertEqual(dir(C), dir(P)) def test_context_attributes(self): @@ -4455,18 +4454,6 @@ class PyCoverage(Coverage): class PyFunctionality(unittest.TestCase): """Extra functionality in decimal.py""" - def test_py_quantize_watchexp(self): - # watchexp functionality - Decimal = P.Decimal - localcontext = P.localcontext - - with localcontext() as c: - c.prec = 1 - c.Emax = 1 - c.Emin = -1 - x = Decimal(99999).quantize(Decimal("1e3"), watchexp=False) - self.assertEqual(x, Decimal('1.00E+5')) - def test_py_alternate_formatting(self): # triples giving a format, a Decimal, and the expected result Decimal = P.Decimal @@ -5409,6 +5396,143 @@ class CWhitebox(unittest.TestCase): y = Decimal(10**(9*25)).__sizeof__() self.assertEqual(y, x+4) +@requires_docstrings +@unittest.skipUnless(C, "test requires C version") +class SignatureTest(unittest.TestCase): + """Function signatures""" + + def test_inspect_module(self): + for attr in dir(P): + if attr.startswith('_'): + continue + p_func = getattr(P, attr) + c_func = getattr(C, attr) + if (attr == 'Decimal' or attr == 'Context' or + inspect.isfunction(p_func)): + p_sig = inspect.signature(p_func) + c_sig = inspect.signature(c_func) + + # parameter names: + c_names = list(c_sig.parameters.keys()) + p_names = [x for x in p_sig.parameters.keys() if not + x.startswith('_')] + + self.assertEqual(c_names, p_names, + msg="parameter name mismatch in %s" % p_func) + + c_kind = [x.kind for x in c_sig.parameters.values()] + p_kind = [x[1].kind for x in p_sig.parameters.items() if not + x[0].startswith('_')] + + # parameters: + if attr != 'setcontext': + self.assertEqual(c_kind, p_kind, + msg="parameter kind mismatch in %s" % p_func) + + def test_inspect_types(self): + + POS = inspect._ParameterKind.POSITIONAL_ONLY + POS_KWD = inspect._ParameterKind.POSITIONAL_OR_KEYWORD + + # Type heuristic (type annotations would help!): + pdict = {C: {'other': C.Decimal(1), + 'third': C.Decimal(1), + 'x': C.Decimal(1), + 'y': C.Decimal(1), + 'z': C.Decimal(1), + 'a': C.Decimal(1), + 'b': C.Decimal(1), + 'c': C.Decimal(1), + 'exp': C.Decimal(1), + 'modulo': C.Decimal(1), + 'num': "1", + 'f': 1.0, + 'rounding': C.ROUND_HALF_UP, + 'context': C.getcontext()}, + P: {'other': P.Decimal(1), + 'third': P.Decimal(1), + 'a': P.Decimal(1), + 'b': P.Decimal(1), + 'c': P.Decimal(1), + 'exp': P.Decimal(1), + 'modulo': P.Decimal(1), + 'num': "1", + 'f': 1.0, + 'rounding': P.ROUND_HALF_UP, + 'context': P.getcontext()}} + + def mkargs(module, sig): + args = [] + kwargs = {} + for name, param in sig.parameters.items(): + if name == 'self': continue + if param.kind == POS: + args.append(pdict[module][name]) + elif param.kind == POS_KWD: + kwargs[name] = pdict[module][name] + else: + raise TestFailed("unexpected parameter kind") + return args, kwargs + + def tr(s): + """The C Context docstrings use 'x' in order to prevent confusion + with the article 'a' in the descriptions.""" + if s == 'x': return 'a' + if s == 'y': return 'b' + if s == 'z': return 'c' + return s + + def doit(ty): + p_type = getattr(P, ty) + c_type = getattr(C, ty) + for attr in dir(p_type): + if attr.startswith('_'): + continue + p_func = getattr(p_type, attr) + c_func = getattr(c_type, attr) + if inspect.isfunction(p_func): + p_sig = inspect.signature(p_func) + c_sig = inspect.signature(c_func) + + # parameter names: + p_names = list(p_sig.parameters.keys()) + c_names = [tr(x) for x in c_sig.parameters.keys()] + + self.assertEqual(c_names, p_names, + msg="parameter name mismatch in %s" % p_func) + + p_kind = [x.kind for x in p_sig.parameters.values()] + c_kind = [x.kind for x in c_sig.parameters.values()] + + # 'self' parameter: + self.assertIs(p_kind[0], POS_KWD) + self.assertIs(c_kind[0], POS) + + # remaining parameters: + if ty == 'Decimal': + self.assertEqual(c_kind[1:], p_kind[1:], + msg="parameter kind mismatch in %s" % p_func) + else: # Context methods are positional only in the C version. + self.assertEqual(len(c_kind), len(p_kind), + msg="parameter kind mismatch in %s" % p_func) + + # Run the function: + args, kwds = mkargs(C, c_sig) + try: + getattr(c_type(9), attr)(*args, **kwds) + except Exception as err: + raise TestFailed("invalid signature for %s: %s %s" % (c_func, args, kwds)) + + args, kwds = mkargs(P, p_sig) + try: + getattr(p_type(9), attr)(*args, **kwds) + except Exception as err: + raise TestFailed("invalid signature for %s: %s %s" % (p_func, args, kwds)) + + doit('Decimal') + doit('Context') + + all_tests = [ CExplicitConstructionTest, PyExplicitConstructionTest, CImplicitConstructionTest, PyImplicitConstructionTest, @@ -5434,6 +5558,7 @@ if not C: all_tests = all_tests[1::2] else: all_tests.insert(0, CheckAttributes) + all_tests.insert(1, SignatureTest) def test_main(arith=None, verbose=None, todo_tests=None, debug=None): diff --git a/Lib/test/test_decorators.py b/Lib/test/test_decorators.py index 53c9469..d0a2ec9 100644 --- a/Lib/test/test_decorators.py +++ b/Lib/test/test_decorators.py @@ -1,5 +1,4 @@ import unittest -from test import support def funcattrs(**kwds): def decorate(func): @@ -301,9 +300,5 @@ class TestClassDecorators(unittest.TestCase): class C(object): pass self.assertEqual(C.extra, 'second') -def test_main(): - support.run_unittest(TestDecorators) - support.run_unittest(TestClassDecorators) - -if __name__=="__main__": - test_main() +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py index 532d535..a90bc2b 100644 --- a/Lib/test/test_defaultdict.py +++ b/Lib/test/test_defaultdict.py @@ -5,7 +5,6 @@ import copy import pickle import tempfile import unittest -from test import support from collections import defaultdict @@ -157,8 +156,9 @@ class TestDefaultDict(unittest.TestCase): def _factory(self): return [] d = sub() - self.assertTrue(repr(d).startswith( - "defaultdict(<bound method sub._factory of defaultdict(...")) + self.assertRegex(repr(d), + r"defaultdict\(<bound method .*sub\._factory " + r"of defaultdict\(\.\.\., \{\}\)>, \{\}\)") # NOTE: printing a subclass of a builtin type does not call its # tp_print slot. So this part is essentially the same test as above. @@ -183,8 +183,5 @@ class TestDefaultDict(unittest.TestCase): o = pickle.loads(s) self.assertEqual(d, o) -def test_main(): - support.run_unittest(TestDefaultDict) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_deque.py b/Lib/test/test_deque.py index 5ecbc73..8718716 100644 --- a/Lib/test/test_deque.py +++ b/Lib/test/test_deque.py @@ -164,6 +164,26 @@ class TestBasic(unittest.TestCase): self.assertEqual(x > y, list(x) > list(y), (x,y)) self.assertEqual(x >= y, list(x) >= list(y), (x,y)) + def test_contains(self): + n = 200 + + d = deque(range(n)) + for i in range(n): + self.assertTrue(i in d) + self.assertTrue((n+1) not in d) + + # Test detection of mutation during iteration + d = deque(range(n)) + d[n//2] = MutateCmp(d, False) + with self.assertRaises(RuntimeError): + n in d + + # Test detection of comparison exceptions + d = deque(range(n)) + d[n//2] = BadCmp() + with self.assertRaises(RuntimeError): + n in d + def test_extend(self): d = deque('a') self.assertRaises(TypeError, d.extend, 1) @@ -172,6 +192,26 @@ class TestBasic(unittest.TestCase): d.extend(d) self.assertEqual(list(d), list('abcdabcd')) + def test_add(self): + d = deque() + e = deque('abc') + f = deque('def') + self.assertEqual(d + d, deque()) + self.assertEqual(e + f, deque('abcdef')) + self.assertEqual(e + e, deque('abcabc')) + self.assertEqual(e + d, deque('abc')) + self.assertEqual(d + e, deque('abc')) + self.assertIsNot(d + d, deque()) + self.assertIsNot(e + d, deque('abc')) + self.assertIsNot(d + e, deque('abc')) + + g = deque('abcdef', maxlen=4) + h = deque('gh') + self.assertEqual(g + h, deque('efgh')) + + with self.assertRaises(TypeError): + deque('abc') + 'def' + def test_iadd(self): d = deque('a') d += 'bcd' @@ -211,6 +251,116 @@ class TestBasic(unittest.TestCase): self.assertRaises(IndexError, d.__getitem__, 0) self.assertRaises(IndexError, d.__getitem__, -1) + def test_index(self): + for n in 1, 2, 30, 40, 200: + + d = deque(range(n)) + for i in range(n): + self.assertEqual(d.index(i), i) + + with self.assertRaises(ValueError): + d.index(n+1) + + # Test detection of mutation during iteration + d = deque(range(n)) + d[n//2] = MutateCmp(d, False) + with self.assertRaises(RuntimeError): + d.index(n) + + # Test detection of comparison exceptions + d = deque(range(n)) + d[n//2] = BadCmp() + with self.assertRaises(RuntimeError): + d.index(n) + + # Test start and stop arguments behavior matches list.index() + elements = 'ABCDEFGHI' + nonelement = 'Z' + d = deque(elements * 2) + s = list(elements * 2) + for start in range(-5 - len(s)*2, 5 + len(s) * 2): + for stop in range(-5 - len(s)*2, 5 + len(s) * 2): + for element in elements + 'Z': + try: + target = s.index(element, start, stop) + except ValueError: + with self.assertRaises(ValueError): + d.index(element, start, stop) + else: + self.assertEqual(d.index(element, start, stop), target) + + def test_insert_bug_24913(self): + d = deque('A' * 3) + with self.assertRaises(ValueError): + i = d.index("Hello world", 0, 4) + + def test_insert(self): + # Test to make sure insert behaves like lists + elements = 'ABCDEFGHI' + for i in range(-5 - len(elements)*2, 5 + len(elements) * 2): + d = deque('ABCDEFGHI') + s = list('ABCDEFGHI') + d.insert(i, 'Z') + s.insert(i, 'Z') + self.assertEqual(list(d), s) + + def test_imul(self): + for n in (-10, -1, 0, 1, 2, 10, 1000): + d = deque() + d *= n + self.assertEqual(d, deque()) + self.assertIsNone(d.maxlen) + + for n in (-10, -1, 0, 1, 2, 10, 1000): + d = deque('a') + d *= n + self.assertEqual(d, deque('a' * n)) + self.assertIsNone(d.maxlen) + + for n in (-10, -1, 0, 1, 2, 10, 499, 500, 501, 1000): + d = deque('a', 500) + d *= n + self.assertEqual(d, deque('a' * min(n, 500))) + self.assertEqual(d.maxlen, 500) + + for n in (-10, -1, 0, 1, 2, 10, 1000): + d = deque('abcdef') + d *= n + self.assertEqual(d, deque('abcdef' * n)) + self.assertIsNone(d.maxlen) + + for n in (-10, -1, 0, 1, 2, 10, 499, 500, 501, 1000): + d = deque('abcdef', 500) + d *= n + self.assertEqual(d, deque(('abcdef' * n)[-500:])) + self.assertEqual(d.maxlen, 500) + + def test_mul(self): + d = deque('abc') + self.assertEqual(d * -5, deque()) + self.assertEqual(d * 0, deque()) + self.assertEqual(d * 1, deque('abc')) + self.assertEqual(d * 2, deque('abcabc')) + self.assertEqual(d * 3, deque('abcabcabc')) + self.assertIsNot(d * 1, d) + + self.assertEqual(deque() * 0, deque()) + self.assertEqual(deque() * 1, deque()) + self.assertEqual(deque() * 5, deque()) + + self.assertEqual(-5 * d, deque()) + self.assertEqual(0 * d, deque()) + self.assertEqual(1 * d, deque('abc')) + self.assertEqual(2 * d, deque('abcabc')) + self.assertEqual(3 * d, deque('abcabcabc')) + + d = deque('abc', maxlen=5) + self.assertEqual(d * -5, deque()) + self.assertEqual(d * 0, deque()) + self.assertEqual(d * 1, deque('abc')) + self.assertEqual(d * 2, deque('bcabc')) + self.assertEqual(d * 30, deque('bcabc')) + def test_setitem(self): n = 200 d = deque(range(n)) @@ -504,10 +654,24 @@ class TestBasic(unittest.TestCase): self.assertNotEqual(id(d), id(e)) self.assertEqual(list(d), list(e)) + def test_copy_method(self): + mut = [10] + d = deque([mut]) + e = d.copy() + self.assertEqual(list(d), list(e)) + mut[0] = 11 + self.assertNotEqual(id(d), id(e)) + self.assertEqual(list(d), list(e)) + def test_reversed(self): for s in ('abcd', range(2000)): self.assertEqual(list(reversed(deque(s))), list(reversed(s))) + def test_reversed_new(self): + klass = type(reversed(deque())) + for s in ('abcd', range(2000)): + self.assertEqual(list(klass(deque(s))), list(reversed(s))) + def test_gc_doesnt_blowup(self): import gc # This used to assert-fail in deque_traverse() under a debug @@ -537,7 +701,7 @@ class TestBasic(unittest.TestCase): @support.cpython_only def test_sizeof(self): - BLOCKLEN = 62 + BLOCKLEN = 64 basesize = support.calcobjsize('2P4nlP') blocksize = struct.calcsize('2P%dP' % BLOCKLEN) self.assertEqual(object.__sizeof__(deque()), basesize) @@ -684,6 +848,21 @@ class TestSubclassWithKwargs(unittest.TestCase): # SF bug #1486663 -- this used to erroneously raise a TypeError SubclassWithKwargs(newarg=1) +class TestSequence(seq_tests.CommonTest): + type2test = deque + + def test_getitem(self): + # For now, bypass tests that require slicing + pass + + def test_getslice(self): + # For now, bypass tests that require slicing + pass + + def test_subscript(self): + # For now, bypass tests that require slicing + pass + #============================================================================== libreftest = """ @@ -798,6 +977,7 @@ def test_main(verbose=None): TestVariousIteratorArgs, TestSubclass, TestSubclassWithKwargs, + TestSequence, ) support.run_unittest(*test_classes) diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index 9a60a12..c74ebae 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -21,6 +21,7 @@ class OperatorsTest(unittest.TestCase): 'add': '+', 'sub': '-', 'mul': '*', + 'matmul': '@', 'truediv': '/', 'floordiv': '//', 'divmod': 'divmod', @@ -1019,6 +1020,67 @@ order (MRO) for bases """ self.assertEqual(x.foo, 1) self.assertEqual(x.__dict__, {'foo': 1}) + def test_object_class_assignment_between_heaptypes_and_nonheaptypes(self): + class SubType(types.ModuleType): + a = 1 + + m = types.ModuleType("m") + self.assertTrue(m.__class__ is types.ModuleType) + self.assertFalse(hasattr(m, "a")) + + m.__class__ = SubType + self.assertTrue(m.__class__ is SubType) + self.assertTrue(hasattr(m, "a")) + + m.__class__ = types.ModuleType + self.assertTrue(m.__class__ is types.ModuleType) + self.assertFalse(hasattr(m, "a")) + + # Make sure that builtin immutable objects don't support __class__ + # assignment, because the object instances may be interned. + # We set __slots__ = () to ensure that the subclasses are + # memory-layout compatible, and thus otherwise reasonable candidates + # for __class__ assignment. + + # The following types have immutable instances, but are not + # subclassable and thus don't need to be checked: + # NoneType, bool + + class MyInt(int): + __slots__ = () + with self.assertRaises(TypeError): + (1).__class__ = MyInt + + class MyFloat(float): + __slots__ = () + with self.assertRaises(TypeError): + (1.0).__class__ = MyFloat + + class MyComplex(complex): + __slots__ = () + with self.assertRaises(TypeError): + (1 + 2j).__class__ = MyComplex + + class MyStr(str): + __slots__ = () + with self.assertRaises(TypeError): + "a".__class__ = MyStr + + class MyBytes(bytes): + __slots__ = () + with self.assertRaises(TypeError): + b"a".__class__ = MyBytes + + class MyTuple(tuple): + __slots__ = () + with self.assertRaises(TypeError): + ().__class__ = MyTuple + + class MyFrozenSet(frozenset): + __slots__ = () + with self.assertRaises(TypeError): + frozenset().__class__ = MyFrozenSet + def test_slots(self): # Testing __slots__... class C0(object): @@ -2005,7 +2067,7 @@ order (MRO) for bases """ self.assertIs(raw.fset, C.__dict__['setx']) self.assertIs(raw.fdel, C.__dict__['delx']) - for attr in "__doc__", "fget", "fset", "fdel": + for attr in "fget", "fset", "fdel": try: setattr(raw, attr, 42) except AttributeError as msg: @@ -2016,6 +2078,9 @@ order (MRO) for bases """ self.fail("expected AttributeError from trying to set readonly %r " "attr on a property" % attr) + raw.__doc__ = 42 + self.assertEqual(raw.__doc__, 42) + class D(object): __getitem__ = property(lambda s: 1/0) @@ -3003,8 +3068,6 @@ order (MRO) for bases """ cant(object(), list) cant(list(), object) class Int(int): __slots__ = [] - cant(2, Int) - cant(Int(), int) cant(True, int) cant(2, bool) o = object() @@ -3324,7 +3387,7 @@ order (MRO) for bases """ A.__call__ = A() try: A()() - except RuntimeError: + except RecursionError: pass else: self.fail("Recursion limit should have been reached for __call__()") @@ -4153,6 +4216,7 @@ order (MRO) for bases """ ('__add__', 'x + y', 'x += y'), ('__sub__', 'x - y', 'x -= y'), ('__mul__', 'x * y', 'x *= y'), + ('__matmul__', 'x @ y', 'x @= y'), ('__truediv__', 'x / y', 'x /= y'), ('__floordiv__', 'x // y', 'x //= y'), ('__mod__', 'x % y', 'x %= y'), @@ -4298,8 +4362,8 @@ order (MRO) for bases """ pass Foo.__repr__ = Foo.__str__ foo = Foo() - self.assertRaises(RuntimeError, str, foo) - self.assertRaises(RuntimeError, repr, foo) + self.assertRaises(RecursionError, str, foo) + self.assertRaises(RecursionError, repr, foo) def test_mixing_slot_wrappers(self): class X(dict): @@ -4414,6 +4478,61 @@ order (MRO) for bases """ self.assertIn("__dict__", Base.__dict__) self.assertNotIn("__dict__", Sub.__dict__) + def test_bound_method_repr(self): + class Foo: + def method(self): + pass + self.assertRegex(repr(Foo().method), + r"<bound method .*Foo\.method of <.*Foo object at .*>>") + + + class Base: + def method(self): + pass + class Derived1(Base): + pass + class Derived2(Base): + def method(self): + pass + base = Base() + derived1 = Derived1() + derived2 = Derived2() + super_d2 = super(Derived2, derived2) + self.assertRegex(repr(base.method), + r"<bound method .*Base\.method of <.*Base object at .*>>") + self.assertRegex(repr(derived1.method), + r"<bound method .*Base\.method of <.*Derived1 object at .*>>") + self.assertRegex(repr(derived2.method), + r"<bound method .*Derived2\.method of <.*Derived2 object at .*>>") + self.assertRegex(repr(super_d2.method), + r"<bound method .*Base\.method of <.*Derived2 object at .*>>") + + class Foo: + @classmethod + def method(cls): + pass + foo = Foo() + self.assertRegex(repr(foo.method), # access via instance + r"<bound method .*Foo\.method of <class '.*Foo'>>") + self.assertRegex(repr(Foo.method), # access via the class + r"<bound method .*Foo\.method of <class '.*Foo'>>") + + + class MyCallable: + def __call__(self, arg): + pass + func = MyCallable() # func has no __name__ or __qualname__ attributes + instance = object() + method = types.MethodType(func, instance) + self.assertRegex(repr(method), + r"<bound method \? of <object object at .*>>") + func.__name__ = "name" + self.assertRegex(repr(method), + r"<bound method name of <object object at .*>>") + func.__qualname__ = "qualname" + self.assertRegex(repr(method), + r"<bound method qualname of <object object at .*>>") + class DictProxyTests(unittest.TestCase): def setUp(self): @@ -4528,26 +4647,15 @@ class PicklingTests(unittest.TestCase): def _check_reduce(self, proto, obj, args=(), kwargs={}, state=None, listitems=None, dictitems=None): - if proto >= 4: + if proto >= 2: reduce_value = obj.__reduce_ex__(proto) - self.assertEqual(reduce_value[:3], - (copyreg.__newobj_ex__, - (type(obj), args, kwargs), - state)) - if listitems is not None: - self.assertListEqual(list(reduce_value[3]), listitems) + if kwargs: + self.assertEqual(reduce_value[0], copyreg.__newobj_ex__) + self.assertEqual(reduce_value[1], (type(obj), args, kwargs)) else: - self.assertIsNone(reduce_value[3]) - if dictitems is not None: - self.assertDictEqual(dict(reduce_value[4]), dictitems) - else: - self.assertIsNone(reduce_value[4]) - elif proto >= 2: - reduce_value = obj.__reduce_ex__(proto) - self.assertEqual(reduce_value[:3], - (copyreg.__newobj__, - (type(obj),) + args, - state)) + self.assertEqual(reduce_value[0], copyreg.__newobj__) + self.assertEqual(reduce_value[1], (type(obj),) + args) + self.assertEqual(reduce_value[2], state) if listitems is not None: self.assertListEqual(list(reduce_value[3]), listitems) else: diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index bd79728..2488b63 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -963,12 +963,5 @@ class Dict(dict): class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol): type2test = Dict -def test_main(): - support.run_unittest( - DictTest, - GeneralMappingTests, - SubclassMappingTests, - ) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_dictviews.py b/Lib/test/test_dictviews.py index 7b02ea9..8d33801 100644 --- a/Lib/test/test_dictviews.py +++ b/Lib/test/test_dictviews.py @@ -1,5 +1,4 @@ import unittest -from test import support class DictSetTest(unittest.TestCase): @@ -196,11 +195,8 @@ class DictSetTest(unittest.TestCase): def test_recursive_repr(self): d = {} d[42] = d.values() - self.assertRaises(RuntimeError, repr, d) + self.assertRaises(RecursionError, repr, d) -def test_main(): - support.run_unittest(DictSetTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_difflib.py b/Lib/test/test_difflib.py index 0ba8f0e..ab9debf 100644 --- a/Lib/test/test_difflib.py +++ b/Lib/test/test_difflib.py @@ -107,6 +107,20 @@ patch914575_to1 = """ 5. Flat is better than nested. """ +patch914575_nonascii_from1 = """ + 1. Beautiful is beTTer than ugly. + 2. Explicit is better than ımplıcıt. + 3. Simple is better than complex. + 4. Complex is better than complicated. +""" + +patch914575_nonascii_to1 = """ + 1. Beautiful is better than ügly. + 3. Sımple is better than complex. + 4. Complicated is better than cömplex. + 5. Flat is better than nested. +""" + patch914575_from2 = """ \t\tLine 1: preceeded by from:[tt] to:[ssss] \t\tLine 2: preceeded by from:[sstt] to:[sssst] @@ -223,6 +237,27 @@ class TestSFpatches(unittest.TestCase): new = [(i%2 and "K:%d" or "V:B:%d") % i for i in range(limit*2)] difflib.SequenceMatcher(None, old, new).get_opcodes() + def test_make_file_default_charset(self): + html_diff = difflib.HtmlDiff() + output = html_diff.make_file(patch914575_from1.splitlines(), + patch914575_to1.splitlines()) + self.assertIn('content="text/html; charset=utf-8"', output) + + def test_make_file_iso88591_charset(self): + html_diff = difflib.HtmlDiff() + output = html_diff.make_file(patch914575_from1.splitlines(), + patch914575_to1.splitlines(), + charset='iso-8859-1') + self.assertIn('content="text/html; charset=iso-8859-1"', output) + + def test_make_file_usascii_charset_with_nonascii_input(self): + html_diff = difflib.HtmlDiff() + output = html_diff.make_file(patch914575_nonascii_from1.splitlines(), + patch914575_nonascii_to1.splitlines(), + charset='us-ascii') + self.assertIn('content="text/html; charset=us-ascii"', output) + self.assertIn('ımplıcıt', output) + class TestOutputFormat(unittest.TestCase): def test_tab_delimiter(self): @@ -287,12 +322,157 @@ class TestOutputFormat(unittest.TestCase): self.assertEqual(fmt(0,0), '0') +class TestBytes(unittest.TestCase): + # don't really care about the content of the output, just the fact + # that it's bytes and we don't crash + def check(self, diff): + diff = list(diff) # trigger exceptions first + for line in diff: + self.assertIsInstance( + line, bytes, + "all lines of diff should be bytes, but got: %r" % line) + + def test_byte_content(self): + # if we receive byte strings, we return byte strings + a = [b'hello', b'andr\xe9'] # iso-8859-1 bytes + b = [b'hello', b'andr\xc3\xa9'] # utf-8 bytes + + unified = difflib.unified_diff + context = difflib.context_diff + + check = self.check + check(difflib.diff_bytes(unified, a, a)) + check(difflib.diff_bytes(unified, a, b)) + + # now with filenames (content and filenames are all bytes!) + check(difflib.diff_bytes(unified, a, a, b'a', b'a')) + check(difflib.diff_bytes(unified, a, b, b'a', b'b')) + + # and with filenames and dates + check(difflib.diff_bytes(unified, a, a, b'a', b'a', b'2005', b'2013')) + check(difflib.diff_bytes(unified, a, b, b'a', b'b', b'2005', b'2013')) + + # same all over again, with context diff + check(difflib.diff_bytes(context, a, a)) + check(difflib.diff_bytes(context, a, b)) + check(difflib.diff_bytes(context, a, a, b'a', b'a')) + check(difflib.diff_bytes(context, a, b, b'a', b'b')) + check(difflib.diff_bytes(context, a, a, b'a', b'a', b'2005', b'2013')) + check(difflib.diff_bytes(context, a, b, b'a', b'b', b'2005', b'2013')) + + def test_byte_filenames(self): + # somebody renamed a file from ISO-8859-2 to UTF-8 + fna = b'\xb3odz.txt' # "łodz.txt" + fnb = b'\xc5\x82odz.txt' + + # they transcoded the content at the same time + a = [b'\xa3odz is a city in Poland.'] + b = [b'\xc5\x81odz is a city in Poland.'] + + check = self.check + unified = difflib.unified_diff + context = difflib.context_diff + check(difflib.diff_bytes(unified, a, b, fna, fnb)) + check(difflib.diff_bytes(context, a, b, fna, fnb)) + + def assertDiff(expect, actual): + # do not compare expect and equal as lists, because unittest + # uses difflib to report difference between lists + actual = list(actual) + self.assertEqual(len(expect), len(actual)) + for e, a in zip(expect, actual): + self.assertEqual(e, a) + + expect = [ + b'--- \xb3odz.txt', + b'+++ \xc5\x82odz.txt', + b'@@ -1 +1 @@', + b'-\xa3odz is a city in Poland.', + b'+\xc5\x81odz is a city in Poland.', + ] + actual = difflib.diff_bytes(unified, a, b, fna, fnb, lineterm=b'') + assertDiff(expect, actual) + + # with dates (plain ASCII) + datea = b'2005-03-18' + dateb = b'2005-03-19' + check(difflib.diff_bytes(unified, a, b, fna, fnb, datea, dateb)) + check(difflib.diff_bytes(context, a, b, fna, fnb, datea, dateb)) + + expect = [ + # note the mixed encodings here: this is deeply wrong by every + # tenet of Unicode, but it doesn't crash, it's parseable by + # patch, and it's how UNIX(tm) diff behaves + b'--- \xb3odz.txt\t2005-03-18', + b'+++ \xc5\x82odz.txt\t2005-03-19', + b'@@ -1 +1 @@', + b'-\xa3odz is a city in Poland.', + b'+\xc5\x81odz is a city in Poland.', + ] + actual = difflib.diff_bytes(unified, a, b, fna, fnb, datea, dateb, + lineterm=b'') + assertDiff(expect, actual) + + def test_mixed_types_content(self): + # type of input content must be consistent: all str or all bytes + a = [b'hello'] + b = ['hello'] + + unified = difflib.unified_diff + context = difflib.context_diff + + expect = "lines to compare must be str, not bytes (b'hello')" + self._assert_type_error(expect, unified, a, b) + self._assert_type_error(expect, unified, b, a) + self._assert_type_error(expect, context, a, b) + self._assert_type_error(expect, context, b, a) + + expect = "all arguments must be bytes, not str ('hello')" + self._assert_type_error(expect, difflib.diff_bytes, unified, a, b) + self._assert_type_error(expect, difflib.diff_bytes, unified, b, a) + self._assert_type_error(expect, difflib.diff_bytes, context, a, b) + self._assert_type_error(expect, difflib.diff_bytes, context, b, a) + + def test_mixed_types_filenames(self): + # cannot pass filenames as bytes if content is str (this may not be + # the right behaviour, but at least the test demonstrates how + # things work) + a = ['hello\n'] + b = ['ohell\n'] + fna = b'ol\xe9.txt' # filename transcoded from ISO-8859-1 + fnb = b'ol\xc3a9.txt' # to UTF-8 + self._assert_type_error( + "all arguments must be str, not: b'ol\\xe9.txt'", + difflib.unified_diff, a, b, fna, fnb) + + def test_mixed_types_dates(self): + # type of dates must be consistent with type of contents + a = [b'foo\n'] + b = [b'bar\n'] + datea = '1 fév' + dateb = '3 fév' + self._assert_type_error( + "all arguments must be bytes, not str ('1 fév')", + difflib.diff_bytes, difflib.unified_diff, + a, b, b'a', b'b', datea, dateb) + + # if input is str, non-ASCII dates are fine + a = ['foo\n'] + b = ['bar\n'] + list(difflib.unified_diff(a, b, 'a', 'b', datea, dateb)) + + def _assert_type_error(self, msg, generator, *args): + with self.assertRaises(TypeError) as ctx: + list(generator(*args)) + self.assertEqual(msg, str(ctx.exception)) + + def test_main(): difflib.HtmlDiff._default_prefix = 0 Doctests = doctest.DocTestSuite(difflib) run_unittest( TestWithAscii, TestAutojunk, TestSFpatches, TestSFbugs, - TestOutputFormat, Doctests) + TestOutputFormat, TestBytes, Doctests) if __name__ == '__main__': test_main() diff --git a/Lib/test/test_difflib_expect.html b/Lib/test/test_difflib_expect.html index 71b6d7a..ea7a24e 100644 --- a/Lib/test/test_difflib_expect.html +++ b/Lib/test/test_difflib_expect.html @@ -6,7 +6,7 @@ <head> <meta http-equiv="Content-Type" - content="text/html; charset=ISO-8859-1" /> + content="text/html; charset=utf-8" /> <title></title> <style type="text/css"> table.diff {font-family:Courier; border:medium;} diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py index b8daff7..421bbad 100644 --- a/Lib/test/test_dis.py +++ b/Lib/test/test_dis.py @@ -1,6 +1,6 @@ # Minimal tests for dis module -from test.support import run_unittest, captured_stdout +from test.support import captured_stdout from test.bytecode_helper import BytecodeTestCase import difflib import unittest @@ -230,6 +230,9 @@ dis_traceback = """\ TRACEBACK_CODE.co_firstlineno + 4, TRACEBACK_CODE.co_firstlineno + 5) +def _g(x): + yield x + class DisTests(unittest.TestCase): def get_disassembly(self, func, lasti=-1, wrapper=True): @@ -315,6 +318,11 @@ class DisTests(unittest.TestCase): method_bytecode = _C(1).__init__.__code__.co_code self.do_disassembly_test(method_bytecode, dis_c_instance_method_bytes) + def test_disassemble_generator(self): + gen_func_disas = self.get_disassembly(_g) # Disassemble generator function + gen_disas = self.get_disassembly(_g(1)) # Disassemble generator itself + self.assertEqual(gen_disas, gen_func_disas) + def test_dis_none(self): try: del sys.last_traceback @@ -472,6 +480,24 @@ Constants: Names: 0: x""" + +async def async_def(): + await 1 + async for a in b: pass + async with c as d: pass + +code_info_async_def = """\ +Name: async_def +Filename: (.*) +Argument count: 0 +Kw-only arguments: 0 +Number of locals: 2 +Stack size: 17 +Flags: OPTIMIZED, NEWLOCALS, GENERATOR, NOFREE, COROUTINE +Constants: + 0: None + 1: 1""" + class CodeInfoTests(unittest.TestCase): test_pairs = [ (dis.code_info, code_info_code_info), @@ -480,6 +506,7 @@ class CodeInfoTests(unittest.TestCase): (expr_str, code_info_expr_str), (simple_stmt_str, code_info_simple_stmt_str), (compound_stmt_str, code_info_compound_stmt_str), + (async_def, code_info_async_def) ] def test_code_info(self): @@ -561,10 +588,10 @@ expected_jumpy_line = 1 #_instructions = dis.get_instructions(outer, first_line=expected_outer_line) #print('expected_opinfo_outer = [\n ', #',\n '.join(map(str, _instructions)), ',\n]', sep='') -#_instructions = dis.get_instructions(outer(), first_line=expected_outer_line) +#_instructions = dis.get_instructions(outer(), first_line=expected_f_line) #print('expected_opinfo_f = [\n ', #',\n '.join(map(str, _instructions)), ',\n]', sep='') -#_instructions = dis.get_instructions(outer()(), first_line=expected_outer_line) +#_instructions = dis.get_instructions(outer()(), first_line=expected_inner_line) #print('expected_opinfo_inner = [\n ', #',\n '.join(map(str, _instructions)), ',\n]', sep='') #_instructions = dis.get_instructions(jumpy, first_line=expected_jumpy_line) @@ -635,12 +662,12 @@ expected_opinfo_inner = [ ] expected_opinfo_jumpy = [ - Instruction(opname='SETUP_LOOP', opcode=120, arg=74, argval=77, argrepr='to 77', offset=0, starts_line=3, is_jump_target=False), + Instruction(opname='SETUP_LOOP', opcode=120, arg=68, argval=71, argrepr='to 71', offset=0, starts_line=3, is_jump_target=False), Instruction(opname='LOAD_GLOBAL', opcode=116, arg=0, argval='range', argrepr='range', offset=3, starts_line=None, is_jump_target=False), Instruction(opname='LOAD_CONST', opcode=100, arg=1, argval=10, argrepr='10', offset=6, starts_line=None, is_jump_target=False), Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=9, starts_line=None, is_jump_target=False), Instruction(opname='GET_ITER', opcode=68, arg=None, argval=None, argrepr='', offset=12, starts_line=None, is_jump_target=False), - Instruction(opname='FOR_ITER', opcode=93, arg=50, argval=66, argrepr='to 66', offset=13, starts_line=None, is_jump_target=True), + Instruction(opname='FOR_ITER', opcode=93, arg=44, argval=60, argrepr='to 60', offset=13, starts_line=None, is_jump_target=True), Instruction(opname='STORE_FAST', opcode=125, arg=0, argval='i', argrepr='i', offset=16, starts_line=None, is_jump_target=False), Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=19, starts_line=4, is_jump_target=False), Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=22, starts_line=None, is_jump_target=False), @@ -649,92 +676,89 @@ expected_opinfo_jumpy = [ Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=29, starts_line=5, is_jump_target=False), Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=4, argrepr='4', offset=32, starts_line=None, is_jump_target=False), Instruction(opname='COMPARE_OP', opcode=107, arg=0, argval='<', argrepr='<', offset=35, starts_line=None, is_jump_target=False), - Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=47, argval=47, argrepr='', offset=38, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=44, argval=44, argrepr='', offset=38, starts_line=None, is_jump_target=False), Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=41, starts_line=6, is_jump_target=False), - Instruction(opname='JUMP_FORWARD', opcode=110, arg=0, argval=47, argrepr='to 47', offset=44, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=47, starts_line=7, is_jump_target=True), - Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=6, argrepr='6', offset=50, starts_line=None, is_jump_target=False), - Instruction(opname='COMPARE_OP', opcode=107, arg=4, argval='>', argrepr='>', offset=53, starts_line=None, is_jump_target=False), - Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=13, argval=13, argrepr='', offset=56, starts_line=None, is_jump_target=False), - Instruction(opname='BREAK_LOOP', opcode=80, arg=None, argval=None, argrepr='', offset=59, starts_line=8, is_jump_target=False), - Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=60, starts_line=None, is_jump_target=False), - Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=63, starts_line=None, is_jump_target=False), - Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=66, starts_line=None, is_jump_target=True), - Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=67, starts_line=10, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=4, argval='I can haz else clause?', argrepr="'I can haz else clause?'", offset=70, starts_line=None, is_jump_target=False), - Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=73, starts_line=None, is_jump_target=False), - Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=76, starts_line=None, is_jump_target=False), - Instruction(opname='SETUP_LOOP', opcode=120, arg=74, argval=154, argrepr='to 154', offset=77, starts_line=11, is_jump_target=True), - Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=80, starts_line=None, is_jump_target=True), - Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=143, argval=143, argrepr='', offset=83, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=86, starts_line=12, is_jump_target=False), - Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=89, starts_line=None, is_jump_target=False), - Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=92, starts_line=None, is_jump_target=False), - Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=95, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=96, starts_line=13, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval=1, argrepr='1', offset=99, starts_line=None, is_jump_target=False), - Instruction(opname='INPLACE_SUBTRACT', opcode=56, arg=None, argval=None, argrepr='', offset=102, starts_line=None, is_jump_target=False), - Instruction(opname='STORE_FAST', opcode=125, arg=0, argval='i', argrepr='i', offset=103, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=106, starts_line=14, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=6, argrepr='6', offset=109, starts_line=None, is_jump_target=False), - Instruction(opname='COMPARE_OP', opcode=107, arg=4, argval='>', argrepr='>', offset=112, starts_line=None, is_jump_target=False), - Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=124, argval=124, argrepr='', offset=115, starts_line=None, is_jump_target=False), - Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=80, argval=80, argrepr='', offset=118, starts_line=15, is_jump_target=False), - Instruction(opname='JUMP_FORWARD', opcode=110, arg=0, argval=124, argrepr='to 124', offset=121, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=124, starts_line=16, is_jump_target=True), - Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=4, argrepr='4', offset=127, starts_line=None, is_jump_target=False), - Instruction(opname='COMPARE_OP', opcode=107, arg=0, argval='<', argrepr='<', offset=130, starts_line=None, is_jump_target=False), - Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=80, argval=80, argrepr='', offset=133, starts_line=None, is_jump_target=False), - Instruction(opname='BREAK_LOOP', opcode=80, arg=None, argval=None, argrepr='', offset=136, starts_line=17, is_jump_target=False), - Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=80, argval=80, argrepr='', offset=137, starts_line=None, is_jump_target=False), - Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=80, argval=80, argrepr='', offset=140, starts_line=None, is_jump_target=False), - Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=143, starts_line=None, is_jump_target=True), - Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=144, starts_line=19, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=6, argval='Who let lolcatz into this test suite?', argrepr="'Who let lolcatz into this test suite?'", offset=147, starts_line=None, is_jump_target=False), - Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=150, starts_line=None, is_jump_target=False), - Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=153, starts_line=None, is_jump_target=False), - Instruction(opname='SETUP_FINALLY', opcode=122, arg=72, argval=229, argrepr='to 229', offset=154, starts_line=20, is_jump_target=True), - Instruction(opname='SETUP_EXCEPT', opcode=121, arg=12, argval=172, argrepr='to 172', offset=157, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval=1, argrepr='1', offset=160, starts_line=21, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=7, argval=0, argrepr='0', offset=163, starts_line=None, is_jump_target=False), - Instruction(opname='BINARY_TRUE_DIVIDE', opcode=27, arg=None, argval=None, argrepr='', offset=166, starts_line=None, is_jump_target=False), - Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=167, starts_line=None, is_jump_target=False), - Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=168, starts_line=None, is_jump_target=False), - Instruction(opname='JUMP_FORWARD', opcode=110, arg=28, argval=200, argrepr='to 200', offset=169, starts_line=None, is_jump_target=False), - Instruction(opname='DUP_TOP', opcode=4, arg=None, argval=None, argrepr='', offset=172, starts_line=22, is_jump_target=True), - Instruction(opname='LOAD_GLOBAL', opcode=116, arg=2, argval='ZeroDivisionError', argrepr='ZeroDivisionError', offset=173, starts_line=None, is_jump_target=False), - Instruction(opname='COMPARE_OP', opcode=107, arg=10, argval='exception match', argrepr='exception match', offset=176, starts_line=None, is_jump_target=False), - Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=199, argval=199, argrepr='', offset=179, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=44, starts_line=7, is_jump_target=True), + Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=6, argrepr='6', offset=47, starts_line=None, is_jump_target=False), + Instruction(opname='COMPARE_OP', opcode=107, arg=4, argval='>', argrepr='>', offset=50, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=13, argval=13, argrepr='', offset=53, starts_line=None, is_jump_target=False), + Instruction(opname='BREAK_LOOP', opcode=80, arg=None, argval=None, argrepr='', offset=56, starts_line=8, is_jump_target=False), + Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=13, argval=13, argrepr='', offset=57, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=60, starts_line=None, is_jump_target=True), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=61, starts_line=10, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=4, argval='I can haz else clause?', argrepr="'I can haz else clause?'", offset=64, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=67, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=70, starts_line=None, is_jump_target=False), + Instruction(opname='SETUP_LOOP', opcode=120, arg=68, argval=142, argrepr='to 142', offset=71, starts_line=11, is_jump_target=True), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=74, starts_line=None, is_jump_target=True), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=131, argval=131, argrepr='', offset=77, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=80, starts_line=12, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=83, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=86, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=89, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=90, starts_line=13, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval=1, argrepr='1', offset=93, starts_line=None, is_jump_target=False), + Instruction(opname='INPLACE_SUBTRACT', opcode=56, arg=None, argval=None, argrepr='', offset=96, starts_line=None, is_jump_target=False), + Instruction(opname='STORE_FAST', opcode=125, arg=0, argval='i', argrepr='i', offset=97, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=100, starts_line=14, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=3, argval=6, argrepr='6', offset=103, starts_line=None, is_jump_target=False), + Instruction(opname='COMPARE_OP', opcode=107, arg=4, argval='>', argrepr='>', offset=106, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=115, argval=115, argrepr='', offset=109, starts_line=None, is_jump_target=False), + Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=74, argval=74, argrepr='', offset=112, starts_line=15, is_jump_target=False), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=115, starts_line=16, is_jump_target=True), + Instruction(opname='LOAD_CONST', opcode=100, arg=2, argval=4, argrepr='4', offset=118, starts_line=None, is_jump_target=False), + Instruction(opname='COMPARE_OP', opcode=107, arg=0, argval='<', argrepr='<', offset=121, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=74, argval=74, argrepr='', offset=124, starts_line=None, is_jump_target=False), + Instruction(opname='BREAK_LOOP', opcode=80, arg=None, argval=None, argrepr='', offset=127, starts_line=17, is_jump_target=False), + Instruction(opname='JUMP_ABSOLUTE', opcode=113, arg=74, argval=74, argrepr='', offset=128, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=131, starts_line=None, is_jump_target=True), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=132, starts_line=19, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=6, argval='Who let lolcatz into this test suite?', argrepr="'Who let lolcatz into this test suite?'", offset=135, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=138, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=141, starts_line=None, is_jump_target=False), + Instruction(opname='SETUP_FINALLY', opcode=122, arg=73, argval=218, argrepr='to 218', offset=142, starts_line=20, is_jump_target=True), + Instruction(opname='SETUP_EXCEPT', opcode=121, arg=12, argval=160, argrepr='to 160', offset=145, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=5, argval=1, argrepr='1', offset=148, starts_line=21, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=7, argval=0, argrepr='0', offset=151, starts_line=None, is_jump_target=False), + Instruction(opname='BINARY_TRUE_DIVIDE', opcode=27, arg=None, argval=None, argrepr='', offset=154, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=155, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=156, starts_line=None, is_jump_target=False), + Instruction(opname='JUMP_FORWARD', opcode=110, arg=28, argval=188, argrepr='to 188', offset=157, starts_line=None, is_jump_target=False), + Instruction(opname='DUP_TOP', opcode=4, arg=None, argval=None, argrepr='', offset=160, starts_line=22, is_jump_target=True), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=2, argval='ZeroDivisionError', argrepr='ZeroDivisionError', offset=161, starts_line=None, is_jump_target=False), + Instruction(opname='COMPARE_OP', opcode=107, arg=10, argval='exception match', argrepr='exception match', offset=164, starts_line=None, is_jump_target=False), + Instruction(opname='POP_JUMP_IF_FALSE', opcode=114, arg=187, argval=187, argrepr='', offset=167, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=170, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=171, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=172, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=173, starts_line=23, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=8, argval='Here we go, here we go, here we go...', argrepr="'Here we go, here we go, here we go...'", offset=176, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=179, starts_line=None, is_jump_target=False), Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=182, starts_line=None, is_jump_target=False), - Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=183, starts_line=None, is_jump_target=False), - Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=184, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=185, starts_line=23, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=8, argval='Here we go, here we go, here we go...', argrepr="'Here we go, here we go, here we go...'", offset=188, starts_line=None, is_jump_target=False), - Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=191, starts_line=None, is_jump_target=False), - Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=194, starts_line=None, is_jump_target=False), - Instruction(opname='POP_EXCEPT', opcode=89, arg=None, argval=None, argrepr='', offset=195, starts_line=None, is_jump_target=False), - Instruction(opname='JUMP_FORWARD', opcode=110, arg=26, argval=225, argrepr='to 225', offset=196, starts_line=None, is_jump_target=False), - Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=199, starts_line=None, is_jump_target=True), - Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=200, starts_line=25, is_jump_target=True), - Instruction(opname='SETUP_WITH', opcode=143, arg=17, argval=223, argrepr='to 223', offset=203, starts_line=None, is_jump_target=False), - Instruction(opname='STORE_FAST', opcode=125, arg=1, argval='dodgy', argrepr='dodgy', offset=206, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=209, starts_line=26, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=9, argval='Never reach this', argrepr="'Never reach this'", offset=212, starts_line=None, is_jump_target=False), - Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=215, starts_line=None, is_jump_target=False), - Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=218, starts_line=None, is_jump_target=False), - Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=219, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=220, starts_line=None, is_jump_target=False), - Instruction(opname='WITH_CLEANUP', opcode=81, arg=None, argval=None, argrepr='', offset=223, starts_line=None, is_jump_target=True), - Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=224, starts_line=None, is_jump_target=False), - Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=225, starts_line=None, is_jump_target=True), - Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=226, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=229, starts_line=28, is_jump_target=True), - Instruction(opname='LOAD_CONST', opcode=100, arg=10, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=232, starts_line=None, is_jump_target=False), - Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=235, starts_line=None, is_jump_target=False), - Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=238, starts_line=None, is_jump_target=False), - Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=239, starts_line=None, is_jump_target=False), - Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=240, starts_line=None, is_jump_target=False), - Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=243, starts_line=None, is_jump_target=False), + Instruction(opname='POP_EXCEPT', opcode=89, arg=None, argval=None, argrepr='', offset=183, starts_line=None, is_jump_target=False), + Instruction(opname='JUMP_FORWARD', opcode=110, arg=27, argval=214, argrepr='to 214', offset=184, starts_line=None, is_jump_target=False), + Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=187, starts_line=None, is_jump_target=True), + Instruction(opname='LOAD_FAST', opcode=124, arg=0, argval='i', argrepr='i', offset=188, starts_line=25, is_jump_target=True), + Instruction(opname='SETUP_WITH', opcode=143, arg=17, argval=211, argrepr='to 211', offset=191, starts_line=None, is_jump_target=False), + Instruction(opname='STORE_FAST', opcode=125, arg=1, argval='dodgy', argrepr='dodgy', offset=194, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=197, starts_line=26, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=9, argval='Never reach this', argrepr="'Never reach this'", offset=200, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=203, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=206, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=207, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=208, starts_line=None, is_jump_target=False), + Instruction(opname='WITH_CLEANUP_START', opcode=81, arg=None, argval=None, argrepr='', offset=211, starts_line=None, is_jump_target=True), + Instruction(opname='WITH_CLEANUP_FINISH', opcode=82, arg=None, argval=None, argrepr='', offset=212, starts_line=None, is_jump_target=False), + Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=213, starts_line=None, is_jump_target=False), + Instruction(opname='POP_BLOCK', opcode=87, arg=None, argval=None, argrepr='', offset=214, starts_line=None, is_jump_target=True), + Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=215, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_GLOBAL', opcode=116, arg=1, argval='print', argrepr='print', offset=218, starts_line=28, is_jump_target=True), + Instruction(opname='LOAD_CONST', opcode=100, arg=10, argval="OK, now we're done", argrepr='"OK, now we\'re done"', offset=221, starts_line=None, is_jump_target=False), + Instruction(opname='CALL_FUNCTION', opcode=131, arg=1, argval=1, argrepr='1 positional, 0 keyword pair', offset=224, starts_line=None, is_jump_target=False), + Instruction(opname='POP_TOP', opcode=1, arg=None, argval=None, argrepr='', offset=227, starts_line=None, is_jump_target=False), + Instruction(opname='END_FINALLY', opcode=88, arg=None, argval=None, argrepr='', offset=228, starts_line=None, is_jump_target=False), + Instruction(opname='LOAD_CONST', opcode=100, arg=0, argval=None, argrepr='None', offset=229, starts_line=None, is_jump_target=False), + Instruction(opname='RETURN_VALUE', opcode=83, arg=None, argval=None, argrepr='', offset=232, starts_line=None, is_jump_target=False), ] # One last piece of inspect fodder to check the default line number handling diff --git a/Lib/test/test_doctest.py b/Lib/test/test_doctest.py index 9292d92..73b4452 100644 --- a/Lib/test/test_doctest.py +++ b/Lib/test/test_doctest.py @@ -4,6 +4,7 @@ Test script for doctest. from test import support import doctest +import functools import os import sys @@ -434,7 +435,7 @@ We'll simulate a __file__ attr that ends in pyc: >>> tests = finder.find(sample_func) >>> print(tests) # doctest: +ELLIPSIS - [<DocTest sample_func from ...:18 (1 example)>] + [<DocTest sample_func from ...:19 (1 example)>] The exact name depends on how test_doctest was invoked, so allow for leading path components. @@ -658,7 +659,7 @@ plain ol' Python and is guaranteed to be available. >>> import builtins >>> tests = doctest.DocTestFinder().find(builtins) - >>> 790 < len(tests) < 800 # approximate number of objects with docstrings + >>> 790 < len(tests) < 810 # approximate number of objects with docstrings True >>> real_tests = [t for t in tests if len(t.examples) > 0] >>> len(real_tests) # objects that actually have doctests @@ -2096,22 +2097,9 @@ def test_DocTestSuite(): >>> suite.run(unittest.TestResult()) <unittest.result.TestResult run=0 errors=0 failures=0> - However, if DocTestSuite finds no docstrings, it raises an error: + The module need not contain any docstrings either: - >>> try: - ... doctest.DocTestSuite('test.sample_doctest_no_docstrings') - ... except ValueError as e: - ... error = e - - >>> print(error.args[1]) - has no docstrings - - You can prevent this error by passing a DocTestFinder instance with - the `exclude_empty` keyword argument set to False: - - >>> finder = doctest.DocTestFinder(exclude_empty=False) - >>> suite = doctest.DocTestSuite('test.sample_doctest_no_docstrings', - ... test_finder=finder) + >>> suite = doctest.DocTestSuite('test.sample_doctest_no_docstrings') >>> suite.run(unittest.TestResult()) <unittest.result.TestResult run=0 errors=0 failures=0> @@ -2121,6 +2109,22 @@ def test_DocTestSuite(): >>> suite.run(unittest.TestResult()) <unittest.result.TestResult run=9 errors=0 failures=4> + We can also provide a DocTestFinder: + + >>> finder = doctest.DocTestFinder() + >>> suite = doctest.DocTestSuite('test.sample_doctest', + ... test_finder=finder) + >>> suite.run(unittest.TestResult()) + <unittest.result.TestResult run=9 errors=0 failures=4> + + The DocTestFinder need not return any tests: + + >>> finder = doctest.DocTestFinder() + >>> suite = doctest.DocTestSuite('test.sample_doctest_no_docstrings', + ... test_finder=finder) + >>> suite.run(unittest.TestResult()) + <unittest.result.TestResult run=0 errors=0 failures=0> + We can supply global variables. If we pass globs, they will be used instead of the module globals. Here we'll pass an empty globals, triggering an extra error: @@ -2168,7 +2172,7 @@ def test_DocTestSuite(): >>> test.test_doctest.sillySetup Traceback (most recent call last): ... - AttributeError: 'module' object has no attribute 'sillySetup' + AttributeError: module 'test.test_doctest' has no attribute 'sillySetup' The setUp and tearDown functions are passed test objects. Here we'll use the setUp function to supply the missing variable y: @@ -2314,7 +2318,7 @@ def test_DocFileSuite(): >>> test.test_doctest.sillySetup Traceback (most recent call last): ... - AttributeError: 'module' object has no attribute 'sillySetup' + AttributeError: module 'test.test_doctest' has no attribute 'sillySetup' The setUp and tearDown functions are passed test objects. Here, we'll use a setUp function to set the favorite color in @@ -2361,6 +2365,22 @@ def test_trailing_space_in_test(): foo \n """ +class Wrapper: + def __init__(self, func): + self.func = func + functools.update_wrapper(self, func) + + def __call__(self, *args, **kwargs): + self.func(*args, **kwargs) + +@Wrapper +def test_look_in_unwrapped(): + """ + Docstrings in wrapped functions must be detected as well. + + >>> 'one other test' + 'one other test' + """ def test_unittest_reportflags(): """Default unittest reporting flags can be set to control reporting @@ -2709,8 +2729,8 @@ With those preliminaries out of the way, we'll start with a file with two simple tests and no errors. We'll run both the unadorned doctest command, and the verbose version, and then check the output: - >>> from test import script_helper - >>> with script_helper.temp_dir() as tmpdir: + >>> from test.support import script_helper, temp_dir + >>> with temp_dir() as tmpdir: ... fn = os.path.join(tmpdir, 'myfile.doc') ... with open(fn, 'w') as f: ... _ = f.write('This is a very simple test file.\n') @@ -2760,8 +2780,8 @@ ability to process more than one file on the command line and, since the second file ends in '.py', its handling of python module files (as opposed to straight text files). - >>> from test import script_helper - >>> with script_helper.temp_dir() as tmpdir: + >>> from test.support import script_helper, temp_dir + >>> with temp_dir() as tmpdir: ... fn = os.path.join(tmpdir, 'myfile.doc') ... with open(fn, 'w') as f: ... _ = f.write('This is another simple test file.\n') @@ -2927,7 +2947,7 @@ Invalid doctest option: def test_main(): # Check the doctest cases in doctest itself: - support.run_doctest(doctest, verbosity=True) + ret = support.run_doctest(doctest, verbosity=True) # Check the doctest cases defined here: from test import test_doctest support.run_doctest(test_doctest, verbosity=True) diff --git a/Lib/test/test_docxmlrpc.py b/Lib/test/test_docxmlrpc.py index 06161f2..e6ca961 100644 --- a/Lib/test/test_docxmlrpc.py +++ b/Lib/test/test_docxmlrpc.py @@ -87,10 +87,11 @@ class DocXMLRPCHTTPGETServer(unittest.TestCase): threading.Thread(target=server, args=(self.evt, 1)).start() # wait for port to be assigned - n = 1000 - while n > 0 and PORT is None: - time.sleep(0.001) - n -= 1 + deadline = time.monotonic() + 10.0 + while PORT is None: + time.sleep(0.010) + if time.monotonic() > deadline: + break self.client = http.client.HTTPConnection("localhost:%d" % PORT) @@ -212,8 +213,5 @@ class DocXMLRPCHTTPGETServer(unittest.TestCase): response.read()) -def test_main(): - support.run_unittest(DocXMLRPCHTTPGETServer) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_dummy_threading.py b/Lib/test/test_dummy_threading.py index 6ec5da3..a0c2972 100644 --- a/Lib/test/test_dummy_threading.py +++ b/Lib/test/test_dummy_threading.py @@ -56,9 +56,5 @@ class DummyThreadingTestCase(unittest.TestCase): if support.verbose: print('all tasks done') -def test_main(): - support.run_unittest(DummyThreadingTestCase) - - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_dynamic.py b/Lib/test/test_dynamic.py index beb7b1c..5080ec9 100644 --- a/Lib/test/test_dynamic.py +++ b/Lib/test/test_dynamic.py @@ -4,7 +4,7 @@ import builtins import contextlib import unittest -from test.support import run_unittest, swap_item, swap_attr +from test.support import swap_item, swap_attr class RebindBuiltinsTests(unittest.TestCase): @@ -135,9 +135,5 @@ class RebindBuiltinsTests(unittest.TestCase): self.assertEqual(foo(), 7) -def test_main(): - run_unittest(RebindBuiltinsTests) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_dynamicclassattribute.py b/Lib/test/test_dynamicclassattribute.py index bc6a39b..9f694d9 100644 --- a/Lib/test/test_dynamicclassattribute.py +++ b/Lib/test/test_dynamicclassattribute.py @@ -4,7 +4,6 @@ import abc import sys import unittest -from test.support import run_unittest from types import DynamicClassAttribute class PropertyBase(Exception): @@ -297,8 +296,5 @@ class PropertySubclassTests(unittest.TestCase): -def test_main(): - run_unittest(PropertyTests, PropertySubclassTests) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_eintr.py b/Lib/test/test_eintr.py new file mode 100644 index 0000000..aabad83 --- /dev/null +++ b/Lib/test/test_eintr.py @@ -0,0 +1,30 @@ +import os +import signal +import subprocess +import sys +import unittest + +from test import support +from test.support import script_helper + + +@unittest.skipUnless(os.name == "posix", "only supported on Unix") +class EINTRTests(unittest.TestCase): + + @unittest.skipUnless(hasattr(signal, "setitimer"), "requires setitimer()") + def test_all(self): + # Run the tester in a sub-process, to make sure there is only one + # thread (for reliable signal delivery). + tester = support.findfile("eintr_tester.py", subdir="eintrdata") + + if support.verbose: + args = [sys.executable, tester] + with subprocess.Popen(args) as proc: + exitcode = proc.wait() + self.assertEqual(exitcode, 0) + else: + script_helper.assert_python_ok(tester) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_email/test_email.py b/Lib/test/test_email/test_email.py index 61e23fc..d7e3dca 100644 --- a/Lib/test/test_email/test_email.py +++ b/Lib/test/test_email/test_email.py @@ -590,6 +590,17 @@ class TestMessageAPI(TestEmailBase): eq(msg.values(), ['One Hundred', 'Twenty', 'Three', 'Eleven']) self.assertRaises(KeyError, msg.replace_header, 'Fourth', 'Missing') + def test_get_content_disposition(self): + msg = Message() + self.assertIsNone(msg.get_content_disposition()) + msg.add_header('Content-Disposition', 'attachment', + filename='random.avi') + self.assertEqual(msg.get_content_disposition(), 'attachment') + msg.replace_header('Content-Disposition', 'inline') + self.assertEqual(msg.get_content_disposition(), 'inline') + msg.replace_header('Content-Disposition', 'InlinE') + self.assertEqual(msg.get_content_disposition(), 'inline') + # test_defect_handling:test_invalid_chars_in_base64_payload def test_broken_base64_payload(self): x = 'AwDp0P7//y6LwKEAcPa/6Q=9' @@ -1640,6 +1651,10 @@ class TestMIMEText(unittest.TestCase): msg = MIMEText('hello there', _charset='us-ascii') eq(msg.get_charset().input_charset, 'us-ascii') eq(msg['content-type'], 'text/plain; charset="us-ascii"') + # Also accept a Charset instance + msg = MIMEText('hello there', _charset=Charset('utf-8')) + eq(msg.get_charset().input_charset, 'utf-8') + eq(msg['content-type'], 'text/plain; charset="utf-8"') def test_7bit_input(self): eq = self.assertEqual diff --git a/Lib/test/test_email/test_generator.py b/Lib/test/test_email/test_generator.py index 8917408..b1cbce2 100644 --- a/Lib/test/test_email/test_generator.py +++ b/Lib/test/test_email/test_generator.py @@ -2,6 +2,7 @@ import io import textwrap import unittest from email import message_from_string, message_from_bytes +from email.message import EmailMessage from email.generator import Generator, BytesGenerator from email import policy from test.test_email import TestEmailBase, parameterize @@ -139,6 +140,28 @@ class TestGeneratorBase: g.flatten(msg, linesep='\n') self.assertEqual(s.getvalue(), self.typ(expected)) + def test_set_mangle_from_via_policy(self): + source = textwrap.dedent("""\ + Subject: test that + from is mangeld in the body! + + From time to time I write a rhyme. + """) + variants = ( + (None, True), + (policy.compat32, True), + (policy.default, False), + (policy.default.clone(mangle_from_=True), True), + ) + for p, mangle in variants: + expected = source.replace('From ', '>From ') if mangle else source + with self.subTest(policy=p, mangle_from_=mangle): + msg = self.msgmaker(self.typ(source)) + s = self.ioclass() + g = self.genclass(s, policy=p) + g.flatten(msg) + self.assertEqual(s.getvalue(), self.typ(expected)) + class TestGenerator(TestGeneratorBase, TestEmailBase): @@ -194,6 +217,27 @@ class TestBytesGenerator(TestGeneratorBase, TestEmailBase): g.flatten(msg) self.assertEqual(s.getvalue(), expected) + def test_smtputf8_policy(self): + msg = EmailMessage() + msg['From'] = "Páolo <főo@bar.com>" + msg['To'] = 'Dinsdale' + msg['Subject'] = 'Nudge nudge, wink, wink \u1F609' + msg.set_content("oh là là, know what I mean, know what I mean?") + expected = textwrap.dedent("""\ + From: Páolo <főo@bar.com> + To: Dinsdale + Subject: Nudge nudge, wink, wink \u1F609 + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 8bit + MIME-Version: 1.0 + + oh là là, know what I mean, know what I mean? + """).encode('utf-8').replace(b'\n', b'\r\n') + s = io.BytesIO() + g = BytesGenerator(s, policy=policy.SMTPUTF8) + g.flatten(msg) + self.assertEqual(s.getvalue(), expected) + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_email/test_message.py b/Lib/test/test_email/test_message.py index 50e1a63..d78049e 100644 --- a/Lib/test/test_email/test_message.py +++ b/Lib/test/test_email/test_message.py @@ -723,24 +723,14 @@ class TestEmailMessageBase: def test_is_attachment(self): m = self._make_message() self.assertFalse(m.is_attachment()) - with self.assertWarns(DeprecationWarning): - self.assertFalse(m.is_attachment) m['Content-Disposition'] = 'inline' self.assertFalse(m.is_attachment()) - with self.assertWarns(DeprecationWarning): - self.assertFalse(m.is_attachment) m.replace_header('Content-Disposition', 'attachment') self.assertTrue(m.is_attachment()) - with self.assertWarns(DeprecationWarning): - self.assertTrue(m.is_attachment) m.replace_header('Content-Disposition', 'AtTachMent') self.assertTrue(m.is_attachment()) - with self.assertWarns(DeprecationWarning): - self.assertTrue(m.is_attachment) m.set_param('filename', 'abc.png', 'Content-Disposition') self.assertTrue(m.is_attachment()) - with self.assertWarns(DeprecationWarning): - self.assertTrue(m.is_attachment) class TestEmailMessage(TestEmailMessageBase, TestEmailBase): diff --git a/Lib/test/test_email/test_policy.py b/Lib/test/test_email/test_policy.py index e797f36..9bb32f0 100644 --- a/Lib/test/test_email/test_policy.py +++ b/Lib/test/test_email/test_policy.py @@ -22,15 +22,18 @@ class PolicyAPITests(unittest.TestCase): 'linesep': '\n', 'cte_type': '8bit', 'raise_on_defect': False, + 'mangle_from_': True, } # These default values are the ones set on email.policy.default. # If any of these defaults change, the docs must be updated. policy_defaults = compat32_defaults.copy() policy_defaults.update({ + 'utf8': False, 'raise_on_defect': False, 'header_factory': email.policy.EmailPolicy.header_factory, 'refold_source': 'long', 'content_manager': email.policy.EmailPolicy.content_manager, + 'mangle_from_': False, }) # For each policy under test, we give here what we expect the defaults to @@ -42,6 +45,9 @@ class PolicyAPITests(unittest.TestCase): email.policy.default: make_defaults(policy_defaults, {}), email.policy.SMTP: make_defaults(policy_defaults, {'linesep': '\r\n'}), + email.policy.SMTPUTF8: make_defaults(policy_defaults, + {'linesep': '\r\n', + 'utf8': True}), email.policy.HTTP: make_defaults(policy_defaults, {'linesep': '\r\n', 'max_line_length': None}), diff --git a/Lib/test/test_ensurepip.py b/Lib/test/test_ensurepip.py index 6dc764b..a78ca14 100644 --- a/Lib/test/test_ensurepip.py +++ b/Lib/test/test_ensurepip.py @@ -357,4 +357,4 @@ class TestUninstallationMainFunction(EnsurepipMixin, unittest.TestCase): if __name__ == "__main__": - test.support.run_unittest(__name__) + unittest.main() diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 5db4040..4b5d0d0 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -66,18 +66,14 @@ try: except Exception: pass -def test_pickle_dump_load(assertion, source, target=None, - *, protocol=(0, HIGHEST_PROTOCOL)): - start, stop = protocol +def test_pickle_dump_load(assertion, source, target=None): if target is None: target = source - for protocol in range(start, stop+1): + for protocol in range(HIGHEST_PROTOCOL + 1): assertion(loads(dumps(source, protocol=protocol)), target) -def test_pickle_exception(assertion, exception, obj, - *, protocol=(0, HIGHEST_PROTOCOL)): - start, stop = protocol - for protocol in range(start, stop+1): +def test_pickle_exception(assertion, exception, obj): + for protocol in range(HIGHEST_PROTOCOL + 1): with assertion(exception): dumps(obj, protocol=protocol) @@ -575,11 +571,7 @@ class TestEnum(unittest.TestCase): self.__class__.NestedEnum = NestedEnum self.NestedEnum.__qualname__ = '%s.NestedEnum' % self.__class__.__name__ - test_pickle_exception( - self.assertRaises, PicklingError, self.NestedEnum.twigs, - protocol=(0, 3)) - test_pickle_dump_load(self.assertIs, self.NestedEnum.twigs, - protocol=(4, HIGHEST_PROTOCOL)) + test_pickle_dump_load(self.assertIs, self.NestedEnum.twigs) def test_pickle_by_name(self): class ReplaceGlobalInt(IntEnum): @@ -654,6 +646,23 @@ class TestEnum(unittest.TestCase): self.assertIn(e, SummerMonth) self.assertIs(type(e), SummerMonth) + def test_programatic_function_string_with_start(self): + SummerMonth = Enum('SummerMonth', 'june july august', start=10) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 10): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + def test_programatic_function_string_list(self): SummerMonth = Enum('SummerMonth', ['june', 'july', 'august']) lst = list(SummerMonth) @@ -671,6 +680,23 @@ class TestEnum(unittest.TestCase): self.assertIn(e, SummerMonth) self.assertIs(type(e), SummerMonth) + def test_programatic_function_string_list_with_start(self): + SummerMonth = Enum('SummerMonth', ['june', 'july', 'august'], start=20) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 20): + e = SummerMonth(i) + self.assertEqual(int(e.value), i) + self.assertNotEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + def test_programatic_function_iterable(self): SummerMonth = Enum( 'SummerMonth', @@ -727,6 +753,22 @@ class TestEnum(unittest.TestCase): self.assertIn(e, SummerMonth) self.assertIs(type(e), SummerMonth) + def test_programatic_function_type_with_start(self): + SummerMonth = Enum('SummerMonth', 'june july august', type=int, start=30) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 30): + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + def test_programatic_function_type_from_subclass(self): SummerMonth = IntEnum('SummerMonth', 'june july august') lst = list(SummerMonth) @@ -743,6 +785,22 @@ class TestEnum(unittest.TestCase): self.assertIn(e, SummerMonth) self.assertIs(type(e), SummerMonth) + def test_programatic_function_type_from_subclass_with_start(self): + SummerMonth = IntEnum('SummerMonth', 'june july august', start=40) + lst = list(SummerMonth) + self.assertEqual(len(lst), len(SummerMonth)) + self.assertEqual(len(SummerMonth), 3, SummerMonth) + self.assertEqual( + [SummerMonth.june, SummerMonth.july, SummerMonth.august], + lst, + ) + for i, month in enumerate('june july august'.split(), 40): + e = SummerMonth(i) + self.assertEqual(e, i) + self.assertEqual(e.name, month) + self.assertIn(e, SummerMonth) + self.assertIs(type(e), SummerMonth) + def test_subclassing(self): if isinstance(Name, Exception): raise Name @@ -1030,9 +1088,9 @@ class TestEnum(unittest.TestCase): globals()['NEI'] = NEI NI5 = NamedInt('test', 5) self.assertEqual(NI5, 5) - test_pickle_dump_load(self.assertEqual, NI5, 5, protocol=(4, 4)) + test_pickle_dump_load(self.assertEqual, NI5, 5) self.assertEqual(NEI.y.value, 2) - test_pickle_dump_load(self.assertIs, NEI.y, protocol=(4, 4)) + test_pickle_dump_load(self.assertIs, NEI.y) test_pickle_dump_load(self.assertIs, NEI) def test_subclasses_with_reduce(self): @@ -1498,10 +1556,12 @@ class TestUnique(unittest.TestCase): turkey = 3 -expected_help_output = """ +expected_help_output_with_docs = """\ Help on class Color in module %s: class Color(enum.Enum) + | An enumeration. + |\x20\x20 | Method resolution order: | Color | enum.Enum @@ -1531,11 +1591,41 @@ class Color(enum.Enum) | Returns a mapping of member name->value. |\x20\x20\x20\x20\x20\x20 | This mapping lists all enum members, including aliases. Note that this - | is a read-only view of the internal mapping. -""".strip() + | is a read-only view of the internal mapping.""" + +expected_help_output_without_docs = """\ +Help on class Color in module %s: + +class Color(enum.Enum) + | Method resolution order: + | Color + | enum.Enum + | builtins.object + |\x20\x20 + | Data and other attributes defined here: + |\x20\x20 + | blue = <Color.blue: 3> + |\x20\x20 + | green = <Color.green: 2> + |\x20\x20 + | red = <Color.red: 1> + |\x20\x20 + | ---------------------------------------------------------------------- + | Data descriptors inherited from enum.Enum: + |\x20\x20 + | name + |\x20\x20 + | value + |\x20\x20 + | ---------------------------------------------------------------------- + | Data descriptors inherited from enum.EnumMeta: + |\x20\x20 + | __members__""" class TestStdLib(unittest.TestCase): + maxDiff = None + class Color(Enum): red = 1 green = 2 @@ -1543,7 +1633,10 @@ class TestStdLib(unittest.TestCase): def test_pydoc(self): # indirectly test __objclass__ - expected_text = expected_help_output % __name__ + if StrEnum.__doc__ is None: + expected_text = expected_help_output_without_docs % __name__ + else: + expected_text = expected_help_output_with_docs % __name__ output = StringIO() helper = pydoc.Helper(output=output) helper(self.Color) @@ -1553,7 +1646,7 @@ class TestStdLib(unittest.TestCase): def test_inspect_getmembers(self): values = dict(( ('__class__', EnumMeta), - ('__doc__', None), + ('__doc__', 'An enumeration.'), ('__members__', self.Color.__members__), ('__module__', __name__), ('blue', self.Color.blue), @@ -1581,7 +1674,7 @@ class TestStdLib(unittest.TestCase): Attribute(name='__class__', kind='data', defining_class=object, object=EnumMeta), Attribute(name='__doc__', kind='data', - defining_class=self.Color, object=None), + defining_class=self.Color, object='An enumeration.'), Attribute(name='__members__', kind='property', defining_class=EnumMeta, object=EnumMeta.__members__), Attribute(name='__module__', kind='data', diff --git a/Lib/test/test_enumerate.py b/Lib/test/test_enumerate.py index e85254c..2630cf2 100644 --- a/Lib/test/test_enumerate.py +++ b/Lib/test/test_enumerate.py @@ -258,16 +258,5 @@ class TestLongStart(EnumerateStartTestCase): (sys.maxsize+3,'c')] -def test_main(verbose=None): - support.run_unittest(__name__) - - # verify reference counting - if verbose and hasattr(sys, "gettotalrefcount"): - counts = [None] * 5 - for i in range(len(counts)): - support.run_unittest(__name__) - counts[i] = sys.gettotalrefcount() - print(counts) - if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_eof.py b/Lib/test/test_eof.py index 52e7932..7baa7ae 100644 --- a/Lib/test/test_eof.py +++ b/Lib/test/test_eof.py @@ -24,8 +24,5 @@ class EOFTestCase(unittest.TestCase): else: raise support.TestFailed -def test_main(): - support.run_unittest(EOFTestCase) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_epoll.py b/Lib/test/test_epoll.py index b37f033..a7359e9 100644 --- a/Lib/test/test_epoll.py +++ b/Lib/test/test_epoll.py @@ -44,7 +44,7 @@ class TestEPoll(unittest.TestCase): def setUp(self): self.serverSocket = socket.socket() self.serverSocket.bind(('127.0.0.1', 0)) - self.serverSocket.listen(1) + self.serverSocket.listen() self.connections = [self.serverSocket] def tearDown(self): @@ -252,8 +252,5 @@ class TestEPoll(unittest.TestCase): self.assertEqual(os.get_inheritable(epoll.fileno()), False) -def test_main(): - support.run_unittest(TestEPoll) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_errno.py b/Lib/test/test_errno.py index 058dcb9..5c437e9 100644 --- a/Lib/test/test_errno.py +++ b/Lib/test/test_errno.py @@ -3,7 +3,6 @@ """ import errno -from test import support import unittest std_c_errors = frozenset(['EDOM', 'ERANGE']) @@ -32,9 +31,5 @@ class ErrorcodeTests(unittest.TestCase): 'no %s attr in errno.errorcode' % attribute) -def test_main(): - support.run_unittest(ErrnoAttributeTests, ErrorcodeTests) - - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_exception_variations.py b/Lib/test/test_exception_variations.py index 11f5e5c..d874b0e 100644 --- a/Lib/test/test_exception_variations.py +++ b/Lib/test/test_exception_variations.py @@ -1,5 +1,4 @@ -from test.support import run_unittest import unittest class ExceptionTestCase(unittest.TestCase): @@ -173,8 +172,5 @@ class ExceptionTestCase(unittest.TestCase): self.assertTrue(hit_finally) self.assertTrue(hit_except) -def test_main(): - run_unittest(ExceptionTestCase) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py index 80d4f1a..32a66ea 100644 --- a/Lib/test/test_exceptions.py +++ b/Lib/test/test_exceptions.py @@ -84,6 +84,7 @@ class ExceptionTests(unittest.TestCase): x += x # this simply shouldn't blow up self.raise_catch(RuntimeError, "RuntimeError") + self.raise_catch(RecursionError, "RecursionError") self.raise_catch(SyntaxError, "SyntaxError") try: exec('/\n') @@ -117,6 +118,8 @@ class ExceptionTests(unittest.TestCase): try: x = 1/0 except Exception as e: pass + self.raise_catch(StopAsyncIteration, "StopAsyncIteration") + def testSyntaxErrorMessage(self): # make sure the right exception message is raised for each of # these code fragments @@ -474,14 +477,14 @@ class ExceptionTests(unittest.TestCase): def testInfiniteRecursion(self): def f(): return f() - self.assertRaises(RuntimeError, f) + self.assertRaises(RecursionError, f) def g(): try: return g() except ValueError: return -1 - self.assertRaises(RuntimeError, g) + self.assertRaises(RecursionError, g) def test_str(self): # Make sure both instances and classes have a str representation. @@ -887,10 +890,10 @@ class ExceptionTests(unittest.TestCase): def g(): try: return g() - except RuntimeError: + except RecursionError: return sys.exc_info() e, v, tb = g() - self.assertTrue(isinstance(v, RuntimeError), type(v)) + self.assertTrue(isinstance(v, RecursionError), type(v)) self.assertIn("maximum recursion depth exceeded", str(v)) @@ -989,10 +992,10 @@ class ExceptionTests(unittest.TestCase): # We cannot use assertRaises since it manually deletes the traceback try: inner() - except RuntimeError as e: + except RecursionError as e: self.assertNotEqual(wr(), None) else: - self.fail("RuntimeError not raised") + self.fail("RecursionError not raised") self.assertEqual(wr(), None) def test_errno_ENOTDIR(self): diff --git a/Lib/test/test_extcall.py b/Lib/test/test_extcall.py index 6b6c12d..654258e 100644 --- a/Lib/test/test_extcall.py +++ b/Lib/test/test_extcall.py @@ -34,17 +34,37 @@ Argument list examples (1, 2, 3, 4, 5) {} >>> f(1, 2, 3, *[4, 5]) (1, 2, 3, 4, 5) {} + >>> f(*[1, 2, 3], 4, 5) + (1, 2, 3, 4, 5) {} >>> f(1, 2, 3, *UserList([4, 5])) (1, 2, 3, 4, 5) {} + >>> f(1, 2, 3, *[4, 5], *[6, 7]) + (1, 2, 3, 4, 5, 6, 7) {} + >>> f(1, *[2, 3], 4, *[5, 6], 7) + (1, 2, 3, 4, 5, 6, 7) {} + >>> f(*UserList([1, 2]), *UserList([3, 4]), 5, *UserList([6, 7])) + (1, 2, 3, 4, 5, 6, 7) {} Here we add keyword arguments >>> f(1, 2, 3, **{'a':4, 'b':5}) (1, 2, 3) {'a': 4, 'b': 5} + >>> f(1, 2, **{'a': -1, 'b': 5}, **{'a': 4, 'c': 6}) + Traceback (most recent call last): + ... + TypeError: f() got multiple values for keyword argument 'a' + >>> f(1, 2, **{'a': -1, 'b': 5}, a=4, c=6) + Traceback (most recent call last): + ... + TypeError: f() got multiple values for keyword argument 'a' >>> f(1, 2, 3, *[4, 5], **{'a':6, 'b':7}) (1, 2, 3, 4, 5) {'a': 6, 'b': 7} >>> f(1, 2, 3, x=4, y=5, *(6, 7), **{'a':8, 'b': 9}) (1, 2, 3, 6, 7) {'a': 8, 'b': 9, 'x': 4, 'y': 5} + >>> f(1, 2, 3, *[4, 5], **{'c': 8}, **{'a':6, 'b':7}) + (1, 2, 3, 4, 5) {'a': 6, 'b': 7, 'c': 8} + >>> f(1, 2, 3, *(4, 5), x=6, y=7, **{'a':8, 'b': 9}) + (1, 2, 3, 4, 5) {'a': 8, 'b': 9, 'x': 6, 'y': 7} >>> f(1, 2, 3, **UserDict(a=4, b=5)) (1, 2, 3) {'a': 4, 'b': 5} @@ -52,6 +72,8 @@ Here we add keyword arguments (1, 2, 3, 4, 5) {'a': 6, 'b': 7} >>> f(1, 2, 3, x=4, y=5, *(6, 7), **UserDict(a=8, b=9)) (1, 2, 3, 6, 7) {'a': 8, 'b': 9, 'x': 4, 'y': 5} + >>> f(1, 2, 3, *(4, 5), x=6, y=7, **UserDict(a=8, b=9)) + (1, 2, 3, 4, 5) {'a': 8, 'b': 9, 'x': 6, 'y': 7} Examples with invalid arguments (TypeErrors). We're also testing the function names in the exception messages. diff --git a/Lib/test/test_faulthandler.py b/Lib/test/test_faulthandler.py index e68a09e..0d86cb5 100644 --- a/Lib/test/test_faulthandler.py +++ b/Lib/test/test_faulthandler.py @@ -6,8 +6,8 @@ import re import signal import subprocess import sys -from test import support, script_helper -from test.script_helper import assert_python_ok +from test import support +from test.support import script_helper import tempfile import unittest from textwrap import dedent @@ -17,6 +17,10 @@ try: HAVE_THREADS = True except ImportError: HAVE_THREADS = False +try: + import _testcapi +except ImportError: + _testcapi = None TIMEOUT = 0.5 @@ -38,7 +42,7 @@ def temporary_filename(): support.unlink(filename) class FaultHandlerTests(unittest.TestCase): - def get_output(self, code, filename=None): + def get_output(self, code, filename=None, fd=None): """ Run the specified code in Python (in a new child process) and read the output from the standard error or from a file (if filename is set). @@ -49,8 +53,11 @@ class FaultHandlerTests(unittest.TestCase): thread XXX". """ code = dedent(code).strip() + pass_fds = [] + if fd is not None: + pass_fds.append(fd) with support.SuppressCrashReport(): - process = script_helper.spawn_python('-c', code) + process = script_helper.spawn_python('-c', code, pass_fds=pass_fds) stdout, stderr = process.communicate() exitcode = process.wait() output = support.strip_python_stderr(stdout) @@ -60,13 +67,20 @@ class FaultHandlerTests(unittest.TestCase): with open(filename, "rb") as fp: output = fp.read() output = output.decode('ascii', 'backslashreplace') + elif fd is not None: + self.assertEqual(output, '') + os.lseek(fd, os.SEEK_SET, 0) + with open(fd, "rb", closefd=False) as fp: + output = fp.read() + output = output.decode('ascii', 'backslashreplace') output = re.sub('Current thread 0x[0-9a-f]+', 'Current thread XXX', output) return output.splitlines(), exitcode def check_fatal_error(self, code, line_number, name_regex, - filename=None, all_threads=True, other_regex=None): + filename=None, all_threads=True, other_regex=None, + fd=None): """ Check that the fault handler for fatal errors is enabled and check the traceback from the child process output. @@ -89,7 +103,7 @@ class FaultHandlerTests(unittest.TestCase): header=re.escape(header))).strip() if other_regex: regex += '|' + other_regex - output, exitcode = self.get_output(code, filename) + output, exitcode = self.get_output(code, filename=filename, fd=fd) output = '\n'.join(output) self.assertRegex(output, regex) self.assertNotEqual(exitcode, 0) @@ -135,26 +149,32 @@ class FaultHandlerTests(unittest.TestCase): 3, 'Floating point exception') - @unittest.skipIf(not hasattr(faulthandler, '_sigbus'), - "need faulthandler._sigbus()") + @unittest.skipIf(_testcapi is None, 'need _testcapi') + @unittest.skipUnless(hasattr(signal, 'SIGBUS'), 'need signal.SIGBUS') def test_sigbus(self): self.check_fatal_error(""" + import _testcapi import faulthandler + import signal + faulthandler.enable() - faulthandler._sigbus() + _testcapi.raise_signal(signal.SIGBUS) """, - 3, + 6, 'Bus error') - @unittest.skipIf(not hasattr(faulthandler, '_sigill'), - "need faulthandler._sigill()") + @unittest.skipIf(_testcapi is None, 'need _testcapi') + @unittest.skipUnless(hasattr(signal, 'SIGILL'), 'need signal.SIGILL') def test_sigill(self): self.check_fatal_error(""" + import _testcapi import faulthandler + import signal + faulthandler.enable() - faulthandler._sigill() + _testcapi.raise_signal(signal.SIGILL) """, - 3, + 6, 'Illegal instruction') def test_fatal_error(self): @@ -201,6 +221,21 @@ class FaultHandlerTests(unittest.TestCase): 'Segmentation fault', filename=filename) + @unittest.skipIf(sys.platform == "win32", + "subprocess doesn't support pass_fds on Windows") + def test_enable_fd(self): + with tempfile.TemporaryFile('wb+') as fp: + fd = fp.fileno() + self.check_fatal_error(""" + import faulthandler + import sys + faulthandler.enable(%s) + faulthandler._sigsegv() + """ % fd, + 4, + 'Segmentation fault', + fd=fd) + def test_enable_single_thread(self): self.check_fatal_error(""" import faulthandler @@ -287,7 +322,7 @@ class FaultHandlerTests(unittest.TestCase): output = subprocess.check_output(args, env=env) self.assertEqual(output.rstrip(), b"True") - def check_dump_traceback(self, filename): + def check_dump_traceback(self, *, filename=None, fd=None): """ Explicitly call dump_traceback() function and check its output. Raise an error if the output doesn't match the expected format. @@ -295,10 +330,16 @@ class FaultHandlerTests(unittest.TestCase): code = """ import faulthandler + filename = {filename!r} + fd = {fd} + def funcB(): - if {has_filename}: - with open({filename}, "wb") as fp: + if filename: + with open(filename, "wb") as fp: faulthandler.dump_traceback(fp, all_threads=False) + elif fd is not None: + faulthandler.dump_traceback(fd, + all_threads=False) else: faulthandler.dump_traceback(all_threads=False) @@ -308,29 +349,37 @@ class FaultHandlerTests(unittest.TestCase): funcA() """ code = code.format( - filename=repr(filename), - has_filename=bool(filename), + filename=filename, + fd=fd, ) if filename: - lineno = 6 + lineno = 9 + elif fd is not None: + lineno = 12 else: - lineno = 8 + lineno = 14 expected = [ 'Stack (most recent call first):', ' File "<string>", line %s in funcB' % lineno, - ' File "<string>", line 11 in funcA', - ' File "<string>", line 13 in <module>' + ' File "<string>", line 17 in funcA', + ' File "<string>", line 19 in <module>' ] - trace, exitcode = self.get_output(code, filename) + trace, exitcode = self.get_output(code, filename, fd) self.assertEqual(trace, expected) self.assertEqual(exitcode, 0) def test_dump_traceback(self): - self.check_dump_traceback(None) + self.check_dump_traceback() def test_dump_traceback_file(self): with temporary_filename() as filename: - self.check_dump_traceback(filename) + self.check_dump_traceback(filename=filename) + + @unittest.skipIf(sys.platform == "win32", + "subprocess doesn't support pass_fds on Windows") + def test_dump_traceback_fd(self): + with tempfile.TemporaryFile('wb+') as fp: + self.check_dump_traceback(fd=fp.fileno()) def test_truncate(self): maxlen = 500 @@ -423,7 +472,10 @@ class FaultHandlerTests(unittest.TestCase): with temporary_filename() as filename: self.check_dump_traceback_threads(filename) - def _check_dump_traceback_later(self, repeat, cancel, filename, loops): + @unittest.skipIf(not hasattr(faulthandler, 'dump_traceback_later'), + 'need faulthandler.dump_traceback_later()') + def check_dump_traceback_later(self, repeat=False, cancel=False, loops=1, + *, filename=None, fd=None): """ Check how many times the traceback is written in timeout x 2.5 seconds, or timeout x 3.5 seconds if cancel is True: 1, 2 or 3 times depending @@ -435,6 +487,14 @@ class FaultHandlerTests(unittest.TestCase): code = """ import faulthandler import time + import sys + + timeout = {timeout} + repeat = {repeat} + cancel = {cancel} + loops = {loops} + filename = {filename!r} + fd = {fd} def func(timeout, repeat, cancel, file, loops): for loop in range(loops): @@ -444,16 +504,14 @@ class FaultHandlerTests(unittest.TestCase): time.sleep(timeout * 5) faulthandler.cancel_dump_traceback_later() - timeout = {timeout} - repeat = {repeat} - cancel = {cancel} - loops = {loops} - if {has_filename}: - file = open({filename}, "wb") + if filename: + file = open(filename, "wb") + elif fd is not None: + file = sys.stderr.fileno() else: file = None func(timeout, repeat, cancel, file, loops) - if file is not None: + if filename: file.close() """ code = code.format( @@ -461,8 +519,8 @@ class FaultHandlerTests(unittest.TestCase): repeat=repeat, cancel=cancel, loops=loops, - has_filename=bool(filename), - filename=repr(filename), + filename=filename, + fd=fd, ) trace, exitcode = self.get_output(code, filename) trace = '\n'.join(trace) @@ -472,27 +530,12 @@ class FaultHandlerTests(unittest.TestCase): if repeat: count *= 2 header = r'Timeout \(%s\)!\nThread 0x[0-9a-f]+ \(most recent call first\):\n' % timeout_str - regex = expected_traceback(9, 20, header, min_count=count) + regex = expected_traceback(17, 26, header, min_count=count) self.assertRegex(trace, regex) else: self.assertEqual(trace, '') self.assertEqual(exitcode, 0) - @unittest.skipIf(not hasattr(faulthandler, 'dump_traceback_later'), - 'need faulthandler.dump_traceback_later()') - def check_dump_traceback_later(self, repeat=False, cancel=False, - file=False, twice=False): - if twice: - loops = 2 - else: - loops = 1 - if file: - with temporary_filename() as filename: - self._check_dump_traceback_later(repeat, cancel, - filename, loops) - else: - self._check_dump_traceback_later(repeat, cancel, None, loops) - def test_dump_traceback_later(self): self.check_dump_traceback_later() @@ -503,15 +546,22 @@ class FaultHandlerTests(unittest.TestCase): self.check_dump_traceback_later(cancel=True) def test_dump_traceback_later_file(self): - self.check_dump_traceback_later(file=True) + with temporary_filename() as filename: + self.check_dump_traceback_later(filename=filename) + + @unittest.skipIf(sys.platform == "win32", + "subprocess doesn't support pass_fds on Windows") + def test_dump_traceback_later_fd(self): + with tempfile.TemporaryFile('wb+') as fp: + self.check_dump_traceback_later(fd=fp.fileno()) def test_dump_traceback_later_twice(self): - self.check_dump_traceback_later(twice=True) + self.check_dump_traceback_later(loops=2) @unittest.skipIf(not hasattr(faulthandler, "register"), "need faulthandler.register") def check_register(self, filename=False, all_threads=False, - unregister=False, chain=False): + unregister=False, chain=False, fd=None): """ Register a handler displaying the traceback on a user signal. Raise the signal and check the written traceback. @@ -527,6 +577,13 @@ class FaultHandlerTests(unittest.TestCase): import signal import sys + all_threads = {all_threads} + signum = {signum} + unregister = {unregister} + chain = {chain} + filename = {filename!r} + fd = {fd} + def func(signum): os.kill(os.getpid(), signum) @@ -534,19 +591,16 @@ class FaultHandlerTests(unittest.TestCase): handler.called = True handler.called = False - exitcode = 0 - signum = {signum} - unregister = {unregister} - chain = {chain} - - if {has_filename}: - file = open({filename}, "wb") + if filename: + file = open(filename, "wb") + elif fd is not None: + file = sys.stderr.fileno() else: file = None if chain: signal.signal(signum, handler) faulthandler.register(signum, file=file, - all_threads={all_threads}, chain={chain}) + all_threads=all_threads, chain={chain}) if unregister: faulthandler.unregister(signum) func(signum) @@ -557,17 +611,19 @@ class FaultHandlerTests(unittest.TestCase): output = sys.stderr print("Error: signal handler not called!", file=output) exitcode = 1 - if file is not None: + else: + exitcode = 0 + if filename: file.close() sys.exit(exitcode) """ code = code.format( - filename=repr(filename), - has_filename=bool(filename), all_threads=all_threads, signum=signum, unregister=unregister, chain=chain, + filename=filename, + fd=fd, ) trace, exitcode = self.get_output(code, filename) trace = '\n'.join(trace) @@ -576,7 +632,7 @@ class FaultHandlerTests(unittest.TestCase): regex = 'Current thread XXX \(most recent call first\):\n' else: regex = 'Stack \(most recent call first\):\n' - regex = expected_traceback(7, 28, regex) + regex = expected_traceback(14, 32, regex) self.assertRegex(trace, regex) else: self.assertEqual(trace, '') @@ -595,6 +651,12 @@ class FaultHandlerTests(unittest.TestCase): with temporary_filename() as filename: self.check_register(filename=filename) + @unittest.skipIf(sys.platform == "win32", + "subprocess doesn't support pass_fds on Windows") + def test_register_fd(self): + with tempfile.TemporaryFile('wb+') as fp: + self.check_register(fd=fp.fileno()) + def test_register_threads(self): self.check_register(all_threads=True) diff --git a/Lib/test/test_file_eintr.py b/Lib/test/test_file_eintr.py index b4e18ce..f1efd26 100644 --- a/Lib/test/test_file_eintr.py +++ b/Lib/test/test_file_eintr.py @@ -13,16 +13,16 @@ import select import signal import subprocess import sys -from test.support import run_unittest import time import unittest # Test import all of the things we're about to try testing up front. -from _io import FileIO +import _io +import _pyio @unittest.skipUnless(os.name == 'posix', 'tests requires a posix system.') -class TestFileIOSignalInterrupt(unittest.TestCase): +class TestFileIOSignalInterrupt: def setUp(self): self._process = None @@ -38,8 +38,9 @@ class TestFileIOSignalInterrupt(unittest.TestCase): subclasseses should override this to test different IO objects. """ - return ('import _io ;' - 'infile = _io.FileIO(sys.stdin.fileno(), "rb")') + return ('import %s as io ;' + 'infile = io.FileIO(sys.stdin.fileno(), "rb")' % + self.modname) def fail_with_process_info(self, why, stdout=b'', stderr=b'', communicate=True): @@ -179,11 +180,19 @@ class TestFileIOSignalInterrupt(unittest.TestCase): expected=b'hello\nworld!\n')) +class CTestFileIOSignalInterrupt(TestFileIOSignalInterrupt, unittest.TestCase): + modname = '_io' + +class PyTestFileIOSignalInterrupt(TestFileIOSignalInterrupt, unittest.TestCase): + modname = '_pyio' + + class TestBufferedIOSignalInterrupt(TestFileIOSignalInterrupt): def _generate_infile_setup_code(self): """Returns the infile = ... line of code to make a BufferedReader.""" - return ('infile = open(sys.stdin.fileno(), "rb") ;' - 'import _io ;assert isinstance(infile, _io.BufferedReader)') + return ('import %s as io ;infile = io.open(sys.stdin.fileno(), "rb") ;' + 'assert isinstance(infile, io.BufferedReader)' % + self.modname) def test_readall(self): """BufferedReader.read() must handle signals and not lose data.""" @@ -193,12 +202,20 @@ class TestBufferedIOSignalInterrupt(TestFileIOSignalInterrupt): read_method_name='read', expected=b'hello\nworld!\n')) +class CTestBufferedIOSignalInterrupt(TestBufferedIOSignalInterrupt, unittest.TestCase): + modname = '_io' + +class PyTestBufferedIOSignalInterrupt(TestBufferedIOSignalInterrupt, unittest.TestCase): + modname = '_pyio' + class TestTextIOSignalInterrupt(TestFileIOSignalInterrupt): def _generate_infile_setup_code(self): """Returns the infile = ... line of code to make a TextIOWrapper.""" - return ('infile = open(sys.stdin.fileno(), "rt", newline=None) ;' - 'import _io ;assert isinstance(infile, _io.TextIOWrapper)') + return ('import %s as io ;' + 'infile = io.open(sys.stdin.fileno(), "rt", newline=None) ;' + 'assert isinstance(infile, io.TextIOWrapper)' % + self.modname) def test_readline(self): """readline() must handle signals and not lose data.""" @@ -224,13 +241,12 @@ class TestTextIOSignalInterrupt(TestFileIOSignalInterrupt): read_method_name='read', expected="hello\nworld!\n")) +class CTestTextIOSignalInterrupt(TestTextIOSignalInterrupt, unittest.TestCase): + modname = '_io' -def test_main(): - test_cases = [ - tc for tc in globals().values() - if isinstance(tc, type) and issubclass(tc, unittest.TestCase)] - run_unittest(*test_cases) +class PyTestTextIOSignalInterrupt(TestTextIOSignalInterrupt, unittest.TestCase): + modname = '_pyio' if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_fileio.py b/Lib/test/test_fileio.py index a4fd20d..59cc38f 100644 --- a/Lib/test/test_fileio.py +++ b/Lib/test/test_fileio.py @@ -12,13 +12,15 @@ from functools import wraps from test.support import TESTFN, check_warnings, run_unittest, make_bad_fd, cpython_only from collections import UserList -from _io import FileIO as _FileIO +import _io # C implementation of io +import _pyio # Python implementation of io -class AutoFileTests(unittest.TestCase): + +class AutoFileTests: # file tests for which a test file is automatically set up def setUp(self): - self.f = _FileIO(TESTFN, 'w') + self.f = self.FileIO(TESTFN, 'w') def tearDown(self): if self.f: @@ -60,20 +62,69 @@ class AutoFileTests(unittest.TestCase): self.assertRaises((AttributeError, TypeError), setattr, f, attr, 'oops') - def testReadinto(self): - # verify readinto - self.f.write(bytes([1, 2])) + def testBlksize(self): + # test private _blksize attribute + blksize = io.DEFAULT_BUFFER_SIZE + # try to get preferred blksize from stat.st_blksize, if available + if hasattr(os, 'fstat'): + fst = os.fstat(self.f.fileno()) + blksize = getattr(fst, 'st_blksize', blksize) + self.assertEqual(self.f._blksize, blksize) + + # verify readinto + def testReadintoByteArray(self): + self.f.write(bytes([1, 2, 0, 255])) self.f.close() - a = array('b', b'x'*10) - self.f = _FileIO(TESTFN, 'r') - n = self.f.readinto(a) - self.assertEqual(array('b', [1, 2]), a[:n]) + + ba = bytearray(b'abcdefgh') + with self.FileIO(TESTFN, 'r') as f: + n = f.readinto(ba) + self.assertEqual(ba, b'\x01\x02\x00\xffefgh') + self.assertEqual(n, 4) + + def _testReadintoMemoryview(self): + self.f.write(bytes([1, 2, 0, 255])) + self.f.close() + + m = memoryview(bytearray(b'abcdefgh')) + with self.FileIO(TESTFN, 'r') as f: + n = f.readinto(m) + self.assertEqual(m, b'\x01\x02\x00\xffefgh') + self.assertEqual(n, 4) + + m = memoryview(bytearray(b'abcdefgh')).cast('H', shape=[2, 2]) + with self.FileIO(TESTFN, 'r') as f: + n = f.readinto(m) + self.assertEqual(bytes(m), b'\x01\x02\x00\xffefgh') + self.assertEqual(n, 4) + + def _testReadintoArray(self): + self.f.write(bytes([1, 2, 0, 255])) + self.f.close() + + a = array('B', b'abcdefgh') + with self.FileIO(TESTFN, 'r') as f: + n = f.readinto(a) + self.assertEqual(a, array('B', [1, 2, 0, 255, 101, 102, 103, 104])) + self.assertEqual(n, 4) + + a = array('b', b'abcdefgh') + with self.FileIO(TESTFN, 'r') as f: + n = f.readinto(a) + self.assertEqual(a, array('b', [1, 2, 0, -1, 101, 102, 103, 104])) + self.assertEqual(n, 4) + + a = array('I', b'abcdefgh') + with self.FileIO(TESTFN, 'r') as f: + n = f.readinto(a) + self.assertEqual(a, array('I', b'\x01\x02\x00\xffefgh')) + self.assertEqual(n, 4) def testWritelinesList(self): l = [b'123', b'456'] self.f.writelines(l) self.f.close() - self.f = _FileIO(TESTFN, 'rb') + self.f = self.FileIO(TESTFN, 'rb') buf = self.f.read() self.assertEqual(buf, b'123456') @@ -81,7 +132,7 @@ class AutoFileTests(unittest.TestCase): l = UserList([b'123', b'456']) self.f.writelines(l) self.f.close() - self.f = _FileIO(TESTFN, 'rb') + self.f = self.FileIO(TESTFN, 'rb') buf = self.f.read() self.assertEqual(buf, b'123456') @@ -93,7 +144,7 @@ class AutoFileTests(unittest.TestCase): def test_none_args(self): self.f.write(b"hi\nbye\nabc") self.f.close() - self.f = _FileIO(TESTFN, 'r') + self.f = self.FileIO(TESTFN, 'r') self.assertEqual(self.f.read(None), b"hi\nbye\nabc") self.f.seek(0) self.assertEqual(self.f.readline(None), b"hi\n") @@ -103,13 +154,26 @@ class AutoFileTests(unittest.TestCase): self.assertRaises(TypeError, self.f.write, "Hello!") def testRepr(self): - self.assertEqual(repr(self.f), "<_io.FileIO name=%r mode=%r>" - % (self.f.name, self.f.mode)) + self.assertEqual(repr(self.f), + "<%s.FileIO name=%r mode=%r closefd=True>" % + (self.modulename, self.f.name, self.f.mode)) del self.f.name - self.assertEqual(repr(self.f), "<_io.FileIO fd=%r mode=%r>" - % (self.f.fileno(), self.f.mode)) + self.assertEqual(repr(self.f), + "<%s.FileIO fd=%r mode=%r closefd=True>" % + (self.modulename, self.f.fileno(), self.f.mode)) self.f.close() - self.assertEqual(repr(self.f), "<_io.FileIO [closed]>") + self.assertEqual(repr(self.f), + "<%s.FileIO [closed]>" % (self.modulename,)) + + def testReprNoCloseFD(self): + fd = os.open(TESTFN, os.O_RDONLY) + try: + with self.FileIO(fd, 'r', closefd=False) as f: + self.assertEqual(repr(f), + "<%s.FileIO name=%r mode=%r closefd=False>" % + (self.modulename, f.name, f.mode)) + finally: + os.close(fd) def testErrors(self): f = self.f @@ -119,7 +183,7 @@ class AutoFileTests(unittest.TestCase): self.assertRaises(ValueError, f.read, 10) # Open for reading f.close() self.assertTrue(f.closed) - f = _FileIO(TESTFN, 'r') + f = self.FileIO(TESTFN, 'r') self.assertRaises(TypeError, f.readinto, "") self.assertFalse(f.closed) f.close() @@ -138,11 +202,11 @@ class AutoFileTests(unittest.TestCase): # should raise on closed file self.assertRaises(ValueError, method) - self.assertRaises(ValueError, self.f.readinto) # XXX should be TypeError? + self.assertRaises(TypeError, self.f.readinto) self.assertRaises(ValueError, self.f.readinto, bytearray(1)) - self.assertRaises(ValueError, self.f.seek) + self.assertRaises(TypeError, self.f.seek) self.assertRaises(ValueError, self.f.seek, 0) - self.assertRaises(ValueError, self.f.write) + self.assertRaises(TypeError, self.f.write) self.assertRaises(ValueError, self.f.write, b'') self.assertRaises(TypeError, self.f.writelines) self.assertRaises(ValueError, self.f.writelines, b'') @@ -150,9 +214,9 @@ class AutoFileTests(unittest.TestCase): def testOpendir(self): # Issue 3703: opening a directory should fill the errno # Windows always returns "[Errno 13]: Permission denied - # Unix calls dircheck() and returns "[Errno 21]: Is a directory" + # Unix uses fstat and returns "[Errno 21]: Is a directory" try: - _FileIO('.', 'r') + self.FileIO('.', 'r') except OSError as e: self.assertNotEqual(e.errno, 0) self.assertEqual(e.filename, ".") @@ -163,7 +227,7 @@ class AutoFileTests(unittest.TestCase): def testOpenDirFD(self): fd = os.open('.', os.O_RDONLY) with self.assertRaises(OSError) as cm: - _FileIO(fd, 'r') + self.FileIO(fd, 'r') os.close(fd) self.assertEqual(cm.exception.errno, errno.EISDIR) @@ -248,7 +312,7 @@ class AutoFileTests(unittest.TestCase): self.f.close() except OSError: pass - self.f = _FileIO(TESTFN, 'r') + self.f = self.FileIO(TESTFN, 'r') os.close(self.f.fileno()) return self.f @@ -268,23 +332,32 @@ class AutoFileTests(unittest.TestCase): a = array('b', b'x'*10) f.readinto(a) -class OtherFileTests(unittest.TestCase): +class CAutoFileTests(AutoFileTests, unittest.TestCase): + FileIO = _io.FileIO + modulename = '_io' + +class PyAutoFileTests(AutoFileTests, unittest.TestCase): + FileIO = _pyio.FileIO + modulename = '_pyio' + + +class OtherFileTests: def testAbles(self): try: - f = _FileIO(TESTFN, "w") + f = self.FileIO(TESTFN, "w") self.assertEqual(f.readable(), False) self.assertEqual(f.writable(), True) self.assertEqual(f.seekable(), True) f.close() - f = _FileIO(TESTFN, "r") + f = self.FileIO(TESTFN, "r") self.assertEqual(f.readable(), True) self.assertEqual(f.writable(), False) self.assertEqual(f.seekable(), True) f.close() - f = _FileIO(TESTFN, "a+") + f = self.FileIO(TESTFN, "a+") self.assertEqual(f.readable(), True) self.assertEqual(f.writable(), True) self.assertEqual(f.seekable(), True) @@ -293,7 +366,7 @@ class OtherFileTests(unittest.TestCase): if sys.platform != "win32": try: - f = _FileIO("/dev/tty", "a") + f = self.FileIO("/dev/tty", "a") except OSError: # When run in a cron job there just aren't any # ttys, so skip the test. This also handles other @@ -316,7 +389,7 @@ class OtherFileTests(unittest.TestCase): # check invalid mode strings for mode in ("", "aU", "wU+", "rw", "rt"): try: - f = _FileIO(TESTFN, mode) + f = self.FileIO(TESTFN, mode) except ValueError: pass else: @@ -332,7 +405,7 @@ class OtherFileTests(unittest.TestCase): ('ab+', 'ab+'), ('a+b', 'ab+'), ('r', 'rb'), ('rb', 'rb'), ('rb+', 'rb+'), ('r+b', 'rb+')]: # read modes are last so that TESTFN will exist first - with _FileIO(TESTFN, modes[0]) as f: + with self.FileIO(TESTFN, modes[0]) as f: self.assertEqual(f.mode, modes[1]) finally: if os.path.exists(TESTFN): @@ -340,7 +413,7 @@ class OtherFileTests(unittest.TestCase): def testUnicodeOpen(self): # verify repr works for unicode too - f = _FileIO(str(TESTFN), "w") + f = self.FileIO(str(TESTFN), "w") f.close() os.unlink(TESTFN) @@ -350,7 +423,7 @@ class OtherFileTests(unittest.TestCase): fn = TESTFN.encode("ascii") except UnicodeEncodeError: self.skipTest('could not encode %r to ascii' % TESTFN) - f = _FileIO(fn, "w") + f = self.FileIO(fn, "w") try: f.write(b"abc") f.close() @@ -361,28 +434,21 @@ class OtherFileTests(unittest.TestCase): def testConstructorHandlesNULChars(self): fn_with_NUL = 'foo\0bar' - self.assertRaises(TypeError, _FileIO, fn_with_NUL, 'w') - self.assertRaises(TypeError, _FileIO, bytes(fn_with_NUL, 'ascii'), 'w') + self.assertRaises(ValueError, self.FileIO, fn_with_NUL, 'w') + self.assertRaises(ValueError, self.FileIO, bytes(fn_with_NUL, 'ascii'), 'w') def testInvalidFd(self): - self.assertRaises(ValueError, _FileIO, -10) - self.assertRaises(OSError, _FileIO, make_bad_fd()) + self.assertRaises(ValueError, self.FileIO, -10) + self.assertRaises(OSError, self.FileIO, make_bad_fd()) if sys.platform == 'win32': import msvcrt self.assertRaises(OSError, msvcrt.get_osfhandle, make_bad_fd()) - @cpython_only - def testInvalidFd_overflow(self): - # Issue 15989 - import _testcapi - self.assertRaises(TypeError, _FileIO, _testcapi.INT_MAX + 1) - self.assertRaises(TypeError, _FileIO, _testcapi.INT_MIN - 1) - def testBadModeArgument(self): # verify that we get a sensible error message for bad mode argument bad_mode = "qwerty" try: - f = _FileIO(TESTFN, bad_mode) + f = self.FileIO(TESTFN, bad_mode) except ValueError as msg: if msg.args[0] != 0: s = str(msg) @@ -395,7 +461,7 @@ class OtherFileTests(unittest.TestCase): self.fail("no error for invalid mode: %s" % bad_mode) def testTruncate(self): - f = _FileIO(TESTFN, 'w') + f = self.FileIO(TESTFN, 'w') f.write(bytes(bytearray(range(10)))) self.assertEqual(f.tell(), 10) f.truncate(5) @@ -410,11 +476,11 @@ class OtherFileTests(unittest.TestCase): def bug801631(): # SF bug <http://www.python.org/sf/801631> # "file.truncate fault on windows" - f = _FileIO(TESTFN, 'w') + f = self.FileIO(TESTFN, 'w') f.write(bytes(range(11))) f.close() - f = _FileIO(TESTFN,'r+') + f = self.FileIO(TESTFN,'r+') data = f.read(5) if data != bytes(range(5)): self.fail("Read on file opened for update failed %r" % data) @@ -454,19 +520,19 @@ class OtherFileTests(unittest.TestCase): pass def testInvalidInit(self): - self.assertRaises(TypeError, _FileIO, "1", 0, 0) + self.assertRaises(TypeError, self.FileIO, "1", 0, 0) def testWarnings(self): with check_warnings(quiet=True) as w: self.assertEqual(w.warnings, []) - self.assertRaises(TypeError, _FileIO, []) + self.assertRaises(TypeError, self.FileIO, []) self.assertEqual(w.warnings, []) - self.assertRaises(ValueError, _FileIO, "/some/invalid/name", "rt") + self.assertRaises(ValueError, self.FileIO, "/some/invalid/name", "rt") self.assertEqual(w.warnings, []) def testUnclosedFDOnException(self): class MyException(Exception): pass - class MyFileIO(_FileIO): + class MyFileIO(self.FileIO): def __setattr__(self, name, value): if name == "name": raise MyException("blocked setting name") @@ -475,12 +541,28 @@ class OtherFileTests(unittest.TestCase): self.assertRaises(MyException, MyFileIO, fd) os.close(fd) # should not raise OSError(EBADF) +class COtherFileTests(OtherFileTests, unittest.TestCase): + FileIO = _io.FileIO + modulename = '_io' + + @cpython_only + def testInvalidFd_overflow(self): + # Issue 15989 + import _testcapi + self.assertRaises(TypeError, self.FileIO, _testcapi.INT_MAX + 1) + self.assertRaises(TypeError, self.FileIO, _testcapi.INT_MIN - 1) + +class PyOtherFileTests(OtherFileTests, unittest.TestCase): + FileIO = _pyio.FileIO + modulename = '_pyio' + def test_main(): # Historically, these tests have been sloppy about removing TESTFN. # So get rid of it no matter what. try: - run_unittest(AutoFileTests, OtherFileTests) + run_unittest(CAutoFileTests, PyAutoFileTests, + COtherFileTests, PyOtherFileTests) finally: if os.path.exists(TESTFN): os.unlink(TESTFN) diff --git a/Lib/test/test_finalization.py b/Lib/test/test_finalization.py index 03ac1aa..35d7913 100644 --- a/Lib/test/test_finalization.py +++ b/Lib/test/test_finalization.py @@ -515,8 +515,5 @@ class LegacyFinalizationTest(TestBase, unittest.TestCase): self.assertIs(wr(), None) -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_float.py b/Lib/test/test_float.py index e87aab0..4251090 100644 --- a/Lib/test/test_float.py +++ b/Lib/test/test_float.py @@ -773,6 +773,14 @@ class RoundTestCase(unittest.TestCase): test(sfmt, NAN, ' nan') test(sfmt, -NAN, ' nan') + def test_None_ndigits(self): + for x in round(1.23), round(1.23, None), round(1.23, ndigits=None): + self.assertEqual(x, 1) + self.assertIsInstance(x, int) + for x in round(1.78), round(1.78, None), round(1.78, ndigits=None): + self.assertEqual(x, 2) + self.assertIsInstance(x, int) + # Beginning with Python 2.6 float has cross platform compatible # ways to create and represent inf and nan @@ -1299,18 +1307,5 @@ class HexFloatTestCase(unittest.TestCase): self.identical(x, fromHex(toHex(x))) -def test_main(): - support.run_unittest( - GeneralFloatCases, - FormatFunctionsTestCase, - UnknownFormatTestCase, - IEEEFormatTestCase, - FormatTestCase, - ReprTestCase, - RoundTestCase, - InfNanTest, - HexFloatTestCase, - ) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_flufl.py b/Lib/test/test_flufl.py index 5a709bc..98b5bd6 100644 --- a/Lib/test/test_flufl.py +++ b/Lib/test/test_flufl.py @@ -18,10 +18,5 @@ class FLUFLTests(unittest.TestCase): '<FLUFL test>', 'exec') -def test_main(): - from test.support import run_unittest - run_unittest(FLUFLTests) - - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_fnmatch.py b/Lib/test/test_fnmatch.py index 482835d..fa37f90 100644 --- a/Lib/test/test_fnmatch.py +++ b/Lib/test/test_fnmatch.py @@ -1,6 +1,5 @@ """Test cases for the fnmatch module.""" -from test import support import unittest from fnmatch import fnmatch, fnmatchcase, translate, filter @@ -79,11 +78,5 @@ class FilterTestCase(unittest.TestCase): self.assertEqual(filter(['a', 'b'], 'a'), ['a']) -def test_main(): - support.run_unittest(FnmatchTestCase, - TranslateTestCase, - FilterTestCase) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_fork1.py b/Lib/test/test_fork1.py index e0626df..eeba306 100644 --- a/Lib/test/test_fork1.py +++ b/Lib/test/test_fork1.py @@ -8,7 +8,7 @@ import sys import time from test.fork_wait import ForkWait -from test.support import (run_unittest, reap_children, get_attribute, +from test.support import (reap_children, get_attribute, import_module, verbose) threading = import_module('threading') @@ -18,13 +18,14 @@ get_attribute(os, 'fork') class ForkTest(ForkWait): def wait_impl(self, cpid): - for i in range(10): + deadline = time.monotonic() + 10.0 + while time.monotonic() <= deadline: # waitpid() shouldn't hang, but some of the buildbots seem to hang # in the forking tests. This is an attempt to fix the problem. spid, status = os.waitpid(cpid, os.WNOHANG) if spid == cpid: break - time.sleep(1.0) + time.sleep(0.1) self.assertEqual(spid, cpid) self.assertEqual(status, 0, "cause = %d, exit = %d" % (status&0xff, status>>8)) @@ -103,9 +104,8 @@ class ForkTest(ForkWait): fork_with_import_lock(level) -def test_main(): - run_unittest(ForkTest) +def tearDownModule(): reap_children() if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py index fc71e48..9b13632 100644 --- a/Lib/test/test_format.py +++ b/Lib/test/test_format.py @@ -9,7 +9,7 @@ maxsize = support.MAX_Py_ssize_t # test string formatting operator (I am not sure if this is being tested # elsewhere but, surely, some of the given cases are *not* tested because # they crash python) -# test on unicode strings as well +# test on bytes object as well def testformat(formatstr, args, output=None, limit=None, overflowok=False): if verbose: @@ -46,191 +46,209 @@ def testformat(formatstr, args, output=None, limit=None, overflowok=False): if verbose: print('yes') +def testcommon(formatstr, args, output=None, limit=None, overflowok=False): + # if formatstr is a str, test str, bytes, and bytearray; + # otherwise, test bytes and bytearry + if isinstance(formatstr, str): + testformat(formatstr, args, output, limit, overflowok) + b_format = formatstr.encode('ascii') + else: + b_format = formatstr + ba_format = bytearray(b_format) + b_args = [] + if not isinstance(args, tuple): + args = (args, ) + b_args = tuple(args) + if output is None: + b_output = ba_output = None + else: + if isinstance(output, str): + b_output = output.encode('ascii') + else: + b_output = output + ba_output = bytearray(b_output) + testformat(b_format, b_args, b_output, limit, overflowok) + testformat(ba_format, b_args, ba_output, limit, overflowok) + class FormatTest(unittest.TestCase): - def test_format(self): - testformat("%.1d", (1,), "1") - testformat("%.*d", (sys.maxsize,1), overflowok=True) # expect overflow - testformat("%.100d", (1,), '00000000000000000000000000000000000000' + + def test_common_format(self): + # test the format identifiers that work the same across + # str, bytes, and bytearrays (integer, float, oct, hex) + testcommon("%.1d", (1,), "1") + testcommon("%.*d", (sys.maxsize,1), overflowok=True) # expect overflow + testcommon("%.100d", (1,), '00000000000000000000000000000000000000' '000000000000000000000000000000000000000000000000000000' '00000001', overflowok=True) - testformat("%#.117x", (1,), '0x00000000000000000000000000000000000' + testcommon("%#.117x", (1,), '0x00000000000000000000000000000000000' '000000000000000000000000000000000000000000000000000000' '0000000000000000000000000001', overflowok=True) - testformat("%#.118x", (1,), '0x00000000000000000000000000000000000' + testcommon("%#.118x", (1,), '0x00000000000000000000000000000000000' '000000000000000000000000000000000000000000000000000000' '00000000000000000000000000001', overflowok=True) - testformat("%f", (1.0,), "1.000000") + testcommon("%f", (1.0,), "1.000000") # these are trying to test the limits of the internal magic-number-length # formatting buffer, if that number changes then these tests are less # effective - testformat("%#.*g", (109, -1.e+49/3.)) - testformat("%#.*g", (110, -1.e+49/3.)) - testformat("%#.*g", (110, -1.e+100/3.)) + testcommon("%#.*g", (109, -1.e+49/3.)) + testcommon("%#.*g", (110, -1.e+49/3.)) + testcommon("%#.*g", (110, -1.e+100/3.)) # test some ridiculously large precision, expect overflow - testformat('%12.*f', (123456, 1.0)) + testcommon('%12.*f', (123456, 1.0)) # check for internal overflow validation on length of precision # these tests should no longer cause overflow in Python # 2.7/3.1 and later. - testformat("%#.*g", (110, -1.e+100/3.)) - testformat("%#.*G", (110, -1.e+100/3.)) - testformat("%#.*f", (110, -1.e+100/3.)) - testformat("%#.*F", (110, -1.e+100/3.)) + testcommon("%#.*g", (110, -1.e+100/3.)) + testcommon("%#.*G", (110, -1.e+100/3.)) + testcommon("%#.*f", (110, -1.e+100/3.)) + testcommon("%#.*F", (110, -1.e+100/3.)) # Formatting of integers. Overflow is not ok - testformat("%x", 10, "a") - testformat("%x", 100000000000, "174876e800") - testformat("%o", 10, "12") - testformat("%o", 100000000000, "1351035564000") - testformat("%d", 10, "10") - testformat("%d", 100000000000, "100000000000") + testcommon("%x", 10, "a") + testcommon("%x", 100000000000, "174876e800") + testcommon("%o", 10, "12") + testcommon("%o", 100000000000, "1351035564000") + testcommon("%d", 10, "10") + testcommon("%d", 100000000000, "100000000000") big = 123456789012345678901234567890 - testformat("%d", big, "123456789012345678901234567890") - testformat("%d", -big, "-123456789012345678901234567890") - testformat("%5d", -big, "-123456789012345678901234567890") - testformat("%31d", -big, "-123456789012345678901234567890") - testformat("%32d", -big, " -123456789012345678901234567890") - testformat("%-32d", -big, "-123456789012345678901234567890 ") - testformat("%032d", -big, "-0123456789012345678901234567890") - testformat("%-032d", -big, "-123456789012345678901234567890 ") - testformat("%034d", -big, "-000123456789012345678901234567890") - testformat("%034d", big, "0000123456789012345678901234567890") - testformat("%0+34d", big, "+000123456789012345678901234567890") - testformat("%+34d", big, " +123456789012345678901234567890") - testformat("%34d", big, " 123456789012345678901234567890") - testformat("%.2d", big, "123456789012345678901234567890") - testformat("%.30d", big, "123456789012345678901234567890") - testformat("%.31d", big, "0123456789012345678901234567890") - testformat("%32.31d", big, " 0123456789012345678901234567890") - testformat("%d", float(big), "123456________________________", 6) + testcommon("%d", big, "123456789012345678901234567890") + testcommon("%d", -big, "-123456789012345678901234567890") + testcommon("%5d", -big, "-123456789012345678901234567890") + testcommon("%31d", -big, "-123456789012345678901234567890") + testcommon("%32d", -big, " -123456789012345678901234567890") + testcommon("%-32d", -big, "-123456789012345678901234567890 ") + testcommon("%032d", -big, "-0123456789012345678901234567890") + testcommon("%-032d", -big, "-123456789012345678901234567890 ") + testcommon("%034d", -big, "-000123456789012345678901234567890") + testcommon("%034d", big, "0000123456789012345678901234567890") + testcommon("%0+34d", big, "+000123456789012345678901234567890") + testcommon("%+34d", big, " +123456789012345678901234567890") + testcommon("%34d", big, " 123456789012345678901234567890") + testcommon("%.2d", big, "123456789012345678901234567890") + testcommon("%.30d", big, "123456789012345678901234567890") + testcommon("%.31d", big, "0123456789012345678901234567890") + testcommon("%32.31d", big, " 0123456789012345678901234567890") + testcommon("%d", float(big), "123456________________________", 6) big = 0x1234567890abcdef12345 # 21 hex digits - testformat("%x", big, "1234567890abcdef12345") - testformat("%x", -big, "-1234567890abcdef12345") - testformat("%5x", -big, "-1234567890abcdef12345") - testformat("%22x", -big, "-1234567890abcdef12345") - testformat("%23x", -big, " -1234567890abcdef12345") - testformat("%-23x", -big, "-1234567890abcdef12345 ") - testformat("%023x", -big, "-01234567890abcdef12345") - testformat("%-023x", -big, "-1234567890abcdef12345 ") - testformat("%025x", -big, "-0001234567890abcdef12345") - testformat("%025x", big, "00001234567890abcdef12345") - testformat("%0+25x", big, "+0001234567890abcdef12345") - testformat("%+25x", big, " +1234567890abcdef12345") - testformat("%25x", big, " 1234567890abcdef12345") - testformat("%.2x", big, "1234567890abcdef12345") - testformat("%.21x", big, "1234567890abcdef12345") - testformat("%.22x", big, "01234567890abcdef12345") - testformat("%23.22x", big, " 01234567890abcdef12345") - testformat("%-23.22x", big, "01234567890abcdef12345 ") - testformat("%X", big, "1234567890ABCDEF12345") - testformat("%#X", big, "0X1234567890ABCDEF12345") - testformat("%#x", big, "0x1234567890abcdef12345") - testformat("%#x", -big, "-0x1234567890abcdef12345") - testformat("%#.23x", -big, "-0x001234567890abcdef12345") - testformat("%#+.23x", big, "+0x001234567890abcdef12345") - testformat("%# .23x", big, " 0x001234567890abcdef12345") - testformat("%#+.23X", big, "+0X001234567890ABCDEF12345") - testformat("%#-+.23X", big, "+0X001234567890ABCDEF12345") - testformat("%#-+26.23X", big, "+0X001234567890ABCDEF12345") - testformat("%#-+27.23X", big, "+0X001234567890ABCDEF12345 ") - testformat("%#+27.23X", big, " +0X001234567890ABCDEF12345") + testcommon("%x", big, "1234567890abcdef12345") + testcommon("%x", -big, "-1234567890abcdef12345") + testcommon("%5x", -big, "-1234567890abcdef12345") + testcommon("%22x", -big, "-1234567890abcdef12345") + testcommon("%23x", -big, " -1234567890abcdef12345") + testcommon("%-23x", -big, "-1234567890abcdef12345 ") + testcommon("%023x", -big, "-01234567890abcdef12345") + testcommon("%-023x", -big, "-1234567890abcdef12345 ") + testcommon("%025x", -big, "-0001234567890abcdef12345") + testcommon("%025x", big, "00001234567890abcdef12345") + testcommon("%0+25x", big, "+0001234567890abcdef12345") + testcommon("%+25x", big, " +1234567890abcdef12345") + testcommon("%25x", big, " 1234567890abcdef12345") + testcommon("%.2x", big, "1234567890abcdef12345") + testcommon("%.21x", big, "1234567890abcdef12345") + testcommon("%.22x", big, "01234567890abcdef12345") + testcommon("%23.22x", big, " 01234567890abcdef12345") + testcommon("%-23.22x", big, "01234567890abcdef12345 ") + testcommon("%X", big, "1234567890ABCDEF12345") + testcommon("%#X", big, "0X1234567890ABCDEF12345") + testcommon("%#x", big, "0x1234567890abcdef12345") + testcommon("%#x", -big, "-0x1234567890abcdef12345") + testcommon("%#.23x", -big, "-0x001234567890abcdef12345") + testcommon("%#+.23x", big, "+0x001234567890abcdef12345") + testcommon("%# .23x", big, " 0x001234567890abcdef12345") + testcommon("%#+.23X", big, "+0X001234567890ABCDEF12345") + testcommon("%#-+.23X", big, "+0X001234567890ABCDEF12345") + testcommon("%#-+26.23X", big, "+0X001234567890ABCDEF12345") + testcommon("%#-+27.23X", big, "+0X001234567890ABCDEF12345 ") + testcommon("%#+27.23X", big, " +0X001234567890ABCDEF12345") # next one gets two leading zeroes from precision, and another from the # 0 flag and the width - testformat("%#+027.23X", big, "+0X0001234567890ABCDEF12345") + testcommon("%#+027.23X", big, "+0X0001234567890ABCDEF12345") # same, except no 0 flag - testformat("%#+27.23X", big, " +0X001234567890ABCDEF12345") - with self.assertWarns(DeprecationWarning): - testformat("%x", float(big), "123456_______________", 6) + testcommon("%#+27.23X", big, " +0X001234567890ABCDEF12345") big = 0o12345670123456701234567012345670 # 32 octal digits - testformat("%o", big, "12345670123456701234567012345670") - testformat("%o", -big, "-12345670123456701234567012345670") - testformat("%5o", -big, "-12345670123456701234567012345670") - testformat("%33o", -big, "-12345670123456701234567012345670") - testformat("%34o", -big, " -12345670123456701234567012345670") - testformat("%-34o", -big, "-12345670123456701234567012345670 ") - testformat("%034o", -big, "-012345670123456701234567012345670") - testformat("%-034o", -big, "-12345670123456701234567012345670 ") - testformat("%036o", -big, "-00012345670123456701234567012345670") - testformat("%036o", big, "000012345670123456701234567012345670") - testformat("%0+36o", big, "+00012345670123456701234567012345670") - testformat("%+36o", big, " +12345670123456701234567012345670") - testformat("%36o", big, " 12345670123456701234567012345670") - testformat("%.2o", big, "12345670123456701234567012345670") - testformat("%.32o", big, "12345670123456701234567012345670") - testformat("%.33o", big, "012345670123456701234567012345670") - testformat("%34.33o", big, " 012345670123456701234567012345670") - testformat("%-34.33o", big, "012345670123456701234567012345670 ") - testformat("%o", big, "12345670123456701234567012345670") - testformat("%#o", big, "0o12345670123456701234567012345670") - testformat("%#o", -big, "-0o12345670123456701234567012345670") - testformat("%#.34o", -big, "-0o0012345670123456701234567012345670") - testformat("%#+.34o", big, "+0o0012345670123456701234567012345670") - testformat("%# .34o", big, " 0o0012345670123456701234567012345670") - testformat("%#+.34o", big, "+0o0012345670123456701234567012345670") - testformat("%#-+.34o", big, "+0o0012345670123456701234567012345670") - testformat("%#-+37.34o", big, "+0o0012345670123456701234567012345670") - testformat("%#+37.34o", big, "+0o0012345670123456701234567012345670") + testcommon("%o", big, "12345670123456701234567012345670") + testcommon("%o", -big, "-12345670123456701234567012345670") + testcommon("%5o", -big, "-12345670123456701234567012345670") + testcommon("%33o", -big, "-12345670123456701234567012345670") + testcommon("%34o", -big, " -12345670123456701234567012345670") + testcommon("%-34o", -big, "-12345670123456701234567012345670 ") + testcommon("%034o", -big, "-012345670123456701234567012345670") + testcommon("%-034o", -big, "-12345670123456701234567012345670 ") + testcommon("%036o", -big, "-00012345670123456701234567012345670") + testcommon("%036o", big, "000012345670123456701234567012345670") + testcommon("%0+36o", big, "+00012345670123456701234567012345670") + testcommon("%+36o", big, " +12345670123456701234567012345670") + testcommon("%36o", big, " 12345670123456701234567012345670") + testcommon("%.2o", big, "12345670123456701234567012345670") + testcommon("%.32o", big, "12345670123456701234567012345670") + testcommon("%.33o", big, "012345670123456701234567012345670") + testcommon("%34.33o", big, " 012345670123456701234567012345670") + testcommon("%-34.33o", big, "012345670123456701234567012345670 ") + testcommon("%o", big, "12345670123456701234567012345670") + testcommon("%#o", big, "0o12345670123456701234567012345670") + testcommon("%#o", -big, "-0o12345670123456701234567012345670") + testcommon("%#.34o", -big, "-0o0012345670123456701234567012345670") + testcommon("%#+.34o", big, "+0o0012345670123456701234567012345670") + testcommon("%# .34o", big, " 0o0012345670123456701234567012345670") + testcommon("%#+.34o", big, "+0o0012345670123456701234567012345670") + testcommon("%#-+.34o", big, "+0o0012345670123456701234567012345670") + testcommon("%#-+37.34o", big, "+0o0012345670123456701234567012345670") + testcommon("%#+37.34o", big, "+0o0012345670123456701234567012345670") # next one gets one leading zero from precision - testformat("%.33o", big, "012345670123456701234567012345670") + testcommon("%.33o", big, "012345670123456701234567012345670") # base marker shouldn't change that, since "0" is redundant - testformat("%#.33o", big, "0o012345670123456701234567012345670") + testcommon("%#.33o", big, "0o012345670123456701234567012345670") # but reduce precision, and base marker should add a zero - testformat("%#.32o", big, "0o12345670123456701234567012345670") + testcommon("%#.32o", big, "0o12345670123456701234567012345670") # one leading zero from precision, and another from "0" flag & width - testformat("%034.33o", big, "0012345670123456701234567012345670") + testcommon("%034.33o", big, "0012345670123456701234567012345670") # base marker shouldn't change that - testformat("%0#34.33o", big, "0o012345670123456701234567012345670") - with self.assertWarns(DeprecationWarning): - testformat("%o", float(big), "123456__________________________", 6) + testcommon("%0#34.33o", big, "0o012345670123456701234567012345670") # Some small ints, in both Python int and flavors). - testformat("%d", 42, "42") - testformat("%d", -42, "-42") - testformat("%d", 42, "42") - testformat("%d", -42, "-42") - testformat("%d", 42.0, "42") - testformat("%#x", 1, "0x1") - testformat("%#x", 1, "0x1") - testformat("%#X", 1, "0X1") - testformat("%#X", 1, "0X1") - with self.assertWarns(DeprecationWarning): - testformat("%#x", 1.0, "0x1") - testformat("%#o", 1, "0o1") - testformat("%#o", 1, "0o1") - testformat("%#o", 0, "0o0") - testformat("%#o", 0, "0o0") - testformat("%o", 0, "0") - testformat("%o", 0, "0") - testformat("%d", 0, "0") - testformat("%d", 0, "0") - testformat("%#x", 0, "0x0") - testformat("%#x", 0, "0x0") - testformat("%#X", 0, "0X0") - testformat("%#X", 0, "0X0") - testformat("%x", 0x42, "42") - testformat("%x", -0x42, "-42") - testformat("%x", 0x42, "42") - testformat("%x", -0x42, "-42") - with self.assertWarns(DeprecationWarning): - testformat("%x", float(0x42), "42") - testformat("%o", 0o42, "42") - testformat("%o", -0o42, "-42") - testformat("%o", 0o42, "42") - testformat("%o", -0o42, "-42") - with self.assertWarns(DeprecationWarning): - testformat("%o", float(0o42), "42") + testcommon("%d", 42, "42") + testcommon("%d", -42, "-42") + testcommon("%d", 42, "42") + testcommon("%d", -42, "-42") + testcommon("%d", 42.0, "42") + testcommon("%#x", 1, "0x1") + testcommon("%#x", 1, "0x1") + testcommon("%#X", 1, "0X1") + testcommon("%#X", 1, "0X1") + testcommon("%#o", 1, "0o1") + testcommon("%#o", 1, "0o1") + testcommon("%#o", 0, "0o0") + testcommon("%#o", 0, "0o0") + testcommon("%o", 0, "0") + testcommon("%o", 0, "0") + testcommon("%d", 0, "0") + testcommon("%d", 0, "0") + testcommon("%#x", 0, "0x0") + testcommon("%#x", 0, "0x0") + testcommon("%#X", 0, "0X0") + testcommon("%#X", 0, "0X0") + testcommon("%x", 0x42, "42") + testcommon("%x", -0x42, "-42") + testcommon("%x", 0x42, "42") + testcommon("%x", -0x42, "-42") + testcommon("%o", 0o42, "42") + testcommon("%o", -0o42, "-42") + testcommon("%o", 0o42, "42") + testcommon("%o", -0o42, "-42") + # alternate float formatting + testcommon('%g', 1.1, '1.1') + testcommon('%#g', 1.1, '1.10000') + + def test_str_format(self): testformat("%r", "\u0378", "'\\u0378'") # non printable testformat("%a", "\u0378", "'\\u0378'") # non printable testformat("%r", "\u0374", "'\u0374'") # printable testformat("%a", "\u0374", "'\\u0374'") # printable - # alternate float formatting - testformat('%g', 1.1, '1.1') - testformat('%#g', 1.1, '1.10000') - - # Test exception for unknown format characters + # Test exception for unknown format characters, etc. if verbose: print('Testing exceptions') def test_exc(formatstr, args, exception, excmsg): @@ -254,11 +272,108 @@ class FormatTest(unittest.TestCase): #test_exc(unicode('abc %\u3000','raw-unicode-escape'), 1, ValueError, # "unsupported format character '?' (0x3000) at index 5") test_exc('%d', '1', TypeError, "%d format: a number is required, not str") + test_exc('%x', '1', TypeError, "%x format: an integer is required, not str") + test_exc('%x', 3.14, TypeError, "%x format: an integer is required, not float") test_exc('%g', '1', TypeError, "a float is required") test_exc('no format', '1', TypeError, "not all arguments converted during string formatting") - test_exc('no format', '1', TypeError, - "not all arguments converted during string formatting") + test_exc('%c', -1, OverflowError, "%c arg not in range(0x110000)") + test_exc('%c', sys.maxunicode+1, OverflowError, + "%c arg not in range(0x110000)") + #test_exc('%c', 2**128, OverflowError, "%c arg not in range(0x110000)") + test_exc('%c', 3.14, TypeError, "%c requires int or char") + test_exc('%c', 'ab', TypeError, "%c requires int or char") + test_exc('%c', b'x', TypeError, "%c requires int or char") + + if maxsize == 2**31-1: + # crashes 2.2.1 and earlier: + try: + "%*d"%(maxsize, -127) + except MemoryError: + pass + else: + raise TestFailed('"%*d"%(maxsize, -127) should fail') + + def test_bytes_and_bytearray_format(self): + # %c will insert a single byte, either from an int in range(256), or + # from a bytes argument of length 1, not from a str. + testcommon(b"%c", 7, b"\x07") + testcommon(b"%c", b"Z", b"Z") + testcommon(b"%c", bytearray(b"Z"), b"Z") + # %b will insert a series of bytes, either from a type that supports + # the Py_buffer protocol, or something that has a __bytes__ method + class FakeBytes(object): + def __bytes__(self): + return b'123' + fb = FakeBytes() + testcommon(b"%b", b"abc", b"abc") + testcommon(b"%b", bytearray(b"def"), b"def") + testcommon(b"%b", fb, b"123") + # # %s is an alias for %b -- should only be used for Py2/3 code + testcommon(b"%s", b"abc", b"abc") + testcommon(b"%s", bytearray(b"def"), b"def") + testcommon(b"%s", fb, b"123") + # %a will give the equivalent of + # repr(some_obj).encode('ascii', 'backslashreplace') + testcommon(b"%a", 3.14, b"3.14") + testcommon(b"%a", b"ghi", b"b'ghi'") + testcommon(b"%a", "jkl", b"'jkl'") + testcommon(b"%a", "\u0544", b"'\\u0544'") + # %r is an alias for %a + testcommon(b"%r", 3.14, b"3.14") + testcommon(b"%r", b"ghi", b"b'ghi'") + testcommon(b"%r", "jkl", b"'jkl'") + testcommon(b"%r", "\u0544", b"'\\u0544'") + + # Test exception for unknown format characters, etc. + if verbose: + print('Testing exceptions') + def test_exc(formatstr, args, exception, excmsg): + try: + testformat(formatstr, args) + except exception as exc: + if str(exc) == excmsg: + if verbose: + print("yes") + else: + if verbose: print('no') + print('Unexpected ', exception, ':', repr(str(exc))) + except: + if verbose: print('no') + print('Unexpected exception') + raise + else: + raise TestFailed('did not get expected exception: %s' % excmsg) + test_exc(b'%d', '1', TypeError, + "%d format: a number is required, not str") + test_exc(b'%d', b'1', TypeError, + "%d format: a number is required, not bytes") + test_exc(b'%x', 3.14, TypeError, + "%x format: an integer is required, not float") + test_exc(b'%g', '1', TypeError, "float argument required, not str") + test_exc(b'%g', b'1', TypeError, "float argument required, not bytes") + test_exc(b'no format', 7, TypeError, + "not all arguments converted during bytes formatting") + test_exc(b'no format', b'1', TypeError, + "not all arguments converted during bytes formatting") + test_exc(b'no format', bytearray(b'1'), TypeError, + "not all arguments converted during bytes formatting") + test_exc(b"%c", -1, OverflowError, + "%c arg not in range(256)") + test_exc(b"%c", 256, OverflowError, + "%c arg not in range(256)") + test_exc(b"%c", 2**128, OverflowError, + "%c arg not in range(256)") + test_exc(b"%c", b"Za", TypeError, + "%c requires an integer in range(256) or a single byte") + test_exc(b"%c", "Y", TypeError, + "%c requires an integer in range(256) or a single byte") + test_exc(b"%c", 3.14, TypeError, + "%c requires an integer in range(256) or a single byte") + test_exc(b"%b", "Xc", TypeError, + "%b requires bytes, or an object that implements __bytes__, not 'str'") + test_exc(b"%s", "Wd", TypeError, + "%b requires bytes, or an object that implements __bytes__, not 'str'") if maxsize == 2**31-1: # crashes 2.2.1 and earlier: diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py index 3336532..1699852 100644 --- a/Lib/test/test_fractions.py +++ b/Lib/test/test_fractions.py @@ -1,13 +1,14 @@ """Tests for Lib/fractions.py.""" from decimal import Decimal -from test.support import run_unittest, requires_IEEE_754 +from test.support import requires_IEEE_754 import math import numbers import operator import fractions import sys import unittest +import warnings from copy import copy, deepcopy from pickle import dumps, loads F = fractions.Fraction @@ -49,7 +50,7 @@ class DummyRational(object): """Test comparison of Fraction with a naive rational implementation.""" def __init__(self, num, den): - g = gcd(num, den) + g = math.gcd(num, den) self.num = num // g self.den = den // g @@ -83,16 +84,26 @@ class DummyFraction(fractions.Fraction): class GcdTest(unittest.TestCase): def testMisc(self): - self.assertEqual(0, gcd(0, 0)) - self.assertEqual(1, gcd(1, 0)) - self.assertEqual(-1, gcd(-1, 0)) - self.assertEqual(1, gcd(0, 1)) - self.assertEqual(-1, gcd(0, -1)) - self.assertEqual(1, gcd(7, 1)) - self.assertEqual(-1, gcd(7, -1)) - self.assertEqual(1, gcd(-23, 15)) - self.assertEqual(12, gcd(120, 84)) - self.assertEqual(-12, gcd(84, -120)) + # fractions.gcd() is deprecated + with self.assertWarnsRegex(DeprecationWarning, r'fractions\.gcd'): + gcd(1, 1) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', r'fractions\.gcd', + DeprecationWarning) + self.assertEqual(0, gcd(0, 0)) + self.assertEqual(1, gcd(1, 0)) + self.assertEqual(-1, gcd(-1, 0)) + self.assertEqual(1, gcd(0, 1)) + self.assertEqual(-1, gcd(0, -1)) + self.assertEqual(1, gcd(7, 1)) + self.assertEqual(-1, gcd(7, -1)) + self.assertEqual(1, gcd(-23, 15)) + self.assertEqual(12, gcd(120, 84)) + self.assertEqual(-12, gcd(84, -120)) + self.assertEqual(gcd(120.0, 84), 12.0) + self.assertEqual(gcd(120, 84.0), 12.0) + self.assertEqual(gcd(F(120), F(84)), F(12)) + self.assertEqual(gcd(F(120, 77), F(84, 55)), F(12, 385)) def _components(r): @@ -330,7 +341,6 @@ class FractionTest(unittest.TestCase): self.assertTypedEquals(F(-2, 10), round(F(-15, 100), 1)) self.assertTypedEquals(F(-2, 10), round(F(-25, 100), 1)) - def testArithmetic(self): self.assertEqual(F(1, 2), F(1, 10) + F(2, 5)) self.assertEqual(F(-3, 10), F(1, 10) - F(2, 5)) @@ -402,6 +412,8 @@ class FractionTest(unittest.TestCase): self.assertTypedEquals(2.0 , 4 ** F(1, 2)) self.assertTypedEquals(0.25, 2.0 ** F(-2, 1)) self.assertTypedEquals(1.0 + 0j, (1.0 + 0j) ** F(1, 10)) + self.assertRaises(ZeroDivisionError, operator.pow, + F(0, 1), -2) def testMixingWithDecimal(self): # Decimal refuses mixed arithmetic (but not mixed comparisons) @@ -605,8 +617,5 @@ class FractionTest(unittest.TestCase): r = F(13, 7) self.assertRaises(AttributeError, setattr, r, 'a', 10) -def test_main(): - run_unittest(FractionTest, GcdTest) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_frame.py b/Lib/test/test_frame.py index c402ec3..189fca9 100644 --- a/Lib/test/test_frame.py +++ b/Lib/test/test_frame.py @@ -161,8 +161,5 @@ class FrameLocalsTest(unittest.TestCase): self.assertEqual(inner.f_locals, {}) -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py index d3be7d6..aef66da 100644 --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -73,7 +73,7 @@ class DummyDTPHandler(asynchat.async_chat): super(DummyDTPHandler, self).push(what.encode('ascii')) def handle_error(self): - raise + raise Exception class DummyFTPHandler(asynchat.async_chat): @@ -118,7 +118,7 @@ class DummyFTPHandler(asynchat.async_chat): self.push('550 command "%s" not understood.' %cmd) def handle_error(self): - raise + raise Exception def push(self, data): asynchat.async_chat.push(self, data.encode('ascii') + b'\r\n') @@ -134,7 +134,7 @@ class DummyFTPHandler(asynchat.async_chat): def cmd_pasv(self, arg): with socket.socket() as sock: sock.bind((self.socket.getsockname()[0], 0)) - sock.listen(5) + sock.listen() sock.settimeout(TIMEOUT) ip, port = sock.getsockname()[:2] ip = ip.replace('.', ','); p1 = port / 256; p2 = port % 256 @@ -152,7 +152,7 @@ class DummyFTPHandler(asynchat.async_chat): def cmd_epsv(self, arg): with socket.socket(socket.AF_INET6) as sock: sock.bind((self.socket.getsockname()[0], 0)) - sock.listen(5) + sock.listen() sock.settimeout(TIMEOUT) port = sock.getsockname()[1] self.push('229 entering extended passive mode (|||%d|)' %port) @@ -296,7 +296,7 @@ class DummyFTPServer(asyncore.dispatcher, threading.Thread): return 0 def handle_error(self): - raise + raise Exception if ssl is not None: @@ -394,7 +394,7 @@ if ssl is not None: raise def handle_error(self): - raise + raise Exception def close(self): if (isinstance(self.socket, ssl.SSLSocket) and @@ -670,7 +670,7 @@ class TestFTPClass(TestCase): self.assertRaises(StopIteration, next, self.client.mlsd()) set_data('') for x in self.client.mlsd(): - self.fail("unexpected data %s" % data) + self.fail("unexpected data %s" % x) def test_makeport(self): with self.client.makeport(): @@ -979,7 +979,7 @@ class TestTimeouts(TestCase): # 1) when the connection is ready to be accepted. # 2) when it is safe for the caller to close the connection # 3) when we have closed the socket - self.sock.listen(5) + self.sock.listen() # (1) Signal the caller that we are ready to accept the connection. self.evt.set() try: @@ -1049,19 +1049,8 @@ class TestTimeouts(TestCase): ftp.close() -class TestNetrcDeprecation(TestCase): - - def test_deprecation(self): - with support.temp_cwd(), support.EnvironmentVarGuard() as env: - env['HOME'] = os.getcwd() - open('.netrc', 'w').close() - with self.assertWarns(DeprecationWarning): - ftplib.Netrc() - - - def test_main(): - tests = [TestFTPClass, TestTimeouts, TestNetrcDeprecation, + tests = [TestFTPClass, TestTimeouts, TestIPv6Environment, TestTLS_FTPClassMixin, TestTLS_FTPClass] diff --git a/Lib/test/test_funcattrs.py b/Lib/test/test_funcattrs.py index 5094f7b..8f481bb 100644 --- a/Lib/test/test_funcattrs.py +++ b/Lib/test/test_funcattrs.py @@ -1,4 +1,3 @@ -from test import support import types import unittest @@ -374,12 +373,5 @@ class BuiltinFunctionPropertiesTest(unittest.TestCase): self.assertEqual({'foo': 'bar'}.pop.__qualname__, 'dict.pop') -def test_main(): - support.run_unittest(FunctionPropertiesTest, InstancemethodAttrTest, - ArbitraryFunctionAttrTest, FunctionDictsTest, - FunctionDocstringTest, CellTest, - StaticMethodAttrsTest, - BuiltinFunctionPropertiesTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index c0d24d8c..7ecf877 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -7,6 +7,10 @@ import sys from test import support import unittest from weakref import proxy +try: + import threading +except ImportError: + threading = None import functools @@ -133,6 +137,25 @@ class TestPartial: join = self.partial(''.join) self.assertEqual(join(data), '0123456789') + def test_nested_optimization(self): + partial = self.partial + inner = partial(signature, 'asdf') + nested = partial(inner, bar=True) + flat = partial(signature, 'asdf', bar=True) + self.assertEqual(signature(nested), signature(flat)) + + def test_nested_partial_with_attribute(self): + # see issue 25137 + partial = self.partial + + def foo(bar): + return bar + + p = partial(foo, 'first') + p2 = partial(p, 'second') + p2.new_attr = 'spam' + self.assertEqual(p2.new_attr, 'spam') + @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestPartialC(TestPartial, unittest.TestCase): @@ -224,6 +247,9 @@ class TestPartialCSubclass(TestPartialC): if c_functools: partial = PartialSubclass + # partial subclasses are not optimized for nested calls + test_nested_optimization = None + class TestPartialMethod(unittest.TestCase): @@ -884,12 +910,30 @@ class TestTotalOrdering(unittest.TestCase): with self.assertRaises(TypeError): a <= b -class TestLRU(unittest.TestCase): + def test_pickle(self): + for proto in range(4, pickle.HIGHEST_PROTOCOL + 1): + for name in '__lt__', '__gt__', '__le__', '__ge__': + with self.subTest(method=name, proto=proto): + method = getattr(Orderable_LT, name) + method_copy = pickle.loads(pickle.dumps(method, proto)) + self.assertIs(method_copy, method) + +@functools.total_ordering +class Orderable_LT: + def __init__(self, value): + self.value = value + def __lt__(self, other): + return self.value < other.value + def __eq__(self, other): + return self.value == other.value + + +class TestLRU: def test_lru(self): def orig(x, y): return 3 * x + y - f = functools.lru_cache(maxsize=20)(orig) + f = self.module.lru_cache(maxsize=20)(orig) hits, misses, maxsize, currsize = f.cache_info() self.assertEqual(maxsize, 20) self.assertEqual(currsize, 0) @@ -927,7 +971,7 @@ class TestLRU(unittest.TestCase): self.assertEqual(currsize, 1) # test size zero (which means "never-cache") - @functools.lru_cache(0) + @self.module.lru_cache(0) def f(): nonlocal f_cnt f_cnt += 1 @@ -943,7 +987,7 @@ class TestLRU(unittest.TestCase): self.assertEqual(currsize, 0) # test size one - @functools.lru_cache(1) + @self.module.lru_cache(1) def f(): nonlocal f_cnt f_cnt += 1 @@ -959,7 +1003,7 @@ class TestLRU(unittest.TestCase): self.assertEqual(currsize, 1) # test size two - @functools.lru_cache(2) + @self.module.lru_cache(2) def f(x): nonlocal f_cnt f_cnt += 1 @@ -976,7 +1020,7 @@ class TestLRU(unittest.TestCase): self.assertEqual(currsize, 2) def test_lru_with_maxsize_none(self): - @functools.lru_cache(maxsize=None) + @self.module.lru_cache(maxsize=None) def fib(n): if n < 2: return n @@ -984,17 +1028,26 @@ class TestLRU(unittest.TestCase): self.assertEqual([fib(n) for n in range(16)], [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) self.assertEqual(fib.cache_info(), - functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) + self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) fib.cache_clear() self.assertEqual(fib.cache_info(), - functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) + self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) + + def test_lru_with_maxsize_negative(self): + @self.module.lru_cache(maxsize=-10) + def eq(n): + return n + for i in (0, 1): + self.assertEqual([eq(n) for n in range(150)], list(range(150))) + self.assertEqual(eq.cache_info(), + self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1)) def test_lru_with_exceptions(self): # Verify that user_function exceptions get passed through without # creating a hard-to-read chained exception. # http://bugs.python.org/issue13177 for maxsize in (None, 128): - @functools.lru_cache(maxsize) + @self.module.lru_cache(maxsize) def func(i): return 'abc'[i] self.assertEqual(func(0), 'a') @@ -1007,7 +1060,7 @@ class TestLRU(unittest.TestCase): def test_lru_with_types(self): for maxsize in (None, 128): - @functools.lru_cache(maxsize=maxsize, typed=True) + @self.module.lru_cache(maxsize=maxsize, typed=True) def square(x): return x * x self.assertEqual(square(3), 9) @@ -1022,7 +1075,7 @@ class TestLRU(unittest.TestCase): self.assertEqual(square.cache_info().misses, 4) def test_lru_with_keyword_args(self): - @functools.lru_cache() + @self.module.lru_cache() def fib(n): if n < 2: return n @@ -1032,13 +1085,13 @@ class TestLRU(unittest.TestCase): [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] ) self.assertEqual(fib.cache_info(), - functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) + self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) fib.cache_clear() self.assertEqual(fib.cache_info(), - functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) + self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) def test_lru_with_keyword_args_maxsize_none(self): - @functools.lru_cache(maxsize=None) + @self.module.lru_cache(maxsize=None) def fib(n): if n < 2: return n @@ -1046,15 +1099,100 @@ class TestLRU(unittest.TestCase): self.assertEqual([fib(n=number) for number in range(16)], [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) self.assertEqual(fib.cache_info(), - functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) + self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) fib.cache_clear() self.assertEqual(fib.cache_info(), - functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) + self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) + + def test_lru_cache_decoration(self): + def f(zomg: 'zomg_annotation'): + """f doc string""" + return 42 + g = self.module.lru_cache()(f) + for attr in self.module.WRAPPER_ASSIGNMENTS: + self.assertEqual(getattr(g, attr), getattr(f, attr)) + + @unittest.skipUnless(threading, 'This test requires threading.') + def test_lru_cache_threaded(self): + n, m = 5, 11 + def orig(x, y): + return 3 * x + y + f = self.module.lru_cache(maxsize=n*m)(orig) + hits, misses, maxsize, currsize = f.cache_info() + self.assertEqual(currsize, 0) + + start = threading.Event() + def full(k): + start.wait(10) + for _ in range(m): + self.assertEqual(f(k, 0), orig(k, 0)) + + def clear(): + start.wait(10) + for _ in range(2*m): + f.cache_clear() + + orig_si = sys.getswitchinterval() + sys.setswitchinterval(1e-6) + try: + # create n threads in order to fill cache + threads = [threading.Thread(target=full, args=[k]) + for k in range(n)] + with support.start_threads(threads): + start.set() + + hits, misses, maxsize, currsize = f.cache_info() + if self.module is py_functools: + # XXX: Why can be not equal? + self.assertLessEqual(misses, n) + self.assertLessEqual(hits, m*n - misses) + else: + self.assertEqual(misses, n) + self.assertEqual(hits, m*n - misses) + self.assertEqual(currsize, n) + + # create n threads in order to fill cache and 1 to clear it + threads = [threading.Thread(target=clear)] + threads += [threading.Thread(target=full, args=[k]) + for k in range(n)] + start.clear() + with support.start_threads(threads): + start.set() + finally: + sys.setswitchinterval(orig_si) + + @unittest.skipUnless(threading, 'This test requires threading.') + def test_lru_cache_threaded2(self): + # Simultaneous call with the same arguments + n, m = 5, 7 + start = threading.Barrier(n+1) + pause = threading.Barrier(n+1) + stop = threading.Barrier(n+1) + @self.module.lru_cache(maxsize=m*n) + def f(x): + pause.wait(10) + return 3 * x + self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) + def test(): + for i in range(m): + start.wait(10) + self.assertEqual(f(i), 3 * i) + stop.wait(10) + threads = [threading.Thread(target=test) for k in range(n)] + with support.start_threads(threads): + for i in range(m): + start.wait(10) + stop.reset() + pause.wait(10) + start.reset() + stop.wait(10) + pause.reset() + self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) def test_need_for_rlock(self): # This will deadlock on an LRU cache that uses a regular lock - @functools.lru_cache(maxsize=10) + @self.module.lru_cache(maxsize=10) def test_func(x): 'Used to demonstrate a reentrant lru_cache call within a single thread' return x @@ -1082,6 +1220,43 @@ class TestLRU(unittest.TestCase): def f(): pass + def test_lru_method(self): + class X(int): + f_cnt = 0 + @self.module.lru_cache(2) + def f(self, x): + self.f_cnt += 1 + return x*10+self + a = X(5) + b = X(5) + c = X(7) + self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) + + for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: + self.assertEqual(a.f(x), x*10 + 5) + self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) + self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) + + for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: + self.assertEqual(b.f(x), x*10 + 5) + self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) + self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) + + for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: + self.assertEqual(c.f(x), x*10 + 7) + self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) + self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) + + self.assertEqual(a.f.cache_info(), X.f.cache_info()) + self.assertEqual(b.f.cache_info(), X.f.cache_info()) + self.assertEqual(c.f.cache_info(), X.f.cache_info()) + +class TestLRUC(TestLRU, unittest.TestCase): + module = c_functools + +class TestLRUPy(TestLRU, unittest.TestCase): + module = py_functools + class TestSingleDispatch(unittest.TestCase): def test_simple_overloads(self): @@ -1576,32 +1751,5 @@ class TestSingleDispatch(unittest.TestCase): functools.WeakKeyDictionary = _orig_wkd -def test_main(verbose=None): - test_classes = ( - TestPartialC, - TestPartialPy, - TestPartialCSubclass, - TestPartialMethod, - TestUpdateWrapper, - TestTotalOrdering, - TestCmpToKeyC, - TestCmpToKeyPy, - TestWraps, - TestReduce, - TestLRU, - TestSingleDispatch, - ) - support.run_unittest(*test_classes) - - # verify reference counting - if verbose and hasattr(sys, "gettotalrefcount"): - import gc - counts = [None] * 5 - for i in range(len(counts)): - support.run_unittest(*test_classes) - gc.collect() - counts[i] = sys.gettotalrefcount() - print(counts) - if __name__ == '__main__': - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py index 2ac1d4b..1f0867d 100644 --- a/Lib/test/test_gc.py +++ b/Lib/test/test_gc.py @@ -1,7 +1,8 @@ import unittest from test.support import (verbose, refcount_test, run_unittest, - strip_python_stderr, cpython_only, start_threads) -from test.script_helper import assert_python_ok, make_script, temp_dir + strip_python_stderr, cpython_only, start_threads, + temp_dir) +from test.support.script_helper import assert_python_ok, make_script import sys import time @@ -546,11 +547,31 @@ class GCTests(unittest.TestCase): class UserClass: pass + + class UserInt(int): + pass + + # Base class is object; no extra fields. + class UserClassSlots: + __slots__ = () + + # Base class is fixed size larger than object; no extra fields. + class UserFloatSlots(float): + __slots__ = () + + # Base class is variable size; no extra fields. + class UserIntSlots(int): + __slots__ = () + self.assertTrue(gc.is_tracked(gc)) self.assertTrue(gc.is_tracked(UserClass)) self.assertTrue(gc.is_tracked(UserClass())) + self.assertTrue(gc.is_tracked(UserInt())) self.assertTrue(gc.is_tracked([])) self.assertTrue(gc.is_tracked(set())) + self.assertFalse(gc.is_tracked(UserClassSlots())) + self.assertFalse(gc.is_tracked(UserFloatSlots())) + self.assertFalse(gc.is_tracked(UserIntSlots())) def test_bug1055820b(self): # Corresponds to temp2b.py in the bug report. diff --git a/Lib/test/test_gdb.py b/Lib/test/test_gdb.py index b5017b9..6c4a348 100644 --- a/Lib/test/test_gdb.py +++ b/Lib/test/test_gdb.py @@ -820,25 +820,27 @@ id(42) "Python was compiled without thread support") def test_pycfunction(self): 'Verify that "py-bt" displays invocations of PyCFunction instances' - cmd = ('from time import sleep\n' + # Tested function must not be defined with METH_NOARGS or METH_O, + # otherwise call_function() doesn't call PyCFunction_Call() + cmd = ('from time import gmtime\n' 'def foo():\n' - ' sleep(1)\n' + ' gmtime(1)\n' 'def bar():\n' ' foo()\n' 'bar()\n') # Verify with "py-bt": gdb_output = self.get_stack_trace(cmd, - breakpoint='time_sleep', + breakpoint='time_gmtime', cmds_after_breakpoint=['bt', 'py-bt'], ) - self.assertIn('<built-in method sleep', gdb_output) + self.assertIn('<built-in method gmtime', gdb_output) # Verify with "py-bt-full": gdb_output = self.get_stack_trace(cmd, - breakpoint='time_sleep', + breakpoint='time_gmtime', cmds_after_breakpoint=['py-bt-full'], ) - self.assertIn('#0 <built-in method sleep', gdb_output) + self.assertIn('#0 <built-in method gmtime', gdb_output) class PyPrintTests(DebuggerTests): diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py index 5c455cd..25cc628 100644 --- a/Lib/test/test_generators.py +++ b/Lib/test/test_generators.py @@ -1,7 +1,10 @@ import gc import sys import unittest +import warnings import weakref +import inspect +import types from test import support @@ -70,6 +73,45 @@ class FinalizationTest(unittest.TestCase): self.assertEqual(cm.exception.value, 2) +class GeneratorTest(unittest.TestCase): + + def test_name(self): + def func(): + yield 1 + + # check generator names + gen = func() + self.assertEqual(gen.__name__, "func") + self.assertEqual(gen.__qualname__, + "GeneratorTest.test_name.<locals>.func") + + # modify generator names + gen.__name__ = "name" + gen.__qualname__ = "qualname" + self.assertEqual(gen.__name__, "name") + self.assertEqual(gen.__qualname__, "qualname") + + # generator names must be a string and cannot be deleted + self.assertRaises(TypeError, setattr, gen, '__name__', 123) + self.assertRaises(TypeError, setattr, gen, '__qualname__', 123) + self.assertRaises(TypeError, delattr, gen, '__name__') + self.assertRaises(TypeError, delattr, gen, '__qualname__') + + # modify names of the function creating the generator + func.__qualname__ = "func_qualname" + func.__name__ = "func_name" + gen = func() + self.assertEqual(gen.__name__, "func_name") + self.assertEqual(gen.__qualname__, "func_qualname") + + # unnamed generator + gen = (x for x in range(10)) + self.assertEqual(gen.__name__, + "<genexpr>") + self.assertEqual(gen.__qualname__, + "GeneratorTest.test_name.<locals>.<genexpr>") + + class ExceptionTest(unittest.TestCase): # Tests for the issue #23353: check that the currently handled exception # is correctly saved/restored in PyEval_EvalFrameEx(). @@ -178,6 +220,79 @@ class ExceptionTest(unittest.TestCase): self.assertEqual(next(g), "done") self.assertEqual(sys.exc_info(), (None, None, None)) + def test_stopiteration_warning(self): + # See also PEP 479. + + def gen(): + raise StopIteration + yield + + with self.assertRaises(StopIteration), \ + self.assertWarnsRegex(PendingDeprecationWarning, "StopIteration"): + + next(gen()) + + with self.assertRaisesRegex(PendingDeprecationWarning, + "generator .* raised StopIteration"), \ + warnings.catch_warnings(): + + warnings.simplefilter('error') + next(gen()) + + + def test_tutorial_stopiteration(self): + # Raise StopIteration" stops the generator too: + + def f(): + yield 1 + raise StopIteration + yield 2 # never reached + + g = f() + self.assertEqual(next(g), 1) + + with self.assertWarnsRegex(PendingDeprecationWarning, "StopIteration"): + with self.assertRaises(StopIteration): + next(g) + + with self.assertRaises(StopIteration): + # This time StopIteration isn't raised from the generator's body, + # hence no warning. + next(g) + + +class YieldFromTests(unittest.TestCase): + def test_generator_gi_yieldfrom(self): + def a(): + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING) + self.assertIsNone(gen_b.gi_yieldfrom) + yield + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_RUNNING) + self.assertIsNone(gen_b.gi_yieldfrom) + + def b(): + self.assertIsNone(gen_b.gi_yieldfrom) + yield from a() + self.assertIsNone(gen_b.gi_yieldfrom) + yield + self.assertIsNone(gen_b.gi_yieldfrom) + + gen_b = b() + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_CREATED) + self.assertIsNone(gen_b.gi_yieldfrom) + + gen_b.send(None) + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_SUSPENDED) + self.assertEqual(gen_b.gi_yieldfrom.gi_code.co_name, 'a') + + gen_b.send(None) + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_SUSPENDED) + self.assertIsNone(gen_b.gi_yieldfrom) + + [] = gen_b # Exhaust generator + self.assertEqual(inspect.getgeneratorstate(gen_b), inspect.GEN_CLOSED) + self.assertIsNone(gen_b.gi_yieldfrom) + tutorial_tests = """ Let's try a simple generator: @@ -224,26 +339,7 @@ Let's try a simple generator: File "<stdin>", line 1, in ? StopIteration -"raise StopIteration" stops the generator too: - - >>> def f(): - ... yield 1 - ... raise StopIteration - ... yield 2 # never reached - ... - >>> g = f() - >>> next(g) - 1 - >>> next(g) - Traceback (most recent call last): - File "<stdin>", line 1, in ? - StopIteration - >>> next(g) - Traceback (most recent call last): - File "<stdin>", line 1, in ? - StopIteration - -However, they are not exactly equivalent: +However, "return" and StopIteration are not exactly equivalent: >>> def g1(): ... try: @@ -563,7 +659,7 @@ From the Iterators list, about the types of these things. >>> type(i) <class 'generator'> >>> [s for s in dir(i) if not s.startswith('_')] -['close', 'gi_code', 'gi_frame', 'gi_running', 'send', 'throw'] +['close', 'gi_code', 'gi_frame', 'gi_running', 'gi_yieldfrom', 'send', 'throw'] >>> from test.support import HAVE_DOCSTRINGS >>> print(i.__next__.__doc__ if HAVE_DOCSTRINGS else 'Implement next(self).') Implement next(self). diff --git a/Lib/test/test_genericpath.py b/Lib/test/test_genericpath.py index e59ed4d..6ba55df 100644 --- a/Lib/test/test_genericpath.py +++ b/Lib/test/test_genericpath.py @@ -434,6 +434,44 @@ class CommonTest(GenericTest): with support.temp_cwd(name): self.test_abspath() + def test_join_errors(self): + # Check join() raises friendly TypeErrors. + with support.check_warnings(('', BytesWarning), quiet=True): + errmsg = "Can't mix strings and bytes in path components" + with self.assertRaisesRegex(TypeError, errmsg): + self.pathmodule.join(b'bytes', 'str') + with self.assertRaisesRegex(TypeError, errmsg): + self.pathmodule.join('str', b'bytes') + # regression, see #15377 + errmsg = r'join\(\) argument must be str or bytes, not %r' + with self.assertRaisesRegex(TypeError, errmsg % 'int'): + self.pathmodule.join(42, 'str') + with self.assertRaisesRegex(TypeError, errmsg % 'int'): + self.pathmodule.join('str', 42) + with self.assertRaisesRegex(TypeError, errmsg % 'int'): + self.pathmodule.join(42) + with self.assertRaisesRegex(TypeError, errmsg % 'list'): + self.pathmodule.join([]) + with self.assertRaisesRegex(TypeError, errmsg % 'bytearray'): + self.pathmodule.join(bytearray(b'foo'), bytearray(b'bar')) + + def test_relpath_errors(self): + # Check relpath() raises friendly TypeErrors. + with support.check_warnings(('', (BytesWarning, DeprecationWarning)), + quiet=True): + errmsg = "Can't mix strings and bytes in path components" + with self.assertRaisesRegex(TypeError, errmsg): + self.pathmodule.relpath(b'bytes', 'str') + with self.assertRaisesRegex(TypeError, errmsg): + self.pathmodule.relpath('str', b'bytes') + errmsg = r'relpath\(\) argument must be str or bytes, not %r' + with self.assertRaisesRegex(TypeError, errmsg % 'int'): + self.pathmodule.relpath(42, 'str') + with self.assertRaisesRegex(TypeError, errmsg % 'int'): + self.pathmodule.relpath('str', 42) + with self.assertRaisesRegex(TypeError, errmsg % 'bytearray'): + self.pathmodule.relpath(bytearray(b'foo'), bytearray(b'bar')) + if __name__=="__main__": unittest.main() diff --git a/Lib/test/test_getargs2.py b/Lib/test/test_getargs2.py index 1853a2d..71472cd 100644 --- a/Lib/test/test_getargs2.py +++ b/Lib/test/test_getargs2.py @@ -34,8 +34,8 @@ except ImportError: # > ** Changed from previous "range-and-a-half" to "none"; the # > range-and-a-half checking wasn't particularly useful. # -# Plus a C API or two, e.g. PyInt_AsLongMask() -> -# unsigned long and PyInt_AsLongLongMask() -> unsigned +# Plus a C API or two, e.g. PyLong_AsUnsignedLongMask() -> +# unsigned long and PyLong_AsUnsignedLongLongMask() -> unsigned # long long (if that exists). LARGE = 0x7FFFFFFF @@ -482,7 +482,7 @@ class Bytes_TestCase(unittest.TestCase): def test_s(self): from _testcapi import getargs_s self.assertEqual(getargs_s('abc\xe9'), b'abc\xc3\xa9') - self.assertRaises(TypeError, getargs_s, 'nul:\0') + self.assertRaises(ValueError, getargs_s, 'nul:\0') self.assertRaises(TypeError, getargs_s, b'bytes') self.assertRaises(TypeError, getargs_s, bytearray(b'bytearray')) self.assertRaises(TypeError, getargs_s, memoryview(b'memoryview')) @@ -509,7 +509,7 @@ class Bytes_TestCase(unittest.TestCase): def test_z(self): from _testcapi import getargs_z self.assertEqual(getargs_z('abc\xe9'), b'abc\xc3\xa9') - self.assertRaises(TypeError, getargs_z, 'nul:\0') + self.assertRaises(ValueError, getargs_z, 'nul:\0') self.assertRaises(TypeError, getargs_z, b'bytes') self.assertRaises(TypeError, getargs_z, bytearray(b'bytearray')) self.assertRaises(TypeError, getargs_z, memoryview(b'memoryview')) @@ -537,7 +537,7 @@ class Bytes_TestCase(unittest.TestCase): from _testcapi import getargs_y self.assertRaises(TypeError, getargs_y, 'abc\xe9') self.assertEqual(getargs_y(b'bytes'), b'bytes') - self.assertRaises(TypeError, getargs_y, b'nul:\0') + self.assertRaises(ValueError, getargs_y, b'nul:\0') self.assertRaises(TypeError, getargs_y, bytearray(b'bytearray')) self.assertRaises(TypeError, getargs_y, memoryview(b'memoryview')) self.assertRaises(TypeError, getargs_y, None) @@ -577,7 +577,7 @@ class Unicode_TestCase(unittest.TestCase): def test_u(self): from _testcapi import getargs_u self.assertEqual(getargs_u('abc\xe9'), 'abc\xe9') - self.assertRaises(TypeError, getargs_u, 'nul:\0') + self.assertRaises(ValueError, getargs_u, 'nul:\0') self.assertRaises(TypeError, getargs_u, b'bytes') self.assertRaises(TypeError, getargs_u, bytearray(b'bytearray')) self.assertRaises(TypeError, getargs_u, memoryview(b'memoryview')) @@ -595,7 +595,7 @@ class Unicode_TestCase(unittest.TestCase): def test_Z(self): from _testcapi import getargs_Z self.assertEqual(getargs_Z('abc\xe9'), 'abc\xe9') - self.assertRaises(TypeError, getargs_Z, 'nul:\0') + self.assertRaises(ValueError, getargs_Z, 'nul:\0') self.assertRaises(TypeError, getargs_Z, b'bytes') self.assertRaises(TypeError, getargs_Z, bytearray(b'bytearray')) self.assertRaises(TypeError, getargs_Z, memoryview(b'memoryview')) diff --git a/Lib/test/test_getopt.py b/Lib/test/test_getopt.py index fa5701f..9275dc4 100644 --- a/Lib/test/test_getopt.py +++ b/Lib/test/test_getopt.py @@ -1,7 +1,7 @@ # test_getopt.py # David Goodger <dgoodger@bigfoot.com> 2000-08-19 -from test.support import verbose, run_doctest, run_unittest, EnvironmentVarGuard +from test.support import verbose, run_doctest, EnvironmentVarGuard import unittest import getopt @@ -180,8 +180,5 @@ class GetoptTests(unittest.TestCase): self.assertEqual(longopts, [('--help', 'x')]) self.assertRaises(getopt.GetoptError, getopt.getopt, ['--help='], '', ['help']) -def test_main(): - run_unittest(GetoptTests) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py index 2737e81..de610c7 100644 --- a/Lib/test/test_gettext.py +++ b/Lib/test/test_gettext.py @@ -33,6 +33,55 @@ IHNiZSBsYmhlIENsZ3ViYSBjZWJ0ZW56ZiBvbCBjZWJpdnF2YXQgbmEgdmFncmVzbnByIGdiIGd1 ciBUQUgKdHJnZ3JrZyB6cmZmbnRyIHBuZ255YnQgeXZvZW5lbC4AYmFjb24Ad2luayB3aW5rAA== ''' +# This data contains an invalid major version number (5) +# An unexpected major version number should be treated as an error when +# parsing a .mo file + +GNU_MO_DATA_BAD_MAJOR_VERSION = b'''\ +3hIElQAABQAGAAAAHAAAAEwAAAALAAAAfAAAAAAAAACoAAAAFQAAAKkAAAAjAAAAvwAAAKEAAADj +AAAABwAAAIUBAAALAAAAjQEAAEUBAACZAQAAFgAAAN8CAAAeAAAA9gIAAKEAAAAVAwAABQAAALcD +AAAJAAAAvQMAAAEAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEAAAABQAAAAYAAAACAAAAAFJh +eW1vbmQgTHV4dXJ5IFlhY2gtdABUaGVyZSBpcyAlcyBmaWxlAFRoZXJlIGFyZSAlcyBmaWxlcwBU +aGlzIG1vZHVsZSBwcm92aWRlcyBpbnRlcm5hdGlvbmFsaXphdGlvbiBhbmQgbG9jYWxpemF0aW9u +CnN1cHBvcnQgZm9yIHlvdXIgUHl0aG9uIHByb2dyYW1zIGJ5IHByb3ZpZGluZyBhbiBpbnRlcmZh +Y2UgdG8gdGhlIEdOVQpnZXR0ZXh0IG1lc3NhZ2UgY2F0YWxvZyBsaWJyYXJ5LgBtdWxsdXNrAG51 +ZGdlIG51ZGdlAFByb2plY3QtSWQtVmVyc2lvbjogMi4wClBPLVJldmlzaW9uLURhdGU6IDIwMDAt +MDgtMjkgMTI6MTktMDQ6MDAKTGFzdC1UcmFuc2xhdG9yOiBKLiBEYXZpZCBJYsOhw7FleiA8ai1k +YXZpZEBub29zLmZyPgpMYW5ndWFnZS1UZWFtOiBYWCA8cHl0aG9uLWRldkBweXRob24ub3JnPgpN +SU1FLVZlcnNpb246IDEuMApDb250ZW50LVR5cGU6IHRleHQvcGxhaW47IGNoYXJzZXQ9aXNvLTg4 +NTktMQpDb250ZW50LVRyYW5zZmVyLUVuY29kaW5nOiBub25lCkdlbmVyYXRlZC1CeTogcHlnZXR0 +ZXh0LnB5IDEuMQpQbHVyYWwtRm9ybXM6IG5wbHVyYWxzPTI7IHBsdXJhbD1uIT0xOwoAVGhyb2F0 +d29iYmxlciBNYW5ncm92ZQBIYXkgJXMgZmljaGVybwBIYXkgJXMgZmljaGVyb3MAR3V2ZiB6YnFo +eXIgY2ViaXZxcmYgdmFncmVhbmd2YmFueXZtbmd2YmEgbmFxIHlicG55dm1uZ3ZiYQpmaGNjYmVn +IHNiZSBsYmhlIENsZ3ViYSBjZWJ0ZW56ZiBvbCBjZWJpdnF2YXQgbmEgdmFncmVzbnByIGdiIGd1 +ciBUQUgKdHJnZ3JrZyB6cmZmbnRyIHBuZ255YnQgeXZvZW5lbC4AYmFjb24Ad2luayB3aW5rAA== +''' + +# This data contains an invalid minor version number (7) +# An unexpected minor version number only indicates that some of the file's +# contents may not be able to be read. It does not indicate an error. + +GNU_MO_DATA_BAD_MINOR_VERSION = b'''\ +3hIElQcAAAAGAAAAHAAAAEwAAAALAAAAfAAAAAAAAACoAAAAFQAAAKkAAAAjAAAAvwAAAKEAAADj +AAAABwAAAIUBAAALAAAAjQEAAEUBAACZAQAAFgAAAN8CAAAeAAAA9gIAAKEAAAAVAwAABQAAALcD +AAAJAAAAvQMAAAEAAAADAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEAAAABQAAAAYAAAACAAAAAFJh +eW1vbmQgTHV4dXJ5IFlhY2gtdABUaGVyZSBpcyAlcyBmaWxlAFRoZXJlIGFyZSAlcyBmaWxlcwBU +aGlzIG1vZHVsZSBwcm92aWRlcyBpbnRlcm5hdGlvbmFsaXphdGlvbiBhbmQgbG9jYWxpemF0aW9u +CnN1cHBvcnQgZm9yIHlvdXIgUHl0aG9uIHByb2dyYW1zIGJ5IHByb3ZpZGluZyBhbiBpbnRlcmZh +Y2UgdG8gdGhlIEdOVQpnZXR0ZXh0IG1lc3NhZ2UgY2F0YWxvZyBsaWJyYXJ5LgBtdWxsdXNrAG51 +ZGdlIG51ZGdlAFByb2plY3QtSWQtVmVyc2lvbjogMi4wClBPLVJldmlzaW9uLURhdGU6IDIwMDAt +MDgtMjkgMTI6MTktMDQ6MDAKTGFzdC1UcmFuc2xhdG9yOiBKLiBEYXZpZCBJYsOhw7FleiA8ai1k +YXZpZEBub29zLmZyPgpMYW5ndWFnZS1UZWFtOiBYWCA8cHl0aG9uLWRldkBweXRob24ub3JnPgpN +SU1FLVZlcnNpb246IDEuMApDb250ZW50LVR5cGU6IHRleHQvcGxhaW47IGNoYXJzZXQ9aXNvLTg4 +NTktMQpDb250ZW50LVRyYW5zZmVyLUVuY29kaW5nOiBub25lCkdlbmVyYXRlZC1CeTogcHlnZXR0 +ZXh0LnB5IDEuMQpQbHVyYWwtRm9ybXM6IG5wbHVyYWxzPTI7IHBsdXJhbD1uIT0xOwoAVGhyb2F0 +d29iYmxlciBNYW5ncm92ZQBIYXkgJXMgZmljaGVybwBIYXkgJXMgZmljaGVyb3MAR3V2ZiB6YnFo +eXIgY2ViaXZxcmYgdmFncmVhbmd2YmFueXZtbmd2YmEgbmFxIHlicG55dm1uZ3ZiYQpmaGNjYmVn +IHNiZSBsYmhlIENsZ3ViYSBjZWJ0ZW56ZiBvbCBjZWJpdnF2YXQgbmEgdmFncmVzbnByIGdiIGd1 +ciBUQUgKdHJnZ3JrZyB6cmZmbnRyIHBuZ255YnQgeXZvZW5lbC4AYmFjb24Ad2luayB3aW5rAA== +''' + + UMO_DATA = b'''\ 3hIElQAAAAACAAAAHAAAACwAAAAFAAAAPAAAAAAAAABQAAAABAAAAFEAAAAPAQAAVgAAAAQAAABm AQAAAQAAAAIAAAAAAAAAAAAAAAAAAAAAYWLDngBQcm9qZWN0LUlkLVZlcnNpb246IDIuMApQTy1S @@ -56,6 +105,8 @@ bGUKR2VuZXJhdGVkLUJ5OiBweWdldHRleHQucHkgMS4zCgA= LOCALEDIR = os.path.join('xx', 'LC_MESSAGES') MOFILE = os.path.join(LOCALEDIR, 'gettext.mo') +MOFILE_BAD_MAJOR_VERSION = os.path.join(LOCALEDIR, 'gettext_bad_major_version.mo') +MOFILE_BAD_MINOR_VERSION = os.path.join(LOCALEDIR, 'gettext_bad_minor_version.mo') UMOFILE = os.path.join(LOCALEDIR, 'ugettext.mo') MMOFILE = os.path.join(LOCALEDIR, 'metadata.mo') @@ -66,6 +117,10 @@ class GettextBaseTest(unittest.TestCase): os.makedirs(LOCALEDIR) with open(MOFILE, 'wb') as fp: fp.write(base64.decodebytes(GNU_MO_DATA)) + with open(MOFILE_BAD_MAJOR_VERSION, 'wb') as fp: + fp.write(base64.decodebytes(GNU_MO_DATA_BAD_MAJOR_VERSION)) + with open(MOFILE_BAD_MINOR_VERSION, 'wb') as fp: + fp.write(base64.decodebytes(GNU_MO_DATA_BAD_MINOR_VERSION)) with open(UMOFILE, 'wb') as fp: fp.write(base64.decodebytes(UMO_DATA)) with open(MMOFILE, 'wb') as fp: @@ -172,6 +227,21 @@ class GettextTestCase2(GettextBaseTest): def test_textdomain(self): self.assertEqual(gettext.textdomain(), 'gettext') + def test_bad_major_version(self): + with open(MOFILE_BAD_MAJOR_VERSION, 'rb') as fp: + with self.assertRaises(OSError) as cm: + gettext.GNUTranslations(fp) + + exception = cm.exception + self.assertEqual(exception.errno, 0) + self.assertEqual(exception.strerror, "Bad version number 5") + self.assertEqual(exception.filename, MOFILE_BAD_MAJOR_VERSION) + + def test_bad_minor_version(self): + with open(MOFILE_BAD_MINOR_VERSION, 'rb') as fp: + # Check that no error is thrown with a bad minor version number + gettext.GNUTranslations(fp) + def test_some_translations(self): eq = self.assertEqual # test some translations diff --git a/Lib/test/test_glob.py b/Lib/test/test_glob.py index a5ab8d6..926588e 100644 --- a/Lib/test/test_glob.py +++ b/Lib/test/test_glob.py @@ -4,8 +4,8 @@ import shutil import sys import unittest -from test.support import (run_unittest, TESTFN, skip_unless_symlink, - can_symlink, create_empty_file) +from test.support import (TESTFN, skip_unless_symlink, + can_symlink, create_empty_file, change_cwd) class GlobTests(unittest.TestCase): @@ -13,6 +13,9 @@ class GlobTests(unittest.TestCase): def norm(self, *parts): return os.path.normpath(os.path.join(self.tempdir, *parts)) + def joins(self, *tuples): + return [os.path.join(self.tempdir, *parts) for parts in tuples] + def mktemp(self, *parts): filename = self.norm(*parts) base, file = os.path.split(filename) @@ -38,17 +41,17 @@ class GlobTests(unittest.TestCase): def tearDown(self): shutil.rmtree(self.tempdir) - def glob(self, *parts): + def glob(self, *parts, **kwargs): if len(parts) == 1: pattern = parts[0] else: pattern = os.path.join(*parts) p = os.path.join(self.tempdir, pattern) - res = glob.glob(p) - self.assertEqual(list(glob.iglob(p)), res) + res = glob.glob(p, **kwargs) + self.assertEqual(list(glob.iglob(p, **kwargs)), res) bres = [os.fsencode(x) for x in res] - self.assertEqual(glob.glob(os.fsencode(p)), bres) - self.assertEqual(list(glob.iglob(os.fsencode(p))), bres) + self.assertEqual(glob.glob(os.fsencode(p), **kwargs), bres) + self.assertEqual(list(glob.iglob(os.fsencode(p), **kwargs)), bres) return res def assertSequencesEqual_noorder(self, l1, l2): @@ -192,9 +195,114 @@ class GlobTests(unittest.TestCase): check('//?/c:/?', '//?/c:/[?]') check('//*/*/*', '//*/*/[*]') -def test_main(): - run_unittest(GlobTests) + def rglob(self, *parts, **kwargs): + return self.glob(*parts, recursive=True, **kwargs) + + def test_recursive_glob(self): + eq = self.assertSequencesEqual_noorder + full = [('ZZZ',), + ('a',), ('a', 'D'), + ('a', 'bcd'), + ('a', 'bcd', 'EF'), + ('a', 'bcd', 'efg'), + ('a', 'bcd', 'efg', 'ha'), + ('aaa',), ('aaa', 'zzzF'), + ('aab',), ('aab', 'F'), + ] + if can_symlink(): + full += [('sym1',), ('sym2',), + ('sym3',), + ('sym3', 'EF'), + ('sym3', 'efg'), + ('sym3', 'efg', 'ha'), + ] + eq(self.rglob('**'), self.joins(('',), *full)) + eq(self.rglob('.', '**'), self.joins(('.',''), + *(('.',) + i for i in full))) + dirs = [('a', ''), ('a', 'bcd', ''), ('a', 'bcd', 'efg', ''), + ('aaa', ''), ('aab', '')] + if can_symlink(): + dirs += [('sym3', ''), ('sym3', 'efg', '')] + eq(self.rglob('**', ''), self.joins(('',), *dirs)) + + eq(self.rglob('a', '**'), self.joins( + ('a', ''), ('a', 'D'), ('a', 'bcd'), ('a', 'bcd', 'EF'), + ('a', 'bcd', 'efg'), ('a', 'bcd', 'efg', 'ha'))) + eq(self.rglob('a**'), self.joins(('a',), ('aaa',), ('aab',))) + expect = [('a', 'bcd', 'EF')] + if can_symlink(): + expect += [('sym3', 'EF')] + eq(self.rglob('**', 'EF'), self.joins(*expect)) + expect = [('a', 'bcd', 'EF'), ('aaa', 'zzzF'), ('aab', 'F')] + if can_symlink(): + expect += [('sym3', 'EF')] + eq(self.rglob('**', '*F'), self.joins(*expect)) + eq(self.rglob('**', '*F', ''), []) + eq(self.rglob('**', 'bcd', '*'), self.joins( + ('a', 'bcd', 'EF'), ('a', 'bcd', 'efg'))) + eq(self.rglob('a', '**', 'bcd'), self.joins(('a', 'bcd'))) + + with change_cwd(self.tempdir): + join = os.path.join + eq(glob.glob('**', recursive=True), [join(*i) for i in full]) + eq(glob.glob(join('**', ''), recursive=True), + [join(*i) for i in dirs]) + eq(glob.glob(join('**','zz*F'), recursive=True), + [join('aaa', 'zzzF')]) + eq(glob.glob('**zz*F', recursive=True), []) + expect = [join('a', 'bcd', 'EF')] + if can_symlink(): + expect += [join('sym3', 'EF')] + eq(glob.glob(join('**', 'EF'), recursive=True), expect) + + +@skip_unless_symlink +class SymlinkLoopGlobTests(unittest.TestCase): + + def test_selflink(self): + tempdir = TESTFN + "_dir" + os.makedirs(tempdir) + self.addCleanup(shutil.rmtree, tempdir) + with change_cwd(tempdir): + os.makedirs('dir') + create_empty_file(os.path.join('dir', 'file')) + os.symlink(os.curdir, os.path.join('dir', 'link')) + + results = glob.glob('**', recursive=True) + self.assertEqual(len(results), len(set(results))) + results = set(results) + depth = 0 + while results: + path = os.path.join(*(['dir'] + ['link'] * depth)) + self.assertIn(path, results) + results.remove(path) + if not results: + break + path = os.path.join(path, 'file') + self.assertIn(path, results) + results.remove(path) + depth += 1 + + results = glob.glob(os.path.join('**', 'file'), recursive=True) + self.assertEqual(len(results), len(set(results))) + results = set(results) + depth = 0 + while results: + path = os.path.join(*(['dir'] + ['link'] * depth + ['file'])) + self.assertIn(path, results) + results.remove(path) + depth += 1 + + results = glob.glob(os.path.join('**', ''), recursive=True) + self.assertEqual(len(results), len(set(results))) + results = set(results) + depth = 0 + while results: + path = os.path.join(*(['dir'] + ['link'] * depth + [''])) + self.assertIn(path, results) + results.remove(path) + depth += 1 if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py index 2504523..ec3d783 100644 --- a/Lib/test/test_grammar.py +++ b/Lib/test/test_grammar.py @@ -1,7 +1,8 @@ # Python test set -- part 1, grammar. # This just tests whether the parser accepts them all. -from test.support import run_unittest, check_syntax_error +from test.support import check_syntax_error +import inspect import unittest import sys # testing import * @@ -205,6 +206,7 @@ class GrammarTests(unittest.TestCase): d01(1) d01(*(1,)) d01(*[] or [2]) + d01(*() or (), *{} and (), **() or {}) d01(**{'a':2}) d01(**{'a':2} or {}) def d11(a, b=1): pass @@ -298,8 +300,12 @@ class GrammarTests(unittest.TestCase): return args, kwargs self.assertEqual(f(1, x=2, *[3, 4], y=5), ((1, 3, 4), {'x':2, 'y':5})) - self.assertRaises(SyntaxError, eval, "f(1, *(2,3), 4)") + self.assertEqual(f(1, *(2,3), 4), ((1, 2, 3, 4), {})) self.assertRaises(SyntaxError, eval, "f(1, x=2, *(3,4), x=5)") + self.assertEqual(f(**{'eggs':'scrambled', 'spam':'fried'}), + ((), {'eggs':'scrambled', 'spam':'fried'})) + self.assertEqual(f(spam='fried', **{'eggs':'scrambled'}), + ((), {'eggs':'scrambled', 'spam':'fried'})) # argument annotation tests def f(x) -> list: pass @@ -530,7 +536,8 @@ class GrammarTests(unittest.TestCase): # Not allowed at class scope check_syntax_error(self, "class foo:yield 1") check_syntax_error(self, "class foo:yield from ()") - + # Check annotation refleak on SyntaxError + check_syntax_error(self, "def g(a:(yield)): pass") def test_raise(self): # 'raise' test [',' test] @@ -1018,9 +1025,103 @@ class GrammarTests(unittest.TestCase): self.assertFalse((False is 2) is 3) self.assertFalse(False is 2 is 3) + def test_matrix_mul(self): + # This is not intended to be a comprehensive test, rather just to be few + # samples of the @ operator in test_grammar.py. + class M: + def __matmul__(self, o): + return 4 + def __imatmul__(self, o): + self.other = o + return self + m = M() + self.assertEqual(m @ m, 4) + m @= 42 + self.assertEqual(m.other, 42) + + def test_async_await(self): + async = 1 + await = 2 + self.assertEqual(async, 1) + + def async(): + nonlocal await + await = 10 + async() + self.assertEqual(await, 10) + + self.assertFalse(bool(async.__code__.co_flags & inspect.CO_COROUTINE)) + + async def test(): + def sum(): + pass + if 1: + await someobj() + + self.assertEqual(test.__name__, 'test') + self.assertTrue(bool(test.__code__.co_flags & inspect.CO_COROUTINE)) + + def decorator(func): + setattr(func, '_marked', True) + return func + + @decorator + async def test2(): + return 22 + self.assertTrue(test2._marked) + self.assertEqual(test2.__name__, 'test2') + self.assertTrue(bool(test2.__code__.co_flags & inspect.CO_COROUTINE)) + + def test_async_for(self): + class Done(Exception): pass + + class AIter: + async def __aiter__(self): + return self + async def __anext__(self): + raise StopAsyncIteration + + async def foo(): + async for i in AIter(): + pass + async for i, j in AIter(): + pass + async for i in AIter(): + pass + else: + pass + raise Done + + with self.assertRaises(Done): + foo().send(None) + + def test_async_with(self): + class Done(Exception): pass + + class manager: + async def __aenter__(self): + return (1, 2) + async def __aexit__(self, *exc): + return False + + async def foo(): + async with manager(): + pass + async with manager() as x: + pass + async with manager() as (x, y): + pass + async with manager(), manager(): + pass + async with manager() as x, manager() as y: + pass + async with manager() as x, manager(): + pass + raise Done + + with self.assertRaises(Done): + foo().send(None) -def test_main(): - run_unittest(TokenTests, GrammarTests) if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_grp.py b/Lib/test/test_grp.py index 749041c..272b086 100644 --- a/Lib/test/test_grp.py +++ b/Lib/test/test_grp.py @@ -92,8 +92,5 @@ class GroupDatabaseTestCase(unittest.TestCase): self.assertRaises(KeyError, grp.getgrgid, fakegid) -def test_main(): - support.run_unittest(GroupDatabaseTestCase) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_gzip.py b/Lib/test/test_gzip.py index b417044..d8408e1 100644 --- a/Lib/test/test_gzip.py +++ b/Lib/test/test_gzip.py @@ -6,6 +6,7 @@ from test import support import os import io import struct +import array gzip = support.import_module('gzip') data1 = b""" int length=DEFAULTALLOC, err = Z_OK; @@ -77,15 +78,18 @@ class TestGzip(BaseTest): def test_write_bytearray(self): self.write_and_read_back(bytearray(data1 * 50)) + def test_write_array(self): + self.write_and_read_back(array.array('I', data1 * 40)) + def test_write_incompatible_type(self): # Test that non-bytes-like types raise TypeError. # Issue #21560: attempts to write incompatible types # should not affect the state of the fileobject with gzip.GzipFile(self.filename, 'wb') as f: with self.assertRaises(TypeError): - f.write('a') + f.write('') with self.assertRaises(TypeError): - f.write([1]) + f.write([]) f.write(data1) with gzip.GzipFile(self.filename, 'rb') as f: self.assertEqual(f.read(), data1) @@ -119,7 +123,10 @@ class TestGzip(BaseTest): # Write to a file, open it for reading, then close it. self.test_write() f = gzip.GzipFile(self.filename, 'r') + fileobj = f.fileobj + self.assertFalse(fileobj.closed) f.close() + self.assertTrue(fileobj.closed) with self.assertRaises(ValueError): f.read(1) with self.assertRaises(ValueError): @@ -128,7 +135,10 @@ class TestGzip(BaseTest): f.tell() # Open the file for writing, then close it. f = gzip.GzipFile(self.filename, 'w') + fileobj = f.fileobj + self.assertFalse(fileobj.closed) f.close() + self.assertTrue(fileobj.closed) with self.assertRaises(ValueError): f.write(b'') with self.assertRaises(ValueError): @@ -267,9 +277,10 @@ class TestGzip(BaseTest): with gzip.GzipFile(self.filename, 'w', mtime = mtime) as fWrite: fWrite.write(data1) with gzip.GzipFile(self.filename) as fRead: + self.assertTrue(hasattr(fRead, 'mtime')) + self.assertIsNone(fRead.mtime) dataRead = fRead.read() self.assertEqual(dataRead, data1) - self.assertTrue(hasattr(fRead, 'mtime')) self.assertEqual(fRead.mtime, mtime) def test_metadata(self): @@ -412,6 +423,18 @@ class TestGzip(BaseTest): with gzip.GzipFile(str_filename, "rb") as f: self.assertEqual(f.read(), data1 * 50) + def test_decompress_limited(self): + """Decompressed data buffering should be limited""" + bomb = gzip.compress(bytes(int(2e6)), compresslevel=9) + self.assertLess(len(bomb), io.DEFAULT_BUFFER_SIZE) + + bomb = io.BytesIO(bomb) + decomp = gzip.GzipFile(fileobj=bomb) + self.assertEqual(bytes(1), decomp.read(1)) + max_decomp = 1 + io.DEFAULT_BUFFER_SIZE + self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, + "Excessive amount of data was decompressed") + # Testing compress/decompress shortcut functions def test_compress(self): @@ -459,7 +482,7 @@ class TestGzip(BaseTest): with gzip.open(self.filename, "wb") as f: f.write(data1) with gzip.open(self.filename, "rb") as f: - f.fileobj.prepend() + f._buffer.raw._fp.prepend() class TestOpen(BaseTest): def test_binary_modes(self): diff --git a/Lib/test/test_hash.py b/Lib/test/test_hash.py index f647c6f..aa4efbf 100644 --- a/Lib/test/test_hash.py +++ b/Lib/test/test_hash.py @@ -7,7 +7,7 @@ import datetime import os import sys import unittest -from test.script_helper import assert_python_ok +from test.support.script_helper import assert_python_ok from collections import Hashable IS_64BIT = sys.maxsize > 2**32 diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py index b5a2fd8..b7e8259 100644 --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -6,14 +6,15 @@ import unittest from test import support from unittest import TestCase, skipUnless +from operator import itemgetter py_heapq = support.import_fresh_module('heapq', blocked=['_heapq']) c_heapq = support.import_fresh_module('heapq', fresh=['_heapq']) # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when # _heapq is imported, so check them there -func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', - 'heapreplace', '_nlargest', '_nsmallest'] +func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace', + '_heappop_max', '_heapreplace_max', '_heapify_max'] class TestModules(TestCase): def test_py_functions(self): @@ -64,7 +65,7 @@ class TestHeap: self.assertTrue(heap[parentpos] <= item) def test_heapify(self): - for size in range(30): + for size in list(range(30)) + [20000]: heap = [random.random() for dummy in range(size)] self.module.heapify(heap) self.check_invariant(heap) @@ -152,11 +153,21 @@ class TestHeap: def test_merge(self): inputs = [] - for i in range(random.randrange(5)): - row = sorted(random.randrange(1000) for j in range(random.randrange(10))) + for i in range(random.randrange(25)): + row = [] + for j in range(random.randrange(100)): + tup = random.choice('ABC'), random.randrange(-500, 500) + row.append(tup) inputs.append(row) - self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs))) - self.assertEqual(list(self.module.merge()), []) + + for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]: + for reverse in [False, True]: + seqs = [] + for seq in inputs: + seqs.append(sorted(seq, key=key, reverse=reverse)) + self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse), + list(self.module.merge(*seqs, key=key, reverse=reverse))) + self.assertEqual(list(self.module.merge()), []) def test_merge_does_not_suppress_index_error(self): # Issue 19018: Heapq.merge suppresses IndexError from user generator diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py index cde56fd..98826b5 100644 --- a/Lib/test/test_hmac.py +++ b/Lib/test/test_hmac.py @@ -493,14 +493,5 @@ class CompareDigestTestCase(unittest.TestCase): self.assertFalse(hmac.compare_digest(a, b)) -def test_main(): - support.run_unittest( - TestVectorsTestCase, - ConstructorTestCase, - SanityTestCase, - CopyTestCase, - CompareDigestTestCase - ) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_html.py b/Lib/test/test_html.py index d6f0ae8..839e0a4 100644 --- a/Lib/test/test_html.py +++ b/Lib/test/test_html.py @@ -4,7 +4,6 @@ Tests for the html module functions. import html import unittest -from test.support import run_unittest class HtmlTests(unittest.TestCase): diff --git a/Lib/test/test_htmlparser.py b/Lib/test/test_htmlparser.py index 144f820..11420b2c 100644 --- a/Lib/test/test_htmlparser.py +++ b/Lib/test/test_htmlparser.py @@ -82,7 +82,7 @@ class EventCollectorCharrefs(EventCollector): class TestCaseBase(unittest.TestCase): def get_collector(self): - raise NotImplementedError + return EventCollector(convert_charrefs=False) def _run_check(self, source, expected_events, collector=None): if collector is None: @@ -102,21 +102,8 @@ class TestCaseBase(unittest.TestCase): self._run_check(source, events, EventCollectorExtra(convert_charrefs=False)) - def _parse_error(self, source): - def parse(source=source): - parser = self.get_collector() - parser.feed(source) - parser.close() - with self.assertRaises(html.parser.HTMLParseError): - with self.assertWarns(DeprecationWarning): - parse() - -class HTMLParserStrictTestCase(TestCaseBase): - - def get_collector(self): - with support.check_warnings(("", DeprecationWarning), quite=False): - return EventCollector(strict=True, convert_charrefs=False) +class HTMLParserTestCase(TestCaseBase): def test_processing_instruction_only(self): self._run_check("<?processing instruction>", [ @@ -198,9 +185,6 @@ text ("data", "this < text > contains < bare>pointy< brackets"), ]) - def test_illegal_declarations(self): - self._parse_error('<!spacer type="block" height="25">') - def test_starttag_end_boundary(self): self._run_check("""<a b='<'>""", [("starttag", "a", [("b", "<")])]) self._run_check("""<a b='>'>""", [("starttag", "a", [("b", ">")])]) @@ -235,25 +219,6 @@ text self._run_check(["<!--abc--", ">"], output) self._run_check(["<!--abc-->", ""], output) - def test_starttag_junk_chars(self): - self._parse_error("</>") - self._parse_error("</$>") - self._parse_error("</") - self._parse_error("</a") - self._parse_error("<a<a>") - self._parse_error("</a<a>") - self._parse_error("<!") - self._parse_error("<a") - self._parse_error("<a foo='bar'") - self._parse_error("<a foo='bar") - self._parse_error("<a foo='>'") - self._parse_error("<a foo='>") - self._parse_error("<a$>") - self._parse_error("<a$b>") - self._parse_error("<a$b/>") - self._parse_error("<a$b >") - self._parse_error("<a$b />") - def test_valid_doctypes(self): # from http://www.w3.org/QA/2002/04/valid-dtd-list.html dtds = ['HTML', # HTML5 doctype @@ -278,9 +243,6 @@ text self._run_check("<!DOCTYPE %s>" % dtd, [('decl', 'DOCTYPE ' + dtd)]) - def test_declaration_junk_chars(self): - self._parse_error("<!DOCTYPE foo $ >") - def test_startendtag(self): self._run_check("<p/>", [ ("startendtag", "p", []), @@ -381,7 +343,8 @@ text self._run_check(html, expected) def test_convert_charrefs(self): - collector = lambda: EventCollectorCharrefs(convert_charrefs=True) + # default value for convert_charrefs is now True + collector = lambda: EventCollectorCharrefs() self.assertTrue(collector().convert_charrefs) charrefs = ['"', '"', '"', '"', '"', '"'] # check charrefs in the middle of the text/attributes @@ -418,23 +381,8 @@ text self._run_check('no charrefs here', [('data', 'no charrefs here')], collector=collector()) - -class HTMLParserTolerantTestCase(HTMLParserStrictTestCase): - - def get_collector(self): - return EventCollector(convert_charrefs=False) - - def test_deprecation_warnings(self): - with self.assertWarns(DeprecationWarning): - EventCollector() # convert_charrefs not passed explicitly - with self.assertWarns(DeprecationWarning): - EventCollector(strict=True) - with self.assertWarns(DeprecationWarning): - EventCollector(strict=False) - with self.assertRaises(html.parser.HTMLParseError): - with self.assertWarns(DeprecationWarning): - EventCollector().error('test') - + # the remaining tests were for the "tolerant" parser (which is now + # the default), and check various kind of broken markup def test_tolerant_parsing(self): self._run_check('<html <html>te>>xt&a<<bc</a></html>\n' '<img src="URL><//img></html</html>', [ @@ -695,11 +643,7 @@ class HTMLParserTolerantTestCase(HTMLParserStrictTestCase): ) -class AttributesStrictTestCase(TestCaseBase): - - def get_collector(self): - with support.check_warnings(("", DeprecationWarning), quite=False): - return EventCollector(strict=True, convert_charrefs=False) +class AttributesTestCase(TestCaseBase): def test_attr_syntax(self): output = [ @@ -756,12 +700,6 @@ class AttributesStrictTestCase(TestCaseBase): [("starttag", "html", [("foo", "\u20AC&aa&unsupported;")])]) - -class AttributesTolerantTestCase(AttributesStrictTestCase): - - def get_collector(self): - return EventCollector(convert_charrefs=False) - def test_attr_funky_names2(self): self._run_check( "<a $><b $=%><c \=/>", diff --git a/Lib/test/test_http_cookies.py b/Lib/test/test_http_cookies.py index c7b680b..d3e06a4 100644 --- a/Lib/test/test_http_cookies.py +++ b/Lib/test/test_http_cookies.py @@ -1,5 +1,6 @@ # Simple test suite for http/cookies.py +import copy from test.support import run_unittest, run_doctest, check_warnings import unittest from http import cookies @@ -154,13 +155,6 @@ class CookieTests(unittest.TestCase): self.assertEqual(C['eggs']['httponly'], 'foo') self.assertEqual(C['eggs']['secure'], 'bar') - def test_bad_attrs(self): - # issue 16611: make sure we don't break backward compatibility. - C = cookies.SimpleCookie() - C.load('cookie=with; invalid; version; second=cookie;') - self.assertEqual(C.output(), - 'Set-Cookie: cookie=with\r\nSet-Cookie: second=cookie') - def test_extra_spaces(self): C = cookies.SimpleCookie() C.load('eggs = scrambled ; secure ; path = bar ; foo=foo ') @@ -195,7 +189,10 @@ class CookieTests(unittest.TestCase): def test_invalid_cookies(self): # Accepting these could be a security issue C = cookies.SimpleCookie() - for s in (']foo=x', '[foo=x', 'blah]foo=x', 'blah[foo=x'): + for s in (']foo=x', '[foo=x', 'blah]foo=x', 'blah[foo=x', + 'Set-Cookie: foo=bar', 'Set-Cookie: foo', + 'foo=bar; baz', 'baz; foo=bar', + 'secure;foo=bar', 'Version=1;foo=bar'): C.load(s) self.assertEqual(dict(C), {}) self.assertEqual(C.output(), '') @@ -217,6 +214,15 @@ class CookieTests(unittest.TestCase): class MorselTests(unittest.TestCase): """Tests for the Morsel object.""" + def test_defaults(self): + morsel = cookies.Morsel() + self.assertIsNone(morsel.key) + self.assertIsNone(morsel.value) + self.assertIsNone(morsel.coded_value) + self.assertEqual(morsel.keys(), cookies.Morsel._reserved.keys()) + for key, val in morsel.items(): + self.assertEqual(val, '', key) + def test_reserved_keys(self): M = cookies.Morsel() # tests valid and invalid reserved keys for Morsels @@ -260,6 +266,197 @@ class MorselTests(unittest.TestCase): self.assertRaises(cookies.CookieError, M.set, i, '%s_value' % i, '%s_value' % i) + def test_deprecation(self): + morsel = cookies.Morsel() + with self.assertWarnsRegex(DeprecationWarning, r'\bkey\b'): + morsel.key = '' + with self.assertWarnsRegex(DeprecationWarning, r'\bvalue\b'): + morsel.value = '' + with self.assertWarnsRegex(DeprecationWarning, r'\bcoded_value\b'): + morsel.coded_value = '' + with self.assertWarnsRegex(DeprecationWarning, r'\bLegalChars\b'): + morsel.set('key', 'value', 'coded_value', LegalChars='.*') + + def test_eq(self): + base_case = ('key', 'value', '"value"') + attribs = { + 'path': '/', + 'comment': 'foo', + 'domain': 'example.com', + 'version': 2, + } + morsel_a = cookies.Morsel() + morsel_a.update(attribs) + morsel_a.set(*base_case) + morsel_b = cookies.Morsel() + morsel_b.update(attribs) + morsel_b.set(*base_case) + self.assertTrue(morsel_a == morsel_b) + self.assertFalse(morsel_a != morsel_b) + cases = ( + ('key', 'value', 'mismatch'), + ('key', 'mismatch', '"value"'), + ('mismatch', 'value', '"value"'), + ) + for case_b in cases: + with self.subTest(case_b): + morsel_b = cookies.Morsel() + morsel_b.update(attribs) + morsel_b.set(*case_b) + self.assertFalse(morsel_a == morsel_b) + self.assertTrue(morsel_a != morsel_b) + + morsel_b = cookies.Morsel() + morsel_b.update(attribs) + morsel_b.set(*base_case) + morsel_b['comment'] = 'bar' + self.assertFalse(morsel_a == morsel_b) + self.assertTrue(morsel_a != morsel_b) + + # test mismatched types + self.assertFalse(cookies.Morsel() == 1) + self.assertTrue(cookies.Morsel() != 1) + self.assertFalse(cookies.Morsel() == '') + self.assertTrue(cookies.Morsel() != '') + items = list(cookies.Morsel().items()) + self.assertFalse(cookies.Morsel() == items) + self.assertTrue(cookies.Morsel() != items) + + # morsel/dict + morsel = cookies.Morsel() + morsel.set(*base_case) + morsel.update(attribs) + self.assertTrue(morsel == dict(morsel)) + self.assertFalse(morsel != dict(morsel)) + + def test_copy(self): + morsel_a = cookies.Morsel() + morsel_a.set('foo', 'bar', 'baz') + morsel_a.update({ + 'version': 2, + 'comment': 'foo', + }) + morsel_b = morsel_a.copy() + self.assertIsInstance(morsel_b, cookies.Morsel) + self.assertIsNot(morsel_a, morsel_b) + self.assertEqual(morsel_a, morsel_b) + + morsel_b = copy.copy(morsel_a) + self.assertIsInstance(morsel_b, cookies.Morsel) + self.assertIsNot(morsel_a, morsel_b) + self.assertEqual(morsel_a, morsel_b) + + def test_setitem(self): + morsel = cookies.Morsel() + morsel['expires'] = 0 + self.assertEqual(morsel['expires'], 0) + morsel['Version'] = 2 + self.assertEqual(morsel['version'], 2) + morsel['DOMAIN'] = 'example.com' + self.assertEqual(morsel['domain'], 'example.com') + + with self.assertRaises(cookies.CookieError): + morsel['invalid'] = 'value' + self.assertNotIn('invalid', morsel) + + def test_setdefault(self): + morsel = cookies.Morsel() + morsel.update({ + 'domain': 'example.com', + 'version': 2, + }) + # this shouldn't override the default value + self.assertEqual(morsel.setdefault('expires', 'value'), '') + self.assertEqual(morsel['expires'], '') + self.assertEqual(morsel.setdefault('Version', 1), 2) + self.assertEqual(morsel['version'], 2) + self.assertEqual(morsel.setdefault('DOMAIN', 'value'), 'example.com') + self.assertEqual(morsel['domain'], 'example.com') + + with self.assertRaises(cookies.CookieError): + morsel.setdefault('invalid', 'value') + self.assertNotIn('invalid', morsel) + + def test_update(self): + attribs = {'expires': 1, 'Version': 2, 'DOMAIN': 'example.com'} + # test dict update + morsel = cookies.Morsel() + morsel.update(attribs) + self.assertEqual(morsel['expires'], 1) + self.assertEqual(morsel['version'], 2) + self.assertEqual(morsel['domain'], 'example.com') + # test iterable update + morsel = cookies.Morsel() + morsel.update(list(attribs.items())) + self.assertEqual(morsel['expires'], 1) + self.assertEqual(morsel['version'], 2) + self.assertEqual(morsel['domain'], 'example.com') + # test iterator update + morsel = cookies.Morsel() + morsel.update((k, v) for k, v in attribs.items()) + self.assertEqual(morsel['expires'], 1) + self.assertEqual(morsel['version'], 2) + self.assertEqual(morsel['domain'], 'example.com') + + with self.assertRaises(cookies.CookieError): + morsel.update({'invalid': 'value'}) + self.assertNotIn('invalid', morsel) + self.assertRaises(TypeError, morsel.update) + self.assertRaises(TypeError, morsel.update, 0) + + def test_pickle(self): + morsel_a = cookies.Morsel() + morsel_a.set('foo', 'bar', 'baz') + morsel_a.update({ + 'version': 2, + 'comment': 'foo', + }) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + morsel_b = pickle.loads(pickle.dumps(morsel_a, proto)) + self.assertIsInstance(morsel_b, cookies.Morsel) + self.assertEqual(morsel_b, morsel_a) + self.assertEqual(str(morsel_b), str(morsel_a)) + + def test_repr(self): + morsel = cookies.Morsel() + self.assertEqual(repr(morsel), '<Morsel: None=None>') + self.assertEqual(str(morsel), 'Set-Cookie: None=None') + morsel.set('key', 'val', 'coded_val') + self.assertEqual(repr(morsel), '<Morsel: key=coded_val>') + self.assertEqual(str(morsel), 'Set-Cookie: key=coded_val') + morsel.update({ + 'path': '/', + 'comment': 'foo', + 'domain': 'example.com', + 'max-age': 0, + 'secure': 0, + 'version': 1, + }) + self.assertEqual(repr(morsel), + '<Morsel: key=coded_val; Comment=foo; Domain=example.com; ' + 'Max-Age=0; Path=/; Version=1>') + self.assertEqual(str(morsel), + 'Set-Cookie: key=coded_val; Comment=foo; Domain=example.com; ' + 'Max-Age=0; Path=/; Version=1') + morsel['secure'] = True + morsel['httponly'] = 1 + self.assertEqual(repr(morsel), + '<Morsel: key=coded_val; Comment=foo; Domain=example.com; ' + 'HttpOnly; Max-Age=0; Path=/; Secure; Version=1>') + self.assertEqual(str(morsel), + 'Set-Cookie: key=coded_val; Comment=foo; Domain=example.com; ' + 'HttpOnly; Max-Age=0; Path=/; Secure; Version=1') + + morsel = cookies.Morsel() + morsel.set('key', 'val', 'coded_val') + morsel['expires'] = 0 + self.assertRegex(repr(morsel), + r'<Morsel: key=coded_val; ' + r'expires=\w+, \d+ \w+ \d+ \d+:\d+:\d+ \w+>') + self.assertRegex(str(morsel), + r'Set-Cookie: key=coded_val; ' + r'expires=\w+, \d+ \w+ \d+ \d+:\d+:\d+ \w+') def test_main(): run_unittest(CookieTests, MorselTests) diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py index df9a9e3..d809414 100644 --- a/Lib/test/test_httplib.py +++ b/Lib/test/test_httplib.py @@ -19,6 +19,26 @@ CERT_fakehostname = os.path.join(here, 'keycert2.pem') # Self-signed cert file for self-signed.pythontest.net CERT_selfsigned_pythontestdotnet = os.path.join(here, 'selfsigned_pythontestdotnet.pem') +# constants for testing chunked encoding +chunked_start = ( + 'HTTP/1.1 200 OK\r\n' + 'Transfer-Encoding: chunked\r\n\r\n' + 'a\r\n' + 'hello worl\r\n' + '3\r\n' + 'd! \r\n' + '8\r\n' + 'and now \r\n' + '22\r\n' + 'for something completely different\r\n' +) +chunked_expected = b'hello world! and now for something completely different' +chunk_extension = ";foo=bar" +last_chunk = "0\r\n" +last_chunk_extended = "0" + chunk_extension + "\r\n" +trailers = "X-Dummy: foo\r\nX-Dumm2: bar\r\n" +chunked_end = "\r\n" + HOST = support.HOST class FakeSocket: @@ -51,6 +71,9 @@ class FakeSocket: def close(self): pass + def setsockopt(self, level, optname, value): + pass + class EPipeSocket(FakeSocket): def __init__(self, text, pipe_trigger): @@ -84,6 +107,23 @@ class NoEOFBytesIO(io.BytesIO): raise AssertionError('caller tried to read past EOF') return data +class FakeSocketHTTPConnection(client.HTTPConnection): + """HTTPConnection subclass using FakeSocket; counts connect() calls""" + + def __init__(self, *args): + self.connections = 0 + super().__init__('example.com') + self.fake_socket_args = args + self._create_connection = self.create_connection + + def connect(self): + """Count the number of times connect() is invoked""" + self.connections += 1 + return super().connect() + + def create_connection(self, *pos, **kw): + return FakeSocket(*self.fake_socket_args) + class HeaderTests(TestCase): def test_auto_headers(self): # Some headers are added automatically, but should not be added by @@ -548,20 +588,8 @@ class BasicTest(TestCase): conn.request('POST', 'test', conn) def test_chunked(self): - chunked_start = ( - 'HTTP/1.1 200 OK\r\n' - 'Transfer-Encoding: chunked\r\n\r\n' - 'a\r\n' - 'hello worl\r\n' - '3\r\n' - 'd! \r\n' - '8\r\n' - 'and now \r\n' - '22\r\n' - 'for something completely different\r\n' - ) - expected = b'hello world! and now for something completely different' - sock = FakeSocket(chunked_start + '0\r\n') + expected = chunked_expected + sock = FakeSocket(chunked_start + last_chunk + chunked_end) resp = client.HTTPResponse(sock, method="GET") resp.begin() self.assertEqual(resp.read(), expected) @@ -569,7 +597,7 @@ class BasicTest(TestCase): # Various read sizes for n in range(1, 12): - sock = FakeSocket(chunked_start + '0\r\n') + sock = FakeSocket(chunked_start + last_chunk + chunked_end) resp = client.HTTPResponse(sock, method="GET") resp.begin() self.assertEqual(resp.read(n) + resp.read(n) + resp.read(), expected) @@ -592,23 +620,12 @@ class BasicTest(TestCase): resp.close() def test_readinto_chunked(self): - chunked_start = ( - 'HTTP/1.1 200 OK\r\n' - 'Transfer-Encoding: chunked\r\n\r\n' - 'a\r\n' - 'hello worl\r\n' - '3\r\n' - 'd! \r\n' - '8\r\n' - 'and now \r\n' - '22\r\n' - 'for something completely different\r\n' - ) - expected = b'hello world! and now for something completely different' + + expected = chunked_expected nexpected = len(expected) b = bytearray(128) - sock = FakeSocket(chunked_start + '0\r\n') + sock = FakeSocket(chunked_start + last_chunk + chunked_end) resp = client.HTTPResponse(sock, method="GET") resp.begin() n = resp.readinto(b) @@ -618,7 +635,7 @@ class BasicTest(TestCase): # Various read sizes for n in range(1, 12): - sock = FakeSocket(chunked_start + '0\r\n') + sock = FakeSocket(chunked_start + last_chunk + chunked_end) resp = client.HTTPResponse(sock, method="GET") resp.begin() m = memoryview(b) @@ -654,7 +671,7 @@ class BasicTest(TestCase): '1\r\n' 'd\r\n' ) - sock = FakeSocket(chunked_start + '0\r\n') + sock = FakeSocket(chunked_start + last_chunk + chunked_end) resp = client.HTTPResponse(sock, method="HEAD") resp.begin() self.assertEqual(resp.read(), b'') @@ -674,7 +691,7 @@ class BasicTest(TestCase): '1\r\n' 'd\r\n' ) - sock = FakeSocket(chunked_start + '0\r\n') + sock = FakeSocket(chunked_start + last_chunk + chunked_end) resp = client.HTTPResponse(sock, method="HEAD") resp.begin() b = bytearray(5) @@ -749,6 +766,7 @@ class BasicTest(TestCase): + '0' * 65536 + 'a\r\n' 'hello world\r\n' '0\r\n' + '\r\n' ) resp = client.HTTPResponse(FakeSocket(body)) resp.begin() @@ -766,28 +784,6 @@ class BasicTest(TestCase): resp.close() self.assertTrue(resp.closed) - def test_delayed_ack_opt(self): - # Test that Nagle/delayed_ack optimistaion works correctly. - - # For small payloads, it should coalesce the body with - # headers, resulting in a single sendall() call - conn = client.HTTPConnection('example.com') - sock = FakeSocket(None) - conn.sock = sock - body = b'x' * (conn.mss - 1) - conn.request('POST', '/', body) - self.assertEqual(sock.sendall_calls, 1) - - # For large payloads, it should send the headers and - # then the body, resulting in more than one sendall() - # call - conn = client.HTTPConnection('example.com') - sock = FakeSocket(None) - conn.sock = sock - body = b'x' * conn.mss - conn.request('POST', '/', body) - self.assertGreater(sock.sendall_calls, 1) - def test_error_leak(self): # Test that the socket is not leaked if getresponse() fails conn = client.HTTPConnection('example.com') @@ -798,12 +794,245 @@ class BasicTest(TestCase): response = self # Avoid garbage collector closing the socket client.HTTPResponse.__init__(self, *pos, **kw) conn.response_class = Response - conn.sock = FakeSocket('') # Emulate server dropping connection + conn.sock = FakeSocket('Invalid status line') conn.request('GET', '/') self.assertRaises(client.BadStatusLine, conn.getresponse) self.assertTrue(response.closed) self.assertTrue(conn.sock.file_closed) + def test_chunked_extension(self): + extra = '3;foo=bar\r\n' + 'abc\r\n' + expected = chunked_expected + b'abc' + + sock = FakeSocket(chunked_start + extra + last_chunk_extended + chunked_end) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + self.assertEqual(resp.read(), expected) + resp.close() + + def test_chunked_missing_end(self): + """some servers may serve up a short chunked encoding stream""" + expected = chunked_expected + sock = FakeSocket(chunked_start + last_chunk) #no terminating crlf + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + self.assertEqual(resp.read(), expected) + resp.close() + + def test_chunked_trailers(self): + """See that trailers are read and ignored""" + expected = chunked_expected + sock = FakeSocket(chunked_start + last_chunk + trailers + chunked_end) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + self.assertEqual(resp.read(), expected) + # we should have reached the end of the file + self.assertEqual(sock.file.read(100), b"") #we read to the end + resp.close() + + def test_chunked_sync(self): + """Check that we don't read past the end of the chunked-encoding stream""" + expected = chunked_expected + extradata = "extradata" + sock = FakeSocket(chunked_start + last_chunk + trailers + chunked_end + extradata) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + self.assertEqual(resp.read(), expected) + # the file should now have our extradata ready to be read + self.assertEqual(sock.file.read(100), extradata.encode("ascii")) #we read to the end + resp.close() + + def test_content_length_sync(self): + """Check that we don't read past the end of the Content-Length stream""" + extradata = "extradata" + expected = b"Hello123\r\n" + sock = FakeSocket('HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\nHello123\r\n' + extradata) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + self.assertEqual(resp.read(), expected) + # the file should now have our extradata ready to be read + self.assertEqual(sock.file.read(100), extradata.encode("ascii")) #we read to the end + resp.close() + +class ExtendedReadTest(TestCase): + """ + Test peek(), read1(), readline() + """ + lines = ( + 'HTTP/1.1 200 OK\r\n' + '\r\n' + 'hello world!\n' + 'and now \n' + 'for something completely different\n' + 'foo' + ) + lines_expected = lines[lines.find('hello'):].encode("ascii") + lines_chunked = ( + 'HTTP/1.1 200 OK\r\n' + 'Transfer-Encoding: chunked\r\n\r\n' + 'a\r\n' + 'hello worl\r\n' + '3\r\n' + 'd!\n\r\n' + '9\r\n' + 'and now \n\r\n' + '23\r\n' + 'for something completely different\n\r\n' + '3\r\n' + 'foo\r\n' + '0\r\n' # terminating chunk + '\r\n' # end of trailers + ) + + def setUp(self): + sock = FakeSocket(self.lines) + resp = client.HTTPResponse(sock, method="GET") + resp.begin() + resp.fp = io.BufferedReader(resp.fp) + self.resp = resp + + + + def test_peek(self): + resp = self.resp + # patch up the buffered peek so that it returns not too much stuff + oldpeek = resp.fp.peek + def mypeek(n=-1): + p = oldpeek(n) + if n >= 0: + return p[:n] + return p[:10] + resp.fp.peek = mypeek + + all = [] + while True: + # try a short peek + p = resp.peek(3) + if p: + self.assertGreater(len(p), 0) + # then unbounded peek + p2 = resp.peek() + self.assertGreaterEqual(len(p2), len(p)) + self.assertTrue(p2.startswith(p)) + next = resp.read(len(p2)) + self.assertEqual(next, p2) + else: + next = resp.read() + self.assertFalse(next) + all.append(next) + if not next: + break + self.assertEqual(b"".join(all), self.lines_expected) + + def test_readline(self): + resp = self.resp + self._verify_readline(self.resp.readline, self.lines_expected) + + def _verify_readline(self, readline, expected): + all = [] + while True: + # short readlines + line = readline(5) + if line and line != b"foo": + if len(line) < 5: + self.assertTrue(line.endswith(b"\n")) + all.append(line) + if not line: + break + self.assertEqual(b"".join(all), expected) + + def test_read1(self): + resp = self.resp + def r(): + res = resp.read1(4) + self.assertLessEqual(len(res), 4) + return res + readliner = Readliner(r) + self._verify_readline(readliner.readline, self.lines_expected) + + def test_read1_unbounded(self): + resp = self.resp + all = [] + while True: + data = resp.read1() + if not data: + break + all.append(data) + self.assertEqual(b"".join(all), self.lines_expected) + + def test_read1_bounded(self): + resp = self.resp + all = [] + while True: + data = resp.read1(10) + if not data: + break + self.assertLessEqual(len(data), 10) + all.append(data) + self.assertEqual(b"".join(all), self.lines_expected) + + def test_read1_0(self): + self.assertEqual(self.resp.read1(0), b"") + + def test_peek_0(self): + p = self.resp.peek(0) + self.assertLessEqual(0, len(p)) + +class ExtendedReadTestChunked(ExtendedReadTest): + """ + Test peek(), read1(), readline() in chunked mode + """ + lines = ( + 'HTTP/1.1 200 OK\r\n' + 'Transfer-Encoding: chunked\r\n\r\n' + 'a\r\n' + 'hello worl\r\n' + '3\r\n' + 'd!\n\r\n' + '9\r\n' + 'and now \n\r\n' + '23\r\n' + 'for something completely different\n\r\n' + '3\r\n' + 'foo\r\n' + '0\r\n' # terminating chunk + '\r\n' # end of trailers + ) + + +class Readliner: + """ + a simple readline class that uses an arbitrary read function and buffering + """ + def __init__(self, readfunc): + self.readfunc = readfunc + self.remainder = b"" + + def readline(self, limit): + data = [] + datalen = 0 + read = self.remainder + try: + while True: + idx = read.find(b'\n') + if idx != -1: + break + if datalen + len(read) >= limit: + idx = limit - datalen - 1 + # read more data + data.append(read) + read = self.readfunc() + if not read: + idx = 0 #eof condition + break + idx += 1 + data.append(read[:idx]) + self.remainder = read[idx:] + return b"".join(data) + except: + self.remainder = b"".join(data) + raise + class OfflineTest(TestCase): def test_all(self): @@ -823,13 +1052,74 @@ class OfflineTest(TestCase): def test_responses(self): self.assertEqual(client.responses[client.NOT_FOUND], "Not Found") + def test_client_constants(self): + # Make sure we don't break backward compatibility with 3.4 + expected = [ + 'CONTINUE', + 'SWITCHING_PROTOCOLS', + 'PROCESSING', + 'OK', + 'CREATED', + 'ACCEPTED', + 'NON_AUTHORITATIVE_INFORMATION', + 'NO_CONTENT', + 'RESET_CONTENT', + 'PARTIAL_CONTENT', + 'MULTI_STATUS', + 'IM_USED', + 'MULTIPLE_CHOICES', + 'MOVED_PERMANENTLY', + 'FOUND', + 'SEE_OTHER', + 'NOT_MODIFIED', + 'USE_PROXY', + 'TEMPORARY_REDIRECT', + 'BAD_REQUEST', + 'UNAUTHORIZED', + 'PAYMENT_REQUIRED', + 'FORBIDDEN', + 'NOT_FOUND', + 'METHOD_NOT_ALLOWED', + 'NOT_ACCEPTABLE', + 'PROXY_AUTHENTICATION_REQUIRED', + 'REQUEST_TIMEOUT', + 'CONFLICT', + 'GONE', + 'LENGTH_REQUIRED', + 'PRECONDITION_FAILED', + 'REQUEST_ENTITY_TOO_LARGE', + 'REQUEST_URI_TOO_LONG', + 'UNSUPPORTED_MEDIA_TYPE', + 'REQUESTED_RANGE_NOT_SATISFIABLE', + 'EXPECTATION_FAILED', + 'UNPROCESSABLE_ENTITY', + 'LOCKED', + 'FAILED_DEPENDENCY', + 'UPGRADE_REQUIRED', + 'PRECONDITION_REQUIRED', + 'TOO_MANY_REQUESTS', + 'REQUEST_HEADER_FIELDS_TOO_LARGE', + 'INTERNAL_SERVER_ERROR', + 'NOT_IMPLEMENTED', + 'BAD_GATEWAY', + 'SERVICE_UNAVAILABLE', + 'GATEWAY_TIMEOUT', + 'HTTP_VERSION_NOT_SUPPORTED', + 'INSUFFICIENT_STORAGE', + 'NOT_EXTENDED', + 'NETWORK_AUTHENTICATION_REQUIRED', + ] + for const in expected: + with self.subTest(constant=const): + self.assertTrue(hasattr(client, const)) + class SourceAddressTest(TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.port = support.bind_port(self.serv) self.source_port = support.find_unused_port() - self.serv.listen(5) + self.serv.listen() self.conn = None def tearDown(self): @@ -861,7 +1151,7 @@ class TimeoutTest(TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) TimeoutTest.PORT = support.bind_port(self.serv) - self.serv.listen(5) + self.serv.listen() def tearDown(self): self.serv.close() @@ -901,6 +1191,78 @@ class TimeoutTest(TestCase): httpConn.close() +class PersistenceTest(TestCase): + + def test_reuse_reconnect(self): + # Should reuse or reconnect depending on header from server + tests = ( + ('1.0', '', False), + ('1.0', 'Connection: keep-alive\r\n', True), + ('1.1', '', True), + ('1.1', 'Connection: close\r\n', False), + ('1.0', 'Connection: keep-ALIVE\r\n', True), + ('1.1', 'Connection: cloSE\r\n', False), + ) + for version, header, reuse in tests: + with self.subTest(version=version, header=header): + msg = ( + 'HTTP/{} 200 OK\r\n' + '{}' + 'Content-Length: 12\r\n' + '\r\n' + 'Dummy body\r\n' + ).format(version, header) + conn = FakeSocketHTTPConnection(msg) + self.assertIsNone(conn.sock) + conn.request('GET', '/open-connection') + with conn.getresponse() as response: + self.assertEqual(conn.sock is None, not reuse) + response.read() + self.assertEqual(conn.sock is None, not reuse) + self.assertEqual(conn.connections, 1) + conn.request('GET', '/subsequent-request') + self.assertEqual(conn.connections, 1 if reuse else 2) + + def test_disconnected(self): + + def make_reset_reader(text): + """Return BufferedReader that raises ECONNRESET at EOF""" + stream = io.BytesIO(text) + def readinto(buffer): + size = io.BytesIO.readinto(stream, buffer) + if size == 0: + raise ConnectionResetError() + return size + stream.readinto = readinto + return io.BufferedReader(stream) + + tests = ( + (io.BytesIO, client.RemoteDisconnected), + (make_reset_reader, ConnectionResetError), + ) + for stream_factory, exception in tests: + with self.subTest(exception=exception): + conn = FakeSocketHTTPConnection(b'', stream_factory) + conn.request('GET', '/eof-response') + self.assertRaises(exception, conn.getresponse) + self.assertIsNone(conn.sock) + # HTTPConnection.connect() should be automatically invoked + conn.request('GET', '/reconnect') + self.assertEqual(conn.connections, 2) + + def test_100_close(self): + conn = FakeSocketHTTPConnection( + b'HTTP/1.1 100 Continue\r\n' + b'\r\n' + # Missing final response + ) + conn.request('GET', '/', headers={'Expect': '100-continue'}) + self.assertRaises(client.RemoteDisconnected, conn.getresponse) + self.assertIsNone(conn.sock) + conn.request('GET', '/reconnect') + self.assertEqual(conn.connections, 2) + + class HTTPSTest(TestCase): def setUp(self): @@ -1171,17 +1533,18 @@ class TunnelTests(TestCase): 'HTTP/1.1 200 OK\r\n' # Reply to HEAD 'Content-Length: 42\r\n\r\n' ) - - def create_connection(address, timeout=None, source_address=None): - return FakeSocket(response_text, host=address[0], port=address[1]) - self.host = 'proxy.com' self.conn = client.HTTPConnection(self.host) - self.conn._create_connection = create_connection + self.conn._create_connection = self._create_connection(response_text) def tearDown(self): self.conn.close() + def _create_connection(self, response_text): + def create_connection(address, timeout=None, source_address=None): + return FakeSocket(response_text, host=address[0], port=address[1]) + return create_connection + def test_set_tunnel_host_port_headers(self): tunnel_host = 'destination.com' tunnel_port = 8888 @@ -1222,13 +1585,27 @@ class TunnelTests(TestCase): self.assertIn(b'CONNECT destination.com', self.conn.sock.data) self.assertIn(b'Host: destination.com', self.conn.sock.data) + def test_tunnel_debuglog(self): + expected_header = 'X-Dummy: 1' + response_text = 'HTTP/1.0 200 OK\r\n{}\r\n\r\n'.format(expected_header) + + self.conn.set_debuglevel(1) + self.conn._create_connection = self._create_connection(response_text) + self.conn.set_tunnel('destination.com') + + with support.captured_stdout() as output: + self.conn.request('PUT', '/', '') + lines = output.getvalue().splitlines() + self.assertIn('header: {}'.format(expected_header), lines) @support.reap_threads def test_main(verbose=None): support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest, + PersistenceTest, HTTPSTest, RequestBodyTest, SourceAddressTest, - HTTPResponseTest, TunnelTests) + HTTPResponseTest, ExtendedReadTest, + ExtendedReadTestChunked, TunnelTests) if __name__ == '__main__': test_main() diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py index 6e5f2db..50244c6 100644 --- a/Lib/test/test_httpservers.py +++ b/Lib/test/test_httpservers.py @@ -6,7 +6,7 @@ Josip Dzolonga, and Michael Otteneder for the 2007/08 GHOP contest. from http.server import BaseHTTPRequestHandler, HTTPServer, \ SimpleHTTPRequestHandler, CGIHTTPRequestHandler -from http import server +from http import server, HTTPStatus import os import sys @@ -79,13 +79,13 @@ class BaseHTTPServerTestCase(BaseTestCase): default_request_version = 'HTTP/1.1' def do_TEST(self): - self.send_response(204) + self.send_response(HTTPStatus.NO_CONTENT) self.send_header('Content-Type', 'text/html') self.send_header('Connection', 'close') self.end_headers() def do_KEEP(self): - self.send_response(204) + self.send_response(HTTPStatus.NO_CONTENT) self.send_header('Content-Type', 'text/html') self.send_header('Connection', 'keep-alive') self.end_headers() @@ -94,7 +94,7 @@ class BaseHTTPServerTestCase(BaseTestCase): self.send_error(999) def do_NOTFOUND(self): - self.send_error(404) + self.send_error(HTTPStatus.NOT_FOUND) def do_EXPLAINERROR(self): self.send_error(999, "Short Message", @@ -122,35 +122,35 @@ class BaseHTTPServerTestCase(BaseTestCase): def test_command(self): self.con.request('GET', '/') res = self.con.getresponse() - self.assertEqual(res.status, 501) + self.assertEqual(res.status, HTTPStatus.NOT_IMPLEMENTED) def test_request_line_trimming(self): self.con._http_vsn_str = 'HTTP/1.1\n' self.con.putrequest('XYZBOGUS', '/') self.con.endheaders() res = self.con.getresponse() - self.assertEqual(res.status, 501) + self.assertEqual(res.status, HTTPStatus.NOT_IMPLEMENTED) def test_version_bogus(self): self.con._http_vsn_str = 'FUBAR' self.con.putrequest('GET', '/') self.con.endheaders() res = self.con.getresponse() - self.assertEqual(res.status, 400) + self.assertEqual(res.status, HTTPStatus.BAD_REQUEST) def test_version_digits(self): self.con._http_vsn_str = 'HTTP/9.9.9' self.con.putrequest('GET', '/') self.con.endheaders() res = self.con.getresponse() - self.assertEqual(res.status, 400) + self.assertEqual(res.status, HTTPStatus.BAD_REQUEST) def test_version_none_get(self): self.con._http_vsn_str = '' self.con.putrequest('GET', '/') self.con.endheaders() res = self.con.getresponse() - self.assertEqual(res.status, 501) + self.assertEqual(res.status, HTTPStatus.NOT_IMPLEMENTED) def test_version_none(self): # Test that a valid method is rejected when not HTTP/1.x @@ -158,7 +158,7 @@ class BaseHTTPServerTestCase(BaseTestCase): self.con.putrequest('CUSTOM', '/') self.con.endheaders() res = self.con.getresponse() - self.assertEqual(res.status, 400) + self.assertEqual(res.status, HTTPStatus.BAD_REQUEST) def test_version_invalid(self): self.con._http_vsn = 99 @@ -166,21 +166,21 @@ class BaseHTTPServerTestCase(BaseTestCase): self.con.putrequest('GET', '/') self.con.endheaders() res = self.con.getresponse() - self.assertEqual(res.status, 505) + self.assertEqual(res.status, HTTPStatus.HTTP_VERSION_NOT_SUPPORTED) def test_send_blank(self): self.con._http_vsn_str = '' self.con.putrequest('', '') self.con.endheaders() res = self.con.getresponse() - self.assertEqual(res.status, 400) + self.assertEqual(res.status, HTTPStatus.BAD_REQUEST) def test_header_close(self): self.con.putrequest('GET', '/') self.con.putheader('Connection', 'close') self.con.endheaders() res = self.con.getresponse() - self.assertEqual(res.status, 501) + self.assertEqual(res.status, HTTPStatus.NOT_IMPLEMENTED) def test_head_keep_alive(self): self.con._http_vsn_str = 'HTTP/1.1' @@ -188,12 +188,12 @@ class BaseHTTPServerTestCase(BaseTestCase): self.con.putheader('Connection', 'keep-alive') self.con.endheaders() res = self.con.getresponse() - self.assertEqual(res.status, 501) + self.assertEqual(res.status, HTTPStatus.NOT_IMPLEMENTED) def test_handler(self): self.con.request('TEST', '/') res = self.con.getresponse() - self.assertEqual(res.status, 204) + self.assertEqual(res.status, HTTPStatus.NO_CONTENT) def test_return_header_keep_alive(self): self.con.request('KEEP', '/') @@ -230,11 +230,48 @@ class BaseHTTPServerTestCase(BaseTestCase): # Issue #16088: standard error responses should have a content-length self.con.request('NOTFOUND', '/') res = self.con.getresponse() - self.assertEqual(res.status, 404) + self.assertEqual(res.status, HTTPStatus.NOT_FOUND) + data = res.read() self.assertEqual(int(res.getheader('Content-Length')), len(data)) +class RequestHandlerLoggingTestCase(BaseTestCase): + class request_handler(BaseHTTPRequestHandler): + protocol_version = 'HTTP/1.1' + default_request_version = 'HTTP/1.1' + + def do_GET(self): + self.send_response(HTTPStatus.OK) + self.end_headers() + + def do_ERROR(self): + self.send_error(HTTPStatus.NOT_FOUND, 'File not found') + + def test_get(self): + self.con = http.client.HTTPConnection(self.HOST, self.PORT) + self.con.connect() + + with support.captured_stderr() as err: + self.con.request('GET', '/') + self.con.getresponse() + + self.assertTrue( + err.getvalue().endswith('"GET / HTTP/1.1" 200 -\n')) + + def test_err(self): + self.con = http.client.HTTPConnection(self.HOST, self.PORT) + self.con.connect() + + with support.captured_stderr() as err: + self.con.request('ERROR', '/') + self.con.getresponse() + + lines = err.getvalue().split('\n') + self.assertTrue(lines[0].endswith('code 404, message File not found')) + self.assertTrue(lines[1].endswith('"ERROR / HTTP/1.1" 404 -')) + + class SimpleHTTPServerTestCase(BaseTestCase): class request_handler(NoLogRequestHandler, SimpleHTTPRequestHandler): pass @@ -261,12 +298,28 @@ class SimpleHTTPServerTestCase(BaseTestCase): BaseTestCase.tearDown(self) def check_status_and_reason(self, response, status, data=None): + def close_conn(): + """Don't close reader yet so we can check if there was leftover + buffered input""" + nonlocal reader + reader = response.fp + response.fp = None + reader = None + response._close_conn = close_conn + body = response.read() self.assertTrue(response) self.assertEqual(response.status, status) self.assertIsNotNone(response.reason) if data: self.assertEqual(data, body) + # Ensure the server has not set up a persistent connection, and has + # not sent any extra data + self.assertEqual(response.version, 10) + self.assertEqual(response.msg.get("Connection", "close"), "close") + self.assertEqual(reader.read(30), b'', 'Connection should be closed') + + reader.close() return body @support.requires_mac_ver(10, 5) @@ -285,52 +338,58 @@ class SimpleHTTPServerTestCase(BaseTestCase): if name != 'test': # Ignore a filename created in setUp(). filename = name break - body = self.check_status_and_reason(response, 200) + body = self.check_status_and_reason(response, HTTPStatus.OK) quotedname = urllib.parse.quote(filename, errors='surrogatepass') self.assertIn(('href="%s"' % quotedname) .encode(enc, 'surrogateescape'), body) self.assertIn(('>%s<' % html.escape(filename)) .encode(enc, 'surrogateescape'), body) response = self.request(self.tempdir_name + '/' + quotedname) - self.check_status_and_reason(response, 200, + self.check_status_and_reason(response, HTTPStatus.OK, data=support.TESTFN_UNDECODABLE) def test_get(self): #constructs the path relative to the root directory of the HTTPServer response = self.request(self.tempdir_name + '/test') - self.check_status_and_reason(response, 200, data=self.data) + self.check_status_and_reason(response, HTTPStatus.OK, data=self.data) # check for trailing "/" which should return 404. See Issue17324 response = self.request(self.tempdir_name + '/test/') - self.check_status_and_reason(response, 404) + self.check_status_and_reason(response, HTTPStatus.NOT_FOUND) response = self.request(self.tempdir_name + '/') - self.check_status_and_reason(response, 200) + self.check_status_and_reason(response, HTTPStatus.OK) response = self.request(self.tempdir_name) - self.check_status_and_reason(response, 301) + self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY) response = self.request(self.tempdir_name + '/?hi=2') - self.check_status_and_reason(response, 200) + self.check_status_and_reason(response, HTTPStatus.OK) response = self.request(self.tempdir_name + '?hi=1') - self.check_status_and_reason(response, 301) + self.check_status_and_reason(response, HTTPStatus.MOVED_PERMANENTLY) self.assertEqual(response.getheader("Location"), self.tempdir_name + "/?hi=1") response = self.request('/ThisDoesNotExist') - self.check_status_and_reason(response, 404) + self.check_status_and_reason(response, HTTPStatus.NOT_FOUND) response = self.request('/' + 'ThisDoesNotExist' + '/') - self.check_status_and_reason(response, 404) - with open(os.path.join(self.tempdir_name, 'index.html'), 'w') as f: - response = self.request('/' + self.tempdir_name + '/') - self.check_status_and_reason(response, 200) - # chmod() doesn't work as expected on Windows, and filesystem - # permissions are ignored by root on Unix. - if os.name == 'posix' and os.geteuid() != 0: - os.chmod(self.tempdir, 0) + self.check_status_and_reason(response, HTTPStatus.NOT_FOUND) + + data = b"Dummy index file\r\n" + with open(os.path.join(self.tempdir_name, 'index.html'), 'wb') as f: + f.write(data) + response = self.request('/' + self.tempdir_name + '/') + self.check_status_and_reason(response, HTTPStatus.OK, data) + + # chmod() doesn't work as expected on Windows, and filesystem + # permissions are ignored by root on Unix. + if os.name == 'posix' and os.geteuid() != 0: + os.chmod(self.tempdir, 0) + try: response = self.request(self.tempdir_name + '/') - self.check_status_and_reason(response, 404) + self.check_status_and_reason(response, HTTPStatus.NOT_FOUND) + finally: os.chmod(self.tempdir, 0o755) def test_head(self): response = self.request( self.tempdir_name + '/test', method='HEAD') - self.check_status_and_reason(response, 200) + self.check_status_and_reason(response, HTTPStatus.OK) self.assertEqual(response.getheader('content-length'), str(len(self.data))) self.assertEqual(response.getheader('content-type'), @@ -338,12 +397,12 @@ class SimpleHTTPServerTestCase(BaseTestCase): def test_invalid_requests(self): response = self.request('/', method='FOO') - self.check_status_and_reason(response, 501) + self.check_status_and_reason(response, HTTPStatus.NOT_IMPLEMENTED) # requests must be case sensitive,so this should fail too response = self.request('/', method='custom') - self.check_status_and_reason(response, 501) + self.check_status_and_reason(response, HTTPStatus.NOT_IMPLEMENTED) response = self.request('/', method='GETs') - self.check_status_and_reason(response, 501) + self.check_status_and_reason(response, HTTPStatus.NOT_IMPLEMENTED) cgi_file1 = """\ @@ -508,12 +567,13 @@ class CGIHTTPServerTestCase(BaseTestCase): def test_headers_and_content(self): res = self.request('/cgi-bin/file1.py') - self.assertEqual((b'Hello World' + self.linesep, 'text/html', 200), - (res.read(), res.getheader('Content-type'), res.status)) + self.assertEqual( + (res.read(), res.getheader('Content-type'), res.status), + (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK)) def test_issue19435(self): res = self.request('///////////nocgi.py/../cgi-bin/nothere.sh') - self.assertEqual(res.status, 404) + self.assertEqual(res.status, HTTPStatus.NOT_FOUND) def test_post(self): params = urllib.parse.urlencode( @@ -526,38 +586,43 @@ class CGIHTTPServerTestCase(BaseTestCase): def test_invaliduri(self): res = self.request('/cgi-bin/invalid') res.read() - self.assertEqual(res.status, 404) + self.assertEqual(res.status, HTTPStatus.NOT_FOUND) def test_authorization(self): headers = {b'Authorization' : b'Basic ' + base64.b64encode(b'username:pass')} res = self.request('/cgi-bin/file1.py', 'GET', headers=headers) - self.assertEqual((b'Hello World' + self.linesep, 'text/html', 200), - (res.read(), res.getheader('Content-type'), res.status)) + self.assertEqual( + (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK), + (res.read(), res.getheader('Content-type'), res.status)) def test_no_leading_slash(self): # http://bugs.python.org/issue2254 res = self.request('cgi-bin/file1.py') - self.assertEqual((b'Hello World' + self.linesep, 'text/html', 200), - (res.read(), res.getheader('Content-type'), res.status)) + self.assertEqual( + (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK), + (res.read(), res.getheader('Content-type'), res.status)) def test_os_environ_is_not_altered(self): signature = "Test CGI Server" os.environ['SERVER_SOFTWARE'] = signature res = self.request('/cgi-bin/file1.py') - self.assertEqual((b'Hello World' + self.linesep, 'text/html', 200), - (res.read(), res.getheader('Content-type'), res.status)) + self.assertEqual( + (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK), + (res.read(), res.getheader('Content-type'), res.status)) self.assertEqual(os.environ['SERVER_SOFTWARE'], signature) def test_urlquote_decoding_in_cgi_check(self): res = self.request('/cgi-bin%2ffile1.py') - self.assertEqual((b'Hello World' + self.linesep, 'text/html', 200), - (res.read(), res.getheader('Content-type'), res.status)) + self.assertEqual( + (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK), + (res.read(), res.getheader('Content-type'), res.status)) def test_nested_cgi_path_issue21323(self): res = self.request('/cgi-bin/child-dir/file3.py') - self.assertEqual((b'Hello World' + self.linesep, 'text/html', 200), - (res.read(), res.getheader('Content-type'), res.status)) + self.assertEqual( + (b'Hello World' + self.linesep, 'text/html', HTTPStatus.OK), + (res.read(), res.getheader('Content-type'), res.status)) def test_query_with_multiple_question_mark(self): res = self.request('/cgi-bin/file4.py?a=b?c=d') @@ -580,7 +645,7 @@ class SocketlessRequestHandler(SimpleHTTPRequestHandler): def do_GET(self): self.get_called = True - self.send_response(200) + self.send_response(HTTPStatus.OK) self.send_header('Content-Type', 'text/html') self.end_headers() self.wfile.write(b'<html><body>Data</body></html>\r\n') @@ -590,7 +655,7 @@ class SocketlessRequestHandler(SimpleHTTPRequestHandler): class RejectingSocketlessRequestHandler(SocketlessRequestHandler): def handle_expect_100(self): - self.send_error(417) + self.send_error(HTTPStatus.EXPECTATION_FAILED) return False @@ -847,6 +912,7 @@ def test_main(verbose=None): cwd = os.getcwd() try: support.run_unittest( + RequestHandlerLoggingTestCase, BaseHTTPRequestHandlerTestCase, BaseHTTPServerTestCase, SimpleHTTPServerTestCase, diff --git a/Lib/test/test_imaplib.py b/Lib/test/test_imaplib.py index 96b4f32..8248656 100644 --- a/Lib/test/test_imaplib.py +++ b/Lib/test/test_imaplib.py @@ -11,7 +11,8 @@ import socketserver import time import calendar -from test.support import reap_threads, verbose, transient_internet, run_with_tz, run_with_locale +from test.support import (reap_threads, verbose, transient_internet, + run_with_tz, run_with_locale) import unittest from datetime import datetime, timezone, timedelta try: @@ -19,8 +20,8 @@ try: except ImportError: ssl = None -CERTFILE = None -CAFILE = None +CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "keycert3.pem") +CAFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "pycacert.pem") class TestImaplib(unittest.TestCase): @@ -41,17 +42,15 @@ class TestImaplib(unittest.TestCase): def test_Internaldate2tuple_issue10941(self): self.assertNotEqual(imaplib.Internaldate2tuple( b'25 (INTERNALDATE "02-Apr-2000 02:30:00 +0000")'), - imaplib.Internaldate2tuple( - b'25 (INTERNALDATE "02-Apr-2000 03:30:00 +0000")')) - - + imaplib.Internaldate2tuple( + b'25 (INTERNALDATE "02-Apr-2000 03:30:00 +0000")')) def timevalues(self): return [2000000000, 2000000000.0, time.localtime(2000000000), (2033, 5, 18, 5, 33, 20, -1, -1, -1), (2033, 5, 18, 5, 33, 20, -1, -1, 1), datetime.fromtimestamp(2000000000, - timezone(timedelta(0, 2*60*60))), + timezone(timedelta(0, 2 * 60 * 60))), '"18-May-2033 05:33:20 +0200"'] @run_with_locale('LC_ALL', 'de_DE', 'fr_FR') @@ -72,7 +71,6 @@ class TestImaplib(unittest.TestCase): if ssl: - class SecureTCPServer(socketserver.TCPServer): def get_request(self): @@ -93,13 +91,17 @@ else: class SimpleIMAPHandler(socketserver.StreamRequestHandler): - timeout = 1 continuation = None capabilities = '' + def setup(self): + super().setup() + self.server.logged = None + def _send(self, message): - if verbose: print("SENT: %r" % message.strip()) + if verbose: + print("SENT: %r" % message.strip()) self.wfile.write(message) def _send_line(self, message): @@ -132,7 +134,8 @@ class SimpleIMAPHandler(socketserver.StreamRequestHandler): if line.endswith(b'\r\n'): break - if verbose: print('GOT: %r' % line.strip()) + if verbose: + print('GOT: %r' % line.strip()) if self.continuation: try: self.continuation.send(line) @@ -144,8 +147,8 @@ class SimpleIMAPHandler(socketserver.StreamRequestHandler): cmd = splitline[1] args = splitline[2:] - if hasattr(self, 'cmd_'+cmd): - continuation = getattr(self, 'cmd_'+cmd)(tag, args) + if hasattr(self, 'cmd_' + cmd): + continuation = getattr(self, 'cmd_' + cmd)(tag, args) if continuation: self.continuation = continuation next(continuation) @@ -153,16 +156,25 @@ class SimpleIMAPHandler(socketserver.StreamRequestHandler): self._send_tagged(tag, 'BAD', cmd + ' unknown') def cmd_CAPABILITY(self, tag, args): - caps = 'IMAP4rev1 ' + self.capabilities if self.capabilities else 'IMAP4rev1' + caps = ('IMAP4rev1 ' + self.capabilities + if self.capabilities + else 'IMAP4rev1') self._send_textline('* CAPABILITY ' + caps) self._send_tagged(tag, 'OK', 'CAPABILITY completed') def cmd_LOGOUT(self, tag, args): + self.server.logged = None self._send_textline('* BYE IMAP4ref1 Server logging out') self._send_tagged(tag, 'OK', 'LOGOUT completed') + def cmd_LOGIN(self, tag, args): + self.server.logged = args[0] + self._send_tagged(tag, 'OK', 'LOGIN completed') -class BaseThreadedNetworkedTests(unittest.TestCase): + +class ThreadedNetworkedTests(unittest.TestCase): + server_class = socketserver.TCPServer + imap_class = imaplib.IMAP4 def make_server(self, addr, hdlr): @@ -172,7 +184,8 @@ class BaseThreadedNetworkedTests(unittest.TestCase): self.server_close() raise - if verbose: print("creating server") + if verbose: + print("creating server") server = MyServer(addr, hdlr) self.assertEqual(server.server_address, server.socket.getsockname()) @@ -188,18 +201,21 @@ class BaseThreadedNetworkedTests(unittest.TestCase): # Short poll interval to make the test finish quickly. # Time between requests is short enough that we won't wake # up spuriously too many times. - kwargs={'poll_interval':0.01}) + kwargs={'poll_interval': 0.01}) t.daemon = True # In case this function raises. t.start() - if verbose: print("server running") + if verbose: + print("server running") return server, t def reap_server(self, server, thread): - if verbose: print("waiting for server") + if verbose: + print("waiting for server") server.shutdown() server.server_close() thread.join() - if verbose: print("done") + if verbose: + print("done") @contextmanager def reaped_server(self, hdlr): @@ -249,6 +265,84 @@ class BaseThreadedNetworkedTests(unittest.TestCase): self.assertRaises(imaplib.IMAP4.abort, self.imap_class, *server.server_address) + class UTF8Server(SimpleIMAPHandler): + capabilities = 'AUTH ENABLE UTF8=ACCEPT' + + def cmd_ENABLE(self, tag, args): + self._send_tagged(tag, 'OK', 'ENABLE successful') + + def cmd_AUTHENTICATE(self, tag, args): + self._send_textline('+') + self.server.response = yield + self._send_tagged(tag, 'OK', 'FAKEAUTH successful') + + @reap_threads + def test_enable_raises_error_if_not_AUTH(self): + with self.reaped_pair(self.UTF8Server) as (server, client): + self.assertFalse(client.utf8_enabled) + self.assertRaises(imaplib.IMAP4.error, client.enable, 'foo') + self.assertFalse(client.utf8_enabled) + + # XXX Also need a test that enable after SELECT raises an error. + + @reap_threads + def test_enable_raises_error_if_no_capability(self): + class NoEnableServer(self.UTF8Server): + capabilities = 'AUTH' + with self.reaped_pair(NoEnableServer) as (server, client): + self.assertRaises(imaplib.IMAP4.error, client.enable, 'foo') + + @reap_threads + def test_enable_UTF8_raises_error_if_not_supported(self): + class NonUTF8Server(SimpleIMAPHandler): + pass + with self.assertRaises(imaplib.IMAP4.error): + with self.reaped_pair(NonUTF8Server) as (server, client): + typ, data = client.login('user', 'pass') + self.assertEqual(typ, 'OK') + client.enable('UTF8=ACCEPT') + pass + + @reap_threads + def test_enable_UTF8_True_append(self): + + class UTF8AppendServer(self.UTF8Server): + def cmd_APPEND(self, tag, args): + self._send_textline('+') + self.server.response = yield + self._send_tagged(tag, 'OK', 'okay') + + with self.reaped_pair(UTF8AppendServer) as (server, client): + self.assertEqual(client._encoding, 'ascii') + code, _ = client.authenticate('MYAUTH', lambda x: b'fake') + self.assertEqual(code, 'OK') + self.assertEqual(server.response, + b'ZmFrZQ==\r\n') # b64 encoded 'fake' + code, _ = client.enable('UTF8=ACCEPT') + self.assertEqual(code, 'OK') + self.assertEqual(client._encoding, 'utf-8') + msg_string = 'Subject: üñí©öðé' + typ, data = client.append( + None, None, None, msg_string.encode('utf-8')) + self.assertEqual(typ, 'OK') + self.assertEqual( + server.response, + ('UTF8 (%s)\r\n' % msg_string).encode('utf-8') + ) + + # XXX also need a test that makes sure that the Literal and Untagged_status + # regexes uses unicode in UTF8 mode instead of the default ASCII. + + @reap_threads + def test_search_disallows_charset_in_utf8_mode(self): + with self.reaped_pair(self.UTF8Server) as (server, client): + typ, _ = client.authenticate('MYAUTH', lambda x: b'fake') + self.assertEqual(typ, 'OK') + typ, _ = client.enable('UTF8=ACCEPT') + self.assertEqual(typ, 'OK') + self.assertTrue(client.utf8_enabled) + self.assertRaises(imaplib.IMAP4.error, client.search, 'foo', 'bar') + @reap_threads def test_bad_auth_name(self): @@ -256,7 +350,7 @@ class BaseThreadedNetworkedTests(unittest.TestCase): def cmd_AUTHENTICATE(self, tag, args): self._send_tagged(tag, 'NO', 'unrecognized authentication ' - 'type {}'.format(args[0])) + 'type {}'.format(args[0])) with self.reaped_pair(MyServer) as (server, client): with self.assertRaises(imaplib.IMAP4.error): @@ -290,13 +384,13 @@ class BaseThreadedNetworkedTests(unittest.TestCase): code, data = client.authenticate('MYAUTH', lambda x: b'fake') self.assertEqual(code, 'OK') self.assertEqual(server.response, - b'ZmFrZQ==\r\n') #b64 encoded 'fake' + b'ZmFrZQ==\r\n') # b64 encoded 'fake' with self.reaped_pair(MyServer) as (server, client): code, data = client.authenticate('MYAUTH', lambda x: 'fake') self.assertEqual(code, 'OK') self.assertEqual(server.response, - b'ZmFrZQ==\r\n') #b64 encoded 'fake' + b'ZmFrZQ==\r\n') # b64 encoded 'fake' @reap_threads def test_login_cram_md5(self): @@ -307,9 +401,10 @@ class BaseThreadedNetworkedTests(unittest.TestCase): def cmd_AUTHENTICATE(self, tag, args): self._send_textline('+ PDE4OTYuNjk3MTcwOTUyQHBvc3RvZmZpY2Uucm' - 'VzdG9uLm1jaS5uZXQ=') + 'VzdG9uLm1jaS5uZXQ=') r = yield - if r == b'dGltIGYxY2E2YmU0NjRiOWVmYTFjY2E2ZmZkNmNmMmQ5ZjMy\r\n': + if (r == b'dGltIGYxY2E2YmU0NjRiOWVmYT' + b'FjY2E2ZmZkNmNmMmQ5ZjMy\r\n'): self._send_tagged(tag, 'OK', 'CRAM-MD5 successful') else: self._send_tagged(tag, 'NO', 'No access') @@ -325,7 +420,6 @@ class BaseThreadedNetworkedTests(unittest.TestCase): self.assertEqual(ret, "OK") - @reap_threads def test_aborted_authentication(self): @@ -344,26 +438,46 @@ class BaseThreadedNetworkedTests(unittest.TestCase): with self.assertRaises(imaplib.IMAP4.error): code, data = client.authenticate('MYAUTH', lambda x: None) + def test_linetoolong(self): class TooLongHandler(SimpleIMAPHandler): def handle(self): # Send a very long response line - self.wfile.write(b'* OK ' + imaplib._MAXLINE*b'x' + b'\r\n') + self.wfile.write(b'* OK ' + imaplib._MAXLINE * b'x' + b'\r\n') with self.reaped_server(TooLongHandler) as server: self.assertRaises(imaplib.IMAP4.error, self.imap_class, *server.server_address) + @reap_threads + def test_simple_with_statement(self): + # simplest call + with self.reaped_server(SimpleIMAPHandler) as server: + with self.imap_class(*server.server_address): + pass -class ThreadedNetworkedTests(BaseThreadedNetworkedTests): + @reap_threads + def test_with_statement(self): + with self.reaped_server(SimpleIMAPHandler) as server: + with self.imap_class(*server.server_address) as imap: + imap.login('user', 'pass') + self.assertEqual(server.logged, 'user') + self.assertIsNone(server.logged) - server_class = socketserver.TCPServer - imap_class = imaplib.IMAP4 + @reap_threads + def test_with_statement_logout(self): + # what happens if already logout in the block? + with self.reaped_server(SimpleIMAPHandler) as server: + with self.imap_class(*server.server_address) as imap: + imap.login('user', 'pass') + self.assertEqual(server.logged, 'user') + imap.logout() + self.assertIsNone(server.logged) + self.assertIsNone(server.logged) @unittest.skipUnless(ssl, "SSL not available") -class ThreadedNetworkedTestsSSL(BaseThreadedNetworkedTests): - +class ThreadedNetworkedTestsSSL(ThreadedNetworkedTests): server_class = SecureTCPServer imap_class = IMAP4_SSL @@ -374,8 +488,9 @@ class ThreadedNetworkedTestsSSL(BaseThreadedNetworkedTests): ssl_context.check_hostname = True ssl_context.load_verify_locations(CAFILE) - with self.assertRaisesRegex(ssl.CertificateError, - "hostname '127.0.0.1' doesn't match 'localhost'"): + with self.assertRaisesRegex( + ssl.CertificateError, + "hostname '127.0.0.1' doesn't match 'localhost'"): with self.reaped_server(SimpleIMAPHandler) as server: client = self.imap_class(*server.server_address, ssl_context=ssl_context) @@ -387,6 +502,8 @@ class ThreadedNetworkedTestsSSL(BaseThreadedNetworkedTests): client.shutdown() +@unittest.skipUnless( + support.is_resource_enabled('network'), 'network resource disabled') class RemoteIMAPTest(unittest.TestCase): host = 'cyrus.andrew.cmu.edu' port = 143 @@ -420,6 +537,8 @@ class RemoteIMAPTest(unittest.TestCase): @unittest.skipUnless(ssl, "SSL not available") +@unittest.skipUnless( + support.is_resource_enabled('network'), 'network resource disabled') class RemoteIMAP_STARTTLSTest(RemoteIMAPTest): def setUp(self): @@ -473,7 +592,8 @@ class RemoteIMAP_SSLTest(RemoteIMAPTest): def test_logincapa_with_client_ssl_context(self): with transient_internet(self.host): - _server = self.imap_class(self.host, self.port, ssl_context=self.create_ssl_context()) + _server = self.imap_class( + self.host, self.port, ssl_context=self.create_ssl_context()) self.check_logincapa(_server) def test_logout(self): @@ -484,35 +604,15 @@ class RemoteIMAP_SSLTest(RemoteIMAPTest): def test_ssl_context_certfile_exclusive(self): with transient_internet(self.host): - self.assertRaises(ValueError, self.imap_class, self.host, self.port, - certfile=CERTFILE, ssl_context=self.create_ssl_context()) + self.assertRaises( + ValueError, self.imap_class, self.host, self.port, + certfile=CERTFILE, ssl_context=self.create_ssl_context()) def test_ssl_context_keyfile_exclusive(self): with transient_internet(self.host): - self.assertRaises(ValueError, self.imap_class, self.host, self.port, - keyfile=CERTFILE, ssl_context=self.create_ssl_context()) - - -def load_tests(*args): - tests = [TestImaplib] - - if support.is_resource_enabled('network'): - if ssl: - global CERTFILE, CAFILE - CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, - "keycert3.pem") - if not os.path.exists(CERTFILE): - raise support.TestFailed("Can't read certificate files!") - CAFILE = os.path.join(os.path.dirname(__file__) or os.curdir, - "pycacert.pem") - if not os.path.exists(CAFILE): - raise support.TestFailed("Can't read CA file!") - tests.extend([ - ThreadedNetworkedTests, ThreadedNetworkedTestsSSL, - RemoteIMAPTest, RemoteIMAP_SSLTest, RemoteIMAP_STARTTLSTest, - ]) - - return unittest.TestSuite([unittest.makeSuite(test) for test in tests]) + self.assertRaises( + ValueError, self.imap_class, self.host, self.port, + keyfile=CERTFILE, ssl_context=self.create_ssl_context()) if __name__ == "__main__": diff --git a/Lib/test/test_imghdr.py b/Lib/test/test_imghdr.py index 0ad4343..b54daf8 100644 --- a/Lib/test/test_imghdr.py +++ b/Lib/test/test_imghdr.py @@ -16,7 +16,9 @@ TEST_FILES = ( ('python.ras', 'rast'), ('python.sgi', 'rgb'), ('python.tiff', 'tiff'), - ('python.xbm', 'xbm') + ('python.xbm', 'xbm'), + ('python.webp', 'webp'), + ('python.exr', 'exr'), ) class UnseekableIO(io.FileIO): diff --git a/Lib/test/test_imp.py b/Lib/test/test_imp.py index 80b9ec3..ee9ee1a 100644 --- a/Lib/test/test_imp.py +++ b/Lib/test/test_imp.py @@ -3,6 +3,7 @@ try: except ImportError: _thread = None import importlib +import importlib.util import os import os.path import shutil @@ -111,7 +112,6 @@ class ImportTests(unittest.TestCase): del sys.path[0] support.unlink(temp_mod_name + '.py') support.unlink(temp_mod_name + '.pyc') - support.unlink(temp_mod_name + '.pyo') def test_issue5604(self): # Test cannot cover imp.load_compiled function. @@ -194,7 +194,7 @@ class ImportTests(unittest.TestCase): self.assertEqual(package.b, 2) finally: del sys.path[0] - for ext in ('.py', '.pyc', '.pyo'): + for ext in ('.py', '.pyc'): support.unlink(temp_mod_name + ext) support.unlink(init_file_name + ext) support.rmtree(test_package_name) @@ -276,6 +276,29 @@ class ImportTests(unittest.TestCase): self.skipTest("found module doesn't appear to be a C extension") imp.load_module(name, None, *found[1:]) + @requires_load_dynamic + def test_issue24748_load_module_skips_sys_modules_check(self): + name = 'test.imp_dummy' + try: + del sys.modules[name] + except KeyError: + pass + try: + module = importlib.import_module(name) + spec = importlib.util.find_spec('_testmultiphase') + module = imp.load_dynamic(name, spec.origin) + self.assertEqual(module.__name__, name) + self.assertEqual(module.__spec__.name, name) + self.assertEqual(module.__spec__.origin, spec.origin) + self.assertRaises(AttributeError, getattr, module, 'dummy_name') + self.assertEqual(module.int_const, 1969) + self.assertIs(sys.modules[name], module) + finally: + try: + del sys.modules[name] + except KeyError: + pass + @unittest.skipIf(sys.dont_write_bytecode, "test meaningful only when writing bytecode") def test_bug7732(self): @@ -346,56 +369,6 @@ class PEP3147Tests(unittest.TestCase): 'qux.{}.pyc'.format(self.tag)) self.assertEqual(imp.cache_from_source(path, True), expect) - def test_cache_from_source_no_cache_tag(self): - # Non cache tag means NotImplementedError. - with support.swap_attr(sys.implementation, 'cache_tag', None): - with self.assertRaises(NotImplementedError): - imp.cache_from_source('whatever.py') - - def test_cache_from_source_no_dot(self): - # Directory with a dot, filename without dot. - path = os.path.join('foo.bar', 'file') - expect = os.path.join('foo.bar', '__pycache__', - 'file{}.pyc'.format(self.tag)) - self.assertEqual(imp.cache_from_source(path, True), expect) - - def test_cache_from_source_optimized(self): - # Given the path to a .py file, return the path to its PEP 3147 - # defined .pyo file (i.e. under __pycache__). - path = os.path.join('foo', 'bar', 'baz', 'qux.py') - expect = os.path.join('foo', 'bar', 'baz', '__pycache__', - 'qux.{}.pyo'.format(self.tag)) - self.assertEqual(imp.cache_from_source(path, False), expect) - - def test_cache_from_source_cwd(self): - path = 'foo.py' - expect = os.path.join('__pycache__', 'foo.{}.pyc'.format(self.tag)) - self.assertEqual(imp.cache_from_source(path, True), expect) - - def test_cache_from_source_override(self): - # When debug_override is not None, it can be any true-ish or false-ish - # value. - path = os.path.join('foo', 'bar', 'baz.py') - partial_expect = os.path.join('foo', 'bar', '__pycache__', - 'baz.{}.py'.format(self.tag)) - self.assertEqual(imp.cache_from_source(path, []), partial_expect + 'o') - self.assertEqual(imp.cache_from_source(path, [17]), - partial_expect + 'c') - # However if the bool-ishness can't be determined, the exception - # propagates. - class Bearish: - def __bool__(self): raise RuntimeError - with self.assertRaises(RuntimeError): - imp.cache_from_source('/foo/bar/baz.py', Bearish()) - - @unittest.skipUnless(os.sep == '\\' and os.altsep == '/', - 'test meaningful only where os.altsep is defined') - def test_sep_altsep_and_sep_cache_from_source(self): - # Windows path and PEP 3147 where sep is right of altsep. - self.assertEqual( - imp.cache_from_source('\\foo\\bar\\baz/qux.py', True), - '\\foo\\bar\\baz\\__pycache__\\qux.{}.pyc'.format(self.tag)) - @unittest.skipUnless(sys.implementation.cache_tag is not None, 'requires sys.implementation.cache_tag to not be ' 'None') @@ -407,68 +380,6 @@ class PEP3147Tests(unittest.TestCase): expect = os.path.join('foo', 'bar', 'baz', 'qux.py') self.assertEqual(imp.source_from_cache(path), expect) - def test_source_from_cache_no_cache_tag(self): - # If sys.implementation.cache_tag is None, raise NotImplementedError. - path = os.path.join('blah', '__pycache__', 'whatever.pyc') - with support.swap_attr(sys.implementation, 'cache_tag', None): - with self.assertRaises(NotImplementedError): - imp.source_from_cache(path) - - def test_source_from_cache_bad_path(self): - # When the path to a pyc file is not in PEP 3147 format, a ValueError - # is raised. - self.assertRaises( - ValueError, imp.source_from_cache, '/foo/bar/bazqux.pyc') - - def test_source_from_cache_no_slash(self): - # No slashes at all in path -> ValueError - self.assertRaises( - ValueError, imp.source_from_cache, 'foo.cpython-32.pyc') - - def test_source_from_cache_too_few_dots(self): - # Too few dots in final path component -> ValueError - self.assertRaises( - ValueError, imp.source_from_cache, '__pycache__/foo.pyc') - - def test_source_from_cache_too_many_dots(self): - # Too many dots in final path component -> ValueError - self.assertRaises( - ValueError, imp.source_from_cache, - '__pycache__/foo.cpython-32.foo.pyc') - - def test_source_from_cache_no__pycache__(self): - # Another problem with the path -> ValueError - self.assertRaises( - ValueError, imp.source_from_cache, - '/foo/bar/foo.cpython-32.foo.pyc') - - def test_package___file__(self): - try: - m = __import__('pep3147') - except ImportError: - pass - else: - self.fail("pep3147 module already exists: %r" % (m,)) - # Test that a package's __file__ points to the right source directory. - os.mkdir('pep3147') - sys.path.insert(0, os.curdir) - def cleanup(): - if sys.path[0] == os.curdir: - del sys.path[0] - shutil.rmtree('pep3147') - self.addCleanup(cleanup) - # Touch the __init__.py file. - support.create_empty_file('pep3147/__init__.py') - importlib.invalidate_caches() - expected___file__ = os.sep.join(('.', 'pep3147', '__init__.py')) - m = __import__('pep3147') - self.assertEqual(m.__file__, expected___file__, (m.__file__, m.__path__, sys.path, sys.path_importer_cache)) - # Ensure we load the pyc file. - support.unload('pep3147') - m = __import__('pep3147') - support.unload('pep3147') - self.assertEqual(m.__file__, expected___file__, (m.__file__, m.__path__, sys.path, sys.path_importer_cache)) - class NullImporterTests(unittest.TestCase): @unittest.skipIf(support.TESTFN_UNENCODABLE is None, diff --git a/Lib/test/test_import.py b/Lib/test/test_import/__init__.py index b4842c5..14a688d 100644 --- a/Lib/test/test_import.py +++ b/Lib/test/test_import/__init__.py @@ -1,7 +1,7 @@ # We import importlib *ASAP* in order to test #15386 import importlib import importlib.util -from importlib._bootstrap import _get_sourcefile +from importlib._bootstrap_external import _get_sourcefile import builtins import marshal import os @@ -21,8 +21,9 @@ import test.support from test.support import ( EnvironmentVarGuard, TESTFN, check_warnings, forget, is_jython, make_legacy_pyc, rmtree, run_unittest, swap_attr, swap_item, temp_umask, - unlink, unload, create_empty_file, cpython_only, TESTFN_UNENCODABLE) -from test import script_helper + unlink, unload, create_empty_file, cpython_only, TESTFN_UNENCODABLE, + temp_dir) +from test.support import script_helper skip_if_dont_write_bytecode = unittest.skipIf( @@ -32,7 +33,6 @@ skip_if_dont_write_bytecode = unittest.skipIf( def remove_files(name): for f in (name + ".py", name + ".pyc", - name + ".pyo", name + ".pyw", name + "$py.class"): unlink(f) @@ -46,7 +46,7 @@ def _ready_to_import(name=None, source=""): # temporarily clears the module from sys.modules (if any) # reverts or removes the module when cleaning up name = name or "spam" - with script_helper.temp_dir() as tempdir: + with temp_dir() as tempdir: path = script_helper.make_script(tempdir, name, source) old_module = sys.modules.pop(name, None) try: @@ -84,7 +84,6 @@ class ImportTests(unittest.TestCase): def test_with_extension(ext): # The extension is normally ".py", perhaps ".pyw". source = TESTFN + ext - pyo = TESTFN + ".pyo" if is_jython: pyc = TESTFN + "$py.class" else: @@ -115,7 +114,6 @@ class ImportTests(unittest.TestCase): forget(TESTFN) unlink(source) unlink(pyc) - unlink(pyo) sys.path.insert(0, os.curdir) try: @@ -138,7 +136,7 @@ class ImportTests(unittest.TestCase): f.write(']') try: - # Compile & remove .py file; we only need .pyc (or .pyo). + # Compile & remove .py file; we only need .pyc. # Bytecode must be relocated from the PEP 3147 bytecode-only location. py_compile.compile(filename) finally: @@ -252,7 +250,7 @@ class ImportTests(unittest.TestCase): importlib.invalidate_caches() mod = __import__(TESTFN) base, ext = os.path.splitext(mod.__file__) - self.assertIn(ext, ('.pyc', '.pyo')) + self.assertEqual(ext, '.pyc') finally: del sys.path[0] remove_files(TESTFN) @@ -294,7 +292,8 @@ class ImportTests(unittest.TestCase): except OverflowError: self.skipTest("cannot set modification time to large integer") except OSError as e: - if e.errno != getattr(errno, 'EOVERFLOW', None): + if e.errno not in (getattr(errno, 'EOVERFLOW', None), + getattr(errno, 'EINVAL', None)): raise self.skipTest("cannot set modification time to large integer ({})".format(e)) __import__(TESTFN) @@ -325,10 +324,23 @@ class ImportTests(unittest.TestCase): with self.assertRaisesRegex(ImportError, "^cannot import name 'bogus'"): from re import bogus + def test_from_import_AttributeError(self): + # Issue #24492: trying to import an attribute that raises an + # AttributeError should lead to an ImportError. + class AlwaysAttributeError: + def __getattr__(self, _): + raise AttributeError + + module_name = 'test_from_import_AttributeError' + self.addCleanup(unload, module_name) + sys.modules[module_name] = AlwaysAttributeError() + with self.assertRaises(ImportError): + from test_from_import_AttributeError import does_not_exist + @skip_if_dont_write_bytecode class FilePermissionTests(unittest.TestCase): - # tests for file mode on cached .pyc/.pyo files + # tests for file mode on cached .pyc files @unittest.skipUnless(os.name == 'posix', "test meaningful only on posix systems") @@ -339,7 +351,7 @@ class FilePermissionTests(unittest.TestCase): module = __import__(name) if not os.path.exists(cached_path): self.fail("__import__ did not result in creation of " - "either a .pyc or .pyo file") + "a .pyc file") stat_info = os.stat(cached_path) # Check that the umask is respected, and the executable bits @@ -358,7 +370,7 @@ class FilePermissionTests(unittest.TestCase): __import__(name) if not os.path.exists(cached_path): self.fail("__import__ did not result in creation of " - "either a .pyc or .pyo file") + "a .pyc file") stat_info = os.stat(cached_path) self.assertEqual(oct(stat.S_IMODE(stat_info.st_mode)), oct(mode)) @@ -373,7 +385,7 @@ class FilePermissionTests(unittest.TestCase): __import__(name) if not os.path.exists(cached_path): self.fail("__import__ did not result in creation of " - "either a .pyc or .pyo file") + "a .pyc file") stat_info = os.stat(cached_path) expected = mode | 0o200 # Account for fix for issue #6074 @@ -404,10 +416,7 @@ class FilePermissionTests(unittest.TestCase): unlink(path) unload(name) importlib.invalidate_caches() - if __debug__: - bytecode_only = path + "c" - else: - bytecode_only = path + "o" + bytecode_only = path + "c" os.rename(importlib.util.cache_from_source(path), bytecode_only) m = __import__(name) self.assertEqual(m.x, 'rewritten') @@ -568,7 +577,7 @@ class RelativeImportTests(unittest.TestCase): def test_relimport_star(self): # This will import * from .test_import. - from . import relimport + from .. import relimport self.assertTrue(hasattr(relimport, "RelativeImportTests")) def test_issue3221(self): @@ -631,9 +640,7 @@ class OverridingImportBuiltinTests(unittest.TestCase): class PycacheTests(unittest.TestCase): - # Test the various PEP 3147 related behaviors. - - tag = sys.implementation.cache_tag + # Test the various PEP 3147/488-related behaviors. def _clean(self): forget(TESTFN) @@ -658,9 +665,10 @@ class PycacheTests(unittest.TestCase): self.assertFalse(os.path.exists('__pycache__')) __import__(TESTFN) self.assertTrue(os.path.exists('__pycache__')) - self.assertTrue(os.path.exists(os.path.join( - '__pycache__', '{}.{}.py{}'.format( - TESTFN, self.tag, 'c' if __debug__ else 'o')))) + pyc_path = importlib.util.cache_from_source(self.source) + self.assertTrue(os.path.exists(pyc_path), + 'bytecode file {!r} for {!r} does not ' + 'exist'.format(pyc_path, TESTFN)) @unittest.skipUnless(os.name == 'posix', "test meaningful only on posix systems") @@ -673,8 +681,10 @@ class PycacheTests(unittest.TestCase): with temp_umask(0o222): __import__(TESTFN) self.assertTrue(os.path.exists('__pycache__')) - self.assertFalse(os.path.exists(os.path.join( - '__pycache__', '{}.{}.pyc'.format(TESTFN, self.tag)))) + pyc_path = importlib.util.cache_from_source(self.source) + self.assertFalse(os.path.exists(pyc_path), + 'bytecode file {!r} for {!r} ' + 'exists'.format(pyc_path, TESTFN)) @skip_if_dont_write_bytecode def test_missing_source(self): @@ -849,19 +859,27 @@ class ImportlibBootstrapTests(unittest.TestCase): self.assertEqual(mod.__package__, 'importlib') self.assertTrue(mod.__file__.endswith('_bootstrap.py'), mod.__file__) + def test_frozen_importlib_external_is_bootstrap_external(self): + from importlib import _bootstrap_external + mod = sys.modules['_frozen_importlib_external'] + self.assertIs(mod, _bootstrap_external) + self.assertEqual(mod.__name__, 'importlib._bootstrap_external') + self.assertEqual(mod.__package__, 'importlib') + self.assertTrue(mod.__file__.endswith('_bootstrap_external.py'), mod.__file__) + def test_there_can_be_only_one(self): # Issue #15386 revealed a tricky loophole in the bootstrapping # This test is technically redundant, since the bug caused importing # this test module to crash completely, but it helps prove the point from importlib import machinery mod = sys.modules['_frozen_importlib'] - self.assertIs(machinery.FileFinder, mod.FileFinder) + self.assertIs(machinery.ModuleSpec, mod.ModuleSpec) @cpython_only class GetSourcefileTests(unittest.TestCase): - """Test importlib._bootstrap._get_sourcefile() as used by the C API. + """Test importlib._bootstrap_external._get_sourcefile() as used by the C API. Because of the peculiarities of the need of this function, the tests are knowingly whitebox tests. @@ -871,7 +889,7 @@ class GetSourcefileTests(unittest.TestCase): def test_get_sourcefile(self): # Given a valid bytecode path, return the path to the corresponding # source file if it exists. - with mock.patch('importlib._bootstrap._path_isfile') as _path_isfile: + with mock.patch('importlib._bootstrap_external._path_isfile') as _path_isfile: _path_isfile.return_value = True; path = TESTFN + '.pyc' expect = TESTFN + '.py' @@ -880,7 +898,7 @@ class GetSourcefileTests(unittest.TestCase): def test_get_sourcefile_no_source(self): # Given a valid bytecode path without a corresponding source path, # return the original bytecode path. - with mock.patch('importlib._bootstrap._path_isfile') as _path_isfile: + with mock.patch('importlib._bootstrap_external._path_isfile') as _path_isfile: _path_isfile.return_value = False; path = TESTFN + '.pyc' self.assertEqual(_get_sourcefile(path), path) @@ -1035,7 +1053,7 @@ class ImportTracebackTests(unittest.TestCase): # We simulate a bug in importlib and check that it's not stripped # away from the traceback. self.create_module("foo", "") - importlib = sys.modules['_frozen_importlib'] + importlib = sys.modules['_frozen_importlib_external'] if 'load_module' in vars(importlib.SourceLoader): old_exec_module = importlib.SourceLoader.exec_module else: @@ -1068,6 +1086,46 @@ class ImportTracebackTests(unittest.TestCase): __isolated=False) +class CircularImportTests(unittest.TestCase): + + """See the docstrings of the modules being imported for the purpose of the + test.""" + + def tearDown(self): + """Make sure no modules pre-exist in sys.modules which are being used to + test.""" + for key in list(sys.modules.keys()): + if key.startswith('test.test_import.data.circular_imports'): + del sys.modules[key] + + def test_direct(self): + try: + import test.test_import.data.circular_imports.basic + except ImportError: + self.fail('circular import through relative imports failed') + + def test_indirect(self): + try: + import test.test_import.data.circular_imports.indirect + except ImportError: + self.fail('relative import in module contributing to circular ' + 'import failed') + + def test_subpackage(self): + try: + import test.test_import.data.circular_imports.subpackage + except ImportError: + self.fail('circular import involving a subpackage failed') + + def test_rebinding(self): + try: + import test.test_import.data.circular_imports.rebinding as rebinding + except ImportError: + self.fail('circular import with rebinding of module attribute failed') + from test.test_import.data.circular_imports.subpkg import util + self.assertIs(util.util, rebinding.util) + + if __name__ == '__main__': # Test needs to be a package, so we can do relative imports. unittest.main() diff --git a/Lib/test/test_import/__main__.py b/Lib/test/test_import/__main__.py new file mode 100644 index 0000000..24f02a1 --- /dev/null +++ b/Lib/test/test_import/__main__.py @@ -0,0 +1,3 @@ +import unittest + +unittest.main('test.test_import') diff --git a/Lib/test/test_import/data/circular_imports/basic.py b/Lib/test/test_import/data/circular_imports/basic.py new file mode 100644 index 0000000..3e41e39 --- /dev/null +++ b/Lib/test/test_import/data/circular_imports/basic.py @@ -0,0 +1,2 @@ +"""Circular imports through direct, relative imports.""" +from . import basic2 diff --git a/Lib/test/test_import/data/circular_imports/basic2.py b/Lib/test/test_import/data/circular_imports/basic2.py new file mode 100644 index 0000000..00bd2f2 --- /dev/null +++ b/Lib/test/test_import/data/circular_imports/basic2.py @@ -0,0 +1 @@ +from . import basic diff --git a/Lib/test/test_import/data/circular_imports/indirect.py b/Lib/test/test_import/data/circular_imports/indirect.py new file mode 100644 index 0000000..6925788 --- /dev/null +++ b/Lib/test/test_import/data/circular_imports/indirect.py @@ -0,0 +1 @@ +from . import basic, basic2 diff --git a/Lib/test/test_import/data/circular_imports/rebinding.py b/Lib/test/test_import/data/circular_imports/rebinding.py new file mode 100644 index 0000000..2b77375 --- /dev/null +++ b/Lib/test/test_import/data/circular_imports/rebinding.py @@ -0,0 +1,3 @@ +"""Test the binding of names when a circular import shares the same name as an +attribute.""" +from .rebinding2 import util diff --git a/Lib/test/test_import/data/circular_imports/rebinding2.py b/Lib/test/test_import/data/circular_imports/rebinding2.py new file mode 100644 index 0000000..57a9e69 --- /dev/null +++ b/Lib/test/test_import/data/circular_imports/rebinding2.py @@ -0,0 +1,3 @@ +from .subpkg import util +from . import rebinding +util = util.util diff --git a/Lib/test/test_import/data/circular_imports/subpackage.py b/Lib/test/test_import/data/circular_imports/subpackage.py new file mode 100644 index 0000000..7b412f7 --- /dev/null +++ b/Lib/test/test_import/data/circular_imports/subpackage.py @@ -0,0 +1,2 @@ +"""Circular import involving a sub-package.""" +from .subpkg import subpackage2 diff --git a/Lib/test/test_import/data/circular_imports/subpkg/subpackage2.py b/Lib/test/test_import/data/circular_imports/subpkg/subpackage2.py new file mode 100644 index 0000000..17b893a --- /dev/null +++ b/Lib/test/test_import/data/circular_imports/subpkg/subpackage2.py @@ -0,0 +1,2 @@ +#from .util import util +from .. import subpackage diff --git a/Lib/test/test_import/data/circular_imports/subpkg/util.py b/Lib/test/test_import/data/circular_imports/subpkg/util.py new file mode 100644 index 0000000..343bd84 --- /dev/null +++ b/Lib/test/test_import/data/circular_imports/subpkg/util.py @@ -0,0 +1,2 @@ +def util(): + pass diff --git a/Lib/test/test_import/data/circular_imports/util.py b/Lib/test/test_import/data/circular_imports/util.py new file mode 100644 index 0000000..343bd84 --- /dev/null +++ b/Lib/test/test_import/data/circular_imports/util.py @@ -0,0 +1,2 @@ +def util(): + pass diff --git a/Lib/test/test_importlib/builtin/test_finder.py b/Lib/test/test_importlib/builtin/test_finder.py index 934562f..a2e6e1e 100644 --- a/Lib/test/test_importlib/builtin/test_finder.py +++ b/Lib/test/test_importlib/builtin/test_finder.py @@ -1,21 +1,21 @@ from .. import abc from .. import util -from . import util as builtin_util -frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') +machinery = util.import_importlib('importlib.machinery') import sys import unittest +@unittest.skipIf(util.BUILTINS.good_name is None, 'no reasonable builtin module') class FindSpecTests(abc.FinderTests): """Test find_spec() for built-in modules.""" def test_module(self): # Common case. - with util.uncache(builtin_util.NAME): - found = self.machinery.BuiltinImporter.find_spec(builtin_util.NAME) + with util.uncache(util.BUILTINS.good_name): + found = self.machinery.BuiltinImporter.find_spec(util.BUILTINS.good_name) self.assertTrue(found) self.assertEqual(found.origin, 'built-in') @@ -39,23 +39,26 @@ class FindSpecTests(abc.FinderTests): def test_ignore_path(self): # The value for 'path' should always trigger a failed import. - with util.uncache(builtin_util.NAME): - spec = self.machinery.BuiltinImporter.find_spec(builtin_util.NAME, + with util.uncache(util.BUILTINS.good_name): + spec = self.machinery.BuiltinImporter.find_spec(util.BUILTINS.good_name, ['pkg']) self.assertIsNone(spec) -Frozen_FindSpecTests, Source_FindSpecTests = util.test_both(FindSpecTests, - machinery=[frozen_machinery, source_machinery]) +(Frozen_FindSpecTests, + Source_FindSpecTests + ) = util.test_both(FindSpecTests, machinery=machinery) + +@unittest.skipIf(util.BUILTINS.good_name is None, 'no reasonable builtin module') class FinderTests(abc.FinderTests): """Test find_module() for built-in modules.""" def test_module(self): # Common case. - with util.uncache(builtin_util.NAME): - found = self.machinery.BuiltinImporter.find_module(builtin_util.NAME) + with util.uncache(util.BUILTINS.good_name): + found = self.machinery.BuiltinImporter.find_module(util.BUILTINS.good_name) self.assertTrue(found) self.assertTrue(hasattr(found, 'load_module')) @@ -72,13 +75,15 @@ class FinderTests(abc.FinderTests): def test_ignore_path(self): # The value for 'path' should always trigger a failed import. - with util.uncache(builtin_util.NAME): - loader = self.machinery.BuiltinImporter.find_module(builtin_util.NAME, + with util.uncache(util.BUILTINS.good_name): + loader = self.machinery.BuiltinImporter.find_module(util.BUILTINS.good_name, ['pkg']) self.assertIsNone(loader) -Frozen_FinderTests, Source_FinderTests = util.test_both(FinderTests, - machinery=[frozen_machinery, source_machinery]) + +(Frozen_FinderTests, + Source_FinderTests + ) = util.test_both(FinderTests, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/builtin/test_loader.py b/Lib/test/test_importlib/builtin/test_loader.py index 1f83574..b1349ec 100644 --- a/Lib/test/test_importlib/builtin/test_loader.py +++ b/Lib/test/test_importlib/builtin/test_loader.py @@ -1,14 +1,13 @@ from .. import abc from .. import util -from . import util as builtin_util -frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') +machinery = util.import_importlib('importlib.machinery') import sys import types import unittest - +@unittest.skipIf(util.BUILTINS.good_name is None, 'no reasonable builtin module') class LoaderTests(abc.LoaderTests): """Test load_module() for built-in modules.""" @@ -29,8 +28,8 @@ class LoaderTests(abc.LoaderTests): def test_module(self): # Common case. - with util.uncache(builtin_util.NAME): - module = self.load_module(builtin_util.NAME) + with util.uncache(util.BUILTINS.good_name): + module = self.load_module(util.BUILTINS.good_name) self.verify(module) # Built-in modules cannot be a package. @@ -41,9 +40,9 @@ class LoaderTests(abc.LoaderTests): def test_module_reuse(self): # Test that the same module is used in a reload. - with util.uncache(builtin_util.NAME): - module1 = self.load_module(builtin_util.NAME) - module2 = self.load_module(builtin_util.NAME) + with util.uncache(util.BUILTINS.good_name): + module1 = self.load_module(util.BUILTINS.good_name) + module2 = self.load_module(util.BUILTINS.good_name) self.assertIs(module1, module2) def test_unloadable(self): @@ -66,40 +65,43 @@ class LoaderTests(abc.LoaderTests): self.assertEqual(cm.exception.name, module_name) -Frozen_LoaderTests, Source_LoaderTests = util.test_both(LoaderTests, - machinery=[frozen_machinery, source_machinery]) +(Frozen_LoaderTests, + Source_LoaderTests + ) = util.test_both(LoaderTests, machinery=machinery) +@unittest.skipIf(util.BUILTINS.good_name is None, 'no reasonable builtin module') class InspectLoaderTests: """Tests for InspectLoader methods for BuiltinImporter.""" def test_get_code(self): # There is no code object. - result = self.machinery.BuiltinImporter.get_code(builtin_util.NAME) + result = self.machinery.BuiltinImporter.get_code(util.BUILTINS.good_name) self.assertIsNone(result) def test_get_source(self): # There is no source. - result = self.machinery.BuiltinImporter.get_source(builtin_util.NAME) + result = self.machinery.BuiltinImporter.get_source(util.BUILTINS.good_name) self.assertIsNone(result) def test_is_package(self): # Cannot be a package. - result = self.machinery.BuiltinImporter.is_package(builtin_util.NAME) + result = self.machinery.BuiltinImporter.is_package(util.BUILTINS.good_name) self.assertFalse(result) + @unittest.skipIf(util.BUILTINS.bad_name is None, 'all modules are built in') def test_not_builtin(self): # Modules not built-in should raise ImportError. for meth_name in ('get_code', 'get_source', 'is_package'): method = getattr(self.machinery.BuiltinImporter, meth_name) with self.assertRaises(ImportError) as cm: - method(builtin_util.BAD_NAME) - self.assertRaises(builtin_util.BAD_NAME) + method(util.BUILTINS.bad_name) + -Frozen_InspectLoaderTests, Source_InspectLoaderTests = util.test_both( - InspectLoaderTests, - machinery=[frozen_machinery, source_machinery]) +(Frozen_InspectLoaderTests, + Source_InspectLoaderTests + ) = util.test_both(InspectLoaderTests, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/builtin/util.py b/Lib/test/test_importlib/builtin/util.py deleted file mode 100644 index 5704699..0000000 --- a/Lib/test/test_importlib/builtin/util.py +++ /dev/null @@ -1,7 +0,0 @@ -import sys - -assert 'errno' in sys.builtin_module_names -NAME = 'errno' - -assert 'importlib' not in sys.builtin_module_names -BAD_NAME = 'importlib' diff --git a/Lib/test/test_importlib/extension/test_case_sensitivity.py b/Lib/test/test_importlib/extension/test_case_sensitivity.py index bb2528e..706c3e4 100644 --- a/Lib/test/test_importlib/extension/test_case_sensitivity.py +++ b/Lib/test/test_importlib/extension/test_case_sensitivity.py @@ -1,25 +1,24 @@ -from importlib import _bootstrap +from importlib import _bootstrap_external import sys from test import support import unittest from .. import util -from . import util as ext_util -frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') +machinery = util.import_importlib('importlib.machinery') # XXX find_spec tests -@unittest.skipIf(ext_util.FILENAME is None, '_testcapi not available') +@unittest.skipIf(util.EXTENSIONS.filename is None, '_testcapi not available') @util.case_insensitive_tests class ExtensionModuleCaseSensitivityTest: def find_module(self): - good_name = ext_util.NAME + good_name = util.EXTENSIONS.name bad_name = good_name.upper() assert good_name != bad_name - finder = self.machinery.FileFinder(ext_util.PATH, + finder = self.machinery.FileFinder(util.EXTENSIONS.path, (self.machinery.ExtensionFileLoader, self.machinery.EXTENSION_SUFFIXES)) return finder.find_module(bad_name) @@ -27,7 +26,7 @@ class ExtensionModuleCaseSensitivityTest: def test_case_sensitive(self): with support.EnvironmentVarGuard() as env: env.unset('PYTHONCASEOK') - if b'PYTHONCASEOK' in _bootstrap._os.environ: + if b'PYTHONCASEOK' in _bootstrap_external._os.environ: self.skipTest('os.environ changes not reflected in ' '_os.environ') loader = self.find_module() @@ -36,15 +35,16 @@ class ExtensionModuleCaseSensitivityTest: def test_case_insensitivity(self): with support.EnvironmentVarGuard() as env: env.set('PYTHONCASEOK', '1') - if b'PYTHONCASEOK' not in _bootstrap._os.environ: + if b'PYTHONCASEOK' not in _bootstrap_external._os.environ: self.skipTest('os.environ changes not reflected in ' '_os.environ') loader = self.find_module() self.assertTrue(hasattr(loader, 'load_module')) -Frozen_ExtensionCaseSensitivity, Source_ExtensionCaseSensitivity = util.test_both( - ExtensionModuleCaseSensitivityTest, - machinery=[frozen_machinery, source_machinery]) + +(Frozen_ExtensionCaseSensitivity, + Source_ExtensionCaseSensitivity + ) = util.test_both(ExtensionModuleCaseSensitivityTest, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/extension/test_finder.py b/Lib/test/test_importlib/extension/test_finder.py index 990f29c..71bf67f 100644 --- a/Lib/test/test_importlib/extension/test_finder.py +++ b/Lib/test/test_importlib/extension/test_finder.py @@ -1,8 +1,7 @@ from .. import abc -from .. import util as test_util -from . import util +from .. import util -machinery = test_util.import_importlib('importlib.machinery') +machinery = util.import_importlib('importlib.machinery') import unittest import warnings @@ -14,7 +13,7 @@ class FinderTests(abc.FinderTests): """Test the finder for extension modules.""" def find_module(self, fullname): - importer = self.machinery.FileFinder(util.PATH, + importer = self.machinery.FileFinder(util.EXTENSIONS.path, (self.machinery.ExtensionFileLoader, self.machinery.EXTENSION_SUFFIXES)) with warnings.catch_warnings(): @@ -22,7 +21,7 @@ class FinderTests(abc.FinderTests): return importer.find_module(fullname) def test_module(self): - self.assertTrue(self.find_module(util.NAME)) + self.assertTrue(self.find_module(util.EXTENSIONS.name)) # No extension module as an __init__ available for testing. test_package = test_package_in_package = None @@ -36,8 +35,10 @@ class FinderTests(abc.FinderTests): def test_failure(self): self.assertIsNone(self.find_module('asdfjkl;')) -Frozen_FinderTests, Source_FinderTests = test_util.test_both( - FinderTests, machinery=machinery) + +(Frozen_FinderTests, + Source_FinderTests + ) = util.test_both(FinderTests, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/extension/test_loader.py b/Lib/test/test_importlib/extension/test_loader.py index fd9abf2..154a793 100644 --- a/Lib/test/test_importlib/extension/test_loader.py +++ b/Lib/test/test_importlib/extension/test_loader.py @@ -1,4 +1,3 @@ -from . import util as ext_util from .. import abc from .. import util @@ -8,6 +7,8 @@ import os.path import sys import types import unittest +import importlib.util +import importlib class LoaderTests(abc.LoaderTests): @@ -15,8 +16,8 @@ class LoaderTests(abc.LoaderTests): """Test load_module() for extension modules.""" def setUp(self): - self.loader = self.machinery.ExtensionFileLoader(ext_util.NAME, - ext_util.FILEPATH) + self.loader = self.machinery.ExtensionFileLoader(util.EXTENSIONS.name, + util.EXTENSIONS.file_path) def load_module(self, fullname): return self.loader.load_module(fullname) @@ -29,23 +30,23 @@ class LoaderTests(abc.LoaderTests): self.load_module('XXX') def test_equality(self): - other = self.machinery.ExtensionFileLoader(ext_util.NAME, - ext_util.FILEPATH) + other = self.machinery.ExtensionFileLoader(util.EXTENSIONS.name, + util.EXTENSIONS.file_path) self.assertEqual(self.loader, other) def test_inequality(self): - other = self.machinery.ExtensionFileLoader('_' + ext_util.NAME, - ext_util.FILEPATH) + other = self.machinery.ExtensionFileLoader('_' + util.EXTENSIONS.name, + util.EXTENSIONS.file_path) self.assertNotEqual(self.loader, other) def test_module(self): - with util.uncache(ext_util.NAME): - module = self.load_module(ext_util.NAME) - for attr, value in [('__name__', ext_util.NAME), - ('__file__', ext_util.FILEPATH), + with util.uncache(util.EXTENSIONS.name): + module = self.load_module(util.EXTENSIONS.name) + for attr, value in [('__name__', util.EXTENSIONS.name), + ('__file__', util.EXTENSIONS.file_path), ('__package__', '')]: self.assertEqual(getattr(module, attr), value) - self.assertIn(ext_util.NAME, sys.modules) + self.assertIn(util.EXTENSIONS.name, sys.modules) self.assertIsInstance(module.__loader__, self.machinery.ExtensionFileLoader) @@ -56,9 +57,9 @@ class LoaderTests(abc.LoaderTests): test_lacking_parent = None def test_module_reuse(self): - with util.uncache(ext_util.NAME): - module1 = self.load_module(ext_util.NAME) - module2 = self.load_module(ext_util.NAME) + with util.uncache(util.EXTENSIONS.name): + module1 = self.load_module(util.EXTENSIONS.name) + module2 = self.load_module(util.EXTENSIONS.name) self.assertIs(module1, module2) # No easy way to trigger a failure after a successful import. @@ -71,15 +72,196 @@ class LoaderTests(abc.LoaderTests): self.assertEqual(cm.exception.name, name) def test_is_package(self): - self.assertFalse(self.loader.is_package(ext_util.NAME)) + self.assertFalse(self.loader.is_package(util.EXTENSIONS.name)) for suffix in self.machinery.EXTENSION_SUFFIXES: path = os.path.join('some', 'path', 'pkg', '__init__' + suffix) loader = self.machinery.ExtensionFileLoader('pkg', path) self.assertTrue(loader.is_package('pkg')) -Frozen_LoaderTests, Source_LoaderTests = util.test_both( - LoaderTests, machinery=machinery) +(Frozen_LoaderTests, + Source_LoaderTests + ) = util.test_both(LoaderTests, machinery=machinery) +class MultiPhaseExtensionModuleTests(abc.LoaderTests): + """Test loading extension modules with multi-phase initialization (PEP 489) + """ + + def setUp(self): + self.name = '_testmultiphase' + finder = self.machinery.FileFinder(None) + self.spec = importlib.util.find_spec(self.name) + assert self.spec + self.loader = self.machinery.ExtensionFileLoader( + self.name, self.spec.origin) + + # No extension module as __init__ available for testing. + test_package = None + + # No extension module in a package available for testing. + test_lacking_parent = None + + # Handling failure on reload is the up to the module. + test_state_after_failure = None + + def test_module(self): + '''Test loading an extension module''' + with util.uncache(self.name): + module = self.load_module() + for attr, value in [('__name__', self.name), + ('__file__', self.spec.origin), + ('__package__', '')]: + self.assertEqual(getattr(module, attr), value) + with self.assertRaises(AttributeError): + module.__path__ + self.assertIs(module, sys.modules[self.name]) + self.assertIsInstance(module.__loader__, + self.machinery.ExtensionFileLoader) + + def test_functionality(self): + '''Test basic functionality of stuff defined in an extension module''' + with util.uncache(self.name): + module = self.load_module() + self.assertIsInstance(module, types.ModuleType) + ex = module.Example() + self.assertEqual(ex.demo('abcd'), 'abcd') + self.assertEqual(ex.demo(), None) + with self.assertRaises(AttributeError): + ex.abc + ex.abc = 0 + self.assertEqual(ex.abc, 0) + self.assertEqual(module.foo(9, 9), 18) + self.assertIsInstance(module.Str(), str) + self.assertEqual(module.Str(1) + '23', '123') + with self.assertRaises(module.error): + raise module.error() + self.assertEqual(module.int_const, 1969) + self.assertEqual(module.str_const, 'something different') + + def test_reload(self): + '''Test that reload didn't re-set the module's attributes''' + with util.uncache(self.name): + module = self.load_module() + ex_class = module.Example + importlib.reload(module) + self.assertIs(ex_class, module.Example) + + def test_try_registration(self): + '''Assert that the PyState_{Find,Add,Remove}Module C API doesn't work''' + module = self.load_module() + with self.subTest('PyState_FindModule'): + self.assertEqual(module.call_state_registration_func(0), None) + with self.subTest('PyState_AddModule'): + with self.assertRaises(SystemError): + module.call_state_registration_func(1) + with self.subTest('PyState_RemoveModule'): + with self.assertRaises(SystemError): + module.call_state_registration_func(2) + + def load_module(self): + '''Load the module from the test extension''' + return self.loader.load_module(self.name) + + def load_module_by_name(self, fullname): + '''Load a module from the test extension by name''' + origin = self.spec.origin + loader = self.machinery.ExtensionFileLoader(fullname, origin) + spec = importlib.util.spec_from_loader(fullname, loader) + module = importlib.util.module_from_spec(spec) + loader.exec_module(module) + return module + + def test_load_submodule(self): + '''Test loading a simulated submodule''' + module = self.load_module_by_name('pkg.' + self.name) + self.assertIsInstance(module, types.ModuleType) + self.assertEqual(module.__name__, 'pkg.' + self.name) + self.assertEqual(module.str_const, 'something different') + + def test_load_short_name(self): + '''Test loading module with a one-character name''' + module = self.load_module_by_name('x') + self.assertIsInstance(module, types.ModuleType) + self.assertEqual(module.__name__, 'x') + self.assertEqual(module.str_const, 'something different') + self.assertNotIn('x', sys.modules) + + def test_load_twice(self): + '''Test that 2 loads result in 2 module objects''' + module1 = self.load_module_by_name(self.name) + module2 = self.load_module_by_name(self.name) + self.assertIsNot(module1, module2) + + def test_unloadable(self): + '''Test nonexistent module''' + name = 'asdfjkl;' + with self.assertRaises(ImportError) as cm: + self.load_module_by_name(name) + self.assertEqual(cm.exception.name, name) + + def test_unloadable_nonascii(self): + '''Test behavior with nonexistent module with non-ASCII name''' + name = 'fo\xf3' + with self.assertRaises(ImportError) as cm: + self.load_module_by_name(name) + self.assertEqual(cm.exception.name, name) + + def test_nonmodule(self): + '''Test returning a non-module object from create works''' + name = self.name + '_nonmodule' + mod = self.load_module_by_name(name) + self.assertNotEqual(type(mod), type(unittest)) + self.assertEqual(mod.three, 3) + + def test_null_slots(self): + '''Test that NULL slots aren't a problem''' + name = self.name + '_null_slots' + module = self.load_module_by_name(name) + self.assertIsInstance(module, types.ModuleType) + self.assertEqual(module.__name__, name) + + def test_bad_modules(self): + '''Test SystemError is raised for misbehaving extensions''' + for name_base in [ + 'bad_slot_large', + 'bad_slot_negative', + 'create_int_with_state', + 'negative_size', + 'export_null', + 'export_uninitialized', + 'export_raise', + 'export_unreported_exception', + 'create_null', + 'create_raise', + 'create_unreported_exception', + 'nonmodule_with_exec_slots', + 'exec_err', + 'exec_raise', + 'exec_unreported_exception', + ]: + with self.subTest(name_base): + name = self.name + '_' + name_base + with self.assertRaises(SystemError): + self.load_module_by_name(name) + + def test_nonascii(self): + '''Test that modules with non-ASCII names can be loaded''' + # punycode behaves slightly differently in some-ASCII and no-ASCII + # cases, so test both + cases = [ + (self.name + '_zkou\u0161ka_na\u010dten\xed', 'Czech'), + ('\uff3f\u30a4\u30f3\u30dd\u30fc\u30c8\u30c6\u30b9\u30c8', + 'Japanese'), + ] + for name, lang in cases: + with self.subTest(name): + module = self.load_module_by_name(name) + self.assertEqual(module.__name__, name) + self.assertEqual(module.__doc__, "Module named in %s" % lang) + + +(Frozen_MultiPhaseExtensionModuleTests, + Source_MultiPhaseExtensionModuleTests + ) = util.test_both(MultiPhaseExtensionModuleTests, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/extension/test_path_hook.py b/Lib/test/test_importlib/extension/test_path_hook.py index 49d6734..8f4b8bb 100644 --- a/Lib/test/test_importlib/extension/test_path_hook.py +++ b/Lib/test/test_importlib/extension/test_path_hook.py @@ -1,7 +1,6 @@ -from .. import util as test_util -from . import util +from .. import util -machinery = test_util.import_importlib('importlib.machinery') +machinery = util.import_importlib('importlib.machinery') import collections import sys @@ -22,10 +21,12 @@ class PathHookTests: def test_success(self): # Path hook should handle a directory where a known extension module # exists. - self.assertTrue(hasattr(self.hook(util.PATH), 'find_module')) + self.assertTrue(hasattr(self.hook(util.EXTENSIONS.path), 'find_module')) -Frozen_PathHooksTests, Source_PathHooksTests = test_util.test_both( - PathHookTests, machinery=machinery) + +(Frozen_PathHooksTests, + Source_PathHooksTests + ) = util.test_both(PathHookTests, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/extension/util.py b/Lib/test/test_importlib/extension/util.py deleted file mode 100644 index 8d089f0..0000000 --- a/Lib/test/test_importlib/extension/util.py +++ /dev/null @@ -1,19 +0,0 @@ -from importlib import machinery -import os -import sys - -PATH = None -EXT = None -FILENAME = None -NAME = '_testcapi' -try: - for PATH in sys.path: - for EXT in machinery.EXTENSION_SUFFIXES: - FILENAME = NAME + EXT - FILEPATH = os.path.join(PATH, FILENAME) - if os.path.exists(os.path.join(PATH, FILENAME)): - raise StopIteration - else: - PATH = EXT = FILENAME = FILEPATH = None -except StopIteration: - pass diff --git a/Lib/test/test_importlib/frozen/test_finder.py b/Lib/test/test_importlib/frozen/test_finder.py index f9f97f3..519aa02 100644 --- a/Lib/test/test_importlib/frozen/test_finder.py +++ b/Lib/test/test_importlib/frozen/test_finder.py @@ -37,8 +37,10 @@ class FindSpecTests(abc.FinderTests): spec = self.find('<not real>') self.assertIsNone(spec) -Frozen_FindSpecTests, Source_FindSpecTests = util.test_both(FindSpecTests, - machinery=machinery) + +(Frozen_FindSpecTests, + Source_FindSpecTests + ) = util.test_both(FindSpecTests, machinery=machinery) class FinderTests(abc.FinderTests): @@ -72,8 +74,10 @@ class FinderTests(abc.FinderTests): loader = self.find('<not real>') self.assertIsNone(loader) -Frozen_FinderTests, Source_FinderTests = util.test_both(FinderTests, - machinery=machinery) + +(Frozen_FinderTests, + Source_FinderTests + ) = util.test_both(FinderTests, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/frozen/test_loader.py b/Lib/test/test_importlib/frozen/test_loader.py index 7c01464..603c7d7 100644 --- a/Lib/test/test_importlib/frozen/test_loader.py +++ b/Lib/test/test_importlib/frozen/test_loader.py @@ -85,8 +85,10 @@ class ExecModuleTests(abc.LoaderTests): self.exec_module('_not_real') self.assertEqual(cm.exception.name, '_not_real') -Frozen_ExecModuleTests, Source_ExecModuleTests = util.test_both(ExecModuleTests, - machinery=machinery) + +(Frozen_ExecModuleTests, + Source_ExecModuleTests + ) = util.test_both(ExecModuleTests, machinery=machinery) class LoaderTests(abc.LoaderTests): @@ -175,8 +177,10 @@ class LoaderTests(abc.LoaderTests): self.machinery.FrozenImporter.load_module('_not_real') self.assertEqual(cm.exception.name, '_not_real') -Frozen_LoaderTests, Source_LoaderTests = util.test_both(LoaderTests, - machinery=machinery) + +(Frozen_LoaderTests, + Source_LoaderTests + ) = util.test_both(LoaderTests, machinery=machinery) class InspectLoaderTests: @@ -214,8 +218,9 @@ class InspectLoaderTests: method('importlib') self.assertEqual(cm.exception.name, 'importlib') -Frozen_ILTests, Source_ILTests = util.test_both(InspectLoaderTests, - machinery=machinery) +(Frozen_ILTests, + Source_ILTests + ) = util.test_both(InspectLoaderTests, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/test___loader__.py b/Lib/test/test_importlib/import_/test___loader__.py index 6df8010..4b18093 100644 --- a/Lib/test/test_importlib/import_/test___loader__.py +++ b/Lib/test/test_importlib/import_/test___loader__.py @@ -4,7 +4,6 @@ import types import unittest from .. import util -from . import util as import_util class SpecLoaderMock: @@ -12,6 +11,9 @@ class SpecLoaderMock: def find_spec(self, fullname, path=None, target=None): return machinery.ModuleSpec(fullname, self) + def create_module(self, spec): + return None + def exec_module(self, module): pass @@ -24,8 +26,10 @@ class SpecLoaderAttributeTests: module = self.__import__('blah') self.assertEqual(loader, module.__loader__) -Frozen_SpecTests, Source_SpecTests = util.test_both( - SpecLoaderAttributeTests, __import__=import_util.__import__) + +(Frozen_SpecTests, + Source_SpecTests + ) = util.test_both(SpecLoaderAttributeTests, __import__=util.__import__) class LoaderMock: @@ -62,8 +66,9 @@ class LoaderAttributeTests: self.assertEqual(loader, module.__loader__) -Frozen_Tests, Source_Tests = util.test_both(LoaderAttributeTests, - __import__=import_util.__import__) +(Frozen_Tests, + Source_Tests + ) = util.test_both(LoaderAttributeTests, __import__=util.__import__) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/test___package__.py b/Lib/test/test_importlib/import_/test___package__.py index 2e19725..c7d3a2a 100644 --- a/Lib/test/test_importlib/import_/test___package__.py +++ b/Lib/test/test_importlib/import_/test___package__.py @@ -6,7 +6,6 @@ of using the typical __path__/__name__ test). """ import unittest from .. import util -from . import util as import_util class Using__package__: @@ -70,17 +69,23 @@ class Using__package__: with self.assertRaises(TypeError): self.__import__('', globals, {}, ['relimport'], 1) + class Using__package__PEP302(Using__package__): mock_modules = util.mock_modules -Frozen_UsingPackagePEP302, Source_UsingPackagePEP302 = util.test_both( - Using__package__PEP302, __import__=import_util.__import__) -class Using__package__PEP302(Using__package__): +(Frozen_UsingPackagePEP302, + Source_UsingPackagePEP302 + ) = util.test_both(Using__package__PEP302, __import__=util.__import__) + + +class Using__package__PEP451(Using__package__): mock_modules = util.mock_spec -Frozen_UsingPackagePEP451, Source_UsingPackagePEP451 = util.test_both( - Using__package__PEP302, __import__=import_util.__import__) + +(Frozen_UsingPackagePEP451, + Source_UsingPackagePEP451 + ) = util.test_both(Using__package__PEP451, __import__=util.__import__) class Setting__package__: @@ -95,7 +100,7 @@ class Setting__package__: """ - __import__ = import_util.__import__[1] + __import__ = util.__import__['Source'] # [top-level] def test_top_level(self): diff --git a/Lib/test/test_importlib/import_/test_api.py b/Lib/test/test_importlib/import_/test_api.py index 439c105..7069d9e 100644 --- a/Lib/test/test_importlib/import_/test_api.py +++ b/Lib/test/test_importlib/import_/test_api.py @@ -1,5 +1,4 @@ from .. import util -from . import util as import_util from importlib import machinery import sys @@ -18,6 +17,10 @@ class BadSpecFinderLoader: return spec @staticmethod + def create_module(spec): + return None + + @staticmethod def exec_module(module): if module.__name__ == SUBMOD_NAME: raise ImportError('I cannot be loaded!') @@ -79,15 +82,19 @@ class APITest: class OldAPITests(APITest): bad_finder_loader = BadLoaderFinder -Frozen_OldAPITests, Source_OldAPITests = util.test_both( - OldAPITests, __import__=import_util.__import__) + +(Frozen_OldAPITests, + Source_OldAPITests + ) = util.test_both(OldAPITests, __import__=util.__import__) class SpecAPITests(APITest): bad_finder_loader = BadSpecFinderLoader -Frozen_SpecAPITests, Source_SpecAPITests = util.test_both( - SpecAPITests, __import__=import_util.__import__) + +(Frozen_SpecAPITests, + Source_SpecAPITests + ) = util.test_both(SpecAPITests, __import__=util.__import__) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/test_caching.py b/Lib/test/test_importlib/import_/test_caching.py index c292ee4..8079add 100644 --- a/Lib/test/test_importlib/import_/test_caching.py +++ b/Lib/test/test_importlib/import_/test_caching.py @@ -1,6 +1,5 @@ """Test that sys.modules is used properly by import.""" from .. import util -from . import util as import_util import sys from types import MethodType import unittest @@ -39,15 +38,17 @@ class UseCache: self.__import__(name) self.assertEqual(cm.exception.name, name) -Frozen_UseCache, Source_UseCache = util.test_both( - UseCache, __import__=import_util.__import__) + +(Frozen_UseCache, + Source_UseCache + ) = util.test_both(UseCache, __import__=util.__import__) class ImportlibUseCache(UseCache, unittest.TestCase): # Pertinent only to PEP 302; exec_module() doesn't return a module. - __import__ = import_util.__import__[1] + __import__ = util.__import__['Source'] def create_mock(self, *names, return_=None): mock = util.mock_modules(*names) diff --git a/Lib/test/test_importlib/import_/test_fromlist.py b/Lib/test/test_importlib/import_/test_fromlist.py index a755b75..8045465 100644 --- a/Lib/test/test_importlib/import_/test_fromlist.py +++ b/Lib/test/test_importlib/import_/test_fromlist.py @@ -1,6 +1,5 @@ """Test that the semantics relating to the 'fromlist' argument are correct.""" from .. import util -from . import util as import_util import unittest @@ -29,8 +28,10 @@ class ReturnValue: module = self.__import__('pkg.module', fromlist=['attr']) self.assertEqual(module.__name__, 'pkg.module') -Frozen_ReturnValue, Source_ReturnValue = util.test_both( - ReturnValue, __import__=import_util.__import__) + +(Frozen_ReturnValue, + Source_ReturnValue + ) = util.test_both(ReturnValue, __import__=util.__import__) class HandlingFromlist: @@ -121,8 +122,10 @@ class HandlingFromlist: self.assertEqual(module.module1.__name__, 'pkg.module1') self.assertEqual(module.module2.__name__, 'pkg.module2') -Frozen_FromList, Source_FromList = util.test_both( - HandlingFromlist, __import__=import_util.__import__) + +(Frozen_FromList, + Source_FromList + ) = util.test_both(HandlingFromlist, __import__=util.__import__) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/test_meta_path.py b/Lib/test/test_importlib/import_/test_meta_path.py index 5eeb145..c452cdd 100644 --- a/Lib/test/test_importlib/import_/test_meta_path.py +++ b/Lib/test/test_importlib/import_/test_meta_path.py @@ -1,5 +1,4 @@ from .. import util -from . import util as import_util import importlib._bootstrap import sys from types import MethodType @@ -46,8 +45,10 @@ class CallingOrder: self.assertEqual(len(w), 1) self.assertTrue(issubclass(w[-1].category, ImportWarning)) -Frozen_CallingOrder, Source_CallingOrder = util.test_both( - CallingOrder, __import__=import_util.__import__) + +(Frozen_CallingOrder, + Source_CallingOrder + ) = util.test_both(CallingOrder, __import__=util.__import__) class CallSignature: @@ -100,19 +101,25 @@ class CallSignature: self.assertEqual(args[0], mod_name) self.assertIs(args[1], path) + class CallSignaturePEP302(CallSignature): mock_modules = util.mock_modules finder_name = 'find_module' -Frozen_CallSignaturePEP302, Source_CallSignaturePEP302 = util.test_both( - CallSignaturePEP302, __import__=import_util.__import__) + +(Frozen_CallSignaturePEP302, + Source_CallSignaturePEP302 + ) = util.test_both(CallSignaturePEP302, __import__=util.__import__) + class CallSignaturePEP451(CallSignature): mock_modules = util.mock_spec finder_name = 'find_spec' -Frozen_CallSignaturePEP451, Source_CallSignaturePEP451 = util.test_both( - CallSignaturePEP451, __import__=import_util.__import__) + +(Frozen_CallSignaturePEP451, + Source_CallSignaturePEP451 + ) = util.test_both(CallSignaturePEP451, __import__=util.__import__) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/test_packages.py b/Lib/test/test_importlib/import_/test_packages.py index 55a5d14..3755b84 100644 --- a/Lib/test/test_importlib/import_/test_packages.py +++ b/Lib/test/test_importlib/import_/test_packages.py @@ -1,5 +1,4 @@ from .. import util -from . import util as import_util import sys import unittest import importlib @@ -102,8 +101,10 @@ class ParentModuleTests: finally: support.unload(subname) -Frozen_ParentTests, Source_ParentTests = util.test_both( - ParentModuleTests, __import__=import_util.__import__) + +(Frozen_ParentTests, + Source_ParentTests + ) = util.test_both(ParentModuleTests, __import__=util.__import__) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/test_path.py b/Lib/test/test_importlib/import_/test_path.py index 1274f8c..4359dd9 100644 --- a/Lib/test/test_importlib/import_/test_path.py +++ b/Lib/test/test_importlib/import_/test_path.py @@ -1,11 +1,12 @@ from .. import util -from . import util as import_util importlib = util.import_importlib('importlib') machinery = util.import_importlib('importlib.machinery') +import errno import os import sys +import tempfile from types import ModuleType import unittest import warnings @@ -58,7 +59,7 @@ class FinderTests: module = '<test module>' path = '<test path>' importer = util.mock_spec(module) - hook = import_util.mock_path_hook(path, importer=importer) + hook = util.mock_path_hook(path, importer=importer) with util.import_state(path_hooks=[hook]): loader = self.machinery.PathFinder.find_module(module, [path]) self.assertIs(loader, importer) @@ -83,7 +84,7 @@ class FinderTests: path = '' module = '<test module>' importer = util.mock_spec(module) - hook = import_util.mock_path_hook(os.getcwd(), importer=importer) + hook = util.mock_path_hook(os.getcwd(), importer=importer) with util.import_state(path=[path], path_hooks=[hook]): loader = self.machinery.PathFinder.find_module(module) self.assertIs(loader, importer) @@ -98,7 +99,7 @@ class FinderTests: new_path_importer_cache.pop(None, None) new_path_hooks = [zipimport.zipimporter, self.machinery.FileFinder.path_hook( - *self.importlib._bootstrap._get_supported_file_loaders())] + *self.importlib._bootstrap_external._get_supported_file_loaders())] missing = object() email = sys.modules.pop('email', missing) try: @@ -112,8 +113,74 @@ class FinderTests: if email is not missing: sys.modules['email'] = email -Frozen_FinderTests, Source_FinderTests = util.test_both( - FinderTests, importlib=importlib, machinery=machinery) + def test_finder_with_find_module(self): + class TestFinder: + def find_module(self, fullname): + return self.to_return + failing_finder = TestFinder() + failing_finder.to_return = None + path = 'testing path' + with util.import_state(path_importer_cache={path: failing_finder}): + self.assertIsNone( + self.machinery.PathFinder.find_spec('whatever', [path])) + success_finder = TestFinder() + success_finder.to_return = __loader__ + with util.import_state(path_importer_cache={path: success_finder}): + spec = self.machinery.PathFinder.find_spec('whatever', [path]) + self.assertEqual(spec.loader, __loader__) + + def test_finder_with_find_loader(self): + class TestFinder: + loader = None + portions = [] + def find_loader(self, fullname): + return self.loader, self.portions + path = 'testing path' + with util.import_state(path_importer_cache={path: TestFinder()}): + self.assertIsNone( + self.machinery.PathFinder.find_spec('whatever', [path])) + success_finder = TestFinder() + success_finder.loader = __loader__ + with util.import_state(path_importer_cache={path: success_finder}): + spec = self.machinery.PathFinder.find_spec('whatever', [path]) + self.assertEqual(spec.loader, __loader__) + + def test_finder_with_find_spec(self): + class TestFinder: + spec = None + def find_spec(self, fullname, target=None): + return self.spec + path = 'testing path' + with util.import_state(path_importer_cache={path: TestFinder()}): + self.assertIsNone( + self.machinery.PathFinder.find_spec('whatever', [path])) + success_finder = TestFinder() + success_finder.spec = self.machinery.ModuleSpec('whatever', __loader__) + with util.import_state(path_importer_cache={path: success_finder}): + got = self.machinery.PathFinder.find_spec('whatever', [path]) + self.assertEqual(got, success_finder.spec) + + @unittest.skipIf(sys.platform == 'win32', "cwd can't not exist on Windows") + def test_deleted_cwd(self): + # Issue #22834 + self.addCleanup(os.chdir, os.getcwd()) + try: + with tempfile.TemporaryDirectory() as path: + os.chdir(path) + except OSError as exc: + if exc.errno == errno.EINVAL: + self.skipTest("platform does not allow the deletion of the cwd") + raise + with util.import_state(path=['']): + # Do not want FileNotFoundError raised. + self.assertIsNone(self.machinery.PathFinder.find_spec('whatever')) + + + + +(Frozen_FinderTests, + Source_FinderTests + ) = util.test_both(FinderTests, importlib=importlib, machinery=machinery) class PathEntryFinderTests: @@ -136,8 +203,10 @@ class PathEntryFinderTests: path_hooks=[Finder]): self.machinery.PathFinder.find_spec('importlib') -Frozen_PEFTests, Source_PEFTests = util.test_both( - PathEntryFinderTests, machinery=machinery) + +(Frozen_PEFTests, + Source_PEFTests + ) = util.test_both(PathEntryFinderTests, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/test_relative_imports.py b/Lib/test/test_importlib/import_/test_relative_imports.py index b216e9c..28bb6f7 100644 --- a/Lib/test/test_importlib/import_/test_relative_imports.py +++ b/Lib/test/test_importlib/import_/test_relative_imports.py @@ -1,6 +1,5 @@ """Test relative imports (PEP 328).""" from .. import util -from . import util as import_util import sys import unittest @@ -208,8 +207,10 @@ class RelativeImports: with self.assertRaises(KeyError): self.__import__('sys', level=1) -Frozen_RelativeImports, Source_RelativeImports = util.test_both( - RelativeImports, __import__=import_util.__import__) + +(Frozen_RelativeImports, + Source_RelativeImports + ) = util.test_both(RelativeImports, __import__=util.__import__) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/import_/util.py b/Lib/test/test_importlib/import_/util.py deleted file mode 100644 index dcb490f..0000000 --- a/Lib/test/test_importlib/import_/util.py +++ /dev/null @@ -1,20 +0,0 @@ -from .. import util - -frozen_importlib, source_importlib = util.import_importlib('importlib') - -import builtins -import functools -import importlib -import unittest - - -__import__ = staticmethod(builtins.__import__), staticmethod(source_importlib.__import__) - - -def mock_path_hook(*entries, importer): - """A mock sys.path_hooks entry.""" - def hook(entry): - if entry not in entries: - raise ImportError - return importer - return hook diff --git a/Lib/test/test_importlib/source/test_case_sensitivity.py b/Lib/test/test_importlib/source/test_case_sensitivity.py index efd3146..c274b38 100644 --- a/Lib/test/test_importlib/source/test_case_sensitivity.py +++ b/Lib/test/test_importlib/source/test_case_sensitivity.py @@ -1,6 +1,5 @@ """Test case-sensitivity (PEP 235).""" from .. import util -from . import util as source_util importlib = util.import_importlib('importlib') machinery = util.import_importlib('importlib.machinery') @@ -32,7 +31,7 @@ class CaseSensitivityTest: """Look for a module with matching and non-matching sensitivity.""" sensitive_pkg = 'sensitive.{0}'.format(self.name) insensitive_pkg = 'insensitive.{0}'.format(self.name.lower()) - context = source_util.create_modules(insensitive_pkg, sensitive_pkg) + context = util.create_modules(insensitive_pkg, sensitive_pkg) with context as mapping: sensitive_path = os.path.join(mapping['.root'], 'sensitive') insensitive_path = os.path.join(mapping['.root'], 'insensitive') @@ -43,7 +42,7 @@ class CaseSensitivityTest: def test_sensitive(self): with test_support.EnvironmentVarGuard() as env: env.unset('PYTHONCASEOK') - if b'PYTHONCASEOK' in self.importlib._bootstrap._os.environ: + if b'PYTHONCASEOK' in self.importlib._bootstrap_external._os.environ: self.skipTest('os.environ changes not reflected in ' '_os.environ') sensitive, insensitive = self.sensitivity_test() @@ -54,7 +53,7 @@ class CaseSensitivityTest: def test_insensitive(self): with test_support.EnvironmentVarGuard() as env: env.set('PYTHONCASEOK', '1') - if b'PYTHONCASEOK' not in self.importlib._bootstrap._os.environ: + if b'PYTHONCASEOK' not in self.importlib._bootstrap_external._os.environ: self.skipTest('os.environ changes not reflected in ' '_os.environ') sensitive, insensitive = self.sensitivity_test() @@ -63,20 +62,28 @@ class CaseSensitivityTest: self.assertIsNotNone(insensitive) self.assertIn(self.name, insensitive.get_filename(self.name)) + class CaseSensitivityTestPEP302(CaseSensitivityTest): def find(self, finder): return finder.find_module(self.name) -Frozen_CaseSensitivityTestPEP302, Source_CaseSensitivityTestPEP302 = util.test_both( - CaseSensitivityTestPEP302, importlib=importlib, machinery=machinery) + +(Frozen_CaseSensitivityTestPEP302, + Source_CaseSensitivityTestPEP302 + ) = util.test_both(CaseSensitivityTestPEP302, importlib=importlib, + machinery=machinery) + class CaseSensitivityTestPEP451(CaseSensitivityTest): def find(self, finder): found = finder.find_spec(self.name) return found.loader if found is not None else found -Frozen_CaseSensitivityTestPEP451, Source_CaseSensitivityTestPEP451 = util.test_both( - CaseSensitivityTestPEP451, importlib=importlib, machinery=machinery) + +(Frozen_CaseSensitivityTestPEP451, + Source_CaseSensitivityTestPEP451 + ) = util.test_both(CaseSensitivityTestPEP451, importlib=importlib, + machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/source/test_file_loader.py b/Lib/test/test_importlib/source/test_file_loader.py index 2d415f9..73f4c62 100644 --- a/Lib/test/test_importlib/source/test_file_loader.py +++ b/Lib/test/test_importlib/source/test_file_loader.py @@ -1,6 +1,5 @@ from .. import abc from .. import util -from . import util as source_util importlib = util.import_importlib('importlib') importlib_abc = util.import_importlib('importlib.abc') @@ -71,7 +70,7 @@ class SimpleTest(abc.LoaderTests): # [basic] def test_module(self): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: loader = self.machinery.SourceFileLoader('_temp', mapping['_temp']) with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) @@ -83,7 +82,7 @@ class SimpleTest(abc.LoaderTests): self.assertEqual(getattr(module, attr), value) def test_package(self): - with source_util.create_modules('_pkg.__init__') as mapping: + with util.create_modules('_pkg.__init__') as mapping: loader = self.machinery.SourceFileLoader('_pkg', mapping['_pkg.__init__']) with warnings.catch_warnings(): @@ -98,7 +97,7 @@ class SimpleTest(abc.LoaderTests): def test_lacking_parent(self): - with source_util.create_modules('_pkg.__init__', '_pkg.mod')as mapping: + with util.create_modules('_pkg.__init__', '_pkg.mod')as mapping: loader = self.machinery.SourceFileLoader('_pkg.mod', mapping['_pkg.mod']) with warnings.catch_warnings(): @@ -115,7 +114,7 @@ class SimpleTest(abc.LoaderTests): return lambda name: fxn(name) + 1 def test_module_reuse(self): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: loader = self.machinery.SourceFileLoader('_temp', mapping['_temp']) with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) @@ -139,7 +138,7 @@ class SimpleTest(abc.LoaderTests): attributes = ('__file__', '__path__', '__package__') value = '<test>' name = '_temp' - with source_util.create_modules(name) as mapping: + with util.create_modules(name) as mapping: orig_module = types.ModuleType(name) for attr in attributes: setattr(orig_module, attr, value) @@ -159,7 +158,7 @@ class SimpleTest(abc.LoaderTests): # [syntax error] def test_bad_syntax(self): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: with open(mapping['_temp'], 'w') as file: file.write('=') loader = self.machinery.SourceFileLoader('_temp', mapping['_temp']) @@ -190,11 +189,11 @@ class SimpleTest(abc.LoaderTests): if os.path.exists(pycache): shutil.rmtree(pycache) - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_timestamp_overflow(self): # When a modification timestamp is larger than 2**32, it should be # truncated rather than raise an OverflowError. - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: source = mapping['_temp'] compiled = self.util.cache_from_source(source) with open(source, 'w') as f: @@ -236,9 +235,11 @@ class SimpleTest(abc.LoaderTests): warnings.simplefilter('ignore', DeprecationWarning) loader.load_module('bad name') -Frozen_SimpleTest, Source_SimpleTest = util.test_both( - SimpleTest, importlib=importlib, machinery=machinery, abc=importlib_abc, - util=importlib_util) + +(Frozen_SimpleTest, + Source_SimpleTest + ) = util.test_both(SimpleTest, importlib=importlib, machinery=machinery, + abc=importlib_abc, util=importlib_util) class BadBytecodeTest: @@ -275,45 +276,45 @@ class BadBytecodeTest: return bytecode_path def _test_empty_file(self, test, *, del_source=False): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, lambda bc: b'', del_source=del_source) test('_temp', mapping, bc_path) - @source_util.writes_bytecode_files + @util.writes_bytecode_files def _test_partial_magic(self, test, *, del_source=False): # When their are less than 4 bytes to a .pyc, regenerate it if # possible, else raise ImportError. - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, lambda bc: bc[:3], del_source=del_source) test('_temp', mapping, bc_path) def _test_magic_only(self, test, *, del_source=False): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, lambda bc: bc[:4], del_source=del_source) test('_temp', mapping, bc_path) def _test_partial_timestamp(self, test, *, del_source=False): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, lambda bc: bc[:7], del_source=del_source) test('_temp', mapping, bc_path) def _test_partial_size(self, test, *, del_source=False): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, lambda bc: bc[:11], del_source=del_source) test('_temp', mapping, bc_path) def _test_no_marshal(self, *, del_source=False): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, lambda bc: bc[:12], del_source=del_source) @@ -322,7 +323,7 @@ class BadBytecodeTest: self.import_(file_path, '_temp') def _test_non_code_marshal(self, *, del_source=False): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: bytecode_path = self.manipulate_bytecode('_temp', mapping, lambda bc: bc[:12] + marshal.dumps(b'abcd'), del_source=del_source) @@ -333,7 +334,7 @@ class BadBytecodeTest: self.assertEqual(cm.exception.path, bytecode_path) def _test_bad_marshal(self, *, del_source=False): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: bytecode_path = self.manipulate_bytecode('_temp', mapping, lambda bc: bc[:12] + b'<test>', del_source=del_source) @@ -342,11 +343,12 @@ class BadBytecodeTest: self.import_(file_path, '_temp') def _test_bad_magic(self, test, *, del_source=False): - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: bc_path = self.manipulate_bytecode('_temp', mapping, lambda bc: b'\x00\x00\x00\x00' + bc[4:]) test('_temp', mapping, bc_path) + class BadBytecodeTestPEP451(BadBytecodeTest): def import_(self, file, module_name): @@ -355,6 +357,7 @@ class BadBytecodeTestPEP451(BadBytecodeTest): module.__spec__ = self.util.spec_from_loader(module_name, loader) loader.exec_module(module) + class BadBytecodeTestPEP302(BadBytecodeTest): def import_(self, file, module_name): @@ -371,7 +374,7 @@ class SourceLoaderBadBytecodeTest: def setUpClass(cls): cls.loader = cls.machinery.SourceFileLoader - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_empty_file(self): # When a .pyc is empty, regenerate it if possible, else raise # ImportError. @@ -390,7 +393,7 @@ class SourceLoaderBadBytecodeTest: self._test_partial_magic(test) - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_magic_only(self): # When there is only the magic number, regenerate the .pyc if possible, # else raise EOFError. @@ -401,7 +404,7 @@ class SourceLoaderBadBytecodeTest: self._test_magic_only(test) - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_bad_magic(self): # When the magic number is different, the bytecode should be # regenerated. @@ -413,7 +416,7 @@ class SourceLoaderBadBytecodeTest: self._test_bad_magic(test) - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_partial_timestamp(self): # When the timestamp is partial, regenerate the .pyc, else # raise EOFError. @@ -424,7 +427,7 @@ class SourceLoaderBadBytecodeTest: self._test_partial_timestamp(test) - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_partial_size(self): # When the size is partial, regenerate the .pyc, else # raise EOFError. @@ -435,29 +438,29 @@ class SourceLoaderBadBytecodeTest: self._test_partial_size(test) - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_no_marshal(self): # When there is only the magic number and timestamp, raise EOFError. self._test_no_marshal() - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_non_code_marshal(self): self._test_non_code_marshal() # XXX ImportError when sourceless # [bad marshal] - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_bad_marshal(self): # Bad marshal data should raise a ValueError. self._test_bad_marshal() # [bad timestamp] - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_old_timestamp(self): # When the timestamp is older than the source, bytecode should be # regenerated. zeros = b'\x00\x00\x00\x00' - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: py_compile.compile(mapping['_temp']) bytecode_path = self.util.cache_from_source(mapping['_temp']) with open(bytecode_path, 'r+b') as bytecode_file: @@ -471,10 +474,10 @@ class SourceLoaderBadBytecodeTest: self.assertEqual(bytecode_file.read(4), source_timestamp) # [bytecode read-only] - @source_util.writes_bytecode_files + @util.writes_bytecode_files def test_read_only_bytecode(self): # When bytecode is read-only but should be rewritten, fail silently. - with source_util.create_modules('_temp') as mapping: + with util.create_modules('_temp') as mapping: # Create bytecode that will need to be re-created. py_compile.compile(mapping['_temp']) bytecode_path = self.util.cache_from_source(mapping['_temp']) @@ -491,21 +494,29 @@ class SourceLoaderBadBytecodeTest: # Make writable for eventual clean-up. os.chmod(bytecode_path, stat.S_IWUSR) + class SourceLoaderBadBytecodeTestPEP451( SourceLoaderBadBytecodeTest, BadBytecodeTestPEP451): pass -Frozen_SourceBadBytecodePEP451, Source_SourceBadBytecodePEP451 = util.test_both( - SourceLoaderBadBytecodeTestPEP451, importlib=importlib, machinery=machinery, - abc=importlib_abc, util=importlib_util) + +(Frozen_SourceBadBytecodePEP451, + Source_SourceBadBytecodePEP451 + ) = util.test_both(SourceLoaderBadBytecodeTestPEP451, importlib=importlib, + machinery=machinery, abc=importlib_abc, + util=importlib_util) + class SourceLoaderBadBytecodeTestPEP302( SourceLoaderBadBytecodeTest, BadBytecodeTestPEP302): pass -Frozen_SourceBadBytecodePEP302, Source_SourceBadBytecodePEP302 = util.test_both( - SourceLoaderBadBytecodeTestPEP302, importlib=importlib, machinery=machinery, - abc=importlib_abc, util=importlib_util) + +(Frozen_SourceBadBytecodePEP302, + Source_SourceBadBytecodePEP302 + ) = util.test_both(SourceLoaderBadBytecodeTestPEP302, importlib=importlib, + machinery=machinery, abc=importlib_abc, + util=importlib_util) class SourcelessLoaderBadBytecodeTest: @@ -567,21 +578,29 @@ class SourcelessLoaderBadBytecodeTest: def test_non_code_marshal(self): self._test_non_code_marshal(del_source=True) + class SourcelessLoaderBadBytecodeTestPEP451(SourcelessLoaderBadBytecodeTest, BadBytecodeTestPEP451): pass -Frozen_SourcelessBadBytecodePEP451, Source_SourcelessBadBytecodePEP451 = util.test_both( - SourcelessLoaderBadBytecodeTestPEP451, importlib=importlib, - machinery=machinery, abc=importlib_abc, util=importlib_util) + +(Frozen_SourcelessBadBytecodePEP451, + Source_SourcelessBadBytecodePEP451 + ) = util.test_both(SourcelessLoaderBadBytecodeTestPEP451, importlib=importlib, + machinery=machinery, abc=importlib_abc, + util=importlib_util) + class SourcelessLoaderBadBytecodeTestPEP302(SourcelessLoaderBadBytecodeTest, BadBytecodeTestPEP302): pass -Frozen_SourcelessBadBytecodePEP302, Source_SourcelessBadBytecodePEP302 = util.test_both( - SourcelessLoaderBadBytecodeTestPEP302, importlib=importlib, - machinery=machinery, abc=importlib_abc, util=importlib_util) + +(Frozen_SourcelessBadBytecodePEP302, + Source_SourcelessBadBytecodePEP302 + ) = util.test_both(SourcelessLoaderBadBytecodeTestPEP302, importlib=importlib, + machinery=machinery, abc=importlib_abc, + util=importlib_util) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/source/test_finder.py b/Lib/test/test_importlib/source/test_finder.py index 473297b..f372b85 100644 --- a/Lib/test/test_importlib/source/test_finder.py +++ b/Lib/test/test_importlib/source/test_finder.py @@ -1,6 +1,5 @@ from .. import abc from .. import util -from . import util as source_util machinery = util.import_importlib('importlib.machinery') @@ -60,7 +59,7 @@ class FinderTests(abc.FinderTests): """ if create is None: create = {test} - with source_util.create_modules(*create) as mapping: + with util.create_modules(*create) as mapping: if compile_: for name in compile_: py_compile.compile(mapping[name]) @@ -100,14 +99,14 @@ class FinderTests(abc.FinderTests): # [sub module] def test_module_in_package(self): - with source_util.create_modules('pkg.__init__', 'pkg.sub') as mapping: + with util.create_modules('pkg.__init__', 'pkg.sub') as mapping: pkg_dir = os.path.dirname(mapping['pkg.__init__']) loader = self.import_(pkg_dir, 'pkg.sub') self.assertTrue(hasattr(loader, 'load_module')) # [sub package] def test_package_in_package(self): - context = source_util.create_modules('pkg.__init__', 'pkg.sub.__init__') + context = util.create_modules('pkg.__init__', 'pkg.sub.__init__') with context as mapping: pkg_dir = os.path.dirname(mapping['pkg.__init__']) loader = self.import_(pkg_dir, 'pkg.sub') @@ -120,7 +119,7 @@ class FinderTests(abc.FinderTests): self.assertIn('__init__', loader.get_filename(name)) def test_failure(self): - with source_util.create_modules('blah') as mapping: + with util.create_modules('blah') as mapping: nothing = self.import_(mapping['.root'], 'sdfsadsadf') self.assertIsNone(nothing) @@ -147,7 +146,7 @@ class FinderTests(abc.FinderTests): # Regression test for http://bugs.python.org/issue14846 def test_dir_removal_handling(self): mod = 'mod' - with source_util.create_modules(mod) as mapping: + with util.create_modules(mod) as mapping: finder = self.get_finder(mapping['.root']) found = self._find(finder, 'mod', loader_only=True) self.assertIsNotNone(found) @@ -196,8 +195,10 @@ class FinderTestsPEP451(FinderTests): spec = finder.find_spec(name) return spec.loader if spec is not None else spec -Frozen_FinderTestsPEP451, Source_FinderTestsPEP451 = util.test_both( - FinderTestsPEP451, machinery=machinery) + +(Frozen_FinderTestsPEP451, + Source_FinderTestsPEP451 + ) = util.test_both(FinderTestsPEP451, machinery=machinery) class FinderTestsPEP420(FinderTests): @@ -210,8 +211,10 @@ class FinderTestsPEP420(FinderTests): loader_portions = finder.find_loader(name) return loader_portions[0] if loader_only else loader_portions -Frozen_FinderTestsPEP420, Source_FinderTestsPEP420 = util.test_both( - FinderTestsPEP420, machinery=machinery) + +(Frozen_FinderTestsPEP420, + Source_FinderTestsPEP420 + ) = util.test_both(FinderTestsPEP420, machinery=machinery) class FinderTestsPEP302(FinderTests): @@ -223,9 +226,10 @@ class FinderTestsPEP302(FinderTests): warnings.simplefilter("ignore", DeprecationWarning) return finder.find_module(name) -Frozen_FinderTestsPEP302, Source_FinderTestsPEP302 = util.test_both( - FinderTestsPEP302, machinery=machinery) +(Frozen_FinderTestsPEP302, + Source_FinderTestsPEP302 + ) = util.test_both(FinderTestsPEP302, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/source/test_path_hook.py b/Lib/test/test_importlib/source/test_path_hook.py index 92da772..e6a2415 100644 --- a/Lib/test/test_importlib/source/test_path_hook.py +++ b/Lib/test/test_importlib/source/test_path_hook.py @@ -1,5 +1,4 @@ from .. import util -from . import util as source_util machinery = util.import_importlib('importlib.machinery') @@ -15,7 +14,7 @@ class PathHookTest: self.machinery.SOURCE_SUFFIXES)) def test_success(self): - with source_util.create_modules('dummy') as mapping: + with util.create_modules('dummy') as mapping: self.assertTrue(hasattr(self.path_hook()(mapping['.root']), 'find_module')) @@ -23,7 +22,10 @@ class PathHookTest: # The empty string represents the cwd. self.assertTrue(hasattr(self.path_hook()(''), 'find_module')) -Frozen_PathHookTest, Source_PathHooktest = util.test_both(PathHookTest, machinery=machinery) + +(Frozen_PathHookTest, + Source_PathHooktest + ) = util.test_both(PathHookTest, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/source/test_source_encoding.py b/Lib/test/test_importlib/source/test_source_encoding.py index c62dfa1..b604afb 100644 --- a/Lib/test/test_importlib/source/test_source_encoding.py +++ b/Lib/test/test_importlib/source/test_source_encoding.py @@ -1,5 +1,4 @@ from .. import util -from . import util as source_util machinery = util.import_importlib('importlib.machinery') @@ -37,7 +36,7 @@ class EncodingTest: module_name = '_temp' def run_test(self, source): - with source_util.create_modules(self.module_name) as mapping: + with util.create_modules(self.module_name) as mapping: with open(mapping[self.module_name], 'wb') as file: file.write(source) loader = self.machinery.SourceFileLoader(self.module_name, @@ -89,6 +88,7 @@ class EncodingTest: with self.assertRaises(SyntaxError): self.run_test(source) + class EncodingTestPEP451(EncodingTest): def load(self, loader): @@ -97,8 +97,11 @@ class EncodingTestPEP451(EncodingTest): loader.exec_module(module) return module -Frozen_EncodingTestPEP451, Source_EncodingTestPEP451 = util.test_both( - EncodingTestPEP451, machinery=machinery) + +(Frozen_EncodingTestPEP451, + Source_EncodingTestPEP451 + ) = util.test_both(EncodingTestPEP451, machinery=machinery) + class EncodingTestPEP302(EncodingTest): @@ -107,8 +110,10 @@ class EncodingTestPEP302(EncodingTest): warnings.simplefilter('ignore', DeprecationWarning) return loader.load_module(self.module_name) -Frozen_EncodingTestPEP302, Source_EncodingTestPEP302 = util.test_both( - EncodingTestPEP302, machinery=machinery) + +(Frozen_EncodingTestPEP302, + Source_EncodingTestPEP302 + ) = util.test_both(EncodingTestPEP302, machinery=machinery) class LineEndingTest: @@ -120,7 +125,7 @@ class LineEndingTest: module_name = '_temp' source_lines = [b"a = 42", b"b = -13", b''] source = line_ending.join(source_lines) - with source_util.create_modules(module_name) as mapping: + with util.create_modules(module_name) as mapping: with open(mapping[module_name], 'wb') as file: file.write(source) loader = self.machinery.SourceFileLoader(module_name, @@ -139,6 +144,7 @@ class LineEndingTest: def test_lf(self): self.run_test(b'\n') + class LineEndingTestPEP451(LineEndingTest): def load(self, loader, module_name): @@ -147,8 +153,11 @@ class LineEndingTestPEP451(LineEndingTest): loader.exec_module(module) return module -Frozen_LineEndingTestPEP451, Source_LineEndingTestPEP451 = util.test_both( - LineEndingTestPEP451, machinery=machinery) + +(Frozen_LineEndingTestPEP451, + Source_LineEndingTestPEP451 + ) = util.test_both(LineEndingTestPEP451, machinery=machinery) + class LineEndingTestPEP302(LineEndingTest): @@ -157,8 +166,10 @@ class LineEndingTestPEP302(LineEndingTest): warnings.simplefilter('ignore', DeprecationWarning) return loader.load_module(module_name) -Frozen_LineEndingTestPEP302, Source_LineEndingTestPEP302 = util.test_both( - LineEndingTestPEP302, machinery=machinery) + +(Frozen_LineEndingTestPEP302, + Source_LineEndingTestPEP302 + ) = util.test_both(LineEndingTestPEP302, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/source/util.py b/Lib/test/test_importlib/source/util.py deleted file mode 100644 index 63cd25a..0000000 --- a/Lib/test/test_importlib/source/util.py +++ /dev/null @@ -1,96 +0,0 @@ -from .. import util -import contextlib -import errno -import functools -import os -import os.path -import sys -import tempfile -from test import support - - -def writes_bytecode_files(fxn): - """Decorator to protect sys.dont_write_bytecode from mutation and to skip - tests that require it to be set to False.""" - if sys.dont_write_bytecode: - return lambda *args, **kwargs: None - @functools.wraps(fxn) - def wrapper(*args, **kwargs): - original = sys.dont_write_bytecode - sys.dont_write_bytecode = False - try: - to_return = fxn(*args, **kwargs) - finally: - sys.dont_write_bytecode = original - return to_return - return wrapper - - -def ensure_bytecode_path(bytecode_path): - """Ensure that the __pycache__ directory for PEP 3147 pyc file exists. - - :param bytecode_path: File system path to PEP 3147 pyc file. - """ - try: - os.mkdir(os.path.dirname(bytecode_path)) - except OSError as error: - if error.errno != errno.EEXIST: - raise - - -@contextlib.contextmanager -def create_modules(*names): - """Temporarily create each named module with an attribute (named 'attr') - that contains the name passed into the context manager that caused the - creation of the module. - - All files are created in a temporary directory returned by - tempfile.mkdtemp(). This directory is inserted at the beginning of - sys.path. When the context manager exits all created files (source and - bytecode) are explicitly deleted. - - No magic is performed when creating packages! This means that if you create - a module within a package you must also create the package's __init__ as - well. - - """ - source = 'attr = {0!r}' - created_paths = [] - mapping = {} - state_manager = None - uncache_manager = None - try: - temp_dir = tempfile.mkdtemp() - mapping['.root'] = temp_dir - import_names = set() - for name in names: - if not name.endswith('__init__'): - import_name = name - else: - import_name = name[:-len('.__init__')] - import_names.add(import_name) - if import_name in sys.modules: - del sys.modules[import_name] - name_parts = name.split('.') - file_path = temp_dir - for directory in name_parts[:-1]: - file_path = os.path.join(file_path, directory) - if not os.path.exists(file_path): - os.mkdir(file_path) - created_paths.append(file_path) - file_path = os.path.join(file_path, name_parts[-1] + '.py') - with open(file_path, 'w') as file: - file.write(source.format(name)) - created_paths.append(file_path) - mapping[name] = file_path - uncache_manager = util.uncache(*import_names) - uncache_manager.__enter__() - state_manager = util.import_state(path=[temp_dir]) - state_manager.__enter__() - yield mapping - finally: - if state_manager is not None: - state_manager.__exit__(None, None, None) - if uncache_manager is not None: - uncache_manager.__exit__(None, None, None) - support.rmtree(temp_dir) diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py index a1f8e76..d4bf915 100644 --- a/Lib/test/test_importlib/test_abc.py +++ b/Lib/test/test_importlib/test_abc.py @@ -10,12 +10,13 @@ import unittest from unittest import mock import warnings -from . import util +from . import util as test_util + +init = test_util.import_importlib('importlib') +abc = test_util.import_importlib('importlib.abc') +machinery = test_util.import_importlib('importlib.machinery') +util = test_util.import_importlib('importlib.util') -frozen_init, source_init = util.import_importlib('importlib') -frozen_abc, source_abc = util.import_importlib('importlib.abc') -machinery = util.import_importlib('importlib.machinery') -frozen_util, source_util = util.import_importlib('importlib.util') ##### Inheritance ############################################################## class InheritanceTests: @@ -26,8 +27,7 @@ class InheritanceTests: subclasses = [] superclasses = [] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def setUp(self): self.superclasses = [getattr(self.abc, class_name) for class_name in self.superclass_names] if hasattr(self, 'subclass_names'): @@ -36,11 +36,11 @@ class InheritanceTests: # checking across module boundaries (i.e. the _bootstrap in abc is # not the same as the one in machinery). That means stealing one of # the modules from the other to make sure the same instance is used. - self.subclasses = [getattr(self.abc.machinery, class_name) - for class_name in self.subclass_names] + machinery = self.abc.machinery + self.subclasses = [getattr(machinery, class_name) + for class_name in self.subclass_names] assert self.subclasses or self.superclasses, self.__class__ - testing = self.__class__.__name__.partition('_')[2] - self.__test = getattr(self.abc, testing) + self.__test = getattr(self.abc, self._NAME) def test_subclasses(self): # Test that the expected subclasses inherit. @@ -54,94 +54,97 @@ class InheritanceTests: self.assertTrue(issubclass(self.__test, superclass), "{0} is not a superclass of {1}".format(superclass, self.__test)) -def create_inheritance_tests(base_class): - def set_frozen(ns): - ns['abc'] = frozen_abc - def set_source(ns): - ns['abc'] = source_abc - - classes = [] - for prefix, ns_set in [('Frozen', set_frozen), ('Source', set_source)]: - classes.append(types.new_class('_'.join([prefix, base_class.__name__]), - (base_class, unittest.TestCase), - exec_body=ns_set)) - return classes - class MetaPathFinder(InheritanceTests): superclass_names = ['Finder'] subclass_names = ['BuiltinImporter', 'FrozenImporter', 'PathFinder', 'WindowsRegistryFinder'] -tests = create_inheritance_tests(MetaPathFinder) -Frozen_MetaPathFinderInheritanceTests, Source_MetaPathFinderInheritanceTests = tests + +(Frozen_MetaPathFinderInheritanceTests, + Source_MetaPathFinderInheritanceTests + ) = test_util.test_both(MetaPathFinder, abc=abc) class PathEntryFinder(InheritanceTests): superclass_names = ['Finder'] subclass_names = ['FileFinder'] -tests = create_inheritance_tests(PathEntryFinder) -Frozen_PathEntryFinderInheritanceTests, Source_PathEntryFinderInheritanceTests = tests + +(Frozen_PathEntryFinderInheritanceTests, + Source_PathEntryFinderInheritanceTests + ) = test_util.test_both(PathEntryFinder, abc=abc) class ResourceLoader(InheritanceTests): superclass_names = ['Loader'] -tests = create_inheritance_tests(ResourceLoader) -Frozen_ResourceLoaderInheritanceTests, Source_ResourceLoaderInheritanceTests = tests + +(Frozen_ResourceLoaderInheritanceTests, + Source_ResourceLoaderInheritanceTests + ) = test_util.test_both(ResourceLoader, abc=abc) class InspectLoader(InheritanceTests): superclass_names = ['Loader'] subclass_names = ['BuiltinImporter', 'FrozenImporter', 'ExtensionFileLoader'] -tests = create_inheritance_tests(InspectLoader) -Frozen_InspectLoaderInheritanceTests, Source_InspectLoaderInheritanceTests = tests + +(Frozen_InspectLoaderInheritanceTests, + Source_InspectLoaderInheritanceTests + ) = test_util.test_both(InspectLoader, abc=abc) class ExecutionLoader(InheritanceTests): superclass_names = ['InspectLoader'] subclass_names = ['ExtensionFileLoader'] -tests = create_inheritance_tests(ExecutionLoader) -Frozen_ExecutionLoaderInheritanceTests, Source_ExecutionLoaderInheritanceTests = tests + +(Frozen_ExecutionLoaderInheritanceTests, + Source_ExecutionLoaderInheritanceTests + ) = test_util.test_both(ExecutionLoader, abc=abc) class FileLoader(InheritanceTests): superclass_names = ['ResourceLoader', 'ExecutionLoader'] subclass_names = ['SourceFileLoader', 'SourcelessFileLoader'] -tests = create_inheritance_tests(FileLoader) -Frozen_FileLoaderInheritanceTests, Source_FileLoaderInheritanceTests = tests + +(Frozen_FileLoaderInheritanceTests, + Source_FileLoaderInheritanceTests + ) = test_util.test_both(FileLoader, abc=abc) class SourceLoader(InheritanceTests): superclass_names = ['ResourceLoader', 'ExecutionLoader'] subclass_names = ['SourceFileLoader'] -tests = create_inheritance_tests(SourceLoader) -Frozen_SourceLoaderInheritanceTests, Source_SourceLoaderInheritanceTests = tests + +(Frozen_SourceLoaderInheritanceTests, + Source_SourceLoaderInheritanceTests + ) = test_util.test_both(SourceLoader, abc=abc) + ##### Default return values #################################################### -def make_abc_subclasses(base_class): - classes = [] - for kind, abc in [('Frozen', frozen_abc), ('Source', source_abc)]: - name = '_'.join([kind, base_class.__name__]) - base_classes = base_class, getattr(abc, base_class.__name__) - classes.append(types.new_class(name, base_classes)) - return classes - -def make_return_value_tests(base_class, test_class): - frozen_class, source_class = make_abc_subclasses(base_class) - tests = [] - for prefix, class_in_test in [('Frozen', frozen_class), ('Source', source_class)]: - def set_ns(ns): - ns['ins'] = class_in_test() - tests.append(types.new_class('_'.join([prefix, test_class.__name__]), - (test_class, unittest.TestCase), - exec_body=set_ns)) - return tests + +def make_abc_subclasses(base_class, name=None, inst=False, **kwargs): + if name is None: + name = base_class.__name__ + base = {kind: getattr(splitabc, name) + for kind, splitabc in abc.items()} + return {cls._KIND: cls() if inst else cls + for cls in test_util.split_frozen(base_class, base, **kwargs)} + + +class ABCTestHarness: + + @property + def ins(self): + # Lazily set ins on the class. + cls = self.SPLIT[self._KIND] + ins = cls() + self.__class__.ins = ins + return ins class MetaPathFinder: @@ -149,10 +152,10 @@ class MetaPathFinder: def find_module(self, fullname, path): return super().find_module(fullname, path) -Frozen_MPF, Source_MPF = make_abc_subclasses(MetaPathFinder) +class MetaPathFinderDefaultsTests(ABCTestHarness): -class MetaPathFinderDefaultsTests: + SPLIT = make_abc_subclasses(MetaPathFinder) def test_find_module(self): # Default should return None. @@ -163,8 +166,9 @@ class MetaPathFinderDefaultsTests: self.ins.invalidate_caches() -tests = make_return_value_tests(MetaPathFinder, MetaPathFinderDefaultsTests) -Frozen_MPFDefaultTests, Source_MPFDefaultTests = tests +(Frozen_MPFDefaultTests, + Source_MPFDefaultTests + ) = test_util.test_both(MetaPathFinderDefaultsTests) class PathEntryFinder: @@ -172,10 +176,10 @@ class PathEntryFinder: def find_loader(self, fullname): return super().find_loader(fullname) -Frozen_PEF, Source_PEF = make_abc_subclasses(PathEntryFinder) +class PathEntryFinderDefaultsTests(ABCTestHarness): -class PathEntryFinderDefaultsTests: + SPLIT = make_abc_subclasses(PathEntryFinder) def test_find_loader(self): self.assertEqual((None, []), self.ins.find_loader('something')) @@ -188,8 +192,9 @@ class PathEntryFinderDefaultsTests: self.ins.invalidate_caches() -tests = make_return_value_tests(PathEntryFinder, PathEntryFinderDefaultsTests) -Frozen_PEFDefaultTests, Source_PEFDefaultTests = tests +(Frozen_PEFDefaultTests, + Source_PEFDefaultTests + ) = test_util.test_both(PathEntryFinderDefaultsTests) class Loader: @@ -198,10 +203,9 @@ class Loader: return super().load_module(fullname) -Frozen_L, Source_L = make_abc_subclasses(Loader) +class LoaderDefaultsTests(ABCTestHarness): - -class LoaderDefaultsTests: + SPLIT = make_abc_subclasses(Loader) def test_load_module(self): with self.assertRaises(ImportError): @@ -217,8 +221,9 @@ class LoaderDefaultsTests: self.assertTrue(repr(mod)) -tests = make_return_value_tests(Loader, LoaderDefaultsTests) -Frozen_LDefaultTests, SourceLDefaultTests = tests +(Frozen_LDefaultTests, + SourceLDefaultTests + ) = test_util.test_both(LoaderDefaultsTests) class ResourceLoader(Loader): @@ -227,18 +232,18 @@ class ResourceLoader(Loader): return super().get_data(path) -Frozen_RL, Source_RL = make_abc_subclasses(ResourceLoader) - +class ResourceLoaderDefaultsTests(ABCTestHarness): -class ResourceLoaderDefaultsTests: + SPLIT = make_abc_subclasses(ResourceLoader) def test_get_data(self): with self.assertRaises(IOError): self.ins.get_data('/some/path') -tests = make_return_value_tests(ResourceLoader, ResourceLoaderDefaultsTests) -Frozen_RLDefaultTests, Source_RLDefaultTests = tests +(Frozen_RLDefaultTests, + Source_RLDefaultTests + ) = test_util.test_both(ResourceLoaderDefaultsTests) class InspectLoader(Loader): @@ -250,10 +255,12 @@ class InspectLoader(Loader): return super().get_source(fullname) -Frozen_IL, Source_IL = make_abc_subclasses(InspectLoader) +SPLIT_IL = make_abc_subclasses(InspectLoader) -class InspectLoaderDefaultsTests: +class InspectLoaderDefaultsTests(ABCTestHarness): + + SPLIT = SPLIT_IL def test_is_package(self): with self.assertRaises(ImportError): @@ -264,8 +271,9 @@ class InspectLoaderDefaultsTests: self.ins.get_source('blah') -tests = make_return_value_tests(InspectLoader, InspectLoaderDefaultsTests) -Frozen_ILDefaultTests, Source_ILDefaultTests = tests +(Frozen_ILDefaultTests, + Source_ILDefaultTests + ) = test_util.test_both(InspectLoaderDefaultsTests) class ExecutionLoader(InspectLoader): @@ -273,21 +281,25 @@ class ExecutionLoader(InspectLoader): def get_filename(self, fullname): return super().get_filename(fullname) -Frozen_EL, Source_EL = make_abc_subclasses(ExecutionLoader) + +SPLIT_EL = make_abc_subclasses(ExecutionLoader) -class ExecutionLoaderDefaultsTests: +class ExecutionLoaderDefaultsTests(ABCTestHarness): + + SPLIT = SPLIT_EL def test_get_filename(self): with self.assertRaises(ImportError): self.ins.get_filename('blah') -tests = make_return_value_tests(ExecutionLoader, InspectLoaderDefaultsTests) -Frozen_ELDefaultTests, Source_ELDefaultsTests = tests +(Frozen_ELDefaultTests, + Source_ELDefaultsTests + ) = test_util.test_both(InspectLoaderDefaultsTests) -##### MetaPathFinder concrete methods ########################################## +##### MetaPathFinder concrete methods ########################################## class MetaPathFinderFindModuleTests: @classmethod @@ -317,13 +329,12 @@ class MetaPathFinderFindModuleTests: self.assertIs(found, spec.loader) -Frozen_MPFFindModuleTests, Source_MPFFindModuleTests = util.test_both( - MetaPathFinderFindModuleTests, - abc=(frozen_abc, source_abc), - util=(frozen_util, source_util)) +(Frozen_MPFFindModuleTests, + Source_MPFFindModuleTests + ) = test_util.test_both(MetaPathFinderFindModuleTests, abc=abc, util=util) -##### PathEntryFinder concrete methods ######################################### +##### PathEntryFinder concrete methods ######################################### class PathEntryFinderFindLoaderTests: @classmethod @@ -361,11 +372,10 @@ class PathEntryFinderFindLoaderTests: self.assertEqual(paths, found[1]) -Frozen_PEFFindLoaderTests, Source_PEFFindLoaderTests = util.test_both( - PathEntryFinderFindLoaderTests, - abc=(frozen_abc, source_abc), - machinery=machinery, - util=(frozen_util, source_util)) +(Frozen_PEFFindLoaderTests, + Source_PEFFindLoaderTests + ) = test_util.test_both(PathEntryFinderFindLoaderTests, abc=abc, util=util, + machinery=machinery) ##### Loader concrete methods ################################################## @@ -386,7 +396,7 @@ class LoaderLoadModuleTests: def test_fresh(self): loader = self.loader() name = 'blah' - with util.uncache(name): + with test_util.uncache(name): loader.load_module(name) module = loader.found self.assertIs(sys.modules[name], module) @@ -404,7 +414,7 @@ class LoaderLoadModuleTests: module = types.ModuleType(name) module.__spec__ = self.util.spec_from_loader(name, loader) module.__loader__ = loader - with util.uncache(name): + with test_util.uncache(name): sys.modules[name] = module loader.load_module(name) found = loader.found @@ -412,10 +422,9 @@ class LoaderLoadModuleTests: self.assertIs(module, sys.modules[name]) -Frozen_LoaderLoadModuleTests, Source_LoaderLoadModuleTests = util.test_both( - LoaderLoadModuleTests, - abc=(frozen_abc, source_abc), - util=(frozen_util, source_util)) +(Frozen_LoaderLoadModuleTests, + Source_LoaderLoadModuleTests + ) = test_util.test_both(LoaderLoadModuleTests, abc=abc, util=util) ##### InspectLoader concrete methods ########################################### @@ -461,11 +470,10 @@ class InspectLoaderSourceToCodeTests: self.assertEqual(code.co_filename, '<string>') -class Frozen_ILSourceToCodeTests(InspectLoaderSourceToCodeTests, unittest.TestCase): - InspectLoaderSubclass = Frozen_IL - -class Source_ILSourceToCodeTests(InspectLoaderSourceToCodeTests, unittest.TestCase): - InspectLoaderSubclass = Source_IL +(Frozen_ILSourceToCodeTests, + Source_ILSourceToCodeTests + ) = test_util.test_both(InspectLoaderSourceToCodeTests, + InspectLoaderSubclass=SPLIT_IL) class InspectLoaderGetCodeTests: @@ -495,11 +503,10 @@ class InspectLoaderGetCodeTests: loader.get_code('blah') -class Frozen_ILGetCodeTests(InspectLoaderGetCodeTests, unittest.TestCase): - InspectLoaderSubclass = Frozen_IL - -class Source_ILGetCodeTests(InspectLoaderGetCodeTests, unittest.TestCase): - InspectLoaderSubclass = Source_IL +(Frozen_ILGetCodeTests, + Source_ILGetCodeTests + ) = test_util.test_both(InspectLoaderGetCodeTests, + InspectLoaderSubclass=SPLIT_IL) class InspectLoaderLoadModuleTests: @@ -543,11 +550,10 @@ class InspectLoaderLoadModuleTests: self.assertEqual(module, sys.modules[self.module_name]) -class Frozen_ILLoadModuleTests(InspectLoaderLoadModuleTests, unittest.TestCase): - InspectLoaderSubclass = Frozen_IL - -class Source_ILLoadModuleTests(InspectLoaderLoadModuleTests, unittest.TestCase): - InspectLoaderSubclass = Source_IL +(Frozen_ILLoadModuleTests, + Source_ILLoadModuleTests + ) = test_util.test_both(InspectLoaderLoadModuleTests, + InspectLoaderSubclass=SPLIT_IL) ##### ExecutionLoader concrete methods ######################################### @@ -608,15 +614,14 @@ class ExecutionLoaderGetCodeTests: self.assertEqual(module.attr, 42) -class Frozen_ELGetCodeTests(ExecutionLoaderGetCodeTests, unittest.TestCase): - ExecutionLoaderSubclass = Frozen_EL - -class Source_ELGetCodeTests(ExecutionLoaderGetCodeTests, unittest.TestCase): - ExecutionLoaderSubclass = Source_EL +(Frozen_ELGetCodeTests, + Source_ELGetCodeTests + ) = test_util.test_both(ExecutionLoaderGetCodeTests, + ExecutionLoaderSubclass=SPLIT_EL) ##### SourceLoader concrete methods ############################################ -class SourceLoader: +class SourceOnlyLoader: # Globals that should be defined for all modules. source = (b"_ = '::'.join([__name__, __file__, __cached__, __package__, " @@ -637,10 +642,10 @@ class SourceLoader: return '<module>' -Frozen_SourceOnlyL, Source_SourceOnlyL = make_abc_subclasses(SourceLoader) +SPLIT_SOL = make_abc_subclasses(SourceOnlyLoader, 'SourceLoader') -class SourceLoader(SourceLoader): +class SourceLoader(SourceOnlyLoader): source_mtime = 1 @@ -677,11 +682,7 @@ class SourceLoader(SourceLoader): return path == self.bytecode_path -Frozen_SL, Source_SL = make_abc_subclasses(SourceLoader) -Frozen_SL.util = frozen_util -Source_SL.util = source_util -Frozen_SL.init = frozen_init -Source_SL.init = source_init +SPLIT_SL = make_abc_subclasses(SourceLoader, util=util, init=init) class SourceLoaderTestHarness: @@ -765,7 +766,7 @@ class SourceOnlyLoaderTests(SourceLoaderTestHarness): # Loading a module should set __name__, __loader__, __package__, # __path__ (for packages), __file__, and __cached__. # The module should also be put into sys.modules. - with util.uncache(self.name): + with test_util.uncache(self.name): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) module = self.loader.load_module(self.name) @@ -778,7 +779,7 @@ class SourceOnlyLoaderTests(SourceLoaderTestHarness): # is a package. # Testing the values for a package are covered by test_load_module. self.setUp(is_package=False) - with util.uncache(self.name): + with test_util.uncache(self.name): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) module = self.loader.load_module(self.name) @@ -798,13 +799,10 @@ class SourceOnlyLoaderTests(SourceLoaderTestHarness): self.assertEqual(returned_source, source) -class Frozen_SourceOnlyLTests(SourceOnlyLoaderTests, unittest.TestCase): - loader_mock = Frozen_SourceOnlyL - util = frozen_util - -class Source_SourceOnlyLTests(SourceOnlyLoaderTests, unittest.TestCase): - loader_mock = Source_SourceOnlyL - util = source_util +(Frozen_SourceOnlyLoaderTests, + Source_SourceOnlyLoaderTests + ) = test_util.test_both(SourceOnlyLoaderTests, util=util, + loader_mock=SPLIT_SOL) @unittest.skipIf(sys.dont_write_bytecode, "sys.dont_write_bytecode is true") @@ -896,15 +894,10 @@ class SourceLoaderBytecodeTests(SourceLoaderTestHarness): self.verify_code(code_object) -class Frozen_SLBytecodeTests(SourceLoaderBytecodeTests, unittest.TestCase): - loader_mock = Frozen_SL - init = frozen_init - util = frozen_util - -class SourceSLBytecodeTests(SourceLoaderBytecodeTests, unittest.TestCase): - loader_mock = Source_SL - init = source_init - util = source_util +(Frozen_SLBytecodeTests, + SourceSLBytecodeTests + ) = test_util.test_both(SourceLoaderBytecodeTests, init=init, util=util, + loader_mock=SPLIT_SL) class SourceLoaderGetSourceTests: @@ -940,11 +933,10 @@ class SourceLoaderGetSourceTests: self.assertEqual(mock.get_source(name), expect) -class Frozen_SourceOnlyLGetSourceTests(SourceLoaderGetSourceTests, unittest.TestCase): - SourceOnlyLoaderMock = Frozen_SourceOnlyL - -class Source_SourceOnlyLGetSourceTests(SourceLoaderGetSourceTests, unittest.TestCase): - SourceOnlyLoaderMock = Source_SourceOnlyL +(Frozen_SourceOnlyLoaderGetSourceTests, + Source_SourceOnlyLoaderGetSourceTests + ) = test_util.test_both(SourceLoaderGetSourceTests, + SourceOnlyLoaderMock=SPLIT_SOL) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/test_api.py b/Lib/test/test_importlib/test_api.py index 2a2d42b..6bc3c56 100644 --- a/Lib/test/test_importlib/test_api.py +++ b/Lib/test/test_importlib/test_api.py @@ -1,8 +1,8 @@ -from . import util +from . import util as test_util -frozen_init, source_init = util.import_importlib('importlib') -frozen_util, source_util = util.import_importlib('importlib.util') -frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') +init = test_util.import_importlib('importlib') +util = test_util.import_importlib('importlib.util') +machinery = test_util.import_importlib('importlib.machinery') import os.path import sys @@ -18,8 +18,8 @@ class ImportModuleTests: def test_module_import(self): # Test importing a top-level module. - with util.mock_modules('top_level') as mock: - with util.import_state(meta_path=[mock]): + with test_util.mock_modules('top_level') as mock: + with test_util.import_state(meta_path=[mock]): module = self.init.import_module('top_level') self.assertEqual(module.__name__, 'top_level') @@ -28,8 +28,8 @@ class ImportModuleTests: pkg_name = 'pkg' pkg_long_name = '{0}.__init__'.format(pkg_name) name = '{0}.mod'.format(pkg_name) - with util.mock_modules(pkg_long_name, name) as mock: - with util.import_state(meta_path=[mock]): + with test_util.mock_modules(pkg_long_name, name) as mock: + with test_util.import_state(meta_path=[mock]): module = self.init.import_module(name) self.assertEqual(module.__name__, name) @@ -40,16 +40,16 @@ class ImportModuleTests: module_name = 'mod' absolute_name = '{0}.{1}'.format(pkg_name, module_name) relative_name = '.{0}'.format(module_name) - with util.mock_modules(pkg_long_name, absolute_name) as mock: - with util.import_state(meta_path=[mock]): + with test_util.mock_modules(pkg_long_name, absolute_name) as mock: + with test_util.import_state(meta_path=[mock]): self.init.import_module(pkg_name) module = self.init.import_module(relative_name, pkg_name) self.assertEqual(module.__name__, absolute_name) def test_deep_relative_package_import(self): modules = ['a.__init__', 'a.b.__init__', 'a.c'] - with util.mock_modules(*modules) as mock: - with util.import_state(meta_path=[mock]): + with test_util.mock_modules(*modules) as mock: + with test_util.import_state(meta_path=[mock]): self.init.import_module('a') self.init.import_module('a.b') module = self.init.import_module('..c', 'a.b') @@ -61,8 +61,8 @@ class ImportModuleTests: pkg_name = 'pkg' pkg_long_name = '{0}.__init__'.format(pkg_name) name = '{0}.mod'.format(pkg_name) - with util.mock_modules(pkg_long_name, name) as mock: - with util.import_state(meta_path=[mock]): + with test_util.mock_modules(pkg_long_name, name) as mock: + with test_util.import_state(meta_path=[mock]): self.init.import_module(pkg_name) module = self.init.import_module(name, pkg_name) self.assertEqual(module.__name__, name) @@ -86,16 +86,15 @@ class ImportModuleTests: b_load_count += 1 code = {'a': load_a, 'a.b': load_b} modules = ['a.__init__', 'a.b'] - with util.mock_modules(*modules, module_code=code) as mock: - with util.import_state(meta_path=[mock]): + with test_util.mock_modules(*modules, module_code=code) as mock: + with test_util.import_state(meta_path=[mock]): self.init.import_module('a.b') self.assertEqual(b_load_count, 1) -class Frozen_ImportModuleTests(ImportModuleTests, unittest.TestCase): - init = frozen_init -class Source_ImportModuleTests(ImportModuleTests, unittest.TestCase): - init = source_init +(Frozen_ImportModuleTests, + Source_ImportModuleTests + ) = test_util.test_both(ImportModuleTests, init=init) class FindLoaderTests: @@ -107,7 +106,7 @@ class FindLoaderTests: def test_sys_modules(self): # If a module with __loader__ is in sys.modules, then return it. name = 'some_mod' - with util.uncache(name): + with test_util.uncache(name): module = types.ModuleType(name) loader = 'a loader!' module.__loader__ = loader @@ -120,7 +119,7 @@ class FindLoaderTests: def test_sys_modules_loader_is_None(self): # If sys.modules[name].__loader__ is None, raise ValueError. name = 'some_mod' - with util.uncache(name): + with test_util.uncache(name): module = types.ModuleType(name) module.__loader__ = None sys.modules[name] = module @@ -133,7 +132,7 @@ class FindLoaderTests: # Should raise ValueError # Issue #17099 name = 'some_mod' - with util.uncache(name): + with test_util.uncache(name): module = types.ModuleType(name) try: del module.__loader__ @@ -148,8 +147,8 @@ class FindLoaderTests: def test_success(self): # Return the loader found on sys.meta_path. name = 'some_mod' - with util.uncache(name): - with util.import_state(meta_path=[self.FakeMetaFinder]): + with test_util.uncache(name): + with test_util.import_state(meta_path=[self.FakeMetaFinder]): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) self.assertEqual((name, None), self.init.find_loader(name)) @@ -158,8 +157,8 @@ class FindLoaderTests: # Searching on a path should work. name = 'some_mod' path = 'path to some place' - with util.uncache(name): - with util.import_state(meta_path=[self.FakeMetaFinder]): + with test_util.uncache(name): + with test_util.import_state(meta_path=[self.FakeMetaFinder]): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) self.assertEqual((name, path), @@ -171,11 +170,10 @@ class FindLoaderTests: warnings.simplefilter('ignore', DeprecationWarning) self.assertIsNone(self.init.find_loader('nevergoingtofindthismodule')) -class Frozen_FindLoaderTests(FindLoaderTests, unittest.TestCase): - init = frozen_init -class Source_FindLoaderTests(FindLoaderTests, unittest.TestCase): - init = source_init +(Frozen_FindLoaderTests, + Source_FindLoaderTests + ) = test_util.test_both(FindLoaderTests, init=init) class ReloadTests: @@ -195,10 +193,10 @@ class ReloadTests: module = type(sys)('top_level') module.spam = 3 sys.modules['top_level'] = module - mock = util.mock_modules('top_level', - module_code={'top_level': code}) + mock = test_util.mock_modules('top_level', + module_code={'top_level': code}) with mock: - with util.import_state(meta_path=[mock]): + with test_util.import_state(meta_path=[mock]): module = self.init.import_module('top_level') reloaded = self.init.reload(module) actual = sys.modules['top_level'] @@ -230,7 +228,7 @@ class ReloadTests: def test_reload_location_changed(self): name = 'spam' with support.temp_cwd(None) as cwd: - with util.uncache('spam'): + with test_util.uncache('spam'): with support.DirsOnSysPath(cwd): # Start as a plain module. self.init.invalidate_caches() @@ -281,7 +279,7 @@ class ReloadTests: def test_reload_namespace_changed(self): name = 'spam' with support.temp_cwd(None) as cwd: - with util.uncache('spam'): + with test_util.uncache('spam'): with support.DirsOnSysPath(cwd): # Start as a namespace package. self.init.invalidate_caches() @@ -338,20 +336,16 @@ class ReloadTests: # See #19851. name = 'spam' subname = 'ham' - with util.temp_module(name, pkg=True) as pkg_dir: - fullname, _ = util.submodule(name, subname, pkg_dir) + with test_util.temp_module(name, pkg=True) as pkg_dir: + fullname, _ = test_util.submodule(name, subname, pkg_dir) ham = self.init.import_module(fullname) reloaded = self.init.reload(ham) self.assertIs(reloaded, ham) -class Frozen_ReloadTests(ReloadTests, unittest.TestCase): - init = frozen_init - util = frozen_util - -class Source_ReloadTests(ReloadTests, unittest.TestCase): - init = source_init - util = source_util +(Frozen_ReloadTests, + Source_ReloadTests + ) = test_util.test_both(ReloadTests, init=init, util=util) class InvalidateCacheTests: @@ -384,11 +378,10 @@ class InvalidateCacheTests: self.addCleanup(lambda: sys.path_importer_cache.__delitem__(key)) self.init.invalidate_caches() # Shouldn't trigger an exception. -class Frozen_InvalidateCacheTests(InvalidateCacheTests, unittest.TestCase): - init = frozen_init -class Source_InvalidateCacheTests(InvalidateCacheTests, unittest.TestCase): - init = source_init +(Frozen_InvalidateCacheTests, + Source_InvalidateCacheTests + ) = test_util.test_both(InvalidateCacheTests, init=init) class FrozenImportlibTests(unittest.TestCase): @@ -398,6 +391,7 @@ class FrozenImportlibTests(unittest.TestCase): # Can't do an isinstance() check since separate copies of importlib # may have been used for import, so just check the name is not for the # frozen loader. + source_init = init['Source'] self.assertNotEqual(source_init.__loader__.__class__.__name__, 'FrozenImporter') @@ -426,11 +420,10 @@ class StartupTests: elif self.machinery.FrozenImporter.find_module(name): self.assertIsNot(module.__spec__, None) -class Frozen_StartupTests(StartupTests, unittest.TestCase): - machinery = frozen_machinery -class Source_StartupTests(StartupTests, unittest.TestCase): - machinery = source_machinery +(Frozen_StartupTests, + Source_StartupTests + ) = test_util.test_both(StartupTests, machinery=machinery) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/test_lazy.py b/Lib/test/test_importlib/test_lazy.py new file mode 100644 index 0000000..2e191bb --- /dev/null +++ b/Lib/test/test_importlib/test_lazy.py @@ -0,0 +1,132 @@ +import importlib +from importlib import abc +from importlib import util +import unittest + +from . import util as test_util + + +class CollectInit: + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + def exec_module(self, module): + return self + + +class LazyLoaderFactoryTests(unittest.TestCase): + + def test_init(self): + factory = util.LazyLoader.factory(CollectInit) + # E.g. what importlib.machinery.FileFinder instantiates loaders with + # plus keyword arguments. + lazy_loader = factory('module name', 'module path', kw='kw') + loader = lazy_loader.loader + self.assertEqual(('module name', 'module path'), loader.args) + self.assertEqual({'kw': 'kw'}, loader.kwargs) + + def test_validation(self): + # No exec_module(), no lazy loading. + with self.assertRaises(TypeError): + util.LazyLoader.factory(object) + + +class TestingImporter(abc.MetaPathFinder, abc.Loader): + + module_name = 'lazy_loader_test' + mutated_name = 'changed' + loaded = None + source_code = 'attr = 42; __name__ = {!r}'.format(mutated_name) + + def find_spec(self, name, path, target=None): + if name != self.module_name: + return None + return util.spec_from_loader(name, util.LazyLoader(self)) + + def exec_module(self, module): + exec(self.source_code, module.__dict__) + self.loaded = module + + +class LazyLoaderTests(unittest.TestCase): + + def test_init(self): + with self.assertRaises(TypeError): + util.LazyLoader(object) + + def new_module(self, source_code=None): + loader = TestingImporter() + if source_code is not None: + loader.source_code = source_code + spec = util.spec_from_loader(TestingImporter.module_name, + util.LazyLoader(loader)) + module = spec.loader.create_module(spec) + module.__spec__ = spec + module.__loader__ = spec.loader + spec.loader.exec_module(module) + # Module is now lazy. + self.assertIsNone(loader.loaded) + return module + + def test_e2e(self): + # End-to-end test to verify the load is in fact lazy. + importer = TestingImporter() + assert importer.loaded is None + with test_util.uncache(importer.module_name): + with test_util.import_state(meta_path=[importer]): + module = importlib.import_module(importer.module_name) + self.assertIsNone(importer.loaded) + # Trigger load. + self.assertEqual(module.__loader__, importer) + self.assertIsNotNone(importer.loaded) + self.assertEqual(module, importer.loaded) + + def test_attr_unchanged(self): + # An attribute only mutated as a side-effect of import should not be + # changed needlessly. + module = self.new_module() + self.assertEqual(TestingImporter.mutated_name, module.__name__) + + def test_new_attr(self): + # A new attribute should persist. + module = self.new_module() + module.new_attr = 42 + self.assertEqual(42, module.new_attr) + + def test_mutated_preexisting_attr(self): + # Changing an attribute that already existed on the module -- + # e.g. __name__ -- should persist. + module = self.new_module() + module.__name__ = 'bogus' + self.assertEqual('bogus', module.__name__) + + def test_mutated_attr(self): + # Changing an attribute that comes into existence after an import + # should persist. + module = self.new_module() + module.attr = 6 + self.assertEqual(6, module.attr) + + def test_delete_eventual_attr(self): + # Deleting an attribute should stay deleted. + module = self.new_module() + del module.attr + self.assertFalse(hasattr(module, 'attr')) + + def test_delete_preexisting_attr(self): + module = self.new_module() + del module.__name__ + self.assertFalse(hasattr(module, '__name__')) + + def test_module_substitution_error(self): + source_code = 'import sys; sys.modules[__name__] = 42' + module = self.new_module(source_code) + with test_util.uncache(TestingImporter.module_name): + with self.assertRaises(ValueError): + module.__name__ + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_importlib/test_locks.py b/Lib/test/test_importlib/test_locks.py index dc97ba1..df0af12 100644 --- a/Lib/test/test_importlib/test_locks.py +++ b/Lib/test/test_importlib/test_locks.py @@ -1,7 +1,6 @@ -from . import util -frozen_init, source_init = util.import_importlib('importlib') -frozen_bootstrap = frozen_init._bootstrap -source_bootstrap = source_init._bootstrap +from . import util as test_util + +init = test_util.import_importlib('importlib') import sys import time @@ -32,14 +31,20 @@ if threading is not None: test_timeout = None # _release_save() unsupported test_release_save_unacquired = None + # lock status in repr unsupported + test_repr = None + test_locked_repr = None - class Frozen_ModuleLockAsRLockTests(ModuleLockAsRLockTests, lock_tests.RLockTests): - LockType = frozen_bootstrap._ModuleLock - - class Source_ModuleLockAsRLockTests(ModuleLockAsRLockTests, lock_tests.RLockTests): - LockType = source_bootstrap._ModuleLock + LOCK_TYPES = {kind: splitinit._bootstrap._ModuleLock + for kind, splitinit in init.items()} + (Frozen_ModuleLockAsRLockTests, + Source_ModuleLockAsRLockTests + ) = test_util.test_both(ModuleLockAsRLockTests, lock_tests.RLockTests, + LockType=LOCK_TYPES) else: + LOCK_TYPES = {} + class Frozen_ModuleLockAsRLockTests(unittest.TestCase): pass @@ -47,78 +52,94 @@ else: pass -class DeadlockAvoidanceTests: - - def setUp(self): - try: - self.old_switchinterval = sys.getswitchinterval() - sys.setswitchinterval(0.000001) - except AttributeError: - self.old_switchinterval = None - - def tearDown(self): - if self.old_switchinterval is not None: - sys.setswitchinterval(self.old_switchinterval) - - def run_deadlock_avoidance_test(self, create_deadlock): - NLOCKS = 10 - locks = [self.LockType(str(i)) for i in range(NLOCKS)] - pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)] - if create_deadlock: - NTHREADS = NLOCKS - else: - NTHREADS = NLOCKS - 1 - barrier = threading.Barrier(NTHREADS) - results = [] - def _acquire(lock): - """Try to acquire the lock. Return True on success, False on deadlock.""" +if threading is not None: + class DeadlockAvoidanceTests: + + def setUp(self): try: - lock.acquire() - except self.DeadlockError: - return False + self.old_switchinterval = sys.getswitchinterval() + sys.setswitchinterval(0.000001) + except AttributeError: + self.old_switchinterval = None + + def tearDown(self): + if self.old_switchinterval is not None: + sys.setswitchinterval(self.old_switchinterval) + + def run_deadlock_avoidance_test(self, create_deadlock): + NLOCKS = 10 + locks = [self.LockType(str(i)) for i in range(NLOCKS)] + pairs = [(locks[i], locks[(i+1)%NLOCKS]) for i in range(NLOCKS)] + if create_deadlock: + NTHREADS = NLOCKS else: - return True - def f(): - a, b = pairs.pop() - ra = _acquire(a) - barrier.wait() - rb = _acquire(b) - results.append((ra, rb)) - if rb: - b.release() - if ra: - a.release() - lock_tests.Bunch(f, NTHREADS).wait_for_finished() - self.assertEqual(len(results), NTHREADS) - return results - - def test_deadlock(self): - results = self.run_deadlock_avoidance_test(True) - # At least one of the threads detected a potential deadlock on its - # second acquire() call. It may be several of them, because the - # deadlock avoidance mechanism is conservative. - nb_deadlocks = results.count((True, False)) - self.assertGreaterEqual(nb_deadlocks, 1) - self.assertEqual(results.count((True, True)), len(results) - nb_deadlocks) - - def test_no_deadlock(self): - results = self.run_deadlock_avoidance_test(False) - self.assertEqual(results.count((True, False)), 0) - self.assertEqual(results.count((True, True)), len(results)) - -@unittest.skipUnless(threading, "threads needed for this test") -class Frozen_DeadlockAvoidanceTests(DeadlockAvoidanceTests, unittest.TestCase): - LockType = frozen_bootstrap._ModuleLock - DeadlockError = frozen_bootstrap._DeadlockError - -@unittest.skipUnless(threading, "threads needed for this test") -class Source_DeadlockAvoidanceTests(DeadlockAvoidanceTests, unittest.TestCase): - LockType = source_bootstrap._ModuleLock - DeadlockError = source_bootstrap._DeadlockError + NTHREADS = NLOCKS - 1 + barrier = threading.Barrier(NTHREADS) + results = [] + + def _acquire(lock): + """Try to acquire the lock. Return True on success, + False on deadlock.""" + try: + lock.acquire() + except self.DeadlockError: + return False + else: + return True + + def f(): + a, b = pairs.pop() + ra = _acquire(a) + barrier.wait() + rb = _acquire(b) + results.append((ra, rb)) + if rb: + b.release() + if ra: + a.release() + lock_tests.Bunch(f, NTHREADS).wait_for_finished() + self.assertEqual(len(results), NTHREADS) + return results + + def test_deadlock(self): + results = self.run_deadlock_avoidance_test(True) + # At least one of the threads detected a potential deadlock on its + # second acquire() call. It may be several of them, because the + # deadlock avoidance mechanism is conservative. + nb_deadlocks = results.count((True, False)) + self.assertGreaterEqual(nb_deadlocks, 1) + self.assertEqual(results.count((True, True)), len(results) - nb_deadlocks) + + def test_no_deadlock(self): + results = self.run_deadlock_avoidance_test(False) + self.assertEqual(results.count((True, False)), 0) + self.assertEqual(results.count((True, True)), len(results)) + + + DEADLOCK_ERRORS = {kind: splitinit._bootstrap._DeadlockError + for kind, splitinit in init.items()} + + (Frozen_DeadlockAvoidanceTests, + Source_DeadlockAvoidanceTests + ) = test_util.test_both(DeadlockAvoidanceTests, + LockType=LOCK_TYPES, + DeadlockError=DEADLOCK_ERRORS) +else: + DEADLOCK_ERRORS = {} + + class Frozen_DeadlockAvoidanceTests(unittest.TestCase): + pass + + class Source_DeadlockAvoidanceTests(unittest.TestCase): + pass class LifetimeTests: + @property + def bootstrap(self): + return self.init._bootstrap + def test_lock_lifetime(self): name = "xyzzy" self.assertNotIn(name, self.bootstrap._module_locks) @@ -135,11 +156,10 @@ class LifetimeTests: self.assertEqual(0, len(self.bootstrap._module_locks), self.bootstrap._module_locks) -class Frozen_LifetimeTests(LifetimeTests, unittest.TestCase): - bootstrap = frozen_bootstrap -class Source_LifetimeTests(LifetimeTests, unittest.TestCase): - bootstrap = source_bootstrap +(Frozen_LifetimeTests, + Source_LifetimeTests + ) = test_util.test_both(LifetimeTests, init=init) @support.reap_threads diff --git a/Lib/test/test_importlib/test_spec.py b/Lib/test/test_importlib/test_spec.py index 71541f6..8b333e8 100644 --- a/Lib/test/test_importlib/test_spec.py +++ b/Lib/test/test_importlib/test_spec.py @@ -1,10 +1,8 @@ -from . import util +from . import util as test_util -frozen_init, source_init = util.import_importlib('importlib') -frozen_bootstrap = frozen_init._bootstrap -source_bootstrap = source_init._bootstrap -frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') -frozen_util, source_util = util.import_importlib('importlib.util') +init = test_util.import_importlib('importlib') +machinery = test_util.import_importlib('importlib.machinery') +util = test_util.import_importlib('importlib.util') import os.path from test.support import CleanImport @@ -36,6 +34,9 @@ class TestLoader: def _is_package(self, name): return self.package + def create_module(self, spec): + return None + class NewLoader(TestLoader): @@ -52,6 +53,8 @@ class LegacyLoader(TestLoader): with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) + frozen_util = util['Frozen'] + @frozen_util.module_for_loader def load_module(self, module): module.ham = self.HAM @@ -221,18 +224,17 @@ class ModuleSpecTests: self.assertEqual(self.loc_spec.cached, 'spam.pyc') -class Frozen_ModuleSpecTests(ModuleSpecTests, unittest.TestCase): - util = frozen_util - machinery = frozen_machinery - - -class Source_ModuleSpecTests(ModuleSpecTests, unittest.TestCase): - util = source_util - machinery = source_machinery +(Frozen_ModuleSpecTests, + Source_ModuleSpecTests + ) = test_util.test_both(ModuleSpecTests, util=util, machinery=machinery) class ModuleSpecMethodsTests: + @property + def bootstrap(self): + return self.init._bootstrap + def setUp(self): self.name = 'spam' self.path = 'spam.py' @@ -243,152 +245,14 @@ class ModuleSpecMethodsTests: origin=self.path) self.loc_spec._set_fileattr = True - # init_module_attrs - - def test_init_module_attrs(self): - module = type(sys)(self.name) - spec = self.machinery.ModuleSpec(self.name, self.loader) - self.bootstrap._SpecMethods(spec).init_module_attrs(module) - - self.assertEqual(module.__name__, spec.name) - self.assertIs(module.__loader__, spec.loader) - self.assertEqual(module.__package__, spec.parent) - self.assertIs(module.__spec__, spec) - self.assertFalse(hasattr(module, '__path__')) - self.assertFalse(hasattr(module, '__file__')) - self.assertFalse(hasattr(module, '__cached__')) - - def test_init_module_attrs_package(self): - module = type(sys)(self.name) - spec = self.machinery.ModuleSpec(self.name, self.loader) - spec.submodule_search_locations = ['spam', 'ham'] - self.bootstrap._SpecMethods(spec).init_module_attrs(module) - - self.assertEqual(module.__name__, spec.name) - self.assertIs(module.__loader__, spec.loader) - self.assertEqual(module.__package__, spec.parent) - self.assertIs(module.__spec__, spec) - self.assertIs(module.__path__, spec.submodule_search_locations) - self.assertFalse(hasattr(module, '__file__')) - self.assertFalse(hasattr(module, '__cached__')) - - def test_init_module_attrs_location(self): - module = type(sys)(self.name) - spec = self.loc_spec - self.bootstrap._SpecMethods(spec).init_module_attrs(module) - - self.assertEqual(module.__name__, spec.name) - self.assertIs(module.__loader__, spec.loader) - self.assertEqual(module.__package__, spec.parent) - self.assertIs(module.__spec__, spec) - self.assertFalse(hasattr(module, '__path__')) - self.assertEqual(module.__file__, spec.origin) - self.assertEqual(module.__cached__, - self.util.cache_from_source(spec.origin)) - - def test_init_module_attrs_different_name(self): - module = type(sys)('eggs') - spec = self.machinery.ModuleSpec(self.name, self.loader) - self.bootstrap._SpecMethods(spec).init_module_attrs(module) - - self.assertEqual(module.__name__, spec.name) - - def test_init_module_attrs_different_spec(self): - module = type(sys)(self.name) - module.__spec__ = self.machinery.ModuleSpec('eggs', object()) - spec = self.machinery.ModuleSpec(self.name, self.loader) - self.bootstrap._SpecMethods(spec).init_module_attrs(module) - - self.assertEqual(module.__name__, spec.name) - self.assertIs(module.__loader__, spec.loader) - self.assertEqual(module.__package__, spec.parent) - self.assertIs(module.__spec__, spec) - - def test_init_module_attrs_already_set(self): - module = type(sys)('ham.eggs') - module.__loader__ = object() - module.__package__ = 'ham' - module.__path__ = ['eggs'] - module.__file__ = 'ham/eggs/__init__.py' - module.__cached__ = self.util.cache_from_source(module.__file__) - original = vars(module).copy() - spec = self.loc_spec - spec.submodule_search_locations = [''] - self.bootstrap._SpecMethods(spec).init_module_attrs(module) - - self.assertIs(module.__loader__, original['__loader__']) - self.assertEqual(module.__package__, original['__package__']) - self.assertIs(module.__path__, original['__path__']) - self.assertEqual(module.__file__, original['__file__']) - self.assertEqual(module.__cached__, original['__cached__']) - - def test_init_module_attrs_immutable(self): - module = object() - spec = self.loc_spec - spec.submodule_search_locations = [''] - self.bootstrap._SpecMethods(spec).init_module_attrs(module) - - self.assertFalse(hasattr(module, '__name__')) - self.assertFalse(hasattr(module, '__loader__')) - self.assertFalse(hasattr(module, '__package__')) - self.assertFalse(hasattr(module, '__spec__')) - self.assertFalse(hasattr(module, '__path__')) - self.assertFalse(hasattr(module, '__file__')) - self.assertFalse(hasattr(module, '__cached__')) - - # create() - - def test_create(self): - created = self.bootstrap._SpecMethods(self.spec).create() - - self.assertEqual(created.__name__, self.spec.name) - self.assertIs(created.__loader__, self.spec.loader) - self.assertEqual(created.__package__, self.spec.parent) - self.assertIs(created.__spec__, self.spec) - self.assertFalse(hasattr(created, '__path__')) - self.assertFalse(hasattr(created, '__file__')) - self.assertFalse(hasattr(created, '__cached__')) - - def test_create_from_loader(self): - module = type(sys.implementation)() - class CreatingLoader(TestLoader): - def create_module(self, spec): - return module - self.spec.loader = CreatingLoader() - created = self.bootstrap._SpecMethods(self.spec).create() - - self.assertIs(created, module) - self.assertEqual(created.__name__, self.spec.name) - self.assertIs(created.__loader__, self.spec.loader) - self.assertEqual(created.__package__, self.spec.parent) - self.assertIs(created.__spec__, self.spec) - self.assertFalse(hasattr(created, '__path__')) - self.assertFalse(hasattr(created, '__file__')) - self.assertFalse(hasattr(created, '__cached__')) - - def test_create_from_loader_not_handled(self): - class CreatingLoader(TestLoader): - def create_module(self, spec): - return None - self.spec.loader = CreatingLoader() - created = self.bootstrap._SpecMethods(self.spec).create() - - self.assertEqual(created.__name__, self.spec.name) - self.assertIs(created.__loader__, self.spec.loader) - self.assertEqual(created.__package__, self.spec.parent) - self.assertIs(created.__spec__, self.spec) - self.assertFalse(hasattr(created, '__path__')) - self.assertFalse(hasattr(created, '__file__')) - self.assertFalse(hasattr(created, '__cached__')) - # exec() def test_exec(self): self.spec.loader = NewLoader() - module = self.bootstrap._SpecMethods(self.spec).create() + module = self.util.module_from_spec(self.spec) sys.modules[self.name] = module self.assertFalse(hasattr(module, 'eggs')) - self.bootstrap._SpecMethods(self.spec).exec(module) + self.bootstrap._exec(self.spec, module) self.assertEqual(module.eggs, 1) @@ -397,7 +261,7 @@ class ModuleSpecMethodsTests: def test_load(self): self.spec.loader = NewLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) installed = sys.modules[self.spec.name] self.assertEqual(loaded.eggs, 1) @@ -410,7 +274,7 @@ class ModuleSpecMethodsTests: sys.modules[module.__name__] = replacement self.spec.loader = ReplacingLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) installed = sys.modules[self.spec.name] self.assertIs(loaded, replacement) @@ -423,7 +287,7 @@ class ModuleSpecMethodsTests: self.spec.loader = FailedLoader() with CleanImport(self.spec.name): with self.assertRaises(RuntimeError): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) self.assertNotIn(self.spec.name, sys.modules) def test_load_failed_removed(self): @@ -434,20 +298,20 @@ class ModuleSpecMethodsTests: self.spec.loader = FailedLoader() with CleanImport(self.spec.name): with self.assertRaises(RuntimeError): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) self.assertNotIn(self.spec.name, sys.modules) def test_load_legacy(self): self.spec.loader = LegacyLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) self.assertEqual(loaded.ham, -1) def test_load_legacy_attributes(self): self.spec.loader = LegacyLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) self.assertIs(loaded.__loader__, self.spec.loader) self.assertEqual(loaded.__package__, self.spec.parent) @@ -461,7 +325,7 @@ class ModuleSpecMethodsTests: return module self.spec.loader = ImmutableLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) self.assertIs(sys.modules[self.spec.name], module) @@ -470,8 +334,8 @@ class ModuleSpecMethodsTests: def test_reload(self): self.spec.loader = NewLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() - reloaded = self.bootstrap._SpecMethods(self.spec).exec(loaded) + loaded = self.bootstrap._load(self.spec) + reloaded = self.bootstrap._exec(self.spec, loaded) installed = sys.modules[self.spec.name] self.assertEqual(loaded.eggs, 1) @@ -481,9 +345,9 @@ class ModuleSpecMethodsTests: def test_reload_modified(self): self.spec.loader = NewLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) loaded.eggs = 2 - reloaded = self.bootstrap._SpecMethods(self.spec).exec(loaded) + reloaded = self.bootstrap._exec(self.spec, loaded) self.assertEqual(loaded.eggs, 1) self.assertIs(reloaded, loaded) @@ -491,9 +355,9 @@ class ModuleSpecMethodsTests: def test_reload_extra_attributes(self): self.spec.loader = NewLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) loaded.available = False - reloaded = self.bootstrap._SpecMethods(self.spec).exec(loaded) + reloaded = self.bootstrap._exec(self.spec, loaded) self.assertFalse(loaded.available) self.assertIs(reloaded, loaded) @@ -501,12 +365,12 @@ class ModuleSpecMethodsTests: def test_reload_init_module_attrs(self): self.spec.loader = NewLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() + loaded = self.bootstrap._load(self.spec) loaded.__name__ = 'ham' del loaded.__loader__ del loaded.__package__ del loaded.__spec__ - self.bootstrap._SpecMethods(self.spec).exec(loaded) + self.bootstrap._exec(self.spec, loaded) self.assertEqual(loaded.__name__, self.spec.name) self.assertIs(loaded.__loader__, self.spec.loader) @@ -519,8 +383,8 @@ class ModuleSpecMethodsTests: def test_reload_legacy(self): self.spec.loader = LegacyLoader() with CleanImport(self.spec.name): - loaded = self.bootstrap._SpecMethods(self.spec).load() - reloaded = self.bootstrap._SpecMethods(self.spec).exec(loaded) + loaded = self.bootstrap._load(self.spec) + reloaded = self.bootstrap._exec(self.spec, loaded) installed = sys.modules[self.spec.name] self.assertEqual(loaded.ham, -1) @@ -528,20 +392,18 @@ class ModuleSpecMethodsTests: self.assertIs(installed, loaded) -class Frozen_ModuleSpecMethodsTests(ModuleSpecMethodsTests, unittest.TestCase): - bootstrap = frozen_bootstrap - machinery = frozen_machinery - util = frozen_util - - -class Source_ModuleSpecMethodsTests(ModuleSpecMethodsTests, unittest.TestCase): - bootstrap = source_bootstrap - machinery = source_machinery - util = source_util +(Frozen_ModuleSpecMethodsTests, + Source_ModuleSpecMethodsTests + ) = test_util.test_both(ModuleSpecMethodsTests, init=init, util=util, + machinery=machinery) class ModuleReprTests: + @property + def bootstrap(self): + return self.init._bootstrap + def setUp(self): self.module = type(os)('spam') self.spec = self.machinery.ModuleSpec('spam', TestLoader()) @@ -625,16 +487,10 @@ class ModuleReprTests: self.assertEqual(modrepr, '<module {!r}>'.format('spam')) -class Frozen_ModuleReprTests(ModuleReprTests, unittest.TestCase): - bootstrap = frozen_bootstrap - machinery = frozen_machinery - util = frozen_util - - -class Source_ModuleReprTests(ModuleReprTests, unittest.TestCase): - bootstrap = source_bootstrap - machinery = source_machinery - util = source_util +(Frozen_ModuleReprTests, + Source_ModuleReprTests + ) = test_util.test_both(ModuleReprTests, init=init, util=util, + machinery=machinery) class FactoryTests: @@ -787,13 +643,14 @@ class FactoryTests: # spec_from_file_location() def test_spec_from_file_location_default(self): - if self.machinery is source_machinery: - raise unittest.SkipTest('not sure why this is breaking...') spec = self.util.spec_from_file_location(self.name, self.path) self.assertEqual(spec.name, self.name) + # Need to use a circuitous route to get at importlib.machinery to make + # sure the same class object is used in the isinstance() check as + # would have been used to create the loader. self.assertIsInstance(spec.loader, - self.machinery.SourceFileLoader) + self.util.abc.machinery.SourceFileLoader) self.assertEqual(spec.loader.name, self.name) self.assertEqual(spec.loader.path, self.path) self.assertEqual(spec.origin, self.path) @@ -947,11 +804,10 @@ class FactoryTests: self.assertTrue(spec.has_location) -class Frozen_FactoryTests(FactoryTests, unittest.TestCase): - util = frozen_util - machinery = frozen_machinery +(Frozen_FactoryTests, + Source_FactoryTests + ) = test_util.test_both(FactoryTests, util=util, machinery=machinery) -class Source_FactoryTests(FactoryTests, unittest.TestCase): - util = source_util - machinery = source_machinery +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_importlib/test_util.py b/Lib/test/test_importlib/test_util.py index b2823c6..69466b2 100644 --- a/Lib/test/test_importlib/test_util.py +++ b/Lib/test/test_importlib/test_util.py @@ -1,10 +1,11 @@ -from importlib import util -from . import util as test_util -frozen_init, source_init = test_util.import_importlib('importlib') -frozen_machinery, source_machinery = test_util.import_importlib('importlib.machinery') -frozen_util, source_util = test_util.import_importlib('importlib.util') +from . import util +abc = util.import_importlib('importlib.abc') +init = util.import_importlib('importlib') +machinery = util.import_importlib('importlib.machinery') +importlib_util = util.import_importlib('importlib.util') import os +import string import sys from test import support import types @@ -32,8 +33,94 @@ class DecodeSourceBytesTests: self.assertEqual(self.util.decode_source(source_bytes), '\n'.join([self.source, self.source])) -Frozen_DecodeSourceBytesTests, Source_DecodeSourceBytesTests = test_util.test_both( - DecodeSourceBytesTests, util=[frozen_util, source_util]) + +(Frozen_DecodeSourceBytesTests, + Source_DecodeSourceBytesTests + ) = util.test_both(DecodeSourceBytesTests, util=importlib_util) + + +class ModuleFromSpecTests: + + def test_no_create_module(self): + class Loader: + def exec_module(self, module): + pass + spec = self.machinery.ModuleSpec('test', Loader()) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + module = self.util.module_from_spec(spec) + self.assertEqual(1, len(w)) + self.assertTrue(issubclass(w[0].category, DeprecationWarning)) + self.assertIn('create_module', str(w[0].message)) + self.assertIsInstance(module, types.ModuleType) + self.assertEqual(module.__name__, spec.name) + + def test_create_module_returns_None(self): + class Loader(self.abc.Loader): + def create_module(self, spec): + return None + spec = self.machinery.ModuleSpec('test', Loader()) + module = self.util.module_from_spec(spec) + self.assertIsInstance(module, types.ModuleType) + self.assertEqual(module.__name__, spec.name) + + def test_create_module(self): + name = 'already set' + class CustomModule(types.ModuleType): + pass + class Loader(self.abc.Loader): + def create_module(self, spec): + module = CustomModule(spec.name) + module.__name__ = name + return module + spec = self.machinery.ModuleSpec('test', Loader()) + module = self.util.module_from_spec(spec) + self.assertIsInstance(module, CustomModule) + self.assertEqual(module.__name__, name) + + def test___name__(self): + spec = self.machinery.ModuleSpec('test', object()) + module = self.util.module_from_spec(spec) + self.assertEqual(module.__name__, spec.name) + + def test___spec__(self): + spec = self.machinery.ModuleSpec('test', object()) + module = self.util.module_from_spec(spec) + self.assertEqual(module.__spec__, spec) + + def test___loader__(self): + loader = object() + spec = self.machinery.ModuleSpec('test', loader) + module = self.util.module_from_spec(spec) + self.assertIs(module.__loader__, loader) + + def test___package__(self): + spec = self.machinery.ModuleSpec('test.pkg', object()) + module = self.util.module_from_spec(spec) + self.assertEqual(module.__package__, spec.parent) + + def test___path__(self): + spec = self.machinery.ModuleSpec('test', object(), is_package=True) + module = self.util.module_from_spec(spec) + self.assertEqual(module.__path__, spec.submodule_search_locations) + + def test___file__(self): + spec = self.machinery.ModuleSpec('test', object(), origin='some/path') + spec.has_location = True + module = self.util.module_from_spec(spec) + self.assertEqual(module.__file__, spec.origin) + + def test___cached__(self): + spec = self.machinery.ModuleSpec('test', object()) + spec.cached = 'some/path' + spec.has_location = True + module = self.util.module_from_spec(spec) + self.assertEqual(module.__cached__, spec.cached) + +(Frozen_ModuleFromSpecTests, + Source_ModuleFromSpecTests +) = util.test_both(ModuleFromSpecTests, abc=abc, machinery=machinery, + util=importlib_util) class ModuleForLoaderTests: @@ -70,7 +157,7 @@ class ModuleForLoaderTests: # Test that when no module exists in sys.modules a new module is # created. module_name = 'a.b.c' - with test_util.uncache(module_name): + with util.uncache(module_name): module = self.return_module(module_name) self.assertIn(module_name, sys.modules) self.assertIsInstance(module, types.ModuleType) @@ -88,7 +175,7 @@ class ModuleForLoaderTests: module = types.ModuleType('a.b.c') module.__loader__ = 42 module.__package__ = 42 - with test_util.uncache(name): + with util.uncache(name): sys.modules[name] = module loader = FakeLoader() returned_module = loader.load_module(name) @@ -100,7 +187,7 @@ class ModuleForLoaderTests: # Test that a module is removed from sys.modules if added but an # exception is raised. name = 'a.b.c' - with test_util.uncache(name): + with util.uncache(name): self.raise_exception(name) self.assertNotIn(name, sys.modules) @@ -108,7 +195,7 @@ class ModuleForLoaderTests: # Test that a failure on reload leaves the module in-place. name = 'a.b.c' module = types.ModuleType(name) - with test_util.uncache(name): + with util.uncache(name): sys.modules[name] = module self.raise_exception(name) self.assertIs(module, sys.modules[name]) @@ -127,7 +214,7 @@ class ModuleForLoaderTests: name = 'mod' module = FalseModule(name) - with test_util.uncache(name): + with util.uncache(name): self.assertFalse(module) sys.modules[name] = module given = self.return_module(name) @@ -146,7 +233,7 @@ class ModuleForLoaderTests: return module name = 'pkg.mod' - with test_util.uncache(name): + with util.uncache(name): loader = FakeLoader(False) module = loader.load_module(name) self.assertEqual(module.__name__, name) @@ -154,15 +241,17 @@ class ModuleForLoaderTests: self.assertEqual(module.__package__, 'pkg') name = 'pkg.sub' - with test_util.uncache(name): + with util.uncache(name): loader = FakeLoader(True) module = loader.load_module(name) self.assertEqual(module.__name__, name) self.assertIs(module.__loader__, loader) self.assertEqual(module.__package__, name) -Frozen_ModuleForLoaderTests, Source_ModuleForLoaderTests = test_util.test_both( - ModuleForLoaderTests, util=[frozen_util, source_util]) + +(Frozen_ModuleForLoaderTests, + Source_ModuleForLoaderTests + ) = util.test_both(ModuleForLoaderTests, util=importlib_util) class SetPackageTests: @@ -222,18 +311,25 @@ class SetPackageTests: self.assertEqual(wrapped.__name__, fxn.__name__) self.assertEqual(wrapped.__qualname__, fxn.__qualname__) -Frozen_SetPackageTests, Source_SetPackageTests = test_util.test_both( - SetPackageTests, util=[frozen_util, source_util]) + +(Frozen_SetPackageTests, + Source_SetPackageTests + ) = util.test_both(SetPackageTests, util=importlib_util) class SetLoaderTests: """Tests importlib.util.set_loader().""" - class DummyLoader: - @util.set_loader - def load_module(self, module): - return self.module + @property + def DummyLoader(self): + # Set DummyLoader on the class lazily. + class DummyLoader: + @self.util.set_loader + def load_module(self, module): + return self.module + self.__class__.DummyLoader = DummyLoader + return DummyLoader def test_no_attribute(self): loader = self.DummyLoader() @@ -262,17 +358,10 @@ class SetLoaderTests: warnings.simplefilter('ignore', DeprecationWarning) self.assertEqual(42, loader.load_module('blah').__loader__) -class Frozen_SetLoaderTests(SetLoaderTests, unittest.TestCase): - class DummyLoader: - @frozen_util.set_loader - def load_module(self, module): - return self.module -class Source_SetLoaderTests(SetLoaderTests, unittest.TestCase): - class DummyLoader: - @source_util.set_loader - def load_module(self, module): - return self.module +(Frozen_SetLoaderTests, + Source_SetLoaderTests + ) = util.test_both(SetLoaderTests, util=importlib_util) class ResolveNameTests: @@ -307,9 +396,10 @@ class ResolveNameTests: with self.assertRaises(ValueError): self.util.resolve_name('..bacon', 'spam') -Frozen_ResolveNameTests, Source_ResolveNameTests = test_util.test_both( - ResolveNameTests, - util=[frozen_util, source_util]) + +(Frozen_ResolveNameTests, + Source_ResolveNameTests + ) = util.test_both(ResolveNameTests, util=importlib_util) class FindSpecTests: @@ -320,7 +410,7 @@ class FindSpecTests: def test_sys_modules(self): name = 'some_mod' - with test_util.uncache(name): + with util.uncache(name): module = types.ModuleType(name) loader = 'a loader!' spec = self.machinery.ModuleSpec(name, loader) @@ -332,7 +422,7 @@ class FindSpecTests: def test_sys_modules_without___loader__(self): name = 'some_mod' - with test_util.uncache(name): + with util.uncache(name): module = types.ModuleType(name) del module.__loader__ loader = 'a loader!' @@ -344,7 +434,7 @@ class FindSpecTests: def test_sys_modules_spec_is_None(self): name = 'some_mod' - with test_util.uncache(name): + with util.uncache(name): module = types.ModuleType(name) module.__spec__ = None sys.modules[name] = module @@ -353,7 +443,7 @@ class FindSpecTests: def test_sys_modules_loader_is_None(self): name = 'some_mod' - with test_util.uncache(name): + with util.uncache(name): module = types.ModuleType(name) spec = self.machinery.ModuleSpec(name, None) module.__spec__ = spec @@ -363,7 +453,7 @@ class FindSpecTests: def test_sys_modules_spec_is_not_set(self): name = 'some_mod' - with test_util.uncache(name): + with util.uncache(name): module = types.ModuleType(name) try: del module.__spec__ @@ -375,20 +465,11 @@ class FindSpecTests: def test_success(self): name = 'some_mod' - with test_util.uncache(name): - with test_util.import_state(meta_path=[self.FakeMetaFinder]): + with util.uncache(name): + with util.import_state(meta_path=[self.FakeMetaFinder]): self.assertEqual((name, None, None), self.util.find_spec(name)) -# def test_success_path(self): -# # Searching on a path should work. -# name = 'some_mod' -# path = 'path to some place' -# with test_util.uncache(name): -# with test_util.import_state(meta_path=[self.FakeMetaFinder]): -# self.assertEqual((name, path, None), -# self.util.find_spec(name, path)) - def test_nothing(self): # None is returned upon failure to find a loader. self.assertIsNone(self.util.find_spec('nevergoingtofindthismodule')) @@ -396,8 +477,8 @@ class FindSpecTests: def test_find_submodule(self): name = 'spam' subname = 'ham' - with test_util.temp_module(name, pkg=True) as pkg_dir: - fullname, _ = test_util.submodule(name, subname, pkg_dir) + with util.temp_module(name, pkg=True) as pkg_dir: + fullname, _ = util.submodule(name, subname, pkg_dir) spec = self.util.find_spec(fullname) self.assertIsNot(spec, None) self.assertIn(name, sorted(sys.modules)) @@ -409,9 +490,9 @@ class FindSpecTests: def test_find_submodule_parent_already_imported(self): name = 'spam' subname = 'ham' - with test_util.temp_module(name, pkg=True) as pkg_dir: + with util.temp_module(name, pkg=True) as pkg_dir: self.init.import_module(name) - fullname, _ = test_util.submodule(name, subname, pkg_dir) + fullname, _ = util.submodule(name, subname, pkg_dir) spec = self.util.find_spec(fullname) self.assertIsNot(spec, None) self.assertIn(name, sorted(sys.modules)) @@ -423,8 +504,8 @@ class FindSpecTests: def test_find_relative_module(self): name = 'spam' subname = 'ham' - with test_util.temp_module(name, pkg=True) as pkg_dir: - fullname, _ = test_util.submodule(name, subname, pkg_dir) + with util.temp_module(name, pkg=True) as pkg_dir: + fullname, _ = util.submodule(name, subname, pkg_dir) relname = '.' + subname spec = self.util.find_spec(relname, name) self.assertIsNot(spec, None) @@ -437,8 +518,8 @@ class FindSpecTests: def test_find_relative_module_missing_package(self): name = 'spam' subname = 'ham' - with test_util.temp_module(name, pkg=True) as pkg_dir: - fullname, _ = test_util.submodule(name, subname, pkg_dir) + with util.temp_module(name, pkg=True) as pkg_dir: + fullname, _ = util.submodule(name, subname, pkg_dir) relname = '.' + subname with self.assertRaises(ValueError): self.util.find_spec(relname) @@ -446,15 +527,10 @@ class FindSpecTests: self.assertNotIn(fullname, sorted(sys.modules)) -class Frozen_FindSpecTests(FindSpecTests, unittest.TestCase): - init = frozen_init - machinery = frozen_machinery - util = frozen_util - -class Source_FindSpecTests(FindSpecTests, unittest.TestCase): - init = source_init - machinery = source_machinery - util = source_util +(Frozen_FindSpecTests, + Source_FindSpecTests + ) = util.test_both(FindSpecTests, init=init, util=importlib_util, + machinery=machinery) class MagicNumberTests: @@ -467,8 +543,10 @@ class MagicNumberTests: # The magic number uses \r\n to come out wrong when splitting on lines. self.assertTrue(self.util.MAGIC_NUMBER.endswith(b'\r\n')) -Frozen_MagicNumberTests, Source_MagicNumberTests = test_util.test_both( - MagicNumberTests, util=[frozen_util, source_util]) + +(Frozen_MagicNumberTests, + Source_MagicNumberTests + ) = util.test_both(MagicNumberTests, util=importlib_util) class PEP3147Tests: @@ -485,7 +563,8 @@ class PEP3147Tests: path = os.path.join('foo', 'bar', 'baz', 'qux.py') expect = os.path.join('foo', 'bar', 'baz', '__pycache__', 'qux.{}.pyc'.format(self.tag)) - self.assertEqual(self.util.cache_from_source(path, True), expect) + self.assertEqual(self.util.cache_from_source(path, optimization=''), + expect) def test_cache_from_source_no_cache_tag(self): # No cache tag means NotImplementedError. @@ -498,43 +577,103 @@ class PEP3147Tests: path = os.path.join('foo.bar', 'file') expect = os.path.join('foo.bar', '__pycache__', 'file{}.pyc'.format(self.tag)) - self.assertEqual(self.util.cache_from_source(path, True), expect) + self.assertEqual(self.util.cache_from_source(path, optimization=''), + expect) - def test_cache_from_source_optimized(self): - # Given the path to a .py file, return the path to its PEP 3147 - # defined .pyo file (i.e. under __pycache__). + def test_cache_from_source_debug_override(self): + # Given the path to a .py file, return the path to its PEP 3147/PEP 488 + # defined .pyc file (i.e. under __pycache__). path = os.path.join('foo', 'bar', 'baz', 'qux.py') - expect = os.path.join('foo', 'bar', 'baz', '__pycache__', - 'qux.{}.pyo'.format(self.tag)) - self.assertEqual(self.util.cache_from_source(path, False), expect) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + self.assertEqual(self.util.cache_from_source(path, False), + self.util.cache_from_source(path, optimization=1)) + self.assertEqual(self.util.cache_from_source(path, True), + self.util.cache_from_source(path, optimization='')) + with warnings.catch_warnings(): + warnings.simplefilter('error') + with self.assertRaises(DeprecationWarning): + self.util.cache_from_source(path, False) + with self.assertRaises(DeprecationWarning): + self.util.cache_from_source(path, True) def test_cache_from_source_cwd(self): path = 'foo.py' expect = os.path.join('__pycache__', 'foo.{}.pyc'.format(self.tag)) - self.assertEqual(self.util.cache_from_source(path, True), expect) + self.assertEqual(self.util.cache_from_source(path, optimization=''), + expect) def test_cache_from_source_override(self): # When debug_override is not None, it can be any true-ish or false-ish # value. path = os.path.join('foo', 'bar', 'baz.py') - partial_expect = os.path.join('foo', 'bar', '__pycache__', - 'baz.{}.py'.format(self.tag)) - self.assertEqual(self.util.cache_from_source(path, []), partial_expect + 'o') - self.assertEqual(self.util.cache_from_source(path, [17]), - partial_expect + 'c') # However if the bool-ishness can't be determined, the exception # propagates. class Bearish: def __bool__(self): raise RuntimeError - with self.assertRaises(RuntimeError): - self.util.cache_from_source('/foo/bar/baz.py', Bearish()) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + self.assertEqual(self.util.cache_from_source(path, []), + self.util.cache_from_source(path, optimization=1)) + self.assertEqual(self.util.cache_from_source(path, [17]), + self.util.cache_from_source(path, optimization='')) + with self.assertRaises(RuntimeError): + self.util.cache_from_source('/foo/bar/baz.py', Bearish()) + + + def test_cache_from_source_optimization_empty_string(self): + # Setting 'optimization' to '' leads to no optimization tag (PEP 488). + path = 'foo.py' + expect = os.path.join('__pycache__', 'foo.{}.pyc'.format(self.tag)) + self.assertEqual(self.util.cache_from_source(path, optimization=''), + expect) + + def test_cache_from_source_optimization_None(self): + # Setting 'optimization' to None uses the interpreter's optimization. + # (PEP 488) + path = 'foo.py' + optimization_level = sys.flags.optimize + almost_expect = os.path.join('__pycache__', 'foo.{}'.format(self.tag)) + if optimization_level == 0: + expect = almost_expect + '.pyc' + elif optimization_level <= 2: + expect = almost_expect + '.opt-{}.pyc'.format(optimization_level) + else: + msg = '{!r} is a non-standard optimization level'.format(optimization_level) + self.skipTest(msg) + self.assertEqual(self.util.cache_from_source(path, optimization=None), + expect) + + def test_cache_from_source_optimization_set(self): + # The 'optimization' parameter accepts anything that has a string repr + # that passes str.alnum(). + path = 'foo.py' + valid_characters = string.ascii_letters + string.digits + almost_expect = os.path.join('__pycache__', 'foo.{}'.format(self.tag)) + got = self.util.cache_from_source(path, optimization=valid_characters) + # Test all valid characters are accepted. + self.assertEqual(got, + almost_expect + '.opt-{}.pyc'.format(valid_characters)) + # str() should be called on argument. + self.assertEqual(self.util.cache_from_source(path, optimization=42), + almost_expect + '.opt-42.pyc') + # Invalid characters raise ValueError. + with self.assertRaises(ValueError): + self.util.cache_from_source(path, optimization='path/is/bad') + + def test_cache_from_source_debug_override_optimization_both_set(self): + # Can only set one of the optimization-related parameters. + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + with self.assertRaises(TypeError): + self.util.cache_from_source('foo.py', False, optimization='') @unittest.skipUnless(os.sep == '\\' and os.altsep == '/', 'test meaningful only where os.altsep is defined') def test_sep_altsep_and_sep_cache_from_source(self): # Windows path and PEP 3147 where sep is right of altsep. self.assertEqual( - self.util.cache_from_source('\\foo\\bar\\baz/qux.py', True), + self.util.cache_from_source('\\foo\\bar\\baz/qux.py', optimization=''), '\\foo\\bar\\baz\\__pycache__\\qux.{}.pyc'.format(self.tag)) @unittest.skipUnless(sys.implementation.cache_tag is not None, @@ -572,7 +711,12 @@ class PEP3147Tests: ValueError, self.util.source_from_cache, '__pycache__/foo.pyc') def test_source_from_cache_too_many_dots(self): - # Too many dots in final path component -> ValueError + with self.assertRaises(ValueError): + self.util.source_from_cache( + '__pycache__/foo.cpython-32.opt-1.foo.pyc') + + def test_source_from_cache_not_opt(self): + # Non-`opt-` path component -> ValueError self.assertRaises( ValueError, self.util.source_from_cache, '__pycache__/foo.cpython-32.foo.pyc') @@ -583,9 +727,21 @@ class PEP3147Tests: ValueError, self.util.source_from_cache, '/foo/bar/foo.cpython-32.foo.pyc') -Frozen_PEP3147Tests, Source_PEP3147Tests = test_util.test_both( - PEP3147Tests, - util=[frozen_util, source_util]) + def test_source_from_cache_optimized_bytecode(self): + # Optimized bytecode is not an issue. + path = os.path.join('__pycache__', 'foo.{}.opt-1.pyc'.format(self.tag)) + self.assertEqual(self.util.source_from_cache(path), 'foo.py') + + def test_source_from_cache_missing_optimization(self): + # An empty optimization level is a no-no. + path = os.path.join('__pycache__', 'foo.{}.opt-.pyc'.format(self.tag)) + with self.assertRaises(ValueError): + self.util.source_from_cache(path) + + +(Frozen_PEP3147Tests, + Source_PEP3147Tests + ) = util.test_both(PEP3147Tests, util=importlib_util) if __name__ == '__main__': diff --git a/Lib/test/test_importlib/test_windows.py b/Lib/test/test_importlib/test_windows.py index 96b4adc..c893bcf 100644 --- a/Lib/test/test_importlib/test_windows.py +++ b/Lib/test/test_importlib/test_windows.py @@ -1,14 +1,64 @@ -from . import util -frozen_machinery, source_machinery = util.import_importlib('importlib.machinery') +from . import util as test_util +machinery = test_util.import_importlib('importlib.machinery') +import os +import re import sys import unittest +from test import support +from distutils.util import get_platform +from contextlib import contextmanager +from .util import temp_module + +support.import_module('winreg', required_on=['win']) +from winreg import ( + CreateKey, HKEY_CURRENT_USER, + SetValue, REG_SZ, KEY_ALL_ACCESS, + EnumKey, CloseKey, DeleteKey, OpenKey +) + +def delete_registry_tree(root, subkey): + try: + hkey = OpenKey(root, subkey, access=KEY_ALL_ACCESS) + except OSError: + # subkey does not exist + return + while True: + try: + subsubkey = EnumKey(hkey, 0) + except OSError: + # no more subkeys + break + delete_registry_tree(hkey, subsubkey) + CloseKey(hkey) + DeleteKey(root, subkey) + +@contextmanager +def setup_module(machinery, name, path=None): + if machinery.WindowsRegistryFinder.DEBUG_BUILD: + root = machinery.WindowsRegistryFinder.REGISTRY_KEY_DEBUG + else: + root = machinery.WindowsRegistryFinder.REGISTRY_KEY + key = root.format(fullname=name, + sys_version=sys.version[:3]) + try: + with temp_module(name, "a = 1") as location: + subkey = CreateKey(HKEY_CURRENT_USER, key) + if path is None: + path = location + ".py" + SetValue(subkey, "", REG_SZ, path) + yield + finally: + if machinery.WindowsRegistryFinder.DEBUG_BUILD: + key = os.path.dirname(key) + delete_registry_tree(HKEY_CURRENT_USER, key) @unittest.skipUnless(sys.platform.startswith('win'), 'requires Windows') class WindowsRegistryFinderTests: - - # XXX Need a test that finds the spec via the registry. + # The module name is process-specific, allowing for + # simultaneous runs of the same test on a single machine. + test_module = "spamham{}".format(os.getpid()) def test_find_spec_missing(self): spec = self.machinery.WindowsRegistryFinder.find_spec('spam') @@ -18,12 +68,42 @@ class WindowsRegistryFinderTests: loader = self.machinery.WindowsRegistryFinder.find_module('spam') self.assertIs(loader, None) + def test_module_found(self): + with setup_module(self.machinery, self.test_module): + loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module) + spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module) + self.assertIsNot(loader, None) + self.assertIsNot(spec, None) + + def test_module_not_found(self): + with setup_module(self.machinery, self.test_module, path="."): + loader = self.machinery.WindowsRegistryFinder.find_module(self.test_module) + spec = self.machinery.WindowsRegistryFinder.find_spec(self.test_module) + self.assertIsNone(loader) + self.assertIsNone(spec) + +(Frozen_WindowsRegistryFinderTests, + Source_WindowsRegistryFinderTests + ) = test_util.test_both(WindowsRegistryFinderTests, machinery=machinery) + +@unittest.skipUnless(sys.platform.startswith('win'), 'requires Windows') +class WindowsExtensionSuffixTests: + def test_tagged_suffix(self): + suffixes = self.machinery.EXTENSION_SUFFIXES + expected_tag = ".cp{0.major}{0.minor}-{1}.pyd".format(sys.version_info, + re.sub('[^a-zA-Z0-9]', '_', get_platform())) + try: + untagged_i = suffixes.index(".pyd") + except ValueError: + untagged_i = suffixes.index("_d.pyd") + expected_tag = "_d" + expected_tag -class Frozen_WindowsRegistryFinderTests(WindowsRegistryFinderTests, - unittest.TestCase): - machinery = frozen_machinery + self.assertIn(expected_tag, suffixes) + # Ensure the tags are in the correct order + tagged_i = suffixes.index(expected_tag) + self.assertLess(tagged_i, untagged_i) -class Source_WindowsRegistryFinderTests(WindowsRegistryFinderTests, - unittest.TestCase): - machinery = source_machinery +(Frozen_WindowsExtensionSuffixTests, + Source_WindowsExtensionSuffixTests + ) = test_util.test_both(WindowsExtensionSuffixTests, machinery=machinery) diff --git a/Lib/test/test_importlib/util.py b/Lib/test/test_importlib/util.py index 885cec3..ce20377 100644 --- a/Lib/test/test_importlib/util.py +++ b/Lib/test/test_importlib/util.py @@ -1,31 +1,85 @@ -from contextlib import contextmanager -from importlib import util, invalidate_caches +import builtins +import contextlib +import errno +import functools +import importlib +from importlib import machinery, util, invalidate_caches +import os import os.path from test import support import unittest import sys +import tempfile import types +BUILTINS = types.SimpleNamespace() +BUILTINS.good_name = None +BUILTINS.bad_name = None +if 'errno' in sys.builtin_module_names: + BUILTINS.good_name = 'errno' +if 'importlib' not in sys.builtin_module_names: + BUILTINS.bad_name = 'importlib' + +EXTENSIONS = types.SimpleNamespace() +EXTENSIONS.path = None +EXTENSIONS.ext = None +EXTENSIONS.filename = None +EXTENSIONS.file_path = None +EXTENSIONS.name = '_testcapi' + +def _extension_details(): + global EXTENSIONS + for path in sys.path: + for ext in machinery.EXTENSION_SUFFIXES: + filename = EXTENSIONS.name + ext + file_path = os.path.join(path, filename) + if os.path.exists(file_path): + EXTENSIONS.path = path + EXTENSIONS.ext = ext + EXTENSIONS.filename = filename + EXTENSIONS.file_path = file_path + return + +_extension_details() + + def import_importlib(module_name): """Import a module from importlib both w/ and w/o _frozen_importlib.""" fresh = ('importlib',) if '.' in module_name else () frozen = support.import_fresh_module(module_name) source = support.import_fresh_module(module_name, fresh=fresh, - blocked=('_frozen_importlib',)) + blocked=('_frozen_importlib', '_frozen_importlib_external')) + return {'Frozen': frozen, 'Source': source} + + +def specialize_class(cls, kind, base=None, **kwargs): + # XXX Support passing in submodule names--load (and cache) them? + # That would clean up the test modules a bit more. + if base is None: + base = unittest.TestCase + elif not isinstance(base, type): + base = base[kind] + name = '{}_{}'.format(kind, cls.__name__) + bases = (cls, base) + specialized = types.new_class(name, bases) + specialized.__module__ = cls.__module__ + specialized._NAME = cls.__name__ + specialized._KIND = kind + for attr, values in kwargs.items(): + value = values[kind] + setattr(specialized, attr, value) + return specialized + + +def split_frozen(cls, base=None, **kwargs): + frozen = specialize_class(cls, 'Frozen', base, **kwargs) + source = specialize_class(cls, 'Source', base, **kwargs) return frozen, source -def test_both(test_class, **kwargs): - frozen_tests = types.new_class('Frozen_'+test_class.__name__, - (test_class, unittest.TestCase)) - source_tests = types.new_class('Source_'+test_class.__name__, - (test_class, unittest.TestCase)) - frozen_tests.__module__ = source_tests.__module__ = test_class.__module__ - for attr, (frozen_value, source_value) in kwargs.items(): - setattr(frozen_tests, attr, frozen_value) - setattr(source_tests, attr, source_value) - return frozen_tests, source_tests +def test_both(test_class, base=None, **kwargs): + return split_frozen(test_class, base, **kwargs) CASE_INSENSITIVE_FS = True @@ -38,6 +92,10 @@ if sys.platform not in ('win32', 'cygwin'): if not os.path.exists(changed_name): CASE_INSENSITIVE_FS = False +source_importlib = import_importlib('importlib')['Source'] +__import__ = {'Frozen': staticmethod(builtins.__import__), + 'Source': staticmethod(source_importlib.__import__)} + def case_insensitive_tests(test): """Class decorator that nullifies tests requiring a case-insensitive @@ -53,7 +111,7 @@ def submodule(parent, name, pkg_dir, content=''): return '{}.{}'.format(parent, name), path -@contextmanager +@contextlib.contextmanager def uncache(*names): """Uncache a module from sys.modules. @@ -79,7 +137,7 @@ def uncache(*names): pass -@contextmanager +@contextlib.contextmanager def temp_module(name, content='', *, pkg=False): conflicts = [n for n in sys.modules if n.partition('.')[0] == name] with support.temp_cwd(None) as cwd: @@ -103,7 +161,7 @@ def temp_module(name, content='', *, pkg=False): yield location -@contextmanager +@contextlib.contextmanager def import_state(**kwargs): """Context manager to manage the various importers and stored state in the sys module. @@ -198,6 +256,7 @@ class mock_modules(_ImporterMock): raise return self.modules[fullname] + class mock_spec(_ImporterMock): """Importer mock using PEP 451 APIs.""" @@ -223,3 +282,99 @@ class mock_spec(_ImporterMock): self.module_code[module.__spec__.name]() except KeyError: pass + + +def writes_bytecode_files(fxn): + """Decorator to protect sys.dont_write_bytecode from mutation and to skip + tests that require it to be set to False.""" + if sys.dont_write_bytecode: + return lambda *args, **kwargs: None + @functools.wraps(fxn) + def wrapper(*args, **kwargs): + original = sys.dont_write_bytecode + sys.dont_write_bytecode = False + try: + to_return = fxn(*args, **kwargs) + finally: + sys.dont_write_bytecode = original + return to_return + return wrapper + + +def ensure_bytecode_path(bytecode_path): + """Ensure that the __pycache__ directory for PEP 3147 pyc file exists. + + :param bytecode_path: File system path to PEP 3147 pyc file. + """ + try: + os.mkdir(os.path.dirname(bytecode_path)) + except OSError as error: + if error.errno != errno.EEXIST: + raise + + +@contextlib.contextmanager +def create_modules(*names): + """Temporarily create each named module with an attribute (named 'attr') + that contains the name passed into the context manager that caused the + creation of the module. + + All files are created in a temporary directory returned by + tempfile.mkdtemp(). This directory is inserted at the beginning of + sys.path. When the context manager exits all created files (source and + bytecode) are explicitly deleted. + + No magic is performed when creating packages! This means that if you create + a module within a package you must also create the package's __init__ as + well. + + """ + source = 'attr = {0!r}' + created_paths = [] + mapping = {} + state_manager = None + uncache_manager = None + try: + temp_dir = tempfile.mkdtemp() + mapping['.root'] = temp_dir + import_names = set() + for name in names: + if not name.endswith('__init__'): + import_name = name + else: + import_name = name[:-len('.__init__')] + import_names.add(import_name) + if import_name in sys.modules: + del sys.modules[import_name] + name_parts = name.split('.') + file_path = temp_dir + for directory in name_parts[:-1]: + file_path = os.path.join(file_path, directory) + if not os.path.exists(file_path): + os.mkdir(file_path) + created_paths.append(file_path) + file_path = os.path.join(file_path, name_parts[-1] + '.py') + with open(file_path, 'w') as file: + file.write(source.format(name)) + created_paths.append(file_path) + mapping[name] = file_path + uncache_manager = uncache(*import_names) + uncache_manager.__enter__() + state_manager = import_state(path=[temp_dir]) + state_manager.__enter__() + yield mapping + finally: + if state_manager is not None: + state_manager.__exit__(None, None, None) + if uncache_manager is not None: + uncache_manager.__exit__(None, None, None) + support.rmtree(temp_dir) + + +def mock_path_hook(*entries, importer): + """A mock sys.path_hooks entry.""" + def hook(entry): + if entry not in entries: + raise ImportError + return importer + return hook diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py index 1a124b5..955b2ad 100644 --- a/Lib/test/test_inspect.py +++ b/Lib/test/test_inspect.py @@ -1,3 +1,4 @@ +import builtins import collections import datetime import functools @@ -8,6 +9,7 @@ import linecache import os from os.path import normcase import _pickle +import pickle import re import shutil import sys @@ -16,6 +18,7 @@ import textwrap import unicodedata import unittest import unittest.mock +import warnings try: from concurrent.futures import ThreadPoolExecutor @@ -23,8 +26,8 @@ except ImportError: ThreadPoolExecutor = None from test.support import run_unittest, TESTFN, DirsOnSysPath, cpython_only -from test.support import MISSING_C_DOCSTRINGS -from test.script_helper import assert_python_ok, assert_python_failure +from test.support import MISSING_C_DOCSTRINGS, cpython_only +from test.support.script_helper import assert_python_ok, assert_python_failure from test import inspect_fodder as mod from test import inspect_fodder2 as mod2 @@ -60,14 +63,16 @@ class IsTestBase(unittest.TestCase): predicates = set([inspect.isbuiltin, inspect.isclass, inspect.iscode, inspect.isframe, inspect.isfunction, inspect.ismethod, inspect.ismodule, inspect.istraceback, - inspect.isgenerator, inspect.isgeneratorfunction]) + inspect.isgenerator, inspect.isgeneratorfunction, + inspect.iscoroutine, inspect.iscoroutinefunction]) def istest(self, predicate, exp): obj = eval(exp) self.assertTrue(predicate(obj), '%s(%s)' % (predicate.__name__, exp)) for other in self.predicates - set([predicate]): - if predicate == inspect.isgeneratorfunction and\ + if (predicate == inspect.isgeneratorfunction or \ + predicate == inspect.iscoroutinefunction) and \ other == inspect.isfunction: continue self.assertFalse(other(obj), 'not %s(%s)' % (other.__name__, exp)) @@ -76,19 +81,19 @@ def generator_function_example(self): for i in range(2): yield i +async def coroutine_function_example(self): + return 'spam' + +@types.coroutine +def gen_coroutine_function_example(self): + yield + return 'spam' + class EqualsToAll: def __eq__(self, other): return True class TestPredicates(IsTestBase): - def test_sixteen(self): - count = len([x for x in dir(inspect) if x.startswith('is')]) - # This test is here for remember you to update Doc/library/inspect.rst - # which claims there are 16 such functions - expected = 16 - err_msg = "There are %d (not %d) is* functions" % (count, expected) - self.assertEqual(count, expected, err_msg) - def test_excluding_predicates(self): global tb @@ -116,11 +121,62 @@ class TestPredicates(IsTestBase): self.istest(inspect.isdatadescriptor, 'collections.defaultdict.default_factory') self.istest(inspect.isgenerator, '(x for x in range(2))') self.istest(inspect.isgeneratorfunction, 'generator_function_example') + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.istest(inspect.iscoroutine, 'coroutine_function_example(1)') + self.istest(inspect.iscoroutinefunction, 'coroutine_function_example') + if hasattr(types, 'MemberDescriptorType'): self.istest(inspect.ismemberdescriptor, 'datetime.timedelta.days') else: self.assertFalse(inspect.ismemberdescriptor(datetime.timedelta.days)) + def test_iscoroutine(self): + gen_coro = gen_coroutine_function_example(1) + coro = coroutine_function_example(1) + + self.assertFalse( + inspect.iscoroutinefunction(gen_coroutine_function_example)) + self.assertFalse(inspect.iscoroutine(gen_coro)) + + self.assertTrue( + inspect.isgeneratorfunction(gen_coroutine_function_example)) + self.assertTrue(inspect.isgenerator(gen_coro)) + + self.assertTrue( + inspect.iscoroutinefunction(coroutine_function_example)) + self.assertTrue(inspect.iscoroutine(coro)) + + self.assertFalse( + inspect.isgeneratorfunction(coroutine_function_example)) + self.assertFalse(inspect.isgenerator(coro)) + + coro.close(); gen_coro.close() # silence warnings + + def test_isawaitable(self): + def gen(): yield + self.assertFalse(inspect.isawaitable(gen())) + + coro = coroutine_function_example(1) + gen_coro = gen_coroutine_function_example(1) + + self.assertTrue(inspect.isawaitable(coro)) + self.assertTrue(inspect.isawaitable(gen_coro)) + + class Future: + def __await__(): + pass + self.assertTrue(inspect.isawaitable(Future())) + self.assertFalse(inspect.isawaitable(Future)) + + class NotFuture: pass + not_fut = NotFuture() + not_fut.__await__ = lambda: None + self.assertFalse(inspect.isawaitable(not_fut)) + + coro.close(); gen_coro.close() # silence warnings + def test_isroutine(self): self.assertTrue(inspect.isroutine(mod.spam)) self.assertTrue(inspect.isroutine([].count)) @@ -186,6 +242,14 @@ class TestInterpreterStack(IsTestBase): (modfile, 43, 'argue', [' spam(a, b, c)\n'], 0)) self.assertEqual(revise(*mod.st[3][1:]), (modfile, 39, 'abuse', [' self.argue(a, b, c)\n'], 0)) + # Test named tuple fields + record = mod.st[0] + self.assertIs(record.frame, mod.fr) + self.assertEqual(record.lineno, 16) + self.assertEqual(record.filename, mod.__file__) + self.assertEqual(record.function, 'eggs') + self.assertIn('inspect.stack()', record.code_context[0]) + self.assertEqual(record.index, 0) def test_trace(self): self.assertEqual(len(git.tr), 3) @@ -217,9 +281,7 @@ class GetSourceBase(unittest.TestCase): # Subclasses must override. fodderModule = None - def __init__(self, *args, **kwargs): - unittest.TestCase.__init__(self, *args, **kwargs) - + def setUp(self): with open(inspect.getsourcefile(self.fodderModule)) as fp: self.source = fp.read() @@ -274,6 +336,7 @@ class TestRetrievingSourceCode(GetSourceBase): def test_getfunctions(self): functions = inspect.getmembers(mod, inspect.isfunction) self.assertEqual(functions, [('eggs', mod.eggs), + ('lobbest', mod.lobbest), ('spam', mod.spam)]) @unittest.skipIf(sys.flags.optimize >= 2, @@ -285,6 +348,27 @@ class TestRetrievingSourceCode(GetSourceBase): self.assertEqual(inspect.getdoc(git.abuse), 'Another\n\ndocstring\n\ncontaining\n\ntabs') + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_getdoc_inherited(self): + self.assertEqual(inspect.getdoc(mod.FesteringGob), + 'A longer,\n\nindented\n\ndocstring.') + self.assertEqual(inspect.getdoc(mod.FesteringGob.abuse), + 'Another\n\ndocstring\n\ncontaining\n\ntabs') + self.assertEqual(inspect.getdoc(mod.FesteringGob().abuse), + 'Another\n\ndocstring\n\ncontaining\n\ntabs') + self.assertEqual(inspect.getdoc(mod.FesteringGob.contradiction), + 'The automatic gainsaying.') + + @unittest.skipIf(MISSING_C_DOCSTRINGS, "test requires docstrings") + def test_finddoc(self): + finddoc = inspect._finddoc + self.assertEqual(finddoc(int), int.__doc__) + self.assertEqual(finddoc(int.to_bytes), int.to_bytes.__doc__) + self.assertEqual(finddoc(int().to_bytes), int.to_bytes.__doc__) + self.assertEqual(finddoc(int.from_bytes), int.from_bytes.__doc__) + self.assertEqual(finddoc(int.real), int.real.__doc__) + def test_cleandoc(self): self.assertEqual(inspect.cleandoc('An\n indented\n docstring.'), 'An\nindented\ndocstring.') @@ -309,7 +393,8 @@ class TestRetrievingSourceCode(GetSourceBase): def test_getsource(self): self.assertSourceEqual(git.abuse, 29, 39) - self.assertSourceEqual(mod.StupidGit, 21, 46) + self.assertSourceEqual(mod.StupidGit, 21, 50) + self.assertSourceEqual(mod.lobbest, 70, 71) def test_getsourcefile(self): self.assertEqual(normcase(inspect.getsourcefile(mod.spam)), modfile) @@ -364,6 +449,9 @@ class TestRetrievingSourceCode(GetSourceBase): finally: linecache.getlines = getlines + def test_getsource_on_code_object(self): + self.assertSourceEqual(mod.eggs.__code__, 12, 18) + class TestDecorators(GetSourceBase): fodderModule = mod2 @@ -373,6 +461,12 @@ class TestDecorators(GetSourceBase): def test_replacing_decorator(self): self.assertSourceEqual(mod2.gone, 9, 10) + def test_getsource_unwrap(self): + self.assertSourceEqual(mod2.real, 130, 132) + + def test_decorator_with_lambda(self): + self.assertSourceEqual(mod2.func114, 113, 115) + class TestOneliners(GetSourceBase): fodderModule = mod2 def test_oneline_lambda(self): @@ -466,8 +560,15 @@ class TestBuggyCases(GetSourceBase): self.assertRaises(IOError, inspect.findsource, co) self.assertRaises(IOError, inspect.getsource, co) + def test_getsource_on_method(self): + self.assertSourceEqual(mod2.ClassWithMethod.method, 118, 119) + + def test_nested_func(self): + self.assertSourceEqual(mod2.cls135.func136, 136, 139) + + class TestNoEOL(GetSourceBase): - def __init__(self, *args, **kwargs): + def setUp(self): self.tempdir = TESTFN + '_dir' os.mkdir(self.tempdir) with open(os.path.join(self.tempdir, @@ -476,7 +577,7 @@ class TestNoEOL(GetSourceBase): with DirsOnSysPath(self.tempdir): import inspect_fodder3 as mod3 self.fodderModule = mod3 - GetSourceBase.__init__(self, *args, **kwargs) + super().setUp() def tearDown(self): shutil.rmtree(self.tempdir) @@ -529,7 +630,8 @@ class TestClassesAndFunctions(unittest.TestCase): def assertArgSpecEquals(self, routine, args_e, varargs_e=None, varkw_e=None, defaults_e=None, formatted=None): - args, varargs, varkw, defaults = inspect.getargspec(routine) + with self.assertWarns(DeprecationWarning): + args, varargs, varkw, defaults = inspect.getargspec(routine) self.assertEqual(args, args_e) self.assertEqual(varargs, varargs_e) self.assertEqual(varkw, varkw_e) @@ -1634,10 +1736,86 @@ class TestGetGeneratorState(unittest.TestCase): self.assertRaises(TypeError, inspect.getgeneratorlocals, (2,3)) +class TestGetCoroutineState(unittest.TestCase): + + def setUp(self): + @types.coroutine + def number_coroutine(): + for number in range(5): + yield number + async def coroutine(): + await number_coroutine() + self.coroutine = coroutine() + + def tearDown(self): + self.coroutine.close() + + def _coroutinestate(self): + return inspect.getcoroutinestate(self.coroutine) + + def test_created(self): + self.assertEqual(self._coroutinestate(), inspect.CORO_CREATED) + + def test_suspended(self): + self.coroutine.send(None) + self.assertEqual(self._coroutinestate(), inspect.CORO_SUSPENDED) + + def test_closed_after_exhaustion(self): + while True: + try: + self.coroutine.send(None) + except StopIteration: + break + + self.assertEqual(self._coroutinestate(), inspect.CORO_CLOSED) + + def test_closed_after_immediate_exception(self): + with self.assertRaises(RuntimeError): + self.coroutine.throw(RuntimeError) + self.assertEqual(self._coroutinestate(), inspect.CORO_CLOSED) + + def test_easy_debugging(self): + # repr() and str() of a coroutine state should contain the state name + names = 'CORO_CREATED CORO_RUNNING CORO_SUSPENDED CORO_CLOSED'.split() + for name in names: + state = getattr(inspect, name) + self.assertIn(name, repr(state)) + self.assertIn(name, str(state)) + + def test_getcoroutinelocals(self): + @types.coroutine + def gencoro(): + yield + + gencoro = gencoro() + async def func(a=None): + b = 'spam' + await gencoro + + coro = func() + self.assertEqual(inspect.getcoroutinelocals(coro), + {'a': None, 'gencoro': gencoro}) + coro.send(None) + self.assertEqual(inspect.getcoroutinelocals(coro), + {'a': None, 'gencoro': gencoro, 'b': 'spam'}) + + +class MySignature(inspect.Signature): + # Top-level to make it picklable; + # used in test_signature_object_pickle + pass + +class MyParameter(inspect.Parameter): + # Top-level to make it picklable; + # used in test_signature_object_pickle + pass + + + class TestSignatureObject(unittest.TestCase): @staticmethod - def signature(func): - sig = inspect.signature(func) + def signature(func, **kw): + sig = inspect.signature(func, **kw) return (tuple((param.name, (... if param.default is param.empty else param.default), (... if param.annotation is param.empty @@ -1691,6 +1869,37 @@ class TestSignatureObject(unittest.TestCase): with self.assertRaisesRegex(ValueError, 'follows default argument'): S((pkd, pk)) + self.assertTrue(repr(sig).startswith('<Signature')) + self.assertTrue('(po, pk' in repr(sig)) + + def test_signature_object_pickle(self): + def foo(a, b, *, c:1={}, **kw) -> {42:'ham'}: pass + foo_partial = functools.partial(foo, a=1) + + sig = inspect.signature(foo_partial) + + for ver in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(pickle_ver=ver, subclass=False): + sig_pickled = pickle.loads(pickle.dumps(sig, ver)) + self.assertEqual(sig, sig_pickled) + + # Test that basic sub-classing works + sig = inspect.signature(foo) + myparam = MyParameter(name='z', kind=inspect.Parameter.POSITIONAL_ONLY) + myparams = collections.OrderedDict(sig.parameters, a=myparam) + mysig = MySignature().replace(parameters=myparams.values(), + return_annotation=sig.return_annotation) + self.assertTrue(isinstance(mysig, MySignature)) + self.assertTrue(isinstance(mysig.parameters['z'], MyParameter)) + + for ver in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(pickle_ver=ver, subclass=True): + sig_pickled = pickle.loads(pickle.dumps(mysig, ver)) + self.assertEqual(mysig, sig_pickled) + self.assertTrue(isinstance(sig_pickled, MySignature)) + self.assertTrue(isinstance(sig_pickled.parameters['z'], + MyParameter)) + def test_signature_immutability(self): def test(a): pass @@ -1805,6 +2014,8 @@ class TestSignatureObject(unittest.TestCase): test_unbound_method(dict.__delitem__) test_unbound_method(property.__delete__) + # Regression test for issue #20586 + test_callable(_testcapi.docstring_with_signature_but_no_doc) @cpython_only @unittest.skipIf(MISSING_C_DOCSTRINGS, @@ -1824,23 +2035,26 @@ class TestSignatureObject(unittest.TestCase): self.assertEqual(inspect.signature(func), inspect.signature(decorated_func)) + def wrapper_like(*args, **kwargs) -> int: pass + self.assertEqual(inspect.signature(decorated_func, + follow_wrapped=False), + inspect.signature(wrapper_like)) + @cpython_only def test_signature_on_builtins_no_signature(self): import _testcapi - with self.assertRaisesRegex(ValueError, 'no signature found for builtin'): + with self.assertRaisesRegex(ValueError, + 'no signature found for builtin'): inspect.signature(_testcapi.docstring_no_signature) + with self.assertRaisesRegex(ValueError, + 'no signature found for builtin'): + inspect.signature(str) + def test_signature_on_non_function(self): with self.assertRaisesRegex(TypeError, 'is not a callable object'): inspect.signature(42) - with self.assertRaisesRegex(TypeError, 'is not a Python function'): - inspect.Signature.from_function(42) - - def test_signature_from_builtin_errors(self): - with self.assertRaisesRegex(TypeError, 'is not a Python builtin'): - inspect.Signature.from_builtin(42) - def test_signature_from_functionlike_object(self): def func(a,b, *args, kwonly=True, kwonlyreq, **kwargs): pass @@ -1861,9 +2075,9 @@ class TestSignatureObject(unittest.TestCase): def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) - sig_func = inspect.Signature.from_function(func) + sig_func = inspect.Signature.from_callable(func) - sig_funclike = inspect.Signature.from_function(funclike(func)) + sig_funclike = inspect.Signature.from_callable(funclike(func)) self.assertEqual(sig_funclike, sig_func) sig_funclike = inspect.signature(funclike(func)) @@ -1911,9 +2125,6 @@ class TestSignatureObject(unittest.TestCase): __defaults__ = func.__defaults__ __kwdefaults__ = func.__kwdefaults__ - with self.assertRaisesRegex(TypeError, 'is not a Python function'): - inspect.Signature.from_function(funclike) - self.assertEqual(str(inspect.signature(funclike)), '(marker)') def test_signature_on_method(self): @@ -2265,6 +2476,13 @@ class TestSignatureObject(unittest.TestCase): ('b', ..., ..., "positional_or_keyword")), ...)) + self.assertEqual(self.signature(Foo.bar, follow_wrapped=False), + ((('args', ..., ..., "var_positional"), + ('kwargs', ..., ..., "var_keyword")), + ...)) # functools.wraps will copy __annotations__ + # from "func" to "wrapper", hence no + # return_annotation + # Test that we handle method wrappers correctly def decorator(func): @functools.wraps(func) @@ -2471,60 +2689,102 @@ class TestSignatureObject(unittest.TestCase): def bar(a, *, b:int) -> float: pass self.assertTrue(inspect.signature(foo) == inspect.signature(bar)) self.assertFalse(inspect.signature(foo) != inspect.signature(bar)) + self.assertEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def bar(a, *, b:int) -> int: pass self.assertFalse(inspect.signature(foo) == inspect.signature(bar)) self.assertTrue(inspect.signature(foo) != inspect.signature(bar)) + self.assertNotEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def bar(a, *, b:int): pass self.assertFalse(inspect.signature(foo) == inspect.signature(bar)) self.assertTrue(inspect.signature(foo) != inspect.signature(bar)) + self.assertNotEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def bar(a, *, b:int=42) -> float: pass self.assertFalse(inspect.signature(foo) == inspect.signature(bar)) self.assertTrue(inspect.signature(foo) != inspect.signature(bar)) + self.assertNotEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def bar(a, *, c) -> float: pass self.assertFalse(inspect.signature(foo) == inspect.signature(bar)) self.assertTrue(inspect.signature(foo) != inspect.signature(bar)) + self.assertNotEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def bar(a, b:int) -> float: pass self.assertFalse(inspect.signature(foo) == inspect.signature(bar)) self.assertTrue(inspect.signature(foo) != inspect.signature(bar)) + self.assertNotEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def spam(b:int, a) -> float: pass self.assertFalse(inspect.signature(spam) == inspect.signature(bar)) self.assertTrue(inspect.signature(spam) != inspect.signature(bar)) + self.assertNotEqual( + hash(inspect.signature(spam)), hash(inspect.signature(bar))) def foo(*, a, b, c): pass def bar(*, c, b, a): pass self.assertTrue(inspect.signature(foo) == inspect.signature(bar)) self.assertFalse(inspect.signature(foo) != inspect.signature(bar)) + self.assertEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def foo(*, a=1, b, c): pass def bar(*, c, b, a=1): pass self.assertTrue(inspect.signature(foo) == inspect.signature(bar)) self.assertFalse(inspect.signature(foo) != inspect.signature(bar)) + self.assertEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def foo(pos, *, a=1, b, c): pass def bar(pos, *, c, b, a=1): pass self.assertTrue(inspect.signature(foo) == inspect.signature(bar)) self.assertFalse(inspect.signature(foo) != inspect.signature(bar)) + self.assertEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def foo(pos, *, a, b, c): pass def bar(pos, *, c, b, a=1): pass self.assertFalse(inspect.signature(foo) == inspect.signature(bar)) self.assertTrue(inspect.signature(foo) != inspect.signature(bar)) + self.assertNotEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) def foo(pos, *args, a=42, b, c, **kwargs:int): pass def bar(pos, *args, c, b, a=42, **kwargs:int): pass self.assertTrue(inspect.signature(foo) == inspect.signature(bar)) self.assertFalse(inspect.signature(foo) != inspect.signature(bar)) + self.assertEqual( + hash(inspect.signature(foo)), hash(inspect.signature(bar))) + + def test_signature_hashable(self): + S = inspect.Signature + P = inspect.Parameter - def test_signature_unhashable(self): def foo(a): pass - sig = inspect.signature(foo) + foo_sig = inspect.signature(foo) + + manual_sig = S(parameters=[P('a', P.POSITIONAL_OR_KEYWORD)]) + + self.assertEqual(hash(foo_sig), hash(manual_sig)) + self.assertNotEqual(hash(foo_sig), + hash(manual_sig.replace(return_annotation='spam'))) + + def bar(a) -> 1: pass + self.assertNotEqual(hash(foo_sig), hash(inspect.signature(bar))) + + def foo(a={}): pass with self.assertRaisesRegex(TypeError, 'unhashable type'): - hash(sig) + hash(inspect.signature(foo)) + + def foo(a) -> {}: pass + with self.assertRaisesRegex(TypeError, 'unhashable type'): + hash(inspect.signature(foo)) def test_signature_str(self): def foo(a:int=1, *, b, c=None, **kwargs) -> 42: @@ -2598,6 +2858,19 @@ class TestSignatureObject(unittest.TestCase): self.assertEqual(self.signature(Spam.foo), self.signature(Ham.foo)) + def test_signature_from_callable_python_obj(self): + class MySignature(inspect.Signature): pass + def foo(a, *, b:1): pass + foo_sig = MySignature.from_callable(foo) + self.assertTrue(isinstance(foo_sig, MySignature)) + + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_signature_from_callable_builtin_obj(self): + class MySignature(inspect.Signature): pass + sig = MySignature.from_callable(_pickle.Pickler) + self.assertTrue(isinstance(sig, MySignature)) + class TestParameterObject(unittest.TestCase): def test_signature_parameter_kinds(self): @@ -2643,6 +2916,16 @@ class TestParameterObject(unittest.TestCase): p.replace(kind=inspect.Parameter.VAR_POSITIONAL) self.assertTrue(repr(p).startswith('<Parameter')) + self.assertTrue('"a=42"' in repr(p)) + + def test_signature_parameter_hashable(self): + P = inspect.Parameter + foo = P('foo', kind=P.POSITIONAL_ONLY) + self.assertEqual(hash(foo), hash(P('foo', kind=P.POSITIONAL_ONLY))) + self.assertNotEqual(hash(foo), hash(P('foo', kind=P.POSITIONAL_ONLY, + default=42))) + self.assertNotEqual(hash(foo), + hash(foo.replace(kind=P.VAR_POSITIONAL))) def test_signature_parameter_equality(self): P = inspect.Parameter @@ -2660,13 +2943,6 @@ class TestParameterObject(unittest.TestCase): self.assertFalse(p != P('foo', default=42, kind=inspect.Parameter.KEYWORD_ONLY)) - def test_signature_parameter_unhashable(self): - p = inspect.Parameter('foo', default=42, - kind=inspect.Parameter.KEYWORD_ONLY) - - with self.assertRaisesRegex(TypeError, 'unhashable type'): - hash(p) - def test_signature_parameter_replace(self): p = inspect.Parameter('foo', default=42, kind=inspect.Parameter.KEYWORD_ONLY) @@ -2735,7 +3011,9 @@ class TestSignatureBind(unittest.TestCase): self.call(test, 1) with self.assertRaisesRegex(TypeError, 'too many positional arguments'): self.call(test, 1, spam=10) - with self.assertRaisesRegex(TypeError, 'too many keyword arguments'): + with self.assertRaisesRegex( + TypeError, "got an unexpected keyword argument 'spam'"): + self.call(test, spam=1) def test_signature_bind_var(self): @@ -2760,10 +3038,12 @@ class TestSignatureBind(unittest.TestCase): with self.assertRaisesRegex(TypeError, 'too many positional arguments'): self.call(test, 1, 2, 3, 4) - with self.assertRaisesRegex(TypeError, "'b' parameter lacking default"): + with self.assertRaisesRegex(TypeError, + "missing a required argument: 'b'"): self.call(test, 1) - with self.assertRaisesRegex(TypeError, "'a' parameter lacking default"): + with self.assertRaisesRegex(TypeError, + "missing a required argument: 'a'"): self.call(test) def test(a, b, c=10): @@ -2836,7 +3116,7 @@ class TestSignatureBind(unittest.TestCase): def test(a, *, foo=1, bar): return foo with self.assertRaisesRegex(TypeError, - "'bar' parameter lacking default value"): + "missing a required argument: 'bar'"): self.call(test, 1) def test(foo, *, bar): @@ -2844,8 +3124,9 @@ class TestSignatureBind(unittest.TestCase): self.assertEqual(self.call(test, 1, bar=2), (1, 2)) self.assertEqual(self.call(test, bar=2, foo=1), (1, 2)) - with self.assertRaisesRegex(TypeError, - 'too many keyword arguments'): + with self.assertRaisesRegex( + TypeError, "got an unexpected keyword argument 'spam'"): + self.call(test, bar=2, foo=1, spam=10) with self.assertRaisesRegex(TypeError, @@ -2856,12 +3137,13 @@ class TestSignatureBind(unittest.TestCase): 'too many positional arguments'): self.call(test, 1, 2, bar=2) - with self.assertRaisesRegex(TypeError, - 'too many keyword arguments'): + with self.assertRaisesRegex( + TypeError, "got an unexpected keyword argument 'spam'"): + self.call(test, 1, bar=2, spam='ham') with self.assertRaisesRegex(TypeError, - "'bar' parameter lacking default value"): + "missing a required argument: 'bar'"): self.call(test, 1) def test(foo, *, bar, **bin): @@ -2873,7 +3155,7 @@ class TestSignatureBind(unittest.TestCase): self.assertEqual(self.call(test, spam='ham', foo=1, bar=2), (1, 2, {'spam': 'ham'})) with self.assertRaisesRegex(TypeError, - "'foo' parameter lacking default value"): + "missing a required argument: 'foo'"): self.call(test, spam='ham', bar=2) self.assertEqual(self.call(test, 1, bar=2, bin=1, spam=10), (1, 2, {'bin': 1, 'spam': 10})) @@ -2938,7 +3220,9 @@ class TestSignatureBind(unittest.TestCase): return a, args sig = inspect.signature(test) - with self.assertRaisesRegex(TypeError, "too many keyword arguments"): + with self.assertRaisesRegex( + TypeError, "got an unexpected keyword argument 'args'"): + sig.bind(a=0, args=1) def test(*args, **kwargs): @@ -2982,6 +3266,64 @@ class TestBoundArguments(unittest.TestCase): self.assertFalse(ba == ba4) self.assertTrue(ba != ba4) + def foo(*, a, b): pass + sig = inspect.signature(foo) + ba1 = sig.bind(a=1, b=2) + ba2 = sig.bind(b=2, a=1) + self.assertTrue(ba1 == ba2) + self.assertFalse(ba1 != ba2) + + def test_signature_bound_arguments_pickle(self): + def foo(a, b, *, c:1={}, **kw) -> {42:'ham'}: pass + sig = inspect.signature(foo) + ba = sig.bind(20, 30, z={}) + + for ver in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(pickle_ver=ver): + ba_pickled = pickle.loads(pickle.dumps(ba, ver)) + self.assertEqual(ba, ba_pickled) + + def test_signature_bound_arguments_repr(self): + def foo(a, b, *, c:1={}, **kw) -> {42:'ham'}: pass + sig = inspect.signature(foo) + ba = sig.bind(20, 30, z={}) + self.assertRegex(repr(ba), r'<BoundArguments \(a=20,.*\}\}\)>') + + def test_signature_bound_arguments_apply_defaults(self): + def foo(a, b=1, *args, c:1={}, **kw): pass + sig = inspect.signature(foo) + + ba = sig.bind(20) + ba.apply_defaults() + self.assertEqual( + list(ba.arguments.items()), + [('a', 20), ('b', 1), ('args', ()), ('c', {}), ('kw', {})]) + + # Make sure that we preserve the order: + # i.e. 'c' should be *before* 'kw'. + ba = sig.bind(10, 20, 30, d=1) + ba.apply_defaults() + self.assertEqual( + list(ba.arguments.items()), + [('a', 10), ('b', 20), ('args', (30,)), ('c', {}), ('kw', {'d':1})]) + + # Make sure that BoundArguments produced by bind_partial() + # are supported. + def foo(a, b): pass + sig = inspect.signature(foo) + ba = sig.bind_partial(20) + ba.apply_defaults() + self.assertEqual( + list(ba.arguments.items()), + [('a', 20)]) + + # Test no args + def foo(): pass + sig = inspect.signature(foo) + ba = sig.bind() + ba.apply_defaults() + self.assertEqual(list(ba.arguments.items()), []) + class TestSignaturePrivateHelpers(unittest.TestCase): def test_signature_get_bound_param(self): @@ -3046,6 +3388,61 @@ class TestSignaturePrivateHelpers(unittest.TestCase): None, None) +class TestSignatureDefinitions(unittest.TestCase): + # This test case provides a home for checking that particular APIs + # have signatures available for introspection + + @cpython_only + @unittest.skipIf(MISSING_C_DOCSTRINGS, + "Signature information for builtins requires docstrings") + def test_builtins_have_signatures(self): + # This checks all builtin callables in CPython have signatures + # A few have signatures Signature can't yet handle, so we skip those + # since they will have to wait until PEP 457 adds the required + # introspection support to the inspect module + # Some others also haven't been converted yet for various other + # reasons, so we also skip those for the time being, but design + # the test to fail in order to indicate when it needs to be + # updated. + no_signature = set() + # These need PEP 457 groups + needs_groups = {"range", "slice", "dir", "getattr", + "next", "iter", "vars"} + no_signature |= needs_groups + # These need PEP 457 groups or a signature change to accept None + needs_semantic_update = {"round"} + no_signature |= needs_semantic_update + # These need *args support in Argument Clinic + needs_varargs = {"min", "max", "print", "__build_class__"} + no_signature |= needs_varargs + # These simply weren't covered in the initial AC conversion + # for builtin callables + not_converted_yet = {"open", "__import__"} + no_signature |= not_converted_yet + # These builtin types are expected to provide introspection info + types_with_signatures = set() + # Check the signatures we expect to be there + ns = vars(builtins) + for name, obj in sorted(ns.items()): + if not callable(obj): + continue + # The builtin types haven't been converted to AC yet + if isinstance(obj, type) and (name not in types_with_signatures): + # Note that this also skips all the exception types + no_signature.add(name) + if (name in no_signature): + # Not yet converted + continue + with self.subTest(builtin=name): + self.assertIsNotNone(inspect.signature(obj)) + # Check callables that haven't been converted don't claim a signature + # This ensures this test will start failing as more signatures are + # added, so the affected items can be moved into the scope of the + # regression test above + for name in no_signature: + with self.subTest(builtin=name): + self.assertIsNone(obj.__text_signature__) + class TestUnwrap(unittest.TestCase): @@ -3187,8 +3584,10 @@ def test_main(): TestGetcallargsFunctions, TestGetcallargsMethods, TestGetcallargsUnboundMethods, TestGetattrStatic, TestGetGeneratorState, TestNoEOL, TestSignatureObject, TestSignatureBind, TestParameterObject, - TestBoundArguments, TestSignaturePrivateHelpers, TestGetClosureVars, - TestUnwrap, TestMain, TestReload + TestBoundArguments, TestSignaturePrivateHelpers, + TestSignatureDefinitions, + TestGetClosureVars, TestUnwrap, TestMain, TestReload, + TestGetCoroutineState ) if __name__ == "__main__": diff --git a/Lib/test/test_int.py b/Lib/test/test_int.py index e94602e..cb57f15 100644 --- a/Lib/test/test_int.py +++ b/Lib/test/test_int.py @@ -451,8 +451,5 @@ class IntTestCases(unittest.TestCase): check('123\ud800') check('123\ud800', 10) -def test_main(): - support.run_unittest(IntTestCases) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_int_literal.py b/Lib/test/test_int_literal.py index 1d578a7..bf72571 100644 --- a/Lib/test/test_int_literal.py +++ b/Lib/test/test_int_literal.py @@ -4,7 +4,6 @@ This is complex because of changes due to PEP 237. """ import unittest -from test import support class TestHexOctBin(unittest.TestCase): @@ -140,8 +139,5 @@ class TestHexOctBin(unittest.TestCase): self.assertEqual(-0b1000000000000000000000000000000000000000000000000000000000000000, -9223372036854775808) self.assertEqual(-0b1111111111111111111111111111111111111111111111111111111111111111, -18446744073709551615) -def test_main(): - support.run_unittest(TestHexOctBin) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_io.py b/Lib/test/test_io.py index f654406..465d45a 100644 --- a/Lib/test/test_io.py +++ b/Lib/test/test_io.py @@ -35,7 +35,7 @@ import weakref from collections import deque, UserList from itertools import cycle, count from test import support -from test.script_helper import assert_python_ok, run_python_until_end +from test.support.script_helper import assert_python_ok, run_python_until_end import codecs import io # C implementation of io @@ -44,10 +44,6 @@ try: import threading except ImportError: threading = None -try: - import fcntl -except ImportError: - fcntl = None def _default_chunk_size(): """Get the default TextIOWrapper chunk size""" @@ -367,8 +363,8 @@ class IOTest(unittest.TestCase): def test_open_handles_NUL_chars(self): fn_with_NUL = 'foo\0bar' - self.assertRaises(TypeError, self.open, fn_with_NUL, 'w') - self.assertRaises(TypeError, self.open, bytes(fn_with_NUL, 'ascii'), 'w') + self.assertRaises(ValueError, self.open, fn_with_NUL, 'w') + self.assertRaises(ValueError, self.open, bytes(fn_with_NUL, 'ascii'), 'w') def test_raw_file_io(self): with self.open(support.TESTFN, "wb", buffering=0) as f: @@ -723,6 +719,21 @@ class PyIOTest(IOTest): pass +@support.cpython_only +class APIMismatchTest(unittest.TestCase): + + def test_RawIOBase_io_in_pyio_match(self): + """Test that pyio RawIOBase class has all c RawIOBase methods""" + mismatch = support.detect_api_mismatch(pyio.RawIOBase, io.RawIOBase, + ignore=('__weakref__',)) + self.assertEqual(mismatch, set(), msg='Python RawIOBase does not have all C RawIOBase methods') + + def test_RawIOBase_pyio_in_io_match(self): + """Test that c RawIOBase class has all pyio RawIOBase methods""" + mismatch = support.detect_api_mismatch(io.RawIOBase, pyio.RawIOBase) + self.assertEqual(mismatch, set(), msg='C RawIOBase does not have all Python RawIOBase methods') + + class CommonBufferedTests: # Tests common to BufferedReader, BufferedWriter and BufferedRandom @@ -811,7 +822,7 @@ class CommonBufferedTests: def test_repr(self): raw = self.MockRawIO() b = self.tp(raw) - clsname = "%s.%s" % (self.tp.__module__, self.tp.__name__) + clsname = "%s.%s" % (self.tp.__module__, self.tp.__qualname__) self.assertEqual(repr(b), "<%s>" % clsname) raw.name = "dummy" self.assertEqual(repr(b), "<%s name='dummy'>" % clsname) @@ -985,6 +996,71 @@ class BufferedReaderTest(unittest.TestCase, CommonBufferedTests): self.assertEqual(bufio.readinto(b), 1) self.assertEqual(b, b"cb") + def test_readinto1(self): + buffer_size = 10 + rawio = self.MockRawIO((b"abc", b"de", b"fgh", b"jkl")) + bufio = self.tp(rawio, buffer_size=buffer_size) + b = bytearray(2) + self.assertEqual(bufio.peek(3), b'abc') + self.assertEqual(rawio._reads, 1) + self.assertEqual(bufio.readinto1(b), 2) + self.assertEqual(b, b"ab") + self.assertEqual(rawio._reads, 1) + self.assertEqual(bufio.readinto1(b), 1) + self.assertEqual(b[:1], b"c") + self.assertEqual(rawio._reads, 1) + self.assertEqual(bufio.readinto1(b), 2) + self.assertEqual(b, b"de") + self.assertEqual(rawio._reads, 2) + b = bytearray(2*buffer_size) + self.assertEqual(bufio.peek(3), b'fgh') + self.assertEqual(rawio._reads, 3) + self.assertEqual(bufio.readinto1(b), 6) + self.assertEqual(b[:6], b"fghjkl") + self.assertEqual(rawio._reads, 4) + + def test_readinto_array(self): + buffer_size = 60 + data = b"a" * 26 + rawio = self.MockRawIO((data,)) + bufio = self.tp(rawio, buffer_size=buffer_size) + + # Create an array with element size > 1 byte + b = array.array('i', b'x' * 32) + assert len(b) != 16 + + # Read into it. We should get as many *bytes* as we can fit into b + # (which is more than the number of elements) + n = bufio.readinto(b) + self.assertGreater(n, len(b)) + + # Check that old contents of b are preserved + bm = memoryview(b).cast('B') + self.assertLess(n, len(bm)) + self.assertEqual(bm[:n], data[:n]) + self.assertEqual(bm[n:], b'x' * (len(bm[n:]))) + + def test_readinto1_array(self): + buffer_size = 60 + data = b"a" * 26 + rawio = self.MockRawIO((data,)) + bufio = self.tp(rawio, buffer_size=buffer_size) + + # Create an array with element size > 1 byte + b = array.array('i', b'x' * 32) + assert len(b) != 16 + + # Read into it. We should get as many *bytes* as we can fit into b + # (which is more than the number of elements) + n = bufio.readinto1(b) + self.assertGreater(n, len(b)) + + # Check that old contents of b are preserved + bm = memoryview(b).cast('B') + self.assertLess(n, len(bm)) + self.assertEqual(bm[:n], data[:n]) + self.assertEqual(bm[n:], b'x' * (len(bm[n:]))) + def test_readlines(self): def bufio(): rawio = self.MockRawIO((b"abc\n", b"d\n", b"ef")) @@ -2913,6 +2989,17 @@ class TextIOWrapperTest(unittest.TestCase): self.assertFalse(err) self.assertEqual("ok", out.decode().strip()) + def test_read_byteslike(self): + r = MemviewBytesIO(b'Just some random string\n') + t = self.TextIOWrapper(r, 'utf-8') + + # TextIOwrapper will not read the full string, because + # we truncate it to a multiple of the native int size + # so that we can construct a more complex memoryview. + bytes_val = _to_memoryview(r.getvalue()).tobytes() + + self.assertEqual(t.read(200), bytes_val.decode('utf-8')) + def test_issue22849(self): class F(object): def readable(self): return True @@ -2929,6 +3016,25 @@ class TextIOWrapperTest(unittest.TestCase): t = self.TextIOWrapper(F(), encoding='utf-8') +class MemviewBytesIO(io.BytesIO): + '''A BytesIO object whose read method returns memoryviews + rather than bytes''' + + def read1(self, len_): + return _to_memoryview(super().read1(len_)) + + def read(self, len_): + return _to_memoryview(super().read(len_)) + +def _to_memoryview(buf): + '''Convert bytes-object *buf* to a non-trivial memoryview''' + + arr = array.array('i') + idx = len(buf) - len(buf) % arr.itemsize + arr.frombytes(buf[:idx]) + return memoryview(arr) + + class CTextIOWrapperTest(TextIOWrapperTest): io = io shutdown_error = "RuntimeError: could not find io module state" @@ -2937,8 +3043,6 @@ class CTextIOWrapperTest(TextIOWrapperTest): r = self.BytesIO(b"\xc3\xa9\n\n") b = self.BufferedReader(r, 1000) t = self.TextIOWrapper(b) - self.assertRaises(TypeError, t.__init__, b, newline=42) - self.assertRaises(ValueError, t.read) self.assertRaises(ValueError, t.__init__, b, newline='xyzzy') self.assertRaises(ValueError, t.read) @@ -3175,6 +3279,8 @@ class MiscIOTest(unittest.TestCase): self.assertRaises(ValueError, f.readall) if hasattr(f, "readinto"): self.assertRaises(ValueError, f.readinto, bytearray(1024)) + if hasattr(f, "readinto1"): + self.assertRaises(ValueError, f.readinto1, bytearray(1024)) self.assertRaises(ValueError, f.readline) self.assertRaises(ValueError, f.readlines) self.assertRaises(ValueError, f.seek, 0) @@ -3288,26 +3394,20 @@ class MiscIOTest(unittest.TestCase): with self.open(support.TESTFN, **kwargs) as f: self.assertRaises(TypeError, pickle.dumps, f, protocol) - @unittest.skipUnless(fcntl, 'fcntl required for this test') def test_nonblock_pipe_write_bigbuf(self): self._test_nonblock_pipe_write(16*1024) - @unittest.skipUnless(fcntl, 'fcntl required for this test') def test_nonblock_pipe_write_smallbuf(self): self._test_nonblock_pipe_write(1024) - def _set_non_blocking(self, fd): - flags = fcntl.fcntl(fd, fcntl.F_GETFL) - self.assertNotEqual(flags, -1) - res = fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) - self.assertEqual(res, 0) - + @unittest.skipUnless(hasattr(os, 'set_blocking'), + 'os.set_blocking() required for this test') def _test_nonblock_pipe_write(self, bufsize): sent = [] received = [] r, w = os.pipe() - self._set_non_blocking(r) - self._set_non_blocking(w) + os.set_blocking(r, False) + os.set_blocking(w, False) # To exercise all code paths in the C implementation we need # to play with buffer sizes. For instance, if we choose a @@ -3456,6 +3556,7 @@ class SignalsTest(unittest.TestCase): t.daemon = True r, w = os.pipe() fdopen_kwargs["closefd"] = False + large_data = item * (support.PIPE_MAX_SIZE // len(item) + 1) try: wio = self.io.open(w, **fdopen_kwargs) t.start() @@ -3467,8 +3568,7 @@ class SignalsTest(unittest.TestCase): # handlers, which in this case will invoke alarm_interrupt(). signal.alarm(1) try: - with self.assertRaises(ZeroDivisionError): - wio.write(item * (support.PIPE_MAX_SIZE // len(item) + 1)) + self.assertRaises(ZeroDivisionError, wio.write, large_data) finally: signal.alarm(0) t.join() @@ -3569,11 +3669,13 @@ class SignalsTest(unittest.TestCase): returning a partial result or EINTR), properly invokes the signal handler and retries if the latter returned successfully.""" select = support.import_module("select") + # A quantity that exceeds the buffer size of an anonymous pipe's # write end. N = support.PIPE_MAX_SIZE r, w = os.pipe() fdopen_kwargs["closefd"] = False + # We need a separate thread to read from the pipe and allow the # write() to finish. This thread is started after the SIGALRM is # received (forcing a first EINTR in write()). @@ -3596,6 +3698,8 @@ class SignalsTest(unittest.TestCase): signal.alarm(1) def alarm2(sig, frame): t.start() + + large_data = item * N signal.signal(signal.SIGALRM, alarm1) try: wio = self.io.open(w, **fdopen_kwargs) @@ -3605,7 +3709,9 @@ class SignalsTest(unittest.TestCase): # and the first alarm) # - second raw write() returns EINTR (because of the second alarm) # - subsequent write()s are successful (either partial or complete) - self.assertEqual(N, wio.write(item * N)) + written = wio.write(large_data) + self.assertEqual(N, written) + wio.flush() write_finished = True t.join() @@ -3645,7 +3751,7 @@ class PySignalsTest(SignalsTest): def load_tests(*args): - tests = (CIOTest, PyIOTest, + tests = (CIOTest, PyIOTest, APIMismatchTest, CBufferedReaderTest, PyBufferedReaderTest, CBufferedWriterTest, PyBufferedWriterTest, CBufferedRWPairTest, PyBufferedRWPairTest, diff --git a/Lib/test/test_ioctl.py b/Lib/test/test_ioctl.py index efe9f51..b82b651 100644 --- a/Lib/test/test_ioctl.py +++ b/Lib/test/test_ioctl.py @@ -1,6 +1,6 @@ import array import unittest -from test.support import run_unittest, import_module, get_attribute +from test.support import import_module, get_attribute import os, struct fcntl = import_module('fcntl') termios = import_module('termios') diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py index bfb5699..c217d36 100644 --- a/Lib/test/test_ipaddress.py +++ b/Lib/test/test_ipaddress.py @@ -9,7 +9,9 @@ import re import contextlib import functools import operator +import pickle import ipaddress +import weakref class BaseTestCase(unittest.TestCase): @@ -83,6 +85,13 @@ class CommonTestMixin: self.assertRaises(TypeError, hex, self.factory(1)) self.assertRaises(TypeError, bytes, self.factory(1)) + def pickle_test(self, addr): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + x = self.factory(addr) + y = pickle.loads(pickle.dumps(x, proto)) + self.assertEqual(y, x) + class CommonTestMixin_v4(CommonTestMixin): @@ -248,6 +257,12 @@ class AddressTestCase_v4(BaseTestCase, CommonTestMixin_v4): assertBadOctet("257.0.0.0", 257) assertBadOctet("192.168.0.999", 999) + def test_pickle(self): + self.pickle_test('192.0.2.1') + + def test_weakref(self): + weakref.ref(self.factory('192.0.2.1')) + class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6): factory = ipaddress.IPv6Address @@ -380,6 +395,12 @@ class AddressTestCase_v6(BaseTestCase, CommonTestMixin_v6): assertBadPart("02001:db8::", "02001") assertBadPart('2001:888888::1', "888888") + def test_pickle(self): + self.pickle_test('2001:db8::') + + def test_weakref(self): + weakref.ref(self.factory('2001:db8::')) + class NetmaskTestMixin_v4(CommonTestMixin_v4): """Input validation on interfaces and networks is very similar""" @@ -443,6 +464,11 @@ class NetmaskTestMixin_v4(CommonTestMixin_v4): assertBadNetmask("1.1.1.1", "pudding") assertBadNetmask("1.1.1.1", "::") + def test_pickle(self): + self.pickle_test('192.0.2.0/27') + self.pickle_test('192.0.2.0/31') # IPV4LENGTH - 1 + self.pickle_test('192.0.2.0') # IPV4LENGTH + class InterfaceTestCase_v4(BaseTestCase, NetmaskTestMixin_v4): factory = ipaddress.IPv4Interface @@ -501,6 +527,11 @@ class NetmaskTestMixin_v6(CommonTestMixin_v6): assertBadNetmask("::1", "pudding") assertBadNetmask("::", "::") + def test_pickle(self): + self.pickle_test('2001:db8::1000/124') + self.pickle_test('2001:db8::1000/127') # IPV6LENGTH - 1 + self.pickle_test('2001:db8::1000') # IPV6LENGTH + class InterfaceTestCase_v6(BaseTestCase, NetmaskTestMixin_v6): factory = ipaddress.IPv6Interface @@ -670,6 +701,119 @@ class IpaddrUnitTest(unittest.TestCase): self.assertEqual("IPv6Interface('::1/128')", repr(ipaddress.IPv6Interface('::1'))) + # issue #16531: constructing IPv4Network from a (address, mask) tuple + def testIPv4Tuple(self): + # /32 + ip = ipaddress.IPv4Address('192.0.2.1') + net = ipaddress.IPv4Network('192.0.2.1/32') + self.assertEqual(ipaddress.IPv4Network(('192.0.2.1', 32)), net) + self.assertEqual(ipaddress.IPv4Network((ip, 32)), net) + self.assertEqual(ipaddress.IPv4Network((3221225985, 32)), net) + self.assertEqual(ipaddress.IPv4Network(('192.0.2.1', + '255.255.255.255')), net) + self.assertEqual(ipaddress.IPv4Network((ip, + '255.255.255.255')), net) + self.assertEqual(ipaddress.IPv4Network((3221225985, + '255.255.255.255')), net) + # strict=True and host bits set + with self.assertRaises(ValueError): + ipaddress.IPv4Network(('192.0.2.1', 24)) + with self.assertRaises(ValueError): + ipaddress.IPv4Network((ip, 24)) + with self.assertRaises(ValueError): + ipaddress.IPv4Network((3221225985, 24)) + with self.assertRaises(ValueError): + ipaddress.IPv4Network(('192.0.2.1', '255.255.255.0')) + with self.assertRaises(ValueError): + ipaddress.IPv4Network((ip, '255.255.255.0')) + with self.assertRaises(ValueError): + ipaddress.IPv4Network((3221225985, '255.255.255.0')) + # strict=False and host bits set + net = ipaddress.IPv4Network('192.0.2.0/24') + self.assertEqual(ipaddress.IPv4Network(('192.0.2.1', 24), + strict=False), net) + self.assertEqual(ipaddress.IPv4Network((ip, 24), + strict=False), net) + self.assertEqual(ipaddress.IPv4Network((3221225985, 24), + strict=False), net) + self.assertEqual(ipaddress.IPv4Network(('192.0.2.1', + '255.255.255.0'), + strict=False), net) + self.assertEqual(ipaddress.IPv4Network((ip, + '255.255.255.0'), + strict=False), net) + self.assertEqual(ipaddress.IPv4Network((3221225985, + '255.255.255.0'), + strict=False), net) + + # /24 + ip = ipaddress.IPv4Address('192.0.2.0') + net = ipaddress.IPv4Network('192.0.2.0/24') + self.assertEqual(ipaddress.IPv4Network(('192.0.2.0', + '255.255.255.0')), net) + self.assertEqual(ipaddress.IPv4Network((ip, + '255.255.255.0')), net) + self.assertEqual(ipaddress.IPv4Network((3221225984, + '255.255.255.0')), net) + self.assertEqual(ipaddress.IPv4Network(('192.0.2.0', 24)), net) + self.assertEqual(ipaddress.IPv4Network((ip, 24)), net) + self.assertEqual(ipaddress.IPv4Network((3221225984, 24)), net) + + self.assertEqual(ipaddress.IPv4Interface(('192.0.2.1', 24)), + ipaddress.IPv4Interface('192.0.2.1/24')) + self.assertEqual(ipaddress.IPv4Interface((3221225985, 24)), + ipaddress.IPv4Interface('192.0.2.1/24')) + + # issue #16531: constructing IPv6Network from a (address, mask) tuple + def testIPv6Tuple(self): + # /128 + ip = ipaddress.IPv6Address('2001:db8::') + net = ipaddress.IPv6Network('2001:db8::/128') + self.assertEqual(ipaddress.IPv6Network(('2001:db8::', '128')), + net) + self.assertEqual(ipaddress.IPv6Network( + (42540766411282592856903984951653826560, 128)), + net) + self.assertEqual(ipaddress.IPv6Network((ip, '128')), + net) + ip = ipaddress.IPv6Address('2001:db8::') + net = ipaddress.IPv6Network('2001:db8::/96') + self.assertEqual(ipaddress.IPv6Network(('2001:db8::', '96')), + net) + self.assertEqual(ipaddress.IPv6Network( + (42540766411282592856903984951653826560, 96)), + net) + self.assertEqual(ipaddress.IPv6Network((ip, '96')), + net) + + # strict=True and host bits set + ip = ipaddress.IPv6Address('2001:db8::1') + with self.assertRaises(ValueError): + ipaddress.IPv6Network(('2001:db8::1', 96)) + with self.assertRaises(ValueError): + ipaddress.IPv6Network(( + 42540766411282592856903984951653826561, 96)) + with self.assertRaises(ValueError): + ipaddress.IPv6Network((ip, 96)) + # strict=False and host bits set + net = ipaddress.IPv6Network('2001:db8::/96') + self.assertEqual(ipaddress.IPv6Network(('2001:db8::1', 96), + strict=False), + net) + self.assertEqual(ipaddress.IPv6Network( + (42540766411282592856903984951653826561, 96), + strict=False), + net) + self.assertEqual(ipaddress.IPv6Network((ip, 96), strict=False), + net) + + # /96 + self.assertEqual(ipaddress.IPv6Interface(('2001:db8::1', '96')), + ipaddress.IPv6Interface('2001:db8::1/96')) + self.assertEqual(ipaddress.IPv6Interface( + (42540766411282592856903984951653826561, '96')), + ipaddress.IPv6Interface('2001:db8::1/96')) + # issue57 def testAddressIntMath(self): self.assertEqual(ipaddress.IPv4Address('1.1.1.1') + 255, @@ -690,20 +834,18 @@ class IpaddrUnitTest(unittest.TestCase): 2 ** ipaddress.IPV6LENGTH) def testInternals(self): - first, last = ipaddress._find_address_range([ - ipaddress.IPv4Address('10.10.10.10'), - ipaddress.IPv4Address('10.10.10.12')]) - self.assertEqual(first, last) + ip1 = ipaddress.IPv4Address('10.10.10.10') + ip2 = ipaddress.IPv4Address('10.10.10.11') + ip3 = ipaddress.IPv4Address('10.10.10.12') + self.assertEqual(list(ipaddress._find_address_range([ip1])), + [(ip1, ip1)]) + self.assertEqual(list(ipaddress._find_address_range([ip1, ip3])), + [(ip1, ip1), (ip3, ip3)]) + self.assertEqual(list(ipaddress._find_address_range([ip1, ip2, ip3])), + [(ip1, ip3)]) self.assertEqual(128, ipaddress._count_righthand_zero_bits(0, 128)) self.assertEqual("IPv4Network('1.2.3.0/24')", repr(self.ipv4_network)) - def testMissingAddressVersion(self): - class Broken(ipaddress._BaseAddress): - pass - broken = Broken('127.0.0.1') - with self.assertRaisesRegex(NotImplementedError, "Broken.*version"): - broken.version - def testMissingNetworkVersion(self): class Broken(ipaddress._BaseNetwork): pass @@ -1635,6 +1777,14 @@ class IpaddrUnitTest(unittest.TestCase): addr3.exploded) self.assertEqual('192.168.178.1', addr4.exploded) + def testReversePointer(self): + addr1 = ipaddress.IPv4Address('127.0.0.1') + addr2 = ipaddress.IPv6Address('2001:db8::1') + self.assertEqual('1.0.0.127.in-addr.arpa', addr1.reverse_pointer) + self.assertEqual('1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.' + + 'b.d.0.1.0.0.2.ip6.arpa', + addr2.reverse_pointer) + def testIntRepresentation(self): self.assertEqual(16909060, int(self.ipv4_address)) self.assertEqual(42540616829182469433547762482097946625, diff --git a/Lib/test/test_isinstance.py b/Lib/test/test_isinstance.py index 7a6730e..e63d59b 100644 --- a/Lib/test/test_isinstance.py +++ b/Lib/test/test_isinstance.py @@ -3,7 +3,6 @@ # testing of error conditions uncovered when using extension types. import unittest -from test import support import sys @@ -259,31 +258,23 @@ class TestIsInstanceIsSubclass(unittest.TestCase): self.assertEqual(True, issubclass(str, (str, (Child, NewChild, str)))) def test_subclass_recursion_limit(self): - # make sure that issubclass raises RuntimeError before the C stack is + # make sure that issubclass raises RecursionError before the C stack is # blown - self.assertRaises(RuntimeError, blowstack, issubclass, str, str) + self.assertRaises(RecursionError, blowstack, issubclass, str, str) def test_isinstance_recursion_limit(self): - # make sure that issubclass raises RuntimeError before the C stack is + # make sure that issubclass raises RecursionError before the C stack is # blown - self.assertRaises(RuntimeError, blowstack, isinstance, '', str) + self.assertRaises(RecursionError, blowstack, isinstance, '', str) def blowstack(fxn, arg, compare_to): # Make sure that calling isinstance with a deeply nested tuple for its - # argument will raise RuntimeError eventually. + # argument will raise RecursionError eventually. tuple_arg = (compare_to,) for cnt in range(sys.getrecursionlimit()+5): tuple_arg = (tuple_arg,) fxn(arg, tuple_arg) -def test_main(): - support.run_unittest( - TestIsInstanceExceptions, - TestIsSubclassExceptions, - TestIsInstanceIsSubclass - ) - - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 993438c..5b3ba7e 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -1870,8 +1870,6 @@ class RegressionTests(unittest.TestCase): hist.append(3) yield 2 hist.append(4) - if x: - raise StopIteration hist = [] self.assertRaises(AssertionError, list, chain(gen1(), gen2(False))) diff --git a/Lib/test/test_json/__init__.py b/Lib/test/test_json/__init__.py index 2cf1032..0807e6f 100644 --- a/Lib/test/test_json/__init__.py +++ b/Lib/test/test_json/__init__.py @@ -9,12 +9,15 @@ from test import support # import json with and without accelerations cjson = support.import_fresh_module('json', fresh=['_json']) pyjson = support.import_fresh_module('json', blocked=['_json']) +# JSONDecodeError is cached inside the _json module +cjson.JSONDecodeError = cjson.decoder.JSONDecodeError = json.JSONDecodeError # create two base classes that will be used by the other tests class PyTest(unittest.TestCase): json = pyjson loads = staticmethod(pyjson.loads) dumps = staticmethod(pyjson.dumps) + JSONDecodeError = staticmethod(pyjson.JSONDecodeError) @unittest.skipUnless(cjson, 'requires _json') class CTest(unittest.TestCase): @@ -22,6 +25,7 @@ class CTest(unittest.TestCase): json = cjson loads = staticmethod(cjson.loads) dumps = staticmethod(cjson.dumps) + JSONDecodeError = staticmethod(cjson.JSONDecodeError) # test PyTest and CTest checking if the functions come from the right module class TestPyTest(PyTest): diff --git a/Lib/test/test_json/test_decode.py b/Lib/test/test_json/test_decode.py index 591b2e2..cc83b45 100644 --- a/Lib/test/test_json/test_decode.py +++ b/Lib/test/test_json/test_decode.py @@ -63,12 +63,12 @@ class TestDecode: def test_extra_data(self): s = '[1, 2, 3]5' msg = 'Extra data' - self.assertRaisesRegex(ValueError, msg, self.loads, s) + self.assertRaisesRegex(self.JSONDecodeError, msg, self.loads, s) def test_invalid_escape(self): s = '["abc\\y"]' msg = 'escape' - self.assertRaisesRegex(ValueError, msg, self.loads, s) + self.assertRaisesRegex(self.JSONDecodeError, msg, self.loads, s) def test_invalid_input_type(self): msg = 'the JSON object must be str' @@ -80,10 +80,10 @@ class TestDecode: def test_string_with_utf8_bom(self): # see #18958 bom_json = "[1,2,3]".encode('utf-8-sig').decode('utf-8') - with self.assertRaises(ValueError) as cm: + with self.assertRaises(self.JSONDecodeError) as cm: self.loads(bom_json) self.assertIn('BOM', str(cm.exception)) - with self.assertRaises(ValueError) as cm: + with self.assertRaises(self.JSONDecodeError) as cm: self.json.load(StringIO(bom_json)) self.assertIn('BOM', str(cm.exception)) # make sure that the BOM is not detected in the middle of a string diff --git a/Lib/test/test_json/test_encode_basestring_ascii.py b/Lib/test/test_json/test_encode_basestring_ascii.py index 2122da1..4bbc6c7 100644 --- a/Lib/test/test_json/test_encode_basestring_ascii.py +++ b/Lib/test/test_json/test_encode_basestring_ascii.py @@ -12,9 +12,6 @@ CASES = [ (' s p a c e d ', '" s p a c e d "'), ('\U0001d120', '"\\ud834\\udd20"'), ('\u03b1\u03a9', '"\\u03b1\\u03a9"'), - ('\u03b1\u03a9', '"\\u03b1\\u03a9"'), - ('\u03b1\u03a9', '"\\u03b1\\u03a9"'), - ('\u03b1\u03a9', '"\\u03b1\\u03a9"'), ("`1~!@#$%^&*()_+-={':[,]}|;.</>?", '"`1~!@#$%^&*()_+-={\':[,]}|;.</>?"'), ('\x08\x0c\n\r\t', '"\\b\\f\\n\\r\\t"'), ('\u0123\u4567\u89ab\ucdef\uabcd\uef4a', '"\\u0123\\u4567\\u89ab\\ucdef\\uabcd\\uef4a"'), diff --git a/Lib/test/test_json/test_fail.py b/Lib/test/test_json/test_fail.py index 7caafdb..95ff5b8 100644 --- a/Lib/test/test_json/test_fail.py +++ b/Lib/test/test_json/test_fail.py @@ -87,7 +87,7 @@ class TestFail: continue try: self.loads(doc) - except ValueError: + except self.JSONDecodeError: pass else: self.fail("Expected failure for fail{0}.json: {1!r}".format(idx, doc)) @@ -124,10 +124,16 @@ class TestFail: ('"spam', 'Unterminated string starting at', 0), ] for data, msg, idx in test_cases: - self.assertRaisesRegex(ValueError, - r'^{0}: line 1 column {1} \(char {2}\)'.format( - re.escape(msg), idx + 1, idx), - self.loads, data) + with self.assertRaises(self.JSONDecodeError) as cm: + self.loads(data) + err = cm.exception + self.assertEqual(err.msg, msg) + self.assertEqual(err.pos, idx) + self.assertEqual(err.lineno, 1) + self.assertEqual(err.colno, idx + 1) + self.assertEqual(str(err), + '%s: line 1 column %d (char %d)' % + (msg, idx + 1, idx)) def test_unexpected_data(self): test_cases = [ @@ -154,10 +160,16 @@ class TestFail: ('{"spam":42,}', 'Expecting property name enclosed in double quotes', 11), ] for data, msg, idx in test_cases: - self.assertRaisesRegex(ValueError, - r'^{0}: line 1 column {1} \(char {2}\)'.format( - re.escape(msg), idx + 1, idx), - self.loads, data) + with self.assertRaises(self.JSONDecodeError) as cm: + self.loads(data) + err = cm.exception + self.assertEqual(err.msg, msg) + self.assertEqual(err.pos, idx) + self.assertEqual(err.lineno, 1) + self.assertEqual(err.colno, idx + 1) + self.assertEqual(str(err), + '%s: line 1 column %d (char %d)' % + (msg, idx + 1, idx)) def test_extra_data(self): test_cases = [ @@ -171,11 +183,16 @@ class TestFail: ('"spam",42', 'Extra data', 6), ] for data, msg, idx in test_cases: - self.assertRaisesRegex(ValueError, - r'^{0}: line 1 column {1} - line 1 column {2}' - r' \(char {3} - {4}\)'.format( - re.escape(msg), idx + 1, len(data) + 1, idx, len(data)), - self.loads, data) + with self.assertRaises(self.JSONDecodeError) as cm: + self.loads(data) + err = cm.exception + self.assertEqual(err.msg, msg) + self.assertEqual(err.pos, idx) + self.assertEqual(err.lineno, 1) + self.assertEqual(err.colno, idx + 1) + self.assertEqual(str(err), + '%s: line 1 column %d (char %d)' % + (msg, idx + 1, idx)) def test_linecol(self): test_cases = [ @@ -185,10 +202,16 @@ class TestFail: ('\n \n\n !', 4, 6, 10), ] for data, line, col, idx in test_cases: - self.assertRaisesRegex(ValueError, - r'^Expecting value: line {0} column {1}' - r' \(char {2}\)$'.format(line, col, idx), - self.loads, data) + with self.assertRaises(self.JSONDecodeError) as cm: + self.loads(data) + err = cm.exception + self.assertEqual(err.msg, 'Expecting value') + self.assertEqual(err.pos, idx) + self.assertEqual(err.lineno, line) + self.assertEqual(err.colno, col) + self.assertEqual(str(err), + 'Expecting value: line %s column %d (char %d)' % + (line, col, idx)) class TestPyFail(TestFail, PyTest): pass class TestCFail(TestFail, CTest): pass diff --git a/Lib/test/test_json/test_recursion.py b/Lib/test/test_json/test_recursion.py index 1a76254..877dc44 100644 --- a/Lib/test/test_json/test_recursion.py +++ b/Lib/test/test_json/test_recursion.py @@ -68,11 +68,11 @@ class TestRecursion: def test_highly_nested_objects_decoding(self): # test that loading highly-nested objects doesn't segfault when C # accelerations are used. See #12017 - with self.assertRaises(RuntimeError): + with self.assertRaises(RecursionError): self.loads('{"a":' * 100000 + '1' + '}' * 100000) - with self.assertRaises(RuntimeError): + with self.assertRaises(RecursionError): self.loads('{"a":' * 100000 + '[1]' + '}' * 100000) - with self.assertRaises(RuntimeError): + with self.assertRaises(RecursionError): self.loads('[' * 100000 + '1' + ']' * 100000) def test_highly_nested_objects_encoding(self): @@ -80,9 +80,9 @@ class TestRecursion: l, d = [], {} for x in range(100000): l, d = [l], {'k':d} - with self.assertRaises(RuntimeError): + with self.assertRaises(RecursionError): self.dumps(l) - with self.assertRaises(RuntimeError): + with self.assertRaises(RecursionError): self.dumps(d) def test_endless_recursion(self): @@ -92,7 +92,7 @@ class TestRecursion: """If check_circular is False, this will keep adding another list.""" return [o] - with self.assertRaises(RuntimeError): + with self.assertRaises(RecursionError): EndlessJSONEncoder(check_circular=False).encode(5j) diff --git a/Lib/test/test_json/test_scanstring.py b/Lib/test/test_json/test_scanstring.py index 07f4358..2d3ee8a 100644 --- a/Lib/test/test_json/test_scanstring.py +++ b/Lib/test/test_json/test_scanstring.py @@ -129,7 +129,7 @@ class TestScanstring: '"\\ud834\\u0X20"', ] for s in bad_escapes: - with self.assertRaises(ValueError, msg=s): + with self.assertRaises(self.JSONDecodeError, msg=s): scanstring(s, 1, True) def test_overflow(self): diff --git a/Lib/test/test_json/test_tool.py b/Lib/test/test_json/test_tool.py index 0c39e56..15f3736 100644 --- a/Lib/test/test_json/test_tool.py +++ b/Lib/test/test_json/test_tool.py @@ -4,7 +4,8 @@ import textwrap import unittest import subprocess from test import support -from test.script_helper import assert_python_ok +from test.support.script_helper import assert_python_ok + class TestTool(unittest.TestCase): data = """ @@ -15,7 +16,7 @@ class TestTool(unittest.TestCase): :"yes"} ] """ - expect = textwrap.dedent("""\ + expect_without_sort_keys = textwrap.dedent("""\ [ [ "blorpie" @@ -37,6 +38,28 @@ class TestTool(unittest.TestCase): ] """) + expect = textwrap.dedent("""\ + [ + [ + "blorpie" + ], + [ + "whoops" + ], + [], + "d-shtaeou", + "d-nthiouh", + "i-vhbjkhnth", + { + "nifty": 87 + }, + { + "morefield": false, + "field": "yes" + } + ] + """) + def test_stdin_stdout(self): with subprocess.Popen( (sys.executable, '-m', 'json.tool'), @@ -55,6 +78,7 @@ class TestTool(unittest.TestCase): def test_infile_stdout(self): infile = self._create_infile() rc, out, err = assert_python_ok('-m', 'json.tool', infile) + self.assertEqual(rc, 0) self.assertEqual(out.splitlines(), self.expect.encode().splitlines()) self.assertEqual(err, b'') @@ -65,5 +89,20 @@ class TestTool(unittest.TestCase): self.addCleanup(os.remove, outfile) with open(outfile, "r") as fp: self.assertEqual(fp.read(), self.expect) + self.assertEqual(rc, 0) self.assertEqual(out, b'') self.assertEqual(err, b'') + + def test_help_flag(self): + rc, out, err = assert_python_ok('-m', 'json.tool', '-h') + self.assertEqual(rc, 0) + self.assertTrue(out.startswith(b'usage: ')) + self.assertEqual(err, b'') + + def test_sort_keys_flag(self): + infile = self._create_infile() + rc, out, err = assert_python_ok('-m', 'json.tool', '--sort-keys', infile) + self.assertEqual(rc, 0) + self.assertEqual(out.splitlines(), + self.expect_without_sort_keys.encode().splitlines()) + self.assertEqual(err, b'') diff --git a/Lib/test/test_keywordonlyarg.py b/Lib/test/test_keywordonlyarg.py index 7f315d4..d82e33d 100644 --- a/Lib/test/test_keywordonlyarg.py +++ b/Lib/test/test_keywordonlyarg.py @@ -4,7 +4,6 @@ __author__ = "Jiwon Seo" __email__ = "seojiwon at gmail dot com" import unittest -from test.support import run_unittest def posonly_sum(pos_arg1, *arg, **kwarg): return pos_arg1 + sum(arg) + sum(kwarg.values()) @@ -186,8 +185,5 @@ class KeywordOnlyArgTestCase(unittest.TestCase): self.assertEqual(str(err.exception), "name 'b' is not defined") -def test_main(): - run_unittest(KeywordOnlyArgTestCase) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_kqueue.py b/Lib/test/test_kqueue.py index f79bd89..f822024 100644 --- a/Lib/test/test_kqueue.py +++ b/Lib/test/test_kqueue.py @@ -9,7 +9,6 @@ import sys import time import unittest -from test import support if not hasattr(select, "kqueue"): raise unittest.SkipTest("test works only on BSD") @@ -114,7 +113,7 @@ class TestKQueue(unittest.TestCase): def test_queue_event(self): serverSocket = socket.socket() serverSocket.bind(('127.0.0.1', 0)) - serverSocket.listen(1) + serverSocket.listen() client = socket.socket() client.setblocking(False) try: @@ -237,8 +236,5 @@ class TestKQueue(unittest.TestCase): self.assertEqual(os.get_inheritable(kqueue.fileno()), False) -def test_main(): - support.run_unittest(TestKQueue) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_linecache.py b/Lib/test/test_linecache.py index 79157de..21ef738 100644 --- a/Lib/test/test_linecache.py +++ b/Lib/test/test_linecache.py @@ -7,6 +7,7 @@ from test import support FILENAME = linecache.__file__ +NONEXISTENT_FILENAME = FILENAME + '.missing' INVALID_NAME = '!@$)(!@#_1' EMPTY = '' TESTS = 'inspect_fodder inspect_fodder2 mapping_tests' @@ -126,6 +127,48 @@ class LineCacheTests(unittest.TestCase): self.assertEqual(line, getline(source_name, index + 1)) source_list.append(line) + def test_lazycache_no_globals(self): + lines = linecache.getlines(FILENAME) + linecache.clearcache() + self.assertEqual(False, linecache.lazycache(FILENAME, None)) + self.assertEqual(lines, linecache.getlines(FILENAME)) + + def test_lazycache_smoke(self): + lines = linecache.getlines(NONEXISTENT_FILENAME, globals()) + linecache.clearcache() + self.assertEqual( + True, linecache.lazycache(NONEXISTENT_FILENAME, globals())) + self.assertEqual(1, len(linecache.cache[NONEXISTENT_FILENAME])) + # Note here that we're looking up a non existant filename with no + # globals: this would error if the lazy value wasn't resolved. + self.assertEqual(lines, linecache.getlines(NONEXISTENT_FILENAME)) + + def test_lazycache_provide_after_failed_lookup(self): + linecache.clearcache() + lines = linecache.getlines(NONEXISTENT_FILENAME, globals()) + linecache.clearcache() + linecache.getlines(NONEXISTENT_FILENAME) + linecache.lazycache(NONEXISTENT_FILENAME, globals()) + self.assertEqual(lines, linecache.updatecache(NONEXISTENT_FILENAME)) + + def test_lazycache_check(self): + linecache.clearcache() + linecache.lazycache(NONEXISTENT_FILENAME, globals()) + linecache.checkcache() + + def test_lazycache_bad_filename(self): + linecache.clearcache() + self.assertEqual(False, linecache.lazycache('', globals())) + self.assertEqual(False, linecache.lazycache('<foo>', globals())) + + def test_lazycache_already_cached(self): + linecache.clearcache() + lines = linecache.getlines(NONEXISTENT_FILENAME, globals()) + self.assertEqual( + False, + linecache.lazycache(NONEXISTENT_FILENAME, globals())) + self.assertEqual(4, len(linecache.cache[NONEXISTENT_FILENAME])) + def test_memoryerror(self): lines = linecache.getlines(FILENAME) self.assertTrue(lines) diff --git a/Lib/test/test_list.py b/Lib/test/test_list.py index 3b94700..ae1be6e 100644 --- a/Lib/test/test_list.py +++ b/Lib/test/test_list.py @@ -108,20 +108,5 @@ class ListTest(list_tests.CommonTest): with self.assertRaises(TypeError): (3,) + L([1,2]) -def test_main(verbose=None): - support.run_unittest(ListTest) - - # verify reference counting - import sys - if verbose and hasattr(sys, "gettotalrefcount"): - import gc - counts = [None] * 5 - for i in range(len(counts)): - support.run_unittest(ListTest) - gc.collect() - counts[i] = sys.gettotalrefcount() - print(counts) - - if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py index 9369a25..fae2c3d 100644 --- a/Lib/test/test_locale.py +++ b/Lib/test/test_locale.py @@ -524,5 +524,59 @@ class TestMiscellaneous(unittest.TestCase): locale.setlocale(locale.LC_ALL, (b'not', b'valid')) +class BaseDelocalizeTest(BaseLocalizedTest): + + def _test_delocalize(self, value, out): + self.assertEqual(locale.delocalize(value), out) + + def _test_atof(self, value, out): + self.assertEqual(locale.atof(value), out) + + def _test_atoi(self, value, out): + self.assertEqual(locale.atoi(value), out) + + +class TestEnUSDelocalize(EnUSCookedTest, BaseDelocalizeTest): + + def test_delocalize(self): + self._test_delocalize('50000.00', '50000.00') + self._test_delocalize('50,000.00', '50000.00') + + def test_atof(self): + self._test_atof('50000.00', 50000.) + self._test_atof('50,000.00', 50000.) + + def test_atoi(self): + self._test_atoi('50000', 50000) + self._test_atoi('50,000', 50000) + + +class TestCDelocalizeTest(CCookedTest, BaseDelocalizeTest): + + def test_delocalize(self): + self._test_delocalize('50000.00', '50000.00') + + def test_atof(self): + self._test_atof('50000.00', 50000.) + + def test_atoi(self): + self._test_atoi('50000', 50000) + + +class TestfrFRDelocalizeTest(FrFRCookedTest, BaseDelocalizeTest): + + def test_delocalize(self): + self._test_delocalize('50000,00', '50000.00') + self._test_delocalize('50 000,00', '50000.00') + + def test_atof(self): + self._test_atof('50000,00', 50000.) + self._test_atof('50 000,00', 50000.) + + def test_atoi(self): + self._test_atoi('50000', 50000) + self._test_atoi('50 000', 50000) + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py index be3d02c..041d38f 100644 --- a/Lib/test/test_logging.py +++ b/Lib/test/test_logging.py @@ -1,4 +1,4 @@ -# Copyright 2001-2013 by Vinay Sajip. All Rights Reserved. +# Copyright 2001-2014 by Vinay Sajip. All Rights Reserved. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose and without fee is hereby granted, @@ -16,7 +16,7 @@ """Test harness for the logging module. Run all tests. -Copyright (C) 2001-2013 Vinay Sajip. All Rights Reserved. +Copyright (C) 2001-2014 Vinay Sajip. All Rights Reserved. """ import logging @@ -34,15 +34,12 @@ import os import queue import random import re -import select import socket import struct import sys import tempfile -from test.script_helper import assert_python_ok -from test.support import (captured_stdout, run_with_locale, run_unittest, - patch, requires_zlib, TestHandler, Matcher, HOST, - swap_attr) +from test.support.script_helper import assert_python_ok +from test import support import textwrap import time import unittest @@ -52,16 +49,12 @@ try: import threading # The following imports are needed only for tests which # require threading - import asynchat import asyncore - import errno from http.server import HTTPServer, BaseHTTPRequestHandler import smtpd from urllib.parse import urlparse, parse_qs from socketserver import (ThreadingUDPServer, DatagramRequestHandler, - ThreadingTCPServer, StreamRequestHandler, - ThreadingUnixStreamServer, - ThreadingUnixDatagramServer) + ThreadingTCPServer, StreamRequestHandler) except ImportError: threading = None try: @@ -642,22 +635,23 @@ class StreamHandlerTest(BaseTest): h = TestStreamHandler(BadStream()) r = logging.makeLogRecord({}) old_raise = logging.raiseExceptions - old_stderr = sys.stderr + try: h.handle(r) self.assertIs(h.error_record, r) + h = logging.StreamHandler(BadStream()) - sys.stderr = sio = io.StringIO() - h.handle(r) - self.assertIn('\nRuntimeError: deliberate mistake\n', - sio.getvalue()) + with support.captured_stderr() as stderr: + h.handle(r) + msg = '\nRuntimeError: deliberate mistake\n' + self.assertIn(msg, stderr.getvalue()) + logging.raiseExceptions = False - sys.stderr = sio = io.StringIO() - h.handle(r) - self.assertEqual('', sio.getvalue()) + with support.captured_stderr() as stderr: + h.handle(r) + self.assertEqual('', stderr.getvalue()) finally: logging.raiseExceptions = old_raise - sys.stderr = old_stderr # -- The following section could be moved into a server_helper.py module # -- if it proves to be of wider utility than just test_logging @@ -685,7 +679,8 @@ if threading: """ def __init__(self, addr, handler, poll_interval, sockmap): - smtpd.SMTPServer.__init__(self, addr, None, map=sockmap) + smtpd.SMTPServer.__init__(self, addr, None, map=sockmap, + decode_data=True) self.port = self.socket.getsockname()[1] self._handler = handler self._thread = None @@ -927,10 +922,10 @@ class SMTPHandlerTest(BaseTest): TIMEOUT = 8.0 def test_basic(self): sockmap = {} - server = TestSMTPServer((HOST, 0), self.process_message, 0.001, + server = TestSMTPServer((support.HOST, 0), self.process_message, 0.001, sockmap) server.start() - addr = (HOST, server.port) + addr = (support.HOST, server.port) h = logging.handlers.SMTPHandler(addr, 'me', 'you', 'Log', timeout=self.TIMEOUT) self.assertEqual(h.toaddrs, ['you']) @@ -1246,7 +1241,7 @@ class ConfigFileTest(BaseTest): def test_config0_ok(self): # A simple config file which overrides the default settings. - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config0) logger = logging.getLogger() # Won't output anything @@ -1261,7 +1256,7 @@ class ConfigFileTest(BaseTest): def test_config0_using_cp_ok(self): # A simple config file which overrides the default settings. - with captured_stdout() as output: + with support.captured_stdout() as output: file = io.StringIO(textwrap.dedent(self.config0)) cp = configparser.ConfigParser() cp.read_file(file) @@ -1279,7 +1274,7 @@ class ConfigFileTest(BaseTest): def test_config1_ok(self, config=config1): # A config file defining a sub-parser as well. - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(config) logger = logging.getLogger("compiler.parser") # Both will output a message @@ -1302,7 +1297,7 @@ class ConfigFileTest(BaseTest): def test_config4_ok(self): # A config file specifying a custom formatter class. - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config4) logger = logging.getLogger() try: @@ -1322,7 +1317,7 @@ class ConfigFileTest(BaseTest): self.test_config1_ok(config=self.config6) def test_config7_ok(self): - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config1a) logger = logging.getLogger("compiler.parser") # See issue #11424. compiler-hyphenated sorts @@ -1342,7 +1337,7 @@ class ConfigFileTest(BaseTest): ], stream=output) # Original logger output is empty. self.assert_log_lines([]) - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config7) logger = logging.getLogger("compiler.parser") self.assertFalse(logger.disabled) @@ -2489,7 +2484,7 @@ class ConfigDictTest(BaseTest): def test_config0_ok(self): # A simple config which overrides the default settings. - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config0) logger = logging.getLogger() # Won't output anything @@ -2504,7 +2499,7 @@ class ConfigDictTest(BaseTest): def test_config1_ok(self, config=config1): # A config defining a sub-parser as well. - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(config) logger = logging.getLogger("compiler.parser") # Both will output a message @@ -2535,7 +2530,7 @@ class ConfigDictTest(BaseTest): def test_config4_ok(self): # A config specifying a custom formatter class. - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config4) #logger = logging.getLogger() try: @@ -2550,7 +2545,7 @@ class ConfigDictTest(BaseTest): def test_config4a_ok(self): # A config specifying a custom formatter class. - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config4a) #logger = logging.getLogger() try: @@ -2570,7 +2565,7 @@ class ConfigDictTest(BaseTest): self.assertRaises(Exception, self.apply_config, self.config6) def test_config7_ok(self): - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config1) logger = logging.getLogger("compiler.parser") # Both will output a message @@ -2582,7 +2577,7 @@ class ConfigDictTest(BaseTest): ], stream=output) # Original logger output is empty. self.assert_log_lines([]) - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config7) logger = logging.getLogger("compiler.parser") self.assertTrue(logger.disabled) @@ -2599,7 +2594,7 @@ class ConfigDictTest(BaseTest): #Same as test_config_7_ok but don't disable old loggers. def test_config_8_ok(self): - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config1) logger = logging.getLogger("compiler.parser") # All will output a message @@ -2611,7 +2606,7 @@ class ConfigDictTest(BaseTest): ], stream=output) # Original logger output is empty. self.assert_log_lines([]) - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config8) logger = logging.getLogger("compiler.parser") self.assertFalse(logger.disabled) @@ -2632,7 +2627,7 @@ class ConfigDictTest(BaseTest): self.assert_log_lines([]) def test_config_8a_ok(self): - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config1a) logger = logging.getLogger("compiler.parser") # See issue #11424. compiler-hyphenated sorts @@ -2652,7 +2647,7 @@ class ConfigDictTest(BaseTest): ], stream=output) # Original logger output is empty. self.assert_log_lines([]) - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config8a) logger = logging.getLogger("compiler.parser") self.assertFalse(logger.disabled) @@ -2675,7 +2670,7 @@ class ConfigDictTest(BaseTest): self.assert_log_lines([]) def test_config_9_ok(self): - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config9) logger = logging.getLogger("compiler.parser") #Nothing will be output since both handler and logger are set to WARNING @@ -2693,7 +2688,7 @@ class ConfigDictTest(BaseTest): ], stream=output) def test_config_10_ok(self): - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config10) logger = logging.getLogger("compiler.parser") logger.warning(self.next_message()) @@ -2721,7 +2716,7 @@ class ConfigDictTest(BaseTest): self.assertRaises(Exception, self.apply_config, self.config13) def test_config14_ok(self): - with captured_stdout() as output: + with support.captured_stdout() as output: self.apply_config(self.config14) h = logging._handlers['hand1'] self.assertEqual(h.foo, 'bar') @@ -2760,7 +2755,7 @@ class ConfigDictTest(BaseTest): @unittest.skipUnless(threading, 'Threading required for this test.') def test_listen_config_10_ok(self): - with captured_stdout() as output: + with support.captured_stdout() as output: self.setup_via_listener(json.dumps(self.config10)) logger = logging.getLogger("compiler.parser") logger.warning(self.next_message()) @@ -2780,7 +2775,7 @@ class ConfigDictTest(BaseTest): @unittest.skipUnless(threading, 'Threading required for this test.') def test_listen_config_1_ok(self): - with captured_stdout() as output: + with support.captured_stdout() as output: self.setup_via_listener(textwrap.dedent(ConfigFileTest.config1)) logger = logging.getLogger("compiler.parser") # Both will output a message @@ -2807,7 +2802,7 @@ class ConfigDictTest(BaseTest): # First, specify a verification function that will fail. # We expect to see no output, since our configuration # never took effect. - with captured_stdout() as output: + with support.captured_stdout() as output: self.setup_via_listener(to_send, verify_fail) # Both will output a message logger.info(self.next_message()) @@ -2822,7 +2817,7 @@ class ConfigDictTest(BaseTest): # Now, perform no verification. Our configuration # should take effect. - with captured_stdout() as output: + with support.captured_stdout() as output: self.setup_via_listener(to_send) # no verify callable specified logger = logging.getLogger("compiler.parser") # Both will output a message @@ -2840,7 +2835,7 @@ class ConfigDictTest(BaseTest): # Now, perform verification which transforms the bytes. - with captured_stdout() as output: + with support.captured_stdout() as output: self.setup_via_listener(to_send[::-1], verify_reverse) logger = logging.getLogger("compiler.parser") # Both will output a message @@ -2995,7 +2990,7 @@ class QueueHandlerTest(BaseTest): @unittest.skipUnless(hasattr(logging.handlers, 'QueueListener'), 'logging.handlers.QueueListener required for this test') def test_queue_listener(self): - handler = TestHandler(Matcher()) + handler = support.TestHandler(support.Matcher()) listener = logging.handlers.QueueListener(self.queue, handler) listener.start() try: @@ -3007,6 +3002,25 @@ class QueueHandlerTest(BaseTest): self.assertTrue(handler.matches(levelno=logging.WARNING, message='1')) self.assertTrue(handler.matches(levelno=logging.ERROR, message='2')) self.assertTrue(handler.matches(levelno=logging.CRITICAL, message='3')) + handler.close() + + # Now test with respect_handler_level set + + handler = support.TestHandler(support.Matcher()) + handler.setLevel(logging.CRITICAL) + listener = logging.handlers.QueueListener(self.queue, handler, + respect_handler_level=True) + listener.start() + try: + self.que_logger.warning(self.next_message()) + self.que_logger.error(self.next_message()) + self.que_logger.critical(self.next_message()) + finally: + listener.stop() + self.assertFalse(handler.matches(levelno=logging.WARNING, message='4')) + self.assertFalse(handler.matches(levelno=logging.ERROR, message='5')) + self.assertTrue(handler.matches(levelno=logging.CRITICAL, message='6')) + ZERO = datetime.timedelta(0) @@ -3163,32 +3177,35 @@ class LastResortTest(BaseTest): # Test the last resort handler root = self.root_logger root.removeHandler(self.root_hdlr) - old_stderr = sys.stderr old_lastresort = logging.lastResort old_raise_exceptions = logging.raiseExceptions + try: - sys.stderr = sio = io.StringIO() - root.debug('This should not appear') - self.assertEqual(sio.getvalue(), '') - root.warning('This is your final chance!') - self.assertEqual(sio.getvalue(), 'This is your final chance!\n') - #No handlers and no last resort, so 'No handlers' message + with support.captured_stderr() as stderr: + root.debug('This should not appear') + self.assertEqual(stderr.getvalue(), '') + root.warning('Final chance!') + self.assertEqual(stderr.getvalue(), 'Final chance!\n') + + # No handlers and no last resort, so 'No handlers' message logging.lastResort = None - sys.stderr = sio = io.StringIO() - root.warning('This is your final chance!') - self.assertEqual(sio.getvalue(), 'No handlers could be found for logger "root"\n') + with support.captured_stderr() as stderr: + root.warning('Final chance!') + msg = 'No handlers could be found for logger "root"\n' + self.assertEqual(stderr.getvalue(), msg) + # 'No handlers' message only printed once - sys.stderr = sio = io.StringIO() - root.warning('This is your final chance!') - self.assertEqual(sio.getvalue(), '') + with support.captured_stderr() as stderr: + root.warning('Final chance!') + self.assertEqual(stderr.getvalue(), '') + + # If raiseExceptions is False, no message is printed root.manager.emittedNoHandlerWarning = False - #If raiseExceptions is False, no message is printed logging.raiseExceptions = False - sys.stderr = sio = io.StringIO() - root.warning('This is your final chance!') - self.assertEqual(sio.getvalue(), '') + with support.captured_stderr() as stderr: + root.warning('Final chance!') + self.assertEqual(stderr.getvalue(), '') finally: - sys.stderr = old_stderr root.addHandler(self.root_hdlr) logging.lastResort = old_lastresort logging.raiseExceptions = old_raise_exceptions @@ -3319,8 +3336,8 @@ class ModuleLevelMiscTest(BaseTest): def _test_log(self, method, level=None): called = [] - patch(self, logging, 'basicConfig', - lambda *a, **kw: called.append((a, kw))) + support.patch(self, logging, 'basicConfig', + lambda *a, **kw: called.append((a, kw))) recording = RecordingHandler() logging.root.addHandler(recording) @@ -3491,7 +3508,7 @@ class BasicConfigTest(unittest.TestCase): self.assertEqual(logging.root.level, self.original_logging_level) def test_strformatstyle(self): - with captured_stdout() as output: + with support.captured_stdout() as output: logging.basicConfig(stream=sys.stdout, style="{") logging.error("Log an error") sys.stdout.seek(0) @@ -3499,7 +3516,7 @@ class BasicConfigTest(unittest.TestCase): "ERROR:root:Log an error") def test_stringtemplatestyle(self): - with captured_stdout() as output: + with support.captured_stdout() as output: logging.basicConfig(stream=sys.stdout, style="$") logging.error("Log an error") sys.stdout.seek(0) @@ -3620,7 +3637,7 @@ class BasicConfigTest(unittest.TestCase): self.addCleanup(logging.root.setLevel, old_level) called.append((a, kw)) - patch(self, logging, 'basicConfig', my_basic_config) + support.patch(self, logging, 'basicConfig', my_basic_config) log_method = getattr(logging, method) if level is not None: @@ -3686,6 +3703,19 @@ class LoggerAdapterTest(unittest.TestCase): self.assertEqual(record.exc_info, (exc.__class__, exc, exc.__traceback__)) + def test_exception_excinfo(self): + try: + 1 / 0 + except ZeroDivisionError as e: + exc = e + + self.adapter.exception('exc_info test', exc_info=exc) + + self.assertEqual(len(self.recording.records), 1) + record = self.recording.records[0] + self.assertEqual(record.exc_info, + (exc.__class__, exc, exc.__traceback__)) + def test_critical(self): msg = 'critical test! %r' self.adapter.critical(msg, self.recording) @@ -3745,17 +3775,17 @@ class LoggerTest(BaseTest): (exc.__class__, exc, exc.__traceback__)) def test_log_invalid_level_with_raise(self): - with swap_attr(logging, 'raiseExceptions', True): + with support.swap_attr(logging, 'raiseExceptions', True): self.assertRaises(TypeError, self.logger.log, '10', 'test message') def test_log_invalid_level_no_raise(self): - with swap_attr(logging, 'raiseExceptions', False): + with support.swap_attr(logging, 'raiseExceptions', False): self.logger.log('10', 'test message') # no exception happens def test_find_caller_with_stack_info(self): called = [] - patch(self, logging.traceback, 'print_stack', - lambda f, file: called.append(file.getvalue())) + support.patch(self, logging.traceback, 'print_stack', + lambda f, file: called.append(file.getvalue())) self.logger.findCaller(stack_info=True) @@ -3892,7 +3922,7 @@ class RotatingFileHandlerTest(BaseFileTest): self.assertFalse(os.path.exists(namer(self.fn + ".3"))) rh.close() - @requires_zlib + @support.requires_zlib def test_rotator(self): def namer(name): return name + ".gz" @@ -4131,22 +4161,20 @@ class NTEventLogHandlerTest(BaseTest): # Set the locale to the platform-dependent default. I have no idea # why the test does this, but in any case we save the current locale # first and restore it at the end. -@run_with_locale('LC_ALL', '') +@support.run_with_locale('LC_ALL', '') def test_main(): - run_unittest(BuiltinLevelsTest, BasicFilterTest, - CustomLevelsAndFiltersTest, HandlerTest, MemoryHandlerTest, - ConfigFileTest, SocketHandlerTest, DatagramHandlerTest, - MemoryTest, EncodingTest, WarningsTest, ConfigDictTest, - ManagerTest, FormatterTest, BufferingFormatterTest, - StreamHandlerTest, LogRecordFactoryTest, ChildLoggerTest, - QueueHandlerTest, ShutdownTest, ModuleLevelMiscTest, - BasicConfigTest, LoggerAdapterTest, LoggerTest, - SMTPHandlerTest, FileHandlerTest, RotatingFileHandlerTest, - LastResortTest, LogRecordTest, ExceptionTest, - SysLogHandlerTest, HTTPHandlerTest, NTEventLogHandlerTest, - TimedRotatingFileHandlerTest, UnixSocketHandlerTest, - UnixDatagramHandlerTest, UnixSysLogHandlerTest - ) + support.run_unittest( + BuiltinLevelsTest, BasicFilterTest, CustomLevelsAndFiltersTest, + HandlerTest, MemoryHandlerTest, ConfigFileTest, SocketHandlerTest, + DatagramHandlerTest, MemoryTest, EncodingTest, WarningsTest, + ConfigDictTest, ManagerTest, FormatterTest, BufferingFormatterTest, + StreamHandlerTest, LogRecordFactoryTest, ChildLoggerTest, + QueueHandlerTest, ShutdownTest, ModuleLevelMiscTest, BasicConfigTest, + LoggerAdapterTest, LoggerTest, SMTPHandlerTest, FileHandlerTest, + RotatingFileHandlerTest, LastResortTest, LogRecordTest, + ExceptionTest, SysLogHandlerTest, HTTPHandlerTest, + NTEventLogHandlerTest, TimedRotatingFileHandlerTest, + UnixSocketHandlerTest, UnixDatagramHandlerTest, UnixSysLogHandlerTest) if __name__ == "__main__": test_main() diff --git a/Lib/test/test_long.py b/Lib/test/test_long.py index 5b9e37a..62e69a9 100644 --- a/Lib/test/test_long.py +++ b/Lib/test/test_long.py @@ -582,8 +582,6 @@ class LongTest(unittest.TestCase): return (x > y) - (x < y) def __eq__(self, other): return self._cmp__(other) == 0 - def __ne__(self, other): - return self._cmp__(other) != 0 def __ge__(self, other): return self._cmp__(other) >= 0 def __gt__(self, other): @@ -1227,8 +1225,5 @@ class LongTest(unittest.TestCase): self.assertEqual(type(value >> shift), int) -def test_main(): - support.run_unittest(LongTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_longexp.py b/Lib/test/test_longexp.py index 1b40d02..f4c463a 100644 --- a/Lib/test/test_longexp.py +++ b/Lib/test/test_longexp.py @@ -1,5 +1,4 @@ import unittest -from test import support class LongExpText(unittest.TestCase): def test_longexp(self): @@ -7,8 +6,5 @@ class LongExpText(unittest.TestCase): l = eval("[" + "2," * REPS + "]") self.assertEqual(len(l), REPS) -def test_main(): - support.run_unittest(LongExpText) - -if __name__=="__main__": - test_main() +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_lzma.py b/Lib/test/test_lzma.py index 07fadbd..2d39099 100644 --- a/Lib/test/test_lzma.py +++ b/Lib/test/test_lzma.py @@ -1,4 +1,5 @@ -from io import BytesIO, UnsupportedOperation +import _compression +from io import BytesIO, UnsupportedOperation, DEFAULT_BUFFER_SIZE import os import pickle import random @@ -135,6 +136,97 @@ class CompressorDecompressorTestCase(unittest.TestCase): self.assertTrue(lzd.eof) self.assertEqual(lzd.unused_data, b"") + def test_decompressor_chunks_maxsize(self): + lzd = LZMADecompressor() + max_length = 100 + out = [] + + # Feed first half the input + len_ = len(COMPRESSED_XZ) // 2 + out.append(lzd.decompress(COMPRESSED_XZ[:len_], + max_length=max_length)) + self.assertFalse(lzd.needs_input) + self.assertEqual(len(out[-1]), max_length) + + # Retrieve more data without providing more input + out.append(lzd.decompress(b'', max_length=max_length)) + self.assertFalse(lzd.needs_input) + self.assertEqual(len(out[-1]), max_length) + + # Retrieve more data while providing more input + out.append(lzd.decompress(COMPRESSED_XZ[len_:], + max_length=max_length)) + self.assertLessEqual(len(out[-1]), max_length) + + # Retrieve remaining uncompressed data + while not lzd.eof: + out.append(lzd.decompress(b'', max_length=max_length)) + self.assertLessEqual(len(out[-1]), max_length) + + out = b"".join(out) + self.assertEqual(out, INPUT) + self.assertEqual(lzd.check, lzma.CHECK_CRC64) + self.assertEqual(lzd.unused_data, b"") + + def test_decompressor_inputbuf_1(self): + # Test reusing input buffer after moving existing + # contents to beginning + lzd = LZMADecompressor() + out = [] + + # Create input buffer and fill it + self.assertEqual(lzd.decompress(COMPRESSED_XZ[:100], + max_length=0), b'') + + # Retrieve some results, freeing capacity at beginning + # of input buffer + out.append(lzd.decompress(b'', 2)) + + # Add more data that fits into input buffer after + # moving existing data to beginning + out.append(lzd.decompress(COMPRESSED_XZ[100:105], 15)) + + # Decompress rest of data + out.append(lzd.decompress(COMPRESSED_XZ[105:])) + self.assertEqual(b''.join(out), INPUT) + + def test_decompressor_inputbuf_2(self): + # Test reusing input buffer by appending data at the + # end right away + lzd = LZMADecompressor() + out = [] + + # Create input buffer and empty it + self.assertEqual(lzd.decompress(COMPRESSED_XZ[:200], + max_length=0), b'') + out.append(lzd.decompress(b'')) + + # Fill buffer with new data + out.append(lzd.decompress(COMPRESSED_XZ[200:280], 2)) + + # Append some more data, not enough to require resize + out.append(lzd.decompress(COMPRESSED_XZ[280:300], 2)) + + # Decompress rest of data + out.append(lzd.decompress(COMPRESSED_XZ[300:])) + self.assertEqual(b''.join(out), INPUT) + + def test_decompressor_inputbuf_3(self): + # Test reusing input buffer after extending it + + lzd = LZMADecompressor() + out = [] + + # Create almost full input buffer + out.append(lzd.decompress(COMPRESSED_XZ[:200], 5)) + + # Add even more data to it, requiring resize + out.append(lzd.decompress(COMPRESSED_XZ[200:300], 5)) + + # Decompress rest of data + out.append(lzd.decompress(COMPRESSED_XZ[300:])) + self.assertEqual(b''.join(out), INPUT) + def test_decompressor_unused_data(self): lzd = LZMADecompressor() extra = b"fooblibar" @@ -681,13 +773,13 @@ class FileTestCase(unittest.TestCase): def test_read_multistream_buffer_size_aligned(self): # Test the case where a stream boundary coincides with the end # of the raw read buffer. - saved_buffer_size = lzma._BUFFER_SIZE - lzma._BUFFER_SIZE = len(COMPRESSED_XZ) + saved_buffer_size = _compression.BUFFER_SIZE + _compression.BUFFER_SIZE = len(COMPRESSED_XZ) try: with LZMAFile(BytesIO(COMPRESSED_XZ * 5)) as f: self.assertEqual(f.read(), INPUT * 5) finally: - lzma._BUFFER_SIZE = saved_buffer_size + _compression.BUFFER_SIZE = saved_buffer_size def test_read_trailing_junk(self): with LZMAFile(BytesIO(COMPRESSED_XZ + COMPRESSED_BOGUS)) as f: @@ -738,7 +830,7 @@ class FileTestCase(unittest.TestCase): with LZMAFile(BytesIO(), "w") as f: self.assertRaises(ValueError, f.read) with LZMAFile(BytesIO(COMPRESSED_XZ)) as f: - self.assertRaises(TypeError, f.read, None) + self.assertRaises(TypeError, f.read, float()) def test_read_bad_data(self): with LZMAFile(BytesIO(COMPRESSED_BOGUS)) as f: @@ -834,6 +926,17 @@ class FileTestCase(unittest.TestCase): with LZMAFile(BytesIO(COMPRESSED_XZ)) as f: self.assertListEqual(f.readlines(), lines) + def test_decompress_limited(self): + """Decompressed data buffering should be limited""" + bomb = lzma.compress(bytes(int(2e6)), preset=6) + self.assertLess(len(bomb), _compression.BUFFER_SIZE) + + decomp = LZMAFile(BytesIO(bomb)) + self.assertEqual(bytes(1), decomp.read(1)) + max_decomp = 1 + DEFAULT_BUFFER_SIZE + self.assertLessEqual(decomp._buffer.raw.tell(), max_decomp, + "Excessive amount of data was decompressed") + def test_write(self): with BytesIO() as dst: with LZMAFile(dst, "w") as f: @@ -999,7 +1102,8 @@ class FileTestCase(unittest.TestCase): self.assertRaises(ValueError, f.seek, 0) with LZMAFile(BytesIO(COMPRESSED_XZ)) as f: self.assertRaises(ValueError, f.seek, 0, 3) - self.assertRaises(ValueError, f.seek, 9, ()) + # io.BufferedReader raises TypeError instead of ValueError + self.assertRaises((TypeError, ValueError), f.seek, 9, ()) self.assertRaises(TypeError, f.seek, None) self.assertRaises(TypeError, f.seek, b"derp") diff --git a/Lib/test/test_macpath.py b/Lib/test/test_macpath.py index 22f8491..80bec7a 100644 --- a/Lib/test/test_macpath.py +++ b/Lib/test/test_macpath.py @@ -142,6 +142,8 @@ class MacPathTestCase(unittest.TestCase): class MacCommonTest(test_genericpath.CommonTest, unittest.TestCase): pathmodule = macpath + test_relpath_errors = None + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_mailcap.py b/Lib/test/test_mailcap.py index a4cd09c..22b2fcc 100644 --- a/Lib/test/test_mailcap.py +++ b/Lib/test/test_mailcap.py @@ -213,9 +213,5 @@ class FindmatchTest(unittest.TestCase): self.assertEqual(mailcap.findmatch(*c[0], **c[1]), c[2]) -def test_main(): - test.support.run_unittest(HelperFunctionTest, GetcapsTest, FindmatchTest) - - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py index 903e12c..c7def9a 100644 --- a/Lib/test/test_marshal.py +++ b/Lib/test/test_marshal.py @@ -193,7 +193,7 @@ class BugsTestCase(unittest.TestCase): head = last = [] # The max stack depth should match the value in Python/marshal.c. if os.name == 'nt' and hasattr(sys, 'gettotalrefcount'): - MAX_MARSHAL_STACK_DEPTH = 1500 + MAX_MARSHAL_STACK_DEPTH = 1000 else: MAX_MARSHAL_STACK_DEPTH = 2000 for i in range(MAX_MARSHAL_STACK_DEPTH - 2): diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index 48f84ba..6c7b99d 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -175,6 +175,14 @@ def parse_testfile(fname): flags ) +# Class providing an __index__ method. +class MyIndexable(object): + def __init__(self, value): + self.value = value + + def __index__(self): + return self.value + class MathTests(unittest.TestCase): def ftest(self, name, value, expected): @@ -422,9 +430,17 @@ class MathTests(unittest.TestCase): self.assertEqual(math.factorial(i), py_factorial(i)) self.assertRaises(ValueError, math.factorial, -1) self.assertRaises(ValueError, math.factorial, -1.0) + self.assertRaises(ValueError, math.factorial, -10**100) + self.assertRaises(ValueError, math.factorial, -1e100) self.assertRaises(ValueError, math.factorial, math.pi) - self.assertRaises(OverflowError, math.factorial, sys.maxsize+1) - self.assertRaises(OverflowError, math.factorial, 10e100) + + # Other implementations may place different upper bounds. + @support.cpython_only + def testFactorialHugeInputs(self): + # Currently raises ValueError for inputs that are too large + # to fit into a C long. + self.assertRaises(OverflowError, math.factorial, 10**100) + self.assertRaises(OverflowError, math.factorial, 1e100) def testFloor(self): self.assertRaises(TypeError, math.floor) @@ -587,6 +603,49 @@ class MathTests(unittest.TestCase): s = msum(vals) self.assertEqual(msum(vals), math.fsum(vals)) + def testGcd(self): + gcd = math.gcd + self.assertEqual(gcd(0, 0), 0) + self.assertEqual(gcd(1, 0), 1) + self.assertEqual(gcd(-1, 0), 1) + self.assertEqual(gcd(0, 1), 1) + self.assertEqual(gcd(0, -1), 1) + self.assertEqual(gcd(7, 1), 1) + self.assertEqual(gcd(7, -1), 1) + self.assertEqual(gcd(-23, 15), 1) + self.assertEqual(gcd(120, 84), 12) + self.assertEqual(gcd(84, -120), 12) + self.assertEqual(gcd(1216342683557601535506311712, + 436522681849110124616458784), 32) + c = 652560 + x = 434610456570399902378880679233098819019853229470286994367836600566 + y = 1064502245825115327754847244914921553977 + a = x * c + b = y * c + self.assertEqual(gcd(a, b), c) + self.assertEqual(gcd(b, a), c) + self.assertEqual(gcd(-a, b), c) + self.assertEqual(gcd(b, -a), c) + self.assertEqual(gcd(a, -b), c) + self.assertEqual(gcd(-b, a), c) + self.assertEqual(gcd(-a, -b), c) + self.assertEqual(gcd(-b, -a), c) + c = 576559230871654959816130551884856912003141446781646602790216406874 + a = x * c + b = y * c + self.assertEqual(gcd(a, b), c) + self.assertEqual(gcd(b, a), c) + self.assertEqual(gcd(-a, b), c) + self.assertEqual(gcd(b, -a), c) + self.assertEqual(gcd(a, -b), c) + self.assertEqual(gcd(-b, a), c) + self.assertEqual(gcd(-a, -b), c) + self.assertEqual(gcd(-b, -a), c) + + self.assertRaises(TypeError, gcd, 120.0, 84) + self.assertRaises(TypeError, gcd, 120, 84.0) + self.assertEqual(gcd(MyIndexable(120), MyIndexable(84)), 12) + def testHypot(self): self.assertRaises(TypeError, math.hypot) self.ftest('hypot(0,0)', math.hypot(0,0), 0) @@ -975,6 +1034,17 @@ class MathTests(unittest.TestCase): self.assertFalse(math.isinf(0.)) self.assertFalse(math.isinf(1.)) + @requires_IEEE_754 + def test_nan_constant(self): + self.assertTrue(math.isnan(math.nan)) + + @requires_IEEE_754 + def test_inf_constant(self): + self.assertTrue(math.isinf(math.inf)) + self.assertGreater(math.inf, 0.0) + self.assertEqual(math.inf, float("inf")) + self.assertEqual(-math.inf, float("-inf")) + # RED_FLAG 16-Oct-2000 Tim # While 2.0 is more consistent about exceptions than previous releases, it # still fails this part of the test on some platforms. For now, we only @@ -1096,10 +1166,131 @@ class MathTests(unittest.TestCase): '\n '.join(failures)) +class IsCloseTests(unittest.TestCase): + isclose = math.isclose # sublcasses should override this + + def assertIsClose(self, a, b, *args, **kwargs): + self.assertTrue(self.isclose(a, b, *args, **kwargs), + msg="%s and %s should be close!" % (a, b)) + + def assertIsNotClose(self, a, b, *args, **kwargs): + self.assertFalse(self.isclose(a, b, *args, **kwargs), + msg="%s and %s should not be close!" % (a, b)) + + def assertAllClose(self, examples, *args, **kwargs): + for a, b in examples: + self.assertIsClose(a, b, *args, **kwargs) + + def assertAllNotClose(self, examples, *args, **kwargs): + for a, b in examples: + self.assertIsNotClose(a, b, *args, **kwargs) + + def test_negative_tolerances(self): + # ValueError should be raised if either tolerance is less than zero + with self.assertRaises(ValueError): + self.assertIsClose(1, 1, rel_tol=-1e-100) + with self.assertRaises(ValueError): + self.assertIsClose(1, 1, rel_tol=1e-100, abs_tol=-1e10) + + def test_identical(self): + # identical values must test as close + identical_examples = [(2.0, 2.0), + (0.1e200, 0.1e200), + (1.123e-300, 1.123e-300), + (12345, 12345.0), + (0.0, -0.0), + (345678, 345678)] + self.assertAllClose(identical_examples, rel_tol=0.0, abs_tol=0.0) + + def test_eight_decimal_places(self): + # examples that are close to 1e-8, but not 1e-9 + eight_decimal_places_examples = [(1e8, 1e8 + 1), + (-1e-8, -1.000000009e-8), + (1.12345678, 1.12345679)] + self.assertAllClose(eight_decimal_places_examples, rel_tol=1e-8) + self.assertAllNotClose(eight_decimal_places_examples, rel_tol=1e-9) + + def test_near_zero(self): + # values close to zero + near_zero_examples = [(1e-9, 0.0), + (-1e-9, 0.0), + (-1e-150, 0.0)] + # these should not be close to any rel_tol + self.assertAllNotClose(near_zero_examples, rel_tol=0.9) + # these should be close to abs_tol=1e-8 + self.assertAllClose(near_zero_examples, abs_tol=1e-8) + + def test_identical_infinite(self): + # these are close regardless of tolerance -- i.e. they are equal + self.assertIsClose(INF, INF) + self.assertIsClose(INF, INF, abs_tol=0.0) + self.assertIsClose(NINF, NINF) + self.assertIsClose(NINF, NINF, abs_tol=0.0) + + def test_inf_ninf_nan(self): + # these should never be close (following IEEE 754 rules for equality) + not_close_examples = [(NAN, NAN), + (NAN, 1e-100), + (1e-100, NAN), + (INF, NAN), + (NAN, INF), + (INF, NINF), + (INF, 1.0), + (1.0, INF), + (INF, 1e308), + (1e308, INF)] + # use largest reasonable tolerance + self.assertAllNotClose(not_close_examples, abs_tol=0.999999999999999) + + def test_zero_tolerance(self): + # test with zero tolerance + zero_tolerance_close_examples = [(1.0, 1.0), + (-3.4, -3.4), + (-1e-300, -1e-300)] + self.assertAllClose(zero_tolerance_close_examples, rel_tol=0.0) + + zero_tolerance_not_close_examples = [(1.0, 1.000000000000001), + (0.99999999999999, 1.0), + (1.0e200, .999999999999999e200)] + self.assertAllNotClose(zero_tolerance_not_close_examples, rel_tol=0.0) + + def test_assymetry(self): + # test the assymetry example from PEP 485 + self.assertAllClose([(9, 10), (10, 9)], rel_tol=0.1) + + def test_integers(self): + # test with integer values + integer_examples = [(100000001, 100000000), + (123456789, 123456788)] + + self.assertAllClose(integer_examples, rel_tol=1e-8) + self.assertAllNotClose(integer_examples, rel_tol=1e-9) + + def test_decimals(self): + # test with Decimal values + from decimal import Decimal + + decimal_examples = [(Decimal('1.00000001'), Decimal('1.0')), + (Decimal('1.00000001e-20'), Decimal('1.0e-20')), + (Decimal('1.00000001e-100'), Decimal('1.0e-100'))] + self.assertAllClose(decimal_examples, rel_tol=1e-8) + self.assertAllNotClose(decimal_examples, rel_tol=1e-9) + + def test_fractions(self): + # test with Fraction values + from fractions import Fraction + + # could use some more examples here! + fraction_examples = [(Fraction(1, 100000000) + 1, Fraction(1))] + self.assertAllClose(fraction_examples, rel_tol=1e-8) + self.assertAllNotClose(fraction_examples, rel_tol=1e-9) + + def test_main(): from doctest import DocFileSuite suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(MathTests)) + suite.addTest(unittest.makeSuite(IsCloseTests)) suite.addTest(DocFileSuite("ieee754.txt")) run_unittest(suite) diff --git a/Lib/test/test_memoryio.py b/Lib/test/test_memoryio.py index 7ce95b9..44d66c3 100644 --- a/Lib/test/test_memoryio.py +++ b/Lib/test/test_memoryio.py @@ -9,6 +9,7 @@ from test import support import io import _pyio as pyio import pickle +import sys class MemorySeekTestMixin: @@ -165,6 +166,10 @@ class MemoryTestMixin: memio.seek(0) self.assertEqual(memio.read(None), buf) self.assertRaises(TypeError, memio.read, '') + memio.seek(len(buf) + 1) + self.assertEqual(memio.read(1), self.EOF) + memio.seek(len(buf) + 1) + self.assertEqual(memio.read(), self.EOF) memio.close() self.assertRaises(ValueError, memio.read) @@ -184,6 +189,9 @@ class MemoryTestMixin: self.assertEqual(memio.readline(-1), buf) memio.seek(0) self.assertEqual(memio.readline(0), self.EOF) + # Issue #24989: Buffer overread + memio.seek(len(buf) * 2 + 1) + self.assertEqual(memio.readline(), self.EOF) buf = self.buftype("1234567890\n") memio = self.ioclass((buf * 3)[:-1]) @@ -216,6 +224,9 @@ class MemoryTestMixin: memio.seek(0) self.assertEqual(memio.readlines(None), [buf] * 10) self.assertRaises(TypeError, memio.readlines, '') + # Issue #24989: Buffer overread + memio.seek(len(buf) * 10 + 1) + self.assertEqual(memio.readlines(), []) memio.close() self.assertRaises(ValueError, memio.readlines) @@ -237,6 +248,9 @@ class MemoryTestMixin: self.assertEqual(line, buf) i += 1 self.assertEqual(i, 10) + # Issue #24989: Buffer overread + memio.seek(len(buf) * 10 + 1) + self.assertEqual(list(memio), []) memio = self.ioclass(buf * 2) memio.close() self.assertRaises(ValueError, memio.__next__) @@ -718,12 +732,56 @@ class CBytesIOTest(PyBytesIOTest): @support.cpython_only def test_sizeof(self): - basesize = support.calcobjsize('P2nN2Pn') + basesize = support.calcobjsize('P2n2Pn') check = self.check_sizeof self.assertEqual(object.__sizeof__(io.BytesIO()), basesize) check(io.BytesIO(), basesize ) - check(io.BytesIO(b'a'), basesize + 1 + 1 ) - check(io.BytesIO(b'a' * 1000), basesize + 1000 + 1 ) + check(io.BytesIO(b'a' * 1000), basesize + sys.getsizeof(b'a' * 1000)) + + # Various tests of copy-on-write behaviour for BytesIO. + + def _test_cow_mutation(self, mutation): + # Common code for all BytesIO copy-on-write mutation tests. + imm = b' ' * 1024 + old_rc = sys.getrefcount(imm) + memio = self.ioclass(imm) + self.assertEqual(sys.getrefcount(imm), old_rc + 1) + mutation(memio) + self.assertEqual(sys.getrefcount(imm), old_rc) + + @support.cpython_only + def test_cow_truncate(self): + # Ensure truncate causes a copy. + def mutation(memio): + memio.truncate(1) + self._test_cow_mutation(mutation) + + @support.cpython_only + def test_cow_write(self): + # Ensure write that would not cause a resize still results in a copy. + def mutation(memio): + memio.seek(0) + memio.write(b'foo') + self._test_cow_mutation(mutation) + + @support.cpython_only + def test_cow_setstate(self): + # __setstate__ should cause buffer to be released. + memio = self.ioclass(b'foooooo') + state = memio.__getstate__() + def mutation(memio): + memio.__setstate__(state) + self._test_cow_mutation(mutation) + + @support.cpython_only + def test_cow_mutable(self): + # BytesIO should accept only Bytes for copy-on-write sharing, since + # arbitrary buffer-exporting objects like bytearray() aren't guaranteed + # to be immutable. + ba = bytearray(1024) + old_rc = sys.getrefcount(ba) + memio = self.ioclass(ba) + self.assertEqual(sys.getrefcount(ba), old_rc) class CStringIOTest(PyStringIOTest): ioclass = io.StringIO @@ -783,10 +841,5 @@ class CStringIOPickleTest(PyStringIOPickleTest): pass -def test_main(): - tests = [PyBytesIOTest, PyStringIOTest, CBytesIOTest, CStringIOTest, - PyStringIOPickleTest, CStringIOPickleTest] - support.run_unittest(*tests) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_memoryview.py b/Lib/test/test_memoryview.py index 4bc3133..da01a84 100644 --- a/Lib/test/test_memoryview.py +++ b/Lib/test/test_memoryview.py @@ -369,12 +369,12 @@ class AbstractMemoryTests: d = memoryview(b) del b - + self.assertEqual(c[0], 256) self.assertEqual(d[0], 256) self.assertEqual(c.format, "H") self.assertEqual(d.format, "H") - + _ = m.cast('I') self.assertEqual(c[0], 256) self.assertEqual(d[0], 256) @@ -492,8 +492,26 @@ class ArrayMemorySliceSliceTest(unittest.TestCase, pass -def test_main(): - test.support.run_unittest(__name__) +class OtherTest(unittest.TestCase): + def test_ctypes_cast(self): + # Issue 15944: Allow all source formats when casting to bytes. + ctypes = test.support.import_module("ctypes") + p6 = bytes(ctypes.c_double(0.6)) + + d = ctypes.c_double() + m = memoryview(d).cast("B") + m[:2] = p6[:2] + m[2:] = p6[2:] + self.assertEqual(d.value, 0.6) + + for format in "Bbc": + with self.subTest(format): + d = ctypes.c_double() + m = memoryview(d).cast(format) + m[:2] = memoryview(p6).cast(format)[:2] + m[2:] = memoryview(p6).cast(format)[2:] + self.assertEqual(d.value, 0.6) + if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_mimetypes.py b/Lib/test/test_mimetypes.py index 0b53032..6856593 100644 --- a/Lib/test/test_mimetypes.py +++ b/Lib/test/test_mimetypes.py @@ -101,11 +101,5 @@ class Win32MimeTypesTestCase(unittest.TestCase): eq(self.db.guess_type("image.jpg"), ("image/jpeg", None)) eq(self.db.guess_type("image.png"), ("image/png", None)) -def test_main(): - support.run_unittest(MimeTypesTestCase, - Win32MimeTypesTestCase - ) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_minidom.py b/Lib/test/test_minidom.py index 05df6e9..e8ca497 100644 --- a/Lib/test/test_minidom.py +++ b/Lib/test/test_minidom.py @@ -1,7 +1,7 @@ # test for xml.dom.minidom import pickle -from test.support import run_unittest, findfile +from test.support import findfile import unittest import xml.dom.minidom @@ -49,6 +49,21 @@ class MinidomTest(unittest.TestCase): t = node.wholeText self.confirm(t == s, "looking for %r, found %r" % (s, t)) + def testDocumentAsyncAttr(self): + doc = Document() + self.assertFalse(doc.async_) + with self.assertWarns(DeprecationWarning): + self.assertFalse(getattr(doc, 'async', True)) + with self.assertWarns(DeprecationWarning): + setattr(doc, 'async', True) + with self.assertWarns(DeprecationWarning): + self.assertTrue(getattr(doc, 'async', False)) + self.assertTrue(doc.async_) + + self.assertFalse(Document.async_) + with self.assertWarns(DeprecationWarning): + self.assertFalse(getattr(Document, 'async', True)) + def testParseFromBinaryFile(self): with open(tstfile, 'rb') as file: dom = parse(file) @@ -1545,8 +1560,5 @@ class MinidomTest(unittest.TestCase): pi = doc.createProcessingInstruction("y", "z") pi.nodeValue = "crash" -def test_main(): - run_unittest(MinidomTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_mmap.py b/Lib/test/test_mmap.py index ad93a59..0f25742 100644 --- a/Lib/test/test_mmap.py +++ b/Lib/test/test_mmap.py @@ -282,6 +282,7 @@ class MmapTests(unittest.TestCase): self.assertEqual(m.find(b'one', 1), 8) self.assertEqual(m.find(b'one', 1, -1), 8) self.assertEqual(m.find(b'one', 1, -2), -1) + self.assertEqual(m.find(bytearray(b'one')), 0) def test_rfind(self): @@ -300,6 +301,7 @@ class MmapTests(unittest.TestCase): self.assertEqual(m.rfind(b'one', 0, -2), 0) self.assertEqual(m.rfind(b'one', 1, -1), 8) self.assertEqual(m.rfind(b'one', 1, -2), -1) + self.assertEqual(m.rfind(bytearray(b'one')), 8) def test_double_close(self): @@ -601,8 +603,10 @@ class MmapTests(unittest.TestCase): m.write(b"bar") self.assertEqual(m.tell(), 6) self.assertEqual(m[:], b"012bar6789") - m.seek(8) - self.assertRaises(ValueError, m.write, b"bar") + m.write(bytearray(b"baz")) + self.assertEqual(m.tell(), 9) + self.assertEqual(m[:], b"012barbaz9") + self.assertRaises(ValueError, m.write, b"ba") def test_non_ascii_byte(self): for b in (129, 200, 255): # > 128 diff --git a/Lib/test/test_module.py b/Lib/test/test_module.py index 1230293..48ab0b4 100644 --- a/Lib/test/test_module.py +++ b/Lib/test/test_module.py @@ -1,8 +1,8 @@ # Test the module type import unittest import weakref -from test.support import run_unittest, gc_collect -from test.script_helper import assert_python_ok +from test.support import gc_collect +from test.support.script_helper import assert_python_ok import sys ModuleType = type(sys) @@ -30,6 +30,22 @@ class ModuleTests(unittest.TestCase): pass self.assertEqual(foo.__doc__, ModuleType.__doc__) + def test_unintialized_missing_getattr(self): + # Issue 8297 + # test the text in the AttributeError of an uninitialized module + foo = ModuleType.__new__(ModuleType) + self.assertRaisesRegex( + AttributeError, "module has no attribute 'not_here'", + getattr, foo, "not_here") + + def test_missing_getattr(self): + # Issue 8297 + # test the text in the AttributeError + foo = ModuleType("foo") + self.assertRaisesRegex( + AttributeError, "module 'foo' has no attribute 'not_here'", + getattr, foo, "not_here") + def test_no_docstring(self): # Regularly initialized module, no docstring foo = ModuleType("foo") @@ -211,12 +227,16 @@ a = A(destroyed)""" b"len = len", b"shutil.rmtree = rmtree"}) - # frozen and namespace module reprs are tested in importlib. - + def test_descriptor_errors_propogate(self): + class Descr: + def __get__(self, o, t): + raise RuntimeError + class M(ModuleType): + melon = Descr() + self.assertRaises(RuntimeError, getattr, M("mymod"), "melon") -def test_main(): - run_unittest(ModuleTests) + # frozen and namespace module reprs are tested in importlib. if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_msilib.py b/Lib/test/test_msilib.py index ccdaec7..8ef334f 100644 --- a/Lib/test/test_msilib.py +++ b/Lib/test/test_msilib.py @@ -1,7 +1,7 @@ """ Test suite for the code in msilib """ import unittest import os -from test.support import run_unittest, import_module +from test.support import import_module msilib = import_module('msilib') class Test_make_id(unittest.TestCase): @@ -39,8 +39,5 @@ class Test_make_id(unittest.TestCase): msilib.make_id(".s\x82o?*+rt"), "_.s_o___rt") -def test_main(): - run_unittest(__name__) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_multiprocessing_main_handling.py b/Lib/test/test_multiprocessing_main_handling.py index de5f782..52273ea 100644 --- a/Lib/test/test_multiprocessing_main_handling.py +++ b/Lib/test/test_multiprocessing_main_handling.py @@ -13,10 +13,9 @@ import os import os.path import py_compile -from test.script_helper import ( +from test.support.script_helper import ( make_pkg, make_script, make_zip_pkg, make_zip_script, - assert_python_ok, assert_python_failure, temp_dir, - spawn_python, kill_python) + assert_python_ok, assert_python_failure, spawn_python, kill_python) # Look up which start methods are available to test import multiprocessing @@ -157,12 +156,12 @@ class MultiProcessingCmdLineMixin(): self._check_output(script_name, rc, out, err) def test_basic_script(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'script') self._check_script(script_name) def test_basic_script_no_suffix(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'script', omit_suffix=True) self._check_script(script_name) @@ -173,7 +172,7 @@ class MultiProcessingCmdLineMixin(): # a workaround for that case # See https://github.com/ipython/ipython/issues/4698 source = test_source_main_skipped_in_children - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'ipython', source=source) self._check_script(script_name) @@ -183,7 +182,7 @@ class MultiProcessingCmdLineMixin(): self._check_script(script_no_suffix) def test_script_compiled(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, 'script') py_compile.compile(script_name, doraise=True) os.remove(script_name) @@ -192,14 +191,14 @@ class MultiProcessingCmdLineMixin(): def test_directory(self): source = self.main_in_children_source - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, '__main__', source=source) self._check_script(script_dir) def test_directory_compiled(self): source = self.main_in_children_source - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, '__main__', source=source) py_compile.compile(script_name, doraise=True) @@ -209,7 +208,7 @@ class MultiProcessingCmdLineMixin(): def test_zipfile(self): source = self.main_in_children_source - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, '__main__', source=source) zip_name, run_name = make_zip_script(script_dir, 'test_zip', script_name) @@ -217,7 +216,7 @@ class MultiProcessingCmdLineMixin(): def test_zipfile_compiled(self): source = self.main_in_children_source - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: script_name = _make_test_script(script_dir, '__main__', source=source) compiled_name = py_compile.compile(script_name, doraise=True) @@ -225,7 +224,7 @@ class MultiProcessingCmdLineMixin(): self._check_script(zip_name) def test_module_in_package(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir) script_name = _make_test_script(pkg_dir, 'check_sibling') @@ -234,20 +233,20 @@ class MultiProcessingCmdLineMixin(): self._check_script(launch_name) def test_module_in_package_in_zipfile(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script') launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.script', zip_name) self._check_script(launch_name) def test_module_in_subpackage_in_zipfile(self): - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script', depth=2) launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.test_pkg.script', zip_name) self._check_script(launch_name) def test_package(self): source = self.main_in_children_source - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir) script_name = _make_test_script(pkg_dir, '__main__', @@ -257,7 +256,7 @@ class MultiProcessingCmdLineMixin(): def test_package_compiled(self): source = self.main_in_children_source - with temp_dir() as script_dir: + with support.temp_dir() as script_dir: pkg_dir = os.path.join(script_dir, 'test_pkg') make_pkg(pkg_dir) script_name = _make_test_script(pkg_dir, '__main__', diff --git a/Lib/test/test_nis.py b/Lib/test/test_nis.py index a3a3c26..387a4e7 100644 --- a/Lib/test/test_nis.py +++ b/Lib/test/test_nis.py @@ -36,8 +36,5 @@ class NisTests(unittest.TestCase): if done: break -def test_main(): - support.run_unittest(NisTests) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_normalization.py b/Lib/test/test_normalization.py index 5dac5db..30fa612 100644 --- a/Lib/test/test_normalization.py +++ b/Lib/test/test_normalization.py @@ -1,4 +1,4 @@ -from test.support import run_unittest, open_urlresource +from test.support import open_urlresource import unittest from http.client import HTTPException @@ -97,8 +97,5 @@ class NormalizationTest(unittest.TestCase): normalize('NFC', '\ud55c\uae00') -def test_main(): - run_unittest(NormalizationTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py index dacddde..580f203 100644 --- a/Lib/test/test_ntpath.py +++ b/Lib/test/test_ntpath.py @@ -330,6 +330,75 @@ class TestNtpath(unittest.TestCase): tester('ntpath.relpath("/a/b", "/a/b")', '.') tester('ntpath.relpath("c:/foo", "C:/FOO")', '.') + def test_commonpath(self): + def check(paths, expected): + tester(('ntpath.commonpath(%r)' % paths).replace('\\\\', '\\'), + expected) + def check_error(exc, paths): + self.assertRaises(exc, ntpath.commonpath, paths) + self.assertRaises(exc, ntpath.commonpath, + [os.fsencode(p) for p in paths]) + + self.assertRaises(ValueError, ntpath.commonpath, []) + check_error(ValueError, ['C:\\Program Files', 'Program Files']) + check_error(ValueError, ['C:\\Program Files', 'C:Program Files']) + check_error(ValueError, ['\\Program Files', 'Program Files']) + check_error(ValueError, ['Program Files', 'C:\\Program Files']) + check(['C:\\Program Files'], 'C:\\Program Files') + check(['C:\\Program Files', 'C:\\Program Files'], 'C:\\Program Files') + check(['C:\\Program Files\\', 'C:\\Program Files'], + 'C:\\Program Files') + check(['C:\\Program Files\\', 'C:\\Program Files\\'], + 'C:\\Program Files') + check(['C:\\\\Program Files', 'C:\\Program Files\\\\'], + 'C:\\Program Files') + check(['C:\\.\\Program Files', 'C:\\Program Files\\.'], + 'C:\\Program Files') + check(['C:\\', 'C:\\bin'], 'C:\\') + check(['C:\\Program Files', 'C:\\bin'], 'C:\\') + check(['C:\\Program Files', 'C:\\Program Files\\Bar'], + 'C:\\Program Files') + check(['C:\\Program Files\\Foo', 'C:\\Program Files\\Bar'], + 'C:\\Program Files') + check(['C:\\Program Files', 'C:\\Projects'], 'C:\\') + check(['C:\\Program Files\\', 'C:\\Projects'], 'C:\\') + + check(['C:\\Program Files\\Foo', 'C:/Program Files/Bar'], + 'C:\\Program Files') + check(['C:\\Program Files\\Foo', 'c:/program files/bar'], + 'C:\\Program Files') + check(['c:/program files/bar', 'C:\\Program Files\\Foo'], + 'c:\\program files') + + check_error(ValueError, ['C:\\Program Files', 'D:\\Program Files']) + + check(['spam'], 'spam') + check(['spam', 'spam'], 'spam') + check(['spam', 'alot'], '') + check(['and\\jam', 'and\\spam'], 'and') + check(['and\\\\jam', 'and\\spam\\\\'], 'and') + check(['and\\.\\jam', '.\\and\\spam'], 'and') + check(['and\\jam', 'and\\spam', 'alot'], '') + check(['and\\jam', 'and\\spam', 'and'], 'and') + check(['C:and\\jam', 'C:and\\spam'], 'C:and') + + check([''], '') + check(['', 'spam\\alot'], '') + check_error(ValueError, ['', '\\spam\\alot']) + + self.assertRaises(TypeError, ntpath.commonpath, + [b'C:\\Program Files', 'C:\\Program Files\\Foo']) + self.assertRaises(TypeError, ntpath.commonpath, + [b'C:\\Program Files', 'Program Files\\Foo']) + self.assertRaises(TypeError, ntpath.commonpath, + [b'Program Files', 'C:\\Program Files\\Foo']) + self.assertRaises(TypeError, ntpath.commonpath, + ['C:\\Program Files', b'C:\\Program Files\\Foo']) + self.assertRaises(TypeError, ntpath.commonpath, + ['C:\\Program Files', b'Program Files\\Foo']) + self.assertRaises(TypeError, ntpath.commonpath, + ['Program Files', b'C:\\Program Files\\Foo']) + def test_sameopenfile(self): with TemporaryFile() as tf1, TemporaryFile() as tf2: # Make sure the same file is really the same diff --git a/Lib/test/test_numeric_tower.py b/Lib/test/test_numeric_tower.py index 3423d4e..c54dedb 100644 --- a/Lib/test/test_numeric_tower.py +++ b/Lib/test/test_numeric_tower.py @@ -5,7 +5,6 @@ import random import math import sys import operator -from test.support import run_unittest from decimal import Decimal as D from fractions import Fraction as F @@ -199,8 +198,5 @@ class ComparisonTest(unittest.TestCase): self.assertRaises(TypeError, op, v, z) -def test_main(): - run_unittest(HashTest, ComparisonTest) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_opcodes.py b/Lib/test/test_opcodes.py index f510bac..6ef93d9 100644 --- a/Lib/test/test_opcodes.py +++ b/Lib/test/test_opcodes.py @@ -1,6 +1,5 @@ # Python test set -- part 2, opcodes -from test.support import run_unittest import unittest class OpcodeTest(unittest.TestCase): @@ -105,8 +104,5 @@ class OpcodeTest(unittest.TestCase): self.assertEqual(MyString() % 3, 42) -def test_main(): - run_unittest(OpcodeTest) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_openpty.py b/Lib/test/test_openpty.py index 4785107..3f46a60 100644 --- a/Lib/test/test_openpty.py +++ b/Lib/test/test_openpty.py @@ -1,7 +1,6 @@ # Test to see if openpty works. (But don't worry if it isn't available.) import os, unittest -from test.support import run_unittest if not hasattr(os, "openpty"): raise unittest.SkipTest("os.openpty() not available.") @@ -18,8 +17,5 @@ class OpenptyTest(unittest.TestCase): os.write(slave, b'Ping!') self.assertEqual(os.read(master, 1024), b'Ping!') -def test_main(): - run_unittest(OpenptyTest) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py index ab58a98..da9c8ef 100644 --- a/Lib/test/test_operator.py +++ b/Lib/test/test_operator.py @@ -1,4 +1,6 @@ import unittest +import pickle +import sys from test import support @@ -203,6 +205,15 @@ class OperatorTestCase: self.assertRaises(TypeError, operator.mul, None, None) self.assertTrue(operator.mul(5, 2) == 10) + def test_matmul(self): + operator = self.module + self.assertRaises(TypeError, operator.matmul) + self.assertRaises(TypeError, operator.matmul, 42, 42) + class M: + def __matmul__(self, other): + return other - 1 + self.assertEqual(M() @ 42, 41) + def test_neg(self): operator = self.module self.assertRaises(TypeError, operator.neg) @@ -387,6 +398,7 @@ class OperatorTestCase: def test_methodcaller(self): operator = self.module self.assertRaises(TypeError, operator.methodcaller) + self.assertRaises(TypeError, operator.methodcaller, 12) class A: def foo(self, *args, **kwds): return args[0] + args[1] @@ -416,6 +428,7 @@ class OperatorTestCase: def __ilshift__ (self, other): return "ilshift" def __imod__ (self, other): return "imod" def __imul__ (self, other): return "imul" + def __imatmul__ (self, other): return "imatmul" def __ior__ (self, other): return "ior" def __ipow__ (self, other): return "ipow" def __irshift__ (self, other): return "irshift" @@ -430,6 +443,7 @@ class OperatorTestCase: self.assertEqual(operator.ilshift (c, 5), "ilshift") self.assertEqual(operator.imod (c, 5), "imod") self.assertEqual(operator.imul (c, 5), "imul") + self.assertEqual(operator.imatmul (c, 5), "imatmul") self.assertEqual(operator.ior (c, 5), "ior") self.assertEqual(operator.ipow (c, 5), "ipow") self.assertEqual(operator.irshift (c, 5), "irshift") @@ -480,5 +494,107 @@ class PyOperatorTestCase(OperatorTestCase, unittest.TestCase): class COperatorTestCase(OperatorTestCase, unittest.TestCase): module = c_operator + +class OperatorPickleTestCase: + def copy(self, obj, proto): + with support.swap_item(sys.modules, 'operator', self.module): + pickled = pickle.dumps(obj, proto) + with support.swap_item(sys.modules, 'operator', self.module2): + return pickle.loads(pickled) + + def test_attrgetter(self): + attrgetter = self.module.attrgetter + class A: + pass + a = A() + a.x = 'X' + a.y = 'Y' + a.z = 'Z' + a.t = A() + a.t.u = A() + a.t.u.v = 'V' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + f = attrgetter('x') + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # multiple gets + f = attrgetter('x', 'y', 'z') + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # recursive gets + f = attrgetter('t.u.v') + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + + def test_itemgetter(self): + itemgetter = self.module.itemgetter + a = 'ABCDE' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + f = itemgetter(2) + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # multiple gets + f = itemgetter(2, 0, 4) + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + + def test_methodcaller(self): + methodcaller = self.module.methodcaller + class A: + def foo(self, *args, **kwds): + return args[0] + args[1] + def bar(self, f=42): + return f + def baz(*args, **kwds): + return kwds['name'], kwds['self'] + a = A() + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + f = methodcaller('bar') + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # positional args + f = methodcaller('foo', 1, 2) + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + # keyword args + f = methodcaller('bar', f=5) + f2 = self.copy(f, proto) + self.assertEqual(repr(f2), repr(f)) + self.assertEqual(f2(a), f(a)) + f = methodcaller('baz', self='eggs', name='spam') + f2 = self.copy(f, proto) + # Can't test repr consistently with multiple keyword args + self.assertEqual(f2(a), f(a)) + +class PyPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase): + module = py_operator + module2 = py_operator + +@unittest.skipUnless(c_operator, 'requires _operator') +class PyCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase): + module = py_operator + module2 = c_operator + +@unittest.skipUnless(c_operator, 'requires _operator') +class CPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase): + module = c_operator + module2 = py_operator + +@unittest.skipUnless(c_operator, 'requires _operator') +class CCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase): + module = c_operator + module2 = c_operator + + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index 54dd9da2..da5a130 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -9,6 +9,7 @@ import contextlib import decimal import errno import fractions +import getpass import itertools import locale import mmap @@ -40,8 +41,34 @@ try: import fcntl except ImportError: fcntl = None +try: + import _winapi +except ImportError: + _winapi = None +try: + import grp + groups = [g.gr_gid for g in grp.getgrall() if getpass.getuser() in g.gr_mem] + if hasattr(os, 'getgid'): + process_gid = os.getgid() + if process_gid not in groups: + groups.append(process_gid) +except ImportError: + groups = [] +try: + import pwd + all_users = [u.pw_uid for u in pwd.getpwall()] +except ImportError: + all_users = [] +try: + from _testcapi import INT_MAX, PY_SSIZE_T_MAX +except ImportError: + INT_MAX = PY_SSIZE_T_MAX = sys.maxsize + +from test.support.script_helper import assert_python_ok -from test.script_helper import assert_python_ok +root_in_posix = False +if hasattr(os, 'geteuid'): + root_in_posix = (os.geteuid() == 0) # Detect whether we're on a Linux system that uses the (now outdated # and unmaintained) linuxthreads threading library. There's an issue @@ -106,6 +133,26 @@ class FileTests(unittest.TestCase): self.assertEqual(type(s), bytes) self.assertEqual(s, b"spam") + @support.cpython_only + # Skip the test on 32-bit platforms: the number of bytes must fit in a + # Py_ssize_t type + @unittest.skipUnless(INT_MAX < PY_SSIZE_T_MAX, + "needs INT_MAX < PY_SSIZE_T_MAX") + @support.bigmemtest(size=INT_MAX + 10, memuse=1, dry_run=False) + def test_large_read(self, size): + with open(support.TESTFN, "wb") as fp: + fp.write(b'test') + self.addCleanup(support.unlink, support.TESTFN) + + # Issue #21932: Make sure that os.read() does not raise an + # OverflowError for size larger than INT_MAX + with open(support.TESTFN, "rb") as fp: + data = os.read(fp.fileno(), size) + + # The test does not try to read more than 2 GB at once because the + # operating system is free to return less bytes than requested. + self.assertEqual(data, b'test') + def test_write(self): # os.write() accepts bytes- and buffer-like objects but not strings fd = os.open(support.TESTFN, os.O_CREAT | os.O_WRONLY) @@ -363,6 +410,28 @@ class StatAttributeTests(unittest.TestCase): os.stat(r) self.assertEqual(ctx.exception.errno, errno.EBADF) + def check_file_attributes(self, result): + self.assertTrue(hasattr(result, 'st_file_attributes')) + self.assertTrue(isinstance(result.st_file_attributes, int)) + self.assertTrue(0 <= result.st_file_attributes <= 0xFFFFFFFF) + + @unittest.skipUnless(sys.platform == "win32", + "st_file_attributes is Win32 specific") + def test_file_attributes(self): + # test file st_file_attributes (FILE_ATTRIBUTE_DIRECTORY not set) + result = os.stat(self.fname) + self.check_file_attributes(result) + self.assertEqual( + result.st_file_attributes & stat.FILE_ATTRIBUTE_DIRECTORY, + 0) + + # test directory st_file_attributes (FILE_ATTRIBUTE_DIRECTORY set) + result = os.stat(support.TESTFN) + self.check_file_attributes(result) + self.assertEqual( + result.st_file_attributes & stat.FILE_ATTRIBUTE_DIRECTORY, + stat.FILE_ATTRIBUTE_DIRECTORY) + class UtimeTests(unittest.TestCase): def setUp(self): @@ -971,17 +1040,6 @@ class MakedirTests(unittest.TestCase): os.makedirs(path, mode=mode, exist_ok=True) os.umask(old_mask) - @unittest.skipUnless(hasattr(os, 'chown'), 'test needs os.chown') - def test_chown_uid_gid_arguments_must_be_index(self): - stat = os.stat(support.TESTFN) - uid = stat.st_uid - gid = stat.st_gid - for value in (-1.0, -1j, decimal.Decimal(-1), fractions.Fraction(-2, 2)): - self.assertRaises(TypeError, os.chown, support.TESTFN, value, gid) - self.assertRaises(TypeError, os.chown, support.TESTFN, uid, value) - self.assertIsNone(os.chown(support.TESTFN, uid, gid)) - self.assertIsNone(os.chown(support.TESTFN, -1, -1)) - def test_exist_ok_s_isgid_directory(self): path = os.path.join(support.TESTFN, 'dir1') S_ISGID = stat.S_ISGID @@ -1032,6 +1090,60 @@ class MakedirTests(unittest.TestCase): os.removedirs(path) +@unittest.skipUnless(hasattr(os, 'chown'), "Test needs chown") +class ChownFileTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + os.mkdir(support.TESTFN) + + def test_chown_uid_gid_arguments_must_be_index(self): + stat = os.stat(support.TESTFN) + uid = stat.st_uid + gid = stat.st_gid + for value in (-1.0, -1j, decimal.Decimal(-1), fractions.Fraction(-2, 2)): + self.assertRaises(TypeError, os.chown, support.TESTFN, value, gid) + self.assertRaises(TypeError, os.chown, support.TESTFN, uid, value) + self.assertIsNone(os.chown(support.TESTFN, uid, gid)) + self.assertIsNone(os.chown(support.TESTFN, -1, -1)) + + @unittest.skipUnless(len(groups) > 1, "test needs more than one group") + def test_chown(self): + gid_1, gid_2 = groups[:2] + uid = os.stat(support.TESTFN).st_uid + os.chown(support.TESTFN, uid, gid_1) + gid = os.stat(support.TESTFN).st_gid + self.assertEqual(gid, gid_1) + os.chown(support.TESTFN, uid, gid_2) + gid = os.stat(support.TESTFN).st_gid + self.assertEqual(gid, gid_2) + + @unittest.skipUnless(root_in_posix and len(all_users) > 1, + "test needs root privilege and more than one user") + def test_chown_with_root(self): + uid_1, uid_2 = all_users[:2] + gid = os.stat(support.TESTFN).st_gid + os.chown(support.TESTFN, uid_1, gid) + uid = os.stat(support.TESTFN).st_uid + self.assertEqual(uid, uid_1) + os.chown(support.TESTFN, uid_2, gid) + uid = os.stat(support.TESTFN).st_uid + self.assertEqual(uid, uid_2) + + @unittest.skipUnless(not root_in_posix and len(all_users) > 1, + "test needs non-root account and more than one user") + def test_chown_without_permission(self): + uid_1, uid_2 = all_users[:2] + gid = os.stat(support.TESTFN).st_gid + with self.assertRaises(PermissionError): + os.chown(support.TESTFN, uid_1, gid) + os.chown(support.TESTFN, uid_2, gid) + + @classmethod + def tearDownClass(cls): + os.rmdir(support.TESTFN) + + class RemoveDirsTests(unittest.TestCase): def setUp(self): os.makedirs(support.TESTFN) @@ -1114,10 +1226,15 @@ class URandomTests(unittest.TestCase): self.assertNotEqual(data1, data2) -HAVE_GETENTROPY = (sysconfig.get_config_var('HAVE_GETENTROPY') == 1) +# os.urandom() doesn't use a file descriptor when it is implemented with the +# getentropy() function, the getrandom() function or the getrandom() syscall +OS_URANDOM_DONT_USE_FD = ( + sysconfig.get_config_var('HAVE_GETENTROPY') == 1 + or sysconfig.get_config_var('HAVE_GETRANDOM') == 1 + or sysconfig.get_config_var('HAVE_GETRANDOM_SYSCALL') == 1) -@unittest.skipIf(HAVE_GETENTROPY, - "getentropy() does not use a file descriptor") +@unittest.skipIf(OS_URANDOM_DONT_USE_FD , + "os.random() does not use a file descriptor") class URandomFDTests(unittest.TestCase): @unittest.skipUnless(resource, "test requires the resource module") def test_urandom_failure(self): @@ -1148,8 +1265,10 @@ class URandomFDTests(unittest.TestCase): code = """if 1: import os import sys + import test.support os.urandom(4) - os.closerange(3, 256) + with test.support.SuppressCrashReport(): + os.closerange(3, 256) sys.stdout.buffer.write(os.urandom(4)) """ rc, out, err = assert_python_ok('-Sc', code) @@ -1163,16 +1282,18 @@ class URandomFDTests(unittest.TestCase): code = """if 1: import os import sys + import test.support os.urandom(4) - for fd in range(3, 256): - try: - os.close(fd) - except OSError: - pass - else: - # Found the urandom fd (XXX hopefully) - break - os.closerange(3, 256) + with test.support.SuppressCrashReport(): + for fd in range(3, 256): + try: + os.close(fd) + except OSError: + pass + else: + # Found the urandom fd (XXX hopefully) + break + os.closerange(3, 256) with open({TESTFN!r}, 'rb') as f: os.dup2(f.fileno(), fd) sys.stdout.buffer.write(os.urandom(4)) @@ -1398,6 +1519,16 @@ class TestInvalidFD(unittest.TestCase): def test_writev(self): self.check(os.writev, [b'abc']) + def test_inheritable(self): + self.check(os.get_inheritable) + self.check(os.set_inheritable, True) + + @unittest.skipUnless(hasattr(os, 'get_blocking'), + 'needs os.get_blocking() and os.set_blocking()') + def test_blocking(self): + self.check(os.get_blocking) + self.check(os.set_blocking, True) + class LinkTests(unittest.TestCase): def setUp(self): @@ -1845,6 +1976,37 @@ class Win32SymlinkTests(unittest.TestCase): shutil.rmtree(level1) +@unittest.skipUnless(sys.platform == "win32", "Win32 specific tests") +class Win32JunctionTests(unittest.TestCase): + junction = 'junctiontest' + junction_target = os.path.dirname(os.path.abspath(__file__)) + + def setUp(self): + assert os.path.exists(self.junction_target) + assert not os.path.exists(self.junction) + + def tearDown(self): + if os.path.exists(self.junction): + # os.rmdir delegates to Windows' RemoveDirectoryW, + # which removes junction points safely. + os.rmdir(self.junction) + + def test_create_junction(self): + _winapi.CreateJunction(self.junction_target, self.junction) + self.assertTrue(os.path.exists(self.junction)) + self.assertTrue(os.path.isdir(self.junction)) + + # Junctions are not recognized as links. + self.assertFalse(os.path.islink(self.junction)) + + def test_unlink_removes_junction(self): + _winapi.CreateJunction(self.junction_target, self.junction) + self.assertTrue(os.path.exists(self.junction)) + + os.unlink(self.junction) + self.assertFalse(os.path.exists(self.junction)) + + @support.skip_unless_symlink class NonLocalSymlinkTests(unittest.TestCase): @@ -1918,6 +2080,12 @@ class PidTests(unittest.TestCase): # We are the parent of our subprocess self.assertEqual(int(stdout), os.getpid()) + def test_waitpid(self): + args = [sys.executable, '-c', 'pass'] + pid = os.spawnv(os.P_NOWAIT, args[0], args) + status = os.waitpid(pid, 0) + self.assertEqual(status, (pid, 0)) + # The introduction of this TestCase caused at least two different errors on # *nix buildbots. Temporarily skip this to let the buildbots move along. @@ -2051,11 +2219,13 @@ class TestSendfile(unittest.TestCase): @classmethod def setUpClass(cls): + cls.key = support.threading_setup() with open(support.TESTFN, "wb") as f: f.write(cls.DATA) @classmethod def tearDownClass(cls): + support.threading_cleanup(*cls.key) support.unlink(support.TESTFN) def setUp(self): @@ -2590,43 +2760,251 @@ class FDInheritanceTests(unittest.TestCase): self.assertEqual(os.get_inheritable(slave_fd), False) -@support.reap_threads -def test_main(): - support.run_unittest( - FileTests, - StatAttributeTests, - UtimeTests, - EnvironTests, - WalkTests, - FwalkTests, - MakedirTests, - DevNullTests, - URandomTests, - URandomFDTests, - ExecTests, - Win32ErrorTests, - TestInvalidFD, - PosixUidGidTests, - Pep383Tests, - Win32KillTests, - Win32ListdirTests, - Win32SymlinkTests, - NonLocalSymlinkTests, - FSEncodingTests, - DeviceEncodingTests, - PidTests, - LoginTests, - LinkTests, - TestSendfile, - ProgramPriorityTests, - ExtendedAttributeTests, - Win32DeprecatedBytesAPI, - TermsizeTests, - OSErrorTests, - RemoveDirsTests, - CPUCountTests, - FDInheritanceTests, - ) +@unittest.skipUnless(hasattr(os, 'get_blocking'), + 'needs os.get_blocking() and os.set_blocking()') +class BlockingTests(unittest.TestCase): + def test_blocking(self): + fd = os.open(__file__, os.O_RDONLY) + self.addCleanup(os.close, fd) + self.assertEqual(os.get_blocking(fd), True) + + os.set_blocking(fd, False) + self.assertEqual(os.get_blocking(fd), False) + + os.set_blocking(fd, True) + self.assertEqual(os.get_blocking(fd), True) + + + +class ExportsTests(unittest.TestCase): + def test_os_all(self): + self.assertIn('open', os.__all__) + self.assertIn('walk', os.__all__) + + +class TestScandir(unittest.TestCase): + def setUp(self): + self.path = os.path.realpath(support.TESTFN) + self.addCleanup(support.rmtree, self.path) + os.mkdir(self.path) + + def create_file(self, name="file.txt"): + filename = os.path.join(self.path, name) + with open(filename, "wb") as fp: + fp.write(b'python') + return filename + + def get_entries(self, names): + entries = dict((entry.name, entry) + for entry in os.scandir(self.path)) + self.assertEqual(sorted(entries.keys()), names) + return entries + + def assert_stat_equal(self, stat1, stat2, skip_fields): + if skip_fields: + for attr in dir(stat1): + if not attr.startswith("st_"): + continue + if attr in ("st_dev", "st_ino", "st_nlink"): + continue + self.assertEqual(getattr(stat1, attr), + getattr(stat2, attr), + (stat1, stat2, attr)) + else: + self.assertEqual(stat1, stat2) + + def check_entry(self, entry, name, is_dir, is_file, is_symlink): + self.assertEqual(entry.name, name) + self.assertEqual(entry.path, os.path.join(self.path, name)) + self.assertEqual(entry.inode(), + os.stat(entry.path, follow_symlinks=False).st_ino) + + entry_stat = os.stat(entry.path) + self.assertEqual(entry.is_dir(), + stat.S_ISDIR(entry_stat.st_mode)) + self.assertEqual(entry.is_file(), + stat.S_ISREG(entry_stat.st_mode)) + self.assertEqual(entry.is_symlink(), + os.path.islink(entry.path)) + + entry_lstat = os.stat(entry.path, follow_symlinks=False) + self.assertEqual(entry.is_dir(follow_symlinks=False), + stat.S_ISDIR(entry_lstat.st_mode)) + self.assertEqual(entry.is_file(follow_symlinks=False), + stat.S_ISREG(entry_lstat.st_mode)) + + self.assert_stat_equal(entry.stat(), + entry_stat, + os.name == 'nt' and not is_symlink) + self.assert_stat_equal(entry.stat(follow_symlinks=False), + entry_lstat, + os.name == 'nt') + + def test_attributes(self): + link = hasattr(os, 'link') + symlink = support.can_symlink() + + dirname = os.path.join(self.path, "dir") + os.mkdir(dirname) + filename = self.create_file("file.txt") + if link: + os.link(filename, os.path.join(self.path, "link_file.txt")) + if symlink: + os.symlink(dirname, os.path.join(self.path, "symlink_dir"), + target_is_directory=True) + os.symlink(filename, os.path.join(self.path, "symlink_file.txt")) + + names = ['dir', 'file.txt'] + if link: + names.append('link_file.txt') + if symlink: + names.extend(('symlink_dir', 'symlink_file.txt')) + entries = self.get_entries(names) + + entry = entries['dir'] + self.check_entry(entry, 'dir', True, False, False) + + entry = entries['file.txt'] + self.check_entry(entry, 'file.txt', False, True, False) + + if link: + entry = entries['link_file.txt'] + self.check_entry(entry, 'link_file.txt', False, True, False) + + if symlink: + entry = entries['symlink_dir'] + self.check_entry(entry, 'symlink_dir', True, False, True) + + entry = entries['symlink_file.txt'] + self.check_entry(entry, 'symlink_file.txt', False, True, True) + + def get_entry(self, name): + entries = list(os.scandir(self.path)) + self.assertEqual(len(entries), 1) + + entry = entries[0] + self.assertEqual(entry.name, name) + return entry + + def create_file_entry(self): + filename = self.create_file() + return self.get_entry(os.path.basename(filename)) + + def test_current_directory(self): + filename = self.create_file() + old_dir = os.getcwd() + try: + os.chdir(self.path) + + # call scandir() without parameter: it must list the content + # of the current directory + entries = dict((entry.name, entry) for entry in os.scandir()) + self.assertEqual(sorted(entries.keys()), + [os.path.basename(filename)]) + finally: + os.chdir(old_dir) + + def test_repr(self): + entry = self.create_file_entry() + self.assertEqual(repr(entry), "<DirEntry 'file.txt'>") + + def test_removed_dir(self): + path = os.path.join(self.path, 'dir') + + os.mkdir(path) + entry = self.get_entry('dir') + os.rmdir(path) + + # On POSIX, is_dir() result depends if scandir() filled d_type or not + if os.name == 'nt': + self.assertTrue(entry.is_dir()) + self.assertFalse(entry.is_file()) + self.assertFalse(entry.is_symlink()) + if os.name == 'nt': + self.assertRaises(FileNotFoundError, entry.inode) + # don't fail + entry.stat() + entry.stat(follow_symlinks=False) + else: + self.assertGreater(entry.inode(), 0) + self.assertRaises(FileNotFoundError, entry.stat) + self.assertRaises(FileNotFoundError, entry.stat, follow_symlinks=False) + + def test_removed_file(self): + entry = self.create_file_entry() + os.unlink(entry.path) + + self.assertFalse(entry.is_dir()) + # On POSIX, is_dir() result depends if scandir() filled d_type or not + if os.name == 'nt': + self.assertTrue(entry.is_file()) + self.assertFalse(entry.is_symlink()) + if os.name == 'nt': + self.assertRaises(FileNotFoundError, entry.inode) + # don't fail + entry.stat() + entry.stat(follow_symlinks=False) + else: + self.assertGreater(entry.inode(), 0) + self.assertRaises(FileNotFoundError, entry.stat) + self.assertRaises(FileNotFoundError, entry.stat, follow_symlinks=False) + + def test_broken_symlink(self): + if not support.can_symlink(): + return self.skipTest('cannot create symbolic link') + + filename = self.create_file("file.txt") + os.symlink(filename, + os.path.join(self.path, "symlink.txt")) + entries = self.get_entries(['file.txt', 'symlink.txt']) + entry = entries['symlink.txt'] + os.unlink(filename) + + self.assertGreater(entry.inode(), 0) + self.assertFalse(entry.is_dir()) + self.assertFalse(entry.is_file()) # broken symlink returns False + self.assertFalse(entry.is_dir(follow_symlinks=False)) + self.assertFalse(entry.is_file(follow_symlinks=False)) + self.assertTrue(entry.is_symlink()) + self.assertRaises(FileNotFoundError, entry.stat) + # don't fail + entry.stat(follow_symlinks=False) + + def test_bytes(self): + if os.name == "nt": + # On Windows, os.scandir(bytes) must raise an exception + self.assertRaises(TypeError, os.scandir, b'.') + return + + self.create_file("file.txt") + + path_bytes = os.fsencode(self.path) + entries = list(os.scandir(path_bytes)) + self.assertEqual(len(entries), 1, entries) + entry = entries[0] + + self.assertEqual(entry.name, b'file.txt') + self.assertEqual(entry.path, + os.fsencode(os.path.join(self.path, 'file.txt'))) + + def test_empty_path(self): + self.assertRaises(FileNotFoundError, os.scandir, '') + + def test_consume_iterator_twice(self): + self.create_file("file.txt") + iterator = os.scandir(self.path) + + entries = list(iterator) + self.assertEqual(len(entries), 1, entries) + + # check than consuming the iterator twice doesn't raise exception + entries2 = list(iterator) + self.assertEqual(len(entries2), 0, entries2) + + def test_bad_path_type(self): + for obj in [1234, 1.234, {}, []]: + self.assertRaises(TypeError, os.scandir, obj) + if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_osx_env.py b/Lib/test/test_osx_env.py index d8eb981..8a3bc5a 100644 --- a/Lib/test/test_osx_env.py +++ b/Lib/test/test_osx_env.py @@ -2,7 +2,7 @@ Test suite for OS X interpreter environment variables. """ -from test.support import EnvironmentVarGuard, run_unittest +from test.support import EnvironmentVarGuard import subprocess import sys import sysconfig diff --git a/Lib/test/test_parser.py b/Lib/test/test_parser.py index e7968cc..3d301b4 100644 --- a/Lib/test/test_parser.py +++ b/Lib/test/test_parser.py @@ -4,7 +4,7 @@ import sys import operator import struct from test import support -from test.script_helper import assert_python_failure +from test.support.script_helper import assert_python_failure # # First, we test that we can generate trees from valid source fragments, @@ -63,6 +63,22 @@ class RoundtripLegalSyntaxTestCase(unittest.TestCase): " if (yield):\n" " yield x\n") + def test_await_statement(self): + self.check_suite("async def f():\n await smth()") + self.check_suite("async def f():\n foo = await smth()") + self.check_suite("async def f():\n foo, bar = await smth()") + self.check_suite("async def f():\n (await smth())") + self.check_suite("async def f():\n foo((await smth()))") + self.check_suite("async def f():\n await foo(); return 42") + + def test_async_with_statement(self): + self.check_suite("async def f():\n async with 1: pass") + self.check_suite("async def f():\n async with a as b, c as d: pass") + + def test_async_for_statement(self): + self.check_suite("async def f():\n async for i in (): pass") + self.check_suite("async def f():\n async for i, b in (): pass") + def test_nonlocal_statement(self): self.check_suite("def f():\n" " x = 0\n" @@ -313,7 +329,12 @@ class RoundtripLegalSyntaxTestCase(unittest.TestCase): "except Exception as e:\n" " raise ValueError from e\n") + def test_list_displays(self): + self.check_expr('[]') + self.check_expr('[*{2}, 3, *[4]]') + def test_set_displays(self): + self.check_expr('{*{2}, 3, *[4]}') self.check_expr('{2}') self.check_expr('{2,}') self.check_expr('{2, 3}') @@ -325,6 +346,15 @@ class RoundtripLegalSyntaxTestCase(unittest.TestCase): self.check_expr('{a:b,}') self.check_expr('{a:b, c:d}') self.check_expr('{a:b, c:d,}') + self.check_expr('{**{}}') + self.check_expr('{**{}, 3:4, **{5:6, 7:8}}') + + def test_argument_unpacking(self): + self.check_expr("f(*a, **b)") + self.check_expr('f(a, *b, *c, *d)') + self.check_expr('f(**a, **b)') + self.check_expr('f(2, *a, *b, **b, **c, **d)') + self.check_expr("f(*b, *() or () and (), **{} and {}, **() or {})") def test_set_comprehensions(self): self.check_expr('{x for x in seq}') @@ -730,16 +760,5 @@ class OtherParserCase(unittest.TestCase): with self.assertRaises(TypeError): parser.expr("a", "b") -def test_main(): - support.run_unittest( - RoundtripLegalSyntaxTestCase, - IllegalSyntaxTestCase, - CompileTestCase, - ParserStackLimitTestCase, - STObjectTestCase, - OtherParserCase, - ) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_pathlib.py b/Lib/test/test_pathlib.py index 11420e2..1c53ab7 100644 --- a/Lib/test/test_pathlib.py +++ b/Lib/test/test_pathlib.py @@ -4,13 +4,10 @@ import os import errno import pathlib import pickle -import shutil import socket import stat -import sys import tempfile import unittest -from contextlib import contextmanager from test import support TESTFN = support.TESTFN @@ -747,7 +744,6 @@ class PureWindowsPathTest(_BasePurePathTest, unittest.TestCase): self.assertEqual(P('//Some/SHARE/a/B'), P('//somE/share/A/b')) def test_as_uri(self): - from urllib.parse import quote_from_bytes P = self.cls with self.assertRaises(ValueError): P('/a/b').as_uri() @@ -1269,11 +1265,55 @@ class _BasePathTest(object): p = self.cls.cwd() self._test_cwd(p) + def _test_home(self, p): + q = self.cls(os.path.expanduser('~')) + self.assertEqual(p, q) + self.assertEqual(str(p), str(q)) + self.assertIs(type(p), type(q)) + self.assertTrue(p.is_absolute()) + + def test_home(self): + p = self.cls.home() + self._test_home(p) + + def test_samefile(self): + fileA_path = os.path.join(BASE, 'fileA') + fileB_path = os.path.join(BASE, 'dirB', 'fileB') + p = self.cls(fileA_path) + pp = self.cls(fileA_path) + q = self.cls(fileB_path) + self.assertTrue(p.samefile(fileA_path)) + self.assertTrue(p.samefile(pp)) + self.assertFalse(p.samefile(fileB_path)) + self.assertFalse(p.samefile(q)) + # Test the non-existent file case + non_existent = os.path.join(BASE, 'foo') + r = self.cls(non_existent) + self.assertRaises(FileNotFoundError, p.samefile, r) + self.assertRaises(FileNotFoundError, p.samefile, non_existent) + self.assertRaises(FileNotFoundError, r.samefile, p) + self.assertRaises(FileNotFoundError, r.samefile, non_existent) + self.assertRaises(FileNotFoundError, r.samefile, r) + self.assertRaises(FileNotFoundError, r.samefile, non_existent) + def test_empty_path(self): # The empty path points to '.' p = self.cls('') self.assertEqual(p.stat(), os.stat('.')) + def test_expanduser_common(self): + P = self.cls + p = P('~') + self.assertEqual(p.expanduser(), P(os.path.expanduser('~'))) + p = P('foo') + self.assertEqual(p.expanduser(), p) + p = P('/~') + self.assertEqual(p.expanduser(), p) + p = P('../~') + self.assertEqual(p.expanduser(), p) + p = P(P('').absolute().anchor) / '~' + self.assertEqual(p.expanduser(), p) + def test_exists(self): P = self.cls p = P(BASE) @@ -1301,6 +1341,23 @@ class _BasePathTest(object): self.assertIsInstance(f, io.RawIOBase) self.assertEqual(f.read().strip(), b"this is file A") + def test_read_write_bytes(self): + p = self.cls(BASE) + (p / 'fileA').write_bytes(b'abcdefg') + self.assertEqual((p / 'fileA').read_bytes(), b'abcdefg') + # check that trying to write str does not truncate the file + self.assertRaises(TypeError, (p / 'fileA').write_bytes, 'somestr') + self.assertEqual((p / 'fileA').read_bytes(), b'abcdefg') + + def test_read_write_text(self): + p = self.cls(BASE) + (p / 'fileA').write_text('äbcdefg', encoding='latin-1') + self.assertEqual((p / 'fileA').read_text( + encoding='utf-8', errors='ignore'), 'bcdefg') + # check that trying to write bytes does not truncate the file + self.assertRaises(TypeError, (p / 'fileA').write_text, b'somebytes') + self.assertEqual((p / 'fileA').read_text(encoding='latin-1'), 'äbcdefg') + def test_iterdir(self): P = self.cls p = P(BASE) @@ -1604,6 +1661,59 @@ class _BasePathTest(object): # the parent's permissions follow the default process settings self.assertEqual(stat.S_IMODE(p.parent.stat().st_mode), mode) + def test_mkdir_exist_ok(self): + p = self.cls(BASE, 'dirB') + st_ctime_first = p.stat().st_ctime + self.assertTrue(p.exists()) + self.assertTrue(p.is_dir()) + with self.assertRaises(FileExistsError) as cm: + p.mkdir() + self.assertEqual(cm.exception.errno, errno.EEXIST) + p.mkdir(exist_ok=True) + self.assertTrue(p.exists()) + self.assertEqual(p.stat().st_ctime, st_ctime_first) + + def test_mkdir_exist_ok_with_parent(self): + p = self.cls(BASE, 'dirC') + self.assertTrue(p.exists()) + with self.assertRaises(FileExistsError) as cm: + p.mkdir() + self.assertEqual(cm.exception.errno, errno.EEXIST) + p = p / 'newdirC' + p.mkdir(parents=True) + st_ctime_first = p.stat().st_ctime + self.assertTrue(p.exists()) + with self.assertRaises(FileExistsError) as cm: + p.mkdir(parents=True) + self.assertEqual(cm.exception.errno, errno.EEXIST) + p.mkdir(parents=True, exist_ok=True) + self.assertTrue(p.exists()) + self.assertEqual(p.stat().st_ctime, st_ctime_first) + + def test_mkdir_with_child_file(self): + p = self.cls(BASE, 'dirB', 'fileB') + self.assertTrue(p.exists()) + # An exception is raised when the last path component is an existing + # regular file, regardless of whether exist_ok is true or not. + with self.assertRaises(FileExistsError) as cm: + p.mkdir(parents=True) + self.assertEqual(cm.exception.errno, errno.EEXIST) + with self.assertRaises(FileExistsError) as cm: + p.mkdir(parents=True, exist_ok=True) + self.assertEqual(cm.exception.errno, errno.EEXIST) + + def test_mkdir_no_parents_file(self): + p = self.cls(BASE, 'fileA') + self.assertTrue(p.exists()) + # An exception is raised when the last path component is an existing + # regular file, regardless of whether exist_ok is true or not. + with self.assertRaises(FileExistsError) as cm: + p.mkdir() + self.assertEqual(cm.exception.errno, errno.EEXIST) + with self.assertRaises(FileExistsError) as cm: + p.mkdir(exist_ok=True) + self.assertEqual(cm.exception.errno, errno.EEXIST) + @with_symlinks def test_symlink_to(self): P = self.cls(BASE) @@ -1846,7 +1956,6 @@ class PosixPathTest(_BasePathTest, unittest.TestCase): @with_symlinks def test_resolve_loop(self): # Loop detection for broken symlinks under POSIX - P = self.cls # Loops with relative symlinks os.symlink('linkX/inside', join('linkX')) self._check_symlink_loop(BASE, 'linkX') @@ -1878,6 +1987,48 @@ class PosixPathTest(_BasePathTest, unittest.TestCase): self.assertEqual(given, expect) self.assertEqual(set(p.rglob("FILEd*")), set()) + def test_expanduser(self): + P = self.cls + support.import_module('pwd') + import pwd + pwdent = pwd.getpwuid(os.getuid()) + username = pwdent.pw_name + userhome = pwdent.pw_dir.rstrip('/') + # find arbitrary different user (if exists) + for pwdent in pwd.getpwall(): + othername = pwdent.pw_name + otherhome = pwdent.pw_dir.rstrip('/') + if othername != username and otherhome: + break + + p1 = P('~/Documents') + p2 = P('~' + username + '/Documents') + p3 = P('~' + othername + '/Documents') + p4 = P('../~' + username + '/Documents') + p5 = P('/~' + username + '/Documents') + p6 = P('') + p7 = P('~fakeuser/Documents') + + with support.EnvironmentVarGuard() as env: + env.pop('HOME', None) + + self.assertEqual(p1.expanduser(), P(userhome) / 'Documents') + self.assertEqual(p2.expanduser(), P(userhome) / 'Documents') + self.assertEqual(p3.expanduser(), P(otherhome) / 'Documents') + self.assertEqual(p4.expanduser(), p4) + self.assertEqual(p5.expanduser(), p5) + self.assertEqual(p6.expanduser(), p6) + self.assertRaises(RuntimeError, p7.expanduser) + + env['HOME'] = '/tmp' + self.assertEqual(p1.expanduser(), P('/tmp/Documents')) + self.assertEqual(p2.expanduser(), P(userhome) / 'Documents') + self.assertEqual(p3.expanduser(), P(otherhome) / 'Documents') + self.assertEqual(p4.expanduser(), p4) + self.assertEqual(p5.expanduser(), p5) + self.assertEqual(p6.expanduser(), p6) + self.assertRaises(RuntimeError, p7.expanduser) + @only_nt class WindowsPathTest(_BasePathTest, unittest.TestCase): @@ -1893,6 +2044,61 @@ class WindowsPathTest(_BasePathTest, unittest.TestCase): p = P(BASE, "dirC") self.assertEqual(set(p.rglob("FILEd")), { P(BASE, "dirC/dirD/fileD") }) + def test_expanduser(self): + P = self.cls + with support.EnvironmentVarGuard() as env: + env.pop('HOME', None) + env.pop('USERPROFILE', None) + env.pop('HOMEPATH', None) + env.pop('HOMEDRIVE', None) + env['USERNAME'] = 'alice' + + # test that the path returns unchanged + p1 = P('~/My Documents') + p2 = P('~alice/My Documents') + p3 = P('~bob/My Documents') + p4 = P('/~/My Documents') + p5 = P('d:~/My Documents') + p6 = P('') + self.assertRaises(RuntimeError, p1.expanduser) + self.assertRaises(RuntimeError, p2.expanduser) + self.assertRaises(RuntimeError, p3.expanduser) + self.assertEqual(p4.expanduser(), p4) + self.assertEqual(p5.expanduser(), p5) + self.assertEqual(p6.expanduser(), p6) + + def check(): + env.pop('USERNAME', None) + self.assertEqual(p1.expanduser(), + P('C:/Users/alice/My Documents')) + self.assertRaises(KeyError, p2.expanduser) + env['USERNAME'] = 'alice' + self.assertEqual(p2.expanduser(), + P('C:/Users/alice/My Documents')) + self.assertEqual(p3.expanduser(), + P('C:/Users/bob/My Documents')) + self.assertEqual(p4.expanduser(), p4) + self.assertEqual(p5.expanduser(), p5) + self.assertEqual(p6.expanduser(), p6) + + # test the first lookup key in the env vars + env['HOME'] = 'C:\\Users\\alice' + check() + + # test that HOMEPATH is available instead + env.pop('HOME', None) + env['HOMEPATH'] = 'C:\\Users\\alice' + check() + + env['HOMEDRIVE'] = 'C:\\' + env['HOMEPATH'] = 'Users\\alice' + check() + + env.pop('HOMEDRIVE', None) + env.pop('HOMEPATH', None) + env['USERPROFILE'] = 'C:\\Users\\alice' + check() + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py index 5025792..41e5091 100644 --- a/Lib/test/test_peepholer.py +++ b/Lib/test/test_peepholer.py @@ -319,21 +319,5 @@ class TestBuglets(unittest.TestCase): f() -def test_main(verbose=None): - import sys - from test import support - test_classes = (TestTranforms, TestBuglets) - support.run_unittest(*test_classes) - - # verify reference counting - if verbose and hasattr(sys, 'gettotalrefcount'): - import gc - counts = [None] * 5 - for i in range(len(counts)): - support.run_unittest(*test_classes) - gc.collect() - counts[i] = sys.gettotalrefcount() - print(counts) - if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_pep247.py b/Lib/test/test_pep247.py index b85a26a..ab5f4189 100644 --- a/Lib/test/test_pep247.py +++ b/Lib/test/test_pep247.py @@ -6,7 +6,6 @@ for hashing algorithms import hmac import unittest from hashlib import md5, sha1, sha224, sha256, sha384, sha512 -from test import support class Pep247Test(unittest.TestCase): @@ -63,8 +62,5 @@ class Pep247Test(unittest.TestCase): def test_hmac(self): self.check_module(hmac, key=b'abc') -def test_main(): - support.run_unittest(Pep247Test) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_pep292.py b/Lib/test/test_pep292.py index fd5256c..1e5e227 100644 --- a/Lib/test/test_pep292.py +++ b/Lib/test/test_pep292.py @@ -244,11 +244,5 @@ class TestTemplate(unittest.TestCase): 'tim likes to eat a bag of ham worth $100') -def test_main(): - from test import support - test_classes = [TestTemplate,] - support.run_unittest(*test_classes) - - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_pep3120.py b/Lib/test/test_pep3120.py index 5b63998..97dced8 100644 --- a/Lib/test/test_pep3120.py +++ b/Lib/test/test_pep3120.py @@ -1,7 +1,6 @@ # This file is marked as binary in the CVS, to prevent MacCVS from recoding it. import unittest -from test import support class PEP3120Test(unittest.TestCase): @@ -40,8 +39,5 @@ class BuiltinCompileTests(unittest.TestCase): self.assertEqual('Ç', ns['u']) -def test_main(): - support.run_unittest(PEP3120Test, BuiltinCompileTests) - -if __name__=="__main__": - test_main() +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_pep3131.py b/Lib/test/test_pep3131.py index 2e6b90a..0679845 100644 --- a/Lib/test/test_pep3131.py +++ b/Lib/test/test_pep3131.py @@ -1,6 +1,5 @@ import unittest import sys -from test import support class PEP3131Test(unittest.TestCase): @@ -28,8 +27,5 @@ class PEP3131Test(unittest.TestCase): else: self.fail("expected exception didn't occur") -def test_main(): - support.run_unittest(PEP3131Test) - -if __name__=="__main__": - test_main() +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_pep3151.py b/Lib/test/test_pep3151.py index 7d4a5d8..8d560cd 100644 --- a/Lib/test/test_pep3151.py +++ b/Lib/test/test_pep3151.py @@ -7,7 +7,6 @@ import unittest import errno from errno import EEXIST -from test import support class SubOSError(OSError): pass @@ -202,8 +201,5 @@ class ExplicitSubclassingTest(unittest.TestCase): self.assertEqual(str(e), '') -def test_main(): - support.run_unittest(__name__) - -if __name__=="__main__": - test_main() +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_pep380.py b/Lib/test/test_pep380.py index 69194df..23ffbed 100644 --- a/Lib/test/test_pep380.py +++ b/Lib/test/test_pep380.py @@ -1013,11 +1013,5 @@ class TestPEP380Operation(unittest.TestCase): self.assertEqual(v, (1, 2, 3, 4)) -def test_main(): - from test import support - test_classes = [TestPEP380Operation] - support.run_unittest(*test_classes) - - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_pep479.py b/Lib/test/test_pep479.py new file mode 100644 index 0000000..bc235ce --- /dev/null +++ b/Lib/test/test_pep479.py @@ -0,0 +1,34 @@ +from __future__ import generator_stop + +import unittest + + +class TestPEP479(unittest.TestCase): + def test_stopiteration_wrapping(self): + def f(): + raise StopIteration + def g(): + yield f() + with self.assertRaisesRegex(RuntimeError, + "generator raised StopIteration"): + next(g()) + + def test_stopiteration_wrapping_context(self): + def f(): + raise StopIteration + def g(): + yield f() + + try: + next(g()) + except RuntimeError as exc: + self.assertIs(type(exc.__cause__), StopIteration) + self.assertIs(type(exc.__context__), StopIteration) + self.assertTrue(exc.__suppress_context__) + else: + self.fail('__cause__, __context__, or __suppress_context__ ' + 'were not properly set') + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_pickle.py b/Lib/test/test_pickle.py index 3d75d65..ab6d92b 100644 --- a/Lib/test/test_pickle.py +++ b/Lib/test/test_pickle.py @@ -366,7 +366,10 @@ class CompatPickleTests(unittest.TestCase): for name, exc in get_exceptions(builtins): with self.subTest(name): - if exc in (BlockingIOError, ResourceWarning): + if exc in (BlockingIOError, + ResourceWarning, + StopAsyncIteration, + RecursionError): continue if exc is not OSError and issubclass(exc, OSError): self.assertEqual(reverse_mapping('builtins', name), diff --git a/Lib/test/test_pkg.py b/Lib/test/test_pkg.py index 9883000..532e8fe 100644 --- a/Lib/test/test_pkg.py +++ b/Lib/test/test_pkg.py @@ -291,9 +291,5 @@ class TestPkg(unittest.TestCase): import t8 self.assertEqual(t8.__doc__, "doc for t8") -def test_main(): - support.run_unittest(__name__) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_pkgimport.py b/Lib/test/test_pkgimport.py index 370b2aa..5d9a451 100644 --- a/Lib/test/test_pkgimport.py +++ b/Lib/test/test_pkgimport.py @@ -7,7 +7,7 @@ import tempfile import unittest from importlib.util import cache_from_source -from test.support import run_unittest, create_empty_file +from test.support import create_empty_file class TestImport(unittest.TestCase): @@ -76,9 +76,5 @@ class TestImport(unittest.TestCase): self.assertEqual(getattr(module, var), 1) -def test_main(): - run_unittest(TestImport) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_pkgutil.py b/Lib/test/test_pkgutil.py index e0c8635de..57ebf1f 100644 --- a/Lib/test/test_pkgutil.py +++ b/Lib/test/test_pkgutil.py @@ -104,6 +104,9 @@ class PkgutilTests(unittest.TestCase): class PkgutilPEP302Tests(unittest.TestCase): class MyTestLoader(object): + def create_module(self, spec): + return None + def exec_module(self, mod): # Count how many times the module is reloaded mod.__dict__['loads'] = mod.__dict__.get('loads', 0) + 1 diff --git a/Lib/test/test_platform.py b/Lib/test/test_platform.py index b3de43b..3ea71f1 100644 --- a/Lib/test/test_platform.py +++ b/Lib/test/test_platform.py @@ -236,7 +236,14 @@ class PlatformTest(unittest.TestCase): self.assertEqual(sts, 0) def test_dist(self): - res = platform.dist() + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', + 'dist\(\) and linux_distribution\(\) ' + 'functions are deprecated .*', + PendingDeprecationWarning, + ) + res = platform.dist() def test_libc_ver(self): import os @@ -305,16 +312,37 @@ class PlatformTest(unittest.TestCase): f.write('Fedora release 19 (Schr\xf6dinger\u2019s Cat)\n') with mock.patch('platform._UNIXCONFDIR', tempdir): - distname, version, distid = platform.linux_distribution() - - self.assertEqual(distname, 'Fedora') + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', + 'dist\(\) and linux_distribution\(\) ' + 'functions are deprecated .*', + PendingDeprecationWarning, + ) + distname, version, distid = platform.linux_distribution() + + self.assertEqual(distname, 'Fedora') self.assertEqual(version, '19') self.assertEqual(distid, 'Schr\xf6dinger\u2019s Cat') -def test_main(): - support.run_unittest( - PlatformTest - ) + +class DeprecationTest(unittest.TestCase): + + def test_dist_deprecation(self): + with self.assertWarns(PendingDeprecationWarning) as cm: + platform.dist() + self.assertEqual(str(cm.warning), + 'dist() and linux_distribution() functions are ' + 'deprecated in Python 3.5 and will be removed in ' + 'Python 3.7') + + def test_linux_distribution_deprecation(self): + with self.assertWarns(PendingDeprecationWarning) as cm: + platform.linux_distribution() + self.assertEqual(str(cm.warning), + 'dist() and linux_distribution() functions are ' + 'deprecated in Python 3.5 and will be removed in ' + 'Python 3.7') if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_popen.py b/Lib/test/test_popen.py index 116b3dd..da01a87 100644 --- a/Lib/test/test_popen.py +++ b/Lib/test/test_popen.py @@ -61,8 +61,5 @@ class PopenTest(unittest.TestCase): with os.popen(cmd="exit 0", mode="w", buffering=-1): pass -def test_main(): - support.run_unittest(PopenTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py index 8a3c9f4..bceeb93 100644 --- a/Lib/test/test_poplib.py +++ b/Lib/test/test_poplib.py @@ -44,6 +44,7 @@ line3\r\n\ class DummyPOP3Handler(asynchat.async_chat): CAPAS = {'UIDL': [], 'IMPLEMENTATION': ['python-testlib-pop-server']} + enable_UTF8 = False def __init__(self, conn): asynchat.async_chat.__init__(self, conn) @@ -142,6 +143,11 @@ class DummyPOP3Handler(asynchat.async_chat): self.push(' '.join(_ln)) self.push('.') + def cmd_utf8(self, arg): + self.push('+OK I know RFC6856' + if self.enable_UTF8 + else '-ERR What is UTF8?!') + if SUPPORTS_SSL: def cmd_stls(self, arg): @@ -309,6 +315,16 @@ class TestPOP3Class(TestCase): self.client.uidl() self.client.uidl('foo') + def test_utf8_raises_if_unsupported(self): + self.server.handler.enable_UTF8 = False + self.assertRaises(poplib.error_proto, self.client.utf8) + + def test_utf8(self): + self.server.handler.enable_UTF8 = True + expected = b'+OK I know RFC6856' + result = self.client.utf8() + self.assertEqual(result, expected) + def test_capa(self): capa = self.client.capa() self.assertTrue('IMPLEMENTATION' in capa.keys()) @@ -345,23 +361,18 @@ class TestPOP3Class(TestCase): if SUPPORTS_SSL: + from test.test_ftplib import SSLConnection - class DummyPOP3_SSLHandler(DummyPOP3Handler): + class DummyPOP3_SSLHandler(SSLConnection, DummyPOP3Handler): def __init__(self, conn): asynchat.async_chat.__init__(self, conn) - ssl_socket = ssl.wrap_socket(self.socket, certfile=CERTFILE, - server_side=True, - do_handshake_on_connect=False) - self.del_channel() - self.set_socket(ssl_socket) - # Must try handshake before calling push() - self.tls_active = True - self.tls_starting = True - self._do_tls_handshake() + self.secure_connection() self.set_terminator(b"\r\n") self.in_buffer = [] self.push('+OK dummy pop3 server ready. <timestamp>') + self.tls_active = True + self.tls_starting = False @requires_ssl @@ -452,7 +463,7 @@ class TestTimeouts(TestCase): del self.thread # Clear out any dangling Thread objects. def server(self, evt, serv): - serv.listen(5) + serv.listen() evt.set() try: conn, addr = serv.accept() diff --git a/Lib/test/test_posix.py b/Lib/test/test_posix.py index d767989..2a59c38 100644 --- a/Lib/test/test_posix.py +++ b/Lib/test/test_posix.py @@ -9,7 +9,6 @@ import errno import sys import time import os -import fcntl import platform import pwd import shutil @@ -355,7 +354,7 @@ class PosixTester(unittest.TestCase): def test_oscloexec(self): fd = os.open(support.TESTFN, os.O_RDONLY|os.O_CLOEXEC) self.addCleanup(os.close, fd) - self.assertTrue(fcntl.fcntl(fd, fcntl.F_GETFD) & fcntl.FD_CLOEXEC) + self.assertFalse(os.get_inheritable(fd)) @unittest.skipUnless(hasattr(posix, 'O_EXLOCK'), 'test needs posix.O_EXLOCK') @@ -643,8 +642,8 @@ class PosixTester(unittest.TestCase): self.addCleanup(os.close, w) self.assertFalse(os.get_inheritable(r)) self.assertFalse(os.get_inheritable(w)) - self.assertTrue(fcntl.fcntl(r, fcntl.F_GETFL) & os.O_NONBLOCK) - self.assertTrue(fcntl.fcntl(w, fcntl.F_GETFL) & os.O_NONBLOCK) + self.assertFalse(os.get_blocking(r)) + self.assertFalse(os.get_blocking(w)) # try reading from an empty pipe: this should fail, not block self.assertRaises(OSError, os.read, r, 1) # try a write big enough to fill-up the pipe: this should either @@ -1184,16 +1183,16 @@ class PosixTester(unittest.TestCase): support.unlink(fn) fd = None try: - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): fd = os.open(fn_with_NUL, os.O_WRONLY | os.O_CREAT) # raises finally: if fd is not None: os.close(fd) self.assertFalse(os.path.exists(fn)) - self.assertRaises(TypeError, os.mkdir, fn_with_NUL) + self.assertRaises(ValueError, os.mkdir, fn_with_NUL) self.assertFalse(os.path.exists(fn)) open(fn, 'wb').close() - self.assertRaises(TypeError, os.stat, fn_with_NUL) + self.assertRaises(ValueError, os.stat, fn_with_NUL) def test_path_with_null_byte(self): fn = os.fsencode(support.TESTFN) diff --git a/Lib/test/test_posixpath.py b/Lib/test/test_posixpath.py index 1d4596e..9d20471 100644 --- a/Lib/test/test_posixpath.py +++ b/Lib/test/test_posixpath.py @@ -57,18 +57,6 @@ class PosixPathTest(unittest.TestCase): self.assertEqual(posixpath.join(b"/foo/", b"bar/", b"baz/"), b"/foo/bar/baz/") - def test_join_errors(self): - # Check posixpath.join raises friendly TypeErrors. - errmsg = "Can't mix strings and bytes in path components" - with self.assertRaisesRegex(TypeError, errmsg): - posixpath.join(b'bytes', 'str') - with self.assertRaisesRegex(TypeError, errmsg): - posixpath.join('str', b'bytes') - # regression, see #15377 - with self.assertRaises(TypeError) as cm: - posixpath.join(None, 'str') - self.assertNotEqual(cm.exception.args[0], errmsg) - def test_split(self): self.assertEqual(posixpath.split("/foo/bar"), ("/foo", "bar")) self.assertEqual(posixpath.split("/"), ("/", "")) @@ -523,6 +511,60 @@ class PosixPathTest(unittest.TestCase): finally: os.getcwdb = real_getcwdb + def test_commonpath(self): + def check(paths, expected): + self.assertEqual(posixpath.commonpath(paths), expected) + self.assertEqual(posixpath.commonpath([os.fsencode(p) for p in paths]), + os.fsencode(expected)) + def check_error(exc, paths): + self.assertRaises(exc, posixpath.commonpath, paths) + self.assertRaises(exc, posixpath.commonpath, + [os.fsencode(p) for p in paths]) + + self.assertRaises(ValueError, posixpath.commonpath, []) + check_error(ValueError, ['/usr', 'usr']) + check_error(ValueError, ['usr', '/usr']) + + check(['/usr/local'], '/usr/local') + check(['/usr/local', '/usr/local'], '/usr/local') + check(['/usr/local/', '/usr/local'], '/usr/local') + check(['/usr/local/', '/usr/local/'], '/usr/local') + check(['/usr//local', '//usr/local'], '/usr/local') + check(['/usr/./local', '/./usr/local'], '/usr/local') + check(['/', '/dev'], '/') + check(['/usr', '/dev'], '/') + check(['/usr/lib/', '/usr/lib/python3'], '/usr/lib') + check(['/usr/lib/', '/usr/lib64/'], '/usr') + + check(['/usr/lib', '/usr/lib64'], '/usr') + check(['/usr/lib/', '/usr/lib64'], '/usr') + + check(['spam'], 'spam') + check(['spam', 'spam'], 'spam') + check(['spam', 'alot'], '') + check(['and/jam', 'and/spam'], 'and') + check(['and//jam', 'and/spam//'], 'and') + check(['and/./jam', './and/spam'], 'and') + check(['and/jam', 'and/spam', 'alot'], '') + check(['and/jam', 'and/spam', 'and'], 'and') + + check([''], '') + check(['', 'spam/alot'], '') + check_error(ValueError, ['', '/spam/alot']) + + self.assertRaises(TypeError, posixpath.commonpath, + [b'/usr/lib/', '/usr/lib/python3']) + self.assertRaises(TypeError, posixpath.commonpath, + [b'/usr/lib/', 'usr/lib/python3']) + self.assertRaises(TypeError, posixpath.commonpath, + [b'usr/lib/', '/usr/lib/python3']) + self.assertRaises(TypeError, posixpath.commonpath, + ['/usr/lib/', b'/usr/lib/python3']) + self.assertRaises(TypeError, posixpath.commonpath, + ['/usr/lib/', b'usr/lib/python3']) + self.assertRaises(TypeError, posixpath.commonpath, + ['usr/lib/', b'/usr/lib/python3']) + class PosixCommonTest(test_genericpath.CommonTest, unittest.TestCase): pathmodule = posixpath diff --git a/Lib/test/test_pow.py b/Lib/test/test_pow.py index 20b1066..6feac40 100644 --- a/Lib/test/test_pow.py +++ b/Lib/test/test_pow.py @@ -122,8 +122,5 @@ class PowTest(unittest.TestCase): eq(pow(a, -fiveto), expected) eq(expected, 1.0) # else we didn't push fiveto to evenness -def test_main(): - test.support.run_unittest(PowTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_pprint.py b/Lib/test/test_pprint.py index 180ddb0..357c5cf 100644 --- a/Lib/test/test_pprint.py +++ b/Lib/test/test_pprint.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- +import collections +import io +import itertools import pprint +import random import test.support -import unittest import test.test_set -import random -import collections -import itertools +import types +import unittest # list, tuple and dict subclasses that do or don't overwrite __repr__ class list2(list): @@ -48,6 +50,25 @@ class Unorderable: def __repr__(self): return str(id(self)) +# Class Orderable is orderable with any type +class Orderable: + def __init__(self, hash): + self._hash = hash + def __lt__(self, other): + return False + def __gt__(self, other): + return self != other + def __le__(self, other): + return self == other + def __ge__(self, other): + return True + def __eq__(self, other): + return self is other + def __ne__(self, other): + return self is not other + def __hash__(self): + return self._hash + class QueryTestCase(unittest.TestCase): def setUp(self): @@ -55,6 +76,18 @@ class QueryTestCase(unittest.TestCase): self.b = list(range(200)) self.a[-12] = self.b + def test_init(self): + pp = pprint.PrettyPrinter() + pp = pprint.PrettyPrinter(indent=4, width=40, depth=5, + stream=io.StringIO(), compact=True) + pp = pprint.PrettyPrinter(4, 40, 5, io.StringIO()) + with self.assertRaises(TypeError): + pp = pprint.PrettyPrinter(4, 40, 5, io.StringIO(), True) + self.assertRaises(ValueError, pprint.PrettyPrinter, indent=-1) + self.assertRaises(ValueError, pprint.PrettyPrinter, depth=0) + self.assertRaises(ValueError, pprint.PrettyPrinter, depth=-1) + self.assertRaises(ValueError, pprint.PrettyPrinter, width=0) + def test_basic(self): # Verify .isrecursive() and .isreadable() w/o recursion pp = pprint.PrettyPrinter() @@ -195,10 +228,52 @@ class QueryTestCase(unittest.TestCase): o = [o1, o2] expected = """\ [ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + {'first': 1, 'second': 2, 'third': 3}]""" + self.assertEqual(pprint.pformat(o, indent=4, width=42), expected) + expected = """\ +[ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], { 'first': 1, 'second': 2, 'third': 3}]""" - self.assertEqual(pprint.pformat(o, indent=4, width=42), expected) + self.assertEqual(pprint.pformat(o, indent=4, width=41), expected) + + def test_width(self): + expected = """\ +[[[[[[1, 2, 3], + '1 2']]]], + {1: [1, 2, 3], + 2: [12, 34]}, + 'abc def ghi', + ('ab cd ef',), + set2({1, 23}), + [[[[[1, 2, 3], + '1 2']]]]]""" + o = eval(expected) + self.assertEqual(pprint.pformat(o, width=15), expected) + self.assertEqual(pprint.pformat(o, width=16), expected) + self.assertEqual(pprint.pformat(o, width=25), expected) + self.assertEqual(pprint.pformat(o, width=14), """\ +[[[[[[1, + 2, + 3], + '1 ' + '2']]]], + {1: [1, + 2, + 3], + 2: [12, + 34]}, + 'abc def ' + 'ghi', + ('ab cd ' + 'ef',), + set2({1, + 23}), + [[[[[1, + 2, + 3], + '1 ' + '2']]]]]""") def test_sorted_dict(self): # Starting in Python 2.5, pprint sorts dict displays by key regardless @@ -219,19 +294,51 @@ class QueryTestCase(unittest.TestCase): r"{5: [[]], 'xy\tab\n': (3,), (): {}}") def test_ordered_dict(self): + d = collections.OrderedDict() + self.assertEqual(pprint.pformat(d, width=1), 'OrderedDict()') + d = collections.OrderedDict([]) + self.assertEqual(pprint.pformat(d, width=1), 'OrderedDict()') words = 'the quick brown fox jumped over a lazy dog'.split() d = collections.OrderedDict(zip(words, itertools.count())) self.assertEqual(pprint.pformat(d), """\ -{'the': 0, - 'quick': 1, - 'brown': 2, - 'fox': 3, - 'jumped': 4, - 'over': 5, - 'a': 6, - 'lazy': 7, - 'dog': 8}""") +OrderedDict([('the', 0), + ('quick', 1), + ('brown', 2), + ('fox', 3), + ('jumped', 4), + ('over', 5), + ('a', 6), + ('lazy', 7), + ('dog', 8)])""") + + def test_mapping_proxy(self): + words = 'the quick brown fox jumped over a lazy dog'.split() + d = dict(zip(words, itertools.count())) + m = types.MappingProxyType(d) + self.assertEqual(pprint.pformat(m), """\ +mappingproxy({'a': 6, + 'brown': 2, + 'dog': 8, + 'fox': 3, + 'jumped': 4, + 'lazy': 7, + 'over': 5, + 'quick': 1, + 'the': 0})""") + d = collections.OrderedDict(zip(words, itertools.count())) + m = types.MappingProxyType(d) + self.assertEqual(pprint.pformat(m), """\ +mappingproxy(OrderedDict([('the', 0), + ('quick', 1), + ('brown', 2), + ('fox', 3), + ('jumped', 4), + ('over', 5), + ('a', 6), + ('lazy', 7), + ('dog', 8)]))""") + def test_subclassing(self): o = {'names with spaces': 'should be presented using repr()', 'others.should.not.be': 'like.this'} @@ -535,16 +642,35 @@ frozenset2({0, self.assertEqual(pprint.pformat(dict.fromkeys(keys, 0)), '{%r: 0, %r: 0}' % tuple(sorted(keys, key=id))) + def test_sort_orderable_and_unorderable_values(self): + # Issue 22721: sorted pprints is not stable + a = Unorderable() + b = Orderable(hash(a)) # should have the same hash value + # self-test + self.assertLess(a, b) + self.assertLess(str(type(b)), str(type(a))) + self.assertEqual(sorted([b, a]), [a, b]) + self.assertEqual(sorted([a, b]), [a, b]) + # set + self.assertEqual(pprint.pformat(set([b, a]), width=1), + '{%r,\n %r}' % (a, b)) + self.assertEqual(pprint.pformat(set([a, b]), width=1), + '{%r,\n %r}' % (a, b)) + # dict + self.assertEqual(pprint.pformat(dict.fromkeys([b, a]), width=1), + '{%r: None,\n %r: None}' % (a, b)) + self.assertEqual(pprint.pformat(dict.fromkeys([a, b]), width=1), + '{%r: None,\n %r: None}' % (a, b)) + def test_str_wrap(self): # pprint tries to wrap strings intelligently fox = 'the quick brown fox jumped over a lazy dog' - self.assertEqual(pprint.pformat(fox, width=20), """\ -('the quick ' - 'brown fox ' - 'jumped over a ' - 'lazy dog')""") + self.assertEqual(pprint.pformat(fox, width=19), """\ +('the quick brown ' + 'fox jumped over ' + 'a lazy dog')""") self.assertEqual(pprint.pformat({'a': 1, 'b': fox, 'c': 2}, - width=26), """\ + width=25), """\ {'a': 1, 'b': 'the quick brown ' 'fox jumped over ' @@ -556,12 +682,34 @@ frozenset2({0, # - non-ASCII is allowed # - an apostrophe doesn't disrupt the pprint special = "Portons dix bons \"whiskys\"\nà l'avocat goujat\t qui fumait au zoo" - self.assertEqual(pprint.pformat(special, width=21), """\ -('Portons dix ' - 'bons "whiskys"\\n' + self.assertEqual(pprint.pformat(special, width=68), repr(special)) + self.assertEqual(pprint.pformat(special, width=31), """\ +('Portons dix bons "whiskys"\\n' + "à l'avocat goujat\\t qui " + 'fumait au zoo')""") + self.assertEqual(pprint.pformat(special, width=20), """\ +('Portons dix bons ' + '"whiskys"\\n' "à l'avocat " 'goujat\\t qui ' 'fumait au zoo')""") + self.assertEqual(pprint.pformat([[[[[special]]]]], width=35), """\ +[[[[['Portons dix bons "whiskys"\\n' + "à l'avocat goujat\\t qui " + 'fumait au zoo']]]]]""") + self.assertEqual(pprint.pformat([[[[[special]]]]], width=25), """\ +[[[[['Portons dix bons ' + '"whiskys"\\n' + "à l'avocat " + 'goujat\\t qui ' + 'fumait au zoo']]]]]""") + self.assertEqual(pprint.pformat([[[[[special]]]]], width=23), """\ +[[[[['Portons dix ' + 'bons "whiskys"\\n' + "à l'avocat " + 'goujat\\t qui ' + 'fumait au ' + 'zoo']]]]]""") # An unwrappable string is formatted as its repr unwrappable = "x" * 100 self.assertEqual(pprint.pformat(unwrappable, width=80), repr(unwrappable)) @@ -584,7 +732,267 @@ frozenset2({0, 14, 15], [], [0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]""" - self.assertEqual(pprint.pformat(o, width=48, compact=True), expected) + self.assertEqual(pprint.pformat(o, width=47, compact=True), expected) + + def test_compact_width(self): + levels = 20 + number = 10 + o = [0] * number + for i in range(levels - 1): + o = [o] + for w in range(levels * 2 + 1, levels + 3 * number - 1): + lines = pprint.pformat(o, width=w, compact=True).splitlines() + maxwidth = max(map(len, lines)) + self.assertLessEqual(maxwidth, w) + self.assertGreater(maxwidth, w - 3) + + def test_bytes_wrap(self): + self.assertEqual(pprint.pformat(b'', width=1), "b''") + self.assertEqual(pprint.pformat(b'abcd', width=1), "b'abcd'") + letters = b'abcdefghijklmnopqrstuvwxyz' + self.assertEqual(pprint.pformat(letters, width=29), repr(letters)) + self.assertEqual(pprint.pformat(letters, width=19), """\ +(b'abcdefghijkl' + b'mnopqrstuvwxyz')""") + self.assertEqual(pprint.pformat(letters, width=18), """\ +(b'abcdefghijkl' + b'mnopqrstuvwx' + b'yz')""") + self.assertEqual(pprint.pformat(letters, width=16), """\ +(b'abcdefghijkl' + b'mnopqrstuvwx' + b'yz')""") + special = bytes(range(16)) + self.assertEqual(pprint.pformat(special, width=61), repr(special)) + self.assertEqual(pprint.pformat(special, width=48), """\ +(b'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b' + b'\\x0c\\r\\x0e\\x0f')""") + self.assertEqual(pprint.pformat(special, width=32), """\ +(b'\\x00\\x01\\x02\\x03' + b'\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b' + b'\\x0c\\r\\x0e\\x0f')""") + self.assertEqual(pprint.pformat(special, width=1), """\ +(b'\\x00\\x01\\x02\\x03' + b'\\x04\\x05\\x06\\x07' + b'\\x08\\t\\n\\x0b' + b'\\x0c\\r\\x0e\\x0f')""") + self.assertEqual(pprint.pformat({'a': 1, 'b': letters, 'c': 2}, + width=21), """\ +{'a': 1, + 'b': b'abcdefghijkl' + b'mnopqrstuvwx' + b'yz', + 'c': 2}""") + self.assertEqual(pprint.pformat({'a': 1, 'b': letters, 'c': 2}, + width=20), """\ +{'a': 1, + 'b': b'abcdefgh' + b'ijklmnop' + b'qrstuvwxyz', + 'c': 2}""") + self.assertEqual(pprint.pformat([[[[[[letters]]]]]], width=25), """\ +[[[[[[b'abcdefghijklmnop' + b'qrstuvwxyz']]]]]]""") + self.assertEqual(pprint.pformat([[[[[[special]]]]]], width=41), """\ +[[[[[[b'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07' + b'\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f']]]]]]""") + # Check that the pprint is a usable repr + for width in range(1, 64): + formatted = pprint.pformat(special, width=width) + self.assertEqual(eval(formatted), special) + formatted = pprint.pformat([special] * 2, width=width) + self.assertEqual(eval(formatted), [special] * 2) + + def test_bytearray_wrap(self): + self.assertEqual(pprint.pformat(bytearray(), width=1), "bytearray(b'')") + letters = bytearray(b'abcdefghijklmnopqrstuvwxyz') + self.assertEqual(pprint.pformat(letters, width=40), repr(letters)) + self.assertEqual(pprint.pformat(letters, width=28), """\ +bytearray(b'abcdefghijkl' + b'mnopqrstuvwxyz')""") + self.assertEqual(pprint.pformat(letters, width=27), """\ +bytearray(b'abcdefghijkl' + b'mnopqrstuvwx' + b'yz')""") + self.assertEqual(pprint.pformat(letters, width=25), """\ +bytearray(b'abcdefghijkl' + b'mnopqrstuvwx' + b'yz')""") + special = bytearray(range(16)) + self.assertEqual(pprint.pformat(special, width=72), repr(special)) + self.assertEqual(pprint.pformat(special, width=57), """\ +bytearray(b'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b' + b'\\x0c\\r\\x0e\\x0f')""") + self.assertEqual(pprint.pformat(special, width=41), """\ +bytearray(b'\\x00\\x01\\x02\\x03' + b'\\x04\\x05\\x06\\x07\\x08\\t\\n\\x0b' + b'\\x0c\\r\\x0e\\x0f')""") + self.assertEqual(pprint.pformat(special, width=1), """\ +bytearray(b'\\x00\\x01\\x02\\x03' + b'\\x04\\x05\\x06\\x07' + b'\\x08\\t\\n\\x0b' + b'\\x0c\\r\\x0e\\x0f')""") + self.assertEqual(pprint.pformat({'a': 1, 'b': letters, 'c': 2}, + width=31), """\ +{'a': 1, + 'b': bytearray(b'abcdefghijkl' + b'mnopqrstuvwx' + b'yz'), + 'c': 2}""") + self.assertEqual(pprint.pformat([[[[[letters]]]]], width=37), """\ +[[[[[bytearray(b'abcdefghijklmnop' + b'qrstuvwxyz')]]]]]""") + self.assertEqual(pprint.pformat([[[[[special]]]]], width=50), """\ +[[[[[bytearray(b'\\x00\\x01\\x02\\x03\\x04\\x05\\x06\\x07' + b'\\x08\\t\\n\\x0b\\x0c\\r\\x0e\\x0f')]]]]]""") + + def test_default_dict(self): + d = collections.defaultdict(int) + self.assertEqual(pprint.pformat(d, width=1), "defaultdict(<class 'int'>, {})") + words = 'the quick brown fox jumped over a lazy dog'.split() + d = collections.defaultdict(int, zip(words, itertools.count())) + self.assertEqual(pprint.pformat(d), +"""\ +defaultdict(<class 'int'>, + {'a': 6, + 'brown': 2, + 'dog': 8, + 'fox': 3, + 'jumped': 4, + 'lazy': 7, + 'over': 5, + 'quick': 1, + 'the': 0})""") + + def test_counter(self): + d = collections.Counter() + self.assertEqual(pprint.pformat(d, width=1), "Counter()") + d = collections.Counter('senselessness') + self.assertEqual(pprint.pformat(d, width=40), +"""\ +Counter({'s': 6, + 'e': 4, + 'n': 2, + 'l': 1})""") + + def test_chainmap(self): + d = collections.ChainMap() + self.assertEqual(pprint.pformat(d, width=1), "ChainMap({})") + words = 'the quick brown fox jumped over a lazy dog'.split() + items = list(zip(words, itertools.count())) + d = collections.ChainMap(dict(items)) + self.assertEqual(pprint.pformat(d), +"""\ +ChainMap({'a': 6, + 'brown': 2, + 'dog': 8, + 'fox': 3, + 'jumped': 4, + 'lazy': 7, + 'over': 5, + 'quick': 1, + 'the': 0})""") + d = collections.ChainMap(dict(items), collections.OrderedDict(items)) + self.assertEqual(pprint.pformat(d), +"""\ +ChainMap({'a': 6, + 'brown': 2, + 'dog': 8, + 'fox': 3, + 'jumped': 4, + 'lazy': 7, + 'over': 5, + 'quick': 1, + 'the': 0}, + OrderedDict([('the', 0), + ('quick', 1), + ('brown', 2), + ('fox', 3), + ('jumped', 4), + ('over', 5), + ('a', 6), + ('lazy', 7), + ('dog', 8)]))""") + + def test_deque(self): + d = collections.deque() + self.assertEqual(pprint.pformat(d, width=1), "deque([])") + d = collections.deque(maxlen=7) + self.assertEqual(pprint.pformat(d, width=1), "deque([], maxlen=7)") + words = 'the quick brown fox jumped over a lazy dog'.split() + d = collections.deque(zip(words, itertools.count())) + self.assertEqual(pprint.pformat(d), +"""\ +deque([('the', 0), + ('quick', 1), + ('brown', 2), + ('fox', 3), + ('jumped', 4), + ('over', 5), + ('a', 6), + ('lazy', 7), + ('dog', 8)])""") + d = collections.deque(zip(words, itertools.count()), maxlen=7) + self.assertEqual(pprint.pformat(d), +"""\ +deque([('brown', 2), + ('fox', 3), + ('jumped', 4), + ('over', 5), + ('a', 6), + ('lazy', 7), + ('dog', 8)], + maxlen=7)""") + + def test_user_dict(self): + d = collections.UserDict() + self.assertEqual(pprint.pformat(d, width=1), "{}") + words = 'the quick brown fox jumped over a lazy dog'.split() + d = collections.UserDict(zip(words, itertools.count())) + self.assertEqual(pprint.pformat(d), +"""\ +{'a': 6, + 'brown': 2, + 'dog': 8, + 'fox': 3, + 'jumped': 4, + 'lazy': 7, + 'over': 5, + 'quick': 1, + 'the': 0}""") + + def test_user_dict(self): + d = collections.UserList() + self.assertEqual(pprint.pformat(d, width=1), "[]") + words = 'the quick brown fox jumped over a lazy dog'.split() + d = collections.UserList(zip(words, itertools.count())) + self.assertEqual(pprint.pformat(d), +"""\ +[('the', 0), + ('quick', 1), + ('brown', 2), + ('fox', 3), + ('jumped', 4), + ('over', 5), + ('a', 6), + ('lazy', 7), + ('dog', 8)]""") + + def test_user_string(self): + d = collections.UserString('') + self.assertEqual(pprint.pformat(d, width=1), "''") + d = collections.UserString('the quick brown fox jumped over a lazy dog') + self.assertEqual(pprint.pformat(d, width=20), +"""\ +('the quick brown ' + 'fox jumped over ' + 'a lazy dog')""") + self.assertEqual(pprint.pformat({1: d}, width=20), +"""\ +{1: 'the quick ' + 'brown fox ' + 'jumped over a ' + 'lazy dog'}""") class DottedPrettyPrinter(pprint.PrettyPrinter): diff --git a/Lib/test/test_property.py b/Lib/test/test_property.py index cee7203..5addd36 100644 --- a/Lib/test/test_property.py +++ b/Lib/test/test_property.py @@ -3,7 +3,6 @@ import sys import unittest -from test.support import run_unittest class PropertyBase(Exception): pass @@ -77,6 +76,13 @@ class PropertyNewGetter(object): """new docstring""" return 8 +class PropertyWritableDoc(object): + + @property + def spam(self): + """Eggs""" + return "eggs" + class PropertyTests(unittest.TestCase): def test_property_decorator_baseclass(self): # see #1620 @@ -151,6 +157,21 @@ class PropertyTests(unittest.TestCase): foo = property(foo) C.foo.__isabstractmethod__ + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_property_builtin_doc_writable(self): + p = property(doc='basic') + self.assertEqual(p.__doc__, 'basic') + p.__doc__ = 'extended' + self.assertEqual(p.__doc__, 'extended') + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def test_property_decorator_doc_writable(self): + sub = PropertyWritableDoc() + self.assertEqual(sub.__class__.spam.__doc__, 'Eggs') + sub.__class__.spam.__doc__ = 'Spam' + self.assertEqual(sub.__class__.spam.__doc__, 'Spam') # Issue 5890: subclasses of property do not preserve method __doc__ strings class PropertySub(property): @@ -247,8 +268,5 @@ class PropertySubclassTests(unittest.TestCase): -def test_main(): - run_unittest(PropertyTests, PropertySubclassTests) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_pstats.py b/Lib/test/test_pstats.py index 9ebeebb..566b3ea 100644 --- a/Lib/test/test_pstats.py +++ b/Lib/test/test_pstats.py @@ -34,12 +34,5 @@ class StatsTestCase(unittest.TestCase): stats.add(self.stats, self.stats) -def test_main(): - support.run_unittest( - AddCallersTestCase, - StatsTestCase, - ) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_pty.py b/Lib/test/test_pty.py index 8916861..ef5e99e 100644 --- a/Lib/test/test_pty.py +++ b/Lib/test/test_pty.py @@ -1,7 +1,6 @@ -from test.support import verbose, run_unittest, import_module, reap_children +from test.support import verbose, import_module, reap_children -#Skip these tests if either fcntl or termios is not available -fcntl = import_module('fcntl') +# Skip these tests if termios is not available import_module('termios') import errno @@ -84,16 +83,18 @@ class PtyTest(unittest.TestCase): # in master_open(), we need to read the EOF. # Ensure the fd is non-blocking in case there's nothing to read. - orig_flags = fcntl.fcntl(master_fd, fcntl.F_GETFL) - fcntl.fcntl(master_fd, fcntl.F_SETFL, orig_flags | os.O_NONBLOCK) + blocking = os.get_blocking(master_fd) try: - s1 = os.read(master_fd, 1024) - self.assertEqual(b'', s1) - except OSError as e: - if e.errno != errno.EAGAIN: - raise - # Restore the original flags. - fcntl.fcntl(master_fd, fcntl.F_SETFL, orig_flags) + os.set_blocking(master_fd, False) + try: + s1 = os.read(master_fd, 1024) + self.assertEqual(b'', s1) + except OSError as e: + if e.errno != errno.EAGAIN: + raise + finally: + # Restore the original flags. + os.set_blocking(master_fd, blocking) debug("Writing to slave_fd") os.write(slave_fd, TEST_STRING_1) @@ -292,11 +293,8 @@ class SmallPtyTests(unittest.TestCase): pty._copy(masters[0]) -def test_main(verbose=None): - try: - run_unittest(SmallPtyTests, PtyTest) - finally: - reap_children() +def tearDownModule(): + reap_children() if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_pulldom.py b/Lib/test/test_pulldom.py index b81a595..1932c6b 100644 --- a/Lib/test/test_pulldom.py +++ b/Lib/test/test_pulldom.py @@ -6,7 +6,7 @@ import xml.sax from xml.sax.xmlreader import AttributesImpl from xml.dom import pulldom -from test.support import run_unittest, findfile +from test.support import findfile tstfile = findfile("test.xml", subdir="xmltestdata") @@ -339,9 +339,5 @@ class SAX2DOMTestCase(unittest.TestCase): doc.unlink() -def test_main(): - run_unittest(PullDOMTestCase, ThoroughTestCase, SAX2DOMTestCase) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_pwd.py b/Lib/test/test_pwd.py index 37a1bcb..b7b1a4a 100644 --- a/Lib/test/test_pwd.py +++ b/Lib/test/test_pwd.py @@ -107,8 +107,5 @@ class PwdTest(unittest.TestCase): self.assertRaises(KeyError, pwd.getpwuid, 2**128) self.assertRaises(KeyError, pwd.getpwuid, -2**128) -def test_main(): - support.run_unittest(PwdTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_py_compile.py b/Lib/test/test_py_compile.py index c1fd4f6..4a6caa5 100644 --- a/Lib/test/test_py_compile.py +++ b/Lib/test/test_py_compile.py @@ -98,6 +98,7 @@ class PyCompileTests(unittest.TestCase): self.assertFalse(os.path.exists( importlib.util.cache_from_source(bad_coding))) + @unittest.skipIf(sys.flags.optimize > 0, 'test does not work with -O') def test_double_dot_no_clobber(self): # http://bugs.python.org/issue22966 # py_compile foo.bar.py -> __pycache__/foo.cpython-34.pyc @@ -117,6 +118,10 @@ class PyCompileTests(unittest.TestCase): self.assertTrue(os.path.exists(cache_path)) self.assertFalse(os.path.exists(pyc_path)) + def test_optimization_path(self): + # Specifying optimized bytecode should lead to a path reflecting that. + self.assertIn('opt-2', py_compile.compile(self.source_path, optimize=2)) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_pyclbr.py b/Lib/test/test_pyclbr.py index 39eb65f..cab430b 100644 --- a/Lib/test/test_pyclbr.py +++ b/Lib/test/test_pyclbr.py @@ -2,7 +2,6 @@ Test cases for pyclbr.py Nick Mathewson ''' -from test.support import run_unittest import sys from types import FunctionType, MethodType, BuiltinFunctionType import pyclbr @@ -173,9 +172,5 @@ class PyclbrTest(TestCase): self.assertRaises(ImportError, pyclbr.readmodule_ex, 'asyncore.foo') -def test_main(): - run_unittest(PyclbrTest) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_pydoc.py b/Lib/test/test_pydoc.py index 0e990b6..ec5c31b 100644 --- a/Lib/test/test_pydoc.py +++ b/Lib/test/test_pydoc.py @@ -2,7 +2,6 @@ import os import sys import builtins import contextlib -import difflib import importlib.util import inspect import pydoc @@ -22,7 +21,7 @@ import xml.etree import textwrap from io import StringIO from collections import namedtuple -from test.script_helper import assert_python_ok +from test.support.script_helper import assert_python_ok from test.support import ( TESTFN, rmtree, reap_children, reap_threads, captured_output, captured_stdout, @@ -257,7 +256,10 @@ expected_html_data_docstrings = tuple(s.replace(' ', ' ') for s in expected_data_docstrings) # output pattern for missing module -missing_pattern = "no Python documentation found for '%s'" +missing_pattern = '''\ +No Python documentation found for %r. +Use help() to get the interactive help utility. +Use help(str) for help on the str class.'''.replace('\n', os.linesep) # output pattern for module with bad imports badimport_pattern = "problem in %s - ImportError: No module named %r" @@ -364,15 +366,6 @@ def get_pydoc_text(module): output = patt.sub('', output) return output.strip(), loc -def print_diffs(text1, text2): - "Prints unified diffs for two texts" - # XXX now obsolete, use unittest built-in support - lines1 = text1.splitlines(keepends=True) - lines2 = text2.splitlines(keepends=True) - diffs = difflib.unified_diff(lines1, lines2, n=0, fromfile='expected', - tofile='got') - print('\n' + ''.join(diffs)) - def get_html_title(text): # Bit of hack, but good enough for test purposes header, _, _ = text.partition("</head>") @@ -418,9 +411,7 @@ class PydocDocTest(unittest.TestCase): expected_html = expected_html_pattern % ( (mod_url, mod_file, doc_loc) + expected_html_data_docstrings) - if result != expected_html: - print_diffs(expected_html, result) - self.fail("outputs are not equal, see diff above") + self.assertEqual(result, expected_html) @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") @@ -433,9 +424,7 @@ class PydocDocTest(unittest.TestCase): (doc_loc,) + expected_text_data_docstrings + (inspect.getabsfile(pydoc_mod),)) - if result != expected_text: - print_diffs(expected_text, result) - self.fail("outputs are not equal, see diff above") + self.assertEqual(expected_text, result) def test_text_enum_member_with_value_zero(self): # Test issue #20654 to ensure enum member with value 0 can be @@ -931,9 +920,7 @@ class PydocWithMetaClasses(unittest.TestCase): expected_text = expected_dynamicattribute_pattern % ( (__name__,) + expected_text_data_docstrings[:2]) result = output.getvalue().strip() - if result != expected_text: - print_diffs(expected_text, result) - self.fail("outputs are not equal, see diff above") + self.assertEqual(expected_text, result) @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") @@ -954,9 +941,7 @@ class PydocWithMetaClasses(unittest.TestCase): helper(Class) expected_text = expected_virtualattribute_pattern1 % __name__ result = output.getvalue().strip() - if result != expected_text: - print_diffs(expected_text, result) - self.fail("outputs are not equal, see diff above") + self.assertEqual(expected_text, result) @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") @@ -996,19 +981,13 @@ class PydocWithMetaClasses(unittest.TestCase): helper(Class1) expected_text1 = expected_virtualattribute_pattern2 % __name__ result1 = output.getvalue().strip() - if result1 != expected_text1: - print_diffs(expected_text1, result1) - fail1 = True + self.assertEqual(expected_text1, result1) output = StringIO() helper = pydoc.Helper(output=output) helper(Class2) expected_text2 = expected_virtualattribute_pattern3 % __name__ result2 = output.getvalue().strip() - if result2 != expected_text2: - print_diffs(expected_text2, result2) - fail2 = True - if fail1 or fail2: - self.fail("outputs are not equal, see diff above") + self.assertEqual(expected_text2, result2) @unittest.skipIf(sys.flags.optimize >= 2, "Docstrings are omitted with -O2 and above") @@ -1025,9 +1004,7 @@ class PydocWithMetaClasses(unittest.TestCase): helper(C) expected_text = expected_missingattribute_pattern % __name__ result = output.getvalue().strip() - if result != expected_text: - print_diffs(expected_text, result) - self.fail("outputs are not equal, see diff above") + self.assertEqual(expected_text, result) def test_resolve_false(self): # Issue #23008: pydoc enum.{,Int}Enum failed diff --git a/Lib/test/test_pyexpat.py b/Lib/test/test_pyexpat.py index 216a46b..550aebf 100644 --- a/Lib/test/test_pyexpat.py +++ b/Lib/test/test_pyexpat.py @@ -11,7 +11,7 @@ import traceback from xml.parsers import expat from xml.parsers.expat import errors -from test.support import sortdict, run_unittest +from test.support import sortdict class SetAttributeTest(unittest.TestCase): @@ -737,19 +737,5 @@ class ForeignDTDTests(unittest.TestCase): self.assertEqual(handler_call_args, [("bar", "baz")]) -def test_main(): - run_unittest(SetAttributeTest, - ParseTest, - NamespaceSeparatorTest, - InterningTest, - BufferTextTest, - HandlerExceptionTest, - PositionTest, - sf1296433Test, - ChardataBufferTest, - MalformedInputTest, - ErrorMessageTest, - ForeignDTDTests) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py index 2cdfee4..4ccaa39 100644 --- a/Lib/test/test_queue.py +++ b/Lib/test/test_queue.py @@ -354,10 +354,5 @@ class FailingQueueTest(BlockingTestMixin, unittest.TestCase): self.failing_queue_test(q) -def test_main(): - support.run_unittest(QueueTest, LifoQueueTest, PriorityQueueTest, - FailingQueueTest) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_quopri.py b/Lib/test/test_quopri.py index 92511fa..7cac013 100644 --- a/Lib/test/test_quopri.py +++ b/Lib/test/test_quopri.py @@ -1,4 +1,3 @@ -from test import support import unittest import sys, os, io, subprocess @@ -207,9 +206,5 @@ zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz''') p = p.decode('latin-1') self.assertEqual(cout.splitlines(), p.splitlines()) -def test_main(): - support.run_unittest(QuopriTestCase) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_raise.py b/Lib/test/test_raise.py index be5c1c6..a41b353 100644 --- a/Lib/test/test_raise.py +++ b/Lib/test/test_raise.py @@ -415,8 +415,5 @@ class TestRemovedFunctionality(unittest.TestCase): self.fail("No exception raised") -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_range.py b/Lib/test/test_range.py index 2dbcebc..106c732 100644 --- a/Lib/test/test_range.py +++ b/Lib/test/test_range.py @@ -647,8 +647,5 @@ class RangeTest(unittest.TestCase): with self.assertRaises(AttributeError): del rangeobj.step -def test_main(): - test.support.run_unittest(RangeTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index 7348af3..7a74141 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -38,6 +38,24 @@ class ReTests(unittest.TestCase): self.assertIs(type(actual), type(expect), msg) recurse(actual, expect) + def checkPatternError(self, pattern, errmsg, pos=None): + with self.assertRaises(re.error) as cm: + re.compile(pattern) + with self.subTest(pattern=pattern): + err = cm.exception + self.assertEqual(err.msg, errmsg) + if pos is not None: + self.assertEqual(err.pos, pos) + + def checkTemplateError(self, pattern, repl, string, errmsg, pos=None): + with self.assertRaises(re.error) as cm: + re.sub(pattern, repl, string) + with self.subTest(pattern=pattern, repl=repl): + err = cm.exception + self.assertEqual(err.msg, errmsg) + if pos is not None: + self.assertEqual(err.pos, pos) + def test_keep_buffer(self): # See bug 14212 b = bytearray(b'x') @@ -84,7 +102,7 @@ class ReTests(unittest.TestCase): self.assertEqual(re.sub("(?i)b+", "x", "bbbb BBBB"), 'x x') self.assertEqual(re.sub(r'\d+', self.bump_num, '08.2 -2 23x99y'), '9.3 -3 24x100y') - self.assertEqual(re.sub(r'\d+', self.bump_num, '08.2 -2 23x99y', 3), + self.assertEqual(re.sub(r'\d+', self.bump_num, '08.2 -2 23x99y', count=3), '9.3 -3 23x99y') self.assertEqual(re.sub('.', lambda m: r"\n", 'x'), '\\n') @@ -100,11 +118,14 @@ class ReTests(unittest.TestCase): self.assertEqual(re.sub('(?P<unk>x)', '\g<unk>\g<unk>', 'xx'), 'xxxx') self.assertEqual(re.sub('(?P<unk>x)', '\g<1>\g<1>', 'xx'), 'xxxx') - self.assertEqual(re.sub('a',r'\t\n\v\r\f\a\b\B\Z\a\A\w\W\s\S\d\D','a'), - '\t\n\v\r\f\a\b\\B\\Z\a\\A\\w\\W\\s\\S\\d\\D') - self.assertEqual(re.sub('a', '\t\n\v\r\f\a', 'a'), '\t\n\v\r\f\a') - self.assertEqual(re.sub('a', '\t\n\v\r\f\a', 'a'), - (chr(9)+chr(10)+chr(11)+chr(13)+chr(12)+chr(7))) + self.assertEqual(re.sub('a', r'\t\n\v\r\f\a\b', 'a'), '\t\n\v\r\f\a\b') + self.assertEqual(re.sub('a', '\t\n\v\r\f\a\b', 'a'), '\t\n\v\r\f\a\b') + self.assertEqual(re.sub('a', '\t\n\v\r\f\a\b', 'a'), + (chr(9)+chr(10)+chr(11)+chr(13)+chr(12)+chr(7)+chr(8))) + for c in 'cdehijklmopqsuwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': + with self.subTest(c): + with self.assertWarns(DeprecationWarning): + self.assertEqual(re.sub('a', '\\' + c, 'a'), '\\' + c) self.assertEqual(re.sub('^\s*', 'X', 'test'), 'Xtest') @@ -145,6 +166,7 @@ class ReTests(unittest.TestCase): self.assertEqual(re.sub('x', r'\009', 'x'), '\0' + '9') self.assertEqual(re.sub('x', r'\111', 'x'), '\111') self.assertEqual(re.sub('x', r'\117', 'x'), '\117') + self.assertEqual(re.sub('x', r'\377', 'x'), '\377') self.assertEqual(re.sub('x', r'\1111', 'x'), '\1111') self.assertEqual(re.sub('x', r'\1111', 'x'), '\111' + '1') @@ -155,21 +177,25 @@ class ReTests(unittest.TestCase): self.assertEqual(re.sub('x', r'\09', 'x'), '\0' + '9') self.assertEqual(re.sub('x', r'\0a', 'x'), '\0' + 'a') - self.assertEqual(re.sub('x', r'\400', 'x'), '\0') - self.assertEqual(re.sub('x', r'\777', 'x'), '\377') - - self.assertRaises(re.error, re.sub, 'x', r'\1', 'x') - self.assertRaises(re.error, re.sub, 'x', r'\8', 'x') - self.assertRaises(re.error, re.sub, 'x', r'\9', 'x') - self.assertRaises(re.error, re.sub, 'x', r'\11', 'x') - self.assertRaises(re.error, re.sub, 'x', r'\18', 'x') - self.assertRaises(re.error, re.sub, 'x', r'\1a', 'x') - self.assertRaises(re.error, re.sub, 'x', r'\90', 'x') - self.assertRaises(re.error, re.sub, 'x', r'\99', 'x') - self.assertRaises(re.error, re.sub, 'x', r'\118', 'x') # r'\11' + '8' - self.assertRaises(re.error, re.sub, 'x', r'\11a', 'x') - self.assertRaises(re.error, re.sub, 'x', r'\181', 'x') # r'\18' + '1' - self.assertRaises(re.error, re.sub, 'x', r'\800', 'x') # r'\80' + '0' + self.checkTemplateError('x', r'\400', 'x', + r'octal escape value \400 outside of ' + r'range 0-0o377', 0) + self.checkTemplateError('x', r'\777', 'x', + r'octal escape value \777 outside of ' + r'range 0-0o377', 0) + + self.checkTemplateError('x', r'\1', 'x', 'invalid group reference') + self.checkTemplateError('x', r'\8', 'x', 'invalid group reference') + self.checkTemplateError('x', r'\9', 'x', 'invalid group reference') + self.checkTemplateError('x', r'\11', 'x', 'invalid group reference') + self.checkTemplateError('x', r'\18', 'x', 'invalid group reference') + self.checkTemplateError('x', r'\1a', 'x', 'invalid group reference') + self.checkTemplateError('x', r'\90', 'x', 'invalid group reference') + self.checkTemplateError('x', r'\99', 'x', 'invalid group reference') + self.checkTemplateError('x', r'\118', 'x', 'invalid group reference') # r'\11' + '8' + self.checkTemplateError('x', r'\11a', 'x', 'invalid group reference') + self.checkTemplateError('x', r'\181', 'x', 'invalid group reference') # r'\18' + '1' + self.checkTemplateError('x', r'\800', 'x', 'invalid group reference') # r'\80' + '0' # in python2.3 (etc), these loop endlessly in sre_parser.py self.assertEqual(re.sub('(((((((((((x)))))))))))', r'\11', 'x'), 'x') @@ -180,7 +206,7 @@ class ReTests(unittest.TestCase): def test_qualified_re_sub(self): self.assertEqual(re.sub('a', 'b', 'aaaaa'), 'bbbbb') - self.assertEqual(re.sub('a', 'b', 'aaaaa', 1), 'baaaa') + self.assertEqual(re.sub('a', 'b', 'aaaaa', count=1), 'baaaa') def test_bug_114660(self): self.assertEqual(re.sub(r'(\S)\s+(\S)', r'\1 \2', 'hello there'), @@ -194,75 +220,105 @@ class ReTests(unittest.TestCase): def test_symbolic_groups(self): re.compile('(?P<a>x)(?P=a)(?(a)y)') re.compile('(?P<a1>x)(?P=a1)(?(a1)y)') - self.assertRaises(re.error, re.compile, '(?P<a>)(?P<a>)') - self.assertRaises(re.error, re.compile, '(?Px)') - self.assertRaises(re.error, re.compile, '(?P=)') - self.assertRaises(re.error, re.compile, '(?P=1)') - self.assertRaises(re.error, re.compile, '(?P=a)') - self.assertRaises(re.error, re.compile, '(?P=a1)') - self.assertRaises(re.error, re.compile, '(?P=a.)') - self.assertRaises(re.error, re.compile, '(?P<)') - self.assertRaises(re.error, re.compile, '(?P<>)') - self.assertRaises(re.error, re.compile, '(?P<1>)') - self.assertRaises(re.error, re.compile, '(?P<a.>)') - self.assertRaises(re.error, re.compile, '(?())') - self.assertRaises(re.error, re.compile, '(?(a))') - self.assertRaises(re.error, re.compile, '(?(1a))') - self.assertRaises(re.error, re.compile, '(?(a.))') + re.compile('(?P<a1>x)\1(?(1)y)') + self.checkPatternError('(?P<a>)(?P<a>)', + "redefinition of group name 'a' as group 2; " + "was group 1") + self.checkPatternError('(?P<a>(?P=a))', + "cannot refer to an open group", 10) + self.checkPatternError('(?Pxy)', 'unknown extension ?Px') + self.checkPatternError('(?P<a>)(?P=a', 'missing ), unterminated name', 11) + self.checkPatternError('(?P=', 'missing group name', 4) + self.checkPatternError('(?P=)', 'missing group name', 4) + self.checkPatternError('(?P=1)', "bad character in group name '1'", 4) + self.checkPatternError('(?P=a)', "unknown group name 'a'") + self.checkPatternError('(?P=a1)', "unknown group name 'a1'") + self.checkPatternError('(?P=a.)', "bad character in group name 'a.'", 4) + self.checkPatternError('(?P<)', 'missing >, unterminated name', 4) + self.checkPatternError('(?P<a', 'missing >, unterminated name', 4) + self.checkPatternError('(?P<', 'missing group name', 4) + self.checkPatternError('(?P<>)', 'missing group name', 4) + self.checkPatternError(r'(?P<1>)', "bad character in group name '1'", 4) + self.checkPatternError(r'(?P<a.>)', "bad character in group name 'a.'", 4) + self.checkPatternError(r'(?(', 'missing group name', 3) + self.checkPatternError(r'(?())', 'missing group name', 3) + self.checkPatternError(r'(?(a))', "unknown group name 'a'", 3) + self.checkPatternError(r'(?(-1))', "bad character in group name '-1'", 3) + self.checkPatternError(r'(?(1a))', "bad character in group name '1a'", 3) + self.checkPatternError(r'(?(a.))', "bad character in group name 'a.'", 3) # New valid/invalid identifiers in Python 3 re.compile('(?P<µ>x)(?P=µ)(?(µ)y)') re.compile('(?P<𝔘𝔫𝔦𝔠𝔬𝔡𝔢>x)(?P=𝔘𝔫𝔦𝔠𝔬𝔡𝔢)(?(𝔘𝔫𝔦𝔠𝔬𝔡𝔢)y)') - self.assertRaises(re.error, re.compile, '(?P<©>x)') + self.checkPatternError('(?P<©>x)', "bad character in group name '©'", 4) + # Support > 100 groups. + pat = '|'.join('x(?P<a%d>%x)y' % (i, i) for i in range(1, 200 + 1)) + pat = '(?:%s)(?(200)z|t)' % pat + self.assertEqual(re.match(pat, 'xc8yz').span(), (0, 5)) def test_symbolic_refs(self): - self.assertRaises(re.error, re.sub, '(?P<a>x)', '\g<a', 'xx') - self.assertRaises(re.error, re.sub, '(?P<a>x)', '\g<', 'xx') - self.assertRaises(re.error, re.sub, '(?P<a>x)', '\g', 'xx') - self.assertRaises(re.error, re.sub, '(?P<a>x)', '\g<a a>', 'xx') - self.assertRaises(re.error, re.sub, '(?P<a>x)', '\g<>', 'xx') - self.assertRaises(re.error, re.sub, '(?P<a>x)', '\g<1a1>', 'xx') - self.assertRaises(IndexError, re.sub, '(?P<a>x)', '\g<ab>', 'xx') - self.assertRaises(re.error, re.sub, '(?P<a>x)|(?P<b>y)', '\g<b>', 'xx') - self.assertRaises(re.error, re.sub, '(?P<a>x)|(?P<b>y)', '\\2', 'xx') - self.assertRaises(re.error, re.sub, '(?P<a>x)', '\g<-1>', 'xx') + self.checkTemplateError('(?P<a>x)', '\g<a', 'xx', + 'missing >, unterminated name', 3) + self.checkTemplateError('(?P<a>x)', '\g<', 'xx', + 'missing group name', 3) + self.checkTemplateError('(?P<a>x)', '\g', 'xx', 'missing <', 2) + self.checkTemplateError('(?P<a>x)', '\g<a a>', 'xx', + "bad character in group name 'a a'", 3) + self.checkTemplateError('(?P<a>x)', '\g<>', 'xx', + 'missing group name', 3) + self.checkTemplateError('(?P<a>x)', '\g<1a1>', 'xx', + "bad character in group name '1a1'", 3) + self.checkTemplateError('(?P<a>x)', r'\g<2>', 'xx', + 'invalid group reference') + self.checkTemplateError('(?P<a>x)', r'\2', 'xx', + 'invalid group reference') + with self.assertRaisesRegex(IndexError, "unknown group name 'ab'"): + re.sub('(?P<a>x)', '\g<ab>', 'xx') + self.assertEqual(re.sub('(?P<a>x)|(?P<b>y)', r'\g<b>', 'xx'), '') + self.assertEqual(re.sub('(?P<a>x)|(?P<b>y)', r'\2', 'xx'), '') + self.checkTemplateError('(?P<a>x)', '\g<-1>', 'xx', + "bad character in group name '-1'", 3) # New valid/invalid identifiers in Python 3 self.assertEqual(re.sub('(?P<µ>x)', r'\g<µ>', 'xx'), 'xx') self.assertEqual(re.sub('(?P<𝔘𝔫𝔦𝔠𝔬𝔡𝔢>x)', r'\g<𝔘𝔫𝔦𝔠𝔬𝔡𝔢>', 'xx'), 'xx') - self.assertRaises(re.error, re.sub, '(?P<a>x)', r'\g<©>', 'xx') + self.checkTemplateError('(?P<a>x)', '\g<©>', 'xx', + "bad character in group name '©'", 3) + # Support > 100 groups. + pat = '|'.join('x(?P<a%d>%x)y' % (i, i) for i in range(1, 200 + 1)) + self.assertEqual(re.sub(pat, '\g<200>', 'xc8yzxc8y'), 'c8zc8') def test_re_subn(self): self.assertEqual(re.subn("(?i)b+", "x", "bbbb BBBB"), ('x x', 2)) self.assertEqual(re.subn("b+", "x", "bbbb BBBB"), ('x BBBB', 1)) self.assertEqual(re.subn("b+", "x", "xyz"), ('xyz', 0)) self.assertEqual(re.subn("b*", "x", "xyz"), ('xxxyxzx', 4)) - self.assertEqual(re.subn("b*", "x", "xyz", 2), ('xxxyz', 2)) + self.assertEqual(re.subn("b*", "x", "xyz", count=2), ('xxxyz', 2)) def test_re_split(self): for string in ":a:b::c", S(":a:b::c"): self.assertTypedEqual(re.split(":", string), ['', 'a', 'b', '', 'c']) - self.assertTypedEqual(re.split(":*", string), + self.assertTypedEqual(re.split(":+", string), ['', 'a', 'b', 'c']) - self.assertTypedEqual(re.split("(:*)", string), + self.assertTypedEqual(re.split("(:+)", string), ['', ':', 'a', ':', 'b', '::', 'c']) for string in (b":a:b::c", B(b":a:b::c"), bytearray(b":a:b::c"), memoryview(b":a:b::c")): self.assertTypedEqual(re.split(b":", string), [b'', b'a', b'b', b'', b'c']) - self.assertTypedEqual(re.split(b":*", string), + self.assertTypedEqual(re.split(b":+", string), [b'', b'a', b'b', b'c']) - self.assertTypedEqual(re.split(b"(:*)", string), + self.assertTypedEqual(re.split(b"(:+)", string), [b'', b':', b'a', b':', b'b', b'::', b'c']) for a, b, c in ("\xe0\xdf\xe7", "\u0430\u0431\u0432", "\U0001d49c\U0001d49e\U0001d4b5"): string = ":%s:%s::%s" % (a, b, c) self.assertEqual(re.split(":", string), ['', a, b, '', c]) - self.assertEqual(re.split(":*", string), ['', a, b, c]) - self.assertEqual(re.split("(:*)", string), + self.assertEqual(re.split(":+", string), ['', a, b, c]) + self.assertEqual(re.split("(:+)", string), ['', ':', a, ':', b, '::', c]) - self.assertEqual(re.split("(?::*)", ":a:b::c"), ['', 'a', 'b', 'c']) - self.assertEqual(re.split("(:)*", ":a:b::c"), + self.assertEqual(re.split("(?::+)", ":a:b::c"), ['', 'a', 'b', 'c']) + self.assertEqual(re.split("(:)+", ":a:b::c"), ['', ':', 'a', ':', 'b', ':', 'c']) self.assertEqual(re.split("([b:]+)", ":a:b::c"), ['', ':', 'a', ':b::', 'c']) @@ -272,13 +328,34 @@ class ReTests(unittest.TestCase): self.assertEqual(re.split("(?:b)|(?::+)", ":a:b::c"), ['', 'a', '', '', 'c']) + for sep, expected in [ + (':*', ['', 'a', 'b', 'c']), + ('(?::*)', ['', 'a', 'b', 'c']), + ('(:*)', ['', ':', 'a', ':', 'b', '::', 'c']), + ('(:)*', ['', ':', 'a', ':', 'b', ':', 'c']), + ]: + with self.subTest(sep=sep), self.assertWarns(FutureWarning): + self.assertTypedEqual(re.split(sep, ':a:b::c'), expected) + + for sep, expected in [ + ('', [':a:b::c']), + (r'\b', [':a:b::c']), + (r'(?=:)', [':a:b::c']), + (r'(?<=:)', [':a:b::c']), + ]: + with self.subTest(sep=sep), self.assertRaises(ValueError): + self.assertTypedEqual(re.split(sep, ':a:b::c'), expected) + def test_qualified_re_split(self): - self.assertEqual(re.split(":", ":a:b::c", 2), ['', 'a', 'b::c']) - self.assertEqual(re.split(':', 'a:b:c:d', 2), ['a', 'b', 'c:d']) - self.assertEqual(re.split("(:)", ":a:b::c", 2), + self.assertEqual(re.split(":", ":a:b::c", maxsplit=2), ['', 'a', 'b::c']) + self.assertEqual(re.split(':', 'a:b:c:d', maxsplit=2), ['a', 'b', 'c:d']) + self.assertEqual(re.split("(:)", ":a:b::c", maxsplit=2), ['', ':', 'a', ':', 'b::c']) - self.assertEqual(re.split("(:*)", ":a:b::c", 2), + self.assertEqual(re.split("(:+)", ":a:b::c", maxsplit=2), ['', ':', 'a', ':', 'b::c']) + with self.assertWarns(FutureWarning): + self.assertEqual(re.split("(:*)", ":a:b::c", maxsplit=2), + ['', ':', 'a', ':', 'b::c']) def test_re_findall(self): self.assertEqual(re.findall(":+", "abc"), []) @@ -405,6 +482,23 @@ class ReTests(unittest.TestCase): self.assertIsNone(p.match('abd')) self.assertIsNone(p.match('ac')) + # Support > 100 groups. + pat = '|'.join('x(?P<a%d>%x)y' % (i, i) for i in range(1, 200 + 1)) + pat = '(?:%s)(?(200)z)' % pat + self.assertEqual(re.match(pat, 'xc8yz').span(), (0, 5)) + + self.checkPatternError(r'(?P<a>)(?(0))', 'bad group number', 10) + self.checkPatternError(r'()(?(1)a|b', + 'missing ), unterminated subpattern', 2) + self.checkPatternError(r'()(?(1)a|b|c)', + 'conditional backref with more than ' + 'two branches', 10) + + def test_re_groupref_overflow(self): + self.checkTemplateError('()', '\g<%s>' % sre_constants.MAXGROUPS, 'xx', + 'invalid group reference', 3) + self.checkPatternError(r'(?P<a>)(?(%d))' % sre_constants.MAXGROUPS, + 'invalid group reference', 10) def test_re_groupref(self): self.assertEqual(re.match(r'^(\|)?([^()]+)\1$', '|a|').groups(), @@ -418,6 +512,8 @@ class ReTests(unittest.TestCase): self.assertEqual(re.match(r'^(?:(a)|c)(\1)?$', 'c').groups(), (None, None)) + self.checkPatternError(r'(abc\1)', 'cannot refer to an open group', 4) + def test_groupdict(self): self.assertEqual(re.match('(?P<first>first) (?P<second>second)', 'first second').groupdict(), @@ -428,6 +524,10 @@ class ReTests(unittest.TestCase): "first second") .expand(r"\2 \1 \g<second> \g<first>"), "second first second first") + self.assertEqual(re.match("(?P<first>first)|(?P<second>second)", + "first") + .expand(r"\2 \g<second>"), + " ") def test_repeat_minmax(self): self.assertIsNone(re.match("^(\w){1}$", "abc")) @@ -451,6 +551,7 @@ class ReTests(unittest.TestCase): self.assertTrue(re.match("^x{3}$", "xxx")) self.assertTrue(re.match("^x{1,3}$", "xxx")) + self.assertTrue(re.match("^x{3,3}$", "xxx")) self.assertTrue(re.match("^x{1,4}$", "xxx")) self.assertTrue(re.match("^x{3,4}?$", "xxx")) self.assertTrue(re.match("^x{3}?$", "xxx")) @@ -461,6 +562,9 @@ class ReTests(unittest.TestCase): self.assertIsNone(re.match("^x{}$", "xxx")) self.assertTrue(re.match("^x{}$", "x{}")) + self.checkPatternError(r'x{2,1}', + 'min repeat greater than max repeat', 2) + def test_getattr(self): self.assertEqual(re.compile("(?i)(a)(b)").pattern, "(?i)(a)(b)") self.assertEqual(re.compile("(?i)(a)(b)").flags, re.I | re.U) @@ -475,6 +579,14 @@ class ReTests(unittest.TestCase): self.assertEqual(re.match("(a)", "a").regs, ((0, 1), (0, 1))) self.assertTrue(re.match("(a)", "a").re) + # Issue 14260. groupindex should be non-modifiable mapping. + p = re.compile(r'(?i)(?P<first>a)(?P<other>b)') + self.assertEqual(sorted(p.groupindex), ['first', 'other']) + self.assertEqual(p.groupindex['other'], 2) + with self.assertRaises(TypeError): + p.groupindex['other'] = 0 + self.assertEqual(p.groupindex['other'], 2) + def test_special_escapes(self): self.assertEqual(re.search(r"\b(b.)\b", "abcd abc bcd bx").group(1), "bx") @@ -484,10 +596,6 @@ class ReTests(unittest.TestCase): "abcd abc bcd bx", re.ASCII).group(1), "bx") self.assertEqual(re.search(r"\B(b.)\B", "abc bcd bc abxd", re.ASCII).group(1), "bx") - self.assertEqual(re.search(r"\b(b.)\b", - "abcd abc bcd bx", re.LOCALE).group(1), "bx") - self.assertEqual(re.search(r"\B(b.)\B", - "abc bcd bc abxd", re.LOCALE).group(1), "bx") self.assertEqual(re.search(r"^abc$", "\nabc\n", re.M).group(0), "abc") self.assertEqual(re.search(r"^\Aabc\Z$", "abc", re.M).group(0), "abc") self.assertIsNone(re.search(r"^\Aabc\Z$", "\nabc\n", re.M)) @@ -508,11 +616,32 @@ class ReTests(unittest.TestCase): b"1aa! a").group(0), b"1aa! a") self.assertEqual(re.search(r"\d\D\w\W\s\S", "1aa! a", re.ASCII).group(0), "1aa! a") - self.assertEqual(re.search(r"\d\D\w\W\s\S", - "1aa! a", re.LOCALE).group(0), "1aa! a") self.assertEqual(re.search(br"\d\D\w\W\s\S", b"1aa! a", re.LOCALE).group(0), b"1aa! a") + def test_other_escapes(self): + self.checkPatternError("\\", 'bad escape (end of pattern)', 0) + self.assertEqual(re.match(r"\(", '(').group(), '(') + self.assertIsNone(re.match(r"\(", ')')) + self.assertEqual(re.match(r"\\", '\\').group(), '\\') + self.assertEqual(re.match(r"[\]]", ']').group(), ']') + self.assertIsNone(re.match(r"[\]]", '[')) + self.assertEqual(re.match(r"[a\-c]", '-').group(), '-') + self.assertIsNone(re.match(r"[a\-c]", 'b')) + self.assertEqual(re.match(r"[\^a]+", 'a^').group(), 'a^') + self.assertIsNone(re.match(r"[\^a]+", 'b')) + re.purge() # for warnings + for c in 'ceghijklmopqyzCEFGHIJKLMNOPQRTVXY': + with self.subTest(c): + with self.assertWarns(DeprecationWarning): + self.assertEqual(re.fullmatch('\\%c' % c, c).group(), c) + self.assertIsNone(re.match('\\%c' % c, 'a')) + for c in 'ceghijklmopqyzABCEFGHIJKLMNOPQRTVXYZ': + with self.subTest(c): + with self.assertWarns(DeprecationWarning): + self.assertEqual(re.fullmatch('[\\%c]' % c, c).group(), c) + self.assertIsNone(re.match('[\\%c]' % c, 'a')) + def test_string_boundaries(self): # See http://bugs.python.org/issue10713 self.assertEqual(re.search(r"\b(abc)\b", "abc").group(1), @@ -574,9 +703,6 @@ class ReTests(unittest.TestCase): # Group reference. self.assertTrue(re.match(r'(a)b(?=\1)a', 'aba')) self.assertIsNone(re.match(r'(a)b(?=\1)c', 'abac')) - # Named group reference. - self.assertTrue(re.match(r'(?P<g>a)b(?=(?P=g))a', 'aba')) - self.assertIsNone(re.match(r'(?P<g>a)b(?=(?P=g))c', 'abac')) # Conditional group reference. self.assertTrue(re.match(r'(?:(a)|(x))b(?=(?(2)x|c))c', 'abc')) self.assertIsNone(re.match(r'(?:(a)|(x))b(?=(?(2)c|x))c', 'abc')) @@ -594,13 +720,25 @@ class ReTests(unittest.TestCase): self.assertIsNone(re.match(r'ab(?<!b)c', 'abc')) self.assertTrue(re.match(r'ab(?<!c)c', 'abc')) # Group reference. - self.assertWarns(RuntimeWarning, re.compile, r'(a)a(?<=\1)c') - # Named group reference. - self.assertWarns(RuntimeWarning, re.compile, r'(?P<g>a)a(?<=(?P=g))c') + self.assertTrue(re.match(r'(a)a(?<=\1)c', 'aac')) + self.assertIsNone(re.match(r'(a)b(?<=\1)a', 'abaa')) + self.assertIsNone(re.match(r'(a)a(?<!\1)c', 'aac')) + self.assertTrue(re.match(r'(a)b(?<!\1)a', 'abaa')) # Conditional group reference. - self.assertWarns(RuntimeWarning, re.compile, r'(a)b(?<=(?(1)b|x))c') + self.assertIsNone(re.match(r'(?:(a)|(x))b(?<=(?(2)x|c))c', 'abc')) + self.assertIsNone(re.match(r'(?:(a)|(x))b(?<=(?(2)b|x))c', 'abc')) + self.assertTrue(re.match(r'(?:(a)|(x))b(?<=(?(2)x|b))c', 'abc')) + self.assertIsNone(re.match(r'(?:(a)|(x))b(?<=(?(1)c|x))c', 'abc')) + self.assertTrue(re.match(r'(?:(a)|(x))b(?<=(?(1)b|x))c', 'abc')) # Group used before defined. - self.assertWarns(RuntimeWarning, re.compile, r'(a)b(?<=(?(2)b|x))(c)') + self.assertRaises(re.error, re.compile, r'(a)b(?<=(?(2)b|x))(c)') + self.assertIsNone(re.match(r'(a)b(?<=(?(1)c|x))(c)', 'abc')) + self.assertTrue(re.match(r'(a)b(?<=(?(1)b|x))(c)', 'abc')) + # Group defined in the same lookbehind pattern + self.assertRaises(re.error, re.compile, r'(a)b(?<=(.)\2)(c)') + self.assertRaises(re.error, re.compile, r'(a)b(?<=(?P<a>.)(?P=a))(c)') + self.assertRaises(re.error, re.compile, r'(a)b(?<=(a)(?(2)b|x))(c)') + self.assertRaises(re.error, re.compile, r'(a)b(?<=(.)(?<=\2))(c)') def test_ignore_case(self): self.assertEqual(re.match("abc", "ABC", re.I).group(0), "ABC") @@ -692,9 +830,12 @@ class ReTests(unittest.TestCase): self.assertEqual(_sre.getlower(ord('A'), 0), ord('a')) self.assertEqual(_sre.getlower(ord('A'), re.LOCALE), ord('a')) self.assertEqual(_sre.getlower(ord('A'), re.UNICODE), ord('a')) + self.assertEqual(_sre.getlower(ord('A'), re.ASCII), ord('a')) self.assertEqual(re.match("abc", "ABC", re.I).group(0), "ABC") self.assertEqual(re.match(b"abc", b"ABC", re.I).group(0), b"ABC") + self.assertEqual(re.match("abc", "ABC", re.I|re.A).group(0), "ABC") + self.assertEqual(re.match(b"abc", b"ABC", re.I|re.L).group(0), b"ABC") def test_not_literal(self): self.assertEqual(re.search("\s([^a])", " b").group(1), "b") @@ -779,8 +920,10 @@ class ReTests(unittest.TestCase): self.assertEqual(re.X, re.VERBOSE) def test_flags(self): - for flag in [re.I, re.M, re.X, re.S, re.L]: + for flag in [re.I, re.M, re.X, re.S, re.A, re.U]: self.assertTrue(re.compile('^pattern$', flag)) + for flag in [re.I, re.M, re.X, re.S, re.A, re.L]: + self.assertTrue(re.compile(b'^pattern$', flag)) def test_sre_character_literals(self): for i in [0, 8, 16, 32, 64, 127, 128, 255, 256, 0xFFFF, 0x10000, 0x10FFFF]: @@ -802,15 +945,17 @@ class ReTests(unittest.TestCase): self.assertTrue(re.match(r"\08", "\0008")) self.assertTrue(re.match(r"\01", "\001")) self.assertTrue(re.match(r"\018", "\0018")) - self.assertTrue(re.match(r"\567", chr(0o167))) - self.assertRaises(re.error, re.match, r"\911", "") - self.assertRaises(re.error, re.match, r"\x1", "") - self.assertRaises(re.error, re.match, r"\x1z", "") - self.assertRaises(re.error, re.match, r"\u123", "") - self.assertRaises(re.error, re.match, r"\u123z", "") - self.assertRaises(re.error, re.match, r"\U0001234", "") - self.assertRaises(re.error, re.match, r"\U0001234z", "") - self.assertRaises(re.error, re.match, r"\U00110000", "") + self.checkPatternError(r"\567", + r'octal escape value \567 outside of ' + r'range 0-0o377', 0) + self.checkPatternError(r"\911", 'invalid group reference', 0) + self.checkPatternError(r"\x1", r'incomplete escape \x1', 0) + self.checkPatternError(r"\x1z", r'incomplete escape \x1', 0) + self.checkPatternError(r"\u123", r'incomplete escape \u123', 0) + self.checkPatternError(r"\u123z", r'incomplete escape \u123', 0) + self.checkPatternError(r"\U0001234", r'incomplete escape \U0001234', 0) + self.checkPatternError(r"\U0001234z", r'incomplete escape \U0001234', 0) + self.checkPatternError(r"\U00110000", r'bad escape \U00110000', 0) def test_sre_character_class_literals(self): for i in [0, 8, 16, 32, 64, 127, 128, 255, 256, 0xFFFF, 0x10000, 0x10FFFF]: @@ -830,12 +975,15 @@ class ReTests(unittest.TestCase): self.assertTrue(re.match(r"[\U%08x]" % i, chr(i))) self.assertTrue(re.match(r"[\U%08x0]" % i, chr(i)+"0")) self.assertTrue(re.match(r"[\U%08xz]" % i, chr(i)+"z")) + self.checkPatternError(r"[\567]", + r'octal escape value \567 outside of ' + r'range 0-0o377', 1) + self.checkPatternError(r"[\911]", r'bad escape \9', 1) + self.checkPatternError(r"[\x1z]", r'incomplete escape \x1', 1) + self.checkPatternError(r"[\u123z]", r'incomplete escape \u123', 1) + self.checkPatternError(r"[\U0001234z]", r'incomplete escape \U0001234', 1) + self.checkPatternError(r"[\U00110000]", r'bad escape \U00110000', 1) self.assertTrue(re.match(r"[\U0001d49c-\U0001d4b5]", "\U0001d49e")) - self.assertRaises(re.error, re.match, r"[\911]", "") - self.assertRaises(re.error, re.match, r"[\x1z]", "") - self.assertRaises(re.error, re.match, r"[\u123z]", "") - self.assertRaises(re.error, re.match, r"[\U0001234z]", "") - self.assertRaises(re.error, re.match, r"[\U00110000]", "") def test_sre_byte_literals(self): for i in [0, 8, 16, 32, 64, 127, 128, 255]: @@ -845,16 +993,20 @@ class ReTests(unittest.TestCase): self.assertTrue(re.match((r"\x%02x" % i).encode(), bytes([i]))) self.assertTrue(re.match((r"\x%02x0" % i).encode(), bytes([i])+b"0")) self.assertTrue(re.match((r"\x%02xz" % i).encode(), bytes([i])+b"z")) - self.assertTrue(re.match(br"\u", b'u')) - self.assertTrue(re.match(br"\U", b'U')) + with self.assertWarns(DeprecationWarning): + self.assertTrue(re.match(br"\u1234", b'u1234')) + with self.assertWarns(DeprecationWarning): + self.assertTrue(re.match(br"\U00012345", b'U00012345')) self.assertTrue(re.match(br"\0", b"\000")) self.assertTrue(re.match(br"\08", b"\0008")) self.assertTrue(re.match(br"\01", b"\001")) self.assertTrue(re.match(br"\018", b"\0018")) - self.assertTrue(re.match(br"\567", bytes([0o167]))) - self.assertRaises(re.error, re.match, br"\911", b"") - self.assertRaises(re.error, re.match, br"\x1", b"") - self.assertRaises(re.error, re.match, br"\x1z", b"") + self.checkPatternError(br"\567", + r'octal escape value \567 outside of ' + r'range 0-0o377', 0) + self.checkPatternError(br"\911", 'invalid group reference', 0) + self.checkPatternError(br"\x1", r'incomplete escape \x1', 0) + self.checkPatternError(br"\x1z", r'incomplete escape \x1', 0) def test_sre_byte_class_literals(self): for i in [0, 8, 16, 32, 64, 127, 128, 255]: @@ -866,10 +1018,26 @@ class ReTests(unittest.TestCase): self.assertTrue(re.match((r"[\x%02x]" % i).encode(), bytes([i]))) self.assertTrue(re.match((r"[\x%02x0]" % i).encode(), bytes([i]))) self.assertTrue(re.match((r"[\x%02xz]" % i).encode(), bytes([i]))) - self.assertTrue(re.match(br"[\u]", b'u')) - self.assertTrue(re.match(br"[\U]", b'U')) - self.assertRaises(re.error, re.match, br"[\911]", b"") - self.assertRaises(re.error, re.match, br"[\x1z]", b"") + with self.assertWarns(DeprecationWarning): + self.assertTrue(re.match(br"[\u1234]", b'u')) + with self.assertWarns(DeprecationWarning): + self.assertTrue(re.match(br"[\U00012345]", b'U')) + self.checkPatternError(br"[\567]", + r'octal escape value \567 outside of ' + r'range 0-0o377', 1) + self.checkPatternError(br"[\911]", r'bad escape \9', 1) + self.checkPatternError(br"[\x1z]", r'incomplete escape \x1', 1) + + def test_character_set_errors(self): + self.checkPatternError(r'[', 'unterminated character set', 0) + self.checkPatternError(r'[^', 'unterminated character set', 0) + self.checkPatternError(r'[a', 'unterminated character set', 0) + # bug 545855 -- This pattern failed to cause a compile error as it + # should, instead provoking a TypeError. + self.checkPatternError(r"[a-", 'unterminated character set', 0) + self.checkPatternError(r"[\w-b]", r'bad character range \w-b', 1) + self.checkPatternError(r"[a-\w]", r'bad character range a-\w', 1) + self.checkPatternError(r"[b-a]", 'bad character range b-a', 1) def test_bug_113254(self): self.assertEqual(re.match(r'(a)|(b)', 'b').start(1), -1) @@ -884,11 +1052,6 @@ class ReTests(unittest.TestCase): self.assertEqual(re.match("(?P<a>a(b))", "ab").lastgroup, 'a') self.assertEqual(re.match("((a))", "a").lastindex, 1) - def test_bug_545855(self): - # bug 545855 -- This pattern failed to cause a compile error as it - # should, instead provoking a TypeError. - self.assertRaises(re.error, re.compile, 'foo[a-') - def test_bug_418626(self): # bugs 418626 at al. -- Testing Greg Chapman's addition of op code # SRE_OP_MIN_REPEAT_ONE for eliminating recursion on simple uses of @@ -912,6 +1075,24 @@ class ReTests(unittest.TestCase): self.assertEqual(re.match('(x)*y', 50000*'x'+'y').group(1), 'x') self.assertEqual(re.match('(x)*?y', 50000*'x'+'y').group(1), 'x') + def test_nothing_to_repeat(self): + for reps in '*', '+', '?', '{1,2}': + for mod in '', '?': + self.checkPatternError('%s%s' % (reps, mod), + 'nothing to repeat', 0) + self.checkPatternError('(?:%s%s)' % (reps, mod), + 'nothing to repeat', 3) + + def test_multiple_repeat(self): + for outer_reps in '*', '+', '{1,2}': + for outer_mod in '', '?': + outer_op = outer_reps + outer_mod + for inner_reps in '*', '+', '?', '{1,2}': + for inner_mod in '', '?': + inner_op = inner_reps + inner_mod + self.checkPatternError(r'x%s%s' % (inner_op, outer_op), + 'multiple repeat', 1 + len(inner_op)) + def test_unlimited_zero_width_repeat(self): # Issue #9669 self.assertIsNone(re.match(r'(?:a?)*y', 'z')) @@ -1062,8 +1243,8 @@ class ReTests(unittest.TestCase): def test_inline_flags(self): # Bug #1700 - upper_char = chr(0x1ea0) # Latin Capital Letter A with Dot Bellow - lower_char = chr(0x1ea1) # Latin Small Letter A with Dot Bellow + upper_char = '\u1ea0' # Latin Capital Letter A with Dot Below + lower_char = '\u1ea1' # Latin Small Letter A with Dot Below p = re.compile(upper_char, re.I | re.U) q = p.match(lower_char) @@ -1143,6 +1324,52 @@ class ReTests(unittest.TestCase): self.assertRaises(ValueError, re.compile, '(?a)\w', re.UNICODE) self.assertRaises(ValueError, re.compile, '(?au)\w') + def test_locale_flag(self): + import locale + _, enc = locale.getlocale(locale.LC_CTYPE) + # Search non-ASCII letter + for i in range(128, 256): + try: + c = bytes([i]).decode(enc) + sletter = c.lower() + if sletter == c: continue + bletter = sletter.encode(enc) + if len(bletter) != 1: continue + if bletter.decode(enc) != sletter: continue + bpat = re.escape(bytes([i])) + break + except (UnicodeError, TypeError): + pass + else: + bletter = None + bpat = b'A' + # Bytes patterns + pat = re.compile(bpat, re.LOCALE | re.IGNORECASE) + if bletter: + self.assertTrue(pat.match(bletter)) + pat = re.compile(b'(?L)' + bpat, re.IGNORECASE) + if bletter: + self.assertTrue(pat.match(bletter)) + pat = re.compile(bpat, re.IGNORECASE) + if bletter: + self.assertIsNone(pat.match(bletter)) + pat = re.compile(b'\w', re.LOCALE) + if bletter: + self.assertTrue(pat.match(bletter)) + pat = re.compile(b'(?L)\w') + if bletter: + self.assertTrue(pat.match(bletter)) + pat = re.compile(b'\w') + if bletter: + self.assertIsNone(pat.match(bletter)) + # Incompatibilities + self.assertWarns(DeprecationWarning, re.compile, '', re.LOCALE) + self.assertWarns(DeprecationWarning, re.compile, '(?L)') + self.assertWarns(DeprecationWarning, re.compile, b'', re.LOCALE | re.ASCII) + self.assertWarns(DeprecationWarning, re.compile, b'(?L)', re.ASCII) + self.assertWarns(DeprecationWarning, re.compile, b'(?a)', re.LOCALE) + self.assertWarns(DeprecationWarning, re.compile, b'(?aL)') + def test_bug_6509(self): # Replacement strings of both types must parse properly. # all strings @@ -1170,8 +1397,10 @@ class ReTests(unittest.TestCase): # a RuntimeError is raised instead of OverflowError. long_overflow = 2**128 self.assertRaises(TypeError, re.finditer, "a", {}) - self.assertRaises(OverflowError, _sre.compile, "abc", 0, [long_overflow]) - self.assertRaises(TypeError, _sre.compile, {}, 0, []) + with self.assertRaises(OverflowError): + _sre.compile("abc", 0, [long_overflow], 0, [], []) + with self.assertRaises(TypeError): + _sre.compile({}, 0, [], 0, [], []) def test_search_dot_unicode(self): self.assertTrue(re.search("123.*-", '123abc-')) @@ -1193,8 +1422,9 @@ class ReTests(unittest.TestCase): def test_bug_13899(self): # Issue #13899: re pattern r"[\A]" should work like "A" but matches # nothing. Ditto B and Z. - self.assertEqual(re.findall(r'[\A\B\b\C\Z]', 'AB\bCZ'), - ['A', 'B', '\b', 'C', 'Z']) + with self.assertWarns(DeprecationWarning): + self.assertEqual(re.findall(r'[\A\B\b\C\Z]', 'AB\bCZ'), + ['A', 'B', '\b', 'C', 'Z']) @bigmemtest(size=_2G, memuse=1) def test_large_search(self, size): @@ -1253,13 +1483,13 @@ class ReTests(unittest.TestCase): def test_backref_group_name_in_exception(self): # Issue 17341: Poor error message when compiling invalid regex - with self.assertRaisesRegex(sre_constants.error, '<foo>'): - re.compile('(?P=<foo>)') + self.checkPatternError('(?P=<foo>)', + "bad character in group name '<foo>'", 4) def test_group_name_in_exception(self): # Issue 17341: Poor error message when compiling invalid regex - with self.assertRaisesRegex(sre_constants.error, '\?foo'): - re.compile('(?P<?foo>)') + self.checkPatternError('(?P<?foo>)', + "bad character in group name '?foo'", 4) def test_issue17998(self): for reps in '*', '+', '?', '{1}': @@ -1309,22 +1539,22 @@ class ReTests(unittest.TestCase): with captured_stdout() as out: re.compile(pat, re.DEBUG) dump = '''\ -subpattern 1 - literal 46 -subpattern None - branch - in - literal 99 - literal 104 - or - literal 112 - literal 121 -subpattern None - groupref_exists 1 - at at_end - else - literal 58 - literal 32 +SUBPATTERN 1 + LITERAL 46 +SUBPATTERN None + BRANCH + IN + LITERAL 99 + LITERAL 104 + OR + LITERAL 112 + LITERAL 121 +SUBPATTERN None + GROUPREF_EXISTS 1 + AT AT_END + ELSE + LITERAL 58 + LITERAL 32 ''' self.assertEqual(out.getvalue(), dump) # Debug output is output again even a second time (bypassing @@ -1392,6 +1622,55 @@ subpattern None self.assertIsNone(re.match(b'(?Li)\xc5', b'\xe5')) self.assertIsNone(re.match(b'(?Li)\xe5', b'\xc5')) + def test_error(self): + with self.assertRaises(re.error) as cm: + re.compile('(\u20ac))') + err = cm.exception + self.assertIsInstance(err.pattern, str) + self.assertEqual(err.pattern, '(\u20ac))') + self.assertEqual(err.pos, 3) + self.assertEqual(err.lineno, 1) + self.assertEqual(err.colno, 4) + self.assertIn(err.msg, str(err)) + self.assertIn(' at position 3', str(err)) + self.assertNotIn(' at position 3', err.msg) + # Bytes pattern + with self.assertRaises(re.error) as cm: + re.compile(b'(\xa4))') + err = cm.exception + self.assertIsInstance(err.pattern, bytes) + self.assertEqual(err.pattern, b'(\xa4))') + self.assertEqual(err.pos, 3) + # Multiline pattern + with self.assertRaises(re.error) as cm: + re.compile(""" + ( + abc + ) + ) + ( + """, re.VERBOSE) + err = cm.exception + self.assertEqual(err.pos, 77) + self.assertEqual(err.lineno, 5) + self.assertEqual(err.colno, 17) + self.assertIn(err.msg, str(err)) + self.assertIn(' at position 77', str(err)) + self.assertIn('(line 5, column 17)', str(err)) + + def test_misc_errors(self): + self.checkPatternError(r'(', 'missing ), unterminated subpattern', 0) + self.checkPatternError(r'((a|b)', 'missing ), unterminated subpattern', 0) + self.checkPatternError(r'(a|b))', 'unbalanced parenthesis', 5) + self.checkPatternError(r'(?P', 'unexpected end of pattern', 3) + self.checkPatternError(r'(?z)', 'unknown extension ?z', 1) + self.checkPatternError(r'(?iz)', 'unknown flag', 3) + self.checkPatternError(r'(?i', 'missing )', 3) + self.checkPatternError(r'(?#abc', 'missing ), unterminated comment', 0) + self.checkPatternError(r'(?<', 'unexpected end of pattern', 3) + self.checkPatternError(r'(?<>)', 'unknown extension ?<>', 1) + self.checkPatternError(r'(?', 'unexpected end of pattern', 2) + class PatternReprTests(unittest.TestCase): def check(self, pattern, expected): @@ -1436,6 +1715,10 @@ class PatternReprTests(unittest.TestCase): self.check_flags(b'bytes pattern', re.A, "re.compile(b'bytes pattern', re.ASCII)") + def test_locale(self): + self.check_flags(b'bytes pattern', re.L, + "re.compile(b'bytes pattern', re.LOCALE)") + def test_quotes(self): self.check('random "double quoted" pattern', '''re.compile('random "double quoted" pattern')''') @@ -1549,8 +1832,16 @@ class ExternalTests(unittest.TestCase): pass else: with self.subTest('bytes pattern match'): - bpat = re.compile(bpat) - self.assertTrue(bpat.search(bs)) + obj = re.compile(bpat) + self.assertTrue(obj.search(bs)) + + # Try the match with LOCALE enabled, and check that it + # still succeeds. + with self.subTest('locale-sensitive match'): + obj = re.compile(bpat, re.LOCALE) + result = obj.search(bs) + if result is None: + print('=== Fails on locale-sensitive match', t) # Try the match with the search area limited to the extent # of the match and see if it still succeeds. \B will @@ -1568,13 +1859,6 @@ class ExternalTests(unittest.TestCase): obj = re.compile(pattern, re.IGNORECASE) self.assertTrue(obj.search(s)) - # Try the match with LOCALE enabled, and check that it - # still succeeds. - if '(?u)' not in pattern: - with self.subTest('locale-sensitive match'): - obj = re.compile(pattern, re.LOCALE) - self.assertTrue(obj.search(s)) - # Try the match with UNICODE locale enabled, and check # that it still succeeds. with self.subTest('unicode-sensitive match'): diff --git a/Lib/test/test_readline.py b/Lib/test/test_readline.py index 0b2b0a5..35330ab 100644 --- a/Lib/test/test_readline.py +++ b/Lib/test/test_readline.py @@ -2,9 +2,10 @@ Very minimal unittests for parts of the readline module. """ import os +import tempfile import unittest -from test.support import run_unittest, import_module -from test.script_helper import assert_python_ok +from test.support import import_module, unlink +from test.support.script_helper import assert_python_ok # Skip tests if there is no readline module readline = import_module('readline') @@ -42,6 +43,45 @@ class TestHistoryManipulation (unittest.TestCase): self.assertEqual(readline.get_current_history_length(), 1) + @unittest.skipUnless(hasattr(readline, "append_history_file"), + "append_history not available") + def test_write_read_append(self): + hfile = tempfile.NamedTemporaryFile(delete=False) + hfile.close() + hfilename = hfile.name + self.addCleanup(unlink, hfilename) + + # test write-clear-read == nop + readline.clear_history() + readline.add_history("first line") + readline.add_history("second line") + readline.write_history_file(hfilename) + + readline.clear_history() + self.assertEqual(readline.get_current_history_length(), 0) + + readline.read_history_file(hfilename) + self.assertEqual(readline.get_current_history_length(), 2) + self.assertEqual(readline.get_history_item(1), "first line") + self.assertEqual(readline.get_history_item(2), "second line") + + # test append + readline.append_history_file(1, hfilename) + readline.clear_history() + readline.read_history_file(hfilename) + self.assertEqual(readline.get_current_history_length(), 3) + self.assertEqual(readline.get_history_item(1), "first line") + self.assertEqual(readline.get_history_item(2), "second line") + self.assertEqual(readline.get_history_item(3), "second line") + + # test 'no such file' behaviour + os.unlink(hfilename) + with self.assertRaises(FileNotFoundError): + readline.append_history_file(1, hfilename) + + # write_history_file can create the target + readline.write_history_file(hfilename) + class TestReadline(unittest.TestCase): @@ -57,8 +97,5 @@ class TestReadline(unittest.TestCase): self.assertEqual(stdout, b'') -def test_main(): - run_unittest(TestHistoryManipulation, TestReadline) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_reprlib.py b/Lib/test/test_reprlib.py index ae67f06..a51c4d7 100644 --- a/Lib/test/test_reprlib.py +++ b/Lib/test/test_reprlib.py @@ -10,7 +10,7 @@ import importlib import importlib.util import unittest -from test.support import run_unittest, create_empty_file, verbose +from test.support import create_empty_file, verbose from reprlib import repr as r # Don't shadow builtin repr from reprlib import Repr from reprlib import recursive_repr @@ -70,18 +70,18 @@ class ReprTests(unittest.TestCase): eq(r([1, 2, 3, 4, 5, 6, 7]), "[1, 2, 3, 4, 5, 6, ...]") # Sets give up after 6 as well - eq(r(set([])), "set([])") - eq(r(set([1])), "set([1])") - eq(r(set([1, 2, 3])), "set([1, 2, 3])") - eq(r(set([1, 2, 3, 4, 5, 6])), "set([1, 2, 3, 4, 5, 6])") - eq(r(set([1, 2, 3, 4, 5, 6, 7])), "set([1, 2, 3, 4, 5, 6, ...])") + eq(r(set([])), "set()") + eq(r(set([1])), "{1}") + eq(r(set([1, 2, 3])), "{1, 2, 3}") + eq(r(set([1, 2, 3, 4, 5, 6])), "{1, 2, 3, 4, 5, 6}") + eq(r(set([1, 2, 3, 4, 5, 6, 7])), "{1, 2, 3, 4, 5, 6, ...}") # Frozensets give up after 6 as well - eq(r(frozenset([])), "frozenset([])") - eq(r(frozenset([1])), "frozenset([1])") - eq(r(frozenset([1, 2, 3])), "frozenset([1, 2, 3])") - eq(r(frozenset([1, 2, 3, 4, 5, 6])), "frozenset([1, 2, 3, 4, 5, 6])") - eq(r(frozenset([1, 2, 3, 4, 5, 6, 7])), "frozenset([1, 2, 3, 4, 5, 6, ...])") + eq(r(frozenset([])), "frozenset()") + eq(r(frozenset([1])), "frozenset({1})") + eq(r(frozenset([1, 2, 3])), "frozenset({1, 2, 3})") + eq(r(frozenset([1, 2, 3, 4, 5, 6])), "frozenset({1, 2, 3, 4, 5, 6})") + eq(r(frozenset([1, 2, 3, 4, 5, 6, 7])), "frozenset({1, 2, 3, 4, 5, 6, ...})") # collections.deque after 6 eq(r(deque([1, 2, 3, 4, 5, 6, 7])), "deque([1, 2, 3, 4, 5, 6, ...])") @@ -94,7 +94,7 @@ class ReprTests(unittest.TestCase): eq(r(d), "{'alice': 1, 'arthur': 1, 'bob': 2, 'charles': 3, ...}") # array.array after 5. - eq(r(array('i')), "array('i', [])") + eq(r(array('i')), "array('i')") eq(r(array('i', [1])), "array('i', [1])") eq(r(array('i', [1, 2])), "array('i', [1, 2])") eq(r(array('i', [1, 2, 3])), "array('i', [1, 2, 3])") @@ -103,6 +103,20 @@ class ReprTests(unittest.TestCase): eq(r(array('i', [1, 2, 3, 4, 5, 6])), "array('i', [1, 2, 3, 4, 5, ...])") + def test_set_literal(self): + eq = self.assertEqual + eq(r({1}), "{1}") + eq(r({1, 2, 3}), "{1, 2, 3}") + eq(r({1, 2, 3, 4, 5, 6}), "{1, 2, 3, 4, 5, 6}") + eq(r({1, 2, 3, 4, 5, 6, 7}), "{1, 2, 3, 4, 5, 6, ...}") + + def test_frozenset(self): + eq = self.assertEqual + eq(r(frozenset({1})), "frozenset({1})") + eq(r(frozenset({1, 2, 3})), "frozenset({1, 2, 3})") + eq(r(frozenset({1, 2, 3, 4, 5, 6})), "frozenset({1, 2, 3, 4, 5, 6})") + eq(r(frozenset({1, 2, 3, 4, 5, 6, 7})), "frozenset({1, 2, 3, 4, 5, 6, ...})") + def test_numbers(self): eq = self.assertEqual eq(r(123), repr(123)) @@ -123,7 +137,7 @@ class ReprTests(unittest.TestCase): eq(r(i2), expected) i3 = ClassWithFailingRepr() - eq(r(i3), ("<ClassWithFailingRepr instance at %x>"%id(i3))) + eq(r(i3), ("<ClassWithFailingRepr instance at %#x>"%id(i3))) s = r(ClassWithFailingRepr) self.assertTrue(s.startswith("<class ")) @@ -373,11 +387,5 @@ class TestRecursiveRepr(unittest.TestCase): m.append(m) self.assertEqual(repr(m), '<a, b, c, d, e, +++, x, +++>') -def test_main(): - run_unittest(ReprTests) - run_unittest(LongReprTest) - run_unittest(TestRecursiveRepr) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_richcmp.py b/Lib/test/test_richcmp.py index 0b629dc..1582caa 100644 --- a/Lib/test/test_richcmp.py +++ b/Lib/test/test_richcmp.py @@ -228,25 +228,25 @@ class MiscTest(unittest.TestCase): b = UserList() a.append(b) b.append(a) - self.assertRaises(RuntimeError, operator.eq, a, b) - self.assertRaises(RuntimeError, operator.ne, a, b) - self.assertRaises(RuntimeError, operator.lt, a, b) - self.assertRaises(RuntimeError, operator.le, a, b) - self.assertRaises(RuntimeError, operator.gt, a, b) - self.assertRaises(RuntimeError, operator.ge, a, b) + self.assertRaises(RecursionError, operator.eq, a, b) + self.assertRaises(RecursionError, operator.ne, a, b) + self.assertRaises(RecursionError, operator.lt, a, b) + self.assertRaises(RecursionError, operator.le, a, b) + self.assertRaises(RecursionError, operator.gt, a, b) + self.assertRaises(RecursionError, operator.ge, a, b) b.append(17) # Even recursive lists of different lengths are different, # but they cannot be ordered self.assertTrue(not (a == b)) self.assertTrue(a != b) - self.assertRaises(RuntimeError, operator.lt, a, b) - self.assertRaises(RuntimeError, operator.le, a, b) - self.assertRaises(RuntimeError, operator.gt, a, b) - self.assertRaises(RuntimeError, operator.ge, a, b) + self.assertRaises(RecursionError, operator.lt, a, b) + self.assertRaises(RecursionError, operator.le, a, b) + self.assertRaises(RecursionError, operator.gt, a, b) + self.assertRaises(RecursionError, operator.ge, a, b) a.append(17) - self.assertRaises(RuntimeError, operator.eq, a, b) - self.assertRaises(RuntimeError, operator.ne, a, b) + self.assertRaises(RecursionError, operator.eq, a, b) + self.assertRaises(RecursionError, operator.ne, a, b) a.insert(0, 11) b.insert(0, 12) self.assertTrue(not (a == b)) @@ -326,8 +326,5 @@ class ListTest(unittest.TestCase): self.assertIs(op(x, y), True) -def test_main(): - support.run_unittest(VectorTest, NumberTest, MiscTest, DictTest, ListTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_rlcompleter.py b/Lib/test/test_rlcompleter.py index 2da7fce..d37b620 100644 --- a/Lib/test/test_rlcompleter.py +++ b/Lib/test/test_rlcompleter.py @@ -72,6 +72,5 @@ class TestRlcompleter(unittest.TestCase): self.assertEqual(completer.complete('as', 2), 'assert') self.assertEqual(completer.complete('an', 0), 'and') - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_runpy.py b/Lib/test/test_runpy.py index 786b813..4bae949 100644 --- a/Lib/test/test_runpy.py +++ b/Lib/test/test_runpy.py @@ -8,10 +8,10 @@ import tempfile import importlib, importlib.machinery, importlib.util import py_compile from test.support import ( - forget, make_legacy_pyc, run_unittest, unload, verbose, no_tracing, - create_empty_file) -from test.script_helper import ( - make_pkg, make_script, make_zip_pkg, make_zip_script, temp_dir) + forget, make_legacy_pyc, unload, verbose, no_tracing, + create_empty_file, temp_dir) +from test.support.script_helper import ( + make_pkg, make_script, make_zip_pkg, make_zip_script) import runpy @@ -269,7 +269,7 @@ class RunModuleTestCase(unittest.TestCase, CodeExecutionMixin): if verbose > 1: print(ex) # Persist with cleaning up def _fix_ns_for_legacy_pyc(self, ns, alter_sys): - char_to_add = "c" if __debug__ else "o" + char_to_add = "c" ns["__file__"] += char_to_add ns["__cached__"] = ns["__file__"] spec = ns["__spec__"] @@ -673,7 +673,7 @@ class RunPathTestCase(unittest.TestCase, CodeExecutionMixin): script_name = self._make_test_script(script_dir, mod_name, source) zip_name, fname = make_zip_script(script_dir, 'test_zip', script_name) msg = "recursion depth exceeded" - self.assertRaisesRegex(RuntimeError, msg, run_path, zip_name) + self.assertRaisesRegex(RecursionError, msg, run_path, zip_name) def test_encoding(self): with temp_dir() as script_dir: diff --git a/Lib/test/test_sax.py b/Lib/test/test_sax.py index 90f3016..2411895 100644 --- a/Lib/test/test_sax.py +++ b/Lib/test/test_sax.py @@ -200,6 +200,13 @@ class ParseTest(unittest.TestCase): parseString(s, XMLGenerator(result, 'utf-8')) self.assertEqual(result.getvalue(), xml_str(self.data, 'utf-8')) + def test_parseString_text(self): + encodings = ('us-ascii', 'iso-8859-1', 'utf-8', + 'utf-16', 'utf-16le', 'utf-16be') + for encoding in encodings: + self.check_parseString(xml_str(self.data, encoding)) + self.check_parseString(self.data) + def test_parseString_bytes(self): # UTF-8 is default encoding, US-ASCII is compatible with UTF-8, # UTF-16 is autodetected @@ -306,12 +313,24 @@ class PrepareInputSourceTest(unittest.TestCase): def make_byte_stream(self): return BytesIO(b"This is a byte stream.") + def make_character_stream(self): + return StringIO("This is a character stream.") + def checkContent(self, stream, content): self.assertIsNotNone(stream) self.assertEqual(stream.read(), content) stream.close() + def test_character_stream(self): + # If the source is an InputSource with a character stream, use it. + src = InputSource(self.file) + src.setCharacterStream(self.make_character_stream()) + prep = prepare_input_source(src) + self.assertIsNone(prep.getByteStream()) + self.checkContent(prep.getCharacterStream(), + "This is a character stream.") + def test_byte_stream(self): # If the source is an InputSource that does not have a character # stream but does have a byte stream, use the byte stream. @@ -346,6 +365,14 @@ class PrepareInputSourceTest(unittest.TestCase): self.checkContent(prep.getByteStream(), b"This is a byte stream.") + def test_text_file(self): + # If the source is a text file-like object, use it as a character + # stream. + prep = prepare_input_source(self.make_character_stream()) + self.assertIsNone(prep.getByteStream()) + self.checkContent(prep.getCharacterStream(), + "This is a character stream.") + # ===== XMLGenerator @@ -1025,6 +1052,19 @@ class ExpatReaderTest(XmlTestBase): self.assertEqual(result.getvalue(), xml_test_out) + def test_expat_inpsource_character_stream(self): + parser = create_parser() + result = BytesIO() + xmlgen = XMLGenerator(result) + + parser.setContentHandler(xmlgen) + inpsrc = InputSource() + with open(TEST_XMLFILE, 'rt', encoding='iso-8859-1') as f: + inpsrc.setCharacterStream(f) + parser.parse(inpsrc) + + self.assertEqual(result.getvalue(), xml_test_out) + # ===== IncrementalParser support def test_expat_incremental(self): diff --git a/Lib/test/test_scope.py b/Lib/test/test_scope.py index b325545..4239b26 100644 --- a/Lib/test/test_scope.py +++ b/Lib/test/test_scope.py @@ -1,7 +1,7 @@ import unittest import weakref -from test.support import check_syntax_error, cpython_only, run_unittest +from test.support import check_syntax_error, cpython_only class ScopeTests(unittest.TestCase): @@ -757,8 +757,5 @@ class ScopeTests(unittest.TestCase): self.assertIsNone(ref()) -def test_main(): - run_unittest(ScopeTests) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_script_helper.py b/Lib/test/test_script_helper.py index 372d6a7..a7680f8 100755..100644 --- a/Lib/test/test_script_helper.py +++ b/Lib/test/test_script_helper.py @@ -1,8 +1,8 @@ -"""Unittests for test.script_helper. Who tests the test helper?""" +"""Unittests for test.support.script_helper. Who tests the test helper?""" import subprocess import sys -from test import script_helper +from test.support import script_helper import unittest from unittest import mock @@ -23,21 +23,21 @@ class TestScriptHelper(unittest.TestCase): with self.assertRaises(AssertionError) as error_context: script_helper.assert_python_ok('-c', 'sys.exit(0)') error_msg = str(error_context.exception) - self.assertIn('command line was:', error_msg) + self.assertIn('command line:', error_msg) self.assertIn('sys.exit(0)', error_msg, msg='unexpected command line') def test_assert_python_failure_raises(self): with self.assertRaises(AssertionError) as error_context: script_helper.assert_python_failure('-c', 'import sys; sys.exit(0)') error_msg = str(error_context.exception) - self.assertIn('Process return code is 0,', error_msg) + self.assertIn('Process return code is 0\n', error_msg) self.assertIn('import sys; sys.exit(0)', error_msg, msg='unexpected command line.') @mock.patch('subprocess.Popen') def test_assert_python_isolated_when_env_not_required(self, mock_popen): with mock.patch.object(script_helper, - '_interpreter_requires_environment', + 'interpreter_requires_environment', return_value=False) as mock_ire_func: mock_popen.side_effect = RuntimeError('bail out of unittest') try: @@ -56,7 +56,7 @@ class TestScriptHelper(unittest.TestCase): def test_assert_python_not_isolated_when_env_is_required(self, mock_popen): """Ensure that -I is not passed when the environment is required.""" with mock.patch.object(script_helper, - '_interpreter_requires_environment', + 'interpreter_requires_environment', return_value=True) as mock_ire_func: mock_popen.side_effect = RuntimeError('bail out of unittest') try: @@ -69,7 +69,7 @@ class TestScriptHelper(unittest.TestCase): class TestScriptHelperEnvironment(unittest.TestCase): - """Code coverage for _interpreter_requires_environment().""" + """Code coverage for interpreter_requires_environment().""" def setUp(self): self.assertTrue( @@ -84,22 +84,22 @@ class TestScriptHelperEnvironment(unittest.TestCase): @mock.patch('subprocess.check_call') def test_interpreter_requires_environment_true(self, mock_check_call): mock_check_call.side_effect = subprocess.CalledProcessError('', '') - self.assertTrue(script_helper._interpreter_requires_environment()) - self.assertTrue(script_helper._interpreter_requires_environment()) + self.assertTrue(script_helper.interpreter_requires_environment()) + self.assertTrue(script_helper.interpreter_requires_environment()) self.assertEqual(1, mock_check_call.call_count) @mock.patch('subprocess.check_call') def test_interpreter_requires_environment_false(self, mock_check_call): # The mocked subprocess.check_call fakes a no-error process. - script_helper._interpreter_requires_environment() - self.assertFalse(script_helper._interpreter_requires_environment()) + script_helper.interpreter_requires_environment() + self.assertFalse(script_helper.interpreter_requires_environment()) self.assertEqual(1, mock_check_call.call_count) @mock.patch('subprocess.check_call') def test_interpreter_requires_environment_details(self, mock_check_call): - script_helper._interpreter_requires_environment() - self.assertFalse(script_helper._interpreter_requires_environment()) - self.assertFalse(script_helper._interpreter_requires_environment()) + script_helper.interpreter_requires_environment() + self.assertFalse(script_helper.interpreter_requires_environment()) + self.assertFalse(script_helper.interpreter_requires_environment()) self.assertEqual(1, mock_check_call.call_count) check_call_command = mock_check_call.call_args[0][0] self.assertEqual(sys.executable, check_call_command[0]) diff --git a/Lib/test/test_select.py b/Lib/test/test_select.py index 8f9a1c9..a973f3f 100644 --- a/Lib/test/test_select.py +++ b/Lib/test/test_select.py @@ -75,9 +75,8 @@ class SelectTestCase(unittest.TestCase): a[:] = [F()] * 10 self.assertEqual(select.select([], a, []), ([], a[:5], [])) -def test_main(): - support.run_unittest(SelectTestCase) +def tearDownModule(): support.reap_children() if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_selectors.py b/Lib/test/test_selectors.py index 952fda6..454c17b 100644 --- a/Lib/test/test_selectors.py +++ b/Lib/test/test_selectors.py @@ -9,10 +9,7 @@ from test import support from time import sleep import unittest import unittest.mock -try: - from time import monotonic as time -except ImportError: - from time import time as time +from time import monotonic as time try: import resource except ImportError: @@ -25,7 +22,7 @@ else: def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): with socket.socket(family, type, proto) as l: l.bind((support.HOST, 0)) - l.listen(3) + l.listen() c = socket.socket(family, type, proto) try: c.connect(l.getsockname()) @@ -188,8 +185,8 @@ class BaseSelectorTestCase(unittest.TestCase): s.register(wr, selectors.EVENT_WRITE) s.close() - self.assertRaises(KeyError, s.get_key, rd) - self.assertRaises(KeyError, s.get_key, wr) + self.assertRaises(RuntimeError, s.get_key, rd) + self.assertRaises(RuntimeError, s.get_key, wr) self.assertRaises(KeyError, mapping.__getitem__, rd) self.assertRaises(KeyError, mapping.__getitem__, wr) @@ -258,8 +255,8 @@ class BaseSelectorTestCase(unittest.TestCase): sel.register(rd, selectors.EVENT_READ) sel.register(wr, selectors.EVENT_WRITE) - self.assertRaises(KeyError, s.get_key, rd) - self.assertRaises(KeyError, s.get_key, wr) + self.assertRaises(RuntimeError, s.get_key, rd) + self.assertRaises(RuntimeError, s.get_key, wr) def test_fileno(self): s = self.SELECTOR() @@ -360,7 +357,35 @@ class BaseSelectorTestCase(unittest.TestCase): @unittest.skipUnless(hasattr(signal, "alarm"), "signal.alarm() required for this test") - def test_select_interrupt(self): + def test_select_interrupt_exc(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + class InterruptSelect(Exception): + pass + + def handler(*args): + raise InterruptSelect + + orig_alrm_handler = signal.signal(signal.SIGALRM, handler) + self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler) + self.addCleanup(signal.alarm, 0) + + signal.alarm(1) + + s.register(rd, selectors.EVENT_READ) + t = time() + # select() is interrupted by a signal which raises an exception + with self.assertRaises(InterruptSelect): + s.select(30) + # select() was interrupted before the timeout of 30 seconds + self.assertLess(time() - t, 5.0) + + @unittest.skipUnless(hasattr(signal, "alarm"), + "signal.alarm() required for this test") + def test_select_interrupt_noraise(self): s = self.SELECTOR() self.addCleanup(s.close) @@ -374,8 +399,11 @@ class BaseSelectorTestCase(unittest.TestCase): s.register(rd, selectors.EVENT_READ) t = time() - self.assertFalse(s.select(2)) - self.assertLess(time() - t, 2.5) + # select() is interrupted by a signal, but the signal handler doesn't + # raise an exception, so select() should by retries with a recomputed + # timeout + self.assertFalse(s.select(1.5)) + self.assertGreaterEqual(time() - t, 1.0) class ScalableSelectorMixIn: @@ -455,10 +483,18 @@ class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): SELECTOR = getattr(selectors, 'KqueueSelector', None) +@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'), + "Test needs selectors.DevpollSelector") +class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): + + SELECTOR = getattr(selectors, 'DevpollSelector', None) + + + def test_main(): tests = [DefaultSelectorTestCase, SelectSelectorTestCase, PollSelectorTestCase, EpollSelectorTestCase, - KqueueSelectorTestCase] + KqueueSelectorTestCase, DevpollSelectorTestCase] support.run_unittest(*tests) support.reap_children() diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index f084ebe..54de508 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -931,7 +931,7 @@ class TestBasicOpsString(TestBasicOps, unittest.TestCase): class TestBasicOpsBytes(TestBasicOps, unittest.TestCase): def setUp(self): - self.case = "string set" + self.case = "bytes set" self.values = [b"a", b"b", b"c"] self.set = set(self.values) self.dup = set(self.values) @@ -1742,6 +1742,19 @@ class TestWeirdBugs(unittest.TestCase): s.update(range(100)) list(si) + def test_merge_and_mutate(self): + class X: + def __hash__(self): + return hash(0) + def __eq__(self, o): + other.clear() + return False + + other = set() + other = {X() for i in range(10)} + s = {0} + s.update(other) + # Application tests (based on David Eppstein's graph recipes ==================================== def powerset(U): diff --git a/Lib/test/test_shlex.py b/Lib/test/test_shlex.py index d2809ae..4fafdd4 100644 --- a/Lib/test/test_shlex.py +++ b/Lib/test/test_shlex.py @@ -3,7 +3,6 @@ import shlex import string import unittest -from test import support # The original test data set was from shellwords, by Hartmut Goebel. @@ -195,8 +194,5 @@ if not getattr(shlex, "split", None): if methname.startswith("test") and methname != "testCompat": delattr(ShlexTest, methname) -def test_main(): - support.run_unittest(ShlexTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_shutil.py b/Lib/test/test_shutil.py index 5b4e7e7..522959a 100644 --- a/Lib/test/test_shutil.py +++ b/Lib/test/test_shutil.py @@ -32,6 +32,12 @@ try: except ImportError: BZ2_SUPPORTED = False +try: + import lzma + LZMA_SUPPORTED = True +except ImportError: + LZMA_SUPPORTED = False + TESTFN2 = TESTFN + "2" try: @@ -1208,6 +1214,8 @@ class TestShutil(unittest.TestCase): formats = ['tar', 'gztar', 'zip'] if BZ2_SUPPORTED: formats.append('bztar') + if LZMA_SUPPORTED: + formats.append('xztar') root_dir, base_dir = self._create_files() expected = rlistdir(root_dir) @@ -1644,6 +1652,24 @@ class TestMove(unittest.TestCase): rv = shutil.move(self.src_file, os.path.join(self.dst_dir, 'bar')) self.assertEqual(rv, os.path.join(self.dst_dir, 'bar')) + @mock_rename + def test_move_file_special_function(self): + moved = [] + def _copy(src, dst): + moved.append((src, dst)) + shutil.move(self.src_file, self.dst_dir, copy_function=_copy) + self.assertEqual(len(moved), 1) + + @mock_rename + def test_move_dir_special_function(self): + moved = [] + def _copy(src, dst): + moved.append((src, dst)) + support.create_empty_file(os.path.join(self.src_dir, 'child')) + support.create_empty_file(os.path.join(self.src_dir, 'child1')) + shutil.move(self.src_dir, self.dst_dir, copy_function=_copy) + self.assertEqual(len(moved), 3) + class TestCopyFile(unittest.TestCase): diff --git a/Lib/test/test_signal.py b/Lib/test/test_signal.py index 74f74af..1b80ff0 100644 --- a/Lib/test/test_signal.py +++ b/Lib/test/test_signal.py @@ -1,19 +1,25 @@ import unittest from test import support from contextlib import closing +import enum import gc import pickle import select import signal +import socket import struct import subprocess import traceback import sys, os, time, errno -from test.script_helper import assert_python_ok, spawn_python +from test.support.script_helper import assert_python_ok, spawn_python try: import threading except ImportError: threading = None +try: + import _testcapi +except ImportError: + _testcapi = None class HandlerBCalled(Exception): @@ -39,6 +45,23 @@ def ignoring_eintr(__func, *args, **kwargs): return None +class GenericTests(unittest.TestCase): + + @unittest.skipIf(threading is None, "test needs threading module") + def test_enums(self): + for name in dir(signal): + sig = getattr(signal, name) + if name in {'SIG_DFL', 'SIG_IGN'}: + self.assertIsInstance(sig, signal.Handlers) + elif name in {'SIG_BLOCK', 'SIG_UNBLOCK', 'SIG_SETMASK'}: + self.assertIsInstance(sig, signal.Sigmasks) + elif name.startswith('SIG') and not name.startswith('SIG_'): + self.assertIsInstance(sig, signal.Signals) + elif name.startswith('CTRL_'): + self.assertIsInstance(sig, signal.Signals) + self.assertEqual(sys.platform, "win32") + + @unittest.skipIf(sys.platform == "win32", "Not valid on Windows") class InterProcessSignalTests(unittest.TestCase): MAX_DURATION = 20 # Entire test should last at most 20 sec. @@ -195,6 +218,7 @@ class PosixTests(unittest.TestCase): def test_getsignal(self): hup = signal.signal(signal.SIGHUP, self.trivial_signal_handler) + self.assertIsInstance(hup, signal.Handlers) self.assertEqual(signal.getsignal(signal.SIGHUP), self.trivial_signal_handler) signal.signal(signal.SIGHUP, hup) @@ -229,15 +253,77 @@ class WakeupFDTests(unittest.TestCase): def test_invalid_fd(self): fd = support.make_bad_fd() - self.assertRaises(ValueError, signal.set_wakeup_fd, fd) + self.assertRaises((ValueError, OSError), + signal.set_wakeup_fd, fd) + + def test_invalid_socket(self): + sock = socket.socket() + fd = sock.fileno() + sock.close() + self.assertRaises((ValueError, OSError), + signal.set_wakeup_fd, fd) + + def test_set_wakeup_fd_result(self): + r1, w1 = os.pipe() + self.addCleanup(os.close, r1) + self.addCleanup(os.close, w1) + r2, w2 = os.pipe() + self.addCleanup(os.close, r2) + self.addCleanup(os.close, w2) + + if hasattr(os, 'set_blocking'): + os.set_blocking(w1, False) + os.set_blocking(w2, False) + + signal.set_wakeup_fd(w1) + self.assertEqual(signal.set_wakeup_fd(w2), w1) + self.assertEqual(signal.set_wakeup_fd(-1), w2) + self.assertEqual(signal.set_wakeup_fd(-1), -1) + + def test_set_wakeup_fd_socket_result(self): + sock1 = socket.socket() + self.addCleanup(sock1.close) + sock1.setblocking(False) + fd1 = sock1.fileno() + + sock2 = socket.socket() + self.addCleanup(sock2.close) + sock2.setblocking(False) + fd2 = sock2.fileno() + + signal.set_wakeup_fd(fd1) + self.assertEqual(signal.set_wakeup_fd(fd2), fd1) + self.assertEqual(signal.set_wakeup_fd(-1), fd2) + self.assertEqual(signal.set_wakeup_fd(-1), -1) + + # On Windows, files are always blocking and Windows does not provide a + # function to test if a socket is in non-blocking mode. + @unittest.skipIf(sys.platform == "win32", "tests specific to POSIX") + def test_set_wakeup_fd_blocking(self): + rfd, wfd = os.pipe() + self.addCleanup(os.close, rfd) + self.addCleanup(os.close, wfd) + + # fd must be non-blocking + os.set_blocking(wfd, True) + with self.assertRaises(ValueError) as cm: + signal.set_wakeup_fd(wfd) + self.assertEqual(str(cm.exception), + "the fd %s must be in non-blocking mode" % wfd) + + # non-blocking is ok + os.set_blocking(wfd, False) + signal.set_wakeup_fd(wfd) + signal.set_wakeup_fd(-1) @unittest.skipIf(sys.platform == "win32", "Not valid on Windows") class WakeupSignalTests(unittest.TestCase): + @unittest.skipIf(_testcapi is None, 'need _testcapi') def check_wakeup(self, test_body, *signals, ordered=True): # use a subprocess to have only one thread code = """if 1: - import fcntl + import _testcapi import os import signal import struct @@ -260,10 +346,7 @@ class WakeupSignalTests(unittest.TestCase): signal.signal(signal.SIGALRM, handler) read, write = os.pipe() - for fd in (read, write): - flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) - flags = flags | os.O_NONBLOCK - fcntl.fcntl(fd, fcntl.F_SETFL, flags) + os.set_blocking(write, False) signal.set_wakeup_fd(write) test() @@ -271,21 +354,21 @@ class WakeupSignalTests(unittest.TestCase): os.close(read) os.close(write) - """.format(signals, ordered, test_body) + """.format(tuple(map(int, signals)), ordered, test_body) assert_python_ok('-c', code) + @unittest.skipIf(_testcapi is None, 'need _testcapi') def test_wakeup_write_error(self): # Issue #16105: write() errors in the C signal handler should not # pass silently. # Use a subprocess to have only one thread. code = """if 1: + import _testcapi import errno - import fcntl import os import signal import sys - import time from test.support import captured_stderr def handler(signum, frame): @@ -293,15 +376,13 @@ class WakeupSignalTests(unittest.TestCase): signal.signal(signal.SIGALRM, handler) r, w = os.pipe() - flags = fcntl.fcntl(r, fcntl.F_GETFL, 0) - fcntl.fcntl(r, fcntl.F_SETFL, flags | os.O_NONBLOCK) + os.set_blocking(r, False) # Set wakeup_fd a read-only file descriptor to trigger the error signal.set_wakeup_fd(r) try: with captured_stderr() as err: - signal.alarm(1) - time.sleep(5.0) + _testcapi.raise_signal(signal.SIGALRM) except ZeroDivisionError: # An ignored exception should have been printed out on stderr err = err.getvalue() @@ -312,6 +393,9 @@ class WakeupSignalTests(unittest.TestCase): raise AssertionError(err) else: raise AssertionError("ZeroDivisionError not raised") + + os.close(r) + os.close(w) """ r, w = os.pipe() try: @@ -334,18 +418,28 @@ class WakeupSignalTests(unittest.TestCase): TIMEOUT_FULL = 10 TIMEOUT_HALF = 5 + class InterruptSelect(Exception): + pass + + def handler(signum, frame): + raise InterruptSelect + signal.signal(signal.SIGALRM, handler) + signal.alarm(1) - before_time = time.time() + # We attempt to get a signal during the sleep, # before select is called - time.sleep(TIMEOUT_FULL) - mid_time = time.time() - dt = mid_time - before_time - if dt >= TIMEOUT_HALF: - raise Exception("%s >= %s" % (dt, TIMEOUT_HALF)) + try: + select.select([], [], [], TIMEOUT_FULL) + except InterruptSelect: + pass + else: + raise Exception("select() was not interrupted") + + before_time = time.monotonic() select.select([read], [], [], TIMEOUT_FULL) - after_time = time.time() - dt = after_time - mid_time + after_time = time.monotonic() + dt = after_time - before_time if dt >= TIMEOUT_HALF: raise Exception("%s >= %s" % (dt, TIMEOUT_HALF)) """, signal.SIGALRM) @@ -358,16 +452,23 @@ class WakeupSignalTests(unittest.TestCase): TIMEOUT_FULL = 10 TIMEOUT_HALF = 5 + class InterruptSelect(Exception): + pass + + def handler(signum, frame): + raise InterruptSelect + signal.signal(signal.SIGALRM, handler) + signal.alarm(1) - before_time = time.time() + before_time = time.monotonic() # We attempt to get a signal during the select call try: select.select([read], [], [], TIMEOUT_FULL) - except OSError: + except InterruptSelect: pass else: - raise Exception("OSError not raised") - after_time = time.time() + raise Exception("select() was not interrupted") + after_time = time.monotonic() dt = after_time - before_time if dt >= TIMEOUT_HALF: raise Exception("%s >= %s" % (dt, TIMEOUT_HALF)) @@ -375,9 +476,10 @@ class WakeupSignalTests(unittest.TestCase): def test_signum(self): self.check_wakeup("""def test(): + import _testcapi signal.signal(signal.SIGUSR1, handler) - os.kill(os.getpid(), signal.SIGUSR1) - os.kill(os.getpid(), signal.SIGALRM) + _testcapi.raise_signal(signal.SIGUSR1) + _testcapi.raise_signal(signal.SIGALRM) """, signal.SIGUSR1, signal.SIGALRM) @unittest.skipUnless(hasattr(signal, 'pthread_sigmask'), @@ -391,13 +493,97 @@ class WakeupSignalTests(unittest.TestCase): signal.signal(signum2, handler) signal.pthread_sigmask(signal.SIG_BLOCK, (signum1, signum2)) - os.kill(os.getpid(), signum1) - os.kill(os.getpid(), signum2) + _testcapi.raise_signal(signum1) + _testcapi.raise_signal(signum2) # Unblocking the 2 signals calls the C signal handler twice signal.pthread_sigmask(signal.SIG_UNBLOCK, (signum1, signum2)) """, signal.SIGUSR1, signal.SIGUSR2, ordered=False) +@unittest.skipUnless(hasattr(socket, 'socketpair'), 'need socket.socketpair') +class WakeupSocketSignalTests(unittest.TestCase): + + @unittest.skipIf(_testcapi is None, 'need _testcapi') + def test_socket(self): + # use a subprocess to have only one thread + code = """if 1: + import signal + import socket + import struct + import _testcapi + + signum = signal.SIGINT + signals = (signum,) + + def handler(signum, frame): + pass + + signal.signal(signum, handler) + + read, write = socket.socketpair() + read.setblocking(False) + write.setblocking(False) + signal.set_wakeup_fd(write.fileno()) + + _testcapi.raise_signal(signum) + + data = read.recv(1) + if not data: + raise Exception("no signum written") + raised = struct.unpack('B', data) + if raised != signals: + raise Exception("%r != %r" % (raised, signals)) + + read.close() + write.close() + """ + + assert_python_ok('-c', code) + + @unittest.skipIf(_testcapi is None, 'need _testcapi') + def test_send_error(self): + # Use a subprocess to have only one thread. + if os.name == 'nt': + action = 'send' + else: + action = 'write' + code = """if 1: + import errno + import signal + import socket + import sys + import time + import _testcapi + from test.support import captured_stderr + + signum = signal.SIGINT + + def handler(signum, frame): + pass + + signal.signal(signum, handler) + + read, write = socket.socketpair() + read.setblocking(False) + write.setblocking(False) + + signal.set_wakeup_fd(write.fileno()) + + # Close sockets: send() will fail + read.close() + write.close() + + with captured_stderr() as err: + _testcapi.raise_signal(signum) + + err = err.getvalue() + if ('Exception ignored when trying to {action} to the signal wakeup fd' + not in err): + raise AssertionError(err) + """.format(action=action) + assert_python_ok('-c', code) + + @unittest.skipIf(sys.platform == "win32", "Not valid on Windows") class SiginterruptTest(unittest.TestCase): @@ -418,7 +604,7 @@ class SiginterruptTest(unittest.TestCase): r, w = os.pipe() def handler(signum, frame): - pass + 1 / 0 signal.signal(signal.SIGALRM, handler) if interrupt is not None: @@ -428,18 +614,21 @@ class SiginterruptTest(unittest.TestCase): sys.stdout.flush() # run the test twice - for loop in range(2): - # send a SIGALRM in a second (during the read) - signal.alarm(1) - try: - # blocking call: read from a pipe without data - os.read(r, 1) - except OSError as err: - if err.errno != errno.EINTR: - raise - else: - sys.exit(2) - sys.exit(3) + try: + for loop in range(2): + # send a SIGALRM in a second (during the read) + signal.alarm(1) + try: + # blocking call: read from a pipe without data + os.read(r, 1) + except ZeroDivisionError: + pass + else: + sys.exit(2) + sys.exit(3) + finally: + os.close(r) + os.close(w) """ % (interrupt,) with spawn_python('-c', code) as process: try: @@ -537,8 +726,8 @@ class ItimerTest(unittest.TestCase): signal.signal(signal.SIGVTALRM, self.sig_vtalrm) signal.setitimer(self.itimer, 0.3, 0.2) - start_time = time.time() - while time.time() - start_time < 60.0: + start_time = time.monotonic() + while time.monotonic() - start_time < 60.0: # use up some virtual time by doing real work _ = pow(12345, 67890, 10000019) if signal.getitimer(self.itimer) == (0.0, 0.0): @@ -560,8 +749,8 @@ class ItimerTest(unittest.TestCase): signal.signal(signal.SIGPROF, self.sig_prof) signal.setitimer(self.itimer, 0.2, 0.2) - start_time = time.time() - while time.time() - start_time < 60.0: + start_time = time.monotonic() + while time.monotonic() - start_time < 60.0: # do some work _ = pow(12345, 67890, 10000019) if signal.getitimer(self.itimer) == (0.0, 0.0): @@ -604,6 +793,8 @@ class PendingSignalsTests(unittest.TestCase): signal.pthread_sigmask(signal.SIG_BLOCK, [signum]) os.kill(os.getpid(), signum) pending = signal.sigpending() + for sig in pending: + assert isinstance(sig, signal.Signals), repr(pending) if pending != {signum}: raise Exception('%s != {%s}' % (pending, signum)) try: @@ -660,6 +851,7 @@ class PendingSignalsTests(unittest.TestCase): code = '''if 1: import signal import sys + from signal import Signals def handler(signum, frame): 1/0 @@ -702,6 +894,7 @@ class PendingSignalsTests(unittest.TestCase): def test(signum): signal.alarm(1) received = signal.sigwait([signum]) + assert isinstance(received, signal.Signals), received if received != signum: raise Exception('received %s, not %s' % (received, signum)) ''') @@ -757,35 +950,6 @@ class PendingSignalsTests(unittest.TestCase): signum = signal.SIGALRM self.assertRaises(ValueError, signal.sigtimedwait, [signum], -1.0) - @unittest.skipUnless(hasattr(signal, 'sigwaitinfo'), - 'need signal.sigwaitinfo()') - # Issue #18238: sigwaitinfo() can be interrupted on Linux (raises - # InterruptedError), but not on AIX - @unittest.skipIf(sys.platform.startswith("aix"), - 'signal.sigwaitinfo() cannot be interrupted on AIX') - def test_sigwaitinfo_interrupted(self): - self.wait_helper(signal.SIGUSR1, ''' - def test(signum): - import errno - - hndl_called = True - def alarm_handler(signum, frame): - hndl_called = False - - signal.signal(signal.SIGALRM, alarm_handler) - signal.alarm(1) - try: - signal.sigwaitinfo([signal.SIGUSR1]) - except OSError as e: - if e.errno == errno.EINTR: - if not hndl_called: - raise Exception("SIGALRM handler not called") - else: - raise Exception("Expected EINTR to be raised by sigwaitinfo") - else: - raise Exception("Expected EINTR to be raised by sigwaitinfo") - ''') - @unittest.skipUnless(hasattr(signal, 'sigwait'), 'need signal.sigwait()') @unittest.skipUnless(hasattr(signal, 'pthread_sigmask'), @@ -842,8 +1006,14 @@ class PendingSignalsTests(unittest.TestCase): def kill(signum): os.kill(os.getpid(), signum) + def check_mask(mask): + for sig in mask: + assert isinstance(sig, signal.Signals), repr(sig) + def read_sigmask(): - return signal.pthread_sigmask(signal.SIG_BLOCK, []) + sigmask = signal.pthread_sigmask(signal.SIG_BLOCK, []) + check_mask(sigmask) + return sigmask signum = signal.SIGUSR1 @@ -852,6 +1022,7 @@ class PendingSignalsTests(unittest.TestCase): # Unblock SIGUSR1 (and copy the old mask) to test our signal handler old_mask = signal.pthread_sigmask(signal.SIG_UNBLOCK, [signum]) + check_mask(old_mask) try: kill(signum) except ZeroDivisionError: @@ -861,11 +1032,13 @@ class PendingSignalsTests(unittest.TestCase): # Block and then raise SIGUSR1. The signal is blocked: the signal # handler is not called, and the signal is now pending - signal.pthread_sigmask(signal.SIG_BLOCK, [signum]) + mask = signal.pthread_sigmask(signal.SIG_BLOCK, [signum]) + check_mask(mask) kill(signum) # Check the new mask blocked = read_sigmask() + check_mask(blocked) if signum not in blocked: raise Exception("%s not in %s" % (signum, blocked)) if old_mask ^ blocked != {signum}: @@ -926,15 +1099,8 @@ class PendingSignalsTests(unittest.TestCase): (exitcode, stdout)) -def test_main(): - try: - support.run_unittest(PosixTests, InterProcessSignalTests, - WakeupFDTests, WakeupSignalTests, - SiginterruptTest, ItimerTest, WindowsSignalTests, - PendingSignalsTests) - finally: - support.reap_children() - +def tearDownModule(): + support.reap_children() if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_site.py b/Lib/test/test_site.py index f71cf73..6615080 100644 --- a/Lib/test/test_site.py +++ b/Lib/test/test_site.py @@ -147,7 +147,7 @@ class HelperFunctionsTests(unittest.TestCase): re.escape(os.path.join(pth_dir, pth_fn))) # XXX: ditto previous XXX comment. self.assertRegex(err_out.getvalue(), 'Traceback') - self.assertRegex(err_out.getvalue(), 'TypeError') + self.assertRegex(err_out.getvalue(), 'ValueError') def test_addsitedir(self): # Same tests for test_addpackage since addsitedir() essentially just @@ -235,20 +235,18 @@ class HelperFunctionsTests(unittest.TestCase): # OS X framework builds site.PREFIXES = ['Python.framework'] dirs = site.getsitepackages() - self.assertEqual(len(dirs), 3) + self.assertEqual(len(dirs), 2) wanted = os.path.join('/Library', sysconfig.get_config_var("PYTHONFRAMEWORK"), sys.version[:3], 'site-packages') - self.assertEqual(dirs[2], wanted) + self.assertEqual(dirs[1], wanted) elif os.sep == '/': # OS X non-framwework builds, Linux, FreeBSD, etc - self.assertEqual(len(dirs), 2) + self.assertEqual(len(dirs), 1) wanted = os.path.join('xoxo', 'lib', 'python' + sys.version[:3], 'site-packages') self.assertEqual(dirs[0], wanted) - wanted = os.path.join('xoxo', 'lib', 'site-python') - self.assertEqual(dirs[1], wanted) else: # other platforms self.assertEqual(len(dirs), 2) @@ -357,8 +355,12 @@ class ImportSideEffectTests(unittest.TestCase): stdout, stderr = proc.communicate() self.assertEqual(proc.returncode, 0) os__file__, os__cached__ = stdout.splitlines()[:2] - self.assertTrue(os.path.isabs(os__file__)) - self.assertTrue(os.path.isabs(os__cached__)) + self.assertTrue(os.path.isabs(os__file__), + "expected absolute path, got {}" + .format(os__file__.decode('ascii'))) + self.assertTrue(os.path.isabs(os__cached__), + "expected absolute path, got {}" + .format(os__cached__.decode('ascii'))) def test_no_duplicate_paths(self): # No duplicate paths should exist in sys.path diff --git a/Lib/test/test_slice.py b/Lib/test/test_slice.py index 1ed71f9..8c4e670 100644 --- a/Lib/test/test_slice.py +++ b/Lib/test/test_slice.py @@ -1,7 +1,6 @@ # tests for slice objects; in particular the indices method. import unittest -from test import support from pickle import loads, dumps import itertools @@ -241,8 +240,5 @@ class SliceTest(unittest.TestCase): self.assertEqual(s.indices(15), t.indices(15)) self.assertNotEqual(id(s), id(t)) -def test_main(): - support.run_unittest(SliceTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_smtpd.py b/Lib/test/test_smtpd.py index 93f14c4..1aa55d2 100644 --- a/Lib/test/test_smtpd.py +++ b/Lib/test/test_smtpd.py @@ -1,4 +1,5 @@ import unittest +import textwrap from test import support, mock_socket import socket import io @@ -7,14 +8,20 @@ import asyncore class DummyServer(smtpd.SMTPServer): - def __init__(self, localaddr, remoteaddr): - smtpd.SMTPServer.__init__(self, localaddr, remoteaddr) + def __init__(self, *args, **kwargs): + smtpd.SMTPServer.__init__(self, *args, **kwargs) self.messages = [] + if self._decode_data: + self.return_status = 'return status' + else: + self.return_status = b'return status' - def process_message(self, peer, mailfrom, rcpttos, data): + def process_message(self, peer, mailfrom, rcpttos, data, **kw): self.messages.append((peer, mailfrom, rcpttos, data)) - if data == 'return status': + if data == self.return_status: return '250 Okish' + if 'mail_options' in kw and 'SMTPUTF8' in kw['mail_options']: + return '250 SMTPUTF8 message okish' class DummyDispatcherBroken(Exception): @@ -31,9 +38,10 @@ class SMTPDServerTest(unittest.TestCase): smtpd.socket = asyncore.socket = mock_socket def test_process_message_unimplemented(self): - server = smtpd.SMTPServer('a', 'b') + server = smtpd.SMTPServer((support.HOST, 0), ('b', 0), + decode_data=True) conn, addr = server.accept() - channel = smtpd.SMTPChannel(server, conn, addr) + channel = smtpd.SMTPChannel(server, conn, addr, decode_data=True) def write_line(line): channel.socket.queue_recv(line) @@ -45,19 +53,251 @@ class SMTPDServerTest(unittest.TestCase): write_line(b'DATA') self.assertRaises(NotImplementedError, write_line, b'spam\r\n.\r\n') + def test_decode_data_default_warns(self): + with self.assertWarns(DeprecationWarning): + smtpd.SMTPServer((support.HOST, 0), ('b', 0)) + + def test_decode_data_and_enable_SMTPUTF8_raises(self): + self.assertRaises( + ValueError, + smtpd.SMTPServer, + (support.HOST, 0), + ('b', 0), + enable_SMTPUTF8=True, + decode_data=True) + + def tearDown(self): + asyncore.close_all() + asyncore.socket = smtpd.socket = socket + + +class DebuggingServerTest(unittest.TestCase): + + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + + def send_data(self, channel, data, enable_SMTPUTF8=False): + def write_line(line): + channel.socket.queue_recv(line) + channel.handle_read() + write_line(b'EHLO example') + if enable_SMTPUTF8: + write_line(b'MAIL From:eggs@example BODY=8BITMIME SMTPUTF8') + else: + write_line(b'MAIL From:eggs@example') + write_line(b'RCPT To:spam@example') + write_line(b'DATA') + write_line(data) + write_line(b'.') + + def test_process_message_with_decode_data_true(self): + server = smtpd.DebuggingServer((support.HOST, 0), ('b', 0), + decode_data=True) + conn, addr = server.accept() + channel = smtpd.SMTPChannel(server, conn, addr, decode_data=True) + with support.captured_stdout() as s: + self.send_data(channel, b'From: test\n\nhello\n') + stdout = s.getvalue() + self.assertEqual(stdout, textwrap.dedent("""\ + ---------- MESSAGE FOLLOWS ---------- + From: test + X-Peer: peer-address + + hello + ------------ END MESSAGE ------------ + """)) + + def test_process_message_with_decode_data_false(self): + server = smtpd.DebuggingServer((support.HOST, 0), ('b', 0), + decode_data=False) + conn, addr = server.accept() + channel = smtpd.SMTPChannel(server, conn, addr, decode_data=False) + with support.captured_stdout() as s: + self.send_data(channel, b'From: test\n\nh\xc3\xa9llo\xff\n') + stdout = s.getvalue() + self.assertEqual(stdout, textwrap.dedent("""\ + ---------- MESSAGE FOLLOWS ---------- + b'From: test' + b'X-Peer: peer-address' + b'' + b'h\\xc3\\xa9llo\\xff' + ------------ END MESSAGE ------------ + """)) + + def test_process_message_with_enable_SMTPUTF8_true(self): + server = smtpd.DebuggingServer((support.HOST, 0), ('b', 0), + enable_SMTPUTF8=True) + conn, addr = server.accept() + channel = smtpd.SMTPChannel(server, conn, addr, enable_SMTPUTF8=True) + with support.captured_stdout() as s: + self.send_data(channel, b'From: test\n\nh\xc3\xa9llo\xff\n') + stdout = s.getvalue() + self.assertEqual(stdout, textwrap.dedent("""\ + ---------- MESSAGE FOLLOWS ---------- + b'From: test' + b'X-Peer: peer-address' + b'' + b'h\\xc3\\xa9llo\\xff' + ------------ END MESSAGE ------------ + """)) + + def test_process_SMTPUTF8_message_with_enable_SMTPUTF8_true(self): + server = smtpd.DebuggingServer((support.HOST, 0), ('b', 0), + enable_SMTPUTF8=True) + conn, addr = server.accept() + channel = smtpd.SMTPChannel(server, conn, addr, enable_SMTPUTF8=True) + with support.captured_stdout() as s: + self.send_data(channel, b'From: test\n\nh\xc3\xa9llo\xff\n', + enable_SMTPUTF8=True) + stdout = s.getvalue() + self.assertEqual(stdout, textwrap.dedent("""\ + ---------- MESSAGE FOLLOWS ---------- + mail options: ['BODY=8BITMIME', 'SMTPUTF8'] + b'From: test' + b'X-Peer: peer-address' + b'' + b'h\\xc3\\xa9llo\\xff' + ------------ END MESSAGE ------------ + """)) + + def tearDown(self): + asyncore.close_all() + asyncore.socket = smtpd.socket = socket + + +class TestFamilyDetection(unittest.TestCase): + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + def tearDown(self): asyncore.close_all() asyncore.socket = smtpd.socket = socket + @unittest.skipUnless(support.IPV6_ENABLED, "IPv6 not enabled") + def test_socket_uses_IPv6(self): + server = smtpd.SMTPServer((support.HOSTv6, 0), (support.HOST, 0), + decode_data=False) + self.assertEqual(server.socket.family, socket.AF_INET6) + + def test_socket_uses_IPv4(self): + server = smtpd.SMTPServer((support.HOST, 0), (support.HOSTv6, 0), + decode_data=False) + self.assertEqual(server.socket.family, socket.AF_INET) + + +class TestRcptOptionParsing(unittest.TestCase): + error_response = (b'555 RCPT TO parameters not recognized or not ' + b'implemented\r\n') + + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + self.old_debugstream = smtpd.DEBUGSTREAM + self.debug = smtpd.DEBUGSTREAM = io.StringIO() + + def tearDown(self): + asyncore.close_all() + asyncore.socket = smtpd.socket = socket + smtpd.DEBUGSTREAM = self.old_debugstream + + def write_line(self, channel, line): + channel.socket.queue_recv(line) + channel.handle_read() + + def test_params_rejected(self): + server = DummyServer((support.HOST, 0), ('b', 0), decode_data=False) + conn, addr = server.accept() + channel = smtpd.SMTPChannel(server, conn, addr, decode_data=False) + self.write_line(channel, b'EHLO example') + self.write_line(channel, b'MAIL from: <foo@example.com> size=20') + self.write_line(channel, b'RCPT to: <foo@example.com> foo=bar') + self.assertEqual(channel.socket.last, self.error_response) + + def test_nothing_accepted(self): + server = DummyServer((support.HOST, 0), ('b', 0), decode_data=False) + conn, addr = server.accept() + channel = smtpd.SMTPChannel(server, conn, addr, decode_data=False) + self.write_line(channel, b'EHLO example') + self.write_line(channel, b'MAIL from: <foo@example.com> size=20') + self.write_line(channel, b'RCPT to: <foo@example.com>') + self.assertEqual(channel.socket.last, b'250 OK\r\n') + + +class TestMailOptionParsing(unittest.TestCase): + error_response = (b'555 MAIL FROM parameters not recognized or not ' + b'implemented\r\n') + + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + self.old_debugstream = smtpd.DEBUGSTREAM + self.debug = smtpd.DEBUGSTREAM = io.StringIO() + + def tearDown(self): + asyncore.close_all() + asyncore.socket = smtpd.socket = socket + smtpd.DEBUGSTREAM = self.old_debugstream + + def write_line(self, channel, line): + channel.socket.queue_recv(line) + channel.handle_read() + + def test_with_decode_data_true(self): + server = DummyServer((support.HOST, 0), ('b', 0), decode_data=True) + conn, addr = server.accept() + channel = smtpd.SMTPChannel(server, conn, addr, decode_data=True) + self.write_line(channel, b'EHLO example') + for line in [ + b'MAIL from: <foo@example.com> size=20 SMTPUTF8', + b'MAIL from: <foo@example.com> size=20 SMTPUTF8 BODY=8BITMIME', + b'MAIL from: <foo@example.com> size=20 BODY=UNKNOWN', + b'MAIL from: <foo@example.com> size=20 body=8bitmime', + ]: + self.write_line(channel, line) + self.assertEqual(channel.socket.last, self.error_response) + self.write_line(channel, b'MAIL from: <foo@example.com> size=20') + self.assertEqual(channel.socket.last, b'250 OK\r\n') + + def test_with_decode_data_false(self): + server = DummyServer((support.HOST, 0), ('b', 0), decode_data=False) + conn, addr = server.accept() + channel = smtpd.SMTPChannel(server, conn, addr, decode_data=False) + self.write_line(channel, b'EHLO example') + for line in [ + b'MAIL from: <foo@example.com> size=20 SMTPUTF8', + b'MAIL from: <foo@example.com> size=20 SMTPUTF8 BODY=8BITMIME', + ]: + self.write_line(channel, line) + self.assertEqual(channel.socket.last, self.error_response) + self.write_line( + channel, + b'MAIL from: <foo@example.com> size=20 SMTPUTF8 BODY=UNKNOWN') + self.assertEqual( + channel.socket.last, + b'501 Error: BODY can only be one of 7BIT, 8BITMIME\r\n') + self.write_line( + channel, b'MAIL from: <foo@example.com> size=20 body=8bitmime') + self.assertEqual(channel.socket.last, b'250 OK\r\n') + + def test_with_enable_smtputf8_true(self): + server = DummyServer((support.HOST, 0), ('b', 0), enable_SMTPUTF8=True) + conn, addr = server.accept() + channel = smtpd.SMTPChannel(server, conn, addr, enable_SMTPUTF8=True) + self.write_line(channel, b'EHLO example') + self.write_line( + channel, + b'MAIL from: <foo@example.com> size=20 body=8bitmime smtputf8') + self.assertEqual(channel.socket.last, b'250 OK\r\n') + class SMTPDChannelTest(unittest.TestCase): def setUp(self): smtpd.socket = asyncore.socket = mock_socket self.old_debugstream = smtpd.DEBUGSTREAM self.debug = smtpd.DEBUGSTREAM = io.StringIO() - self.server = DummyServer('a', 'b') + self.server = DummyServer((support.HOST, 0), ('b', 0), + decode_data=True) conn, addr = self.server.accept() - self.channel = smtpd.SMTPChannel(self.server, conn, addr) + self.channel = smtpd.SMTPChannel(self.server, conn, addr, + decode_data=True) def tearDown(self): asyncore.close_all() @@ -69,7 +309,9 @@ class SMTPDChannelTest(unittest.TestCase): self.channel.handle_read() def test_broken_connect(self): - self.assertRaises(DummyDispatcherBroken, BrokenDummyServer, 'a', 'b') + self.assertRaises( + DummyDispatcherBroken, BrokenDummyServer, + (support.HOST, 0), ('b', 0), decode_data=True) def test_server_accept(self): self.server.handle_accept() @@ -214,6 +456,12 @@ class SMTPDChannelTest(unittest.TestCase): self.assertEqual(self.channel.socket.last, b'500 Error: line too long\r\n') + def test_MAIL_command_rejects_SMTPUTF8_by_default(self): + self.write_line(b'EHLO example') + self.write_line( + b'MAIL from: <naive@example.com> BODY=8BITMIME SMTPUTF8') + self.assertEqual(self.channel.socket.last[0:1], b'5') + def test_data_longer_than_default_data_size_limit(self): # Hack the default so we don't have to generate so much data. self.channel.data_size_limit = 1048 @@ -387,7 +635,10 @@ class SMTPDChannelTest(unittest.TestCase): self.write_line(b'data\r\nmore\r\n.') self.assertEqual(self.channel.socket.last, b'250 OK\r\n') self.assertEqual(self.server.messages, - [('peer', 'eggs@example', ['spam@example'], 'data\nmore')]) + [(('peer-address', 'peer-port'), + 'eggs@example', + ['spam@example'], + 'data\nmore')]) def test_DATA_syntax(self): self.write_line(b'HELO example') @@ -417,7 +668,10 @@ class SMTPDChannelTest(unittest.TestCase): self.write_line(b'DATA') self.write_line(b'data\r\n.') self.assertEqual(self.server.messages, - [('peer', 'eggs@example', ['spam@example','ham@example'], 'data')]) + [(('peer-address', 'peer-port'), + 'eggs@example', + ['spam@example','ham@example'], + 'data')]) def test_manual_status(self): # checks that the Channel is able to return a custom status message @@ -439,7 +693,10 @@ class SMTPDChannelTest(unittest.TestCase): self.write_line(b'DATA') self.write_line(b'data\r\n.') self.assertEqual(self.server.messages, - [('peer', 'foo@example', ['eggs@example'], 'data')]) + [(('peer-address', 'peer-port'), + 'foo@example', + ['eggs@example'], + 'data')]) def test_HELO_RSET(self): self.write_line(b'HELO example') @@ -502,6 +759,24 @@ class SMTPDChannelTest(unittest.TestCase): with support.check_warnings(('', DeprecationWarning)): self.channel._SMTPChannel__addr = 'spam' + def test_decode_data_default_warning(self): + with self.assertWarns(DeprecationWarning): + server = DummyServer((support.HOST, 0), ('b', 0)) + conn, addr = self.server.accept() + with self.assertWarns(DeprecationWarning): + smtpd.SMTPChannel(server, conn, addr) + +@unittest.skipUnless(support.IPV6_ENABLED, "IPv6 not enabled") +class SMTPDChannelIPv6Test(SMTPDChannelTest): + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + self.old_debugstream = smtpd.DEBUGSTREAM + self.debug = smtpd.DEBUGSTREAM = io.StringIO() + self.server = DummyServer((support.HOSTv6, 0), ('b', 0), + decode_data=True) + conn, addr = self.server.accept() + self.channel = smtpd.SMTPChannel(self.server, conn, addr, + decode_data=True) class SMTPDChannelWithDataSizeLimitTest(unittest.TestCase): @@ -509,10 +784,12 @@ class SMTPDChannelWithDataSizeLimitTest(unittest.TestCase): smtpd.socket = asyncore.socket = mock_socket self.old_debugstream = smtpd.DEBUGSTREAM self.debug = smtpd.DEBUGSTREAM = io.StringIO() - self.server = DummyServer('a', 'b') + self.server = DummyServer((support.HOST, 0), ('b', 0), + decode_data=True) conn, addr = self.server.accept() # Set DATA size limit to 32 bytes for easy testing - self.channel = smtpd.SMTPChannel(self.server, conn, addr, 32) + self.channel = smtpd.SMTPChannel(self.server, conn, addr, 32, + decode_data=True) def tearDown(self): asyncore.close_all() @@ -536,7 +813,10 @@ class SMTPDChannelWithDataSizeLimitTest(unittest.TestCase): self.write_line(b'data\r\nmore\r\n.') self.assertEqual(self.channel.socket.last, b'250 OK\r\n') self.assertEqual(self.server.messages, - [('peer', 'eggs@example', ['spam@example'], 'data\nmore')]) + [(('peer-address', 'peer-port'), + 'eggs@example', + ['spam@example'], + 'data\nmore')]) def test_data_limit_dialog_too_much_data(self): self.write_line(b'HELO example') @@ -553,5 +833,181 @@ class SMTPDChannelWithDataSizeLimitTest(unittest.TestCase): b'552 Error: Too much mail data\r\n') +class SMTPDChannelWithDecodeDataFalse(unittest.TestCase): + + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + self.old_debugstream = smtpd.DEBUGSTREAM + self.debug = smtpd.DEBUGSTREAM = io.StringIO() + self.server = DummyServer((support.HOST, 0), ('b', 0), + decode_data=False) + conn, addr = self.server.accept() + # Set decode_data to False + self.channel = smtpd.SMTPChannel(self.server, conn, addr, + decode_data=False) + + def tearDown(self): + asyncore.close_all() + asyncore.socket = smtpd.socket = socket + smtpd.DEBUGSTREAM = self.old_debugstream + + def write_line(self, line): + self.channel.socket.queue_recv(line) + self.channel.handle_read() + + def test_ascii_data(self): + self.write_line(b'HELO example') + self.write_line(b'MAIL From:eggs@example') + self.write_line(b'RCPT To:spam@example') + self.write_line(b'DATA') + self.write_line(b'plain ascii text') + self.write_line(b'.') + self.assertEqual(self.channel.received_data, b'plain ascii text') + + def test_utf8_data(self): + self.write_line(b'HELO example') + self.write_line(b'MAIL From:eggs@example') + self.write_line(b'RCPT To:spam@example') + self.write_line(b'DATA') + self.write_line(b'utf8 enriched text: \xc5\xbc\xc5\xba\xc4\x87') + self.write_line(b'and some plain ascii') + self.write_line(b'.') + self.assertEqual( + self.channel.received_data, + b'utf8 enriched text: \xc5\xbc\xc5\xba\xc4\x87\n' + b'and some plain ascii') + + +class SMTPDChannelWithDecodeDataTrue(unittest.TestCase): + + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + self.old_debugstream = smtpd.DEBUGSTREAM + self.debug = smtpd.DEBUGSTREAM = io.StringIO() + self.server = DummyServer((support.HOST, 0), ('b', 0), + decode_data=True) + conn, addr = self.server.accept() + # Set decode_data to True + self.channel = smtpd.SMTPChannel(self.server, conn, addr, + decode_data=True) + + def tearDown(self): + asyncore.close_all() + asyncore.socket = smtpd.socket = socket + smtpd.DEBUGSTREAM = self.old_debugstream + + def write_line(self, line): + self.channel.socket.queue_recv(line) + self.channel.handle_read() + + def test_ascii_data(self): + self.write_line(b'HELO example') + self.write_line(b'MAIL From:eggs@example') + self.write_line(b'RCPT To:spam@example') + self.write_line(b'DATA') + self.write_line(b'plain ascii text') + self.write_line(b'.') + self.assertEqual(self.channel.received_data, 'plain ascii text') + + def test_utf8_data(self): + self.write_line(b'HELO example') + self.write_line(b'MAIL From:eggs@example') + self.write_line(b'RCPT To:spam@example') + self.write_line(b'DATA') + self.write_line(b'utf8 enriched text: \xc5\xbc\xc5\xba\xc4\x87') + self.write_line(b'and some plain ascii') + self.write_line(b'.') + self.assertEqual( + self.channel.received_data, + 'utf8 enriched text: żźć\nand some plain ascii') + + +class SMTPDChannelTestWithEnableSMTPUTF8True(unittest.TestCase): + def setUp(self): + smtpd.socket = asyncore.socket = mock_socket + self.old_debugstream = smtpd.DEBUGSTREAM + self.debug = smtpd.DEBUGSTREAM = io.StringIO() + self.server = DummyServer((support.HOST, 0), ('b', 0), + enable_SMTPUTF8=True) + conn, addr = self.server.accept() + self.channel = smtpd.SMTPChannel(self.server, conn, addr, + enable_SMTPUTF8=True) + + def tearDown(self): + asyncore.close_all() + asyncore.socket = smtpd.socket = socket + smtpd.DEBUGSTREAM = self.old_debugstream + + def write_line(self, line): + self.channel.socket.queue_recv(line) + self.channel.handle_read() + + def test_MAIL_command_accepts_SMTPUTF8_when_announced(self): + self.write_line(b'EHLO example') + self.write_line( + 'MAIL from: <naïve@example.com> BODY=8BITMIME SMTPUTF8'.encode( + 'utf-8') + ) + self.assertEqual(self.channel.socket.last, b'250 OK\r\n') + + def test_process_smtputf8_message(self): + self.write_line(b'EHLO example') + for mail_parameters in [b'', b'BODY=8BITMIME SMTPUTF8']: + self.write_line(b'MAIL from: <a@example> ' + mail_parameters) + self.assertEqual(self.channel.socket.last[0:3], b'250') + self.write_line(b'rcpt to:<b@example.com>') + self.assertEqual(self.channel.socket.last[0:3], b'250') + self.write_line(b'data') + self.assertEqual(self.channel.socket.last[0:3], b'354') + self.write_line(b'c\r\n.') + if mail_parameters == b'': + self.assertEqual(self.channel.socket.last, b'250 OK\r\n') + else: + self.assertEqual(self.channel.socket.last, + b'250 SMTPUTF8 message okish\r\n') + + def test_utf8_data(self): + self.write_line(b'EHLO example') + self.write_line( + 'MAIL From: naïve@examplé BODY=8BITMIME SMTPUTF8'.encode('utf-8')) + self.assertEqual(self.channel.socket.last[0:3], b'250') + self.write_line('RCPT To:späm@examplé'.encode('utf-8')) + self.assertEqual(self.channel.socket.last[0:3], b'250') + self.write_line(b'DATA') + self.assertEqual(self.channel.socket.last[0:3], b'354') + self.write_line(b'utf8 enriched text: \xc5\xbc\xc5\xba\xc4\x87') + self.write_line(b'.') + self.assertEqual( + self.channel.received_data, + b'utf8 enriched text: \xc5\xbc\xc5\xba\xc4\x87') + + def test_MAIL_command_limit_extended_with_SIZE_and_SMTPUTF8(self): + self.write_line(b'ehlo example') + fill_len = (512 + 26 + 10) - len('mail from:<@example>') + self.write_line(b'MAIL from:<' + + b'a' * (fill_len + 1) + + b'@example>') + self.assertEqual(self.channel.socket.last, + b'500 Error: line too long\r\n') + self.write_line(b'MAIL from:<' + + b'a' * fill_len + + b'@example>') + self.assertEqual(self.channel.socket.last, b'250 OK\r\n') + + def test_multiple_emails_with_extended_command_length(self): + self.write_line(b'ehlo example') + fill_len = (512 + 26 + 10) - len('mail from:<@example>') + for char in [b'a', b'b', b'c']: + self.write_line(b'MAIL from:<' + char * fill_len + b'a@example>') + self.assertEqual(self.channel.socket.last[0:3], b'500') + self.write_line(b'MAIL from:<' + char * fill_len + b'@example>') + self.assertEqual(self.channel.socket.last[0:3], b'250') + self.write_line(b'rcpt to:<hans@example.com>') + self.assertEqual(self.channel.socket.last[0:3], b'250') + self.write_line(b'data') + self.assertEqual(self.channel.socket.last[0:3], b'354') + self.write_line(b'test\r\n.') + self.assertEqual(self.channel.socket.last[0:3], b'250') + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index 95a9dbe..8e362414 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -1,5 +1,7 @@ import asyncore import email.mime.text +from email.message import EmailMessage +from email.base64mime import body_encode as encode_base64 import email.utils import socket import smtpd @@ -10,6 +12,7 @@ import sys import time import select import errno +import textwrap import unittest from test import support, mock_socket @@ -30,7 +33,7 @@ if sys.platform == 'darwin': def server(evt, buf, serv): - serv.listen(5) + serv.listen() evt.set() try: conn, addr = serv.accept() @@ -123,6 +126,27 @@ class GeneralTests(unittest.TestCase): self.assertEqual(smtp.sock.gettimeout(), 30) smtp.close() + def test_debuglevel(self): + mock_socket.reply_with(b"220 Hello world") + smtp = smtplib.SMTP() + smtp.set_debuglevel(1) + with support.captured_stderr() as stderr: + smtp.connect(HOST, self.port) + smtp.close() + expected = re.compile(r"^connect:", re.MULTILINE) + self.assertRegex(stderr.getvalue(), expected) + + def test_debuglevel_2(self): + mock_socket.reply_with(b"220 Hello world") + smtp = smtplib.SMTP() + smtp.set_debuglevel(2) + with support.captured_stderr() as stderr: + smtp.connect(HOST, self.port) + smtp.close() + expected = re.compile(r"^\d{2}:\d{2}:\d{2}\.\d{6} connect: ", + re.MULTILINE) + self.assertRegex(stderr.getvalue(), expected) + # Test server thread using the specified SMTP server class def debugging_server(serv, serv_evt, client_evt): @@ -184,7 +208,8 @@ class DebuggingServerTests(unittest.TestCase): self.old_DEBUGSTREAM = smtpd.DEBUGSTREAM smtpd.DEBUGSTREAM = io.StringIO() # Pick a random unused port by passing 0 for the port number - self.serv = smtpd.DebuggingServer((HOST, 0), ('nowhere', -1)) + self.serv = smtpd.DebuggingServer((HOST, 0), ('nowhere', -1), + decode_data=True) # Keep a note of what port was assigned self.port = self.serv.socket.getsockname()[1] serv_args = (self.serv, self.serv_evt, self.client_evt) @@ -604,7 +629,8 @@ sim_auth_credentials = { 'cram-md5': ('TXIUQUBZB21LD2HLCMUUY29TIDG4OWQ0MJ' 'KWZGQ4ODNMNDA4NTGXMDRLZWMYZJDMODG1'), } -sim_auth_login_password = 'C29TZXBHC3N3B3JK' +sim_auth_login_user = 'TXIUQUBZB21LD2HLCMUUY29T' +sim_auth_plain = 'AE1YLKFAC29TZXDOZXJLLMNVBQBZB21LCGFZC3DVCMQ=' sim_lists = {'list-1':['Mr.A@somewhere.com','Mrs.C@somewhereesle.com'], 'list-2':['Ms.B@xn--fo-fka.com',], @@ -658,18 +684,16 @@ class SimSMTPChannel(smtpd.SMTPChannel): self.push('550 No access for you!') def smtp_AUTH(self, arg): - if arg.strip().lower()=='cram-md5': + mech = arg.strip().lower() + if mech=='cram-md5': self.push('334 {}'.format(sim_cram_md5_challenge)) - return - mech, auth = arg.split() - mech = mech.lower() - if mech not in sim_auth_credentials: + elif mech not in sim_auth_credentials: self.push('504 auth type unimplemented') return - if mech == 'plain' and auth==sim_auth_credentials['plain']: - self.push('235 plain auth ok') - elif mech=='login' and auth==sim_auth_credentials['login']: - self.push('334 Password:') + elif mech=='plain': + self.push('334 ') + elif mech=='login': + self.push('334 ') else: self.push('550 No access for you!') @@ -719,7 +743,8 @@ class SimSMTPServer(smtpd.SMTPServer): def handle_accepted(self, conn, addr): self._SMTPchannel = self.channel_class( - self._extra_features, self, conn, addr) + self._extra_features, self, conn, addr, + decode_data=self._decode_data) def process_message(self, peer, mailfrom, rcpttos, data): pass @@ -742,7 +767,7 @@ class SMTPSimTests(unittest.TestCase): self.serv_evt = threading.Event() self.client_evt = threading.Event() # Pick a random unused port by passing 0 for the port number - self.serv = SimSMTPServer((HOST, 0), ('nowhere', -1)) + self.serv = SimSMTPServer((HOST, 0), ('nowhere', -1), decode_data=True) # Keep a note of what port was assigned self.port = self.serv.socket.getsockname()[1] serv_args = (self.serv, self.serv_evt, self.client_evt) @@ -790,11 +815,11 @@ class SMTPSimTests(unittest.TestCase): def testVRFY(self): smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15) - for email, name in sim_users.items(): + for addr_spec, name in sim_users.items(): expected_known = (250, bytes('%s %s' % - (name, smtplib.quoteaddr(email)), + (name, smtplib.quoteaddr(addr_spec)), "ascii")) - self.assertEqual(smtp.vrfy(email), expected_known) + self.assertEqual(smtp.vrfy(addr_spec), expected_known) u = 'nobody@nowhere.com' expected_unknown = (550, ('No such user: %s' % u).encode('ascii')) @@ -816,28 +841,28 @@ class SMTPSimTests(unittest.TestCase): self.assertEqual(smtp.expn(u), expected_unknown) smtp.quit() - def testAUTH_PLAIN(self): - self.serv.add_feature("AUTH PLAIN") - smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15) - - expected_auth_ok = (235, b'plain auth ok') - self.assertEqual(smtp.login(sim_auth[0], sim_auth[1]), expected_auth_ok) - smtp.close() - - # SimSMTPChannel doesn't fully support LOGIN or CRAM-MD5 auth because they - # require a synchronous read to obtain the credentials...so instead smtpd + # SimSMTPChannel doesn't fully support AUTH because it requires a + # synchronous read to obtain the credentials...so instead smtpd # sees the credential sent by smtplib's login method as an unknown command, # which results in smtplib raising an auth error. Fortunately the error # message contains the encoded credential, so we can partially check that it # was generated correctly (partially, because the 'word' is uppercased in # the error message). + def testAUTH_PLAIN(self): + self.serv.add_feature("AUTH PLAIN") + smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15) + try: smtp.login(sim_auth[0], sim_auth[1], initial_response_ok=False) + except smtplib.SMTPAuthenticationError as err: + self.assertIn(sim_auth_plain, str(err)) + smtp.close() + def testAUTH_LOGIN(self): self.serv.add_feature("AUTH LOGIN") smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15) try: smtp.login(sim_auth[0], sim_auth[1]) except smtplib.SMTPAuthenticationError as err: - self.assertIn(sim_auth_login_password, str(err)) + self.assertIn(sim_auth_login_user, str(err)) smtp.close() def testAUTH_CRAM_MD5(self): @@ -855,7 +880,23 @@ class SMTPSimTests(unittest.TestCase): smtp = smtplib.SMTP(HOST, self.port, local_hostname='localhost', timeout=15) try: smtp.login(sim_auth[0], sim_auth[1]) except smtplib.SMTPAuthenticationError as err: - self.assertIn(sim_auth_login_password, str(err)) + self.assertIn(sim_auth_login_user, str(err)) + smtp.close() + + def test_auth_function(self): + smtp = smtplib.SMTP(HOST, self.port, + local_hostname='localhost', timeout=15) + self.serv.add_feature("AUTH CRAM-MD5") + smtp.user, smtp.password = sim_auth[0], sim_auth[1] + supported = {'CRAM-MD5': smtp.auth_cram_md5, + 'PLAIN': smtp.auth_plain, + 'LOGIN': smtp.auth_login, + } + for mechanism, method in supported.items(): + try: smtp.auth(mechanism, method, initial_response_ok=False) + except smtplib.SMTPAuthenticationError as err: + self.assertIn(sim_auth_credentials[mechanism.lower()].upper(), + str(err)) smtp.close() def test_quit_resets_greeting(self): @@ -938,13 +979,249 @@ class SMTPSimTests(unittest.TestCase): self.assertIsNone(smtp.sock) self.assertEqual(self.serv._SMTPchannel.rcpt_count, 0) + def test_smtputf8_NotSupportedError_if_no_server_support(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', timeout=3) + self.addCleanup(smtp.close) + smtp.ehlo() + self.assertTrue(smtp.does_esmtp) + self.assertFalse(smtp.has_extn('smtputf8')) + self.assertRaises( + smtplib.SMTPNotSupportedError, + smtp.sendmail, + 'John', 'Sally', '', mail_options=['BODY=8BITMIME', 'SMTPUTF8']) + self.assertRaises( + smtplib.SMTPNotSupportedError, + smtp.mail, 'John', options=['BODY=8BITMIME', 'SMTPUTF8']) + + def test_send_unicode_without_SMTPUTF8(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', timeout=3) + self.addCleanup(smtp.close) + self.assertRaises(UnicodeEncodeError, smtp.sendmail, 'Alice', 'Böb', '') + self.assertRaises(UnicodeEncodeError, smtp.mail, 'Älice') + + +class SimSMTPUTF8Server(SimSMTPServer): + + def __init__(self, *args, **kw): + # The base SMTP server turns these on automatically, but our test + # server is set up to munge the EHLO response, so we need to provide + # them as well. And yes, the call is to SMTPServer not SimSMTPServer. + self._extra_features = ['SMTPUTF8', '8BITMIME'] + smtpd.SMTPServer.__init__(self, *args, **kw) + + def handle_accepted(self, conn, addr): + self._SMTPchannel = self.channel_class( + self._extra_features, self, conn, addr, + decode_data=self._decode_data, + enable_SMTPUTF8=self.enable_SMTPUTF8, + ) + + def process_message(self, peer, mailfrom, rcpttos, data, mail_options=None, + rcpt_options=None): + self.last_peer = peer + self.last_mailfrom = mailfrom + self.last_rcpttos = rcpttos + self.last_message = data + self.last_mail_options = mail_options + self.last_rcpt_options = rcpt_options + + +@unittest.skipUnless(threading, 'Threading required for this test.') +class SMTPUTF8SimTests(unittest.TestCase): + + maxDiff = None + + def setUp(self): + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Pick a random unused port by passing 0 for the port number + self.serv = SimSMTPUTF8Server((HOST, 0), ('nowhere', -1), + decode_data=False, + enable_SMTPUTF8=True) + # Keep a note of what port was assigned + self.port = self.serv.socket.getsockname()[1] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + self.thread.join() + + def test_test_server_supports_extensions(self): + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', timeout=3) + self.addCleanup(smtp.close) + smtp.ehlo() + self.assertTrue(smtp.does_esmtp) + self.assertTrue(smtp.has_extn('smtputf8')) + + def test_send_unicode_with_SMTPUTF8_via_sendmail(self): + m = '¡a test message containing unicode!'.encode('utf-8') + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', timeout=3) + self.addCleanup(smtp.close) + smtp.sendmail('Jőhn', 'Sálly', m, + mail_options=['BODY=8BITMIME', 'SMTPUTF8']) + self.assertEqual(self.serv.last_mailfrom, 'Jőhn') + self.assertEqual(self.serv.last_rcpttos, ['Sálly']) + self.assertEqual(self.serv.last_message, m) + self.assertIn('BODY=8BITMIME', self.serv.last_mail_options) + self.assertIn('SMTPUTF8', self.serv.last_mail_options) + self.assertEqual(self.serv.last_rcpt_options, []) + + def test_send_unicode_with_SMTPUTF8_via_low_level_API(self): + m = '¡a test message containing unicode!'.encode('utf-8') + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', timeout=3) + self.addCleanup(smtp.close) + smtp.ehlo() + self.assertEqual( + smtp.mail('Jő', options=['BODY=8BITMIME', 'SMTPUTF8']), + (250, b'OK')) + self.assertEqual(smtp.rcpt('János'), (250, b'OK')) + self.assertEqual(smtp.data(m), (250, b'OK')) + self.assertEqual(self.serv.last_mailfrom, 'Jő') + self.assertEqual(self.serv.last_rcpttos, ['János']) + self.assertEqual(self.serv.last_message, m) + self.assertIn('BODY=8BITMIME', self.serv.last_mail_options) + self.assertIn('SMTPUTF8', self.serv.last_mail_options) + self.assertEqual(self.serv.last_rcpt_options, []) + + def test_send_message_uses_smtputf8_if_addrs_non_ascii(self): + msg = EmailMessage() + msg['From'] = "Páolo <főo@bar.com>" + msg['To'] = 'Dinsdale' + msg['Subject'] = 'Nudge nudge, wink, wink \u1F609' + # XXX I don't know why I need two \n's here, but this is an existing + # bug (if it is one) and not a problem with the new functionality. + msg.set_content("oh là là, know what I mean, know what I mean?\n\n") + # XXX smtpd converts received /r/n to /n, so we can't easily test that + # we are successfully sending /r/n :(. + expected = textwrap.dedent("""\ + From: Páolo <főo@bar.com> + To: Dinsdale + Subject: Nudge nudge, wink, wink \u1F609 + Content-Type: text/plain; charset="utf-8" + Content-Transfer-Encoding: 8bit + MIME-Version: 1.0 + + oh là là, know what I mean, know what I mean? + """) + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', timeout=3) + self.addCleanup(smtp.close) + self.assertEqual(smtp.send_message(msg), {}) + self.assertEqual(self.serv.last_mailfrom, 'főo@bar.com') + self.assertEqual(self.serv.last_rcpttos, ['Dinsdale']) + self.assertEqual(self.serv.last_message.decode(), expected) + self.assertIn('BODY=8BITMIME', self.serv.last_mail_options) + self.assertIn('SMTPUTF8', self.serv.last_mail_options) + self.assertEqual(self.serv.last_rcpt_options, []) + + def test_send_message_error_on_non_ascii_addrs_if_no_smtputf8(self): + msg = EmailMessage() + msg['From'] = "Páolo <főo@bar.com>" + msg['To'] = 'Dinsdale' + msg['Subject'] = 'Nudge nudge, wink, wink \u1F609' + smtp = smtplib.SMTP( + HOST, self.port, local_hostname='localhost', timeout=3) + self.addCleanup(smtp.close) + self.assertRaises(smtplib.SMTPNotSupportedError, + smtp.send_message(msg)) + + +EXPECTED_RESPONSE = encode_base64(b'\0psu\0doesnotexist', eol='') + +class SimSMTPAUTHInitialResponseChannel(SimSMTPChannel): + def smtp_AUTH(self, arg): + # RFC 4954's AUTH command allows for an optional initial-response. + # Not all AUTH methods support this; some require a challenge. AUTH + # PLAIN does those, so test that here. See issue #15014. + args = arg.split() + if args[0].lower() == 'plain': + if len(args) == 2: + # AUTH PLAIN <initial-response> with the response base 64 + # encoded. Hard code the expected response for the test. + if args[1] == EXPECTED_RESPONSE: + self.push('235 Ok') + return + self.push('571 Bad authentication') + +class SimSMTPAUTHInitialResponseServer(SimSMTPServer): + channel_class = SimSMTPAUTHInitialResponseChannel + + +@unittest.skipUnless(threading, 'Threading required for this test.') +class SMTPAUTHInitialResponseSimTests(unittest.TestCase): + def setUp(self): + self.real_getfqdn = socket.getfqdn + socket.getfqdn = mock_socket.getfqdn + self.serv_evt = threading.Event() + self.client_evt = threading.Event() + # Pick a random unused port by passing 0 for the port number + self.serv = SimSMTPAUTHInitialResponseServer( + (HOST, 0), ('nowhere', -1), decode_data=True) + # Keep a note of what port was assigned + self.port = self.serv.socket.getsockname()[1] + serv_args = (self.serv, self.serv_evt, self.client_evt) + self.thread = threading.Thread(target=debugging_server, args=serv_args) + self.thread.start() + + # wait until server thread has assigned a port number + self.serv_evt.wait() + self.serv_evt.clear() + + def tearDown(self): + socket.getfqdn = self.real_getfqdn + # indicate that the client is finished + self.client_evt.set() + # wait for the server thread to terminate + self.serv_evt.wait() + self.thread.join() + + def testAUTH_PLAIN_initial_response_login(self): + self.serv.add_feature('AUTH PLAIN') + smtp = smtplib.SMTP(HOST, self.port, + local_hostname='localhost', timeout=15) + smtp.login('psu', 'doesnotexist') + smtp.close() + + def testAUTH_PLAIN_initial_response_auth(self): + self.serv.add_feature('AUTH PLAIN') + smtp = smtplib.SMTP(HOST, self.port, + local_hostname='localhost', timeout=15) + smtp.user = 'psu' + smtp.password = 'doesnotexist' + code, response = smtp.auth('plain', smtp.auth_plain) + smtp.close() + self.assertEqual(code, 235) + @support.reap_threads def test_main(verbose=None): - support.run_unittest(GeneralTests, DebuggingServerTests, - NonConnectingTests, - BadHELOServerTests, SMTPSimTests, - TooLongLineTests) + support.run_unittest( + BadHELOServerTests, + DebuggingServerTests, + GeneralTests, + NonConnectingTests, + SMTPAUTHInitialResponseSimTests, + SMTPSimTests, + TooLongLineTests, + ) + if __name__ == '__main__': test_main() diff --git a/Lib/test/test_smtpnet.py b/Lib/test/test_smtpnet.py index 15654f2..cc9bab4 100644 --- a/Lib/test/test_smtpnet.py +++ b/Lib/test/test_smtpnet.py @@ -79,8 +79,5 @@ class SmtpSSLTest(unittest.TestCase): server.quit() -def test_main(): - support.run_unittest(SmtpTest, SmtpSSLTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_sndhdr.py b/Lib/test/test_sndhdr.py index 5e0abe0..426417c 100644 --- a/Lib/test/test_sndhdr.py +++ b/Lib/test/test_sndhdr.py @@ -1,4 +1,5 @@ import sndhdr +import pickle import unittest from test.support import findfile @@ -18,6 +19,19 @@ class TestFormats(unittest.TestCase): what = sndhdr.what(filename) self.assertNotEqual(what, None, filename) self.assertSequenceEqual(what, expected) + self.assertEqual(what.filetype, expected[0]) + self.assertEqual(what.framerate, expected[1]) + self.assertEqual(what.nchannels, expected[2]) + self.assertEqual(what.nframes, expected[3]) + self.assertEqual(what.sampwidth, expected[4]) + + def test_pickleable(self): + filename = findfile('sndhdr.aifc', subdir="sndhdrdata") + what = sndhdr.what(filename) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + dump = pickle.dumps(what, proto) + self.assertEqual(pickle.loads(dump), what) + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index d319112..17819f2 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -20,6 +20,8 @@ import signal import math import pickle import struct +import random +import string try: import multiprocessing except ImportError: @@ -76,7 +78,7 @@ class SocketTCPTest(unittest.TestCase): def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.port = support.bind_port(self.serv) - self.serv.listen(1) + self.serv.listen() def tearDown(self): self.serv.close() @@ -445,7 +447,7 @@ class SocketListeningTestMixin(SocketTestBase): def setUp(self): super().setUp() - self.serv.listen(1) + self.serv.listen() class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase, @@ -716,11 +718,11 @@ class GeneralModuleTests(unittest.TestCase): with self.assertRaises(TypeError) as cm: s.sendto('\u2620', sockname) self.assertEqual(str(cm.exception), - "'str' does not support the buffer interface") + "a bytes-like object is required, not 'str'") with self.assertRaises(TypeError) as cm: s.sendto(5j, sockname) self.assertEqual(str(cm.exception), - "'complex' does not support the buffer interface") + "a bytes-like object is required, not 'complex'") with self.assertRaises(TypeError) as cm: s.sendto(b'foo', None) self.assertIn('not NoneType',str(cm.exception)) @@ -728,11 +730,11 @@ class GeneralModuleTests(unittest.TestCase): with self.assertRaises(TypeError) as cm: s.sendto('\u2620', 0, sockname) self.assertEqual(str(cm.exception), - "'str' does not support the buffer interface") + "a bytes-like object is required, not 'str'") with self.assertRaises(TypeError) as cm: s.sendto(5j, 0, sockname) self.assertEqual(str(cm.exception), - "'complex' does not support the buffer interface") + "a bytes-like object is required, not 'complex'") with self.assertRaises(TypeError) as cm: s.sendto(b'foo', 0, None) self.assertIn('not NoneType', str(cm.exception)) @@ -1072,6 +1074,7 @@ class GeneralModuleTests(unittest.TestCase): assertInvalid(f, b'\x00' * 3) assertInvalid(f, b'\x00' * 5) assertInvalid(f, b'\x00' * 16) + self.assertEqual('170.85.170.85', f(bytearray(b'\xaa\x55\xaa\x55'))) self.assertEqual('1.0.1.0', g(b'\x01\x00\x01\x00')) self.assertEqual('170.85.170.85', g(b'\xaa\x55\xaa\x55')) @@ -1079,6 +1082,7 @@ class GeneralModuleTests(unittest.TestCase): assertInvalid(g, b'\x00' * 3) assertInvalid(g, b'\x00' * 5) assertInvalid(g, b'\x00' * 16) + self.assertEqual('170.85.170.85', g(bytearray(b'\xaa\x55\xaa\x55'))) @unittest.skipUnless(hasattr(socket, 'inet_ntop'), 'test needs socket.inet_ntop()') @@ -1108,6 +1112,7 @@ class GeneralModuleTests(unittest.TestCase): 'aef:b01:506:1001:ffff:9997:55:170', f(b'\x0a\xef\x0b\x01\x05\x06\x10\x01\xff\xff\x99\x97\x00\x55\x01\x70') ) + self.assertEqual('::1', f(bytearray(b'\x00' * 15 + b'\x01'))) assertInvalid(b'\x12' * 15) assertInvalid(b'\x12' * 17) @@ -1382,10 +1387,13 @@ class GeneralModuleTests(unittest.TestCase): def test_listen_backlog(self): for backlog in 0, -1: - srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: + srv.bind((HOST, 0)) + srv.listen(backlog) + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as srv: srv.bind((HOST, 0)) - srv.listen(backlog) - srv.close() + srv.listen() @support.cpython_only def test_listen_backlog_overflow(self): @@ -1491,6 +1499,7 @@ class BasicCANTest(unittest.TestCase): s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, can_filter) self.assertEqual(can_filter, s.getsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, 8)) + s.setsockopt(socket.SOL_CAN_RAW, socket.CAN_RAW_FILTER, bytearray(can_filter)) @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') @@ -3593,7 +3602,7 @@ class InterruptedTimeoutBase(unittest.TestCase): def setUp(self): super().setUp() orig_alrm_handler = signal.signal(signal.SIGALRM, - lambda signum, frame: None) + lambda signum, frame: 1 / 0) self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler) self.addCleanup(self.setAlarm, 0) @@ -3630,13 +3639,11 @@ class InterruptedRecvTimeoutTest(InterruptedTimeoutBase, UDPTestBase): self.serv.settimeout(self.timeout) def checkInterruptedRecv(self, func, *args, **kwargs): - # Check that func(*args, **kwargs) raises OSError with an + # Check that func(*args, **kwargs) raises # errno of EINTR when interrupted by a signal. self.setAlarm(self.alarm_time) - with self.assertRaises(OSError) as cm: + with self.assertRaises(ZeroDivisionError) as cm: func(*args, **kwargs) - self.assertNotIsInstance(cm.exception, socket.timeout) - self.assertEqual(cm.exception.errno, errno.EINTR) def testInterruptedRecvTimeout(self): self.checkInterruptedRecv(self.serv.recv, 1024) @@ -3692,12 +3699,10 @@ class InterruptedSendTimeoutTest(InterruptedTimeoutBase, # Check that func(*args, **kwargs), run in a loop, raises # OSError with an errno of EINTR when interrupted by a # signal. - with self.assertRaises(OSError) as cm: + with self.assertRaises(ZeroDivisionError) as cm: while True: self.setAlarm(self.alarm_time) func(*args, **kwargs) - self.assertNotIsInstance(cm.exception, socket.timeout) - self.assertEqual(cm.exception.errno, errno.EINTR) # Issue #12958: The following tests have problems on OS X prior to 10.7 @support.requires_mac_ver(10, 7) @@ -3739,8 +3744,6 @@ class TCPCloserTest(ThreadedTCPSocketTest): self.cli.connect((HOST, self.port)) time.sleep(1.0) -@unittest.skipUnless(hasattr(socket, 'socketpair'), - 'test needs socket.socketpair()') @unittest.skipUnless(thread, 'Threading required for this test.') class BasicSocketPairTest(SocketPairTest): @@ -3821,7 +3824,7 @@ class NonBlockingTCPTests(ThreadedTCPSocketTest): self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM | socket.SOCK_NONBLOCK) self.port = support.bind_port(self.serv) - self.serv.listen(1) + self.serv.listen() # actual testing start = time.time() try: @@ -4067,117 +4070,6 @@ class FileObjectClassTestCase(SocketConnectedTest): pass -class FileObjectInterruptedTestCase(unittest.TestCase): - """Test that the file object correctly handles EINTR internally.""" - - class MockSocket(object): - def __init__(self, recv_funcs=()): - # A generator that returns callables that we'll call for each - # call to recv(). - self._recv_step = iter(recv_funcs) - - def recv_into(self, buffer): - data = next(self._recv_step)() - assert len(buffer) >= len(data) - buffer[:len(data)] = data - return len(data) - - def _decref_socketios(self): - pass - - def _textiowrap_for_test(self, buffering=-1): - raw = socket.SocketIO(self, "r") - if buffering < 0: - buffering = io.DEFAULT_BUFFER_SIZE - if buffering == 0: - return raw - buffer = io.BufferedReader(raw, buffering) - text = io.TextIOWrapper(buffer, None, None) - text.mode = "rb" - return text - - @staticmethod - def _raise_eintr(): - raise OSError(errno.EINTR, "interrupted") - - def _textiowrap_mock_socket(self, mock, buffering=-1): - raw = socket.SocketIO(mock, "r") - if buffering < 0: - buffering = io.DEFAULT_BUFFER_SIZE - if buffering == 0: - return raw - buffer = io.BufferedReader(raw, buffering) - text = io.TextIOWrapper(buffer, None, None) - text.mode = "rb" - return text - - def _test_readline(self, size=-1, buffering=-1): - mock_sock = self.MockSocket(recv_funcs=[ - lambda : b"This is the first line\nAnd the sec", - self._raise_eintr, - lambda : b"ond line is here\n", - lambda : b"", - lambda : b"", # XXX(gps): io library does an extra EOF read - ]) - fo = mock_sock._textiowrap_for_test(buffering=buffering) - self.assertEqual(fo.readline(size), "This is the first line\n") - self.assertEqual(fo.readline(size), "And the second line is here\n") - - def _test_read(self, size=-1, buffering=-1): - mock_sock = self.MockSocket(recv_funcs=[ - lambda : b"This is the first line\nAnd the sec", - self._raise_eintr, - lambda : b"ond line is here\n", - lambda : b"", - lambda : b"", # XXX(gps): io library does an extra EOF read - ]) - expecting = (b"This is the first line\n" - b"And the second line is here\n") - fo = mock_sock._textiowrap_for_test(buffering=buffering) - if buffering == 0: - data = b'' - else: - data = '' - expecting = expecting.decode('utf-8') - while len(data) != len(expecting): - part = fo.read(size) - if not part: - break - data += part - self.assertEqual(data, expecting) - - def test_default(self): - self._test_readline() - self._test_readline(size=100) - self._test_read() - self._test_read(size=100) - - def test_with_1k_buffer(self): - self._test_readline(buffering=1024) - self._test_readline(size=100, buffering=1024) - self._test_read(buffering=1024) - self._test_read(size=100, buffering=1024) - - def _test_readline_no_buffer(self, size=-1): - mock_sock = self.MockSocket(recv_funcs=[ - lambda : b"a", - lambda : b"\n", - lambda : b"B", - self._raise_eintr, - lambda : b"b", - lambda : b"", - ]) - fo = mock_sock._textiowrap_for_test(buffering=0) - self.assertEqual(fo.readline(size), b"a\n") - self.assertEqual(fo.readline(size), b"Bb") - - def test_no_buffer(self): - self._test_readline_no_buffer() - self._test_readline_no_buffer(size=4) - self._test_read(buffering=0) - self._test_read(size=100, buffering=0) - - class UnbufferedFileObjectClassTestCase(FileObjectClassTestCase): """Repeat the tests from FileObjectClassTestCase with bufsize==0. @@ -4596,7 +4488,7 @@ class TestLinuxAbstractNamespace(unittest.TestCase): address = b"\x00python-test-hello\x00\xff" with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s1: s1.bind(address) - s1.listen(1) + s1.listen() with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s2: s2.connect(s1.getsockname()) with s1.accept()[0] as s3: @@ -4623,6 +4515,12 @@ class TestLinuxAbstractNamespace(unittest.TestCase): finally: s.close() + def testBytearrayName(self): + # Check that an abstract name can be passed as a bytearray. + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.bind(bytearray(b"\x00python\x00test\x00")) + self.assertEqual(s.getsockname(), b"\x00python\x00test\x00") + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'test needs socket.AF_UNIX') class TestUnixDomain(unittest.TestCase): @@ -4828,7 +4726,7 @@ class TIPCThreadableTest(unittest.TestCase, ThreadableTest): srvaddr = (socket.TIPC_ADDR_NAMESEQ, TIPC_STYPE, TIPC_LOWER, TIPC_UPPER) self.srv.bind(srvaddr) - self.srv.listen(5) + self.srv.listen() self.serverExplicitReady() self.conn, self.connaddr = self.srv.accept() self.addCleanup(self.conn.close) @@ -5117,6 +5015,275 @@ class TestSocketSharing(SocketTCPTest): source.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendfileUsingSendTest(ThreadedTCPSocketTest): + """ + Test the send() implementation of socket.sendfile(). + """ + + FILESIZE = (10 * 1024 * 1024) # 10MB + BUFSIZE = 8192 + FILEDATA = b"" + TIMEOUT = 2 + + @classmethod + def setUpClass(cls): + def chunks(total, step): + assert total >= step + while total > step: + yield step + total -= step + if total: + yield total + + chunk = b"".join([random.choice(string.ascii_letters).encode() + for i in range(cls.BUFSIZE)]) + with open(support.TESTFN, 'wb') as f: + for csize in chunks(cls.FILESIZE, cls.BUFSIZE): + f.write(chunk) + with open(support.TESTFN, 'rb') as f: + cls.FILEDATA = f.read() + assert len(cls.FILEDATA) == cls.FILESIZE + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + + def accept_conn(self): + self.serv.settimeout(self.TIMEOUT) + conn, addr = self.serv.accept() + conn.settimeout(self.TIMEOUT) + self.addCleanup(conn.close) + return conn + + def recv_data(self, conn): + received = [] + while True: + chunk = conn.recv(self.BUFSIZE) + if not chunk: + break + received.append(chunk) + return b''.join(received) + + def meth_from_sock(self, sock): + # Depending on the mixin class being run return either send() + # or sendfile() method implementation. + return getattr(sock, "_sendfile_use_send") + + # regular file + + def _testRegularFile(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + + def testRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # non regular file + + def _testNonRegularFile(self): + address = self.serv.getsockname() + file = io.BytesIO(self.FILEDATA) + with socket.create_connection(address) as sock, file as file: + sent = sock.sendfile(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + self.assertRaises(socket._GiveupOnSendfile, + sock._sendfile_use_sendfile, file) + + def testNonRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # empty file + + def _testEmptyFileSend(self): + address = self.serv.getsockname() + filename = support.TESTFN + "2" + with open(filename, 'wb'): + self.addCleanup(support.unlink, filename) + file = open(filename, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, 0) + self.assertEqual(file.tell(), 0) + + def testEmptyFileSend(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(data, b"") + + # offset + + def _testOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file, offset=5000) + self.assertEqual(sent, self.FILESIZE - 5000) + self.assertEqual(file.tell(), self.FILESIZE) + + def testOffset(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE - 5000) + self.assertEqual(data, self.FILEDATA[5000:]) + + # count + + def _testCount(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 5000007 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCount(self): + count = 5000007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count small + + def _testCountSmall(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 1 + meth = self.meth_from_sock(sock) + sent = meth(file, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count) + + def testCountSmall(self): + count = 1 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[:count]) + + # count + offset + + def _testCountWithOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + count = 100007 + meth = self.meth_from_sock(sock) + sent = meth(file, offset=2007, count=count) + self.assertEqual(sent, count) + self.assertEqual(file.tell(), count + 2007) + + def testCountWithOffset(self): + count = 100007 + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), count) + self.assertEqual(data, self.FILEDATA[2007:count+2007]) + + # non blocking sockets are not supposed to work + + def _testNonBlocking(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sock.setblocking(False) + meth = self.meth_from_sock(sock) + self.assertRaises(ValueError, meth, file) + self.assertRaises(ValueError, sock.sendfile, file) + + def testNonBlocking(self): + conn = self.accept_conn() + if conn.recv(8192): + self.fail('was not supposed to receive any data') + + # timeout (non-triggered) + + def _testWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + meth = self.meth_from_sock(sock) + sent = meth(file) + self.assertEqual(sent, self.FILESIZE) + + def testWithTimeout(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # timeout (triggered) + + def _testWithTimeoutTriggeredSend(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=0.01) as sock, \ + file as file: + meth = self.meth_from_sock(sock) + self.assertRaises(socket.timeout, meth, file) + + def testWithTimeoutTriggeredSend(self): + conn = self.accept_conn() + conn.recv(88192) + + # errors + + def _test_errors(self): + pass + + def test_errors(self): + with open(support.TESTFN, 'rb') as file: + with socket.socket(type=socket.SOCK_DGRAM) as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "SOCK_STREAM", meth, file) + with open(support.TESTFN, 'rt') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex( + ValueError, "binary mode", meth, file) + with open(support.TESTFN, 'rb') as file: + with socket.socket() as s: + meth = self.meth_from_sock(s) + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count='2') + self.assertRaisesRegex(TypeError, "positive integer", + meth, file, count=0.1) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=0) + self.assertRaisesRegex(ValueError, "positive integer", + meth, file, count=-1) + + +@unittest.skipUnless(thread, 'Threading required for this test.') +@unittest.skipUnless(hasattr(os, "sendfile"), + 'os.sendfile() required for this test.') +class SendfileUsingSendfileTest(SendfileUsingSendTest): + """ + Test the sendfile() implementation of socket.sendfile(). + """ + def meth_from_sock(self, sock): + return getattr(sock, "_sendfile_use_sendfile") + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] @@ -5124,7 +5291,6 @@ def test_main(): tests.extend([ NonBlockingTCPTests, FileObjectClassTestCase, - FileObjectInterruptedTestCase, UnbufferedFileObjectClassTestCase, LineBufferedFileObjectClassTestCase, SmallBufferedFileObjectClassTestCase, @@ -5169,6 +5335,8 @@ def test_main(): InterruptedRecvTimeoutTest, InterruptedSendTimeoutTest, TestSocketSharing, + SendfileUsingSendTest, + SendfileUsingSendfileTest, ]) thread_info = support.threading_setup() diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py index 325d485..1ea66a6 100644 --- a/Lib/test/test_socketserver.py +++ b/Lib/test/test_socketserver.py @@ -222,38 +222,6 @@ class SocketServerTest(unittest.TestCase): socketserver.DatagramRequestHandler, self.dgram_examine) - @contextlib.contextmanager - def mocked_select_module(self): - """Mocks the select.select() call to raise EINTR for first call""" - old_select = select.select - - class MockSelect: - def __init__(self): - self.called = 0 - - def __call__(self, *args): - self.called += 1 - if self.called == 1: - # raise the exception on first call - raise OSError(errno.EINTR, os.strerror(errno.EINTR)) - else: - # Return real select value for consecutive calls - return old_select(*args) - - select.select = MockSelect() - try: - yield select.select - finally: - select.select = old_select - - def test_InterruptServerSelectCall(self): - with self.mocked_select_module() as mock_select: - pid = self.run_server(socketserver.TCPServer, - socketserver.StreamRequestHandler, - self.stream_examine) - # Make sure select was called again: - self.assertGreater(mock_select.called, 1) - # Alas, on Linux (at least) recvfrom() doesn't return a meaningful # client address so this cannot work: diff --git a/Lib/test/test_sort.py b/Lib/test/test_sort.py index 8f6af64..a5d0ebf 100644 --- a/Lib/test/test_sort.py +++ b/Lib/test/test_sort.py @@ -262,24 +262,5 @@ class TestDecorateSortUndecorate(unittest.TestCase): #============================================================================== -def test_main(verbose=None): - test_classes = ( - TestBase, - TestDecorateSortUndecorate, - TestBugs, - ) - - support.run_unittest(*test_classes) - - # verify reference counting - if verbose and hasattr(sys, "gettotalrefcount"): - import gc - counts = [None] * 5 - for i in range(len(counts)): - support.run_unittest(*test_classes) - gc.collect() - counts[i] = sys.gettotalrefcount() - print(counts) - if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index cdf3aed..f314ff4 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -85,6 +85,12 @@ def have_verify_flags(): # 0.9.8 or higher return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15) +def utc_offset(): #NOTE: ignore issues like #1647654 + # local time = utc time + utc offset + if time.daylight and time.localtime().tm_isdst > 0: + return -time.altzone # seconds + return -time.timezone + def asn1time(cert_time): # Some versions of OpenSSL ignore seconds, see #18207 # 0.9.8.i @@ -133,6 +139,14 @@ class BasicSocketTests(unittest.TestCase): self.assertIn(ssl.HAS_SNI, {True, False}) self.assertIn(ssl.HAS_ECDH, {True, False}) + def test_str_for_enums(self): + # Make sure that the PROTOCOL_* constants have enum-like string + # reprs. + proto = ssl.PROTOCOL_SSLv23 + self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_SSLv23') + ctx = ssl.SSLContext(proto) + self.assertIs(ctx.protocol, proto) + def test_random(self): v = ssl.RAND_status() if support.verbose: @@ -157,6 +171,8 @@ class BasicSocketTests(unittest.TestCase): self.assertRaises(TypeError, ssl.RAND_egd, 1) self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1) ssl.RAND_add("this is a random string", 75.0) + ssl.RAND_add(b"this is a random bytes object", 75.0) + ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0) @unittest.skipUnless(os.name == 'posix', 'requires posix') def test_random_fork(self): @@ -297,10 +313,10 @@ class BasicSocketTests(unittest.TestCase): # Version string as returned by {Open,Libre}SSL, the format might change if "LibreSSL" in s: self.assertTrue(s.startswith("LibreSSL {:d}.{:d}".format(major, minor)), - (s, t)) + (s, t, hex(n))) else: self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)), - (s, t)) + (s, t, hex(n))) @support.cpython_only def test_refcycle(self): @@ -368,6 +384,8 @@ class BasicSocketTests(unittest.TestCase): self.assertRaises(ssl.CertificateError, ssl.match_hostname, cert, hostname) + # -- Hostname matching -- + cert = {'subject': ((('commonName', 'example.com'),),)} ok(cert, 'example.com') ok(cert, 'ExAmple.cOm') @@ -453,6 +471,28 @@ class BasicSocketTests(unittest.TestCase): # Only commonName is considered fail(cert, 'California') + # -- IPv4 matching -- + cert = {'subject': ((('commonName', 'example.com'),),), + 'subjectAltName': (('DNS', 'example.com'), + ('IP Address', '10.11.12.13'), + ('IP Address', '14.15.16.17'))} + ok(cert, '10.11.12.13') + ok(cert, '14.15.16.17') + fail(cert, '14.15.16.18') + fail(cert, 'example.net') + + # -- IPv6 matching -- + cert = {'subject': ((('commonName', 'example.com'),),), + 'subjectAltName': (('DNS', 'example.com'), + ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'), + ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))} + ok(cert, '2001::cafe') + ok(cert, '2003::baba') + fail(cert, '2003::bebe') + fail(cert, 'example.net') + + # -- Miscellaneous -- + # Neither commonName nor subjectAltName cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT', 'subject': ((('countryName', 'US'),), @@ -504,9 +544,14 @@ class BasicSocketTests(unittest.TestCase): def test_unknown_channel_binding(self): # should raise ValueError for unknown type s = socket.socket(socket.AF_INET) - with ssl.wrap_socket(s) as ss: + s.bind(('127.0.0.1', 0)) + s.listen() + c = socket.socket(socket.AF_INET) + c.connect(s.getsockname()) + with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss: with self.assertRaises(ValueError): ss.get_channel_binding("unknown-type") + s.close() @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, "'tls-unique' channel binding not available") @@ -647,6 +692,71 @@ class BasicSocketTests(unittest.TestCase): ctx.wrap_socket(s) self.assertEqual(str(cx.exception), "only stream sockets are supported") + def cert_time_ok(self, timestring, timestamp): + self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp) + + def cert_time_fail(self, timestring): + with self.assertRaises(ValueError): + ssl.cert_time_to_seconds(timestring) + + @unittest.skipUnless(utc_offset(), + 'local time needs to be different from UTC') + def test_cert_time_to_seconds_timezone(self): + # Issue #19940: ssl.cert_time_to_seconds() returns wrong + # results if local timezone is not UTC + self.cert_time_ok("May 9 00:00:00 2007 GMT", 1178668800.0) + self.cert_time_ok("Jan 5 09:34:43 2018 GMT", 1515144883.0) + + def test_cert_time_to_seconds(self): + timestring = "Jan 5 09:34:43 2018 GMT" + ts = 1515144883.0 + self.cert_time_ok(timestring, ts) + # accept keyword parameter, assert its name + self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts) + # accept both %e and %d (space or zero generated by strftime) + self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts) + # case-insensitive + self.cert_time_ok("JaN 5 09:34:43 2018 GmT", ts) + self.cert_time_fail("Jan 5 09:34 2018 GMT") # no seconds + self.cert_time_fail("Jan 5 09:34:43 2018") # no GMT + self.cert_time_fail("Jan 5 09:34:43 2018 UTC") # not GMT timezone + self.cert_time_fail("Jan 35 09:34:43 2018 GMT") # invalid day + self.cert_time_fail("Jon 5 09:34:43 2018 GMT") # invalid month + self.cert_time_fail("Jan 5 24:00:00 2018 GMT") # invalid hour + self.cert_time_fail("Jan 5 09:60:43 2018 GMT") # invalid minute + + newyear_ts = 1230768000.0 + # leap seconds + self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts) + # same timestamp + self.cert_time_ok("Jan 1 00:00:00 2009 GMT", newyear_ts) + + self.cert_time_ok("Jan 5 09:34:59 2018 GMT", 1515144899) + # allow 60th second (even if it is not a leap second) + self.cert_time_ok("Jan 5 09:34:60 2018 GMT", 1515144900) + # allow 2nd leap second for compatibility with time.strptime() + self.cert_time_ok("Jan 5 09:34:61 2018 GMT", 1515144901) + self.cert_time_fail("Jan 5 09:34:62 2018 GMT") # invalid seconds + + # no special treatement for the special value: + # 99991231235959Z (rfc 5280) + self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0) + + @support.run_with_locale('LC_ALL', '') + def test_cert_time_to_seconds_locale(self): + # `cert_time_to_seconds()` should be locale independent + + def local_february_name(): + return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0)) + + if local_february_name().lower() == 'feb': + self.skipTest("locale-specific month name needs to be " + "different from C locale") + + # locale-independent + self.cert_time_ok("Feb 9 00:00:00 2007 GMT", 1170979200.0) + self.cert_time_fail(local_february_name() + " 9 00:00:00 2007 GMT") + class ContextTests(unittest.TestCase): @@ -1156,7 +1266,7 @@ class SSLErrorTests(unittest.TestCase): ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) with socket.socket() as s: s.bind(("127.0.0.1", 0)) - s.listen(5) + s.listen() c = socket.socket() c.connect(s.getsockname()) c.setblocking(False) @@ -1169,6 +1279,69 @@ class SSLErrorTests(unittest.TestCase): self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ) +class MemoryBIOTests(unittest.TestCase): + + def test_read_write(self): + bio = ssl.MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + self.assertEqual(bio.read(), b'') + bio.write(b'foo') + bio.write(b'bar') + self.assertEqual(bio.read(), b'foobar') + self.assertEqual(bio.read(), b'') + bio.write(b'baz') + self.assertEqual(bio.read(2), b'ba') + self.assertEqual(bio.read(1), b'z') + self.assertEqual(bio.read(1), b'') + + def test_eof(self): + bio = ssl.MemoryBIO() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertFalse(bio.eof) + bio.write(b'foo') + self.assertFalse(bio.eof) + bio.write_eof() + self.assertFalse(bio.eof) + self.assertEqual(bio.read(2), b'fo') + self.assertFalse(bio.eof) + self.assertEqual(bio.read(1), b'o') + self.assertTrue(bio.eof) + self.assertEqual(bio.read(), b'') + self.assertTrue(bio.eof) + + def test_pending(self): + bio = ssl.MemoryBIO() + self.assertEqual(bio.pending, 0) + bio.write(b'foo') + self.assertEqual(bio.pending, 3) + for i in range(3): + bio.read(1) + self.assertEqual(bio.pending, 3-i-1) + for i in range(3): + bio.write(b'x') + self.assertEqual(bio.pending, i+1) + bio.read() + self.assertEqual(bio.pending, 0) + + def test_buffer_types(self): + bio = ssl.MemoryBIO() + bio.write(b'foo') + self.assertEqual(bio.read(), b'foo') + bio.write(bytearray(b'bar')) + self.assertEqual(bio.read(), b'bar') + bio.write(memoryview(b'baz')) + self.assertEqual(bio.read(), b'baz') + + def test_error_types(self): + bio = ssl.MemoryBIO() + self.assertRaises(TypeError, bio.write, 'foo') + self.assertRaises(TypeError, bio.write, None) + self.assertRaises(TypeError, bio.write, True) + self.assertRaises(TypeError, bio.write, 1) + + class NetworkedTests(unittest.TestCase): def test_connect(self): @@ -1396,14 +1569,12 @@ class NetworkedTests(unittest.TestCase): def test_get_server_certificate(self): def _test_get_server_certificate(host, port, cert=None): with support.transient_internet(host): - pem = ssl.get_server_certificate((host, port), - ssl.PROTOCOL_SSLv23) + pem = ssl.get_server_certificate((host, port)) if not pem: self.fail("No server certificate on %s:%s!" % (host, port)) try: pem = ssl.get_server_certificate((host, port), - ssl.PROTOCOL_SSLv23, ca_certs=CERTFILE) except ssl.SSLError as x: #should fail @@ -1413,7 +1584,6 @@ class NetworkedTests(unittest.TestCase): self.fail("Got server certificate %s for %s:%s!" % (pem, host, port)) pem = ssl.get_server_certificate((host, port), - ssl.PROTOCOL_SSLv23, ca_certs=cert) if not pem: self.fail("No server certificate on %s:%s!" % (host, port)) @@ -1499,6 +1669,93 @@ class NetworkedTests(unittest.TestCase): self.assertIs(ss.context, ctx2) self.assertIs(ss._sslobj.context, ctx2) + +class NetworkedBIOTests(unittest.TestCase): + + def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs): + # A simple IO loop. Call func(*args) depending on the error we get + # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs. + timeout = kwargs.get('timeout', 10) + count = 0 + while True: + errno = None + count += 1 + try: + ret = func(*args) + except ssl.SSLError as e: + # Note that we get a spurious -1/SSL_ERROR_SYSCALL for + # non-blocking IO. The SSL_shutdown manpage hints at this. + # It *should* be safe to just ignore SYS_ERROR_SYSCALL because + # with a Memory BIO there's no syscalls (for IO at least). + if e.errno not in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + errno = e.errno + # Get any data from the outgoing BIO irrespective of any error, and + # send it to the socket. + buf = outgoing.read() + sock.sendall(buf) + # If there's no error, we're done. For WANT_READ, we need to get + # data from the socket and put it in the incoming BIO. + if errno is None: + break + elif errno == ssl.SSL_ERROR_WANT_READ: + buf = sock.recv(32768) + if buf: + incoming.write(buf) + else: + incoming.write_eof() + if support.verbose: + sys.stdout.write("Needed %d calls to complete %s().\n" + % (count, func.__name__)) + return ret + + def test_handshake(self): + with support.transient_internet("svn.python.org"): + sock = socket.socket(socket.AF_INET) + sock.connect(("svn.python.org", 443)) + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_REQUIRED + ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT) + ctx.check_hostname = True + sslobj = ctx.wrap_bio(incoming, outgoing, False, 'svn.python.org') + self.assertIs(sslobj._sslobj.owner, sslobj) + self.assertIsNone(sslobj.cipher()) + self.assertIsNone(sslobj.shared_ciphers()) + self.assertRaises(ValueError, sslobj.getpeercert) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertIsNone(sslobj.get_channel_binding('tls-unique')) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + self.assertTrue(sslobj.cipher()) + self.assertIsNone(sslobj.shared_ciphers()) + self.assertTrue(sslobj.getpeercert()) + if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES: + self.assertTrue(sslobj.get_channel_binding('tls-unique')) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + self.assertRaises(ssl.SSLError, sslobj.write, b'foo') + sock.close() + + def test_read_write_data(self): + with support.transient_internet("svn.python.org"): + sock = socket.socket(socket.AF_INET) + sock.connect(("svn.python.org", 443)) + incoming = ssl.MemoryBIO() + outgoing = ssl.MemoryBIO() + ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ctx.verify_mode = ssl.CERT_NONE + sslobj = ctx.wrap_bio(incoming, outgoing, False) + self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake) + req = b'GET / HTTP/1.0\r\n\r\n' + self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req) + buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024) + self.assertEqual(buf[:5], b'HTTP/') + self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap) + sock.close() + + try: import threading except ImportError: @@ -1530,7 +1787,8 @@ else: try: self.sslconn = self.server.context.wrap_socket( self.sock, server_side=True) - self.server.selected_protocols.append(self.sslconn.selected_npn_protocol()) + self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol()) + self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol()) except (ssl.SSLError, ConnectionResetError) as e: # We treat ConnectionResetError as though it were an # SSLError - OpenSSL on Ubuntu abruptly closes the @@ -1547,6 +1805,7 @@ else: self.close() return False else: + self.server.shared_ciphers.append(self.sslconn.shared_ciphers()) if self.server.context.verify_mode == ssl.CERT_REQUIRED: cert = self.sslconn.getpeercert() if support.verbose and self.server.chatty: @@ -1637,7 +1896,8 @@ else: def __init__(self, certificate=None, ssl_version=None, certreqs=None, cacerts=None, chatty=True, connectionchatty=False, starttls_server=False, - npn_protocols=None, ciphers=None, context=None): + npn_protocols=None, alpn_protocols=None, + ciphers=None, context=None): if context: self.context = context else: @@ -1652,6 +1912,8 @@ else: self.context.load_cert_chain(certificate) if npn_protocols: self.context.set_npn_protocols(npn_protocols) + if alpn_protocols: + self.context.set_alpn_protocols(alpn_protocols) if ciphers: self.context.set_ciphers(ciphers) self.chatty = chatty @@ -1661,7 +1923,9 @@ else: self.port = support.bind_port(self.sock) self.flag = None self.active = False - self.selected_protocols = [] + self.selected_npn_protocols = [] + self.selected_alpn_protocols = [] + self.shared_ciphers = [] self.conn_errors = [] threading.Thread.__init__(self) self.daemon = True @@ -1681,7 +1945,7 @@ else: def run(self): self.sock.settimeout(0.05) - self.sock.listen(5) + self.sock.listen() self.active = True if self.flag: # signal an event @@ -1887,14 +2151,25 @@ else: 'compression': s.compression(), 'cipher': s.cipher(), 'peercert': s.getpeercert(), - 'client_npn_protocol': s.selected_npn_protocol() + 'client_alpn_protocol': s.selected_alpn_protocol(), + 'client_npn_protocol': s.selected_npn_protocol(), + 'version': s.version(), }) s.close() - stats['server_npn_protocols'] = server.selected_protocols + stats['server_alpn_protocols'] = server.selected_alpn_protocols + stats['server_npn_protocols'] = server.selected_npn_protocols + stats['server_shared_ciphers'] = server.shared_ciphers return stats def try_protocol_combo(server_protocol, client_protocol, expect_success, certsreqs=None, server_options=0, client_options=0): + """ + Try to SSL-connect using *client_protocol* to *server_protocol*. + If *expect_success* is true, assert that the connection succeeds, + if it's false, assert that the connection fails. + Also, if *expect_success* is a string, assert that it is the protocol + version actually used by the connection. + """ if certsreqs is None: certsreqs = ssl.CERT_NONE certtype = { @@ -1924,8 +2199,8 @@ else: ctx.load_cert_chain(CERTFILE) ctx.load_verify_locations(CERTFILE) try: - server_params_test(client_context, server_context, - chatty=False, connectionchatty=False) + stats = server_params_test(client_context, server_context, + chatty=False, connectionchatty=False) # Protocol mismatch can result in either an SSLError, or a # "Connection reset by peer" error. except ssl.SSLError: @@ -1940,6 +2215,10 @@ else: "Client protocol %s succeeded with server protocol %s!" % (ssl.get_protocol_name(client_protocol), ssl.get_protocol_name(server_protocol))) + elif (expect_success is not True + and expect_success != stats['version']): + raise AssertionError("version mismatch: expected %r, got %r" + % (expect_success, stats['version'])) class ThreadedTests(unittest.TestCase): @@ -2107,7 +2386,7 @@ else: # and sets Event `listener_gone` to let the main thread know # the socket is gone. def listener(): - s.listen(5) + s.listen() listener_ready.set() newsock, addr = s.accept() newsock.close() @@ -2172,19 +2451,19 @@ else: " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n" % str(x)) if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3') try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1') if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) if hasattr(ssl, 'PROTOCOL_SSLv3'): - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) # Server with specific SSL options if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2204,9 +2483,9 @@ else: """Connecting to an SSLv3 server with various client options""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3') + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED) if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False) try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False, @@ -2214,7 +2493,7 @@ else: try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False) if no_sslv2_implies_sslv3_hello(): # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs - try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, True, + try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, 'SSLv3', client_options=ssl.OP_NO_SSLv2) @skip_if_broken_ubuntu_ssl @@ -2222,9 +2501,9 @@ else: """Connecting to a TLSv1 server with various client options""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_OPTIONAL) - try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, True, ssl.CERT_REQUIRED) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1') + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL) + try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED) if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2240,7 +2519,7 @@ else: Testing against older TLS versions.""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') if hasattr(ssl, 'PROTOCOL_SSLv2'): try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False) if hasattr(ssl, 'PROTOCOL_SSLv3'): @@ -2248,7 +2527,7 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1_1) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1') try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False) @@ -2261,7 +2540,7 @@ else: Testing against older TLS versions.""" if support.verbose: sys.stdout.write("\n") - try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, True, + try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2', server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2, client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,) if hasattr(ssl, 'PROTOCOL_SSLv2'): @@ -2271,7 +2550,7 @@ else: try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False, client_options=ssl.OP_NO_TLSv1_2) - try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, True) + try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2') try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False) try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False) try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False) @@ -2507,6 +2786,36 @@ else: s.write(b"over\n") s.close() + def test_nonblocking_send(self): + server = ThreadedEchoServer(CERTFILE, + certreqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1, + cacerts=CERTFILE, + chatty=True, + connectionchatty=False) + with server: + s = ssl.wrap_socket(socket.socket(), + server_side=False, + certfile=CERTFILE, + ca_certs=CERTFILE, + cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_TLSv1) + s.connect((HOST, server.port)) + s.setblocking(False) + + # If we keep sending data, at some point the buffers + # will be full and the call will block + buf = bytearray(8192) + def fill_buffer(): + while True: + s.send(buf) + self.assertRaises((ssl.SSLWantWriteError, + ssl.SSLWantReadError), fill_buffer) + + # Now read all the output and discard it + s.setblocking(True) + s.close() + def test_handshake_timeout(self): # Issue #5103: SSL handshake must respect the socket timeout server = socket.socket(socket.AF_INET) @@ -2516,7 +2825,7 @@ else: finish = False def serve(): - server.listen(5) + server.listen() started.set() conns = [] while not finish: @@ -2573,7 +2882,7 @@ else: peer = None def serve(): nonlocal remote, peer - server.listen(5) + server.listen() # Block on the accept and wait on the connection to close. evt.set() remote, peer = server.accept() @@ -2623,6 +2932,21 @@ else: s.connect((HOST, server.port)) self.assertIn("no shared cipher", str(server.conn_errors[0])) + def test_version_basic(self): + """ + Basic tests for SSLSocket.version(). + More tests are done in the test_protocol_*() methods. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + with ThreadedEchoServer(CERTFILE, + ssl_version=ssl.PROTOCOL_TLSv1, + chatty=False) as server: + with context.wrap_socket(socket.socket()) as s: + self.assertIs(s.version(), None) + s.connect((HOST, server.port)) + self.assertEqual(s.version(), "TLSv1") + self.assertIs(s.version(), None) + @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL") def test_default_ecdh_curve(self): # Issue #21015: elliptic curve-based Diffie Hellman key exchange @@ -2732,6 +3056,55 @@ else: if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: self.fail("Non-DH cipher: " + cipher[0]) + def test_selected_alpn_protocol(self): + # selected_alpn_protocol() is None unless ALPN is used. + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.load_cert_chain(CERTFILE) + stats = server_params_test(context, context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required") + def test_selected_alpn_protocol_if_server_uses_alpn(self): + # selected_alpn_protocol() is None unless ALPN is used by the client. + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.load_verify_locations(CERTFILE) + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(CERTFILE) + server_context.set_alpn_protocols(['foo', 'bar']) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) + self.assertIs(stats['client_alpn_protocol'], None) + + @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test") + def test_alpn_protocols(self): + server_protocols = ['foo', 'bar', 'milkshake'] + protocol_tests = [ + (['foo', 'bar'], 'foo'), + (['bar', 'foo'], 'foo'), + (['milkshake'], 'milkshake'), + (['http/3.0', 'http/4.0'], None) + ] + for client_protocols, expected in protocol_tests: + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(CERTFILE) + server_context.set_alpn_protocols(server_protocols) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.load_cert_chain(CERTFILE) + client_context.set_alpn_protocols(client_protocols) + stats = server_params_test(client_context, server_context, + chatty=True, connectionchatty=True) + + msg = "failed trying %s (s) and %s (c).\n" \ + "was expecting %s, but got %%s from the %%s" \ + % (str(server_protocols), str(client_protocols), + str(expected)) + client_result = stats['client_alpn_protocol'] + self.assertEqual(client_result, expected, msg % (client_result, "client")) + server_result = stats['server_alpn_protocols'][-1] \ + if len(stats['server_alpn_protocols']) else 'nothing' + self.assertEqual(server_result, expected, msg % (server_result, "server")) + def test_selected_npn_protocol(self): # selected_npn_protocol() is None unless NPN is used context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) @@ -2872,6 +3245,20 @@ else: self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR') self.assertIn("TypeError", stderr.getvalue()) + def test_shared_ciphers(self): + server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + server_context.load_cert_chain(SIGNED_CERTFILE) + client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + client_context.verify_mode = ssl.CERT_REQUIRED + client_context.load_verify_locations(SIGNING_CA) + client_context.set_ciphers("RC4") + server_context.set_ciphers("AES:RC4") + stats = server_params_test(client_context, server_context) + ciphers = stats['server_shared_ciphers'][0] + self.assertGreater(len(ciphers), 0) + for name, tls_version, bits in ciphers: + self.assertIn("RC4", name.split("-")) + def test_read_write_after_close_raises_valuerror(self): context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) context.verify_mode = ssl.CERT_REQUIRED @@ -2887,21 +3274,46 @@ else: self.assertRaises(ValueError, s.read, 1024) self.assertRaises(ValueError, s.write, b'hello') + def test_sendfile(self): + TEST_DATA = b"x" * 512 + with open(support.TESTFN, 'wb') as f: + f.write(TEST_DATA) + self.addCleanup(support.unlink, support.TESTFN) + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + with open(support.TESTFN, 'rb') as file: + s.sendfile(file) + self.assertEqual(s.recv(1024), TEST_DATA) + def test_main(verbose=False): if support.verbose: + import warnings plats = { 'Linux': platform.linux_distribution, 'Mac': platform.mac_ver, 'Windows': platform.win32_ver, } - for name, func in plats.items(): - plat = func() - if plat and plat[0]: - plat = '%s %r' % (name, plat) - break - else: - plat = repr(platform.platform()) + with warnings.catch_warnings(): + warnings.filterwarnings( + 'ignore', + 'dist\(\) and linux_distribution\(\) ' + 'functions are deprecated .*', + PendingDeprecationWarning, + ) + for name, func in plats.items(): + plat = func() + if plat and plat[0]: + plat = '%s %r' % (name, plat) + break + else: + plat = repr(platform.platform()) print("test_ssl: testing with %r %r" % (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO)) print(" under %s" % plat) @@ -2920,10 +3332,11 @@ def test_main(verbose=False): if not os.path.exists(filename): raise support.TestFailed("Can't read certificate file %r" % filename) - tests = [ContextTests, BasicSocketTests, SSLErrorTests] + tests = [ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests] if support.is_resource_enabled('network'): tests.append(NetworkedTests) + tests.append(NetworkedBIOTests) if _have_threads: thread_info = support.threading_setup() diff --git a/Lib/test/test_startfile.py b/Lib/test/test_startfile.py index 43abf9b..f59252e 100644 --- a/Lib/test/test_startfile.py +++ b/Lib/test/test_startfile.py @@ -30,8 +30,5 @@ class TestCase(unittest.TestCase): startfile(empty) startfile(empty, "open") -def test_main(): - support.run_unittest(TestCase) - -if __name__=="__main__": - test_main() +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_stat.py b/Lib/test/test_stat.py index af6ced4..f1a5938 100644 --- a/Lib/test/test_stat.py +++ b/Lib/test/test_stat.py @@ -1,5 +1,6 @@ import unittest import os +import sys from test.support import TESTFN, import_fresh_module c_stat = import_fresh_module('stat', fresh=['_stat']) @@ -52,6 +53,26 @@ class TestFilemode: 'S_IWOTH': 0o002, 'S_IXOTH': 0o001} + # defined by the Windows API documentation + file_attributes = { + 'FILE_ATTRIBUTE_ARCHIVE': 32, + 'FILE_ATTRIBUTE_COMPRESSED': 2048, + 'FILE_ATTRIBUTE_DEVICE': 64, + 'FILE_ATTRIBUTE_DIRECTORY': 16, + 'FILE_ATTRIBUTE_ENCRYPTED': 16384, + 'FILE_ATTRIBUTE_HIDDEN': 2, + 'FILE_ATTRIBUTE_INTEGRITY_STREAM': 32768, + 'FILE_ATTRIBUTE_NORMAL': 128, + 'FILE_ATTRIBUTE_NOT_CONTENT_INDEXED': 8192, + 'FILE_ATTRIBUTE_NO_SCRUB_DATA': 131072, + 'FILE_ATTRIBUTE_OFFLINE': 4096, + 'FILE_ATTRIBUTE_READONLY': 1, + 'FILE_ATTRIBUTE_REPARSE_POINT': 1024, + 'FILE_ATTRIBUTE_SPARSE_FILE': 512, + 'FILE_ATTRIBUTE_SYSTEM': 4, + 'FILE_ATTRIBUTE_TEMPORARY': 256, + 'FILE_ATTRIBUTE_VIRTUAL': 65536} + def setUp(self): try: os.remove(TESTFN) @@ -185,6 +206,14 @@ class TestFilemode: self.assertTrue(callable(func)) self.assertEqual(func(0), 0) + @unittest.skipUnless(sys.platform == "win32", + "FILE_ATTRIBUTE_* constants are Win32 specific") + def test_file_attribute_constants(self): + for key, value in sorted(self.file_attributes.items()): + self.assertTrue(hasattr(self.statmod, key), key) + modvalue = getattr(self.statmod, key) + self.assertEqual(value, modvalue, key) + class TestFilemodeCStat(TestFilemode, unittest.TestCase): statmod = c_stat diff --git a/Lib/test/test_string.py b/Lib/test/test_string.py index 57963bf..0cd2b86 100644 --- a/Lib/test/test_string.py +++ b/Lib/test/test_string.py @@ -1,19 +1,22 @@ import unittest, string -from test import support class ModuleTest(unittest.TestCase): def test_attrs(self): - string.whitespace - string.ascii_lowercase - string.ascii_uppercase - string.ascii_letters - string.digits - string.hexdigits - string.octdigits - string.punctuation - string.printable + # While the exact order of the items in these attributes is not + # technically part of the "language spec", in practice there is almost + # certainly user code that depends on the order, so de-facto it *is* + # part of the spec. + self.assertEqual(string.whitespace, ' \t\n\r\x0b\x0c') + self.assertEqual(string.ascii_lowercase, 'abcdefghijklmnopqrstuvwxyz') + self.assertEqual(string.ascii_uppercase, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ') + self.assertEqual(string.ascii_letters, string.ascii_lowercase + string.ascii_uppercase) + self.assertEqual(string.digits, '0123456789') + self.assertEqual(string.hexdigits, string.digits + 'abcdefABCDEF') + self.assertEqual(string.octdigits, '01234567') + self.assertEqual(string.punctuation, '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~') + self.assertEqual(string.printable, string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + string.whitespace) def test_capwords(self): self.assertEqual(string.capwords('abc def ghi'), 'Abc Def Ghi') @@ -43,8 +46,9 @@ class ModuleTest(unittest.TestCase): self.assertEqual(fmt.format("-{format_string}-", format_string='test'), '-test-') self.assertRaises(KeyError, fmt.format, "-{format_string}-") - self.assertEqual(fmt.format(arg='test', format_string="-{arg}-"), - '-test-') + with self.assertWarnsRegex(DeprecationWarning, "format_string"): + self.assertEqual(fmt.format(arg='test', format_string="-{arg}-"), + '-test-') def test_auto_numbering(self): fmt = string.Formatter() @@ -183,8 +187,5 @@ class ModuleTest(unittest.TestCase): self.assertIn("recursion", str(err.exception)) -def test_main(): - support.run_unittest(ModuleTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_stringprep.py b/Lib/test/test_stringprep.py index e763635..d4b4a13 100644 --- a/Lib/test/test_stringprep.py +++ b/Lib/test/test_stringprep.py @@ -2,7 +2,6 @@ # Since we don't have them, this test checks only a few code points. import unittest -from test import support from stringprep import * @@ -89,8 +88,5 @@ class StringprepTests(unittest.TestCase): # h.update(data) # print p, h.hexdigest() -def test_main(): - support.run_unittest(StringprepTests) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_strlit.py b/Lib/test/test_strlit.py index d01322f..87cffe8 100644 --- a/Lib/test/test_strlit.py +++ b/Lib/test/test_strlit.py @@ -32,7 +32,6 @@ import sys import shutil import tempfile import unittest -import test.support TEMPLATE = r"""# coding: %s @@ -199,8 +198,5 @@ class TestLiterals(unittest.TestCase): self.check_encoding("latin9") -def test_main(): - test.support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_strptime.py b/Lib/test/test_strptime.py index 2a6f3f8..346e2c6 100644 --- a/Lib/test/test_strptime.py +++ b/Lib/test/test_strptime.py @@ -578,18 +578,5 @@ class CacheTests(unittest.TestCase): locale.setlocale(locale.LC_TIME, locale_info) -def test_main(): - support.run_unittest( - getlang_Tests, - LocaleTime_Tests, - TimeRETests, - StrptimeTests, - Strptime12AMPMTests, - JulianTests, - CalculationTests, - CacheTests - ) - - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_strtod.py b/Lib/test/test_strtod.py index 41b6e5f..2727514 100644 --- a/Lib/test/test_strtod.py +++ b/Lib/test/test_strtod.py @@ -429,8 +429,5 @@ class StrtodTests(unittest.TestCase): for s in test_strings: self.check_strtod(s) -def test_main(): - test.support.run_unittest(StrtodTests) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py index 0107eeb..efbdbfc 100644 --- a/Lib/test/test_struct.py +++ b/Lib/test/test_struct.py @@ -660,8 +660,5 @@ class UnpackIteratorTest(unittest.TestCase): self.assertRaises(StopIteration, next, it) -def test_main(): - support.run_unittest(__name__) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_structmembers.py b/Lib/test/test_structmembers.py index 1c931ae..57ec45f 100644 --- a/Lib/test/test_structmembers.py +++ b/Lib/test/test_structmembers.py @@ -140,8 +140,5 @@ class TestWarnings(unittest.TestCase): ts.T_USHORT = USHRT_MAX+1 -def test_main(verbose=None): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py index 353d0ea..3ecb27d 100644 --- a/Lib/test/test_structseq.py +++ b/Lib/test/test_structseq.py @@ -1,7 +1,6 @@ import os import time import unittest -from test import support class StructSeqTest(unittest.TestCase): @@ -123,8 +122,5 @@ class StructSeqTest(unittest.TestCase): self.assertEqual(list(t[start:stop:step]), L[start:stop:step]) -def test_main(): - support.run_unittest(StructSeqTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_subprocess.py b/Lib/test/test_subprocess.py index 2ce5951..5fd6526 100644 --- a/Lib/test/test_subprocess.py +++ b/Lib/test/test_subprocess.py @@ -1,5 +1,5 @@ import unittest -from test import script_helper +from test.support import script_helper from test import support import subprocess import sys @@ -381,7 +381,7 @@ class ProcessTestCase(BaseTestCase): python_dir, python_base = self._split_python_path() abs_python = os.path.join(python_dir, python_base) rel_python = os.path.join(os.curdir, python_base) - with script_helper.temp_dir() as wrong_dir: + with support.temp_dir() as wrong_dir: # Before calling with an absolute path, confirm that using a # relative path fails. self.assertRaises(FileNotFoundError, subprocess.Popen, @@ -1219,6 +1219,102 @@ class ProcessTestCase(BaseTestCase): fds_after_exception = os.listdir(fd_directory) self.assertEqual(fds_before_popen, fds_after_exception) + +class RunFuncTestCase(BaseTestCase): + def run_python(self, code, **kwargs): + """Run Python code in a subprocess using subprocess.run""" + argv = [sys.executable, "-c", code] + return subprocess.run(argv, **kwargs) + + def test_returncode(self): + # call() function with sequence argument + cp = self.run_python("import sys; sys.exit(47)") + self.assertEqual(cp.returncode, 47) + with self.assertRaises(subprocess.CalledProcessError): + cp.check_returncode() + + def test_check(self): + with self.assertRaises(subprocess.CalledProcessError) as c: + self.run_python("import sys; sys.exit(47)", check=True) + self.assertEqual(c.exception.returncode, 47) + + def test_check_zero(self): + # check_returncode shouldn't raise when returncode is zero + cp = self.run_python("import sys; sys.exit(0)", check=True) + self.assertEqual(cp.returncode, 0) + + def test_timeout(self): + # run() function with timeout argument; we want to test that the child + # process gets killed when the timeout expires. If the child isn't + # killed, this call will deadlock since subprocess.run waits for the + # child. + with self.assertRaises(subprocess.TimeoutExpired): + self.run_python("while True: pass", timeout=0.0001) + + def test_capture_stdout(self): + # capture stdout with zero return code + cp = self.run_python("print('BDFL')", stdout=subprocess.PIPE) + self.assertIn(b'BDFL', cp.stdout) + + def test_capture_stderr(self): + cp = self.run_python("import sys; sys.stderr.write('BDFL')", + stderr=subprocess.PIPE) + self.assertIn(b'BDFL', cp.stderr) + + def test_check_output_stdin_arg(self): + # run() can be called with stdin set to a file + tf = tempfile.TemporaryFile() + self.addCleanup(tf.close) + tf.write(b'pear') + tf.seek(0) + cp = self.run_python( + "import sys; sys.stdout.write(sys.stdin.read().upper())", + stdin=tf, stdout=subprocess.PIPE) + self.assertIn(b'PEAR', cp.stdout) + + def test_check_output_input_arg(self): + # check_output() can be called with input set to a string + cp = self.run_python( + "import sys; sys.stdout.write(sys.stdin.read().upper())", + input=b'pear', stdout=subprocess.PIPE) + self.assertIn(b'PEAR', cp.stdout) + + def test_check_output_stdin_with_input_arg(self): + # run() refuses to accept 'stdin' with 'input' + tf = tempfile.TemporaryFile() + self.addCleanup(tf.close) + tf.write(b'pear') + tf.seek(0) + with self.assertRaises(ValueError, + msg="Expected ValueError when stdin and input args supplied.") as c: + output = self.run_python("print('will not be run')", + stdin=tf, input=b'hare') + self.assertIn('stdin', c.exception.args[0]) + self.assertIn('input', c.exception.args[0]) + + def test_check_output_timeout(self): + with self.assertRaises(subprocess.TimeoutExpired) as c: + cp = self.run_python(( + "import sys, time\n" + "sys.stdout.write('BDFL')\n" + "sys.stdout.flush()\n" + "time.sleep(3600)"), + # Some heavily loaded buildbots (sparc Debian 3.x) require + # this much time to start and print. + timeout=3, stdout=subprocess.PIPE) + self.assertEqual(c.exception.output, b'BDFL') + # output is aliased to stdout + self.assertEqual(c.exception.stdout, b'BDFL') + + def test_run_kwargs(self): + newenv = os.environ.copy() + newenv["FRUIT"] = "banana" + cp = self.run_python(('import sys, os;' + 'sys.exit(33 if os.getenv("FRUIT")=="banana" else 31)'), + env=newenv) + self.assertEqual(cp.returncode, 33) + + @unittest.skipIf(mswindows, "POSIX specific tests") class POSIXProcessTestCase(BaseTestCase): @@ -2407,24 +2503,20 @@ class ProcessTestCaseNoPoll(ProcessTestCase): subprocess._PopenSelector = self.orig_selector ProcessTestCase.tearDown(self) + def test__all__(self): + """Ensure that __all__ is populated properly.""" + intentionally_excluded = set(("list2cmdline",)) + exported = set(subprocess.__all__) + possible_exports = set() + import types + for name, value in subprocess.__dict__.items(): + if name.startswith('_'): + continue + if isinstance(value, (types.ModuleType,)): + continue + possible_exports.add(name) + self.assertEqual(exported, possible_exports - intentionally_excluded) -class HelperFunctionTests(unittest.TestCase): - @unittest.skipIf(mswindows, "errno and EINTR make no sense on windows") - def test_eintr_retry_call(self): - record_calls = [] - def fake_os_func(*args): - record_calls.append(args) - if len(record_calls) == 2: - raise OSError(errno.EINTR, "fake interrupted system call") - return tuple(reversed(args)) - - self.assertEqual((999, 256), - subprocess._eintr_retry_call(fake_os_func, 256, 999)) - self.assertEqual([(256, 999)], record_calls) - # This time there will be an EINTR so it will loop once. - self.assertEqual((666,), - subprocess._eintr_retry_call(fake_os_func, 666)) - self.assertEqual([(256, 999), (666,), (666,)], record_calls) @unittest.skipUnless(mswindows, "Windows-specific tests") @@ -2531,9 +2623,9 @@ def test_main(): Win32ProcessTestCase, CommandTests, ProcessTestCaseNoPoll, - HelperFunctionTests, CommandsWithSpaces, ContextManagerTests, + RunFuncTestCase, ) support.run_unittest(*unit_tests) diff --git a/Lib/test/test_sundry.py b/Lib/test/test_sundry.py index e99ca9e..1fb9964 100644 --- a/Lib/test/test_sundry.py +++ b/Lib/test/test_sundry.py @@ -22,8 +22,6 @@ class TestUntestedModules(unittest.TestCase): import distutils.ccompiler import distutils.cygwinccompiler import distutils.filelist - if sys.platform.startswith('win'): - import distutils.msvccompiler import distutils.text_file import distutils.unixccompiler diff --git a/Lib/test/test_super.py b/Lib/test/test_super.py index 37fc2d9..dc3a15f 100644 --- a/Lib/test/test_super.py +++ b/Lib/test/test_super.py @@ -2,7 +2,6 @@ import sys import unittest -from test import support class A: @@ -173,9 +172,5 @@ class TestSuper(unittest.TestCase): self.assertRaises(TypeError, X.meth, c) -def test_main(): - support.run_unittest(TestSuper) - - if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py index 03ce9d1..2c00417 100644 --- a/Lib/test/test_support.py +++ b/Lib/test/test_support.py @@ -85,7 +85,7 @@ class TestSupport(unittest.TestCase): def test_bind_port(self): s = socket.socket() support.bind_port(s) - s.listen(1) + s.listen() s.close() # Tests for temp_dir() @@ -280,6 +280,38 @@ class TestSupport(unittest.TestCase): self.assertEqual(D["item"], 5) self.assertEqual(D["item"], 1) + class RefClass: + attribute1 = None + attribute2 = None + _hidden_attribute1 = None + __magic_1__ = None + + class OtherClass: + attribute2 = None + attribute3 = None + __magic_1__ = None + __magic_2__ = None + + def test_detect_api_mismatch(self): + missing_items = support.detect_api_mismatch(self.RefClass, + self.OtherClass) + self.assertEqual({'attribute1'}, missing_items) + + missing_items = support.detect_api_mismatch(self.OtherClass, + self.RefClass) + self.assertEqual({'attribute3', '__magic_2__'}, missing_items) + + def test_detect_api_mismatch__ignore(self): + ignore = ['attribute1', 'attribute3', '__magic_2__', 'not_in_either'] + + missing_items = support.detect_api_mismatch( + self.RefClass, self.OtherClass, ignore=ignore) + self.assertEqual(set(), missing_items) + + missing_items = support.detect_api_mismatch( + self.OtherClass, self.RefClass, ignore=ignore) + self.assertEqual(set(), missing_items) + # XXX -follows a list of untested API # make_legacy_pyc # is_resource_enabled diff --git a/Lib/test/test_symtable.py b/Lib/test/test_symtable.py index 335b4dc..e5e7b83 100644 --- a/Lib/test/test_symtable.py +++ b/Lib/test/test_symtable.py @@ -4,7 +4,6 @@ Test the API of the symtable module. import symtable import unittest -from test import support TEST_CODE = """ @@ -169,8 +168,5 @@ class SymtableTest(unittest.TestCase): symbols = symtable.symtable("def f(x): return x", "?", "exec") -def test_main(): - support.run_unittest(SymtableTest) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_syntax.py b/Lib/test/test_syntax.py index a9d3628..a22cebb 100644 --- a/Lib/test/test_syntax.py +++ b/Lib/test/test_syntax.py @@ -141,6 +141,9 @@ From ast_for_call(): >>> f(x for x in L, 1) Traceback (most recent call last): SyntaxError: Generator expression must be parenthesized if not sole argument +>>> f(x for x in L, y for y in L) +Traceback (most recent call last): +SyntaxError: Generator expression must be parenthesized if not sole argument >>> f((x for x in L), 1) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] @@ -582,7 +585,18 @@ class SyntaxTestCase(unittest.TestCase): subclass=IndentationError) def test_kwargs_last(self): - self._check_error("int(base=10, '2')", "non-keyword arg") + self._check_error("int(base=10, '2')", + "positional argument follows keyword argument") + + def test_kwargs_last2(self): + self._check_error("int(**{base: 10}, '2')", + "positional argument follows " + "keyword argument unpacking") + + def test_kwargs_last3(self): + self._check_error("int(**{base: 10}, *['2'])", + "iterable argument unpacking follows " + "keyword argument unpacking") def test_main(): support.run_unittest(SyntaxTestCase) diff --git a/Lib/test/test_sys.py b/Lib/test/test_sys.py index b7ddbe9..8ec38c8 100644 --- a/Lib/test/test_sys.py +++ b/Lib/test/test_sys.py @@ -1,5 +1,5 @@ import unittest, test.support -from test.script_helper import assert_python_ok, assert_python_failure +from test.support.script_helper import assert_python_ok, assert_python_failure import sys, io, os import struct import subprocess @@ -212,8 +212,8 @@ class SysModuleTest(unittest.TestCase): for i in (50, 1000): # Issue #5392: stack overflow after hitting recursion limit twice sys.setrecursionlimit(i) - self.assertRaises(RuntimeError, f) - self.assertRaises(RuntimeError, f) + self.assertRaises(RecursionError, f) + self.assertRaises(RecursionError, f) finally: sys.setrecursionlimit(oldlimit) @@ -226,7 +226,7 @@ class SysModuleTest(unittest.TestCase): def f(): try: f() - except RuntimeError: + except RecursionError: f() sys.setrecursionlimit(%d) @@ -637,6 +637,53 @@ class SysModuleTest(unittest.TestCase): expected = None self.check_fsencoding(fs_encoding, expected) + def c_locale_get_error_handler(self, isolated=False, encoding=None): + # Force the POSIX locale + env = os.environ.copy() + env["LC_ALL"] = "C" + code = '\n'.join(( + 'import sys', + 'def dump(name):', + ' std = getattr(sys, name)', + ' print("%s: %s" % (name, std.errors))', + 'dump("stdin")', + 'dump("stdout")', + 'dump("stderr")', + )) + args = [sys.executable, "-c", code] + if isolated: + args.append("-I") + elif encoding: + env['PYTHONIOENCODING'] = encoding + p = subprocess.Popen(args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + universal_newlines=True) + stdout, stderr = p.communicate() + return stdout + + def test_c_locale_surrogateescape(self): + out = self.c_locale_get_error_handler(isolated=True) + self.assertEqual(out, + 'stdin: surrogateescape\n' + 'stdout: surrogateescape\n' + 'stderr: backslashreplace\n') + + # replace the default error handler + out = self.c_locale_get_error_handler(encoding=':strict') + self.assertEqual(out, + 'stdin: strict\n' + 'stdout: strict\n' + 'stderr: backslashreplace\n') + + # force the encoding + out = self.c_locale_get_error_handler(encoding='iso8859-1') + self.assertEqual(out, + 'stdin: surrogateescape\n' + 'stdout: surrogateescape\n' + 'stderr: backslashreplace\n') + def test_implementation(self): # This test applies to all implementations equally. @@ -662,7 +709,7 @@ class SysModuleTest(unittest.TestCase): @test.support.cpython_only def test_debugmallocstats(self): # Test sys._debugmallocstats() - from test.script_helper import assert_python_ok + from test.support.script_helper import assert_python_ok args = ['-c', 'import sys; sys._debugmallocstats()'] ret, out, err = assert_python_ok(*args) self.assertIn(b"free PyDictObjects", err) @@ -699,6 +746,27 @@ class SysModuleTest(unittest.TestCase): c = sys.getallocatedblocks() self.assertIn(c, range(b - 50, b + 50)) + def test_is_finalizing(self): + self.assertIs(sys.is_finalizing(), False) + # Don't use the atexit module because _Py_Finalizing is only set + # after calling atexit callbacks + code = """if 1: + import sys + + class AtExit: + is_finalizing = sys.is_finalizing + print = print + + def __del__(self): + self.print(self.is_finalizing(), flush=True) + + # Keep a reference in the __main__ module namespace, so the + # AtExit destructor will be called at Python exit + ref = AtExit() + """ + rc, stdout, stderr = assert_python_ok('-c', code) + self.assertEqual(stdout.rstrip(), b'True') + @test.support.cpython_only class SizeofTest(unittest.TestCase): @@ -771,7 +839,7 @@ class SizeofTest(unittest.TestCase): # buffer # XXX # builtin_function_or_method - check(len, size('3P')) # XXX check layout + check(len, size('4P')) # XXX check layout # bytearray samples = [b'', b'u'*100000] for sample in samples: @@ -875,7 +943,7 @@ class SizeofTest(unittest.TestCase): check(bar, size('PP')) # generator def get_gen(): yield 1 - check(get_gen(), size('Pb2P')) + check(get_gen(), size('Pb2PPP')) # iterator check(iter('abc'), size('lP')) # callable-iterator @@ -929,7 +997,7 @@ class SizeofTest(unittest.TestCase): # frozenset PySet_MINSIZE = 8 samples = [[], range(10), range(50)] - s = size('3n2P' + PySet_MINSIZE*'nP' + 'nP') + s = size('3nP' + PySet_MINSIZE*'nP' + '2nP') for sample in samples: minused = len(sample) if minused == 0: tmp = 1 @@ -958,9 +1026,9 @@ class SizeofTest(unittest.TestCase): # static type: PyTypeObject s = vsize('P2n15Pl4Pn9Pn11PIP') check(int, s) - # (PyTypeObject + PyNumberMethods + PyMappingMethods + + # (PyTypeObject + PyAsyncMethods + PyNumberMethods + PyMappingMethods + # PySequenceMethods + PyBufferProcs + 4P) - s = vsize('P2n15Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 10P 2P 4P') + s = vsize('P2n17Pl4Pn9Pn11PIP') + struct.calcsize('34P 3P 3P 10P 2P 4P') # Separate block for PyDictKeysObject with 4 entries s += struct.calcsize("2nPn") + 4*struct.calcsize("n2P") # class diff --git a/Lib/test/test_sys_setprofile.py b/Lib/test/test_sys_setprofile.py index 9816e3e..bb71acd 100644 --- a/Lib/test/test_sys_setprofile.py +++ b/Lib/test/test_sys_setprofile.py @@ -3,7 +3,6 @@ import pprint import sys import unittest -from test import support class TestGetProfile(unittest.TestCase): def setUp(self): @@ -260,7 +259,6 @@ class ProfileHookTestCase(TestCaseBase): def f(): for i in range(2): yield i - raise StopIteration def g(p): for i in f(): pass @@ -374,13 +372,5 @@ def show_events(callable): pprint.pprint(capture_events(callable)) -def test_main(): - support.run_unittest( - TestGetProfile, - ProfileHookTestCase, - ProfileSimulatorTestCase - ) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_sysconfig.py b/Lib/test/test_sysconfig.py index 8ed729a..0917c3e 100644 --- a/Lib/test/test_sysconfig.py +++ b/Lib/test/test_sysconfig.py @@ -385,6 +385,25 @@ class TestSysConfig(unittest.TestCase): self.assertIsNotNone(vars['SO']) self.assertEqual(vars['SO'], vars['EXT_SUFFIX']) + @unittest.skipUnless(sys.platform == 'linux', 'Linux-specific test') + def test_triplet_in_ext_suffix(self): + import ctypes, platform, re + machine = platform.machine() + suffix = sysconfig.get_config_var('EXT_SUFFIX') + if re.match('(aarch64|arm|mips|ppc|powerpc|s390|sparc)', machine): + self.assertTrue('linux' in suffix, suffix) + if re.match('(i[3-6]86|x86_64)$', machine): + if ctypes.sizeof(ctypes.c_char_p()) == 4: + self.assertTrue(suffix.endswith('i386-linux-gnu.so') \ + or suffix.endswith('x86_64-linux-gnux32.so'), + suffix) + else: # 8 byte pointer size + self.assertTrue(suffix.endswith('x86_64-linux-gnu.so'), suffix) + + @unittest.skipUnless(sys.platform == 'darwin', 'OS X-specific test') + def test_osx_ext_suffix(self): + suffix = sysconfig.get_config_var('EXT_SUFFIX') + self.assertTrue(suffix.endswith('-darwin.so'), suffix) class MakefileTests(unittest.TestCase): diff --git a/Lib/test/test_syslog.py b/Lib/test/test_syslog.py index b7fd2bd..6f902f1 100644 --- a/Lib/test/test_syslog.py +++ b/Lib/test/test_syslog.py @@ -36,8 +36,5 @@ class Test(unittest.TestCase): syslog.openlog() syslog.syslog('test message from python test_syslog') -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py index 3091ce7..1412cae 100644 --- a/Lib/test/test_tarfile.py +++ b/Lib/test/test_tarfile.py @@ -2,11 +2,14 @@ import sys import os import io from hashlib import md5 +from contextlib import contextmanager import unittest +import unittest.mock import tarfile -from test import support, script_helper +from test import support +from test.support import script_helper # Check for our compression modules. try: @@ -285,6 +288,18 @@ class ListTest(ReadTest, unittest.TestCase): self.assertIn(b'pax' + (b'/123' * 125) + b'/longlink link to pax' + (b'/123' * 125) + b'/longname', out) + def test_list_members(self): + tio = io.TextIOWrapper(io.BytesIO(), 'ascii', newline='\n') + def members(tar): + for tarinfo in tar.getmembers(): + if 'reg' in tarinfo.name: + yield tarinfo + with support.swap_attr(sys, 'stdout', tio): + self.tar.list(verbose=False, members=members(self.tar)) + out = tio.detach().getvalue() + self.assertIn(b'ustar/regtype', out) + self.assertNotIn(b'ustar/conttype', out) + class GzipListTest(GzipTest, ListTest): pass @@ -990,6 +1005,19 @@ class WriteTestBase(TarTest): self.assertFalse(fobj.closed) self.assertEqual(data, fobj.getvalue()) + def test_eof_marker(self): + # Make sure an end of archive marker is written (two zero blocks). + # tarfile insists on aligning archives to a 20 * 512 byte recordsize. + # So, we create an archive that has exactly 10240 bytes without the + # marker, and has 20480 bytes once the marker is written. + with tarfile.open(tmpname, self.mode) as tar: + t = tarfile.TarInfo("foo") + t.size = tarfile.RECORDSIZE - tarfile.BLOCKSIZE + tar.addfile(t, io.BytesIO(b"a" * t.size)) + + with self.open(tmpname, "rb") as fobj: + self.assertEqual(len(fobj.read()), tarfile.RECORDSIZE * 2) + class WriteTest(WriteTestBase, unittest.TestCase): @@ -1433,6 +1461,88 @@ class GNUWriteTest(unittest.TestCase): ("longlnk/" * 127) + "longlink_") +class CreateTest(WriteTestBase, unittest.TestCase): + + prefix = "x:" + + file_path = os.path.join(TEMPDIR, "spameggs42") + + def setUp(self): + support.unlink(tmpname) + + @classmethod + def setUpClass(cls): + with open(cls.file_path, "wb") as fobj: + fobj.write(b"aaa") + + @classmethod + def tearDownClass(cls): + support.unlink(cls.file_path) + + def test_create(self): + with tarfile.open(tmpname, self.mode) as tobj: + tobj.add(self.file_path) + + with self.taropen(tmpname) as tobj: + names = tobj.getnames() + self.assertEqual(len(names), 1) + self.assertIn('spameggs42', names[0]) + + def test_create_existing(self): + with tarfile.open(tmpname, self.mode) as tobj: + tobj.add(self.file_path) + + with self.assertRaises(FileExistsError): + tobj = tarfile.open(tmpname, self.mode) + + with self.taropen(tmpname) as tobj: + names = tobj.getnames() + self.assertEqual(len(names), 1) + self.assertIn('spameggs42', names[0]) + + def test_create_taropen(self): + with self.taropen(tmpname, "x") as tobj: + tobj.add(self.file_path) + + with self.taropen(tmpname) as tobj: + names = tobj.getnames() + self.assertEqual(len(names), 1) + self.assertIn('spameggs42', names[0]) + + def test_create_existing_taropen(self): + with self.taropen(tmpname, "x") as tobj: + tobj.add(self.file_path) + + with self.assertRaises(FileExistsError): + with self.taropen(tmpname, "x"): + pass + + with self.taropen(tmpname) as tobj: + names = tobj.getnames() + self.assertEqual(len(names), 1) + self.assertIn("spameggs42", names[0]) + + +class GzipCreateTest(GzipTest, CreateTest): + pass + + +class Bz2CreateTest(Bz2Test, CreateTest): + pass + + +class LzmaCreateTest(LzmaTest, CreateTest): + pass + + +class CreateWithXModeTest(CreateTest): + + prefix = "x" + + test_create_taropen = None + test_create_existing_taropen = None + + @unittest.skipUnless(hasattr(os, "link"), "Missing hardlink implementation") class HardlinkTest(unittest.TestCase): # Test the creation of LNKTYPE (hardlink) members in an archive. @@ -2191,6 +2301,138 @@ class Bz2PartialReadTest(Bz2Test, unittest.TestCase): self._test_partial_input("r:bz2") +def root_is_uid_gid_0(): + try: + import pwd, grp + except ImportError: + return False + if pwd.getpwuid(0)[0] != 'root': + return False + if grp.getgrgid(0)[0] != 'root': + return False + return True + + +@unittest.skipUnless(hasattr(os, 'chown'), "missing os.chown") +@unittest.skipUnless(hasattr(os, 'geteuid'), "missing os.geteuid") +class NumericOwnerTest(unittest.TestCase): + # mock the following: + # os.chown: so we can test what's being called + # os.chmod: so the modes are not actually changed. if they are, we can't + # delete the files/directories + # os.geteuid: so we can lie and say we're root (uid = 0) + + @staticmethod + def _make_test_archive(filename_1, dirname_1, filename_2): + # the file contents to write + fobj = io.BytesIO(b"content") + + # create a tar file with a file, a directory, and a file within that + # directory. Assign various .uid/.gid values to them + items = [(filename_1, 99, 98, tarfile.REGTYPE, fobj), + (dirname_1, 77, 76, tarfile.DIRTYPE, None), + (filename_2, 88, 87, tarfile.REGTYPE, fobj), + ] + with tarfile.open(tmpname, 'w') as tarfl: + for name, uid, gid, typ, contents in items: + t = tarfile.TarInfo(name) + t.uid = uid + t.gid = gid + t.uname = 'root' + t.gname = 'root' + t.type = typ + tarfl.addfile(t, contents) + + # return the full pathname to the tar file + return tmpname + + @staticmethod + @contextmanager + def _setup_test(mock_geteuid): + mock_geteuid.return_value = 0 # lie and say we're root + fname = 'numeric-owner-testfile' + dirname = 'dir' + + # the names we want stored in the tarfile + filename_1 = fname + dirname_1 = dirname + filename_2 = os.path.join(dirname, fname) + + # create the tarfile with the contents we're after + tar_filename = NumericOwnerTest._make_test_archive(filename_1, + dirname_1, + filename_2) + + # open the tarfile for reading. yield it and the names of the items + # we stored into the file + with tarfile.open(tar_filename) as tarfl: + yield tarfl, filename_1, dirname_1, filename_2 + + @unittest.mock.patch('os.chown') + @unittest.mock.patch('os.chmod') + @unittest.mock.patch('os.geteuid') + def test_extract_with_numeric_owner(self, mock_geteuid, mock_chmod, + mock_chown): + with self._setup_test(mock_geteuid) as (tarfl, filename_1, _, + filename_2): + tarfl.extract(filename_1, TEMPDIR, numeric_owner=True) + tarfl.extract(filename_2 , TEMPDIR, numeric_owner=True) + + # convert to filesystem paths + f_filename_1 = os.path.join(TEMPDIR, filename_1) + f_filename_2 = os.path.join(TEMPDIR, filename_2) + + mock_chown.assert_has_calls([unittest.mock.call(f_filename_1, 99, 98), + unittest.mock.call(f_filename_2, 88, 87), + ], + any_order=True) + + @unittest.mock.patch('os.chown') + @unittest.mock.patch('os.chmod') + @unittest.mock.patch('os.geteuid') + def test_extractall_with_numeric_owner(self, mock_geteuid, mock_chmod, + mock_chown): + with self._setup_test(mock_geteuid) as (tarfl, filename_1, dirname_1, + filename_2): + tarfl.extractall(TEMPDIR, numeric_owner=True) + + # convert to filesystem paths + f_filename_1 = os.path.join(TEMPDIR, filename_1) + f_dirname_1 = os.path.join(TEMPDIR, dirname_1) + f_filename_2 = os.path.join(TEMPDIR, filename_2) + + mock_chown.assert_has_calls([unittest.mock.call(f_filename_1, 99, 98), + unittest.mock.call(f_dirname_1, 77, 76), + unittest.mock.call(f_filename_2, 88, 87), + ], + any_order=True) + + # this test requires that uid=0 and gid=0 really be named 'root'. that's + # because the uname and gname in the test file are 'root', and extract() + # will look them up using pwd and grp to find their uid and gid, which we + # test here to be 0. + @unittest.skipUnless(root_is_uid_gid_0(), + 'uid=0,gid=0 must be named "root"') + @unittest.mock.patch('os.chown') + @unittest.mock.patch('os.chmod') + @unittest.mock.patch('os.geteuid') + def test_extract_without_numeric_owner(self, mock_geteuid, mock_chmod, + mock_chown): + with self._setup_test(mock_geteuid) as (tarfl, filename_1, _, _): + tarfl.extract(filename_1, TEMPDIR, numeric_owner=False) + + # convert to filesystem paths + f_filename_1 = os.path.join(TEMPDIR, filename_1) + + mock_chown.assert_called_with(f_filename_1, 0, 0) + + @unittest.mock.patch('os.geteuid') + def test_keyword_only(self, mock_geteuid): + with self._setup_test(mock_geteuid) as (tarfl, filename_1, _, _): + self.assertRaises(TypeError, + tarfl.extract, filename_1, TEMPDIR, False, True) + + def setUpModule(): support.unlink(TEMPDIR) os.makedirs(TEMPDIR) diff --git a/Lib/test/test_tcl.py b/Lib/test/test_tcl.py index 66e9d49..5be645a 100644 --- a/Lib/test/test_tcl.py +++ b/Lib/test/test_tcl.py @@ -7,9 +7,7 @@ from test import support # Skip this test if the _tkinter module wasn't built. _tkinter = support.import_module('_tkinter') -# Make sure tkinter._fix runs to set up the environment -tkinter = support.import_fresh_module('tkinter') - +import tkinter from tkinter import Tcl from _tkinter import TclError @@ -130,9 +128,7 @@ class TclTest(unittest.TestCase): self.assertRaises(TclError,tcl.unsetvar,'a') def get_integers(self): - integers = (0, 1, -1, 2**31-1, -2**31) - if tcl_version >= (8, 4): # wideInt was added in Tcl 8.4 - integers += (2**31, -2**31-1, 2**63-1, -2**63) + integers = (0, 1, -1, 2**31-1, -2**31, 2**31, -2**31-1, 2**63-1, -2**63) # bignum was added in Tcl 8.5, but its support is able only since 8.5.8 if (get_tk_patchlevel() >= (8, 6, 0, 'final') or (8, 5, 8) <= get_tk_patchlevel() < (8, 6)): @@ -165,10 +161,10 @@ class TclTest(unittest.TestCase): self.assertEqual(tcl.getdouble(' 42 '), 42.0) self.assertEqual(tcl.getdouble(' 42.5 '), 42.5) self.assertEqual(tcl.getdouble(42.5), 42.5) + self.assertEqual(tcl.getdouble(42), 42.0) self.assertRaises(TypeError, tcl.getdouble) self.assertRaises(TypeError, tcl.getdouble, '42.5', '10') self.assertRaises(TypeError, tcl.getdouble, b'42.5') - self.assertRaises(TypeError, tcl.getdouble, 42) self.assertRaises(TclError, tcl.getdouble, 'a') self.assertRaises((TypeError, ValueError, TclError), tcl.getdouble, '42.5\0') @@ -464,6 +460,8 @@ class TclTest(unittest.TestCase): # XXX NaN representation can be not parsable by float() self.assertEqual(passValue((1, '2', (3.4,))), (1, '2', (3.4,)) if self.wantobjects else '1 2 3.4') + self.assertEqual(passValue(['a', ['b', 'c']]), + ('a', ('b', 'c')) if self.wantobjects else 'a {b c}') def test_user_command(self): result = None @@ -517,6 +515,7 @@ class TclTest(unittest.TestCase): # XXX NaN representation can be not parsable by float() check((), '') check((1, (2,), (3, 4), '5 6', ()), '1 2 {3 4} {5 6} {}') + check([1, [2,], [3, 4], '5 6', []], '1 2 {3 4} {5 6} {}') def test_splitlist(self): splitlist = self.interp.tk.splitlist @@ -542,12 +541,15 @@ class TclTest(unittest.TestCase): ('a 3.4', ('a', '3.4')), (('a', 3.4), ('a', 3.4)), ((), ()), + ([], ()), + (['a', ['b', 'c']], ('a', ['b', 'c'])), (call('list', 1, '2', (3.4,)), (1, '2', (3.4,)) if self.wantobjects else ('1', '2', '3.4')), ] + tk_patchlevel = get_tk_patchlevel() if tcl_version >= (8, 5): - if not self.wantobjects or get_tk_patchlevel() < (8, 5, 5): + if not self.wantobjects or tk_patchlevel < (8, 5, 5): # Before 8.5.5 dicts were converted to lists through string expected = ('12', '\u20ac', '\xe2\x82\xac', '3.4') else: @@ -556,8 +558,11 @@ class TclTest(unittest.TestCase): (call('dict', 'create', 12, '\u20ac', b'\xe2\x82\xac', (3.4,)), expected), ] + dbg_info = ('want objects? %s, Tcl version: %s, Tk patchlevel: %s' + % (self.wantobjects, tcl_version, tk_patchlevel)) for arg, res in testcases: - self.assertEqual(splitlist(arg), res, msg=arg) + self.assertEqual(splitlist(arg), res, + 'arg=%a, %s' % (arg, dbg_info)) self.assertRaises(TclError, splitlist, '{') def test_split(self): @@ -589,6 +594,9 @@ class TclTest(unittest.TestCase): (('a', 3.4), ('a', 3.4)), (('a', (2, 3.4)), ('a', (2, 3.4))), ((), ()), + ([], ()), + (['a', 'b c'], ('a', ('b', 'c'))), + (['a', ['b', 'c']], ('a', ('b', 'c'))), (call('list', 1, '2', (3.4,)), (1, '2', (3.4,)) if self.wantobjects else ('1', '2', '3.4')), diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py index ee1c357..524bba3 100644 --- a/Lib/test/test_telnetlib.py +++ b/Lib/test/test_telnetlib.py @@ -11,7 +11,7 @@ threading = support.import_module('threading') HOST = support.HOST def server(evt, serv): - serv.listen(5) + serv.listen() evt.set() try: conn, addr = serv.accept() @@ -393,9 +393,5 @@ class ExpectTests(ExpectAndReadTestCase): self.assertEqual(data, b''.join(want[:-1])) -def test_main(verbose=None): - support.run_unittest(GeneralTests, ReadTests, WriteTests, OptionTests, - ExpectTests) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py index 6641298..4a077cc 100644 --- a/Lib/test/test_tempfile.py +++ b/Lib/test/test_tempfile.py @@ -12,7 +12,8 @@ import weakref from unittest import mock import unittest -from test import support, script_helper +from test import support +from test.support import script_helper if hasattr(os, 'stat'): @@ -35,10 +36,38 @@ else: # in order of their appearance in the file. Testing which requires # threads is not done here. +class TestLowLevelInternals(unittest.TestCase): + def test_infer_return_type_singles(self): + self.assertIs(str, tempfile._infer_return_type('')) + self.assertIs(bytes, tempfile._infer_return_type(b'')) + self.assertIs(str, tempfile._infer_return_type(None)) + + def test_infer_return_type_multiples(self): + self.assertIs(str, tempfile._infer_return_type('', '')) + self.assertIs(bytes, tempfile._infer_return_type(b'', b'')) + with self.assertRaises(TypeError): + tempfile._infer_return_type('', b'') + with self.assertRaises(TypeError): + tempfile._infer_return_type(b'', '') + + def test_infer_return_type_multiples_and_none(self): + self.assertIs(str, tempfile._infer_return_type(None, '')) + self.assertIs(str, tempfile._infer_return_type('', None)) + self.assertIs(str, tempfile._infer_return_type(None, None)) + self.assertIs(bytes, tempfile._infer_return_type(b'', None)) + self.assertIs(bytes, tempfile._infer_return_type(None, b'')) + with self.assertRaises(TypeError): + tempfile._infer_return_type('', None, b'') + with self.assertRaises(TypeError): + tempfile._infer_return_type(b'', None, '') + + # Common functionality. + class BaseTestCase(unittest.TestCase): str_check = re.compile(r"^[a-z0-9_-]{8}$") + b_check = re.compile(br"^[a-z0-9_-]{8}$") def setUp(self): self._warnings_manager = support.check_warnings() @@ -55,18 +84,31 @@ class BaseTestCase(unittest.TestCase): npre = nbase[:len(pre)] nsuf = nbase[len(nbase)-len(suf):] + if dir is not None: + self.assertIs(type(name), str if type(dir) is str else bytes, + "unexpected return type") + if pre is not None: + self.assertIs(type(name), str if type(pre) is str else bytes, + "unexpected return type") + if suf is not None: + self.assertIs(type(name), str if type(suf) is str else bytes, + "unexpected return type") + if (dir, pre, suf) == (None, None, None): + self.assertIs(type(name), str, "default return type must be str") + # check for equality of the absolute paths! self.assertEqual(os.path.abspath(ndir), os.path.abspath(dir), - "file '%s' not in directory '%s'" % (name, dir)) + "file %r not in directory %r" % (name, dir)) self.assertEqual(npre, pre, - "file '%s' does not begin with '%s'" % (nbase, pre)) + "file %r does not begin with %r" % (nbase, pre)) self.assertEqual(nsuf, suf, - "file '%s' does not end with '%s'" % (nbase, suf)) + "file %r does not end with %r" % (nbase, suf)) nbase = nbase[len(pre):len(nbase)-len(suf)] - self.assertTrue(self.str_check.match(nbase), - "random string '%s' does not match ^[a-z0-9_-]{8}$" - % nbase) + check = self.str_check if isinstance(nbase, str) else self.b_check + self.assertTrue(check.match(nbase), + "random characters %r do not match %r" + % (nbase, check.pattern)) class TestExports(BaseTestCase): @@ -82,7 +124,9 @@ class TestExports(BaseTestCase): "mktemp" : 1, "TMP_MAX" : 1, "gettempprefix" : 1, + "gettempprefixb" : 1, "gettempdir" : 1, + "gettempdirb" : 1, "tempdir" : 1, "template" : 1, "SpooledTemporaryFile" : 1, @@ -319,7 +363,8 @@ class TestMkstempInner(TestBadTempdir, BaseTestCase): if bin: flags = self._bflags else: flags = self._tflags - (self.fd, self.name) = tempfile._mkstemp_inner(dir, pre, suf, flags) + output_type = tempfile._infer_return_type(dir, pre, suf) + (self.fd, self.name) = tempfile._mkstemp_inner(dir, pre, suf, flags, output_type) def write(self, str): os.write(self.fd, str) @@ -328,9 +373,17 @@ class TestMkstempInner(TestBadTempdir, BaseTestCase): self._close(self.fd) self._unlink(self.name) - def do_create(self, dir=None, pre="", suf="", bin=1): + def do_create(self, dir=None, pre=None, suf=None, bin=1): + output_type = tempfile._infer_return_type(dir, pre, suf) if dir is None: - dir = tempfile.gettempdir() + if output_type is str: + dir = tempfile.gettempdir() + else: + dir = tempfile.gettempdirb() + if pre is None: + pre = output_type() + if suf is None: + suf = output_type() file = self.mkstemped(dir, pre, suf, bin) self.nameCheck(file.name, dir, pre, suf) @@ -344,6 +397,23 @@ class TestMkstempInner(TestBadTempdir, BaseTestCase): self.do_create(pre="a", suf="b").write(b"blat") self.do_create(pre="aa", suf=".txt").write(b"blat") + def test_basic_with_bytes_names(self): + # _mkstemp_inner can create files when given name parts all + # specified as bytes. + dir_b = tempfile.gettempdirb() + self.do_create(dir=dir_b, suf=b"").write(b"blat") + self.do_create(dir=dir_b, pre=b"a").write(b"blat") + self.do_create(dir=dir_b, suf=b"b").write(b"blat") + self.do_create(dir=dir_b, pre=b"a", suf=b"b").write(b"blat") + self.do_create(dir=dir_b, pre=b"aa", suf=b".txt").write(b"blat") + # Can't mix str & binary types in the args. + with self.assertRaises(TypeError): + self.do_create(dir="", suf=b"").write(b"blat") + with self.assertRaises(TypeError): + self.do_create(dir=dir_b, pre="").write(b"blat") + with self.assertRaises(TypeError): + self.do_create(dir=dir_b, pre=b"", suf="").write(b"blat") + def test_basic_many(self): # _mkstemp_inner can create many files (stochastic) extant = list(range(TEST_FILES)) @@ -423,9 +493,10 @@ class TestMkstempInner(TestBadTempdir, BaseTestCase): def make_temp(self): return tempfile._mkstemp_inner(tempfile.gettempdir(), - tempfile.template, + tempfile.gettempprefix(), '', - tempfile._bin_openflags) + tempfile._bin_openflags, + str) def test_collision_with_existing_file(self): # _mkstemp_inner tries another name when a file with @@ -461,7 +532,12 @@ class TestGetTempPrefix(BaseTestCase): p = tempfile.gettempprefix() self.assertIsInstance(p, str) - self.assertTrue(len(p) > 0) + self.assertGreater(len(p), 0) + + pb = tempfile.gettempprefixb() + + self.assertIsInstance(pb, bytes) + self.assertGreater(len(pb), 0) def test_usable_template(self): # gettempprefix returns a usable prefix string @@ -486,11 +562,11 @@ class TestGetTempDir(BaseTestCase): def test_directory_exists(self): # gettempdir returns a directory which exists - dir = tempfile.gettempdir() - self.assertTrue(os.path.isabs(dir) or dir == os.curdir, - "%s is not an absolute path" % dir) - self.assertTrue(os.path.isdir(dir), - "%s is not a directory" % dir) + for d in (tempfile.gettempdir(), tempfile.gettempdirb()): + self.assertTrue(os.path.isabs(d) or d == os.curdir, + "%r is not an absolute path" % d) + self.assertTrue(os.path.isdir(d), + "%r is not a directory" % d) def test_directory_writable(self): # gettempdir returns a directory writable by the user @@ -506,8 +582,11 @@ class TestGetTempDir(BaseTestCase): # gettempdir always returns the same object a = tempfile.gettempdir() b = tempfile.gettempdir() + c = tempfile.gettempdirb() self.assertTrue(a is b) + self.assertNotEqual(type(a), type(c)) + self.assertEqual(a, os.fsdecode(c)) def test_case_sensitive(self): # gettempdir should not flatten its case @@ -527,9 +606,17 @@ class TestGetTempDir(BaseTestCase): class TestMkstemp(BaseTestCase): """Test mkstemp().""" - def do_create(self, dir=None, pre="", suf=""): + def do_create(self, dir=None, pre=None, suf=None): + output_type = tempfile._infer_return_type(dir, pre, suf) if dir is None: - dir = tempfile.gettempdir() + if output_type is str: + dir = tempfile.gettempdir() + else: + dir = tempfile.gettempdirb() + if pre is None: + pre = output_type() + if suf is None: + suf = output_type() (fd, name) = tempfile.mkstemp(dir=dir, prefix=pre, suffix=suf) (ndir, nbase) = os.path.split(name) adir = os.path.abspath(dir) @@ -551,6 +638,24 @@ class TestMkstemp(BaseTestCase): self.do_create(pre="aa", suf=".txt") self.do_create(dir=".") + def test_basic_with_bytes_names(self): + # mkstemp can create files when given name parts all + # specified as bytes. + d = tempfile.gettempdirb() + self.do_create(dir=d, suf=b"") + self.do_create(dir=d, pre=b"a") + self.do_create(dir=d, suf=b"b") + self.do_create(dir=d, pre=b"a", suf=b"b") + self.do_create(dir=d, pre=b"aa", suf=b".txt") + self.do_create(dir=b".") + with self.assertRaises(TypeError): + self.do_create(dir=".", pre=b"aa", suf=b".txt") + with self.assertRaises(TypeError): + self.do_create(dir=b".", pre="aa", suf=b".txt") + with self.assertRaises(TypeError): + self.do_create(dir=b".", pre=b"aa", suf=".txt") + + def test_choose_directory(self): # mkstemp can create directories in a user-selected directory dir = tempfile.mkdtemp() @@ -566,9 +671,17 @@ class TestMkdtemp(TestBadTempdir, BaseTestCase): def make_temp(self): return tempfile.mkdtemp() - def do_create(self, dir=None, pre="", suf=""): + def do_create(self, dir=None, pre=None, suf=None): + output_type = tempfile._infer_return_type(dir, pre, suf) if dir is None: - dir = tempfile.gettempdir() + if output_type is str: + dir = tempfile.gettempdir() + else: + dir = tempfile.gettempdirb() + if pre is None: + pre = output_type() + if suf is None: + suf = output_type() name = tempfile.mkdtemp(dir=dir, prefix=pre, suffix=suf) try: @@ -586,6 +699,21 @@ class TestMkdtemp(TestBadTempdir, BaseTestCase): os.rmdir(self.do_create(pre="a", suf="b")) os.rmdir(self.do_create(pre="aa", suf=".txt")) + def test_basic_with_bytes_names(self): + # mkdtemp can create directories when given all binary parts + d = tempfile.gettempdirb() + os.rmdir(self.do_create(dir=d)) + os.rmdir(self.do_create(dir=d, pre=b"a")) + os.rmdir(self.do_create(dir=d, suf=b"b")) + os.rmdir(self.do_create(dir=d, pre=b"a", suf=b"b")) + os.rmdir(self.do_create(dir=d, pre=b"aa", suf=b".txt")) + with self.assertRaises(TypeError): + os.rmdir(self.do_create(dir=d, pre="aa", suf=b".txt")) + with self.assertRaises(TypeError): + os.rmdir(self.do_create(dir=d, pre=b"aa", suf=".txt")) + with self.assertRaises(TypeError): + os.rmdir(self.do_create(dir="", pre=b"aa", suf=b".txt")) + def test_basic_many(self): # mkdtemp can create many directories (stochastic) extant = list(range(TEST_FILES)) @@ -1313,8 +1441,5 @@ class TestTemporaryDirectory(BaseTestCase): self.assertFalse(os.path.exists(name)) -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_textwrap.py b/Lib/test/test_textwrap.py index 1bba77e..707aaaa 100644 --- a/Lib/test/test_textwrap.py +++ b/Lib/test/test_textwrap.py @@ -184,6 +184,16 @@ What a mess! self.check_wrap(text, 42, ["this-is-a-useful-feature-for-reformatting-", "posts-from-tim-peters'ly"]) + # The test tests current behavior but is not testing parts of the API. + expect = ("this-|is-|a-|useful-|feature-|for-|" + "reformatting-|posts-|from-|tim-|peters'ly").split('|') + self.check_wrap(text, 1, expect, break_long_words=False) + self.check_split(text, expect) + + self.check_split('e-mail', ['e-mail']) + self.check_split('Jelly-O', ['Jelly-O']) + # The test tests current behavior but is not testing parts of the API. + self.check_split('half-a-crown', 'half-|a-|crown'.split('|')) def test_hyphenated_numbers(self): # Test that hyphenated numbers (eg. dates) are not broken like words. @@ -195,6 +205,7 @@ What a mess! 'released on 1994-02-15.']) self.check_wrap(text, 40, ['Python 1.0.0 was released on 1994-01-26.', 'Python 1.0.1 was released on 1994-02-15.']) + self.check_wrap(text, 1, text.split(), break_long_words=False) text = "I do all my shopping at 7-11." self.check_wrap(text, 25, ["I do all my shopping at", @@ -202,6 +213,7 @@ What a mess! self.check_wrap(text, 27, ["I do all my shopping at", "7-11."]) self.check_wrap(text, 29, ["I do all my shopping at 7-11."]) + self.check_wrap(text, 1, text.split(), break_long_words=False) def test_em_dash(self): # Test text with em-dashes @@ -326,6 +338,10 @@ What a mess! self.check_split("the ['wibble-wobble'] widget", ['the', ' ', "['wibble-", "wobble']", ' ', 'widget']) + # The test tests current behavior but is not testing parts of the API. + self.check_split("what-d'you-call-it.", + "what-d'you-|call-|it.".split('|')) + def test_funky_parens (self): # Second part of SF bug #596434: long option strings inside # parentheses. diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py index 6144901..ef3059b 100644 --- a/Lib/test/test_thread.py +++ b/Lib/test/test_thread.py @@ -252,9 +252,5 @@ class TestForkInThread(unittest.TestCase): pass -def test_main(): - support.run_unittest(ThreadRunningTests, BarrierTest, LockTests, - TestForkInThread) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_threaded_import.py b/Lib/test/test_threaded_import.py index 4be615a..9b2d9a6 100644 --- a/Lib/test/test_threaded_import.py +++ b/Lib/test/test_threaded_import.py @@ -115,12 +115,18 @@ class ThreadedImportTests(unittest.TestCase): errors = [] done_tasks = [] done.clear() + t0 = time.monotonic() with start_threads(threading.Thread(target=task, args=(N, done, done_tasks, errors,)) for i in range(N)): pass - self.assertTrue(done.wait(60)) - self.assertFalse(errors) + completed = done.wait(10 * 60) + dt = time.monotonic() - t0 + if verbose: + print("%.1f ms" % (dt*1e3), flush=True, end=" ") + dbg_info = 'done: %s/%s' % (len(done_tasks), N) + self.assertFalse(errors, dbg_info) + self.assertTrue(completed, dbg_info) if verbose: print("OK.") diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 4b75ea6..3b11bf6 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -4,7 +4,7 @@ Tests for the threading module. import test.support from test.support import verbose, strip_python_stderr, import_module, cpython_only -from test.script_helper import assert_python_ok, assert_python_failure +from test.support.script_helper import assert_python_ok, assert_python_failure import random import re @@ -945,7 +945,7 @@ class ThreadingExceptionTests(BaseTestCase): def outer(): try: recurse() - except RuntimeError: + except RecursionError: pass w = threading.Thread(target=outer) diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py index be7ddcc..de0cbc4 100644 --- a/Lib/test/test_time.py +++ b/Lib/test/test_time.py @@ -1,21 +1,37 @@ from test import support -import time -import unittest +import enum import locale -import sysconfig -import sys import platform +import sys +import sysconfig +import time +import unittest try: import threading except ImportError: threading = None +try: + import _testcapi +except ImportError: + _testcapi = None + # Max year is only limited by the size of C int. SIZEOF_INT = sysconfig.get_config_var('SIZEOF_INT') or 4 TIME_MAXYEAR = (1 << 8 * SIZEOF_INT - 1) - 1 TIME_MINYEAR = -TIME_MAXYEAR - 1 -_PyTime_ROUND_DOWN = 0 -_PyTime_ROUND_UP = 1 + +US_TO_NS = 10 ** 3 +MS_TO_NS = 10 ** 6 +SEC_TO_NS = 10 ** 9 + +class _PyTime(enum.IntEnum): + # Round towards minus infinity (-inf) + ROUND_FLOOR = 0 + # Round towards infinity (+inf) + ROUND_CEILING = 1 + +ALL_ROUNDING_METHODS = (_PyTime.ROUND_FLOOR, _PyTime.ROUND_CEILING) class TimeTestCase(unittest.TestCase): @@ -98,13 +114,6 @@ class TimeTestCase(unittest.TestCase): except ValueError: self.fail('conversion specifier: %r failed.' % format) - # Issue #10762: Guard against invalid/non-supported format string - # so that Python don't crash (Windows crashes when the format string - # input to [w]strftime is not kosher. - if sys.platform.startswith('win'): - with self.assertRaises(ValueError): - time.strftime('%f') - def _bounds_checking(self, func): # Make sure that strftime() checks the bounds of the various parts # of the time tuple (0 is valid for *all* values). @@ -165,6 +174,19 @@ class TimeTestCase(unittest.TestCase): def test_strftime_bounding_check(self): self._bounds_checking(lambda tup: time.strftime('', tup)) + def test_strftime_format_check(self): + # Test that strftime does not crash on invalid format strings + # that may trigger a buffer overread. When not triggered, + # strftime may succeed or raise ValueError depending on + # the platform. + for x in [ '', 'A', '%A', '%AA' ]: + for y in range(0x0, 0x10): + for z in [ '%', 'A%', 'AA%', '%A%', 'A%A%', '%#' ]: + try: + time.strftime(x * y + z) + except ValueError: + pass + def test_default_values_for_zero(self): # Make sure that using all zeros uses the proper default # values. No test for daylight savings since strftime() does @@ -595,112 +617,65 @@ class TestPytime(unittest.TestCase): def test_time_t(self): from _testcapi import pytime_object_to_time_t for obj, time_t, rnd in ( - # Round towards zero - (0, 0, _PyTime_ROUND_DOWN), - (-1, -1, _PyTime_ROUND_DOWN), - (-1.0, -1, _PyTime_ROUND_DOWN), - (-1.9, -1, _PyTime_ROUND_DOWN), - (1.0, 1, _PyTime_ROUND_DOWN), - (1.9, 1, _PyTime_ROUND_DOWN), - # Round away from zero - (0, 0, _PyTime_ROUND_UP), - (-1, -1, _PyTime_ROUND_UP), - (-1.0, -1, _PyTime_ROUND_UP), - (-1.9, -2, _PyTime_ROUND_UP), - (1.0, 1, _PyTime_ROUND_UP), - (1.9, 2, _PyTime_ROUND_UP), + # Round towards minus infinity (-inf) + (0, 0, _PyTime.ROUND_FLOOR), + (-1, -1, _PyTime.ROUND_FLOOR), + (-1.0, -1, _PyTime.ROUND_FLOOR), + (-1.9, -2, _PyTime.ROUND_FLOOR), + (1.0, 1, _PyTime.ROUND_FLOOR), + (1.9, 1, _PyTime.ROUND_FLOOR), + # Round towards infinity (+inf) + (0, 0, _PyTime.ROUND_CEILING), + (-1, -1, _PyTime.ROUND_CEILING), + (-1.0, -1, _PyTime.ROUND_CEILING), + (-1.9, -1, _PyTime.ROUND_CEILING), + (1.0, 1, _PyTime.ROUND_CEILING), + (1.9, 2, _PyTime.ROUND_CEILING), ): self.assertEqual(pytime_object_to_time_t(obj, rnd), time_t) - rnd = _PyTime_ROUND_DOWN + rnd = _PyTime.ROUND_FLOOR for invalid in self.invalid_values: self.assertRaises(OverflowError, pytime_object_to_time_t, invalid, rnd) @support.cpython_only - def test_timeval(self): - from _testcapi import pytime_object_to_timeval - for obj, timeval, rnd in ( - # Round towards zero - (0, (0, 0), _PyTime_ROUND_DOWN), - (-1, (-1, 0), _PyTime_ROUND_DOWN), - (-1.0, (-1, 0), _PyTime_ROUND_DOWN), - (1e-6, (0, 1), _PyTime_ROUND_DOWN), - (1e-7, (0, 0), _PyTime_ROUND_DOWN), - (-1e-6, (-1, 999999), _PyTime_ROUND_DOWN), - (-1e-7, (-1, 999999), _PyTime_ROUND_DOWN), - (-1.2, (-2, 800000), _PyTime_ROUND_DOWN), - (0.9999999, (0, 999999), _PyTime_ROUND_DOWN), - (0.0000041, (0, 4), _PyTime_ROUND_DOWN), - (1.1234560, (1, 123456), _PyTime_ROUND_DOWN), - (1.1234569, (1, 123456), _PyTime_ROUND_DOWN), - (-0.0000040, (-1, 999996), _PyTime_ROUND_DOWN), - (-0.0000041, (-1, 999995), _PyTime_ROUND_DOWN), - (-1.1234560, (-2, 876544), _PyTime_ROUND_DOWN), - (-1.1234561, (-2, 876543), _PyTime_ROUND_DOWN), - # Round away from zero - (0, (0, 0), _PyTime_ROUND_UP), - (-1, (-1, 0), _PyTime_ROUND_UP), - (-1.0, (-1, 0), _PyTime_ROUND_UP), - (1e-6, (0, 1), _PyTime_ROUND_UP), - (1e-7, (0, 1), _PyTime_ROUND_UP), - (-1e-6, (-1, 999999), _PyTime_ROUND_UP), - (-1e-7, (-1, 999999), _PyTime_ROUND_UP), - (-1.2, (-2, 800000), _PyTime_ROUND_UP), - (0.9999999, (1, 0), _PyTime_ROUND_UP), - (0.0000041, (0, 5), _PyTime_ROUND_UP), - (1.1234560, (1, 123457), _PyTime_ROUND_UP), - (1.1234569, (1, 123457), _PyTime_ROUND_UP), - (-0.0000040, (-1, 999996), _PyTime_ROUND_UP), - (-0.0000041, (-1, 999995), _PyTime_ROUND_UP), - (-1.1234560, (-2, 876544), _PyTime_ROUND_UP), - (-1.1234561, (-2, 876543), _PyTime_ROUND_UP), - ): - with self.subTest(obj=obj, round=rnd, timeval=timeval): - self.assertEqual(pytime_object_to_timeval(obj, rnd), timeval) - - rnd = _PyTime_ROUND_DOWN - for invalid in self.invalid_values: - self.assertRaises(OverflowError, - pytime_object_to_timeval, invalid, rnd) - - @support.cpython_only def test_timespec(self): from _testcapi import pytime_object_to_timespec for obj, timespec, rnd in ( - # Round towards zero - (0, (0, 0), _PyTime_ROUND_DOWN), - (-1, (-1, 0), _PyTime_ROUND_DOWN), - (-1.0, (-1, 0), _PyTime_ROUND_DOWN), - (1e-9, (0, 1), _PyTime_ROUND_DOWN), - (1e-10, (0, 0), _PyTime_ROUND_DOWN), - (-1e-9, (-1, 999999999), _PyTime_ROUND_DOWN), - (-1e-10, (-1, 999999999), _PyTime_ROUND_DOWN), - (-1.2, (-2, 800000000), _PyTime_ROUND_DOWN), - (0.9999999999, (0, 999999999), _PyTime_ROUND_DOWN), - (1.1234567890, (1, 123456789), _PyTime_ROUND_DOWN), - (1.1234567899, (1, 123456789), _PyTime_ROUND_DOWN), - (-1.1234567890, (-2, 876543211), _PyTime_ROUND_DOWN), - (-1.1234567891, (-2, 876543210), _PyTime_ROUND_DOWN), - # Round away from zero - (0, (0, 0), _PyTime_ROUND_UP), - (-1, (-1, 0), _PyTime_ROUND_UP), - (-1.0, (-1, 0), _PyTime_ROUND_UP), - (1e-9, (0, 1), _PyTime_ROUND_UP), - (1e-10, (0, 1), _PyTime_ROUND_UP), - (-1e-9, (-1, 999999999), _PyTime_ROUND_UP), - (-1e-10, (-1, 999999999), _PyTime_ROUND_UP), - (-1.2, (-2, 800000000), _PyTime_ROUND_UP), - (0.9999999999, (1, 0), _PyTime_ROUND_UP), - (1.1234567890, (1, 123456790), _PyTime_ROUND_UP), - (1.1234567899, (1, 123456790), _PyTime_ROUND_UP), - (-1.1234567890, (-2, 876543211), _PyTime_ROUND_UP), - (-1.1234567891, (-2, 876543210), _PyTime_ROUND_UP), + # Round towards minus infinity (-inf) + (0, (0, 0), _PyTime.ROUND_FLOOR), + (-1, (-1, 0), _PyTime.ROUND_FLOOR), + (-1.0, (-1, 0), _PyTime.ROUND_FLOOR), + (1e-9, (0, 1), _PyTime.ROUND_FLOOR), + (1e-10, (0, 0), _PyTime.ROUND_FLOOR), + (-1e-9, (-1, 999999999), _PyTime.ROUND_FLOOR), + (-1e-10, (-1, 999999999), _PyTime.ROUND_FLOOR), + (-1.2, (-2, 800000000), _PyTime.ROUND_FLOOR), + (0.9999999999, (0, 999999999), _PyTime.ROUND_FLOOR), + (1.1234567890, (1, 123456789), _PyTime.ROUND_FLOOR), + (1.1234567899, (1, 123456789), _PyTime.ROUND_FLOOR), + (-1.1234567890, (-2, 876543211), _PyTime.ROUND_FLOOR), + (-1.1234567891, (-2, 876543210), _PyTime.ROUND_FLOOR), + # Round towards infinity (+inf) + (0, (0, 0), _PyTime.ROUND_CEILING), + (-1, (-1, 0), _PyTime.ROUND_CEILING), + (-1.0, (-1, 0), _PyTime.ROUND_CEILING), + (1e-9, (0, 1), _PyTime.ROUND_CEILING), + (1e-10, (0, 1), _PyTime.ROUND_CEILING), + (-1e-9, (-1, 999999999), _PyTime.ROUND_CEILING), + (-1e-10, (0, 0), _PyTime.ROUND_CEILING), + (-1.2, (-2, 800000000), _PyTime.ROUND_CEILING), + (0.9999999999, (1, 0), _PyTime.ROUND_CEILING), + (1.1234567890, (1, 123456790), _PyTime.ROUND_CEILING), + (1.1234567899, (1, 123456790), _PyTime.ROUND_CEILING), + (-1.1234567890, (-2, 876543211), _PyTime.ROUND_CEILING), + (-1.1234567891, (-2, 876543211), _PyTime.ROUND_CEILING), ): with self.subTest(obj=obj, round=rnd, timespec=timespec): self.assertEqual(pytime_object_to_timespec(obj, rnd), timespec) - rnd = _PyTime_ROUND_DOWN + rnd = _PyTime.ROUND_FLOOR for invalid in self.invalid_values: self.assertRaises(OverflowError, pytime_object_to_timespec, invalid, rnd) @@ -759,5 +734,267 @@ class TestPytime(unittest.TestCase): self.assertIs(lt.tm_zone, None) +@unittest.skipUnless(_testcapi is not None, + 'need the _testcapi module') +class TestPyTime_t(unittest.TestCase): + def test_FromSeconds(self): + from _testcapi import PyTime_FromSeconds + for seconds in (0, 3, -456, _testcapi.INT_MAX, _testcapi.INT_MIN): + with self.subTest(seconds=seconds): + self.assertEqual(PyTime_FromSeconds(seconds), + seconds * SEC_TO_NS) + + def test_FromSecondsObject(self): + from _testcapi import PyTime_FromSecondsObject + + # Conversion giving the same result for all rounding methods + for rnd in ALL_ROUNDING_METHODS: + for obj, ts in ( + # integers + (0, 0), + (1, SEC_TO_NS), + (-3, -3 * SEC_TO_NS), + + # float: subseconds + (0.0, 0), + (1e-9, 1), + (1e-6, 10 ** 3), + (1e-3, 10 ** 6), + + # float: seconds + (2.0, 2 * SEC_TO_NS), + (123.0, 123 * SEC_TO_NS), + (-7.0, -7 * SEC_TO_NS), + + # nanosecond are kept for value <= 2^23 seconds + (2**22 - 1e-9, 4194303999999999), + (2**22, 4194304000000000), + (2**22 + 1e-9, 4194304000000001), + (2**23 - 1e-9, 8388607999999999), + (2**23, 8388608000000000), + + # start loosing precision for value > 2^23 seconds + (2**23 + 1e-9, 8388608000000002), + + # nanoseconds are lost for value > 2^23 seconds + (2**24 - 1e-9, 16777215999999998), + (2**24, 16777216000000000), + (2**24 + 1e-9, 16777216000000000), + (2**25 - 1e-9, 33554432000000000), + (2**25 , 33554432000000000), + (2**25 + 1e-9, 33554432000000000), + + # close to 2^63 nanoseconds (_PyTime_t limit) + (9223372036, 9223372036 * SEC_TO_NS), + (9223372036.0, 9223372036 * SEC_TO_NS), + (-9223372036, -9223372036 * SEC_TO_NS), + (-9223372036.0, -9223372036 * SEC_TO_NS), + ): + with self.subTest(obj=obj, round=rnd, timestamp=ts): + self.assertEqual(PyTime_FromSecondsObject(obj, rnd), ts) + + with self.subTest(round=rnd): + with self.assertRaises(OverflowError): + PyTime_FromSecondsObject(9223372037, rnd) + PyTime_FromSecondsObject(9223372037.0, rnd) + PyTime_FromSecondsObject(-9223372037, rnd) + PyTime_FromSecondsObject(-9223372037.0, rnd) + + # Conversion giving different results depending on the rounding method + FLOOR = _PyTime.ROUND_FLOOR + CEILING = _PyTime.ROUND_CEILING + for obj, ts, rnd in ( + # close to zero + ( 1e-10, 0, FLOOR), + ( 1e-10, 1, CEILING), + (-1e-10, -1, FLOOR), + (-1e-10, 0, CEILING), + + # test rounding of the last nanosecond + ( 1.1234567899, 1123456789, FLOOR), + ( 1.1234567899, 1123456790, CEILING), + (-1.1234567899, -1123456790, FLOOR), + (-1.1234567899, -1123456789, CEILING), + + # close to 1 second + ( 0.9999999999, 999999999, FLOOR), + ( 0.9999999999, 1000000000, CEILING), + (-0.9999999999, -1000000000, FLOOR), + (-0.9999999999, -999999999, CEILING), + ): + with self.subTest(obj=obj, round=rnd, timestamp=ts): + self.assertEqual(PyTime_FromSecondsObject(obj, rnd), ts) + + def test_AsSecondsDouble(self): + from _testcapi import PyTime_AsSecondsDouble + + for nanoseconds, seconds in ( + # near 1 nanosecond + ( 0, 0.0), + ( 1, 1e-9), + (-1, -1e-9), + + # near 1 second + (SEC_TO_NS + 1, 1.0 + 1e-9), + (SEC_TO_NS, 1.0), + (SEC_TO_NS - 1, 1.0 - 1e-9), + + # a few seconds + (123 * SEC_TO_NS, 123.0), + (-567 * SEC_TO_NS, -567.0), + + # nanosecond are kept for value <= 2^23 seconds + (4194303999999999, 2**22 - 1e-9), + (4194304000000000, 2**22), + (4194304000000001, 2**22 + 1e-9), + + # start loosing precision for value > 2^23 seconds + (8388608000000002, 2**23 + 1e-9), + + # nanoseconds are lost for value > 2^23 seconds + (16777215999999998, 2**24 - 1e-9), + (16777215999999999, 2**24 - 1e-9), + (16777216000000000, 2**24 ), + (16777216000000001, 2**24 ), + (16777216000000002, 2**24 + 2e-9), + + (33554432000000000, 2**25 ), + (33554432000000002, 2**25 ), + (33554432000000004, 2**25 + 4e-9), + + # close to 2^63 nanoseconds (_PyTime_t limit) + (9223372036 * SEC_TO_NS, 9223372036.0), + (-9223372036 * SEC_TO_NS, -9223372036.0), + ): + with self.subTest(nanoseconds=nanoseconds, seconds=seconds): + self.assertEqual(PyTime_AsSecondsDouble(nanoseconds), + seconds) + + def test_timeval(self): + from _testcapi import PyTime_AsTimeval + for rnd in ALL_ROUNDING_METHODS: + for ns, tv in ( + # microseconds + (0, (0, 0)), + (1000, (0, 1)), + (-1000, (-1, 999999)), + + # seconds + (2 * SEC_TO_NS, (2, 0)), + (-3 * SEC_TO_NS, (-3, 0)), + ): + with self.subTest(nanoseconds=ns, timeval=tv, round=rnd): + self.assertEqual(PyTime_AsTimeval(ns, rnd), tv) + + FLOOR = _PyTime.ROUND_FLOOR + CEILING = _PyTime.ROUND_CEILING + for ns, tv, rnd in ( + # nanoseconds + (1, (0, 0), FLOOR), + (1, (0, 1), CEILING), + (-1, (-1, 999999), FLOOR), + (-1, (0, 0), CEILING), + + # seconds + nanoseconds + (1234567001, (1, 234567), FLOOR), + (1234567001, (1, 234568), CEILING), + (-1234567001, (-2, 765432), FLOOR), + (-1234567001, (-2, 765433), CEILING), + ): + with self.subTest(nanoseconds=ns, timeval=tv, round=rnd): + self.assertEqual(PyTime_AsTimeval(ns, rnd), tv) + + @unittest.skipUnless(hasattr(_testcapi, 'PyTime_AsTimespec'), + 'need _testcapi.PyTime_AsTimespec') + def test_timespec(self): + from _testcapi import PyTime_AsTimespec + for ns, ts in ( + # nanoseconds + (0, (0, 0)), + (1, (0, 1)), + (-1, (-1, 999999999)), + + # seconds + (2 * SEC_TO_NS, (2, 0)), + (-3 * SEC_TO_NS, (-3, 0)), + + # seconds + nanoseconds + (1234567890, (1, 234567890)), + (-1234567890, (-2, 765432110)), + ): + with self.subTest(nanoseconds=ns, timespec=ts): + self.assertEqual(PyTime_AsTimespec(ns), ts) + + def test_milliseconds(self): + from _testcapi import PyTime_AsMilliseconds + for rnd in ALL_ROUNDING_METHODS: + for ns, tv in ( + # milliseconds + (1 * MS_TO_NS, 1), + (-2 * MS_TO_NS, -2), + + # seconds + (2 * SEC_TO_NS, 2000), + (-3 * SEC_TO_NS, -3000), + ): + with self.subTest(nanoseconds=ns, timeval=tv, round=rnd): + self.assertEqual(PyTime_AsMilliseconds(ns, rnd), tv) + + FLOOR = _PyTime.ROUND_FLOOR + CEILING = _PyTime.ROUND_CEILING + for ns, ms, rnd in ( + # nanoseconds + (1, 0, FLOOR), + (1, 1, CEILING), + (-1, -1, FLOOR), + (-1, 0, CEILING), + + # seconds + nanoseconds + (1234 * MS_TO_NS + 1, 1234, FLOOR), + (1234 * MS_TO_NS + 1, 1235, CEILING), + (-1234 * MS_TO_NS - 1, -1235, FLOOR), + (-1234 * MS_TO_NS - 1, -1234, CEILING), + ): + with self.subTest(nanoseconds=ns, milliseconds=ms, round=rnd): + self.assertEqual(PyTime_AsMilliseconds(ns, rnd), ms) + + def test_microseconds(self): + from _testcapi import PyTime_AsMicroseconds + for rnd in ALL_ROUNDING_METHODS: + for ns, tv in ( + # microseconds + (1 * US_TO_NS, 1), + (-2 * US_TO_NS, -2), + + # milliseconds + (1 * MS_TO_NS, 1000), + (-2 * MS_TO_NS, -2000), + + # seconds + (2 * SEC_TO_NS, 2000000), + (-3 * SEC_TO_NS, -3000000), + ): + with self.subTest(nanoseconds=ns, timeval=tv, round=rnd): + self.assertEqual(PyTime_AsMicroseconds(ns, rnd), tv) + + FLOOR = _PyTime.ROUND_FLOOR + CEILING = _PyTime.ROUND_CEILING + for ns, ms, rnd in ( + # nanoseconds + (1, 0, FLOOR), + (1, 1, CEILING), + (-1, -1, FLOOR), + (-1, 0, CEILING), + + # seconds + nanoseconds + (1234 * US_TO_NS + 1, 1234, FLOOR), + (1234 * US_TO_NS + 1, 1235, CEILING), + (-1234 * US_TO_NS - 1, -1235, FLOOR), + (-1234 * US_TO_NS - 1, -1234, CEILING), + ): + with self.subTest(nanoseconds=ns, milliseconds=ms, round=rnd): + self.assertEqual(PyTime_AsMicroseconds(ns, rnd), ms) + + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_timeit.py b/Lib/test/test_timeit.py index 918a294..2db3c1b 100644 --- a/Lib/test/test_timeit.py +++ b/Lib/test/test_timeit.py @@ -5,7 +5,6 @@ import io import time from textwrap import dedent -from test.support import run_unittest from test.support import captured_stdout from test.support import captured_stderr @@ -89,8 +88,8 @@ class TestTimeit(unittest.TestCase): self.assertRaises(SyntaxError, timeit.Timer, setup='continue') self.assertRaises(SyntaxError, timeit.Timer, setup='from timeit import *') - fake_setup = "import timeit; timeit._fake_timer.setup()" - fake_stmt = "import timeit; timeit._fake_timer.inc()" + fake_setup = "import timeit\ntimeit._fake_timer.setup()" + fake_stmt = "import timeit\ntimeit._fake_timer.inc()" def fake_callable_setup(self): self.fake_timer.setup() @@ -98,9 +97,10 @@ class TestTimeit(unittest.TestCase): def fake_callable_stmt(self): self.fake_timer.inc() - def timeit(self, stmt, setup, number=None): + def timeit(self, stmt, setup, number=None, globals=None): self.fake_timer = FakeTimer() - t = timeit.Timer(stmt=stmt, setup=setup, timer=self.fake_timer) + t = timeit.Timer(stmt=stmt, setup=setup, timer=self.fake_timer, + globals=globals) kwargs = {} if number is None: number = DEFAULT_NUMBER @@ -142,6 +142,17 @@ class TestTimeit(unittest.TestCase): timer=FakeTimer()) self.assertEqual(delta_time, 0) + def test_timeit_globals_args(self): + global _global_timer + _global_timer = FakeTimer() + t = timeit.Timer(stmt='_global_timer.inc()', timer=_global_timer) + self.assertRaises(NameError, t.timeit, number=3) + timeit.timeit(stmt='_global_timer.inc()', timer=_global_timer, + globals=globals(), number=3) + local_timer = FakeTimer() + timeit.timeit(stmt='local_timer.inc()', timer=local_timer, + globals=locals(), number=3) + def repeat(self, stmt, setup, repeat=None, number=None): self.fake_timer = FakeTimer() t = timeit.Timer(stmt=stmt, setup=setup, timer=self.fake_timer) @@ -261,6 +272,12 @@ class TestTimeit(unittest.TestCase): self.assertEqual(s, "CustomSetup\n" * 3 + "35 loops, best of 3: 2 sec per loop\n") + def test_main_multiple_setups(self): + s = self.run_main(seconds_per_increment=2.0, + switches=['-n35', '-s', 'a = "CustomSetup"', '-s', 'print(a)']) + self.assertEqual(s, "CustomSetup\n" * 3 + + "35 loops, best of 3: 2 sec per loop\n") + def test_main_fixed_reps(self): s = self.run_main(seconds_per_increment=60.0, switches=['-r9']) self.assertEqual(s, "10 loops, best of 9: 60 sec per loop\n") @@ -307,6 +324,26 @@ class TestTimeit(unittest.TestCase): 10000 loops, best of 3: 50 usec per loop """)) + def test_main_with_time_unit(self): + unit_sec = self.run_main(seconds_per_increment=0.002, + switches=['-u', 'sec']) + self.assertEqual(unit_sec, + "1000 loops, best of 3: 0.002 sec per loop\n") + unit_msec = self.run_main(seconds_per_increment=0.002, + switches=['-u', 'msec']) + self.assertEqual(unit_msec, + "1000 loops, best of 3: 2 msec per loop\n") + unit_usec = self.run_main(seconds_per_increment=0.002, + switches=['-u', 'usec']) + self.assertEqual(unit_usec, + "1000 loops, best of 3: 2e+03 usec per loop\n") + # Test invalid unit input + with captured_stderr() as error_stringio: + invalid = self.run_main(seconds_per_increment=0.002, + switches=['-u', 'parsec']) + self.assertEqual(error_stringio.getvalue(), + "Unrecognized unit. Please select usec, msec, or sec.\n") + def test_main_exception(self): with captured_stderr() as error_stringio: s = self.run_main(switches=['1/0']) @@ -318,8 +355,5 @@ class TestTimeit(unittest.TestCase): self.assert_exc_string(error_stringio.getvalue(), 'ZeroDivisionError') -def test_main(): - run_unittest(TestTimeit) - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_timeout.py b/Lib/test/test_timeout.py index 703c43a..3c75dcc 100644 --- a/Lib/test/test_timeout.py +++ b/Lib/test/test_timeout.py @@ -243,14 +243,14 @@ class TCPTimeoutTestCase(TimeoutTestCase): def testAcceptTimeout(self): # Test accept() timeout support.bind_port(self.sock, self.localhost) - self.sock.listen(5) + self.sock.listen() self._sock_operation(1, 1.5, 'accept') def testSend(self): # Test send() timeout with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as serv: support.bind_port(serv, self.localhost) - serv.listen(5) + serv.listen() self.sock.connect(serv.getsockname()) # Send a lot of data in order to bypass buffering in the TCP stack. self._sock_operation(100, 1.5, 'send', b"X" * 200000) @@ -259,7 +259,7 @@ class TCPTimeoutTestCase(TimeoutTestCase): # Test sendto() timeout with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as serv: support.bind_port(serv, self.localhost) - serv.listen(5) + serv.listen() self.sock.connect(serv.getsockname()) # The address argument is ignored since we already connected. self._sock_operation(100, 1.5, 'sendto', b"X" * 200000, @@ -269,7 +269,7 @@ class TCPTimeoutTestCase(TimeoutTestCase): # Test sendall() timeout with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as serv: support.bind_port(serv, self.localhost) - serv.listen(5) + serv.listen() self.sock.connect(serv.getsockname()) # Send a lot of data in order to bypass buffering in the TCP stack. self._sock_operation(100, 1.5, 'sendall', b"X" * 200000) diff --git a/Lib/test/test_tix.py b/Lib/test/test_tix.py new file mode 100644 index 0000000..e6ea3d0 --- /dev/null +++ b/Lib/test/test_tix.py @@ -0,0 +1,32 @@ +import unittest +from test import support +import sys + +# Skip this test if the _tkinter module wasn't built. +_tkinter = support.import_module('_tkinter') + +# Skip test if tk cannot be initialized. +support.requires('gui') + +from tkinter import tix, TclError + + +class TestTix(unittest.TestCase): + + def setUp(self): + try: + self.root = tix.Tk() + except TclError: + if sys.platform.startswith('win'): + self.fail('Tix should always be available on Windows') + self.skipTest('Tix not available') + else: + self.addCleanup(self.root.destroy) + + def test_tix_available(self): + # this test is just here to make setUp run + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_tk.py b/Lib/test/test_tk.py index 62729f0..48cefd9 100644 --- a/Lib/test/test_tk.py +++ b/Lib/test/test_tk.py @@ -2,9 +2,6 @@ from test import support # Skip test if _tkinter wasn't built. support.import_module('_tkinter') -# Make sure tkinter._fix runs to set up the environment -support.import_fresh_module('tkinter') - # Skip test if tk cannot be initialized. support.requires('gui') diff --git a/Lib/test/test_tokenize.py b/Lib/test/test_tokenize.py index 6506b67..b7ca089 100644 --- a/Lib/test/test_tokenize.py +++ b/Lib/test/test_tokenize.py @@ -466,7 +466,7 @@ Additive Multiplicative - >>> dump_tokens("x = 1//1*1/5*12%0x12") + >>> dump_tokens("x = 1//1*1/5*12%0x12@42") ENCODING 'utf-8' (0, 0) (0, 0) NAME 'x' (1, 0) (1, 1) OP '=' (1, 2) (1, 3) @@ -481,6 +481,8 @@ Multiplicative NUMBER '12' (1, 13) (1, 15) OP '%' (1, 15) (1, 16) NUMBER '0x12' (1, 16) (1, 20) + OP '@' (1, 20) (1, 21) + NUMBER '42' (1, 21) (1, 23) Unary @@ -641,6 +643,276 @@ Legacy unicode literals: NAME 'grün' (2, 0) (2, 4) OP '=' (2, 5) (2, 6) STRING "U'green'" (2, 7) (2, 15) + +Async/await extension: + + >>> dump_tokens("async = 1") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'async' (1, 0) (1, 5) + OP '=' (1, 6) (1, 7) + NUMBER '1' (1, 8) (1, 9) + + >>> dump_tokens("a = (async = 1)") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'a' (1, 0) (1, 1) + OP '=' (1, 2) (1, 3) + OP '(' (1, 4) (1, 5) + NAME 'async' (1, 5) (1, 10) + OP '=' (1, 11) (1, 12) + NUMBER '1' (1, 13) (1, 14) + OP ')' (1, 14) (1, 15) + + >>> dump_tokens("async()") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'async' (1, 0) (1, 5) + OP '(' (1, 5) (1, 6) + OP ')' (1, 6) (1, 7) + + >>> dump_tokens("class async(Bar):pass") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'class' (1, 0) (1, 5) + NAME 'async' (1, 6) (1, 11) + OP '(' (1, 11) (1, 12) + NAME 'Bar' (1, 12) (1, 15) + OP ')' (1, 15) (1, 16) + OP ':' (1, 16) (1, 17) + NAME 'pass' (1, 17) (1, 21) + + >>> dump_tokens("class async:pass") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'class' (1, 0) (1, 5) + NAME 'async' (1, 6) (1, 11) + OP ':' (1, 11) (1, 12) + NAME 'pass' (1, 12) (1, 16) + + >>> dump_tokens("await = 1") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'await' (1, 0) (1, 5) + OP '=' (1, 6) (1, 7) + NUMBER '1' (1, 8) (1, 9) + + >>> dump_tokens("foo.async") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'foo' (1, 0) (1, 3) + OP '.' (1, 3) (1, 4) + NAME 'async' (1, 4) (1, 9) + + >>> dump_tokens("async for a in b: pass") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'async' (1, 0) (1, 5) + NAME 'for' (1, 6) (1, 9) + NAME 'a' (1, 10) (1, 11) + NAME 'in' (1, 12) (1, 14) + NAME 'b' (1, 15) (1, 16) + OP ':' (1, 16) (1, 17) + NAME 'pass' (1, 18) (1, 22) + + >>> dump_tokens("async with a as b: pass") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'async' (1, 0) (1, 5) + NAME 'with' (1, 6) (1, 10) + NAME 'a' (1, 11) (1, 12) + NAME 'as' (1, 13) (1, 15) + NAME 'b' (1, 16) (1, 17) + OP ':' (1, 17) (1, 18) + NAME 'pass' (1, 19) (1, 23) + + >>> dump_tokens("async.foo") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'async' (1, 0) (1, 5) + OP '.' (1, 5) (1, 6) + NAME 'foo' (1, 6) (1, 9) + + >>> dump_tokens("async") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'async' (1, 0) (1, 5) + + >>> dump_tokens("async\\n#comment\\nawait") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'async' (1, 0) (1, 5) + NEWLINE '\\n' (1, 5) (1, 6) + COMMENT '#comment' (2, 0) (2, 8) + NL '\\n' (2, 8) (2, 9) + NAME 'await' (3, 0) (3, 5) + + >>> dump_tokens("async\\n...\\nawait") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'async' (1, 0) (1, 5) + NEWLINE '\\n' (1, 5) (1, 6) + OP '...' (2, 0) (2, 3) + NEWLINE '\\n' (2, 3) (2, 4) + NAME 'await' (3, 0) (3, 5) + + >>> dump_tokens("async\\nawait") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'async' (1, 0) (1, 5) + NEWLINE '\\n' (1, 5) (1, 6) + NAME 'await' (2, 0) (2, 5) + + >>> dump_tokens("foo.async + 1") + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'foo' (1, 0) (1, 3) + OP '.' (1, 3) (1, 4) + NAME 'async' (1, 4) (1, 9) + OP '+' (1, 10) (1, 11) + NUMBER '1' (1, 12) (1, 13) + + >>> dump_tokens("async def foo(): pass") + ENCODING 'utf-8' (0, 0) (0, 0) + ASYNC 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'foo' (1, 10) (1, 13) + OP '(' (1, 13) (1, 14) + OP ')' (1, 14) (1, 15) + OP ':' (1, 15) (1, 16) + NAME 'pass' (1, 17) (1, 21) + + >>> dump_tokens('''async def foo(): + ... def foo(await): + ... await = 1 + ... if 1: + ... await + ... async += 1 + ... ''') + ENCODING 'utf-8' (0, 0) (0, 0) + ASYNC 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'foo' (1, 10) (1, 13) + OP '(' (1, 13) (1, 14) + OP ')' (1, 14) (1, 15) + OP ':' (1, 15) (1, 16) + NEWLINE '\\n' (1, 16) (1, 17) + INDENT ' ' (2, 0) (2, 2) + NAME 'def' (2, 2) (2, 5) + NAME 'foo' (2, 6) (2, 9) + OP '(' (2, 9) (2, 10) + AWAIT 'await' (2, 10) (2, 15) + OP ')' (2, 15) (2, 16) + OP ':' (2, 16) (2, 17) + NEWLINE '\\n' (2, 17) (2, 18) + INDENT ' ' (3, 0) (3, 4) + AWAIT 'await' (3, 4) (3, 9) + OP '=' (3, 10) (3, 11) + NUMBER '1' (3, 12) (3, 13) + NEWLINE '\\n' (3, 13) (3, 14) + DEDENT '' (4, 2) (4, 2) + NAME 'if' (4, 2) (4, 4) + NUMBER '1' (4, 5) (4, 6) + OP ':' (4, 6) (4, 7) + NEWLINE '\\n' (4, 7) (4, 8) + INDENT ' ' (5, 0) (5, 4) + AWAIT 'await' (5, 4) (5, 9) + NEWLINE '\\n' (5, 9) (5, 10) + DEDENT '' (6, 0) (6, 0) + DEDENT '' (6, 0) (6, 0) + NAME 'async' (6, 0) (6, 5) + OP '+=' (6, 6) (6, 8) + NUMBER '1' (6, 9) (6, 10) + NEWLINE '\\n' (6, 10) (6, 11) + + >>> dump_tokens('''async def foo(): + ... async for i in 1: pass''') + ENCODING 'utf-8' (0, 0) (0, 0) + ASYNC 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'foo' (1, 10) (1, 13) + OP '(' (1, 13) (1, 14) + OP ')' (1, 14) (1, 15) + OP ':' (1, 15) (1, 16) + NEWLINE '\\n' (1, 16) (1, 17) + INDENT ' ' (2, 0) (2, 2) + ASYNC 'async' (2, 2) (2, 7) + NAME 'for' (2, 8) (2, 11) + NAME 'i' (2, 12) (2, 13) + NAME 'in' (2, 14) (2, 16) + NUMBER '1' (2, 17) (2, 18) + OP ':' (2, 18) (2, 19) + NAME 'pass' (2, 20) (2, 24) + DEDENT '' (3, 0) (3, 0) + + >>> dump_tokens('''async def foo(async): await''') + ENCODING 'utf-8' (0, 0) (0, 0) + ASYNC 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'foo' (1, 10) (1, 13) + OP '(' (1, 13) (1, 14) + ASYNC 'async' (1, 14) (1, 19) + OP ')' (1, 19) (1, 20) + OP ':' (1, 20) (1, 21) + AWAIT 'await' (1, 22) (1, 27) + + >>> dump_tokens('''def f(): + ... + ... def baz(): pass + ... async def bar(): pass + ... + ... await = 2''') + ENCODING 'utf-8' (0, 0) (0, 0) + NAME 'def' (1, 0) (1, 3) + NAME 'f' (1, 4) (1, 5) + OP '(' (1, 5) (1, 6) + OP ')' (1, 6) (1, 7) + OP ':' (1, 7) (1, 8) + NEWLINE '\\n' (1, 8) (1, 9) + NL '\\n' (2, 0) (2, 1) + INDENT ' ' (3, 0) (3, 2) + NAME 'def' (3, 2) (3, 5) + NAME 'baz' (3, 6) (3, 9) + OP '(' (3, 9) (3, 10) + OP ')' (3, 10) (3, 11) + OP ':' (3, 11) (3, 12) + NAME 'pass' (3, 13) (3, 17) + NEWLINE '\\n' (3, 17) (3, 18) + ASYNC 'async' (4, 2) (4, 7) + NAME 'def' (4, 8) (4, 11) + NAME 'bar' (4, 12) (4, 15) + OP '(' (4, 15) (4, 16) + OP ')' (4, 16) (4, 17) + OP ':' (4, 17) (4, 18) + NAME 'pass' (4, 19) (4, 23) + NEWLINE '\\n' (4, 23) (4, 24) + NL '\\n' (5, 0) (5, 1) + NAME 'await' (6, 2) (6, 7) + OP '=' (6, 8) (6, 9) + NUMBER '2' (6, 10) (6, 11) + DEDENT '' (7, 0) (7, 0) + + >>> dump_tokens('''async def f(): + ... + ... def baz(): pass + ... async def bar(): pass + ... + ... await = 2''') + ENCODING 'utf-8' (0, 0) (0, 0) + ASYNC 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'f' (1, 10) (1, 11) + OP '(' (1, 11) (1, 12) + OP ')' (1, 12) (1, 13) + OP ':' (1, 13) (1, 14) + NEWLINE '\\n' (1, 14) (1, 15) + NL '\\n' (2, 0) (2, 1) + INDENT ' ' (3, 0) (3, 2) + NAME 'def' (3, 2) (3, 5) + NAME 'baz' (3, 6) (3, 9) + OP '(' (3, 9) (3, 10) + OP ')' (3, 10) (3, 11) + OP ':' (3, 11) (3, 12) + NAME 'pass' (3, 13) (3, 17) + NEWLINE '\\n' (3, 17) (3, 18) + ASYNC 'async' (4, 2) (4, 7) + NAME 'def' (4, 8) (4, 11) + NAME 'bar' (4, 12) (4, 15) + OP '(' (4, 15) (4, 16) + OP ')' (4, 16) (4, 17) + OP ':' (4, 17) (4, 18) + NAME 'pass' (4, 19) (4, 23) + NEWLINE '\\n' (4, 23) (4, 24) + NL '\\n' (5, 0) (5, 1) + AWAIT 'await' (6, 2) (6, 7) + OP '=' (6, 8) (6, 9) + NUMBER '2' (6, 10) (6, 11) + DEDENT '' (7, 0) (7, 0) """ from test import support @@ -1111,6 +1383,17 @@ class TestTokenize(TestCase): self.assertTrue(encoding_used, encoding) + def test_oneline_defs(self): + buf = [] + for i in range(500): + buf.append('def i{i}(): return {i}'.format(i=i)) + buf.append('OK') + buf = '\n'.join(buf) + + # Test that 500 consequent, one-line defs is OK + toks = list(tokenize(BytesIO(buf.encode('utf-8')).readline)) + self.assertEqual(toks[-2].string, 'OK') # [-1] is always ENDMARKER + def assertExactTypeEqual(self, opstr, *optypes): tokens = list(tokenize(BytesIO(opstr.encode('utf-8')).readline)) num_optypes = len(optypes) @@ -1165,6 +1448,7 @@ class TestTokenize(TestCase): self.assertExactTypeEqual('//', token.DOUBLESLASH) self.assertExactTypeEqual('//=', token.DOUBLESLASHEQUAL) self.assertExactTypeEqual('@', token.AT) + self.assertExactTypeEqual('@=', token.ATEQUAL) self.assertExactTypeEqual('a**2+b**2==c**2', NAME, token.DOUBLESTAR, NUMBER, diff --git a/Lib/test/test_tools/test_i18n.py b/Lib/test/test_tools/test_i18n.py new file mode 100644 index 0000000..6eaa8dd --- /dev/null +++ b/Lib/test/test_tools/test_i18n.py @@ -0,0 +1,68 @@ +"""Tests to cover the Tools/i18n package""" + +import os +import unittest + +from test.support.script_helper import assert_python_ok +from test.test_tools import toolsdir +from test.support import temp_cwd + +class Test_pygettext(unittest.TestCase): + """Tests for the pygettext.py tool""" + + script = os.path.join(toolsdir,'i18n', 'pygettext.py') + + def get_header(self, data): + """ utility: return the header of a .po file as a dictionary """ + headers = {} + for line in data.split('\n'): + if not line or line.startswith(('#', 'msgid','msgstr')): + continue + line = line.strip('"') + key, val = line.split(':',1) + headers[key] = val.strip() + return headers + + def test_header(self): + """Make sure the required fields are in the header, according to: + http://www.gnu.org/software/gettext/manual/gettext.html#Header-Entry + """ + with temp_cwd(None) as cwd: + assert_python_ok(self.script) + with open('messages.pot') as fp: + data = fp.read() + header = self.get_header(data) + + self.assertIn("Project-Id-Version", header) + self.assertIn("POT-Creation-Date", header) + self.assertIn("PO-Revision-Date", header) + self.assertIn("Last-Translator", header) + self.assertIn("Language-Team", header) + self.assertIn("MIME-Version", header) + self.assertIn("Content-Type", header) + self.assertIn("Content-Transfer-Encoding", header) + self.assertIn("Generated-By", header) + + # not clear if these should be required in POT (template) files + #self.assertIn("Report-Msgid-Bugs-To", header) + #self.assertIn("Language", header) + + #"Plural-Forms" is optional + + + def test_POT_Creation_Date(self): + """ Match the date format from xgettext for POT-Creation-Date """ + from datetime import datetime + with temp_cwd(None) as cwd: + assert_python_ok(self.script) + with open('messages.pot') as fp: + data = fp.read() + header = self.get_header(data) + creationDate = header['POT-Creation-Date'] + + # peel off the escaped newline at the end of string + if creationDate.endswith('\\n'): + creationDate = creationDate[:-len('\\n')] + + # This will raise if the date format does not exactly match. + datetime.strptime(creationDate, '%Y-%m-%d %H:%M%z') diff --git a/Lib/test/test_tools/test_md5sum.py b/Lib/test/test_tools/test_md5sum.py index 59ea149..1305295 100644 --- a/Lib/test/test_tools/test_md5sum.py +++ b/Lib/test/test_tools/test_md5sum.py @@ -4,7 +4,7 @@ import os import sys import unittest from test import support -from test.script_helper import assert_python_ok, assert_python_failure +from test.support.script_helper import assert_python_ok, assert_python_failure from test.test_tools import scriptsdir, import_tool, skip_if_missing diff --git a/Lib/test/test_tools/test_pindent.py b/Lib/test/test_tools/test_pindent.py index 14a0aa2..e293bc8 100644 --- a/Lib/test/test_tools/test_pindent.py +++ b/Lib/test/test_tools/test_pindent.py @@ -6,7 +6,7 @@ import unittest import subprocess import textwrap from test import support -from test.script_helper import assert_python_ok +from test.support.script_helper import assert_python_ok from test.test_tools import scriptsdir, skip_if_missing diff --git a/Lib/test/test_tools/test_reindent.py b/Lib/test/test_tools/test_reindent.py index 45cebf7..d7c20e1 100644 --- a/Lib/test/test_tools/test_reindent.py +++ b/Lib/test/test_tools/test_reindent.py @@ -6,7 +6,7 @@ Tools directory of a Python checkout or tarball, such as reindent.py. import os import unittest -from test.script_helper import assert_python_ok +from test.support.script_helper import assert_python_ok from test.test_tools import scriptsdir, skip_if_missing diff --git a/Lib/test/test_trace.py b/Lib/test/test_trace.py index ee33986..03dff84 100644 --- a/Lib/test/test_trace.py +++ b/Lib/test/test_trace.py @@ -9,12 +9,11 @@ from trace import CoverageResults, Trace from test.tracedmodules import testmod - #------------------------------- Utilities -----------------------------------# def fix_ext_py(filename): - """Given a .pyc/.pyo filename converts it to the appropriate .py""" - if filename.endswith(('.pyc', '.pyo')): + """Given a .pyc filename converts it to the appropriate .py""" + if filename.endswith('.pyc'): filename = filename[:-1] return filename @@ -223,6 +222,11 @@ class TestFuncs(unittest.TestCase): self.addCleanup(sys.settrace, sys.gettrace()) self.tracer = Trace(count=0, trace=0, countfuncs=1) self.filemod = my_file_and_modname() + self._saved_tracefunc = sys.gettrace() + + def tearDown(self): + if self._saved_tracefunc is not None: + sys.settrace(self._saved_tracefunc) def test_simple_caller(self): self.tracer.runfunc(traced_func_simple_caller, 1) diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py index d6c9df2..b7695d6 100644 --- a/Lib/test/test_traceback.py +++ b/Lib/test/test_traceback.py @@ -1,15 +1,24 @@ """Test cases for traceback module""" +from collections import namedtuple from io import StringIO +import linecache import sys import unittest import re -from test.support import run_unittest, Error, captured_output -from test.support import TESTFN, unlink, cpython_only +from test import support +from test.support import TESTFN, Error, captured_output, unlink, cpython_only +from test.support.script_helper import assert_python_ok +import textwrap import traceback +test_code = namedtuple('code', ['co_filename', 'co_name']) +test_frame = namedtuple('frame', ['f_code', 'f_globals', 'f_locals']) +test_tb = namedtuple('tb', ['tb_frame', 'tb_lineno', 'tb_next']) + + class SyntaxTracebackCases(unittest.TestCase): # For now, a very minimal set of tests. I want to be sure that # formatting of SyntaxErrors works based on changes for 2.1. @@ -92,9 +101,9 @@ class SyntaxTracebackCases(unittest.TestCase): self.assertEqual(len(err), 1) str_value = '<unprintable %s object>' % X.__name__ if X.__module__ in ('__main__', 'builtins'): - str_name = X.__name__ + str_name = X.__qualname__ else: - str_name = '.'.join([X.__module__, X.__name__]) + str_name = '.'.join([X.__module__, X.__qualname__]) self.assertEqual(err[0], "%s: %s\n" % (str_name, str_value)) def test_without_exception(self): @@ -169,6 +178,44 @@ class SyntaxTracebackCases(unittest.TestCase): # Issue #18960: coding spec should has no effect do_test("0\n# coding: GBK\n", "h\xe9 ho", 'utf-8', 5) + def test_print_traceback_at_exit(self): + # Issue #22599: Ensure that it is possible to use the traceback module + # to display an exception at Python exit + code = textwrap.dedent(""" + import sys + import traceback + + class PrintExceptionAtExit(object): + def __init__(self): + try: + x = 1 / 0 + except Exception: + self.exc_info = sys.exc_info() + # self.exc_info[1] (traceback) contains frames: + # explicitly clear the reference to self in the current + # frame to break a reference cycle + self = None + + def __del__(self): + traceback.print_exception(*self.exc_info) + + # Keep a reference in the module namespace to call the destructor + # when the module is unloaded + obj = PrintExceptionAtExit() + """) + rc, stdout, stderr = assert_python_ok('-c', code) + expected = [b'Traceback (most recent call last):', + b' File "<string>", line 8, in __init__', + b'ZeroDivisionError: division by zero'] + self.assertEqual(stderr.splitlines(), expected) + + def test_print_exception(self): + output = StringIO() + traceback.print_exception( + Exception, Exception("projector"), None, file=output + ) + self.assertEqual(output.getvalue(), "Exception: projector\n") + class TracebackFormatTests(unittest.TestCase): @@ -439,6 +486,126 @@ class CExcReportingTests(BaseExceptionReportingTests, unittest.TestCase): return s.getvalue() +class LimitTests(unittest.TestCase): + + ''' Tests for limit argument. + It's enough to test extact_tb, extract_stack and format_exception ''' + + def last_raises1(self): + raise Exception('Last raised') + + def last_raises2(self): + self.last_raises1() + + def last_raises3(self): + self.last_raises2() + + def last_raises4(self): + self.last_raises3() + + def last_raises5(self): + self.last_raises4() + + def last_returns_frame1(self): + return sys._getframe() + + def last_returns_frame2(self): + return self.last_returns_frame1() + + def last_returns_frame3(self): + return self.last_returns_frame2() + + def last_returns_frame4(self): + return self.last_returns_frame3() + + def last_returns_frame5(self): + return self.last_returns_frame4() + + def test_extract_stack(self): + frame = self.last_returns_frame5() + def extract(**kwargs): + return traceback.extract_stack(frame, **kwargs) + def assertEqualExcept(actual, expected, ignore): + self.assertEqual(actual[:ignore], expected[:ignore]) + self.assertEqual(actual[ignore+1:], expected[ignore+1:]) + self.assertEqual(len(actual), len(expected)) + + with support.swap_attr(sys, 'tracebacklimit', 1000): + nolim = extract() + self.assertGreater(len(nolim), 5) + self.assertEqual(extract(limit=2), nolim[-2:]) + assertEqualExcept(extract(limit=100), nolim[-100:], -5-1) + self.assertEqual(extract(limit=-2), nolim[:2]) + assertEqualExcept(extract(limit=-100), nolim[:100], len(nolim)-5-1) + self.assertEqual(extract(limit=0), []) + del sys.tracebacklimit + assertEqualExcept(extract(), nolim, -5-1) + sys.tracebacklimit = 2 + self.assertEqual(extract(), nolim[-2:]) + self.assertEqual(extract(limit=3), nolim[-3:]) + self.assertEqual(extract(limit=-3), nolim[:3]) + sys.tracebacklimit = 0 + self.assertEqual(extract(), []) + sys.tracebacklimit = -1 + self.assertEqual(extract(), []) + + def test_extract_tb(self): + try: + self.last_raises5() + except Exception: + exc_type, exc_value, tb = sys.exc_info() + def extract(**kwargs): + return traceback.extract_tb(tb, **kwargs) + + with support.swap_attr(sys, 'tracebacklimit', 1000): + nolim = extract() + self.assertEqual(len(nolim), 5+1) + self.assertEqual(extract(limit=2), nolim[:2]) + self.assertEqual(extract(limit=10), nolim) + self.assertEqual(extract(limit=-2), nolim[-2:]) + self.assertEqual(extract(limit=-10), nolim) + self.assertEqual(extract(limit=0), []) + del sys.tracebacklimit + self.assertEqual(extract(), nolim) + sys.tracebacklimit = 2 + self.assertEqual(extract(), nolim[:2]) + self.assertEqual(extract(limit=3), nolim[:3]) + self.assertEqual(extract(limit=-3), nolim[-3:]) + sys.tracebacklimit = 0 + self.assertEqual(extract(), []) + sys.tracebacklimit = -1 + self.assertEqual(extract(), []) + + def test_format_exception(self): + try: + self.last_raises5() + except Exception: + exc_type, exc_value, tb = sys.exc_info() + # [1:-1] to exclude "Traceback (...)" header and + # exception type and value + def extract(**kwargs): + return traceback.format_exception(exc_type, exc_value, tb, **kwargs)[1:-1] + + with support.swap_attr(sys, 'tracebacklimit', 1000): + nolim = extract() + self.assertEqual(len(nolim), 5+1) + self.assertEqual(extract(limit=2), nolim[:2]) + self.assertEqual(extract(limit=10), nolim) + self.assertEqual(extract(limit=-2), nolim[-2:]) + self.assertEqual(extract(limit=-10), nolim) + self.assertEqual(extract(limit=0), []) + del sys.tracebacklimit + self.assertEqual(extract(), nolim) + sys.tracebacklimit = 2 + self.assertEqual(extract(), nolim[:2]) + self.assertEqual(extract(limit=3), nolim[:3]) + self.assertEqual(extract(limit=-3), nolim[-3:]) + sys.tracebacklimit = 0 + self.assertEqual(extract(), []) + sys.tracebacklimit = -1 + self.assertEqual(extract(), []) + + class MiscTracebackCases(unittest.TestCase): # # Check non-printing functions in traceback module @@ -476,11 +643,279 @@ class MiscTracebackCases(unittest.TestCase): self.assertEqual(result[-2:], [ (__file__, lineno+2, 'test_extract_stack', 'result = extract()'), (__file__, lineno+1, 'extract', 'return traceback.extract_stack()'), - ]) + ]) + + +class TestFrame(unittest.TestCase): + + def test_basics(self): + linecache.clearcache() + linecache.lazycache("f", globals()) + f = traceback.FrameSummary("f", 1, "dummy") + self.assertEqual(f, + ("f", 1, "dummy", '"""Test cases for traceback module"""')) + self.assertEqual(tuple(f), + ("f", 1, "dummy", '"""Test cases for traceback module"""')) + self.assertEqual(f, traceback.FrameSummary("f", 1, "dummy")) + self.assertEqual(f, tuple(f)) + # Since tuple.__eq__ doesn't support FrameSummary, the equality + # operator fallbacks to FrameSummary.__eq__. + self.assertEqual(tuple(f), f) + self.assertIsNone(f.locals) + + def test_lazy_lines(self): + linecache.clearcache() + f = traceback.FrameSummary("f", 1, "dummy", lookup_line=False) + self.assertEqual(None, f._line) + linecache.lazycache("f", globals()) + self.assertEqual( + '"""Test cases for traceback module"""', + f.line) + + def test_explicit_line(self): + f = traceback.FrameSummary("f", 1, "dummy", line="line") + self.assertEqual("line", f.line) + + +class TestStack(unittest.TestCase): + + def test_walk_stack(self): + s = list(traceback.walk_stack(None)) + self.assertGreater(len(s), 10) + + def test_walk_tb(self): + try: + 1/0 + except Exception: + _, _, tb = sys.exc_info() + s = list(traceback.walk_tb(tb)) + self.assertEqual(len(s), 1) + + def test_extract_stack(self): + s = traceback.StackSummary.extract(traceback.walk_stack(None)) + self.assertIsInstance(s, traceback.StackSummary) + + def test_extract_stack_limit(self): + s = traceback.StackSummary.extract(traceback.walk_stack(None), limit=5) + self.assertEqual(len(s), 5) + + def test_extract_stack_lookup_lines(self): + linecache.clearcache() + linecache.updatecache('/foo.py', globals()) + c = test_code('/foo.py', 'method') + f = test_frame(c, None, None) + s = traceback.StackSummary.extract(iter([(f, 6)]), lookup_lines=True) + linecache.clearcache() + self.assertEqual(s[0].line, "import sys") + + def test_extract_stackup_deferred_lookup_lines(self): + linecache.clearcache() + c = test_code('/foo.py', 'method') + f = test_frame(c, None, None) + s = traceback.StackSummary.extract(iter([(f, 6)]), lookup_lines=False) + self.assertEqual({}, linecache.cache) + linecache.updatecache('/foo.py', globals()) + self.assertEqual(s[0].line, "import sys") + + def test_from_list(self): + s = traceback.StackSummary.from_list([('foo.py', 1, 'fred', 'line')]) + self.assertEqual( + [' File "foo.py", line 1, in fred\n line\n'], + s.format()) + + def test_from_list_edited_stack(self): + s = traceback.StackSummary.from_list([('foo.py', 1, 'fred', 'line')]) + s[0] = ('foo.py', 2, 'fred', 'line') + s2 = traceback.StackSummary.from_list(s) + self.assertEqual( + [' File "foo.py", line 2, in fred\n line\n'], + s2.format()) + + def test_format_smoke(self): + # For detailed tests see the format_list tests, which consume the same + # code. + s = traceback.StackSummary.from_list([('foo.py', 1, 'fred', 'line')]) + self.assertEqual( + [' File "foo.py", line 1, in fred\n line\n'], + s.format()) + + def test_locals(self): + linecache.updatecache('/foo.py', globals()) + c = test_code('/foo.py', 'method') + f = test_frame(c, globals(), {'something': 1}) + s = traceback.StackSummary.extract(iter([(f, 6)]), capture_locals=True) + self.assertEqual(s[0].locals, {'something': '1'}) + + def test_no_locals(self): + linecache.updatecache('/foo.py', globals()) + c = test_code('/foo.py', 'method') + f = test_frame(c, globals(), {'something': 1}) + s = traceback.StackSummary.extract(iter([(f, 6)])) + self.assertEqual(s[0].locals, None) + + def test_format_locals(self): + def some_inner(k, v): + a = 1 + b = 2 + return traceback.StackSummary.extract( + traceback.walk_stack(None), capture_locals=True, limit=1) + s = some_inner(3, 4) + self.assertEqual( + [' File "%s", line %d, in some_inner\n' + ' traceback.walk_stack(None), capture_locals=True, limit=1)\n' + ' a = 1\n' + ' b = 2\n' + ' k = 3\n' + ' v = 4\n' % (__file__, some_inner.__code__.co_firstlineno + 4) + ], s.format()) + +class TestTracebackException(unittest.TestCase): + + def test_smoke(self): + try: + 1/0 + except Exception: + exc_info = sys.exc_info() + exc = traceback.TracebackException(*exc_info) + expected_stack = traceback.StackSummary.extract( + traceback.walk_tb(exc_info[2])) + self.assertEqual(None, exc.__cause__) + self.assertEqual(None, exc.__context__) + self.assertEqual(False, exc.__suppress_context__) + self.assertEqual(expected_stack, exc.stack) + self.assertEqual(exc_info[0], exc.exc_type) + self.assertEqual(str(exc_info[1]), str(exc)) + + def test_from_exception(self): + # Check all the parameters are accepted. + def foo(): + 1/0 + try: + foo() + except Exception as e: + exc_info = sys.exc_info() + self.expected_stack = traceback.StackSummary.extract( + traceback.walk_tb(exc_info[2]), limit=1, lookup_lines=False, + capture_locals=True) + self.exc = traceback.TracebackException.from_exception( + e, limit=1, lookup_lines=False, capture_locals=True) + expected_stack = self.expected_stack + exc = self.exc + self.assertEqual(None, exc.__cause__) + self.assertEqual(None, exc.__context__) + self.assertEqual(False, exc.__suppress_context__) + self.assertEqual(expected_stack, exc.stack) + self.assertEqual(exc_info[0], exc.exc_type) + self.assertEqual(str(exc_info[1]), str(exc)) + + def test_cause(self): + try: + try: + 1/0 + finally: + exc_info_context = sys.exc_info() + exc_context = traceback.TracebackException(*exc_info_context) + cause = Exception("cause") + raise Exception("uh oh") from cause + except Exception: + exc_info = sys.exc_info() + exc = traceback.TracebackException(*exc_info) + expected_stack = traceback.StackSummary.extract( + traceback.walk_tb(exc_info[2])) + exc_cause = traceback.TracebackException(Exception, cause, None) + self.assertEqual(exc_cause, exc.__cause__) + self.assertEqual(exc_context, exc.__context__) + self.assertEqual(True, exc.__suppress_context__) + self.assertEqual(expected_stack, exc.stack) + self.assertEqual(exc_info[0], exc.exc_type) + self.assertEqual(str(exc_info[1]), str(exc)) + def test_context(self): + try: + try: + 1/0 + finally: + exc_info_context = sys.exc_info() + exc_context = traceback.TracebackException(*exc_info_context) + raise Exception("uh oh") + except Exception: + exc_info = sys.exc_info() + exc = traceback.TracebackException(*exc_info) + expected_stack = traceback.StackSummary.extract( + traceback.walk_tb(exc_info[2])) + self.assertEqual(None, exc.__cause__) + self.assertEqual(exc_context, exc.__context__) + self.assertEqual(False, exc.__suppress_context__) + self.assertEqual(expected_stack, exc.stack) + self.assertEqual(exc_info[0], exc.exc_type) + self.assertEqual(str(exc_info[1]), str(exc)) + + def test_limit(self): + def recurse(n): + if n: + recurse(n-1) + else: + 1/0 + try: + recurse(10) + except Exception: + exc_info = sys.exc_info() + exc = traceback.TracebackException(*exc_info, limit=5) + expected_stack = traceback.StackSummary.extract( + traceback.walk_tb(exc_info[2]), limit=5) + self.assertEqual(expected_stack, exc.stack) + + def test_lookup_lines(self): + linecache.clearcache() + e = Exception("uh oh") + c = test_code('/foo.py', 'method') + f = test_frame(c, None, None) + tb = test_tb(f, 6, None) + exc = traceback.TracebackException(Exception, e, tb, lookup_lines=False) + self.assertEqual({}, linecache.cache) + linecache.updatecache('/foo.py', globals()) + self.assertEqual(exc.stack[0].line, "import sys") + + def test_locals(self): + linecache.updatecache('/foo.py', globals()) + e = Exception("uh oh") + c = test_code('/foo.py', 'method') + f = test_frame(c, globals(), {'something': 1, 'other': 'string'}) + tb = test_tb(f, 6, None) + exc = traceback.TracebackException( + Exception, e, tb, capture_locals=True) + self.assertEqual( + exc.stack[0].locals, {'something': '1', 'other': "'string'"}) + + def test_no_locals(self): + linecache.updatecache('/foo.py', globals()) + e = Exception("uh oh") + c = test_code('/foo.py', 'method') + f = test_frame(c, globals(), {'something': 1}) + tb = test_tb(f, 6, None) + exc = traceback.TracebackException(Exception, e, tb) + self.assertEqual(exc.stack[0].locals, None) + + def test_traceback_header(self): + # do not print a traceback header if exc_traceback is None + # see issue #24695 + exc = traceback.TracebackException(Exception, Exception("haven"), None) + self.assertEqual(list(exc.format()), ["Exception: haven\n"]) + + +class MiscTest(unittest.TestCase): + + def test_all(self): + expected = set() + blacklist = {'print_list'} + for name in dir(traceback): + if name.startswith('_') or name in blacklist: + continue + module_object = getattr(traceback, name) + if getattr(module_object, '__module__', None) == 'traceback': + expected.add(name) + self.assertCountEqual(traceback.__all__, expected) -def test_main(): - run_unittest(__name__) if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_tracemalloc.py b/Lib/test/test_tracemalloc.py index 48ccab2..f65e361 100644 --- a/Lib/test/test_tracemalloc.py +++ b/Lib/test/test_tracemalloc.py @@ -4,8 +4,9 @@ import sys import tracemalloc import unittest from unittest.mock import patch -from test.script_helper import assert_python_ok, assert_python_failure -from test import script_helper, support +from test.support.script_helper import (assert_python_ok, assert_python_failure, + interpreter_requires_environment) +from test import support try: import threading except ImportError: @@ -660,11 +661,9 @@ class TestFilters(unittest.TestCase): self.assertFalse(fnmatch('abcdd', 'a*c*e')) self.assertFalse(fnmatch('abcbdefef', 'a*bd*eg')) - # replace .pyc and .pyo suffix with .py + # replace .pyc suffix with .py self.assertTrue(fnmatch('a.pyc', 'a.py')) - self.assertTrue(fnmatch('a.pyo', 'a.py')) self.assertTrue(fnmatch('a.py', 'a.pyc')) - self.assertTrue(fnmatch('a.py', 'a.pyo')) if os.name == 'nt': # case insensitive @@ -672,18 +671,14 @@ class TestFilters(unittest.TestCase): self.assertTrue(fnmatch('aBcDe', 'Ab*dE')) self.assertTrue(fnmatch('a.pyc', 'a.PY')) - self.assertTrue(fnmatch('a.PYO', 'a.py')) self.assertTrue(fnmatch('a.py', 'a.PYC')) - self.assertTrue(fnmatch('a.PY', 'a.pyo')) else: # case sensitive self.assertFalse(fnmatch('aBC', 'ABc')) self.assertFalse(fnmatch('aBcDe', 'Ab*dE')) self.assertFalse(fnmatch('a.pyc', 'a.PY')) - self.assertFalse(fnmatch('a.PYO', 'a.py')) self.assertFalse(fnmatch('a.py', 'a.PYC')) - self.assertFalse(fnmatch('a.PY', 'a.pyo')) if os.name == 'nt': # normalize alternate separator "/" to the standard separator "\" @@ -698,6 +693,9 @@ class TestFilters(unittest.TestCase): self.assertFalse(fnmatch(r'a/b\c', r'a\b/c')) self.assertFalse(fnmatch(r'a/b/c', r'a\b\c')) + # as of 3.5, .pyo is no longer munged to .py + self.assertFalse(fnmatch('a.pyo', 'a.py')) + def test_filter_match_trace(self): t1 = (("a.py", 2), ("b.py", 3)) t2 = (("b.py", 4), ("b.py", 5)) @@ -755,7 +753,7 @@ class TestCommandLine(unittest.TestCase): stdout = stdout.rstrip() self.assertEqual(stdout, b'False') - @unittest.skipIf(script_helper._interpreter_requires_environment(), + @unittest.skipIf(interpreter_requires_environment(), 'Cannot run -E tests when PYTHON env vars are required.') def test_env_var_ignored_with_E(self): """PYTHON* environment variables must be ignored when -E is present.""" diff --git a/Lib/test/test_ttk_guionly.py b/Lib/test/test_ttk_guionly.py index fcdedac..490e723 100644 --- a/Lib/test/test_ttk_guionly.py +++ b/Lib/test/test_ttk_guionly.py @@ -5,12 +5,10 @@ from test import support # Skip this test if _tkinter wasn't built. support.import_module('_tkinter') -# Make sure tkinter._fix runs to set up the environment -tkinter = support.import_fresh_module('tkinter') - # Skip test if tk cannot be initialized. support.requires('gui') +import tkinter from _tkinter import TclError from tkinter import ttk from tkinter.test import runtktests diff --git a/Lib/test/test_ttk_textonly.py b/Lib/test/test_ttk_textonly.py index 1cfeb15..566fc9d 100644 --- a/Lib/test/test_ttk_textonly.py +++ b/Lib/test/test_ttk_textonly.py @@ -4,9 +4,6 @@ from test import support # Skip this test if _tkinter does not exist. support.import_module('_tkinter') -# Make sure tkinter._fix runs to set up the environment -support.import_fresh_module('tkinter') - from tkinter.test import runtktests def test_main(): diff --git a/Lib/test/test_tuple.py b/Lib/test/test_tuple.py index 51875a1..fb113ab 100644 --- a/Lib/test/test_tuple.py +++ b/Lib/test/test_tuple.py @@ -6,6 +6,11 @@ import pickle class TupleTest(seq_tests.CommonTest): type2test = tuple + def test_getitem_error(self): + msg = "tuple indices must be integers or slices" + with self.assertRaisesRegex(TypeError, msg): + ()['a'] + def test_constructors(self): super().test_constructors() # calling built-in types without argument must return empty @@ -203,8 +208,13 @@ class TupleTest(seq_tests.CommonTest): with self.assertRaises(TypeError): [3,] + T((1,2)) -def test_main(): - support.run_unittest(TupleTest) + def test_lexicographic_ordering(self): + # Issue 21100 + a = self.type2test([1, 2]) + b = self.type2test([1, 2, 0]) + c = self.type2test([1, 3]) + self.assertLess(a, b) + self.assertLess(b, c) -if __name__=="__main__": - test_main() +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_typechecks.py b/Lib/test/test_typechecks.py index 17cd5d3..a0e617b 100644 --- a/Lib/test/test_typechecks.py +++ b/Lib/test/test_typechecks.py @@ -1,7 +1,6 @@ """Unit tests for __instancecheck__ and __subclasscheck__.""" import unittest -from test import support class ABC(type): @@ -68,9 +67,5 @@ class TypeChecksTest(unittest.TestCase): self.assertEqual(isinstance(42, (SubInt,)), False) -def test_main(): - support.run_unittest(TypeChecksTest) - - if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 849cba9..5e74115 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -1,12 +1,14 @@ # Python test set -- part 6, built-in types -from test.support import run_unittest, run_with_locale -import collections +from test.support import run_with_locale +import collections.abc +import inspect import pickle import locale import sys import types -import unittest +import unittest.mock +import weakref class TypesTests(unittest.TestCase): @@ -343,6 +345,8 @@ class TypesTests(unittest.TestCase): self.assertRaises(ValueError, 3 .__format__, ",n") # can't have ',' with 'c' self.assertRaises(ValueError, 3 .__format__, ",c") + # can't have '#' with 'c' + self.assertRaises(ValueError, 3 .__format__, "#c") # ensure that only int and float type specifiers work for format_spec in ([chr(x) for x in range(ord('a'), ord('z')+1)] + @@ -1186,9 +1190,318 @@ class SimpleNamespaceTests(unittest.TestCase): types.SimpleNamespace() >= FakeSimpleNamespace() -def test_main(): - run_unittest(TypesTests, MappingProxyTests, ClassCreationTests, - SimpleNamespaceTests) +class CoroutineTests(unittest.TestCase): + def test_wrong_args(self): + samples = [None, 1, object()] + for sample in samples: + with self.assertRaisesRegex(TypeError, + 'types.coroutine.*expects a callable'): + types.coroutine(sample) + + def test_non_gen_values(self): + @types.coroutine + def foo(): + return 'spam' + self.assertEqual(foo(), 'spam') + + class Awaitable: + def __await__(self): + return () + aw = Awaitable() + @types.coroutine + def foo(): + return aw + self.assertIs(aw, foo()) + + # decorate foo second time + foo = types.coroutine(foo) + self.assertIs(aw, foo()) + + def test_async_def(self): + # Test that types.coroutine passes 'async def' coroutines + # without modification + + async def foo(): pass + foo_code = foo.__code__ + foo_flags = foo.__code__.co_flags + decorated_foo = types.coroutine(foo) + self.assertIs(foo, decorated_foo) + self.assertEqual(foo.__code__.co_flags, foo_flags) + self.assertIs(decorated_foo.__code__, foo_code) + + foo_coro = foo() + def bar(): return foo_coro + for _ in range(2): + bar = types.coroutine(bar) + coro = bar() + self.assertIs(foo_coro, coro) + self.assertEqual(coro.cr_code.co_flags, foo_flags) + coro.close() + + def test_duck_coro(self): + class CoroLike: + def send(self): pass + def throw(self): pass + def close(self): pass + def __await__(self): return self + + coro = CoroLike() + @types.coroutine + def foo(): + return coro + self.assertIs(foo(), coro) + self.assertIs(foo().__await__(), coro) + + def test_duck_corogen(self): + class CoroGenLike: + def send(self): pass + def throw(self): pass + def close(self): pass + def __await__(self): return self + def __iter__(self): return self + def __next__(self): pass + + coro = CoroGenLike() + @types.coroutine + def foo(): + return coro + self.assertIs(foo(), coro) + self.assertIs(foo().__await__(), coro) + + def test_duck_gen(self): + class GenLike: + def send(self): pass + def throw(self): pass + def close(self): pass + def __iter__(self): pass + def __next__(self): pass + + # Setup generator mock object + gen = unittest.mock.MagicMock(GenLike) + gen.__iter__ = lambda gen: gen + gen.__name__ = 'gen' + gen.__qualname__ = 'test.gen' + self.assertIsInstance(gen, collections.abc.Generator) + self.assertIs(gen, iter(gen)) + + @types.coroutine + def foo(): return gen + + wrapper = foo() + self.assertIsInstance(wrapper, types._GeneratorWrapper) + self.assertIs(wrapper.__await__(), wrapper) + # Wrapper proxies duck generators completely: + self.assertIs(iter(wrapper), wrapper) + + self.assertIsInstance(wrapper, collections.abc.Coroutine) + self.assertIsInstance(wrapper, collections.abc.Awaitable) + + self.assertIs(wrapper.__qualname__, gen.__qualname__) + self.assertIs(wrapper.__name__, gen.__name__) + + # Test AttributeErrors + for name in {'gi_running', 'gi_frame', 'gi_code', 'gi_yieldfrom', + 'cr_running', 'cr_frame', 'cr_code', 'cr_await'}: + with self.assertRaises(AttributeError): + getattr(wrapper, name) + + # Test attributes pass-through + gen.gi_running = object() + gen.gi_frame = object() + gen.gi_code = object() + gen.gi_yieldfrom = object() + self.assertIs(wrapper.gi_running, gen.gi_running) + self.assertIs(wrapper.gi_frame, gen.gi_frame) + self.assertIs(wrapper.gi_code, gen.gi_code) + self.assertIs(wrapper.gi_yieldfrom, gen.gi_yieldfrom) + self.assertIs(wrapper.cr_running, gen.gi_running) + self.assertIs(wrapper.cr_frame, gen.gi_frame) + self.assertIs(wrapper.cr_code, gen.gi_code) + self.assertIs(wrapper.cr_await, gen.gi_yieldfrom) + + wrapper.close() + gen.close.assert_called_once_with() + + wrapper.send(1) + gen.send.assert_called_once_with(1) + gen.reset_mock() + + next(wrapper) + gen.__next__.assert_called_once_with() + gen.reset_mock() + + wrapper.throw(1, 2, 3) + gen.throw.assert_called_once_with(1, 2, 3) + gen.reset_mock() + + wrapper.throw(1, 2) + gen.throw.assert_called_once_with(1, 2) + gen.reset_mock() + + wrapper.throw(1) + gen.throw.assert_called_once_with(1) + gen.reset_mock() + + # Test exceptions propagation + error = Exception() + gen.throw.side_effect = error + try: + wrapper.throw(1) + except Exception as ex: + self.assertIs(ex, error) + else: + self.fail('wrapper did not propagate an exception') + + # Test invalid args + gen.reset_mock() + with self.assertRaises(TypeError): + wrapper.throw() + self.assertFalse(gen.throw.called) + with self.assertRaises(TypeError): + wrapper.close(1) + self.assertFalse(gen.close.called) + with self.assertRaises(TypeError): + wrapper.send() + self.assertFalse(gen.send.called) + + # Test that we do not double wrap + @types.coroutine + def bar(): return wrapper + self.assertIs(wrapper, bar()) + + # Test weakrefs support + ref = weakref.ref(wrapper) + self.assertIs(ref(), wrapper) + + def test_duck_functional_gen(self): + class Generator: + """Emulates the following generator (very clumsy): + + def gen(fut): + result = yield fut + return result * 2 + """ + def __init__(self, fut): + self._i = 0 + self._fut = fut + def __iter__(self): + return self + def __next__(self): + return self.send(None) + def send(self, v): + try: + if self._i == 0: + assert v is None + return self._fut + if self._i == 1: + raise StopIteration(v * 2) + if self._i > 1: + raise StopIteration + finally: + self._i += 1 + def throw(self, tp, *exc): + self._i = 100 + if tp is not GeneratorExit: + raise tp + def close(self): + self.throw(GeneratorExit) + + @types.coroutine + def foo(): return Generator('spam') + + wrapper = foo() + self.assertIsInstance(wrapper, types._GeneratorWrapper) + + async def corofunc(): + return await foo() + 100 + coro = corofunc() + + self.assertEqual(coro.send(None), 'spam') + try: + coro.send(20) + except StopIteration as ex: + self.assertEqual(ex.args[0], 140) + else: + self.fail('StopIteration was expected') + + def test_gen(self): + def gen_func(): + yield 1 + return (yield 2) + gen = gen_func() + @types.coroutine + def foo(): return gen + wrapper = foo() + self.assertIsInstance(wrapper, types._GeneratorWrapper) + self.assertIs(wrapper.__await__(), gen) + + for name in ('__name__', '__qualname__', 'gi_code', + 'gi_running', 'gi_frame'): + self.assertIs(getattr(foo(), name), + getattr(gen, name)) + self.assertIs(foo().cr_code, gen.gi_code) + + self.assertEqual(next(wrapper), 1) + self.assertEqual(wrapper.send(None), 2) + with self.assertRaisesRegex(StopIteration, 'spam'): + wrapper.send('spam') + + gen = gen_func() + wrapper = foo() + wrapper.send(None) + with self.assertRaisesRegex(Exception, 'ham'): + wrapper.throw(Exception, Exception('ham')) + + # decorate foo second time + foo = types.coroutine(foo) + self.assertIs(foo().__await__(), gen) + + def test_returning_itercoro(self): + @types.coroutine + def gen(): + yield + + gencoro = gen() + + @types.coroutine + def foo(): + return gencoro + + self.assertIs(foo(), gencoro) + + # decorate foo second time + foo = types.coroutine(foo) + self.assertIs(foo(), gencoro) + + def test_genfunc(self): + def gen(): yield + self.assertIs(types.coroutine(gen), gen) + self.assertIs(types.coroutine(types.coroutine(gen)), gen) + + self.assertTrue(gen.__code__.co_flags & inspect.CO_ITERABLE_COROUTINE) + self.assertFalse(gen.__code__.co_flags & inspect.CO_COROUTINE) + + g = gen() + self.assertTrue(g.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) + self.assertFalse(g.gi_code.co_flags & inspect.CO_COROUTINE) + + self.assertIs(types.coroutine(gen), gen) + + def test_wrapper_object(self): + def gen(): + yield + @types.coroutine + def coro(): + return gen() + + wrapper = coro() + self.assertIn('GeneratorWrapper', repr(wrapper)) + self.assertEqual(repr(wrapper), str(wrapper)) + self.assertTrue(set(dir(wrapper)).issuperset({ + '__await__', '__iter__', '__next__', 'cr_code', 'cr_running', + 'cr_frame', 'gi_code', 'gi_frame', 'gi_running', 'send', + 'close', 'throw'})) + if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py new file mode 100644 index 0000000..1461cfb --- /dev/null +++ b/Lib/test/test_typing.py @@ -0,0 +1,1257 @@ +from collections import namedtuple +import re +import sys +from unittest import TestCase, main +try: + from unittest import mock +except ImportError: + import mock # 3rd party install, for PY3.2. + +from typing import Any +from typing import TypeVar, AnyStr +from typing import T, KT, VT # Not in __all__. +from typing import Union, Optional +from typing import Tuple +from typing import Callable +from typing import Generic +from typing import cast +from typing import get_type_hints +from typing import no_type_check, no_type_check_decorator +from typing import NamedTuple +from typing import IO, TextIO, BinaryIO +from typing import Pattern, Match +import typing + + +class Employee: + pass + + +class Manager(Employee): + pass + + +class Founder(Employee): + pass + + +class ManagingFounder(Manager, Founder): + pass + + +class AnyTests(TestCase): + + def test_any_instance_type_error(self): + with self.assertRaises(TypeError): + isinstance(42, Any) + + def test_any_subclass(self): + self.assertTrue(issubclass(Employee, Any)) + self.assertTrue(issubclass(int, Any)) + self.assertTrue(issubclass(type(None), Any)) + self.assertTrue(issubclass(object, Any)) + + def test_others_any(self): + self.assertFalse(issubclass(Any, Employee)) + self.assertFalse(issubclass(Any, int)) + self.assertFalse(issubclass(Any, type(None))) + # However, Any is a subclass of object (this can't be helped). + self.assertTrue(issubclass(Any, object)) + + def test_repr(self): + self.assertEqual(repr(Any), 'typing.Any') + + def test_errors(self): + with self.assertRaises(TypeError): + issubclass(42, Any) + with self.assertRaises(TypeError): + Any[int] # Any is not a generic type. + + def test_cannot_subclass(self): + with self.assertRaises(TypeError): + class A(Any): + pass + + def test_cannot_instantiate(self): + with self.assertRaises(TypeError): + Any() + + def test_cannot_subscript(self): + with self.assertRaises(TypeError): + Any[int] + + def test_any_is_subclass(self): + # Any should be considered a subclass of everything. + assert issubclass(Any, Any) + assert issubclass(Any, typing.List) + assert issubclass(Any, typing.List[int]) + assert issubclass(Any, typing.List[T]) + assert issubclass(Any, typing.Mapping) + assert issubclass(Any, typing.Mapping[str, int]) + assert issubclass(Any, typing.Mapping[KT, VT]) + assert issubclass(Any, Generic) + assert issubclass(Any, Generic[T]) + assert issubclass(Any, Generic[KT, VT]) + assert issubclass(Any, AnyStr) + assert issubclass(Any, Union) + assert issubclass(Any, Union[int, str]) + assert issubclass(Any, typing.Match) + assert issubclass(Any, typing.Match[str]) + # These expressions must simply not fail. + typing.Match[Any] + typing.Pattern[Any] + typing.IO[Any] + + +class TypeVarTests(TestCase): + + def test_basic_plain(self): + T = TypeVar('T') + # Every class is a subclass of T. + assert issubclass(int, T) + assert issubclass(str, T) + # T equals itself. + assert T == T + # T is a subclass of itself. + assert issubclass(T, T) + # T is an instance of TypeVar + assert isinstance(T, TypeVar) + + def test_typevar_instance_type_error(self): + T = TypeVar('T') + with self.assertRaises(TypeError): + isinstance(42, T) + + def test_basic_constrained(self): + A = TypeVar('A', str, bytes) + # Only str and bytes are subclasses of A. + assert issubclass(str, A) + assert issubclass(bytes, A) + assert not issubclass(int, A) + # A equals itself. + assert A == A + # A is a subclass of itself. + assert issubclass(A, A) + + def test_constrained_error(self): + with self.assertRaises(TypeError): + X = TypeVar('X', int) + + def test_union_unique(self): + X = TypeVar('X') + Y = TypeVar('Y') + assert X != Y + assert Union[X] == X + assert Union[X] != Union[X, Y] + assert Union[X, X] == X + assert Union[X, int] != Union[X] + assert Union[X, int] != Union[int] + assert Union[X, int].__union_params__ == (X, int) + assert Union[X, int].__union_set_params__ == {X, int} + + def test_union_constrained(self): + A = TypeVar('A', str, bytes) + assert Union[A, str] != Union[A] + + def test_repr(self): + self.assertEqual(repr(T), '~T') + self.assertEqual(repr(KT), '~KT') + self.assertEqual(repr(VT), '~VT') + self.assertEqual(repr(AnyStr), '~AnyStr') + T_co = TypeVar('T_co', covariant=True) + self.assertEqual(repr(T_co), '+T_co') + T_contra = TypeVar('T_contra', contravariant=True) + self.assertEqual(repr(T_contra), '-T_contra') + + def test_no_redefinition(self): + self.assertNotEqual(TypeVar('T'), TypeVar('T')) + self.assertNotEqual(TypeVar('T', int, str), TypeVar('T', int, str)) + + def test_subclass_as_unions(self): + # None of these are true -- each type var is its own world. + self.assertFalse(issubclass(TypeVar('T', int, str), + TypeVar('T', int, str))) + self.assertFalse(issubclass(TypeVar('T', int, float), + TypeVar('T', int, float, str))) + self.assertFalse(issubclass(TypeVar('T', int, str), + TypeVar('T', str, int))) + A = TypeVar('A', int, str) + B = TypeVar('B', int, str, float) + self.assertFalse(issubclass(A, B)) + self.assertFalse(issubclass(B, A)) + + def test_cannot_subclass_vars(self): + with self.assertRaises(TypeError): + class V(TypeVar('T')): + pass + + def test_cannot_subclass_var_itself(self): + with self.assertRaises(TypeError): + class V(TypeVar): + pass + + def test_cannot_instantiate_vars(self): + with self.assertRaises(TypeError): + TypeVar('A')() + + def test_bound(self): + X = TypeVar('X', bound=Employee) + assert issubclass(Employee, X) + assert issubclass(Manager, X) + assert not issubclass(int, X) + + def test_bound_errors(self): + with self.assertRaises(TypeError): + TypeVar('X', bound=42) + with self.assertRaises(TypeError): + TypeVar('X', str, float, bound=Employee) + + +class UnionTests(TestCase): + + def test_basics(self): + u = Union[int, float] + self.assertNotEqual(u, Union) + self.assertTrue(issubclass(int, u)) + self.assertTrue(issubclass(float, u)) + + def test_union_any(self): + u = Union[Any] + self.assertEqual(u, Any) + u = Union[int, Any] + self.assertEqual(u, Any) + u = Union[Any, int] + self.assertEqual(u, Any) + + def test_union_object(self): + u = Union[object] + self.assertEqual(u, object) + u = Union[int, object] + self.assertEqual(u, object) + u = Union[object, int] + self.assertEqual(u, object) + + def test_union_any_object(self): + u = Union[object, Any] + self.assertEqual(u, Any) + u = Union[Any, object] + self.assertEqual(u, Any) + + def test_unordered(self): + u1 = Union[int, float] + u2 = Union[float, int] + self.assertEqual(u1, u2) + + def test_subclass(self): + u = Union[int, Employee] + self.assertTrue(issubclass(Manager, u)) + + def test_self_subclass(self): + self.assertTrue(issubclass(Union[KT, VT], Union)) + self.assertFalse(issubclass(Union, Union[KT, VT])) + + def test_multiple_inheritance(self): + u = Union[int, Employee] + self.assertTrue(issubclass(ManagingFounder, u)) + + def test_single_class_disappears(self): + t = Union[Employee] + self.assertIs(t, Employee) + + def test_base_class_disappears(self): + u = Union[Employee, Manager, int] + self.assertEqual(u, Union[int, Employee]) + u = Union[Manager, int, Employee] + self.assertEqual(u, Union[int, Employee]) + u = Union[Employee, Manager] + self.assertIs(u, Employee) + + def test_weird_subclasses(self): + u = Union[Employee, int, float] + v = Union[int, float] + self.assertTrue(issubclass(v, u)) + w = Union[int, Manager] + self.assertTrue(issubclass(w, u)) + + def test_union_union(self): + u = Union[int, float] + v = Union[u, Employee] + self.assertEqual(v, Union[int, float, Employee]) + + def test_repr(self): + self.assertEqual(repr(Union), 'typing.Union') + u = Union[Employee, int] + self.assertEqual(repr(u), 'typing.Union[%s.Employee, int]' % __name__) + u = Union[int, Employee] + self.assertEqual(repr(u), 'typing.Union[int, %s.Employee]' % __name__) + + def test_cannot_subclass(self): + with self.assertRaises(TypeError): + class C(Union): + pass + with self.assertRaises(TypeError): + class C(Union[int, str]): + pass + + def test_cannot_instantiate(self): + with self.assertRaises(TypeError): + Union() + u = Union[int, float] + with self.assertRaises(TypeError): + u() + + def test_optional(self): + o = Optional[int] + u = Union[int, None] + self.assertEqual(o, u) + + def test_empty(self): + with self.assertRaises(TypeError): + Union[()] + + def test_issubclass_union(self): + assert issubclass(Union[int, str], Union) + assert not issubclass(int, Union) + + def test_union_instance_type_error(self): + with self.assertRaises(TypeError): + isinstance(42, Union[int, str]) + + +class TypeVarUnionTests(TestCase): + + def test_simpler(self): + A = TypeVar('A', int, str, float) + B = TypeVar('B', int, str) + assert issubclass(A, A) + assert issubclass(B, B) + assert not issubclass(B, A) + assert issubclass(A, Union[int, str, float]) + assert not issubclass(Union[int, str, float], A) + assert not issubclass(Union[int, str], B) + assert issubclass(B, Union[int, str]) + assert not issubclass(A, B) + assert not issubclass(Union[int, str, float], B) + assert not issubclass(A, Union[int, str]) + + def test_var_union_subclass(self): + self.assertTrue(issubclass(T, Union[int, T])) + self.assertTrue(issubclass(KT, Union[KT, VT])) + + def test_var_union(self): + TU = TypeVar('TU', Union[int, float], None) + assert issubclass(int, TU) + assert issubclass(float, TU) + + +class TupleTests(TestCase): + + def test_basics(self): + self.assertTrue(issubclass(Tuple[int, str], Tuple)) + self.assertTrue(issubclass(Tuple[int, str], Tuple[int, str])) + self.assertFalse(issubclass(int, Tuple)) + self.assertFalse(issubclass(Tuple[float, str], Tuple[int, str])) + self.assertFalse(issubclass(Tuple[int, str, int], Tuple[int, str])) + self.assertFalse(issubclass(Tuple[int, str], Tuple[int, str, int])) + self.assertTrue(issubclass(tuple, Tuple)) + self.assertFalse(issubclass(Tuple, tuple)) # Can't have it both ways. + + def test_tuple_subclass(self): + class MyTuple(tuple): + pass + self.assertTrue(issubclass(MyTuple, Tuple)) + + def test_tuple_instance_type_error(self): + with self.assertRaises(TypeError): + isinstance((0, 0), Tuple[int, int]) + with self.assertRaises(TypeError): + isinstance((0, 0), Tuple) + + def test_tuple_ellipsis_subclass(self): + + class B: + pass + + class C(B): + pass + + assert not issubclass(Tuple[B], Tuple[B, ...]) + assert issubclass(Tuple[C, ...], Tuple[B, ...]) + assert not issubclass(Tuple[C, ...], Tuple[B]) + assert not issubclass(Tuple[C], Tuple[B, ...]) + + def test_repr(self): + self.assertEqual(repr(Tuple), 'typing.Tuple') + self.assertEqual(repr(Tuple[()]), 'typing.Tuple[]') + self.assertEqual(repr(Tuple[int, float]), 'typing.Tuple[int, float]') + self.assertEqual(repr(Tuple[int, ...]), 'typing.Tuple[int, ...]') + + def test_errors(self): + with self.assertRaises(TypeError): + issubclass(42, Tuple) + with self.assertRaises(TypeError): + issubclass(42, Tuple[int]) + + +class CallableTests(TestCase): + + def test_self_subclass(self): + self.assertTrue(issubclass(Callable[[int], int], Callable)) + self.assertFalse(issubclass(Callable, Callable[[int], int])) + self.assertTrue(issubclass(Callable[[int], int], Callable[[int], int])) + self.assertFalse(issubclass(Callable[[Employee], int], + Callable[[Manager], int])) + self.assertFalse(issubclass(Callable[[Manager], int], + Callable[[Employee], int])) + self.assertFalse(issubclass(Callable[[int], Employee], + Callable[[int], Manager])) + self.assertFalse(issubclass(Callable[[int], Manager], + Callable[[int], Employee])) + + def test_eq_hash(self): + self.assertEqual(Callable[[int], int], Callable[[int], int]) + self.assertEqual(len({Callable[[int], int], Callable[[int], int]}), 1) + self.assertNotEqual(Callable[[int], int], Callable[[int], str]) + self.assertNotEqual(Callable[[int], int], Callable[[str], int]) + self.assertNotEqual(Callable[[int], int], Callable[[int, int], int]) + self.assertNotEqual(Callable[[int], int], Callable[[], int]) + self.assertNotEqual(Callable[[int], int], Callable) + + def test_cannot_subclass(self): + with self.assertRaises(TypeError): + + class C(Callable): + pass + + with self.assertRaises(TypeError): + + class C(Callable[[int], int]): + pass + + def test_cannot_instantiate(self): + with self.assertRaises(TypeError): + Callable() + c = Callable[[int], str] + with self.assertRaises(TypeError): + c() + + def test_callable_instance_works(self): + def f(): + pass + assert isinstance(f, Callable) + assert not isinstance(None, Callable) + + def test_callable_instance_type_error(self): + def f(): + pass + with self.assertRaises(TypeError): + assert isinstance(f, Callable[[], None]) + with self.assertRaises(TypeError): + assert isinstance(f, Callable[[], Any]) + with self.assertRaises(TypeError): + assert not isinstance(None, Callable[[], None]) + with self.assertRaises(TypeError): + assert not isinstance(None, Callable[[], Any]) + + def test_repr(self): + ct0 = Callable[[], bool] + self.assertEqual(repr(ct0), 'typing.Callable[[], bool]') + ct2 = Callable[[str, float], int] + self.assertEqual(repr(ct2), 'typing.Callable[[str, float], int]') + ctv = Callable[..., str] + self.assertEqual(repr(ctv), 'typing.Callable[..., str]') + + def test_callable_with_ellipsis(self): + + def foo(a: Callable[..., T]): + pass + + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': Callable[..., T]}) + + +XK = TypeVar('XK', str, bytes) +XV = TypeVar('XV') + + +class SimpleMapping(Generic[XK, XV]): + + def __getitem__(self, key: XK) -> XV: + ... + + def __setitem__(self, key: XK, value: XV): + ... + + def get(self, key: XK, default: XV = None) -> XV: + ... + + +class MySimpleMapping(SimpleMapping): + + def __init__(self): + self.store = {} + + def __getitem__(self, key: str): + return self.store[key] + + def __setitem__(self, key: str, value): + self.store[key] = value + + def get(self, key: str, default=None): + try: + return self.store[key] + except KeyError: + return default + + +class ProtocolTests(TestCase): + + def test_supports_int(self): + assert issubclass(int, typing.SupportsInt) + assert not issubclass(str, typing.SupportsInt) + + def test_supports_float(self): + assert issubclass(float, typing.SupportsFloat) + assert not issubclass(str, typing.SupportsFloat) + + def test_supports_complex(self): + + # Note: complex itself doesn't have __complex__. + class C: + def __complex__(self): + return 0j + + assert issubclass(C, typing.SupportsComplex) + assert not issubclass(str, typing.SupportsComplex) + + def test_supports_bytes(self): + + # Note: bytes itself doesn't have __bytes__. + class B: + def __bytes__(self): + return b'' + + assert issubclass(B, typing.SupportsBytes) + assert not issubclass(str, typing.SupportsBytes) + + def test_supports_abs(self): + assert issubclass(float, typing.SupportsAbs) + assert issubclass(int, typing.SupportsAbs) + assert not issubclass(str, typing.SupportsAbs) + + def test_supports_round(self): + assert issubclass(float, typing.SupportsRound) + assert issubclass(int, typing.SupportsRound) + assert not issubclass(str, typing.SupportsRound) + + def test_reversible(self): + assert issubclass(list, typing.Reversible) + assert not issubclass(int, typing.Reversible) + + def test_protocol_instance_type_error(self): + with self.assertRaises(TypeError): + isinstance([], typing.Reversible) + + +class GenericTests(TestCase): + + def test_basics(self): + X = SimpleMapping[str, Any] + Y = SimpleMapping[XK, str] + X[str, str] + Y[str, str] + with self.assertRaises(TypeError): + X[int, str] + with self.assertRaises(TypeError): + Y[str, bytes] + + def test_init(self): + T = TypeVar('T') + S = TypeVar('S') + with self.assertRaises(TypeError): + Generic[T, T] + with self.assertRaises(TypeError): + Generic[T, S, T] + + def test_repr(self): + self.assertEqual(repr(SimpleMapping), + __name__ + '.' + 'SimpleMapping[~XK, ~XV]') + self.assertEqual(repr(MySimpleMapping), + __name__ + '.' + 'MySimpleMapping[~XK, ~XV]') + + def test_errors(self): + with self.assertRaises(TypeError): + B = SimpleMapping[XK, Any] + + class C(Generic[B]): + pass + + def test_repr_2(self): + PY32 = sys.version_info[:2] < (3, 3) + + class C(Generic[T]): + pass + + assert C.__module__ == __name__ + if not PY32: + assert C.__qualname__ == 'GenericTests.test_repr_2.<locals>.C' + assert repr(C).split('.')[-1] == 'C[~T]' + X = C[int] + assert X.__module__ == __name__ + if not PY32: + assert X.__qualname__ == 'C' + assert repr(X).split('.')[-1] == 'C[int]' + + class Y(C[int]): + pass + + assert Y.__module__ == __name__ + if not PY32: + assert Y.__qualname__ == 'GenericTests.test_repr_2.<locals>.Y' + assert repr(Y).split('.')[-1] == 'Y[int]' + + def test_eq_1(self): + assert Generic == Generic + assert Generic[T] == Generic[T] + assert Generic[KT] != Generic[VT] + + def test_eq_2(self): + + class A(Generic[T]): + pass + + class B(Generic[T]): + pass + + assert A == A + assert A != B + assert A[T] == A[T] + assert A[T] != B[T] + + def test_multiple_inheritance(self): + + class A(Generic[T, VT]): + pass + + class B(Generic[KT, T]): + pass + + class C(A, Generic[KT, VT], B): + pass + + assert C.__parameters__ == (T, VT, KT) + + def test_nested(self): + + class G(Generic): + pass + + class Visitor(G[T]): + + a = None + + def set(self, a: T): + self.a = a + + def get(self): + return self.a + + def visit(self) -> T: + return self.a + + V = Visitor[typing.List[int]] + + class IntListVisitor(V): + + def append(self, x: int): + self.a.append(x) + + a = IntListVisitor() + a.set([]) + a.append(1) + a.append(42) + assert a.get() == [1, 42] + + def test_type_erasure(self): + T = TypeVar('T') + + class Node(Generic[T]): + def __init__(self, label: T, + left: 'Node[T]' = None, + right: 'Node[T]' = None): + self.label = label # type: T + self.left = left # type: Optional[Node[T]] + self.right = right # type: Optional[Node[T]] + + def foo(x: T): + a = Node(x) + b = Node[T](x) + c = Node[Any](x) + assert type(a) is Node + assert type(b) is Node + assert type(c) is Node + + foo(42) + + +class VarianceTests(TestCase): + + def test_invariance(self): + # Because of invariance, List[subclass of X] is not a subclass + # of List[X], and ditto for MutableSequence. + assert not issubclass(typing.List[Manager], typing.List[Employee]) + assert not issubclass(typing.MutableSequence[Manager], + typing.MutableSequence[Employee]) + # It's still reflexive. + assert issubclass(typing.List[Employee], typing.List[Employee]) + assert issubclass(typing.MutableSequence[Employee], + typing.MutableSequence[Employee]) + + def test_covariance_tuple(self): + # Check covariace for Tuple (which are really special cases). + assert issubclass(Tuple[Manager], Tuple[Employee]) + assert not issubclass(Tuple[Employee], Tuple[Manager]) + # And pairwise. + assert issubclass(Tuple[Manager, Manager], Tuple[Employee, Employee]) + assert not issubclass(Tuple[Employee, Employee], + Tuple[Manager, Employee]) + # And using ellipsis. + assert issubclass(Tuple[Manager, ...], Tuple[Employee, ...]) + assert not issubclass(Tuple[Employee, ...], Tuple[Manager, ...]) + + def test_covariance_sequence(self): + # Check covariance for Sequence (which is just a generic class + # for this purpose, but using a covariant type variable). + assert issubclass(typing.Sequence[Manager], typing.Sequence[Employee]) + assert not issubclass(typing.Sequence[Employee], + typing.Sequence[Manager]) + + def test_covariance_mapping(self): + # Ditto for Mapping (covariant in the value, invariant in the key). + assert issubclass(typing.Mapping[Employee, Manager], + typing.Mapping[Employee, Employee]) + assert not issubclass(typing.Mapping[Manager, Employee], + typing.Mapping[Employee, Employee]) + assert not issubclass(typing.Mapping[Employee, Manager], + typing.Mapping[Manager, Manager]) + assert not issubclass(typing.Mapping[Manager, Employee], + typing.Mapping[Manager, Manager]) + + +class CastTests(TestCase): + + def test_basics(self): + assert cast(int, 42) == 42 + assert cast(float, 42) == 42 + assert type(cast(float, 42)) is int + assert cast(Any, 42) == 42 + assert cast(list, 42) == 42 + assert cast(Union[str, float], 42) == 42 + assert cast(AnyStr, 42) == 42 + assert cast(None, 42) == 42 + + def test_errors(self): + # Bogus calls are not expected to fail. + cast(42, 42) + cast('hello', 42) + + +class ForwardRefTests(TestCase): + + def test_basics(self): + + class Node(Generic[T]): + + def __init__(self, label: T): + self.label = label + self.left = self.right = None + + def add_both(self, + left: 'Optional[Node[T]]', + right: 'Node[T]' = None, + stuff: int = None, + blah=None): + self.left = left + self.right = right + + def add_left(self, node: Optional['Node[T]']): + self.add_both(node, None) + + def add_right(self, node: 'Node[T]' = None): + self.add_both(None, node) + + t = Node[int] + both_hints = get_type_hints(t.add_both, globals(), locals()) + assert both_hints['left'] == both_hints['right'] == Optional[Node[T]] + assert both_hints['stuff'] == Optional[int] + assert 'blah' not in both_hints + + left_hints = get_type_hints(t.add_left, globals(), locals()) + assert left_hints['node'] == Optional[Node[T]] + + right_hints = get_type_hints(t.add_right, globals(), locals()) + assert right_hints['node'] == Optional[Node[T]] + + def test_forwardref_instance_type_error(self): + fr = typing._ForwardRef('int') + with self.assertRaises(TypeError): + isinstance(42, fr) + + def test_union_forward(self): + + def foo(a: Union['T']): + pass + + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': Union[T]}) + + def test_tuple_forward(self): + + def foo(a: Tuple['T']): + pass + + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': Tuple[T]}) + + def test_callable_forward(self): + + def foo(a: Callable[['T'], 'T']): + pass + + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': Callable[[T], T]}) + + def test_callable_with_ellipsis_forward(self): + + def foo(a: 'Callable[..., T]'): + pass + + self.assertEqual(get_type_hints(foo, globals(), locals()), + {'a': Callable[..., T]}) + + def test_syntax_error(self): + + with self.assertRaises(SyntaxError): + Generic['/T'] + + def test_delayed_syntax_error(self): + + def foo(a: 'Node[T'): + pass + + with self.assertRaises(SyntaxError): + get_type_hints(foo) + + def test_type_error(self): + + def foo(a: Tuple['42']): + pass + + with self.assertRaises(TypeError): + get_type_hints(foo) + + def test_name_error(self): + + def foo(a: 'Noode[T]'): + pass + + with self.assertRaises(NameError): + get_type_hints(foo, locals()) + + def test_no_type_check(self): + + @no_type_check + def foo(a: 'whatevers') -> {}: + pass + + th = get_type_hints(foo) + self.assertEqual(th, {}) + + def test_no_type_check_class(self): + + @no_type_check + class C: + def foo(a: 'whatevers') -> {}: + pass + + cth = get_type_hints(C.foo) + self.assertEqual(cth, {}) + ith = get_type_hints(C().foo) + self.assertEqual(ith, {}) + + def test_meta_no_type_check(self): + + @no_type_check_decorator + def magic_decorator(deco): + return deco + + self.assertEqual(magic_decorator.__name__, 'magic_decorator') + + @magic_decorator + def foo(a: 'whatevers') -> {}: + pass + + @magic_decorator + class C: + def foo(a: 'whatevers') -> {}: + pass + + self.assertEqual(foo.__name__, 'foo') + th = get_type_hints(foo) + self.assertEqual(th, {}) + cth = get_type_hints(C.foo) + self.assertEqual(cth, {}) + ith = get_type_hints(C().foo) + self.assertEqual(ith, {}) + + def test_default_globals(self): + code = ("class C:\n" + " def foo(self, a: 'C') -> 'D': pass\n" + "class D:\n" + " def bar(self, b: 'D') -> C: pass\n" + ) + ns = {} + exec(code, ns) + hints = get_type_hints(ns['C'].foo) + assert hints == {'a': ns['C'], 'return': ns['D']} + + +class OverloadTests(TestCase): + + def test_overload_exists(self): + from typing import overload + + def test_overload_fails(self): + from typing import overload + + with self.assertRaises(RuntimeError): + @overload + def blah(): + pass + + +class CollectionsAbcTests(TestCase): + + def test_hashable(self): + assert isinstance(42, typing.Hashable) + assert not isinstance([], typing.Hashable) + + def test_iterable(self): + assert isinstance([], typing.Iterable) + # Due to ABC caching, the second time takes a separate code + # path and could fail. So call this a few times. + assert isinstance([], typing.Iterable) + assert isinstance([], typing.Iterable) + assert isinstance([], typing.Iterable[int]) + assert not isinstance(42, typing.Iterable) + # Just in case, also test issubclass() a few times. + assert issubclass(list, typing.Iterable) + assert issubclass(list, typing.Iterable) + + def test_iterator(self): + it = iter([]) + assert isinstance(it, typing.Iterator) + assert isinstance(it, typing.Iterator[int]) + assert not isinstance(42, typing.Iterator) + + def test_sized(self): + assert isinstance([], typing.Sized) + assert not isinstance(42, typing.Sized) + + def test_container(self): + assert isinstance([], typing.Container) + assert not isinstance(42, typing.Container) + + def test_abstractset(self): + assert isinstance(set(), typing.AbstractSet) + assert not isinstance(42, typing.AbstractSet) + + def test_mutableset(self): + assert isinstance(set(), typing.MutableSet) + assert not isinstance(frozenset(), typing.MutableSet) + + def test_mapping(self): + assert isinstance({}, typing.Mapping) + assert not isinstance(42, typing.Mapping) + + def test_mutablemapping(self): + assert isinstance({}, typing.MutableMapping) + assert not isinstance(42, typing.MutableMapping) + + def test_sequence(self): + assert isinstance([], typing.Sequence) + assert not isinstance(42, typing.Sequence) + + def test_mutablesequence(self): + assert isinstance([], typing.MutableSequence) + assert not isinstance((), typing.MutableSequence) + + def test_bytestring(self): + assert isinstance(b'', typing.ByteString) + assert isinstance(bytearray(b''), typing.ByteString) + + def test_list(self): + assert issubclass(list, typing.List) + + def test_set(self): + assert issubclass(set, typing.Set) + assert not issubclass(frozenset, typing.Set) + + def test_frozenset(self): + assert issubclass(frozenset, typing.FrozenSet) + assert not issubclass(set, typing.FrozenSet) + + def test_dict(self): + assert issubclass(dict, typing.Dict) + + def test_no_list_instantiation(self): + with self.assertRaises(TypeError): + typing.List() + with self.assertRaises(TypeError): + typing.List[T]() + with self.assertRaises(TypeError): + typing.List[int]() + + def test_list_subclass_instantiation(self): + + class MyList(typing.List[int]): + pass + + a = MyList() + assert isinstance(a, MyList) + + def test_no_dict_instantiation(self): + with self.assertRaises(TypeError): + typing.Dict() + with self.assertRaises(TypeError): + typing.Dict[KT, VT]() + with self.assertRaises(TypeError): + typing.Dict[str, int]() + + def test_dict_subclass_instantiation(self): + + class MyDict(typing.Dict[str, int]): + pass + + d = MyDict() + assert isinstance(d, MyDict) + + def test_no_set_instantiation(self): + with self.assertRaises(TypeError): + typing.Set() + with self.assertRaises(TypeError): + typing.Set[T]() + with self.assertRaises(TypeError): + typing.Set[int]() + + def test_set_subclass_instantiation(self): + + class MySet(typing.Set[int]): + pass + + d = MySet() + assert isinstance(d, MySet) + + def test_no_frozenset_instantiation(self): + with self.assertRaises(TypeError): + typing.FrozenSet() + with self.assertRaises(TypeError): + typing.FrozenSet[T]() + with self.assertRaises(TypeError): + typing.FrozenSet[int]() + + def test_frozenset_subclass_instantiation(self): + + class MyFrozenSet(typing.FrozenSet[int]): + pass + + d = MyFrozenSet() + assert isinstance(d, MyFrozenSet) + + def test_no_tuple_instantiation(self): + with self.assertRaises(TypeError): + Tuple() + with self.assertRaises(TypeError): + Tuple[T]() + with self.assertRaises(TypeError): + Tuple[int]() + + def test_generator(self): + def foo(): + yield 42 + g = foo() + assert issubclass(type(g), typing.Generator) + assert issubclass(typing.Generator[Manager, Employee, Manager], + typing.Generator[Employee, Manager, Employee]) + assert not issubclass(typing.Generator[Manager, Manager, Manager], + typing.Generator[Employee, Employee, Employee]) + + def test_no_generator_instantiation(self): + with self.assertRaises(TypeError): + typing.Generator() + with self.assertRaises(TypeError): + typing.Generator[T, T, T]() + with self.assertRaises(TypeError): + typing.Generator[int, int, int]() + + def test_subclassing(self): + + class MMA(typing.MutableMapping): + pass + + with self.assertRaises(TypeError): # It's abstract + MMA() + + class MMC(MMA): + def __len__(self): + return 0 + + assert len(MMC()) == 0 + + class MMB(typing.MutableMapping[KT, VT]): + def __len__(self): + return 0 + + assert len(MMB()) == 0 + assert len(MMB[str, str]()) == 0 + assert len(MMB[KT, VT]()) == 0 + + +class NamedTupleTests(TestCase): + + def test_basics(self): + Emp = NamedTuple('Emp', [('name', str), ('id', int)]) + assert issubclass(Emp, tuple) + joe = Emp('Joe', 42) + jim = Emp(name='Jim', id=1) + assert isinstance(joe, Emp) + assert isinstance(joe, tuple) + assert joe.name == 'Joe' + assert joe.id == 42 + assert jim.name == 'Jim' + assert jim.id == 1 + assert Emp.__name__ == 'Emp' + assert Emp._fields == ('name', 'id') + assert Emp._field_types == dict(name=str, id=int) + + +class IOTests(TestCase): + + def test_io(self): + + def stuff(a: IO) -> AnyStr: + return a.readline() + + a = stuff.__annotations__['a'] + assert a.__parameters__ == (AnyStr,) + + def test_textio(self): + + def stuff(a: TextIO) -> str: + return a.readline() + + a = stuff.__annotations__['a'] + assert a.__parameters__ == (str,) + + def test_binaryio(self): + + def stuff(a: BinaryIO) -> bytes: + return a.readline() + + a = stuff.__annotations__['a'] + assert a.__parameters__ == (bytes,) + + def test_io_submodule(self): + from typing.io import IO, TextIO, BinaryIO, __all__, __name__ + assert IO is typing.IO + assert TextIO is typing.TextIO + assert BinaryIO is typing.BinaryIO + assert set(__all__) == set(['IO', 'TextIO', 'BinaryIO']) + assert __name__ == 'typing.io' + + +class RETests(TestCase): + # Much of this is really testing _TypeAlias. + + def test_basics(self): + pat = re.compile('[a-z]+', re.I) + assert issubclass(pat.__class__, Pattern) + assert issubclass(type(pat), Pattern) + assert issubclass(type(pat), Pattern[str]) + + mat = pat.search('12345abcde.....') + assert issubclass(mat.__class__, Match) + assert issubclass(mat.__class__, Match[str]) + assert issubclass(mat.__class__, Match[bytes]) # Sad but true. + assert issubclass(type(mat), Match) + assert issubclass(type(mat), Match[str]) + + p = Pattern[Union[str, bytes]] + assert issubclass(Pattern[str], Pattern) + assert issubclass(Pattern[str], p) + + m = Match[Union[bytes, str]] + assert issubclass(Match[bytes], Match) + assert issubclass(Match[bytes], m) + + def test_errors(self): + with self.assertRaises(TypeError): + # Doesn't fit AnyStr. + Pattern[int] + with self.assertRaises(TypeError): + # Can't change type vars? + Match[T] + m = Match[Union[str, bytes]] + with self.assertRaises(TypeError): + # Too complicated? + m[str] + with self.assertRaises(TypeError): + # We don't support isinstance(). + isinstance(42, Pattern) + with self.assertRaises(TypeError): + # We don't support isinstance(). + isinstance(42, Pattern[str]) + + def test_repr(self): + assert repr(Pattern) == 'Pattern[~AnyStr]' + assert repr(Pattern[str]) == 'Pattern[str]' + assert repr(Pattern[bytes]) == 'Pattern[bytes]' + assert repr(Match) == 'Match[~AnyStr]' + assert repr(Match[str]) == 'Match[str]' + assert repr(Match[bytes]) == 'Match[bytes]' + + def test_re_submodule(self): + from typing.re import Match, Pattern, __all__, __name__ + assert Match is typing.Match + assert Pattern is typing.Pattern + assert set(__all__) == set(['Match', 'Pattern']) + assert __name__ == 'typing.re' + + def test_cannot_subclass(self): + with self.assertRaises(TypeError) as ex: + + class A(typing.Match): + pass + + assert str(ex.exception) == "A type alias cannot be subclassed" + + +class AllTests(TestCase): + """Tests for __all__.""" + + def test_all(self): + from typing import __all__ as a + # Just spot-check the first and last of every category. + assert 'AbstractSet' in a + assert 'ValuesView' in a + assert 'cast' in a + assert 'overload' in a + assert 'io' in a + assert 're' in a + # Spot-check that stdlib modules aren't exported. + assert 'os' not in a + assert 'sys' not in a + + +if __name__ == '__main__': + main() diff --git a/Lib/test/test_ucn.py b/Lib/test/test_ucn.py index 1e07f66..8febf0a 100644 --- a/Lib/test/test_ucn.py +++ b/Lib/test/test_ucn.py @@ -233,8 +233,5 @@ class UnicodeNamesTest(unittest.TestCase): ) -def test_main(): - support.run_unittest(UnicodeNamesTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_unary.py b/Lib/test/test_unary.py index b835564..c3c17cc 100644 --- a/Lib/test/test_unary.py +++ b/Lib/test/test_unary.py @@ -1,7 +1,6 @@ """Test compiler changes for unary ops (+, -, ~) introduced in Python 2.2""" import unittest -from test.support import run_unittest class UnaryOpTestCase(unittest.TestCase): @@ -50,9 +49,5 @@ class UnaryOpTestCase(unittest.TestCase): self.assertRaises(TypeError, eval, "~2.0") -def test_main(): - run_unittest(UnaryOpTestCase) - - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 2cc1d7c..1429a6d 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -8,6 +8,7 @@ Written by Marc-Andre Lemburg (mal@lemburg.com). import _string import codecs import itertools +import operator import struct import sys import unittest @@ -315,6 +316,7 @@ class UnicodeTest(string_tests.CommonTest, {ord('a'): None, ord('b'): ''}) self.checkequalnofix('xyyx', 'xzx', 'translate', {ord('z'): 'yy'}) + # this needs maketrans() self.checkequalnofix('abababc', 'abababc', 'translate', {'b': '<i>'}) @@ -324,6 +326,33 @@ class UnicodeTest(string_tests.CommonTest, tbl = self.type2test.maketrans('abc', 'xyz', 'd') self.checkequalnofix('xyzzy', 'abdcdcbdddd', 'translate', tbl) + # various tests switching from ASCII to latin1 or the opposite; + # same length, remove a letter, or replace with a longer string. + self.assertEqual("[a]".translate(str.maketrans('a', 'X')), + "[X]") + self.assertEqual("[a]".translate(str.maketrans({'a': 'X'})), + "[X]") + self.assertEqual("[a]".translate(str.maketrans({'a': None})), + "[]") + self.assertEqual("[a]".translate(str.maketrans({'a': 'XXX'})), + "[XXX]") + self.assertEqual("[a]".translate(str.maketrans({'a': '\xe9'})), + "[\xe9]") + self.assertEqual("[a]".translate(str.maketrans({'a': '<\xe9>'})), + "[<\xe9>]") + self.assertEqual("[\xe9]".translate(str.maketrans({'\xe9': 'a'})), + "[a]") + self.assertEqual("[\xe9]".translate(str.maketrans({'\xe9': None})), + "[]") + + # invalid Unicode characters + invalid_char = 0x10ffff+1 + for before in "a\xe9\u20ac\U0010ffff": + mapping = str.maketrans({before: invalid_char}) + text = "[%s]" % before + self.assertRaises(ValueError, text.translate, mapping) + + # errors self.assertRaises(TypeError, self.type2test.maketrans) self.assertRaises(ValueError, self.type2test.maketrans, 'abc', 'defg') self.assertRaises(TypeError, self.type2test.maketrans, 2, 'def') @@ -1306,20 +1335,20 @@ class UnicodeTest(string_tests.CommonTest, self.assertEqual('%.2s' % "a\xe9\u20ac", 'a\xe9') #issue 19995 - class PsuedoInt: + class PseudoInt: def __init__(self, value): self.value = int(value) def __int__(self): return self.value def __index__(self): return self.value - class PsuedoFloat: + class PseudoFloat: def __init__(self, value): self.value = float(value) def __int__(self): return int(self.value) - pi = PsuedoFloat(3.1415) - letter_m = PsuedoInt(109) + pi = PseudoFloat(3.1415) + letter_m = PseudoInt(109) self.assertEqual('%x' % 42, '2a') self.assertEqual('%X' % 15, 'F') self.assertEqual('%o' % 9, '11') @@ -1328,11 +1357,11 @@ class UnicodeTest(string_tests.CommonTest, self.assertEqual('%X' % letter_m, '6D') self.assertEqual('%o' % letter_m, '155') self.assertEqual('%c' % letter_m, 'm') - self.assertWarns(DeprecationWarning, '%x'.__mod__, pi), - self.assertWarns(DeprecationWarning, '%x'.__mod__, 3.14), - self.assertWarns(DeprecationWarning, '%X'.__mod__, 2.11), - self.assertWarns(DeprecationWarning, '%o'.__mod__, 1.79), - self.assertWarns(DeprecationWarning, '%c'.__mod__, pi), + self.assertRaisesRegex(TypeError, '%x format: an integer is required, not float', operator.mod, '%x', 3.14), + self.assertRaisesRegex(TypeError, '%X format: an integer is required, not float', operator.mod, '%X', 2.11), + self.assertRaisesRegex(TypeError, '%o format: an integer is required, not float', operator.mod, '%o', 1.79), + self.assertRaisesRegex(TypeError, '%x format: an integer is required, not PseudoFloat', operator.mod, '%x', pi), + self.assertRaises(TypeError, operator.mod, '%c', pi), def test_formatting_with_enum(self): # issue18780 @@ -2053,7 +2082,8 @@ class UnicodeTest(string_tests.CommonTest, 'cp863', 'cp865', 'cp866', 'cp1125', 'iso8859_10', 'iso8859_13', 'iso8859_14', 'iso8859_15', 'iso8859_2', 'iso8859_3', 'iso8859_4', 'iso8859_5', 'iso8859_6', - 'iso8859_7', 'iso8859_9', 'koi8_r', 'latin_1', + 'iso8859_7', 'iso8859_9', + 'koi8_r', 'koi8_t', 'koi8_u', 'kz1048', 'latin_1', 'mac_cyrillic', 'mac_latin2', 'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', @@ -2081,14 +2111,14 @@ class UnicodeTest(string_tests.CommonTest, 'cp863', 'cp865', 'cp866', 'cp1125', 'iso8859_10', 'iso8859_13', 'iso8859_14', 'iso8859_15', 'iso8859_2', 'iso8859_4', 'iso8859_5', - 'iso8859_9', 'koi8_r', 'latin_1', + 'iso8859_9', 'koi8_r', 'koi8_u', 'latin_1', 'mac_cyrillic', 'mac_latin2', ### These have undefined mappings: #'cp1250', 'cp1251', 'cp1252', 'cp1253', 'cp1254', 'cp1255', #'cp1256', 'cp1257', 'cp1258', #'cp424', 'cp856', 'cp857', 'cp864', 'cp869', 'cp874', - #'iso8859_3', 'iso8859_6', 'iso8859_7', + #'iso8859_3', 'iso8859_6', 'iso8859_7', 'koi8_t', 'kz1048', #'mac_greek', 'mac_iceland','mac_roman', 'mac_turkish', ### These fail the round-trip: diff --git a/Lib/test/test_unicodedata.py b/Lib/test/test_unicodedata.py index 707b30e..0f33d19 100644 --- a/Lib/test/test_unicodedata.py +++ b/Lib/test/test_unicodedata.py @@ -21,7 +21,7 @@ errors = 'surrogatepass' class UnicodeMethodsTest(unittest.TestCase): # update this, if the database changes - expectedchecksum = 'e74e878de71b6e780ffac271785c3cb58f6251f3' + expectedchecksum = '5971760872b2f98bb9c701e6c0db3273d756b3ec' def test_method_checksum(self): h = hashlib.sha1() @@ -79,8 +79,9 @@ class UnicodeDatabaseTest(unittest.TestCase): class UnicodeFunctionsTest(UnicodeDatabaseTest): - # update this, if the database changes - expectedchecksum = 'f0b74d26776331cc7bdc3a4698f037d73f2cee2b' + # Update this if the database changes. Make sure to do a full rebuild + # (e.g. 'make distclean && make') to get the correct checksum. + expectedchecksum = '5e74827cd07f9e546a30f34b7bcf6cc2eac38c8c' def test_function_checksum(self): data = [] h = hashlib.sha1() @@ -312,12 +313,5 @@ class UnicodeMiscTest(UnicodeDatabaseTest): self.assertEqual(len(lines), 1, r"\u%.4x should not be a linebreak" % i) -def test_main(): - test.support.run_unittest( - UnicodeMiscTest, - UnicodeMethodsTest, - UnicodeFunctionsTest - ) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_unpack.py b/Lib/test/test_unpack.py index b1c483d..d1ccb38 100644 --- a/Lib/test/test_unpack.py +++ b/Lib/test/test_unpack.py @@ -76,7 +76,7 @@ Unpacking sequence too short >>> a, b, c, d = Seq() Traceback (most recent call last): ... - ValueError: need more than 3 values to unpack + ValueError: not enough values to unpack (expected 4, got 3) Unpacking sequence too long diff --git a/Lib/test/test_unpack_ex.py b/Lib/test/test_unpack_ex.py index ae2dcbd..d27eef0 100644 --- a/Lib/test/test_unpack_ex.py +++ b/Lib/test/test_unpack_ex.py @@ -71,8 +71,188 @@ Multiple targets >>> a == 0 and b == [1, 2, 3] and c == 4 and d == [0, 1, 2, 3] and e == 4 True +Assignment unpacking + + >>> a, b, *c = range(5) + >>> a, b, c + (0, 1, [2, 3, 4]) + >>> *a, b, c = a, b, *c + >>> a, b, c + ([0, 1, 2], 3, 4) + +Set display element unpacking + + >>> a = [1, 2, 3] + >>> sorted({1, *a, 0, 4}) + [0, 1, 2, 3, 4] + + >>> {1, *1, 0, 4} + Traceback (most recent call last): + ... + TypeError: 'int' object is not iterable + +Dict display element unpacking + + >>> kwds = {'z': 0, 'w': 12} + >>> sorted({'x': 1, 'y': 2, **kwds}.items()) + [('w', 12), ('x', 1), ('y', 2), ('z', 0)] + + >>> sorted({**{'x': 1}, 'y': 2, **{'z': 3}}.items()) + [('x', 1), ('y', 2), ('z', 3)] + + >>> sorted({**{'x': 1}, 'y': 2, **{'x': 3}}.items()) + [('x', 3), ('y', 2)] + + >>> sorted({**{'x': 1}, **{'x': 3}, 'x': 4}.items()) + [('x', 4)] + + >>> {**{}} + {} + + >>> a = {} + >>> {**a}[0] = 1 + >>> a + {} + + >>> {**1} + Traceback (most recent call last): + ... + TypeError: 'int' object is not a mapping + + >>> {**[]} + Traceback (most recent call last): + ... + TypeError: 'list' object is not a mapping + + >>> len(eval("{" + ", ".join("**{{{}: {}}}".format(i, i) + ... for i in range(1000)) + "}")) + 1000 + + >>> {0:1, **{0:2}, 0:3, 0:4} + {0: 4} + +List comprehension element unpacking + + >>> a, b, c = [0, 1, 2], 3, 4 + >>> [*a, b, c] + [0, 1, 2, 3, 4] + + >>> l = [a, (3, 4), {5}, {6: None}, (i for i in range(7, 10))] + >>> [*item for item in l] + Traceback (most recent call last): + ... + SyntaxError: iterable unpacking cannot be used in comprehension + + >>> [*[0, 1] for i in range(10)] + Traceback (most recent call last): + ... + SyntaxError: iterable unpacking cannot be used in comprehension + + >>> [*'a' for i in range(10)] + Traceback (most recent call last): + ... + SyntaxError: iterable unpacking cannot be used in comprehension + + >>> [*[] for i in range(10)] + Traceback (most recent call last): + ... + SyntaxError: iterable unpacking cannot be used in comprehension + +Generator expression in function arguments + + >>> list(*x for x in (range(5) for i in range(3))) + Traceback (most recent call last): + ... + list(*x for x in (range(5) for i in range(3))) + ^ + SyntaxError: invalid syntax + + >>> dict(**x for x in [{1:2}]) + Traceback (most recent call last): + ... + dict(**x for x in [{1:2}]) + ^ + SyntaxError: invalid syntax + +Iterable argument unpacking + + >>> print(*[1], *[2], 3) + 1 2 3 + +Make sure that they don't corrupt the passed-in dicts. + + >>> def f(x, y): + ... print(x, y) + ... + >>> original_dict = {'x': 1} + >>> f(**original_dict, y=2) + 1 2 + >>> original_dict + {'x': 1} + Now for some failures +Make sure the raised errors are right for keyword argument unpackings + + >>> from collections.abc import MutableMapping + >>> class CrazyDict(MutableMapping): + ... def __init__(self): + ... self.d = {} + ... + ... def __iter__(self): + ... for x in self.d.__iter__(): + ... if x == 'c': + ... self.d['z'] = 10 + ... yield x + ... + ... def __getitem__(self, k): + ... return self.d[k] + ... + ... def __len__(self): + ... return len(self.d) + ... + ... def __setitem__(self, k, v): + ... self.d[k] = v + ... + ... def __delitem__(self, k): + ... del self.d[k] + ... + >>> d = CrazyDict() + >>> d.d = {chr(ord('a') + x): x for x in range(5)} + >>> e = {**d} + Traceback (most recent call last): + ... + RuntimeError: dictionary changed size during iteration + + >>> d.d = {chr(ord('a') + x): x for x in range(5)} + >>> def f(**kwargs): print(kwargs) + >>> f(**d) + Traceback (most recent call last): + ... + RuntimeError: dictionary changed size during iteration + +Overridden parameters + + >>> f(x=5, **{'x': 3}, y=2) + Traceback (most recent call last): + ... + TypeError: f() got multiple values for keyword argument 'x' + + >>> f(**{'x': 3}, x=5, y=2) + Traceback (most recent call last): + ... + TypeError: f() got multiple values for keyword argument 'x' + + >>> f(**{'x': 3}, **{'x': 5}, y=2) + Traceback (most recent call last): + ... + TypeError: f() got multiple values for keyword argument 'x' + + >>> f(**{1: 3}, **{1: 5}) + Traceback (most recent call last): + ... + TypeError: f() keywords must be strings + Unpacking non-sequence >>> a, *b = 7 @@ -85,7 +265,14 @@ Unpacking sequence too short >>> a, *b, c, d, e = Seq() Traceback (most recent call last): ... - ValueError: need more than 3 values to unpack + ValueError: not enough values to unpack (expected at least 4, got 3) + +Unpacking sequence too short and target appears last + + >>> a, b, c, d, *e = Seq() + Traceback (most recent call last): + ... + ValueError: not enough values to unpack (expected at least 4, got 3) Unpacking a sequence where the test for too long raises a different kind of error @@ -131,17 +318,17 @@ Now some general starred expressions (all fail). >>> *a # doctest:+ELLIPSIS Traceback (most recent call last): ... - SyntaxError: can use starred expression only as assignment target + SyntaxError: can't use starred expression here >>> *1 # doctest:+ELLIPSIS Traceback (most recent call last): ... - SyntaxError: can use starred expression only as assignment target + SyntaxError: can't use starred expression here >>> x = *a # doctest:+ELLIPSIS Traceback (most recent call last): ... - SyntaxError: can use starred expression only as assignment target + SyntaxError: can't use starred expression here Some size constraints (all fail.) diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py index 16236ef..58ca2a5 100644 --- a/Lib/test/test_urllib.py +++ b/Lib/test/test_urllib.py @@ -10,7 +10,10 @@ import unittest from unittest.mock import patch from test import support import os -import ssl +try: + import ssl +except ImportError: + ssl = None import sys import tempfile from nturl2path import url2pathname, pathname2url @@ -380,6 +383,7 @@ Content-Type: text/html; charset=iso-8859-1 with support.check_warnings(('',DeprecationWarning)): urllib.request.URLopener() + @unittest.skipUnless(ssl, "ssl module required") def test_cafile_and_context(self): context = ssl.create_default_context() with self.assertRaises(ValueError): @@ -1331,7 +1335,7 @@ class URLopener_Tests(unittest.TestCase): # serv.settimeout(3) # serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # serv.bind(("", 9093)) -# serv.listen(5) +# serv.listen() # try: # conn, addr = serv.accept() # conn.send("1 Hola mundo\n") diff --git a/Lib/test/test_urllib2.py b/Lib/test/test_urllib2.py index 7d41ea1..a5281d8 100644 --- a/Lib/test/test_urllib2.py +++ b/Lib/test/test_urllib2.py @@ -11,7 +11,9 @@ import sys import urllib.request # The proxy bypass method imported below has logic specific to the OSX # proxy config data structure but is testable on all platforms. -from urllib.request import Request, OpenerDirector, _parse_proxy, _proxy_bypass_macosx_sysconf +from urllib.request import (Request, OpenerDirector, HTTPBasicAuthHandler, + HTTPPasswordMgrWithPriorAuth, _parse_proxy, + _proxy_bypass_macosx_sysconf) from urllib.parse import urlparse import urllib.error import http.client @@ -21,6 +23,7 @@ import http.client # CacheFTPHandler (hard to write) # parse_keqv_list, parse_http_list, HTTPDigestAuthHandler + class TrivialTests(unittest.TestCase): def test___all__(self): @@ -71,6 +74,7 @@ class TrivialTests(unittest.TestCase): err = urllib.error.URLError('reason') self.assertIn(err.reason, str(err)) + class RequestHdrsTests(unittest.TestCase): def test_request_headers_dict(self): @@ -130,7 +134,6 @@ class RequestHdrsTests(unittest.TestCase): req.remove_header("Unredirected-spam") self.assertFalse(req.has_header("Unredirected-spam")) - def test_password_manager(self): mgr = urllib.request.HTTPPasswordMgr() add = mgr.add_password @@ -234,43 +237,60 @@ class RequestHdrsTests(unittest.TestCase): class MockOpener: addheaders = [] + def open(self, req, data=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT): self.req, self.data, self.timeout = req, data, timeout + def error(self, proto, *args): self.proto, self.args = proto, args + class MockFile: - def read(self, count=None): pass - def readline(self, count=None): pass - def close(self): pass + def read(self, count=None): + pass + + def readline(self, count=None): + pass + + def close(self): + pass + class MockHeaders(dict): def getheaders(self, name): return list(self.values()) + class MockResponse(io.StringIO): def __init__(self, code, msg, headers, data, url=None): io.StringIO.__init__(self, data) self.code, self.msg, self.headers, self.url = code, msg, headers, url + def info(self): return self.headers + def geturl(self): return self.url + class MockCookieJar: def add_cookie_header(self, request): self.ach_req = request + def extract_cookies(self, response, request): self.ec_req, self.ec_r = request, response + class FakeMethod: def __init__(self, meth_name, action, handle): self.meth_name = meth_name self.handle = handle self.action = action + def __call__(self, *args): return self.handle(self.meth_name, self.action, *args) + class MockHTTPResponse(io.IOBase): def __init__(self, fp, msg, status, reason): self.fp = fp @@ -324,24 +344,31 @@ class MockHTTPClass: self.data = body if self.raise_on_endheaders: raise OSError() + def getresponse(self): return MockHTTPResponse(MockFile(), {}, 200, "OK") def close(self): pass + class MockHandler: # useful for testing handler machinery # see add_ordered_mock_handlers() docstring handler_order = 500 + def __init__(self, methods): self._define_methods(methods) + def _define_methods(self, methods): for spec in methods: - if len(spec) == 2: name, action = spec - else: name, action = spec, None + if len(spec) == 2: + name, action = spec + else: + name, action = spec, None meth = FakeMethod(name, action, self.handle) setattr(self.__class__, name, meth) + def handle(self, fn_name, action, *args, **kwds): self.parent.calls.append((self, fn_name, args, kwds)) if action is None: @@ -364,16 +391,21 @@ class MockHandler: elif action == "raise": raise urllib.error.URLError("blah") assert False - def close(self): pass + + def close(self): + pass + def add_parent(self, parent): self.parent = parent self.parent.calls = [] + def __lt__(self, other): if not hasattr(other, "handler_order"): # No handler_order, leave in original order. Yuck. return True return self.handler_order < other.handler_order + def add_ordered_mock_handlers(opener, meth_spec): """Create MockHandlers and add them to an OpenerDirector. @@ -396,7 +428,9 @@ def add_ordered_mock_handlers(opener, meth_spec): handlers = [] count = 0 for meths in meth_spec: - class MockHandlerSubclass(MockHandler): pass + class MockHandlerSubclass(MockHandler): + pass + h = MockHandlerSubclass(meths) h.handler_order += count h.add_parent(opener) @@ -405,12 +439,14 @@ def add_ordered_mock_handlers(opener, meth_spec): opener.add_handler(h) return handlers + def build_test_opener(*handler_instances): opener = OpenerDirector() for h in handler_instances: opener.add_handler(h) return opener + class MockHTTPHandler(urllib.request.BaseHandler): # useful for testing redirections and auth # sends supplied headers and code as first response @@ -419,9 +455,11 @@ class MockHTTPHandler(urllib.request.BaseHandler): self.code = code self.headers = headers self.reset() + def reset(self): self._count = 0 self.requests = [] + def http_open(self, req): import email, http.client, copy self.requests.append(copy.deepcopy(req)) @@ -436,6 +474,7 @@ class MockHTTPHandler(urllib.request.BaseHandler): msg = email.message_from_string("\r\n\r\n") return MockResponse(200, "OK", msg, "", req.get_full_url()) + class MockHTTPSHandler(urllib.request.AbstractHTTPHandler): # Useful for testing the Proxy-Authorization request by verifying the # properties of httpcon @@ -447,12 +486,33 @@ class MockHTTPSHandler(urllib.request.AbstractHTTPHandler): def https_open(self, req): return self.do_open(self.httpconn, req) + +class MockHTTPHandlerCheckAuth(urllib.request.BaseHandler): + # useful for testing auth + # sends supplied code response + # checks if auth header is specified in request + def __init__(self, code): + self.code = code + self.has_auth_header = False + + def reset(self): + self.has_auth_header = False + + def http_open(self, req): + if req.has_header('Authorization'): + self.has_auth_header = True + name = http.client.responses[self.code] + return MockResponse(self.code, name, MockFile(), "", req.get_full_url()) + + + class MockPasswordManager: def add_password(self, realm, uri, user, password): self.realm = realm self.url = uri self.user = user self.password = password + def find_user_password(self, realm, authuri): self.target_realm = realm self.target_url = authuri @@ -517,11 +577,11 @@ class OpenerDirectorTests(unittest.TestCase): def test_handler_order(self): o = OpenerDirector() handlers = [] - for meths, handler_order in [ - ([("http_open", "return self")], 500), - (["http_open"], 0), - ]: - class MockHandlerSubclass(MockHandler): pass + for meths, handler_order in [([("http_open", "return self")], 500), + (["http_open"], 0)]: + class MockHandlerSubclass(MockHandler): + pass + h = MockHandlerSubclass(meths) h.handler_order = handler_order handlers.append(h) @@ -559,7 +619,8 @@ class OpenerDirectorTests(unittest.TestCase): handlers = add_ordered_mock_handlers(o, meth_spec) class Unknown: - def __eq__(self, other): return True + def __eq__(self, other): + return True req = Request("http://example.com/") o.open(req) @@ -572,7 +633,6 @@ class OpenerDirectorTests(unittest.TestCase): self.assertEqual((handler, method_name), got[:2]) self.assertEqual(args, got[2]) - def test_processors(self): # *_request / *_response methods get called appropriately o = OpenerDirector() @@ -608,6 +668,7 @@ class OpenerDirectorTests(unittest.TestCase): if args[1] is not None: self.assertIsInstance(args[1], MockResponse) + def sanepathname2url(path): try: path.encode("utf-8") @@ -619,18 +680,25 @@ def sanepathname2url(path): # XXX don't ask me about the mac... return urlpath + class HandlerTests(unittest.TestCase): def test_ftp(self): class MockFTPWrapper: - def __init__(self, data): self.data = data + def __init__(self, data): + self.data = data + def retrfile(self, filename, filetype): self.filename, self.filetype = filename, filetype return io.StringIO(self.data), len(self.data) - def close(self): pass + + def close(self): + pass class NullFTPHandler(urllib.request.FTPHandler): - def __init__(self, data): self.data = data + def __init__(self, data): + self.data = data + def connect_ftp(self, user, passwd, host, port, dirs, timeout=socket._GLOBAL_DEFAULT_TIMEOUT): self.user, self.passwd = user, passwd @@ -868,7 +936,7 @@ class HandlerTests(unittest.TestCase): self.assertRaises(ValueError, h.do_request_, req) else: newreq = h.do_request_(req) - self.assertEqual(int(newreq.get_header('Content-length')),30) + self.assertEqual(int(newreq.get_header('Content-length')), 30) file_obj.close() @@ -901,12 +969,12 @@ class HandlerTests(unittest.TestCase): # Check whether host is determined correctly if there is no proxy np_ds_req = h.do_request_(ds_req) - self.assertEqual(np_ds_req.unredirected_hdrs["Host"],"example.com") + self.assertEqual(np_ds_req.unredirected_hdrs["Host"], "example.com") # Check whether host is determined correctly if there is a proxy - ds_req.set_proxy("someproxy:3128",None) + ds_req.set_proxy("someproxy:3128", None) p_ds_req = h.do_request_(ds_req) - self.assertEqual(p_ds_req.unredirected_hdrs["Host"],"example.com") + self.assertEqual(p_ds_req.unredirected_hdrs["Host"], "example.com") def test_full_url_setter(self): # Checks to ensure that components are set correctly after setting the @@ -948,15 +1016,14 @@ class HandlerTests(unittest.TestCase): weird_url = 'http://www.python.org?getspam' req = Request(weird_url) newreq = h.do_request_(req) - self.assertEqual(newreq.host,'www.python.org') - self.assertEqual(newreq.selector,'/?getspam') + self.assertEqual(newreq.host, 'www.python.org') + self.assertEqual(newreq.selector, '/?getspam') url_without_path = 'http://www.python.org' req = Request(url_without_path) newreq = h.do_request_(req) - self.assertEqual(newreq.host,'www.python.org') - self.assertEqual(newreq.selector,'') - + self.assertEqual(newreq.host, 'www.python.org') + self.assertEqual(newreq.selector, '') def test_errors(self): h = urllib.request.HTTPErrorProcessor() @@ -1043,6 +1110,7 @@ class HandlerTests(unittest.TestCase): # loop detection req = Request(from_url) req.timeout = socket._GLOBAL_DEFAULT_TIMEOUT + def redirect(h, req, url=to_url): h.http_error_302(req, MockFile(), 302, "Blah", MockHeaders({"location": url})) @@ -1073,7 +1141,6 @@ class HandlerTests(unittest.TestCase): self.assertEqual(count, urllib.request.HTTPRedirectHandler.max_redirections) - def test_invalid_redirect(self): from_url = "http://example.com/a.html" valid_schemes = ['http','https','ftp'] @@ -1176,7 +1243,6 @@ class HandlerTests(unittest.TestCase): self.assertEqual(req.host, "www.python.org") del os.environ['no_proxy'] - def test_proxy_https(self): o = OpenerDirector() ph = urllib.request.ProxyHandler(dict(https="proxy.example.com:3128")) @@ -1200,21 +1266,21 @@ class HandlerTests(unittest.TestCase): https_handler = MockHTTPSHandler() o.add_handler(https_handler) req = Request("https://www.example.com/") - req.add_header("Proxy-Authorization","FooBar") - req.add_header("User-Agent","Grail") + req.add_header("Proxy-Authorization", "FooBar") + req.add_header("User-Agent", "Grail") self.assertEqual(req.host, "www.example.com") self.assertIsNone(req._tunnel_host) o.open(req) # Verify Proxy-Authorization gets tunneled to request. # httpsconn req_headers do not have the Proxy-Authorization header but # the req will have. - self.assertNotIn(("Proxy-Authorization","FooBar"), + self.assertNotIn(("Proxy-Authorization", "FooBar"), https_handler.httpconn.req_headers) - self.assertIn(("User-Agent","Grail"), + self.assertIn(("User-Agent", "Grail"), https_handler.httpconn.req_headers) self.assertIsNotNone(req._tunnel_host) self.assertEqual(req.host, "proxy.example.com:3128") - self.assertEqual(req.get_header("Proxy-authorization"),"FooBar") + self.assertEqual(req.get_header("Proxy-authorization"), "FooBar") # TODO: This should be only for OSX @unittest.skipUnless(sys.platform == 'darwin', "only relevant for OSX") @@ -1246,7 +1312,7 @@ class HandlerTests(unittest.TestCase): realm = "ACME Widget Store" http_handler = MockHTTPHandler( 401, 'WWW-Authenticate: Basic realm=%s%s%s\r\n\r\n' % - (quote_char, realm, quote_char) ) + (quote_char, realm, quote_char)) opener.add_handler(auth_handler) opener.add_handler(http_handler) self._test_basic_auth(opener, auth_handler, "Authorization", @@ -1304,13 +1370,16 @@ class HandlerTests(unittest.TestCase): def __init__(self): OpenerDirector.__init__(self) self.recorded = [] + def record(self, info): self.recorded.append(info) + class TestDigestAuthHandler(urllib.request.HTTPDigestAuthHandler): def http_error_401(self, *args, **kwds): self.parent.record("digest") urllib.request.HTTPDigestAuthHandler.http_error_401(self, *args, **kwds) + class TestBasicAuthHandler(urllib.request.HTTPBasicAuthHandler): def http_error_401(self, *args, **kwds): self.parent.record("basic") @@ -1346,7 +1415,7 @@ class HandlerTests(unittest.TestCase): 401, 'WWW-Authenticate: Kerberos\r\n\r\n') opener.add_handler(digest_auth_handler) opener.add_handler(http_handler) - self.assertRaises(ValueError,opener.open,"http://www.example.com") + self.assertRaises(ValueError, opener.open, "http://www.example.com") def test_unsupported_auth_basic_handler(self): # While using BasicAuthHandler @@ -1356,7 +1425,7 @@ class HandlerTests(unittest.TestCase): 401, 'WWW-Authenticate: NTLM\r\n\r\n') opener.add_handler(basic_auth_handler) opener.add_handler(http_handler) - self.assertRaises(ValueError,opener.open,"http://www.example.com") + self.assertRaises(ValueError, opener.open, "http://www.example.com") def _test_basic_auth(self, opener, auth_handler, auth_header, realm, http_handler, password_manager, @@ -1395,6 +1464,72 @@ class HandlerTests(unittest.TestCase): self.assertEqual(len(http_handler.requests), 1) self.assertFalse(http_handler.requests[0].has_header(auth_header)) + def test_basic_prior_auth_auto_send(self): + # Assume already authenticated if is_authenticated=True + # for APIs like Github that don't return 401 + + user, password = "wile", "coyote" + request_url = "http://acme.example.com/protected" + + http_handler = MockHTTPHandlerCheckAuth(200) + + pwd_manager = HTTPPasswordMgrWithPriorAuth() + auth_prior_handler = HTTPBasicAuthHandler(pwd_manager) + auth_prior_handler.add_password( + None, request_url, user, password, is_authenticated=True) + + is_auth = pwd_manager.is_authenticated(request_url) + self.assertTrue(is_auth) + + opener = OpenerDirector() + opener.add_handler(auth_prior_handler) + opener.add_handler(http_handler) + + opener.open(request_url) + + # expect request to be sent with auth header + self.assertTrue(http_handler.has_auth_header) + + def test_basic_prior_auth_send_after_first_success(self): + # Auto send auth header after authentication is successful once + + user, password = 'wile', 'coyote' + request_url = 'http://acme.example.com/protected' + realm = 'ACME' + + pwd_manager = HTTPPasswordMgrWithPriorAuth() + auth_prior_handler = HTTPBasicAuthHandler(pwd_manager) + auth_prior_handler.add_password(realm, request_url, user, password) + + is_auth = pwd_manager.is_authenticated(request_url) + self.assertFalse(is_auth) + + opener = OpenerDirector() + opener.add_handler(auth_prior_handler) + + http_handler = MockHTTPHandler( + 401, 'WWW-Authenticate: Basic realm="%s"\r\n\r\n' % None) + opener.add_handler(http_handler) + + opener.open(request_url) + + is_auth = pwd_manager.is_authenticated(request_url) + self.assertTrue(is_auth) + + http_handler = MockHTTPHandlerCheckAuth(200) + self.assertFalse(http_handler.has_auth_header) + + opener = OpenerDirector() + opener.add_handler(auth_prior_handler) + opener.add_handler(http_handler) + + # After getting 200 from MockHTTPHandler + # Next request sends header in the first request + opener.open(request_url) + + # expect request to be sent with auth header + self.assertTrue(http_handler.has_auth_header) + def test_http_closed(self): """Test the connection is cleaned up when the response is closed""" for (transfer, data) in ( @@ -1423,6 +1558,7 @@ class HandlerTests(unittest.TestCase): self.assertTrue(conn.fakesock.closed, "Connection not closed") + class MiscTests(unittest.TestCase): def opener_has_handler(self, opener, handler_class): @@ -1430,11 +1566,16 @@ class MiscTests(unittest.TestCase): for h in opener.handlers)) def test_build_opener(self): - class MyHTTPHandler(urllib.request.HTTPHandler): pass + class MyHTTPHandler(urllib.request.HTTPHandler): + pass + class FooHandler(urllib.request.BaseHandler): - def foo_open(self): pass + def foo_open(self): + pass + class BarHandler(urllib.request.BaseHandler): - def bar_open(self): pass + def bar_open(self): + pass build_opener = urllib.request.build_opener @@ -1461,7 +1602,9 @@ class MiscTests(unittest.TestCase): self.opener_has_handler(o, urllib.request.HTTPHandler) # Issue2670: multiple handlers sharing the same base class - class MyOtherHTTPHandler(urllib.request.HTTPHandler): pass + class MyOtherHTTPHandler(urllib.request.HTTPHandler): + pass + o = build_opener(MyHTTPHandler, MyOtherHTTPHandler) self.opener_has_handler(o, MyHTTPHandler) self.opener_has_handler(o, MyOtherHTTPHandler) @@ -1497,6 +1640,8 @@ class MiscTests(unittest.TestCase): self.assertEqual(err.headers, 'Content-Length: 42') expected_errmsg = 'HTTP Error %s: %s' % (err.code, err.msg) self.assertEqual(str(err), expected_errmsg) + expected_errmsg = '<HTTPError %s: %r>' % (err.code, err.msg) + self.assertEqual(repr(err), expected_errmsg) def test_parse_proxy(self): parse_proxy_test_cases = [ @@ -1535,9 +1680,10 @@ class MiscTests(unittest.TestCase): self.assertRaises(ValueError, _parse_proxy, 'file:/ftp.example.com'), + class RequestTests(unittest.TestCase): class PutRequest(Request): - method='PUT' + method = 'PUT' def setUp(self): self.get = Request("http://www.python.org/~jeremy/") @@ -1626,7 +1772,7 @@ class RequestTests(unittest.TestCase): def test_url_fullurl_get_full_url(self): urls = ['http://docs.python.org', 'http://docs.python.org/library/urllib2.html#OK', - 'http://www.python.org/?qs=query#fragment=true' ] + 'http://www.python.org/?qs=query#fragment=true'] for url in urls: req = Request(url) self.assertEqual(req.get_full_url(), req.full_url) diff --git a/Lib/test/test_urllib2net.py b/Lib/test/test_urllib2net.py index 17f9d1b..cad83fd 100644 --- a/Lib/test/test_urllib2net.py +++ b/Lib/test/test_urllib2net.py @@ -20,8 +20,6 @@ def _retry_thrice(func, exc, *args, **kwargs): except exc as e: last_exc = e continue - except: - raise raise last_exc def _wrap_with_retry_thrice(func, exc): diff --git a/Lib/test/test_urlparse.py b/Lib/test/test_urlparse.py index 1775ef3..0552f90 100644 --- a/Lib/test/test_urlparse.py +++ b/Lib/test/test_urlparse.py @@ -210,10 +210,6 @@ class UrlParseTestCase(unittest.TestCase): # "abnormal" cases from RFC 1808: self.checkJoin(RFC1808_BASE, '', 'http://a/b/c/d;p?q#f') - self.checkJoin(RFC1808_BASE, '../../../g', 'http://a/../g') - self.checkJoin(RFC1808_BASE, '../../../../g', 'http://a/../../g') - self.checkJoin(RFC1808_BASE, '/./g', 'http://a/./g') - self.checkJoin(RFC1808_BASE, '/../g', 'http://a/../g') self.checkJoin(RFC1808_BASE, 'g.', 'http://a/b/c/g.') self.checkJoin(RFC1808_BASE, '.g', 'http://a/b/c/.g') self.checkJoin(RFC1808_BASE, 'g..', 'http://a/b/c/g..') @@ -228,6 +224,13 @@ class UrlParseTestCase(unittest.TestCase): #self.checkJoin(RFC1808_BASE, 'http:g', 'http:g') #self.checkJoin(RFC1808_BASE, 'http:', 'http:') + # XXX: The following tests are no longer compatible with RFC3986 + # self.checkJoin(RFC1808_BASE, '../../../g', 'http://a/../g') + # self.checkJoin(RFC1808_BASE, '../../../../g', 'http://a/../../g') + # self.checkJoin(RFC1808_BASE, '/./g', 'http://a/./g') + # self.checkJoin(RFC1808_BASE, '/../g', 'http://a/../g') + + def test_RFC2368(self): # Issue 11467: path that starts with a number is not parsed correctly self.assertEqual(urllib.parse.urlparse('mailto:1337@example.org'), @@ -258,10 +261,6 @@ class UrlParseTestCase(unittest.TestCase): self.checkJoin(RFC2396_BASE, '../../', 'http://a/') self.checkJoin(RFC2396_BASE, '../../g', 'http://a/g') self.checkJoin(RFC2396_BASE, '', RFC2396_BASE) - self.checkJoin(RFC2396_BASE, '../../../g', 'http://a/../g') - self.checkJoin(RFC2396_BASE, '../../../../g', 'http://a/../../g') - self.checkJoin(RFC2396_BASE, '/./g', 'http://a/./g') - self.checkJoin(RFC2396_BASE, '/../g', 'http://a/../g') self.checkJoin(RFC2396_BASE, 'g.', 'http://a/b/c/g.') self.checkJoin(RFC2396_BASE, '.g', 'http://a/b/c/.g') self.checkJoin(RFC2396_BASE, 'g..', 'http://a/b/c/g..') @@ -277,10 +276,17 @@ class UrlParseTestCase(unittest.TestCase): self.checkJoin(RFC2396_BASE, 'g#s/./x', 'http://a/b/c/g#s/./x') self.checkJoin(RFC2396_BASE, 'g#s/../x', 'http://a/b/c/g#s/../x') + # XXX: The following tests are no longer compatible with RFC3986 + # self.checkJoin(RFC2396_BASE, '../../../g', 'http://a/../g') + # self.checkJoin(RFC2396_BASE, '../../../../g', 'http://a/../../g') + # self.checkJoin(RFC2396_BASE, '/./g', 'http://a/./g') + # self.checkJoin(RFC2396_BASE, '/../g', 'http://a/../g') + + def test_RFC3986(self): # Test cases from RFC3986 self.checkJoin(RFC3986_BASE, '?y','http://a/b/c/d;p?y') - self.checkJoin(RFC2396_BASE, ';x', 'http://a/b/c/;x') + self.checkJoin(RFC3986_BASE, ';x', 'http://a/b/c/;x') self.checkJoin(RFC3986_BASE, 'g:h','g:h') self.checkJoin(RFC3986_BASE, 'g','http://a/b/c/g') self.checkJoin(RFC3986_BASE, './g','http://a/b/c/g') @@ -304,17 +310,17 @@ class UrlParseTestCase(unittest.TestCase): self.checkJoin(RFC3986_BASE, '../..','http://a/') self.checkJoin(RFC3986_BASE, '../../','http://a/') self.checkJoin(RFC3986_BASE, '../../g','http://a/g') + self.checkJoin(RFC3986_BASE, '../../../g', 'http://a/g') #Abnormal Examples # The 'abnormal scenarios' are incompatible with RFC2986 parsing # Tests are here for reference. - #self.checkJoin(RFC3986_BASE, '../../../g','http://a/g') - #self.checkJoin(RFC3986_BASE, '../../../../g','http://a/g') - #self.checkJoin(RFC3986_BASE, '/./g','http://a/g') - #self.checkJoin(RFC3986_BASE, '/../g','http://a/g') - + self.checkJoin(RFC3986_BASE, '../../../g','http://a/g') + self.checkJoin(RFC3986_BASE, '../../../../g','http://a/g') + self.checkJoin(RFC3986_BASE, '/./g','http://a/g') + self.checkJoin(RFC3986_BASE, '/../g','http://a/g') self.checkJoin(RFC3986_BASE, 'g.','http://a/b/c/g.') self.checkJoin(RFC3986_BASE, '.g','http://a/b/c/.g') self.checkJoin(RFC3986_BASE, 'g..','http://a/b/c/g..') @@ -354,10 +360,8 @@ class UrlParseTestCase(unittest.TestCase): self.checkJoin(SIMPLE_BASE, '../g','http://a/b/g') self.checkJoin(SIMPLE_BASE, '../..','http://a/') self.checkJoin(SIMPLE_BASE, '../../g','http://a/g') - self.checkJoin(SIMPLE_BASE, '../../../g','http://a/../g') self.checkJoin(SIMPLE_BASE, './../g','http://a/b/g') self.checkJoin(SIMPLE_BASE, './g/.','http://a/b/c/g/') - self.checkJoin(SIMPLE_BASE, '/./g','http://a/./g') self.checkJoin(SIMPLE_BASE, 'g/./h','http://a/b/c/g/h') self.checkJoin(SIMPLE_BASE, 'g/../h','http://a/b/c/h') self.checkJoin(SIMPLE_BASE, 'http:g','http://a/b/c/g') @@ -371,6 +375,25 @@ class UrlParseTestCase(unittest.TestCase): self.checkJoin('svn://pathtorepo/dir1', 'dir2', 'svn://pathtorepo/dir2') self.checkJoin('svn+ssh://pathtorepo/dir1', 'dir2', 'svn+ssh://pathtorepo/dir2') + # XXX: The following tests are no longer compatible with RFC3986 + # self.checkJoin(SIMPLE_BASE, '../../../g','http://a/../g') + # self.checkJoin(SIMPLE_BASE, '/./g','http://a/./g') + + # test for issue22118 duplicate slashes + self.checkJoin(SIMPLE_BASE + '/', 'foo', SIMPLE_BASE + '/foo') + + # Non-RFC-defined tests, covering variations of base and trailing + # slashes + self.checkJoin('http://a/b/c/d/e/', '../../f/g/', 'http://a/b/c/f/g/') + self.checkJoin('http://a/b/c/d/e', '../../f/g/', 'http://a/b/f/g/') + self.checkJoin('http://a/b/c/d/e/', '/../../f/g/', 'http://a/f/g/') + self.checkJoin('http://a/b/c/d/e', '/../../f/g/', 'http://a/f/g/') + self.checkJoin('http://a/b/c/d/e/', '../../f/g', 'http://a/b/c/f/g') + self.checkJoin('http://a/b/', '../../f/g/', 'http://a/f/g/') + + # issue 23703: don't duplicate filename + self.checkJoin('a', 'b', 'b') + def test_RFC2732(self): str_cases = [ ('http://Test.python.org:5432/foo/', 'test.python.org', 5432), @@ -803,6 +826,16 @@ class UrlParseTestCase(unittest.TestCase): result = urllib.parse.urlencode({'a': Trivial()}, True) self.assertEqual(result, 'a=trivial') + def test_urlencode_quote_via(self): + result = urllib.parse.urlencode({'a': 'some value'}) + self.assertEqual(result, "a=some+value") + result = urllib.parse.urlencode({'a': 'some value/another'}, + quote_via=urllib.parse.quote) + self.assertEqual(result, "a=some%20value%2Fanother") + result = urllib.parse.urlencode({'a': 'some value/another'}, + safe='/', quote_via=urllib.parse.quote) + self.assertEqual(result, "a=some%20value/another") + def test_quote_from_bytes(self): self.assertRaises(TypeError, urllib.parse.quote_from_bytes, 'foo') result = urllib.parse.quote_from_bytes(b'archaeological arcana') @@ -861,6 +894,22 @@ class UrlParseTestCase(unittest.TestCase): quoter = urllib.parse.Quoter(urllib.parse._ALWAYS_SAFE) self.assertIn('Quoter', repr(quoter)) + def test_all(self): + expected = [] + undocumented = { + 'splitattr', 'splithost', 'splitnport', 'splitpasswd', + 'splitport', 'splitquery', 'splittag', 'splittype', 'splituser', + 'splitvalue', + 'Quoter', 'ResultBase', 'clear_cache', 'to_bytes', 'unwrap', + } + for name in dir(urllib.parse): + if name.startswith('_') or name in undocumented: + continue + object = getattr(urllib.parse, name) + if getattr(object, '__module__', None) == 'urllib.parse': + expected.append(name) + self.assertCountEqual(urllib.parse.__all__, expected) + class Utility_Tests(unittest.TestCase): """Testcase to test the various utility functions in the urllib.""" diff --git a/Lib/test/test_userdict.py b/Lib/test/test_userdict.py index e7fee55..68a582c 100644 --- a/Lib/test/test_userdict.py +++ b/Lib/test/test_userdict.py @@ -215,10 +215,5 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol): -def test_main(): - support.run_unittest( - UserDictTest, - ) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_userlist.py b/Lib/test/test_userlist.py index 6381070..4a304df 100644 --- a/Lib/test/test_userlist.py +++ b/Lib/test/test_userlist.py @@ -58,8 +58,5 @@ class UserListTest(list_tests.CommonTest): self.assertEqual(u, v) self.assertEqual(type(u), type(v)) -def test_main(): - support.run_unittest(UserListTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_uuid.py b/Lib/test/test_uuid.py index 1e8cba3..fcb8454 100644 --- a/Lib/test/test_uuid.py +++ b/Lib/test/test_uuid.py @@ -1,9 +1,10 @@ -import unittest +import unittest.mock from test import support import builtins import io import os import shutil +import subprocess import uuid def importable(name): @@ -412,28 +413,27 @@ class TestUUID(unittest.TestCase): class TestInternals(unittest.TestCase): @unittest.skipUnless(os.name == 'posix', 'requires Posix') def test_find_mac(self): - data = '''\ - + data = ''' fake hwaddr cscotun0 Link encap:UNSPEC HWaddr 00-00-00-00-00-00-00-00-00-00-00-00-00-00-00-00 eth0 Link encap:Ethernet HWaddr 12:34:56:78:90:ab ''' - def mock_popen(cmd): - return io.StringIO(data) - - if shutil.which('ifconfig') is None: - path = os.pathsep.join(('/sbin', '/usr/sbin')) - if shutil.which('ifconfig', path=path) is None: - self.skipTest('requires ifconfig') - - with support.swap_attr(os, 'popen', mock_popen): - mac = uuid._find_mac( - command='ifconfig', - args='', - hw_identifiers=['hwaddr'], - get_index=lambda x: x + 1, - ) - self.assertEqual(mac, 0x1234567890ab) + + popen = unittest.mock.MagicMock() + popen.stdout = io.BytesIO(data.encode()) + + with unittest.mock.patch.object(shutil, 'which', + return_value='/sbin/ifconfig'): + with unittest.mock.patch.object(subprocess, 'Popen', + return_value=popen): + mac = uuid._find_mac( + command='ifconfig', + args='', + hw_identifiers=[b'hwaddr'], + get_index=lambda x: x + 1, + ) + + self.assertEqual(mac, 0x1234567890ab) def check_node(self, node, requires=None, network=False): if requires and node is None: @@ -454,6 +454,11 @@ eth0 Link encap:Ethernet HWaddr 12:34:56:78:90:ab self.check_node(node, 'ifconfig', True) @unittest.skipUnless(os.name == 'posix', 'requires Posix') + def test_ip_getnode(self): + node = uuid._ip_getnode() + self.check_node(node, 'ip', True) + + @unittest.skipUnless(os.name == 'posix', 'requires Posix') def test_arp_getnode(self): node = uuid._arp_getnode() self.check_node(node, 'arp', True) diff --git a/Lib/test/test_venv.py b/Lib/test/test_venv.py index b462588..9207a68 100644 --- a/Lib/test/test_venv.py +++ b/Lib/test/test_venv.py @@ -12,7 +12,7 @@ import struct import subprocess import sys import tempfile -from test.support import (captured_stdout, captured_stderr, run_unittest, +from test.support import (captured_stdout, captured_stderr, can_symlink, EnvironmentVarGuard, rmtree) import textwrap import unittest @@ -398,8 +398,5 @@ class EnsurePipTest(BaseTest): self.assert_pip_not_installed() -def test_main(): - run_unittest(BasicTest, EnsurePipTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_wait3.py b/Lib/test/test_wait3.py index f6a065d..eb51b2c 100644 --- a/Lib/test/test_wait3.py +++ b/Lib/test/test_wait3.py @@ -5,7 +5,7 @@ import os import time import unittest from test.fork_wait import ForkWait -from test.support import run_unittest, reap_children +from test.support import reap_children if not hasattr(os, 'fork'): raise unittest.SkipTest("os.fork not defined") @@ -18,7 +18,8 @@ class Wait3Test(ForkWait): # This many iterations can be required, since some previously run # tests (e.g. test_ctypes) could have spawned a lot of children # very quickly. - for i in range(30): + deadline = time.monotonic() + 10.0 + while time.monotonic() <= deadline: # wait3() shouldn't hang, but some of the buildbots seem to hang # in the forking tests. This is an attempt to fix the problem. spid, status, rusage = os.wait3(os.WNOHANG) @@ -30,9 +31,8 @@ class Wait3Test(ForkWait): self.assertEqual(status, 0, "cause = %d, exit = %d" % (status&0xff, status>>8)) self.assertTrue(rusage) -def test_main(): - run_unittest(Wait3Test) +def tearDownModule(): reap_children() if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_wait4.py b/Lib/test/test_wait4.py index 352c11a..43869be 100644 --- a/Lib/test/test_wait4.py +++ b/Lib/test/test_wait4.py @@ -5,7 +5,7 @@ import os import time import sys from test.fork_wait import ForkWait -from test.support import run_unittest, reap_children, get_attribute +from test.support import reap_children, get_attribute # If either of these do not exist, skip this test. get_attribute(os, 'fork') @@ -19,20 +19,20 @@ class Wait4Test(ForkWait): # Issue #11185: wait4 is broken on AIX and will always return 0 # with WNOHANG. option = 0 - for i in range(10): + deadline = time.monotonic() + 10.0 + while time.monotonic() <= deadline: # wait4() shouldn't hang, but some of the buildbots seem to hang # in the forking tests. This is an attempt to fix the problem. spid, status, rusage = os.wait4(cpid, option) if spid == cpid: break - time.sleep(1.0) + time.sleep(0.1) self.assertEqual(spid, cpid) self.assertEqual(status, 0, "cause = %d, exit = %d" % (status&0xff, status>>8)) self.assertTrue(rusage) -def test_main(): - run_unittest(Wait4Test) +def tearDownModule(): reap_children() if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_warnings.py b/Lib/test/test_warnings/__init__.py index b519f0a..cea9c57 100644 --- a/Lib/test/test_warnings.py +++ b/Lib/test/test_warnings/__init__.py @@ -5,9 +5,9 @@ from io import StringIO import sys import unittest from test import support -from test.script_helper import assert_python_ok +from test.support.script_helper import assert_python_ok, assert_python_failure -from test import warning_tests +from test.test_warnings.data import stacklevel as warning_tests import warnings as original_warnings @@ -194,11 +194,11 @@ class FilterTests(BaseTest): self.module.resetwarnings() self.module.filterwarnings("once", category=UserWarning) message = UserWarning("FilterTests.test_once") - self.module.warn_explicit(message, UserWarning, "test_warnings.py", + self.module.warn_explicit(message, UserWarning, "__init__.py", 42) self.assertEqual(w[-1].message, message) del w[:] - self.module.warn_explicit(message, UserWarning, "test_warnings.py", + self.module.warn_explicit(message, UserWarning, "__init__.py", 13) self.assertEqual(len(w), 0) self.module.warn_explicit(message, UserWarning, "test_warnings2.py", @@ -304,10 +304,10 @@ class WarnTests(BaseTest): module=self.module) as w: warning_tests.inner("spam1") self.assertEqual(os.path.basename(w[-1].filename), - "warning_tests.py") + "stacklevel.py") warning_tests.outer("spam2") self.assertEqual(os.path.basename(w[-1].filename), - "warning_tests.py") + "stacklevel.py") def test_stacklevel(self): # Test stacklevel argument @@ -317,25 +317,36 @@ class WarnTests(BaseTest): module=self.module) as w: warning_tests.inner("spam3", stacklevel=1) self.assertEqual(os.path.basename(w[-1].filename), - "warning_tests.py") + "stacklevel.py") warning_tests.outer("spam4", stacklevel=1) self.assertEqual(os.path.basename(w[-1].filename), - "warning_tests.py") + "stacklevel.py") warning_tests.inner("spam5", stacklevel=2) self.assertEqual(os.path.basename(w[-1].filename), - "test_warnings.py") + "__init__.py") warning_tests.outer("spam6", stacklevel=2) self.assertEqual(os.path.basename(w[-1].filename), - "warning_tests.py") + "stacklevel.py") warning_tests.outer("spam6.5", stacklevel=3) self.assertEqual(os.path.basename(w[-1].filename), - "test_warnings.py") + "__init__.py") warning_tests.inner("spam7", stacklevel=9999) self.assertEqual(os.path.basename(w[-1].filename), "sys") + def test_stacklevel_import(self): + # Issue #24305: With stacklevel=2, module-level warnings should work. + support.unload('test.test_warnings.data.import_warning') + with warnings_state(self.module): + with original_warnings.catch_warnings(record=True, + module=self.module) as w: + self.module.simplefilter('always') + import test.test_warnings.data.import_warning + self.assertEqual(len(w), 1) + self.assertEqual(w[0].filename, __file__) + def test_missing_filename_not_main(self): # If __file__ is not specified and __main__ is not the module name, # then __file__ should be set to the module name. @@ -450,6 +461,44 @@ class WarnTests(BaseTest): with self.assertRaises(ValueError): self.module.warn(BadStrWarning()) + def test_warning_classes(self): + class MyWarningClass(Warning): + pass + + class NonWarningSubclass: + pass + + # passing a non-subclass of Warning should raise a TypeError + with self.assertRaises(TypeError) as cm: + self.module.warn('bad warning category', '') + self.assertIn('category must be a Warning subclass, not ', + str(cm.exception)) + + with self.assertRaises(TypeError) as cm: + self.module.warn('bad warning category', NonWarningSubclass) + self.assertIn('category must be a Warning subclass, not ', + str(cm.exception)) + + # check that warning instances also raise a TypeError + with self.assertRaises(TypeError) as cm: + self.module.warn('bad warning category', MyWarningClass()) + self.assertIn('category must be a Warning subclass, not ', + str(cm.exception)) + + with original_warnings.catch_warnings(module=self.module): + self.module.resetwarnings() + self.module.filterwarnings('default') + with self.assertWarns(MyWarningClass) as cm: + self.module.warn('good warning category', MyWarningClass) + self.assertEqual('good warning category', str(cm.warning)) + + with self.assertWarns(UserWarning) as cm: + self.module.warn('good warning category', None) + self.assertEqual('good warning category', str(cm.warning)) + + with self.assertWarns(MyWarningClass) as cm: + self.module.warn('good warning category', MyWarningClass) + self.assertIsInstance(cm.warning, Warning) class CWarnTests(WarnTests, unittest.TestCase): module = c_warnings @@ -839,7 +888,19 @@ class EnvironmentVariableTests(BaseTest): "import sys; sys.stdout.write(str(sys.warnoptions))", PYTHONWARNINGS="ignore::DeprecationWarning") self.assertEqual(stdout, - b"['ignore::UnicodeWarning', 'ignore::DeprecationWarning']") + b"['ignore::DeprecationWarning', 'ignore::UnicodeWarning']") + + def test_conflicting_envvar_and_command_line(self): + rc, stdout, stderr = assert_python_failure("-Werror::DeprecationWarning", "-c", + "import sys, warnings; sys.stdout.write(str(sys.warnoptions)); " + "warnings.warn('Message', DeprecationWarning)", + PYTHONWARNINGS="default::DeprecationWarning") + self.assertEqual(stdout, + b"['default::DeprecationWarning', 'error::DeprecationWarning']") + self.assertEqual(stderr.splitlines(), + [b"Traceback (most recent call last):", + b" File \"<string>\", line 1, in <module>", + b"DeprecationWarning: Message"]) @unittest.skipUnless(sys.getfilesystemencoding() != 'ascii', 'requires non-ascii filesystemencoding') diff --git a/Lib/test/test_warnings/__main__.py b/Lib/test/test_warnings/__main__.py new file mode 100644 index 0000000..44e52ec --- /dev/null +++ b/Lib/test/test_warnings/__main__.py @@ -0,0 +1,3 @@ +import unittest + +unittest.main('test.test_warnings') diff --git a/Lib/test/test_warnings/data/import_warning.py b/Lib/test/test_warnings/data/import_warning.py new file mode 100644 index 0000000..d6ea2ce --- /dev/null +++ b/Lib/test/test_warnings/data/import_warning.py @@ -0,0 +1,3 @@ +import warnings + +warnings.warn('module-level warning', DeprecationWarning, stacklevel=2)
\ No newline at end of file diff --git a/Lib/test/warning_tests.py b/Lib/test/test_warnings/data/stacklevel.py index d0519ef..d0519ef 100644 --- a/Lib/test/warning_tests.py +++ b/Lib/test/test_warnings/data/stacklevel.py diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index 212cf34..f04e72b 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -7,7 +7,8 @@ import operator import contextlib import copy -from test import support, script_helper +from test import support +from test.support import script_helper # Used in ReferencesTestCase.test_ref_created_during_del() . ref_from_del = None @@ -92,6 +93,18 @@ class ReferencesTestCase(TestBase): self.check_basic_callback(create_function) self.check_basic_callback(create_bound_method) + @support.cpython_only + def test_cfunction(self): + import _testcapi + create_cfunction = _testcapi.create_cfunction + f = create_cfunction() + wr = weakref.ref(f) + self.assertIs(wr(), f) + del f + self.assertIsNone(wr()) + self.check_basic_ref(create_cfunction) + self.check_basic_callback(create_cfunction) + def test_multiple_callbacks(self): o = C() ref1 = weakref.ref(o, self.callback) @@ -1599,6 +1612,14 @@ class MappingTestCase(TestBase): self.assertEqual(len(d), 0) self.assertEqual(count, 2) + def test_make_weak_valued_dict_repr(self): + dict = weakref.WeakValueDictionary() + self.assertRegex(repr(dict), '<WeakValueDictionary at 0x.*>') + + def test_make_weak_keyed_dict_repr(self): + dict = weakref.WeakKeyDictionary() + self.assertRegex(repr(dict), '<WeakKeyDictionary at 0x.*>') + from test import mapping_tests class WeakValueDictionaryTestCase(mapping_tests.BasicTestMappingProtocol): diff --git a/Lib/test/test_weakset.py b/Lib/test/test_weakset.py index fb22879..9ce672b 100644 --- a/Lib/test/test_weakset.py +++ b/Lib/test/test_weakset.py @@ -1,5 +1,4 @@ import unittest -from test import support from weakref import proxy, ref, WeakSet import operator import copy @@ -443,8 +442,5 @@ class TestWeakSet(unittest.TestCase): self.assertLessEqual(n2, n1) -def test_main(verbose=None): - support.run_unittest(TestWeakSet) - if __name__ == "__main__": - test_main(verbose=True) + unittest.main() diff --git a/Lib/test/test_winsound.py b/Lib/test/test_winsound.py index 83618b6..7afb24b 100644 --- a/Lib/test/test_winsound.py +++ b/Lib/test/test_winsound.py @@ -246,8 +246,5 @@ def _have_soundcard(): return __have_soundcard_cache -def test_main(): - support.run_unittest(BeepTest, MessageBeepTest, PlaySoundTest) - -if __name__=="__main__": - test_main() +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py index cbaafcf..e8d789b 100644 --- a/Lib/test/test_with.py +++ b/Lib/test/test_with.py @@ -8,7 +8,6 @@ import sys import unittest from collections import deque from contextlib import _GeneratorContextManager, contextmanager -from test.support import run_unittest class MockContextManager(_GeneratorContextManager): @@ -455,7 +454,8 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase): with cm(): raise StopIteration("from with") - self.assertRaises(StopIteration, shouldThrow) + with self.assertWarnsRegex(PendingDeprecationWarning, "StopIteration"): + self.assertRaises(StopIteration, shouldThrow) def testRaisedStopIteration2(self): # From bug 1462485 @@ -482,7 +482,8 @@ class ExceptionalTestCase(ContextmanagerAssertionMixin, unittest.TestCase): with cm(): raise next(iter([])) - self.assertRaises(StopIteration, shouldThrow) + with self.assertWarnsRegex(PendingDeprecationWarning, "StopIteration"): + self.assertRaises(StopIteration, shouldThrow) def testRaisedGeneratorExit1(self): # From bug 1462485 @@ -737,14 +738,5 @@ class NestedWith(unittest.TestCase): self.assertEqual(10, b1) self.assertEqual(20, b2) -def test_main(): - run_unittest(FailureTestCase, NonexceptionalTestCase, - NestedNonexceptionalTestCase, ExceptionalTestCase, - NonLocalFlowControlTestCase, - AssignmentTargetTestCase, - ExitSwallowsExceptionTestCase, - NestedWith) - - if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_wsgiref.py b/Lib/test/test_wsgiref.py index 4076b68..8cca595 100644 --- a/Lib/test/test_wsgiref.py +++ b/Lib/test/test_wsgiref.py @@ -15,8 +15,6 @@ import re import sys import unittest -from test import support - class MockServer(WSGIServer): """Non-socket HTTP server""" @@ -369,6 +367,7 @@ class HeaderTests(TestCase): def testMappingInterface(self): test = [('x','y')] + self.assertEqual(len(Headers()), 0) self.assertEqual(len(Headers([])),0) self.assertEqual(len(Headers(test[:])),1) self.assertEqual(Headers(test[:]).keys(), ['x']) @@ -376,7 +375,7 @@ class HeaderTests(TestCase): self.assertEqual(Headers(test[:]).items(), test) self.assertIsNot(Headers(test).items(), test) # must be copy! - h=Headers([]) + h = Headers() del h['foo'] # should not raise an error h['Foo'] = 'bar' @@ -401,9 +400,8 @@ class HeaderTests(TestCase): def testRequireList(self): self.assertRaises(TypeError, Headers, "foo") - def testExtras(self): - h = Headers([]) + h = Headers() self.assertEqual(str(h),'\r\n') h.add_header('foo','bar',baz="spam") @@ -659,8 +657,5 @@ class HandlerTests(TestCase): self.assertEqual(side_effects['close_called'], True) -def test_main(): - support.run_unittest(__name__) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_xdrlib.py b/Lib/test/test_xdrlib.py index 70496d6..3df5f26 100644 --- a/Lib/test/test_xdrlib.py +++ b/Lib/test/test_xdrlib.py @@ -1,4 +1,3 @@ -from test import support import unittest import xdrlib @@ -74,9 +73,5 @@ class ConversionErrorTest(unittest.TestCase): def test_uhyper(self): self.assertRaisesConversion(self.packer.pack_uhyper, 'string') -def test_main(): - support.run_unittest(XDRTest) - support.run_unittest(ConversionErrorTest) - if __name__ == "__main__": - test_main() + unittest.main() diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index b87b098..ca8cdf8 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -705,7 +705,7 @@ class ElementTreeTest(unittest.TestCase): 'mac-roman', 'mac-turkish', 'iso2022-jp', 'iso2022-jp-1', 'iso2022-jp-2', 'iso2022-jp-2004', 'iso2022-jp-3', 'iso2022-jp-ext', - 'koi8-r', 'koi8-u', + 'koi8-r', 'koi8-t', 'koi8-u', 'kz1048', 'hz', 'ptcp154', ] for encoding in supported_encodings: diff --git a/Lib/test/test_xml_etree_c.py b/Lib/test/test_xml_etree_c.py index 816aa86..d0df38d 100644 --- a/Lib/test/test_xml_etree_c.py +++ b/Lib/test/test_xml_etree_c.py @@ -55,7 +55,7 @@ class SizeofTest(unittest.TestCase): def setUp(self): self.elementsize = support.calcobjsize('5P') # extra - self.extra = struct.calcsize('PiiP4P') + self.extra = struct.calcsize('PnnP4P') check_sizeof = support.check_sizeof diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py index 7ae0dce..9880f4a 100644 --- a/Lib/test/test_xmlrpc.py +++ b/Lib/test/test_xmlrpc.py @@ -287,7 +287,7 @@ class DateTimeTestCase(unittest.TestCase): def test_repr(self): d = datetime.datetime(2007,1,2,3,4,5) t = xmlrpclib.DateTime(d) - val ="<DateTime '20070102T03:04:05' at %x>" % id(t) + val ="<DateTime '20070102T03:04:05' at %#x>" % id(t) self.assertEqual(repr(t), val) def test_decode(self): @@ -713,6 +713,23 @@ class SimpleServerTestCase(BaseServerTestCase): conn.request('POST', '/RPC2 HTTP/1.0\r\nContent-Length: 100\r\n\r\nbye') conn.close() + def test_context_manager(self): + with xmlrpclib.ServerProxy(URL) as server: + server.add(2, 3) + self.assertNotEqual(server('transport')._connection, + (None, None)) + self.assertEqual(server('transport')._connection, + (None, None)) + + def test_context_manager_method_error(self): + try: + with xmlrpclib.ServerProxy(URL) as server: + server.add(2, "a") + except xmlrpclib.Fault: + pass + self.assertEqual(server('transport')._connection, + (None, None)) + class MultiPathServerTestCase(BaseServerTestCase): threadFunc = staticmethod(http_multi_server) @@ -919,6 +936,7 @@ class ServerProxyTestCase(unittest.TestCase): p = xmlrpclib.ServerProxy(self.url, transport=t) self.assertEqual(p('transport'), t) + # This is a contrived way to make a failure occur on the server side # in order to test the _send_traceback_header flag on the server class FailingMessageClass(http.client.HTTPMessage): diff --git a/Lib/test/test_zipapp.py b/Lib/test/test_zipapp.py new file mode 100644 index 0000000..9734380 --- /dev/null +++ b/Lib/test/test_zipapp.py @@ -0,0 +1,349 @@ +"""Test harness for the zipapp module.""" + +import io +import pathlib +import stat +import sys +import tempfile +import unittest +import zipapp +import zipfile + +from unittest.mock import patch + +class ZipAppTest(unittest.TestCase): + + """Test zipapp module functionality.""" + + def setUp(self): + tmpdir = tempfile.TemporaryDirectory() + self.addCleanup(tmpdir.cleanup) + self.tmpdir = pathlib.Path(tmpdir.name) + + def test_create_archive(self): + # Test packing a directory. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target)) + self.assertTrue(target.is_file()) + + def test_create_archive_with_pathlib(self): + # Test packing a directory using Path objects for source and target. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(source, target) + self.assertTrue(target.is_file()) + + def test_create_archive_with_subdirs(self): + # Test packing a directory includes entries for subdirectories. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + (source / 'foo').mkdir() + (source / 'bar').mkdir() + (source / 'foo' / '__init__.py').touch() + target = io.BytesIO() + zipapp.create_archive(str(source), target) + target.seek(0) + with zipfile.ZipFile(target, 'r') as z: + self.assertIn('foo/', z.namelist()) + self.assertIn('bar/', z.namelist()) + + def test_create_archive_default_target(self): + # Test packing a directory to the default name. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + zipapp.create_archive(str(source)) + expected_target = self.tmpdir / 'source.pyz' + self.assertTrue(expected_target.is_file()) + + def test_no_main(self): + # Test that packing a directory with no __main__.py fails. + source = self.tmpdir / 'source' + source.mkdir() + (source / 'foo.py').touch() + target = self.tmpdir / 'source.pyz' + with self.assertRaises(zipapp.ZipAppError): + zipapp.create_archive(str(source), str(target)) + + def test_main_and_main_py(self): + # Test that supplying a main argument with __main__.py fails. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + with self.assertRaises(zipapp.ZipAppError): + zipapp.create_archive(str(source), str(target), main='pkg.mod:fn') + + def test_main_written(self): + # Test that the __main__.py is written correctly. + source = self.tmpdir / 'source' + source.mkdir() + (source / 'foo.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target), main='pkg.mod:fn') + with zipfile.ZipFile(str(target), 'r') as z: + self.assertIn('__main__.py', z.namelist()) + self.assertIn(b'pkg.mod.fn()', z.read('__main__.py')) + + def test_main_only_written_once(self): + # Test that we don't write multiple __main__.py files. + # The initial implementation had this bug; zip files allow + # multiple entries with the same name + source = self.tmpdir / 'source' + source.mkdir() + # Write 2 files, as the original bug wrote __main__.py + # once for each file written :-( + # See http://bugs.python.org/review/23491/diff/13982/Lib/zipapp.py#newcode67Lib/zipapp.py:67 + # (line 67) + (source / 'foo.py').touch() + (source / 'bar.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target), main='pkg.mod:fn') + with zipfile.ZipFile(str(target), 'r') as z: + self.assertEqual(1, z.namelist().count('__main__.py')) + + def test_main_validation(self): + # Test that invalid values for main are rejected. + source = self.tmpdir / 'source' + source.mkdir() + target = self.tmpdir / 'source.pyz' + problems = [ + '', 'foo', 'foo:', ':bar', '12:bar', 'a.b.c.:d', + '.a:b', 'a:b.', 'a:.b', 'a:silly name' + ] + for main in problems: + with self.subTest(main=main): + with self.assertRaises(zipapp.ZipAppError): + zipapp.create_archive(str(source), str(target), main=main) + + def test_default_no_shebang(self): + # Test that no shebang line is written to the target by default. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target)) + with target.open('rb') as f: + self.assertNotEqual(f.read(2), b'#!') + + def test_custom_interpreter(self): + # Test that a shebang line with a custom interpreter is written + # correctly. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target), interpreter='python') + with target.open('rb') as f: + self.assertEqual(f.read(2), b'#!') + self.assertEqual(b'python\n', f.readline()) + + def test_pack_to_fileobj(self): + # Test that we can pack to a file object. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = io.BytesIO() + zipapp.create_archive(str(source), target, interpreter='python') + self.assertTrue(target.getvalue().startswith(b'#!python\n')) + + def test_read_shebang(self): + # Test that we can read the shebang line correctly. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target), interpreter='python') + self.assertEqual(zipapp.get_interpreter(str(target)), 'python') + + def test_read_missing_shebang(self): + # Test that reading the shebang line of a file without one returns None. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target)) + self.assertEqual(zipapp.get_interpreter(str(target)), None) + + def test_modify_shebang(self): + # Test that we can change the shebang of a file. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target), interpreter='python') + new_target = self.tmpdir / 'changed.pyz' + zipapp.create_archive(str(target), str(new_target), interpreter='python2.7') + self.assertEqual(zipapp.get_interpreter(str(new_target)), 'python2.7') + + def test_write_shebang_to_fileobj(self): + # Test that we can change the shebang of a file, writing the result to a + # file object. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target), interpreter='python') + new_target = io.BytesIO() + zipapp.create_archive(str(target), new_target, interpreter='python2.7') + self.assertTrue(new_target.getvalue().startswith(b'#!python2.7\n')) + + def test_read_from_pathobj(self): + # Test that we can copy an archive using an pathlib.Path object + # for the source. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target1 = self.tmpdir / 'target1.pyz' + target2 = self.tmpdir / 'target2.pyz' + zipapp.create_archive(source, target1, interpreter='python') + zipapp.create_archive(target1, target2, interpreter='python2.7') + self.assertEqual(zipapp.get_interpreter(target2), 'python2.7') + + def test_read_from_fileobj(self): + # Test that we can copy an archive using an open file object. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + temp_archive = io.BytesIO() + zipapp.create_archive(str(source), temp_archive, interpreter='python') + new_target = io.BytesIO() + temp_archive.seek(0) + zipapp.create_archive(temp_archive, new_target, interpreter='python2.7') + self.assertTrue(new_target.getvalue().startswith(b'#!python2.7\n')) + + def test_remove_shebang(self): + # Test that we can remove the shebang from a file. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target), interpreter='python') + new_target = self.tmpdir / 'changed.pyz' + zipapp.create_archive(str(target), str(new_target), interpreter=None) + self.assertEqual(zipapp.get_interpreter(str(new_target)), None) + + def test_content_of_copied_archive(self): + # Test that copying an archive doesn't corrupt it. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = io.BytesIO() + zipapp.create_archive(str(source), target, interpreter='python') + new_target = io.BytesIO() + target.seek(0) + zipapp.create_archive(target, new_target, interpreter=None) + new_target.seek(0) + with zipfile.ZipFile(new_target, 'r') as z: + self.assertEqual(set(z.namelist()), {'__main__.py'}) + + # (Unix only) tests that archives with shebang lines are made executable + @unittest.skipIf(sys.platform == 'win32', + 'Windows does not support an executable bit') + def test_shebang_is_executable(self): + # Test that an archive with a shebang line is made executable. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target), interpreter='python') + self.assertTrue(target.stat().st_mode & stat.S_IEXEC) + + @unittest.skipIf(sys.platform == 'win32', + 'Windows does not support an executable bit') + def test_no_shebang_is_not_executable(self): + # Test that an archive with no shebang line is not made executable. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(str(source), str(target), interpreter=None) + self.assertFalse(target.stat().st_mode & stat.S_IEXEC) + + +class ZipAppCmdlineTest(unittest.TestCase): + + """Test zipapp module command line API.""" + + def setUp(self): + tmpdir = tempfile.TemporaryDirectory() + self.addCleanup(tmpdir.cleanup) + self.tmpdir = pathlib.Path(tmpdir.name) + + def make_archive(self): + # Test that an archive with no shebang line is not made executable. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + target = self.tmpdir / 'source.pyz' + zipapp.create_archive(source, target) + return target + + def test_cmdline_create(self): + # Test the basic command line API. + source = self.tmpdir / 'source' + source.mkdir() + (source / '__main__.py').touch() + args = [str(source)] + zipapp.main(args) + target = source.with_suffix('.pyz') + self.assertTrue(target.is_file()) + + def test_cmdline_copy(self): + # Test copying an archive. + original = self.make_archive() + target = self.tmpdir / 'target.pyz' + args = [str(original), '-o', str(target)] + zipapp.main(args) + self.assertTrue(target.is_file()) + + def test_cmdline_copy_inplace(self): + # Test copying an archive in place fails. + original = self.make_archive() + target = self.tmpdir / 'target.pyz' + args = [str(original), '-o', str(original)] + with self.assertRaises(SystemExit) as cm: + zipapp.main(args) + # Program should exit with a non-zero returm code. + self.assertTrue(cm.exception.code) + + def test_cmdline_copy_change_main(self): + # Test copying an archive doesn't allow changing __main__.py. + original = self.make_archive() + target = self.tmpdir / 'target.pyz' + args = [str(original), '-o', str(target), '-m', 'foo:bar'] + with self.assertRaises(SystemExit) as cm: + zipapp.main(args) + # Program should exit with a non-zero returm code. + self.assertTrue(cm.exception.code) + + @patch('sys.stdout', new_callable=io.StringIO) + def test_info_command(self, mock_stdout): + # Test the output of the info command. + target = self.make_archive() + args = [str(target), '--info'] + with self.assertRaises(SystemExit) as cm: + zipapp.main(args) + # Program should exit with a zero returm code. + self.assertEqual(cm.exception.code, 0) + self.assertEqual(mock_stdout.getvalue(), "Interpreter: <none>\n") + + def test_info_error(self): + # Test the info command fails when the archive does not exist. + target = self.tmpdir / 'dummy.pyz' + args = [str(target), '--info'] + with self.assertRaises(SystemExit) as cm: + zipapp.main(args) + # Program should exit with a non-zero returm code. + self.assertTrue(cm.exception.code) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_zipfile.py b/Lib/test/test_zipfile.py index 4bdf7d4..2c10821 100644 --- a/Lib/test/test_zipfile.py +++ b/Lib/test/test_zipfile.py @@ -330,6 +330,37 @@ class AbstractTestsWithSourceFile: while zipopen.read1(100): pass + def test_repr(self): + fname = 'file.name' + for f in get_files(self): + with zipfile.ZipFile(f, 'w', self.compression) as zipfp: + zipfp.write(TESTFN, fname) + r = repr(zipfp) + self.assertIn("mode='w'", r) + + with zipfile.ZipFile(f, 'r') as zipfp: + r = repr(zipfp) + if isinstance(f, str): + self.assertIn('filename=%r' % f, r) + else: + self.assertIn('file=%r' % f, r) + self.assertIn("mode='r'", r) + r = repr(zipfp.getinfo(fname)) + self.assertIn('filename=%r' % fname, r) + self.assertIn('filemode=', r) + self.assertIn('file_size=', r) + if self.compression != zipfile.ZIP_STORED: + self.assertIn('compress_type=', r) + self.assertIn('compress_size=', r) + with zipfp.open(fname) as zipopen: + r = repr(zipopen) + self.assertIn('name=%r' % fname, r) + self.assertIn("mode='r'", r) + if self.compression != zipfile.ZIP_STORED: + self.assertIn('compress_type=', r) + self.assertIn('[closed]', repr(zipopen)) + self.assertIn('[closed]', repr(zipfp)) + def tearDown(self): unlink(TESTFN) unlink(TESTFN2) @@ -665,7 +696,7 @@ class PyZipFileTests(unittest.TestCase): self.requiresWriteAccess(os.path.dirname(__file__)) with TemporaryFile() as t, zipfile.PyZipFile(t, "w") as zipfp: fn = __file__ - if fn.endswith('.pyc') or fn.endswith('.pyo'): + if fn.endswith('.pyc'): path_split = fn.split(os.sep) if os.altsep is not None: path_split.extend(fn.split(os.altsep)) @@ -682,7 +713,7 @@ class PyZipFileTests(unittest.TestCase): with TemporaryFile() as t, zipfile.PyZipFile(t, "w") as zipfp: fn = __file__ - if fn.endswith(('.pyc', '.pyo')): + if fn.endswith('.pyc'): fn = fn[:-1] zipfp.writepy(fn, "testpackage") @@ -739,10 +770,8 @@ class PyZipFileTests(unittest.TestCase): import email packagedir = os.path.dirname(email.__file__) self.requiresWriteAccess(packagedir) - # use .pyc if running test in optimization mode, - # use .pyo if running test in debug mode optlevel = 1 if __debug__ else 0 - ext = '.pyo' if optlevel == 1 else '.pyc' + ext = '.pyc' with TemporaryFile() as t, \ zipfile.PyZipFile(t, "w", optimize=optlevel) as zipfp: @@ -816,11 +845,10 @@ class PyZipFileTests(unittest.TestCase): self.assertIn("SyntaxError", s.getvalue()) # as it will not have compiled the python file, it will - # include the .py file not .pyc or .pyo + # include the .py file not .pyc names = zipfp.namelist() self.assertIn('mod1.py', names) self.assertNotIn('mod1.pyc', names) - self.assertNotIn('mod1.pyo', names) finally: rmtree(TESTFN2) @@ -1081,6 +1109,19 @@ class OtherTests(unittest.TestCase): self.assertEqual(zf.filelist[0].filename, "foo.txt") self.assertEqual(zf.filelist[1].filename, "\xf6.txt") + def test_exclusive_create_zip_file(self): + """Test exclusive creating a new zipfile.""" + unlink(TESTFN2) + filename = 'testfile.txt' + content = b'hello, world. this is some content.' + with zipfile.ZipFile(TESTFN2, "x", zipfile.ZIP_STORED) as zipfp: + zipfp.writestr(filename, content) + with self.assertRaises(FileExistsError): + zipfile.ZipFile(TESTFN2, "x", zipfile.ZIP_STORED) + with zipfile.ZipFile(TESTFN2, "r") as zipfp: + self.assertEqual(zipfp.namelist(), [filename]) + self.assertEqual(zipfp.read(filename), content) + def test_create_non_existent_file_for_append(self): if os.path.exists(TESTFN): os.unlink(TESTFN) @@ -1655,6 +1696,72 @@ class LzmaTestsWithRandomBinaryFiles(AbstractTestsWithRandomBinaryFiles, compression = zipfile.ZIP_LZMA +# Privide the tell() method but not seek() +class Tellable: + def __init__(self, fp): + self.fp = fp + self.offset = 0 + + def write(self, data): + n = self.fp.write(data) + self.offset += n + return n + + def tell(self): + return self.offset + + def flush(self): + self.fp.flush() + +class Unseekable: + def __init__(self, fp): + self.fp = fp + + def write(self, data): + return self.fp.write(data) + + def flush(self): + self.fp.flush() + +class UnseekableTests(unittest.TestCase): + def test_writestr(self): + for wrapper in (lambda f: f), Tellable, Unseekable: + with self.subTest(wrapper=wrapper): + f = io.BytesIO() + f.write(b'abc') + bf = io.BufferedWriter(f) + with zipfile.ZipFile(wrapper(bf), 'w', zipfile.ZIP_STORED) as zipfp: + zipfp.writestr('ones', b'111') + zipfp.writestr('twos', b'222') + self.assertEqual(f.getvalue()[:5], b'abcPK') + with zipfile.ZipFile(f, mode='r') as zipf: + with zipf.open('ones') as zopen: + self.assertEqual(zopen.read(), b'111') + with zipf.open('twos') as zopen: + self.assertEqual(zopen.read(), b'222') + + def test_write(self): + for wrapper in (lambda f: f), Tellable, Unseekable: + with self.subTest(wrapper=wrapper): + f = io.BytesIO() + f.write(b'abc') + bf = io.BufferedWriter(f) + with zipfile.ZipFile(wrapper(bf), 'w', zipfile.ZIP_STORED) as zipfp: + self.addCleanup(unlink, TESTFN) + with open(TESTFN, 'wb') as f2: + f2.write(b'111') + zipfp.write(TESTFN, 'ones') + with open(TESTFN, 'wb') as f2: + f2.write(b'222') + zipfp.write(TESTFN, 'twos') + self.assertEqual(f.getvalue()[:5], b'abcPK') + with zipfile.ZipFile(f, mode='r') as zipf: + with zipf.open('ones') as zopen: + self.assertEqual(zopen.read(), b'111') + with zipf.open('twos') as zopen: + self.assertEqual(zopen.read(), b'222') + + @requires_zlib class TestsWithMultipleOpens(unittest.TestCase): @classmethod @@ -1671,35 +1778,52 @@ class TestsWithMultipleOpens(unittest.TestCase): def test_same_file(self): # Verify that (when the ZipFile is in control of creating file objects) # multiple open() calls can be made without interfering with each other. - self.make_test_archive(TESTFN2) - with zipfile.ZipFile(TESTFN2, mode="r") as zipf: - with zipf.open('ones') as zopen1, zipf.open('ones') as zopen2: - data1 = zopen1.read(500) - data2 = zopen2.read(500) - data1 += zopen1.read() - data2 += zopen2.read() - self.assertEqual(data1, data2) - self.assertEqual(data1, self.data1) + for f in get_files(self): + self.make_test_archive(f) + with zipfile.ZipFile(f, mode="r") as zipf: + with zipf.open('ones') as zopen1, zipf.open('ones') as zopen2: + data1 = zopen1.read(500) + data2 = zopen2.read(500) + data1 += zopen1.read() + data2 += zopen2.read() + self.assertEqual(data1, data2) + self.assertEqual(data1, self.data1) def test_different_file(self): # Verify that (when the ZipFile is in control of creating file objects) # multiple open() calls can be made without interfering with each other. - self.make_test_archive(TESTFN2) - with zipfile.ZipFile(TESTFN2, mode="r") as zipf: - with zipf.open('ones') as zopen1, zipf.open('twos') as zopen2: - data1 = zopen1.read(500) - data2 = zopen2.read(500) - data1 += zopen1.read() - data2 += zopen2.read() - self.assertEqual(data1, self.data1) - self.assertEqual(data2, self.data2) + for f in get_files(self): + self.make_test_archive(f) + with zipfile.ZipFile(f, mode="r") as zipf: + with zipf.open('ones') as zopen1, zipf.open('twos') as zopen2: + data1 = zopen1.read(500) + data2 = zopen2.read(500) + data1 += zopen1.read() + data2 += zopen2.read() + self.assertEqual(data1, self.data1) + self.assertEqual(data2, self.data2) def test_interleaved(self): # Verify that (when the ZipFile is in control of creating file objects) # multiple open() calls can be made without interfering with each other. - self.make_test_archive(TESTFN2) - with zipfile.ZipFile(TESTFN2, mode="r") as zipf: - with zipf.open('ones') as zopen1, zipf.open('twos') as zopen2: + for f in get_files(self): + self.make_test_archive(f) + with zipfile.ZipFile(f, mode="r") as zipf: + with zipf.open('ones') as zopen1, zipf.open('twos') as zopen2: + data1 = zopen1.read(500) + data2 = zopen2.read(500) + data1 += zopen1.read() + data2 += zopen2.read() + self.assertEqual(data1, self.data1) + self.assertEqual(data2, self.data2) + + def test_read_after_close(self): + for f in get_files(self): + self.make_test_archive(f) + with contextlib.ExitStack() as stack: + with zipfile.ZipFile(f, 'r') as zipf: + zopen1 = stack.enter_context(zipf.open('ones')) + zopen2 = stack.enter_context(zipf.open('twos')) data1 = zopen1.read(500) data2 = zopen2.read(500) data1 += zopen1.read() @@ -1707,43 +1831,32 @@ class TestsWithMultipleOpens(unittest.TestCase): self.assertEqual(data1, self.data1) self.assertEqual(data2, self.data2) - def test_read_after_close(self): - self.make_test_archive(TESTFN2) - with contextlib.ExitStack() as stack: - with zipfile.ZipFile(TESTFN2, 'r') as zipf: - zopen1 = stack.enter_context(zipf.open('ones')) - zopen2 = stack.enter_context(zipf.open('twos')) - data1 = zopen1.read(500) - data2 = zopen2.read(500) - data1 += zopen1.read() - data2 += zopen2.read() - self.assertEqual(data1, self.data1) - self.assertEqual(data2, self.data2) - def test_read_after_write(self): - with zipfile.ZipFile(TESTFN2, 'w', zipfile.ZIP_DEFLATED) as zipf: - zipf.writestr('ones', self.data1) - zipf.writestr('twos', self.data2) - with zipf.open('ones') as zopen1: - data1 = zopen1.read(500) - self.assertEqual(data1, self.data1[:500]) - with zipfile.ZipFile(TESTFN2, 'r') as zipf: - data1 = zipf.read('ones') - data2 = zipf.read('twos') - self.assertEqual(data1, self.data1) - self.assertEqual(data2, self.data2) + for f in get_files(self): + with zipfile.ZipFile(f, 'w', zipfile.ZIP_DEFLATED) as zipf: + zipf.writestr('ones', self.data1) + zipf.writestr('twos', self.data2) + with zipf.open('ones') as zopen1: + data1 = zopen1.read(500) + self.assertEqual(data1, self.data1[:500]) + with zipfile.ZipFile(f, 'r') as zipf: + data1 = zipf.read('ones') + data2 = zipf.read('twos') + self.assertEqual(data1, self.data1) + self.assertEqual(data2, self.data2) def test_write_after_read(self): - with zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_DEFLATED) as zipf: - zipf.writestr('ones', self.data1) - with zipf.open('ones') as zopen1: - zopen1.read(500) - zipf.writestr('twos', self.data2) - with zipfile.ZipFile(TESTFN2, 'r') as zipf: - data1 = zipf.read('ones') - data2 = zipf.read('twos') - self.assertEqual(data1, self.data1) - self.assertEqual(data2, self.data2) + for f in get_files(self): + with zipfile.ZipFile(f, "w", zipfile.ZIP_DEFLATED) as zipf: + zipf.writestr('ones', self.data1) + with zipf.open('ones') as zopen1: + zopen1.read(500) + zipf.writestr('twos', self.data2) + with zipfile.ZipFile(f, 'r') as zipf: + data1 = zipf.read('ones') + data2 = zipf.read('twos') + self.assertEqual(data1, self.data1) + self.assertEqual(data2, self.data2) def test_many_opens(self): # Verify that read() and open() promptly close the file descriptor, diff --git a/Lib/test/test_zipimport.py b/Lib/test/test_zipimport.py index 1e351c8..a97a778 100644 --- a/Lib/test/test_zipimport.py +++ b/Lib/test/test_zipimport.py @@ -51,7 +51,7 @@ TESTPACK2 = "ziptestpackage2" TEMP_ZIP = os.path.abspath("junk95142.zip") pyc_file = importlib.util.cache_from_source(TESTMOD + '.py') -pyc_ext = ('.pyc' if __debug__ else '.pyo') +pyc_ext = '.pyc' class ImportHooksBaseTestCase(unittest.TestCase): @@ -450,7 +450,9 @@ class BadFileZipImportTestCase(unittest.TestCase): fd = os.open(TESTMOD, os.O_CREAT, 000) try: os.close(fd) - self.assertZipFailure(TESTMOD) + + with self.assertRaises(zipimport.ZipImportError) as cm: + zipimport.zipimporter(TESTMOD) finally: # If we leave "the read-only bit" set on Windows, nothing can # delete TESTMOD, and later tests suffer bogus failures. diff --git a/Lib/test/test_zipimport_support.py b/Lib/test/test_zipimport_support.py index 66c3557..5913622 100644 --- a/Lib/test/test_zipimport_support.py +++ b/Lib/test/test_zipimport_support.py @@ -14,8 +14,8 @@ import inspect import linecache import pdb import unittest -from test.script_helper import (spawn_python, kill_python, assert_python_ok, - temp_dir, make_script, make_zip_script) +from test.support.script_helper import (spawn_python, kill_python, assert_python_ok, + make_script, make_zip_script) verbose = test.support.verbose @@ -39,7 +39,7 @@ def _run_object_doctest(obj, module): # Use the object's fully qualified name if it has one # Otherwise, use the module's name try: - name = "%s.%s" % (obj.__module__, obj.__name__) + name = "%s.%s" % (obj.__module__, obj.__qualname__) except AttributeError: name = module.__name__ for example in finder.find(obj, name, module): @@ -78,7 +78,7 @@ class ZipSupportTests(unittest.TestCase): def test_inspect_getsource_issue4223(self): test_src = "def foo(): pass\n" - with temp_dir() as d: + with test.support.temp_dir() as d: init_name = make_script(d, '__init__', test_src) name_in_zip = os.path.join('zip_pkg', os.path.basename(init_name)) @@ -118,7 +118,7 @@ class ZipSupportTests(unittest.TestCase): mod_name = mod_name.replace("sample_", "sample_zipped_") sample_sources[mod_name] = src - with temp_dir() as d: + with test.support.temp_dir() as d: script_name = make_script(d, 'test_zipped_doctest', test_src) zip_name, run_name = make_zip_script(d, 'test_zip', @@ -195,7 +195,7 @@ class ZipSupportTests(unittest.TestCase): doctest.testmod() """) pattern = 'File "%s", line 2, in %s' - with temp_dir() as d: + with test.support.temp_dir() as d: script_name = make_script(d, 'script', test_src) rc, out, err = assert_python_ok(script_name) expected = pattern % (script_name, "__main__.Test") @@ -222,7 +222,7 @@ class ZipSupportTests(unittest.TestCase): import pdb pdb.Pdb(nosigint=True).runcall(f) """) - with temp_dir() as d: + with test.support.temp_dir() as d: script_name = make_script(d, 'script', test_src) p = spawn_python(script_name) p.stdin.write(b'l\n') @@ -238,9 +238,8 @@ class ZipSupportTests(unittest.TestCase): self.assertIn(os.path.normcase(run_name.encode('utf-8')), data) -def test_main(): - test.support.run_unittest(ZipSupportTests) +def tearDownModule(): test.support.reap_children() if __name__ == '__main__': - test_main() + unittest.main() diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py index 53bb2ad..7cd1d7c 100644 --- a/Lib/test/test_zlib.py +++ b/Lib/test/test_zlib.py @@ -714,16 +714,5 @@ LAERTES """ -def test_main(): - support.run_unittest( - VersionTestCase, - ChecksumTestCase, - ChecksumBigBufferTestCase, - ExceptionTestCase, - CompressTestCase, - CompressObjectTestCase - ) - if __name__ == "__main__": - unittest.main() # XXX - ###test_main() + unittest.main() diff --git a/Lib/test/tf_inherit_check.py b/Lib/test/tf_inherit_check.py index afe50d2..138f25a 100644 --- a/Lib/test/tf_inherit_check.py +++ b/Lib/test/tf_inherit_check.py @@ -4,22 +4,24 @@ import sys import os +from test.support import SuppressCrashReport -verbose = (sys.argv[1] == 'v') -try: - fd = int(sys.argv[2]) - +with SuppressCrashReport(): + verbose = (sys.argv[1] == 'v') try: - os.write(fd, b"blat") - except OSError: - # Success -- could not write to fd. - sys.exit(0) - else: + fd = int(sys.argv[2]) + + try: + os.write(fd, b"blat") + except OSError: + # Success -- could not write to fd. + sys.exit(0) + else: + if verbose: + sys.stderr.write("fd %d is open in child" % fd) + sys.exit(1) + + except Exception: if verbose: - sys.stderr.write("fd %d is open in child" % fd) + raise sys.exit(1) - -except Exception: - if verbose: - raise - sys.exit(1) |