diff options
Diffstat (limited to 'Lib/test/test_unittest')
30 files changed, 18359 insertions, 0 deletions
diff --git a/Lib/test/test_unittest/__init__.py b/Lib/test/test_unittest/__init__.py new file mode 100644 index 0000000..bc502ef --- /dev/null +++ b/Lib/test/test_unittest/__init__.py @@ -0,0 +1,6 @@ +import os.path +from test.support import load_package_tests + + +def load_tests(*args): + return load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_unittest/__main__.py b/Lib/test/test_unittest/__main__.py new file mode 100644 index 0000000..40a23a2 --- /dev/null +++ b/Lib/test/test_unittest/__main__.py @@ -0,0 +1,4 @@ +from . import load_tests +import unittest + +unittest.main() diff --git a/Lib/test/test_unittest/_test_warnings.py b/Lib/test/test_unittest/_test_warnings.py new file mode 100644 index 0000000..5cbfb53 --- /dev/null +++ b/Lib/test/test_unittest/_test_warnings.py @@ -0,0 +1,73 @@ +# helper module for test_runner.Test_TextTestRunner.test_warnings + +""" +This module has a number of tests that raise different kinds of warnings. +When the tests are run, the warnings are caught and their messages are printed +to stdout. This module also accepts an arg that is then passed to +unittest.main to affect the behavior of warnings. +Test_TextTestRunner.test_warnings executes this script with different +combinations of warnings args and -W flags and check that the output is correct. +See #10535. +""" + +import sys +import unittest +import warnings + +def warnfun(): + warnings.warn('rw', RuntimeWarning) + +class TestWarnings(unittest.TestCase): + # unittest warnings will be printed at most once per type (max one message + # for the fail* methods, and one for the assert* methods) + def test_assert(self): + self.assertEquals(2+2, 4) + self.assertEquals(2*2, 4) + self.assertEquals(2**2, 4) + + def test_fail(self): + self.failUnless(1) + self.failUnless(True) + + def test_other_unittest(self): + self.assertAlmostEqual(2+2, 4) + self.assertNotAlmostEqual(4+4, 2) + + # these warnings are normally silenced, but they are printed in unittest + def test_deprecation(self): + warnings.warn('dw', DeprecationWarning) + warnings.warn('dw', DeprecationWarning) + warnings.warn('dw', DeprecationWarning) + + def test_import(self): + warnings.warn('iw', ImportWarning) + warnings.warn('iw', ImportWarning) + warnings.warn('iw', ImportWarning) + + # user warnings should always be printed + def test_warning(self): + warnings.warn('uw') + warnings.warn('uw') + warnings.warn('uw') + + # these warnings come from the same place; they will be printed + # only once by default or three times if the 'always' filter is used + def test_function(self): + + warnfun() + warnfun() + warnfun() + + + +if __name__ == '__main__': + with warnings.catch_warnings(record=True) as ws: + # if an arg is provided pass it to unittest.main as 'warnings' + if len(sys.argv) == 2: + unittest.main(exit=False, warnings=sys.argv.pop()) + else: + unittest.main(exit=False) + + # print all the warning messages collected + for w in ws: + print(w.message) diff --git a/Lib/test/test_unittest/dummy.py b/Lib/test/test_unittest/dummy.py new file mode 100644 index 0000000..e4f14e4 --- /dev/null +++ b/Lib/test/test_unittest/dummy.py @@ -0,0 +1 @@ +# Empty module for testing the loading of modules diff --git a/Lib/test/test_unittest/support.py b/Lib/test/test_unittest/support.py new file mode 100644 index 0000000..5292653 --- /dev/null +++ b/Lib/test/test_unittest/support.py @@ -0,0 +1,138 @@ +import unittest + + +class TestEquality(object): + """Used as a mixin for TestCase""" + + # Check for a valid __eq__ implementation + def test_eq(self): + for obj_1, obj_2 in self.eq_pairs: + self.assertEqual(obj_1, obj_2) + self.assertEqual(obj_2, obj_1) + + # Check for a valid __ne__ implementation + def test_ne(self): + for obj_1, obj_2 in self.ne_pairs: + self.assertNotEqual(obj_1, obj_2) + self.assertNotEqual(obj_2, obj_1) + +class TestHashing(object): + """Used as a mixin for TestCase""" + + # Check for a valid __hash__ implementation + def test_hash(self): + for obj_1, obj_2 in self.eq_pairs: + try: + if not hash(obj_1) == hash(obj_2): + self.fail("%r and %r do not hash equal" % (obj_1, obj_2)) + except Exception as e: + self.fail("Problem hashing %r and %r: %s" % (obj_1, obj_2, e)) + + for obj_1, obj_2 in self.ne_pairs: + try: + if hash(obj_1) == hash(obj_2): + self.fail("%s and %s hash equal, but shouldn't" % + (obj_1, obj_2)) + except Exception as e: + self.fail("Problem hashing %s and %s: %s" % (obj_1, obj_2, e)) + + +class _BaseLoggingResult(unittest.TestResult): + def __init__(self, log): + self._events = log + super().__init__() + + def startTest(self, test): + self._events.append('startTest') + super().startTest(test) + + def startTestRun(self): + self._events.append('startTestRun') + super().startTestRun() + + def stopTest(self, test): + self._events.append('stopTest') + super().stopTest(test) + + def stopTestRun(self): + self._events.append('stopTestRun') + super().stopTestRun() + + def addFailure(self, *args): + self._events.append('addFailure') + super().addFailure(*args) + + def addSuccess(self, *args): + self._events.append('addSuccess') + super().addSuccess(*args) + + def addError(self, *args): + self._events.append('addError') + super().addError(*args) + + def addSkip(self, *args): + self._events.append('addSkip') + super().addSkip(*args) + + def addExpectedFailure(self, *args): + self._events.append('addExpectedFailure') + super().addExpectedFailure(*args) + + def addUnexpectedSuccess(self, *args): + self._events.append('addUnexpectedSuccess') + super().addUnexpectedSuccess(*args) + + +class LegacyLoggingResult(_BaseLoggingResult): + """ + A legacy TestResult implementation, without an addSubTest method, + which records its method calls. + """ + + @property + def addSubTest(self): + raise AttributeError + + +class LoggingResult(_BaseLoggingResult): + """ + A TestResult implementation which records its method calls. + """ + + def addSubTest(self, test, subtest, err): + if err is None: + self._events.append('addSubTestSuccess') + else: + self._events.append('addSubTestFailure') + super().addSubTest(test, subtest, err) + + +class ResultWithNoStartTestRunStopTestRun(object): + """An object honouring TestResult before startTestRun/stopTestRun.""" + + def __init__(self): + self.failures = [] + self.errors = [] + self.testsRun = 0 + self.skipped = [] + self.expectedFailures = [] + self.unexpectedSuccesses = [] + self.shouldStop = False + + def startTest(self, test): + pass + + def stopTest(self, test): + pass + + def addError(self, test): + pass + + def addFailure(self, test): + pass + + def addSuccess(self, test): + pass + + def wasSuccessful(self): + return True diff --git a/Lib/test/test_unittest/test_assertions.py b/Lib/test/test_unittest/test_assertions.py new file mode 100644 index 0000000..a0db3423 --- /dev/null +++ b/Lib/test/test_unittest/test_assertions.py @@ -0,0 +1,416 @@ +import datetime +import warnings +import weakref +import unittest +from test.support import gc_collect +from itertools import product + + +class Test_Assertions(unittest.TestCase): + def test_AlmostEqual(self): + self.assertAlmostEqual(1.00000001, 1.0) + self.assertNotAlmostEqual(1.0000001, 1.0) + self.assertRaises(self.failureException, + self.assertAlmostEqual, 1.0000001, 1.0) + self.assertRaises(self.failureException, + self.assertNotAlmostEqual, 1.00000001, 1.0) + + self.assertAlmostEqual(1.1, 1.0, places=0) + self.assertRaises(self.failureException, + self.assertAlmostEqual, 1.1, 1.0, places=1) + + self.assertAlmostEqual(0, .1+.1j, places=0) + self.assertNotAlmostEqual(0, .1+.1j, places=1) + self.assertRaises(self.failureException, + self.assertAlmostEqual, 0, .1+.1j, places=1) + self.assertRaises(self.failureException, + self.assertNotAlmostEqual, 0, .1+.1j, places=0) + + self.assertAlmostEqual(float('inf'), float('inf')) + self.assertRaises(self.failureException, self.assertNotAlmostEqual, + float('inf'), float('inf')) + + def test_AmostEqualWithDelta(self): + self.assertAlmostEqual(1.1, 1.0, delta=0.5) + self.assertAlmostEqual(1.0, 1.1, delta=0.5) + self.assertNotAlmostEqual(1.1, 1.0, delta=0.05) + self.assertNotAlmostEqual(1.0, 1.1, delta=0.05) + + self.assertAlmostEqual(1.0, 1.0, delta=0.5) + self.assertRaises(self.failureException, self.assertNotAlmostEqual, + 1.0, 1.0, delta=0.5) + + self.assertRaises(self.failureException, self.assertAlmostEqual, + 1.1, 1.0, delta=0.05) + self.assertRaises(self.failureException, self.assertNotAlmostEqual, + 1.1, 1.0, delta=0.5) + + self.assertRaises(TypeError, self.assertAlmostEqual, + 1.1, 1.0, places=2, delta=2) + self.assertRaises(TypeError, self.assertNotAlmostEqual, + 1.1, 1.0, places=2, delta=2) + + first = datetime.datetime.now() + second = first + datetime.timedelta(seconds=10) + self.assertAlmostEqual(first, second, + delta=datetime.timedelta(seconds=20)) + self.assertNotAlmostEqual(first, second, + delta=datetime.timedelta(seconds=5)) + + def test_assertRaises(self): + def _raise(e): + raise e + self.assertRaises(KeyError, _raise, KeyError) + self.assertRaises(KeyError, _raise, KeyError("key")) + try: + self.assertRaises(KeyError, lambda: None) + except self.failureException as e: + self.assertIn("KeyError not raised", str(e)) + else: + self.fail("assertRaises() didn't fail") + try: + self.assertRaises(KeyError, _raise, ValueError) + except ValueError: + pass + else: + self.fail("assertRaises() didn't let exception pass through") + with self.assertRaises(KeyError) as cm: + try: + raise KeyError + except Exception as e: + exc = e + raise + self.assertIs(cm.exception, exc) + + with self.assertRaises(KeyError): + raise KeyError("key") + try: + with self.assertRaises(KeyError): + pass + except self.failureException as e: + self.assertIn("KeyError not raised", str(e)) + else: + self.fail("assertRaises() didn't fail") + try: + with self.assertRaises(KeyError): + raise ValueError + except ValueError: + pass + else: + self.fail("assertRaises() didn't let exception pass through") + + def test_assertRaises_frames_survival(self): + # Issue #9815: assertRaises should avoid keeping local variables + # in a traceback alive. + class A: + pass + wr = None + + class Foo(unittest.TestCase): + + def foo(self): + nonlocal wr + a = A() + wr = weakref.ref(a) + try: + raise OSError + except OSError: + raise ValueError + + def test_functional(self): + self.assertRaises(ValueError, self.foo) + + def test_with(self): + with self.assertRaises(ValueError): + self.foo() + + Foo("test_functional").run() + gc_collect() # For PyPy or other GCs. + self.assertIsNone(wr()) + Foo("test_with").run() + gc_collect() # For PyPy or other GCs. + self.assertIsNone(wr()) + + def testAssertNotRegex(self): + self.assertNotRegex('Ala ma kota', r'r+') + try: + self.assertNotRegex('Ala ma kota', r'k.t', 'Message') + except self.failureException as e: + self.assertIn('Message', e.args[0]) + else: + self.fail('assertNotRegex should have failed.') + + +class TestLongMessage(unittest.TestCase): + """Test that the individual asserts honour longMessage. + This actually tests all the message behaviour for + asserts that use longMessage.""" + + def setUp(self): + class TestableTestFalse(unittest.TestCase): + longMessage = False + failureException = self.failureException + + def testTest(self): + pass + + class TestableTestTrue(unittest.TestCase): + longMessage = True + failureException = self.failureException + + def testTest(self): + pass + + self.testableTrue = TestableTestTrue('testTest') + self.testableFalse = TestableTestFalse('testTest') + + def testDefault(self): + self.assertTrue(unittest.TestCase.longMessage) + + def test_formatMsg(self): + self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo") + self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo") + + self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo") + self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo") + + # This blows up if _formatMessage uses string concatenation + self.testableTrue._formatMessage(object(), 'foo') + + def test_formatMessage_unicode_error(self): + one = ''.join(chr(i) for i in range(255)) + # this used to cause a UnicodeDecodeError constructing msg + self.testableTrue._formatMessage(one, '\uFFFD') + + def assertMessages(self, methodName, args, errors): + """ + Check that methodName(*args) raises the correct error messages. + errors should be a list of 4 regex that match the error when: + 1) longMessage = False and no msg passed; + 2) longMessage = False and msg passed; + 3) longMessage = True and no msg passed; + 4) longMessage = True and msg passed; + """ + def getMethod(i): + useTestableFalse = i < 2 + if useTestableFalse: + test = self.testableFalse + else: + test = self.testableTrue + return getattr(test, methodName) + + for i, expected_regex in enumerate(errors): + testMethod = getMethod(i) + kwargs = {} + withMsg = i % 2 + if withMsg: + kwargs = {"msg": "oops"} + + with self.assertRaisesRegex(self.failureException, + expected_regex=expected_regex): + testMethod(*args, **kwargs) + + def testAssertTrue(self): + self.assertMessages('assertTrue', (False,), + ["^False is not true$", "^oops$", "^False is not true$", + "^False is not true : oops$"]) + + def testAssertFalse(self): + self.assertMessages('assertFalse', (True,), + ["^True is not false$", "^oops$", "^True is not false$", + "^True is not false : oops$"]) + + def testNotEqual(self): + self.assertMessages('assertNotEqual', (1, 1), + ["^1 == 1$", "^oops$", "^1 == 1$", + "^1 == 1 : oops$"]) + + def testAlmostEqual(self): + self.assertMessages( + 'assertAlmostEqual', (1, 2), + [r"^1 != 2 within 7 places \(1 difference\)$", "^oops$", + r"^1 != 2 within 7 places \(1 difference\)$", + r"^1 != 2 within 7 places \(1 difference\) : oops$"]) + + def testNotAlmostEqual(self): + self.assertMessages('assertNotAlmostEqual', (1, 1), + ["^1 == 1 within 7 places$", "^oops$", + "^1 == 1 within 7 places$", "^1 == 1 within 7 places : oops$"]) + + def test_baseAssertEqual(self): + self.assertMessages('_baseAssertEqual', (1, 2), + ["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"]) + + def testAssertSequenceEqual(self): + # Error messages are multiline so not testing on full message + # assertTupleEqual and assertListEqual delegate to this method + self.assertMessages('assertSequenceEqual', ([], [None]), + [r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", + r"\+ \[None\] : oops$"]) + + def testAssertSetEqual(self): + self.assertMessages('assertSetEqual', (set(), set([None])), + ["None$", "^oops$", "None$", + "None : oops$"]) + + def testAssertIn(self): + self.assertMessages('assertIn', (None, []), + [r'^None not found in \[\]$', "^oops$", + r'^None not found in \[\]$', + r'^None not found in \[\] : oops$']) + + def testAssertNotIn(self): + self.assertMessages('assertNotIn', (None, [None]), + [r'^None unexpectedly found in \[None\]$', "^oops$", + r'^None unexpectedly found in \[None\]$', + r'^None unexpectedly found in \[None\] : oops$']) + + def testAssertDictEqual(self): + self.assertMessages('assertDictEqual', ({}, {'key': 'value'}), + [r"\+ \{'key': 'value'\}$", "^oops$", + r"\+ \{'key': 'value'\}$", + r"\+ \{'key': 'value'\} : oops$"]) + + def testAssertDictContainsSubset(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + self.assertMessages('assertDictContainsSubset', ({'key': 'value'}, {}), + ["^Missing: 'key'$", "^oops$", + "^Missing: 'key'$", + "^Missing: 'key' : oops$"]) + + def testAssertMultiLineEqual(self): + self.assertMessages('assertMultiLineEqual', ("", "foo"), + [r"\+ foo$", "^oops$", + r"\+ foo$", + r"\+ foo : oops$"]) + + def testAssertLess(self): + self.assertMessages('assertLess', (2, 1), + ["^2 not less than 1$", "^oops$", + "^2 not less than 1$", "^2 not less than 1 : oops$"]) + + def testAssertLessEqual(self): + self.assertMessages('assertLessEqual', (2, 1), + ["^2 not less than or equal to 1$", "^oops$", + "^2 not less than or equal to 1$", + "^2 not less than or equal to 1 : oops$"]) + + def testAssertGreater(self): + self.assertMessages('assertGreater', (1, 2), + ["^1 not greater than 2$", "^oops$", + "^1 not greater than 2$", + "^1 not greater than 2 : oops$"]) + + def testAssertGreaterEqual(self): + self.assertMessages('assertGreaterEqual', (1, 2), + ["^1 not greater than or equal to 2$", "^oops$", + "^1 not greater than or equal to 2$", + "^1 not greater than or equal to 2 : oops$"]) + + def testAssertIsNone(self): + self.assertMessages('assertIsNone', ('not None',), + ["^'not None' is not None$", "^oops$", + "^'not None' is not None$", + "^'not None' is not None : oops$"]) + + def testAssertIsNotNone(self): + self.assertMessages('assertIsNotNone', (None,), + ["^unexpectedly None$", "^oops$", + "^unexpectedly None$", + "^unexpectedly None : oops$"]) + + def testAssertIs(self): + self.assertMessages('assertIs', (None, 'foo'), + ["^None is not 'foo'$", "^oops$", + "^None is not 'foo'$", + "^None is not 'foo' : oops$"]) + + def testAssertIsNot(self): + self.assertMessages('assertIsNot', (None, None), + ["^unexpectedly identical: None$", "^oops$", + "^unexpectedly identical: None$", + "^unexpectedly identical: None : oops$"]) + + def testAssertRegex(self): + self.assertMessages('assertRegex', ('foo', 'bar'), + ["^Regex didn't match:", + "^oops$", + "^Regex didn't match:", + "^Regex didn't match: (.*) : oops$"]) + + def testAssertNotRegex(self): + self.assertMessages('assertNotRegex', ('foo', 'foo'), + ["^Regex matched:", + "^oops$", + "^Regex matched:", + "^Regex matched: (.*) : oops$"]) + + + def assertMessagesCM(self, methodName, args, func, errors): + """ + Check that the correct error messages are raised while executing: + with method(*args): + func() + *errors* should be a list of 4 regex that match the error when: + 1) longMessage = False and no msg passed; + 2) longMessage = False and msg passed; + 3) longMessage = True and no msg passed; + 4) longMessage = True and msg passed; + """ + p = product((self.testableFalse, self.testableTrue), + ({}, {"msg": "oops"})) + for (cls, kwargs), err in zip(p, errors): + method = getattr(cls, methodName) + with self.assertRaisesRegex(cls.failureException, err): + with method(*args, **kwargs) as cm: + func() + + def testAssertRaises(self): + self.assertMessagesCM('assertRaises', (TypeError,), lambda: None, + ['^TypeError not raised$', '^oops$', + '^TypeError not raised$', + '^TypeError not raised : oops$']) + + def testAssertRaisesRegex(self): + # test error not raised + self.assertMessagesCM('assertRaisesRegex', (TypeError, 'unused regex'), + lambda: None, + ['^TypeError not raised$', '^oops$', + '^TypeError not raised$', + '^TypeError not raised : oops$']) + # test error raised but with wrong message + def raise_wrong_message(): + raise TypeError('foo') + self.assertMessagesCM('assertRaisesRegex', (TypeError, 'regex'), + raise_wrong_message, + ['^"regex" does not match "foo"$', '^oops$', + '^"regex" does not match "foo"$', + '^"regex" does not match "foo" : oops$']) + + def testAssertWarns(self): + self.assertMessagesCM('assertWarns', (UserWarning,), lambda: None, + ['^UserWarning not triggered$', '^oops$', + '^UserWarning not triggered$', + '^UserWarning not triggered : oops$']) + + def testAssertWarnsRegex(self): + # test error not raised + self.assertMessagesCM('assertWarnsRegex', (UserWarning, 'unused regex'), + lambda: None, + ['^UserWarning not triggered$', '^oops$', + '^UserWarning not triggered$', + '^UserWarning not triggered : oops$']) + # test warning raised but with wrong message + def raise_wrong_message(): + warnings.warn('foo') + self.assertMessagesCM('assertWarnsRegex', (UserWarning, 'regex'), + raise_wrong_message, + ['^"regex" does not match "foo"$', '^oops$', + '^"regex" does not match "foo"$', + '^"regex" does not match "foo" : oops$']) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/test_async_case.py b/Lib/test/test_unittest/test_async_case.py new file mode 100644 index 0000000..beadcac --- /dev/null +++ b/Lib/test/test_unittest/test_async_case.py @@ -0,0 +1,439 @@ +import asyncio +import contextvars +import unittest +from test import support + +support.requires_working_socket(module=True) + + +class MyException(Exception): + pass + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class TestCM: + def __init__(self, ordering, enter_result=None): + self.ordering = ordering + self.enter_result = enter_result + + async def __aenter__(self): + self.ordering.append('enter') + return self.enter_result + + async def __aexit__(self, *exc_info): + self.ordering.append('exit') + + +class LacksEnterAndExit: + pass +class LacksEnter: + async def __aexit__(self, *exc_info): + pass +class LacksExit: + async def __aenter__(self): + pass + + +VAR = contextvars.ContextVar('VAR', default=()) + + +class TestAsyncCase(unittest.TestCase): + maxDiff = None + + def tearDown(self): + # Ensure that IsolatedAsyncioTestCase instances are destroyed before + # starting a new event loop + support.gc_collect() + + def test_full_cycle(self): + class Test(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.assertEqual(events, []) + events.append('setUp') + VAR.set(VAR.get() + ('setUp',)) + + async def asyncSetUp(self): + self.assertEqual(events, ['setUp']) + events.append('asyncSetUp') + VAR.set(VAR.get() + ('asyncSetUp',)) + self.addAsyncCleanup(self.on_cleanup1) + + async def test_func(self): + self.assertEqual(events, ['setUp', + 'asyncSetUp']) + events.append('test') + VAR.set(VAR.get() + ('test',)) + self.addAsyncCleanup(self.on_cleanup2) + + async def asyncTearDown(self): + self.assertEqual(events, ['setUp', + 'asyncSetUp', + 'test']) + VAR.set(VAR.get() + ('asyncTearDown',)) + events.append('asyncTearDown') + + def tearDown(self): + self.assertEqual(events, ['setUp', + 'asyncSetUp', + 'test', + 'asyncTearDown']) + events.append('tearDown') + VAR.set(VAR.get() + ('tearDown',)) + + async def on_cleanup1(self): + self.assertEqual(events, ['setUp', + 'asyncSetUp', + 'test', + 'asyncTearDown', + 'tearDown', + 'cleanup2']) + events.append('cleanup1') + VAR.set(VAR.get() + ('cleanup1',)) + nonlocal cvar + cvar = VAR.get() + + async def on_cleanup2(self): + self.assertEqual(events, ['setUp', + 'asyncSetUp', + 'test', + 'asyncTearDown', + 'tearDown']) + events.append('cleanup2') + VAR.set(VAR.get() + ('cleanup2',)) + + events = [] + cvar = () + test = Test("test_func") + result = test.run() + self.assertEqual(result.errors, []) + self.assertEqual(result.failures, []) + expected = ['setUp', 'asyncSetUp', 'test', + 'asyncTearDown', 'tearDown', 'cleanup2', 'cleanup1'] + self.assertEqual(events, expected) + self.assertEqual(cvar, tuple(expected)) + + events = [] + cvar = () + test = Test("test_func") + test.debug() + self.assertEqual(events, expected) + self.assertEqual(cvar, tuple(expected)) + test.doCleanups() + self.assertEqual(events, expected) + self.assertEqual(cvar, tuple(expected)) + + def test_exception_in_setup(self): + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + events.append('asyncSetUp') + self.addAsyncCleanup(self.on_cleanup) + raise MyException() + + async def test_func(self): + events.append('test') + + async def asyncTearDown(self): + events.append('asyncTearDown') + + async def on_cleanup(self): + events.append('cleanup') + + + events = [] + test = Test("test_func") + result = test.run() + self.assertEqual(events, ['asyncSetUp', 'cleanup']) + self.assertIs(result.errors[0][0], test) + self.assertIn('MyException', result.errors[0][1]) + + events = [] + test = Test("test_func") + try: + test.debug() + except MyException: + pass + else: + self.fail('Expected a MyException exception') + self.assertEqual(events, ['asyncSetUp']) + test.doCleanups() + self.assertEqual(events, ['asyncSetUp', 'cleanup']) + + def test_exception_in_test(self): + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + events.append('asyncSetUp') + + async def test_func(self): + events.append('test') + self.addAsyncCleanup(self.on_cleanup) + raise MyException() + + async def asyncTearDown(self): + events.append('asyncTearDown') + + async def on_cleanup(self): + events.append('cleanup') + + events = [] + test = Test("test_func") + result = test.run() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup']) + self.assertIs(result.errors[0][0], test) + self.assertIn('MyException', result.errors[0][1]) + + events = [] + test = Test("test_func") + try: + test.debug() + except MyException: + pass + else: + self.fail('Expected a MyException exception') + self.assertEqual(events, ['asyncSetUp', 'test']) + test.doCleanups() + self.assertEqual(events, ['asyncSetUp', 'test', 'cleanup']) + + def test_exception_in_tear_down(self): + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + events.append('asyncSetUp') + + async def test_func(self): + events.append('test') + self.addAsyncCleanup(self.on_cleanup) + + async def asyncTearDown(self): + events.append('asyncTearDown') + raise MyException() + + async def on_cleanup(self): + events.append('cleanup') + + events = [] + test = Test("test_func") + result = test.run() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup']) + self.assertIs(result.errors[0][0], test) + self.assertIn('MyException', result.errors[0][1]) + + events = [] + test = Test("test_func") + try: + test.debug() + except MyException: + pass + else: + self.fail('Expected a MyException exception') + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown']) + test.doCleanups() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup']) + + def test_exception_in_tear_clean_up(self): + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + events.append('asyncSetUp') + + async def test_func(self): + events.append('test') + self.addAsyncCleanup(self.on_cleanup1) + self.addAsyncCleanup(self.on_cleanup2) + + async def asyncTearDown(self): + events.append('asyncTearDown') + + async def on_cleanup1(self): + events.append('cleanup1') + raise MyException('some error') + + async def on_cleanup2(self): + events.append('cleanup2') + raise MyException('other error') + + events = [] + test = Test("test_func") + result = test.run() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup2', 'cleanup1']) + self.assertIs(result.errors[0][0], test) + self.assertIn('MyException: other error', result.errors[0][1]) + self.assertIn('MyException: some error', result.errors[1][1]) + + events = [] + test = Test("test_func") + try: + test.debug() + except MyException: + pass + else: + self.fail('Expected a MyException exception') + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup2']) + test.doCleanups() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup2', 'cleanup1']) + + def test_deprecation_of_return_val_from_test(self): + # Issue 41322 - deprecate return of value!=None from a test + class Test(unittest.IsolatedAsyncioTestCase): + async def test1(self): + return 1 + async def test2(self): + yield 1 + + with self.assertWarns(DeprecationWarning) as w: + Test('test1').run() + self.assertIn('It is deprecated to return a value!=None', str(w.warning)) + self.assertIn('test1', str(w.warning)) + self.assertEqual(w.filename, __file__) + + with self.assertWarns(DeprecationWarning) as w: + Test('test2').run() + self.assertIn('It is deprecated to return a value!=None', str(w.warning)) + self.assertIn('test2', str(w.warning)) + self.assertEqual(w.filename, __file__) + + def test_cleanups_interleave_order(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def test_func(self): + self.addAsyncCleanup(self.on_sync_cleanup, 1) + self.addAsyncCleanup(self.on_async_cleanup, 2) + self.addAsyncCleanup(self.on_sync_cleanup, 3) + self.addAsyncCleanup(self.on_async_cleanup, 4) + + async def on_sync_cleanup(self, val): + events.append(f'sync_cleanup {val}') + + async def on_async_cleanup(self, val): + events.append(f'async_cleanup {val}') + + test = Test("test_func") + test.run() + self.assertEqual(events, ['async_cleanup 4', + 'sync_cleanup 3', + 'async_cleanup 2', + 'sync_cleanup 1']) + + def test_base_exception_from_async_method(self): + events = [] + class Test(unittest.IsolatedAsyncioTestCase): + async def test_base(self): + events.append("test_base") + raise BaseException() + events.append("not it") + + async def test_no_err(self): + events.append("test_no_err") + + async def test_cancel(self): + raise asyncio.CancelledError() + + test = Test("test_base") + output = test.run() + self.assertFalse(output.wasSuccessful()) + + test = Test("test_no_err") + test.run() + self.assertEqual(events, ['test_base', 'test_no_err']) + + test = Test("test_cancel") + output = test.run() + self.assertFalse(output.wasSuccessful()) + + def test_cancellation_hanging_tasks(self): + cancelled = False + class Test(unittest.IsolatedAsyncioTestCase): + async def test_leaking_task(self): + async def coro(): + nonlocal cancelled + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + cancelled = True + raise + + # Leave this running in the background + asyncio.create_task(coro()) + + test = Test("test_leaking_task") + output = test.run() + self.assertTrue(cancelled) + + def test_enterAsyncContext(self): + events = [] + + class Test(unittest.IsolatedAsyncioTestCase): + async def test_func(slf): + slf.addAsyncCleanup(events.append, 'cleanup1') + cm = TestCM(events, 42) + self.assertEqual(await slf.enterAsyncContext(cm), 42) + slf.addAsyncCleanup(events.append, 'cleanup2') + events.append('test') + + test = Test('test_func') + output = test.run() + self.assertTrue(output.wasSuccessful(), output) + self.assertEqual(events, ['enter', 'test', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterAsyncContext_arg_errors(self): + class Test(unittest.IsolatedAsyncioTestCase): + async def test_func(slf): + with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): + await slf.enterAsyncContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): + await slf.enterAsyncContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'asynchronous context manager'): + await slf.enterAsyncContext(LacksExit()) + + test = Test('test_func') + output = test.run() + self.assertTrue(output.wasSuccessful()) + + def test_debug_cleanup_same_loop(self): + class Test(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + async def coro(): + await asyncio.sleep(0) + fut = asyncio.ensure_future(coro()) + self.addAsyncCleanup(self.cleanup, fut) + events.append('asyncSetUp') + + async def test_func(self): + events.append('test') + raise MyException() + + async def asyncTearDown(self): + events.append('asyncTearDown') + + async def cleanup(self, fut): + try: + # Raises an exception if in different loop + await asyncio.wait([fut]) + events.append('cleanup') + except: + import traceback + traceback.print_exc() + raise + + events = [] + test = Test("test_func") + result = test.run() + self.assertEqual(events, ['asyncSetUp', 'test', 'asyncTearDown', 'cleanup']) + self.assertIn('MyException', result.errors[0][1]) + + events = [] + test = Test("test_func") + try: + test.debug() + except MyException: + pass + else: + self.fail('Expected a MyException exception') + self.assertEqual(events, ['asyncSetUp', 'test']) + test.doCleanups() + self.assertEqual(events, ['asyncSetUp', 'test', 'cleanup']) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/test_break.py b/Lib/test/test_unittest/test_break.py new file mode 100644 index 0000000..33cbdd2 --- /dev/null +++ b/Lib/test/test_unittest/test_break.py @@ -0,0 +1,306 @@ +import gc +import io +import os +import sys +import signal +import weakref +import unittest + +from test import support + + +@unittest.skipUnless(hasattr(os, 'kill'), "Test requires os.kill") +@unittest.skipIf(sys.platform =="win32", "Test cannot run on Windows") +class TestBreak(unittest.TestCase): + int_handler = None + # This number was smart-guessed, previously tests were failing + # after 7th run. So, we take `x * 2 + 1` to be sure. + default_repeats = 15 + + def setUp(self): + self._default_handler = signal.getsignal(signal.SIGINT) + if self.int_handler is not None: + signal.signal(signal.SIGINT, self.int_handler) + + def tearDown(self): + signal.signal(signal.SIGINT, self._default_handler) + unittest.signals._results = weakref.WeakKeyDictionary() + unittest.signals._interrupt_handler = None + + + def withRepeats(self, test_function, repeats=None): + if not support.check_impl_detail(cpython=True): + # Override repeats count on non-cpython to execute only once. + # Because this test only makes sense to be repeated on CPython. + repeats = 1 + elif repeats is None: + repeats = self.default_repeats + + for repeat in range(repeats): + with self.subTest(repeat=repeat): + # We don't run `setUp` for the very first repeat + # and we don't run `tearDown` for the very last one, + # because they are handled by the test class itself. + if repeat != 0: + self.setUp() + try: + test_function() + finally: + if repeat != repeats - 1: + self.tearDown() + + def testInstallHandler(self): + default_handler = signal.getsignal(signal.SIGINT) + unittest.installHandler() + self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler) + + try: + pid = os.getpid() + os.kill(pid, signal.SIGINT) + except KeyboardInterrupt: + self.fail("KeyboardInterrupt not handled") + + self.assertTrue(unittest.signals._interrupt_handler.called) + + def testRegisterResult(self): + result = unittest.TestResult() + self.assertNotIn(result, unittest.signals._results) + + unittest.registerResult(result) + try: + self.assertIn(result, unittest.signals._results) + finally: + unittest.removeResult(result) + + def testInterruptCaught(self): + def test(result): + pid = os.getpid() + os.kill(pid, signal.SIGINT) + result.breakCaught = True + self.assertTrue(result.shouldStop) + + def test_function(): + result = unittest.TestResult() + unittest.installHandler() + unittest.registerResult(result) + + self.assertNotEqual( + signal.getsignal(signal.SIGINT), + self._default_handler, + ) + + try: + test(result) + except KeyboardInterrupt: + self.fail("KeyboardInterrupt not handled") + self.assertTrue(result.breakCaught) + self.withRepeats(test_function) + + def testSecondInterrupt(self): + # Can't use skipIf decorator because the signal handler may have + # been changed after defining this method. + if signal.getsignal(signal.SIGINT) == signal.SIG_IGN: + self.skipTest("test requires SIGINT to not be ignored") + + def test(result): + pid = os.getpid() + os.kill(pid, signal.SIGINT) + result.breakCaught = True + self.assertTrue(result.shouldStop) + os.kill(pid, signal.SIGINT) + self.fail("Second KeyboardInterrupt not raised") + + def test_function(): + result = unittest.TestResult() + unittest.installHandler() + unittest.registerResult(result) + + with self.assertRaises(KeyboardInterrupt): + test(result) + self.assertTrue(result.breakCaught) + self.withRepeats(test_function) + + + def testTwoResults(self): + def test_function(): + unittest.installHandler() + + result = unittest.TestResult() + unittest.registerResult(result) + new_handler = signal.getsignal(signal.SIGINT) + + result2 = unittest.TestResult() + unittest.registerResult(result2) + self.assertEqual(signal.getsignal(signal.SIGINT), new_handler) + + result3 = unittest.TestResult() + + try: + os.kill(os.getpid(), signal.SIGINT) + except KeyboardInterrupt: + self.fail("KeyboardInterrupt not handled") + + self.assertTrue(result.shouldStop) + self.assertTrue(result2.shouldStop) + self.assertFalse(result3.shouldStop) + self.withRepeats(test_function) + + + def testHandlerReplacedButCalled(self): + # Can't use skipIf decorator because the signal handler may have + # been changed after defining this method. + if signal.getsignal(signal.SIGINT) == signal.SIG_IGN: + self.skipTest("test requires SIGINT to not be ignored") + + def test_function(): + # If our handler has been replaced (is no longer installed) but is + # called by the *new* handler, then it isn't safe to delay the + # SIGINT and we should immediately delegate to the default handler + unittest.installHandler() + + handler = signal.getsignal(signal.SIGINT) + def new_handler(frame, signum): + handler(frame, signum) + signal.signal(signal.SIGINT, new_handler) + + try: + os.kill(os.getpid(), signal.SIGINT) + except KeyboardInterrupt: + pass + else: + self.fail("replaced but delegated handler doesn't raise interrupt") + self.withRepeats(test_function) + + def testRunner(self): + # Creating a TextTestRunner with the appropriate argument should + # register the TextTestResult it creates + runner = unittest.TextTestRunner(stream=io.StringIO()) + + result = runner.run(unittest.TestSuite()) + self.assertIn(result, unittest.signals._results) + + def testWeakReferences(self): + # Calling registerResult on a result should not keep it alive + result = unittest.TestResult() + unittest.registerResult(result) + + ref = weakref.ref(result) + del result + + # For non-reference counting implementations + gc.collect();gc.collect() + self.assertIsNone(ref()) + + + def testRemoveResult(self): + result = unittest.TestResult() + unittest.registerResult(result) + + unittest.installHandler() + self.assertTrue(unittest.removeResult(result)) + + # Should this raise an error instead? + self.assertFalse(unittest.removeResult(unittest.TestResult())) + + try: + pid = os.getpid() + os.kill(pid, signal.SIGINT) + except KeyboardInterrupt: + pass + + self.assertFalse(result.shouldStop) + + def testMainInstallsHandler(self): + failfast = object() + test = object() + verbosity = object() + result = object() + default_handler = signal.getsignal(signal.SIGINT) + + class FakeRunner(object): + initArgs = [] + runArgs = [] + def __init__(self, *args, **kwargs): + self.initArgs.append((args, kwargs)) + def run(self, test): + self.runArgs.append(test) + return result + + class Program(unittest.TestProgram): + def __init__(self, catchbreak): + self.exit = False + self.verbosity = verbosity + self.failfast = failfast + self.catchbreak = catchbreak + self.tb_locals = False + self.testRunner = FakeRunner + self.test = test + self.result = None + + p = Program(False) + p.runTests() + + self.assertEqual(FakeRunner.initArgs, [((), {'buffer': None, + 'verbosity': verbosity, + 'failfast': failfast, + 'tb_locals': False, + 'warnings': None})]) + self.assertEqual(FakeRunner.runArgs, [test]) + self.assertEqual(p.result, result) + + self.assertEqual(signal.getsignal(signal.SIGINT), default_handler) + + FakeRunner.initArgs = [] + FakeRunner.runArgs = [] + p = Program(True) + p.runTests() + + self.assertEqual(FakeRunner.initArgs, [((), {'buffer': None, + 'verbosity': verbosity, + 'failfast': failfast, + 'tb_locals': False, + 'warnings': None})]) + self.assertEqual(FakeRunner.runArgs, [test]) + self.assertEqual(p.result, result) + + self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler) + + def testRemoveHandler(self): + default_handler = signal.getsignal(signal.SIGINT) + unittest.installHandler() + unittest.removeHandler() + self.assertEqual(signal.getsignal(signal.SIGINT), default_handler) + + # check that calling removeHandler multiple times has no ill-effect + unittest.removeHandler() + self.assertEqual(signal.getsignal(signal.SIGINT), default_handler) + + def testRemoveHandlerAsDecorator(self): + default_handler = signal.getsignal(signal.SIGINT) + unittest.installHandler() + + @unittest.removeHandler + def test(): + self.assertEqual(signal.getsignal(signal.SIGINT), default_handler) + + test() + self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler) + +@unittest.skipUnless(hasattr(os, 'kill'), "Test requires os.kill") +@unittest.skipIf(sys.platform =="win32", "Test cannot run on Windows") +class TestBreakDefaultIntHandler(TestBreak): + int_handler = signal.default_int_handler + +@unittest.skipUnless(hasattr(os, 'kill'), "Test requires os.kill") +@unittest.skipIf(sys.platform =="win32", "Test cannot run on Windows") +class TestBreakSignalIgnored(TestBreak): + int_handler = signal.SIG_IGN + +@unittest.skipUnless(hasattr(os, 'kill'), "Test requires os.kill") +@unittest.skipIf(sys.platform =="win32", "Test cannot run on Windows") +class TestBreakSignalDefault(TestBreak): + int_handler = signal.SIG_DFL + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/test_case.py b/Lib/test/test_unittest/test_case.py new file mode 100644 index 0000000..e000fe4 --- /dev/null +++ b/Lib/test/test_unittest/test_case.py @@ -0,0 +1,1977 @@ +import contextlib +import difflib +import pprint +import pickle +import re +import sys +import logging +import warnings +import weakref +import inspect +import types + +from copy import deepcopy +from test import support + +import unittest + +from test.test_unittest.support import ( + TestEquality, TestHashing, LoggingResult, LegacyLoggingResult, + ResultWithNoStartTestRunStopTestRun +) +from test.support import captured_stderr, gc_collect + + +log_foo = logging.getLogger('foo') +log_foobar = logging.getLogger('foo.bar') +log_quux = logging.getLogger('quux') + + +class Test(object): + "Keep these TestCase classes out of the main namespace" + + class Foo(unittest.TestCase): + def runTest(self): pass + def test1(self): pass + + class Bar(Foo): + def test2(self): pass + + class LoggingTestCase(unittest.TestCase): + """A test case which logs its calls.""" + + def __init__(self, events): + super(Test.LoggingTestCase, self).__init__('test') + self.events = events + + def setUp(self): + self.events.append('setUp') + + def test(self): + self.events.append('test') + + def tearDown(self): + self.events.append('tearDown') + + +class Test_TestCase(unittest.TestCase, TestEquality, TestHashing): + + ### Set up attributes used by inherited tests + ################################################################ + + # Used by TestHashing.test_hash and TestEquality.test_eq + eq_pairs = [(Test.Foo('test1'), Test.Foo('test1'))] + + # Used by TestEquality.test_ne + ne_pairs = [(Test.Foo('test1'), Test.Foo('runTest')), + (Test.Foo('test1'), Test.Bar('test1')), + (Test.Foo('test1'), Test.Bar('test2'))] + + ################################################################ + ### /Set up attributes used by inherited tests + + + # "class TestCase([methodName])" + # ... + # "Each instance of TestCase will run a single test method: the + # method named methodName." + # ... + # "methodName defaults to "runTest"." + # + # Make sure it really is optional, and that it defaults to the proper + # thing. + def test_init__no_test_name(self): + class Test(unittest.TestCase): + def runTest(self): raise MyException() + def test(self): pass + + self.assertEqual(Test().id()[-13:], '.Test.runTest') + + # test that TestCase can be instantiated with no args + # primarily for use at the interactive interpreter + test = unittest.TestCase() + test.assertEqual(3, 3) + with test.assertRaises(test.failureException): + test.assertEqual(3, 2) + + with self.assertRaises(AttributeError): + test.run() + + # "class TestCase([methodName])" + # ... + # "Each instance of TestCase will run a single test method: the + # method named methodName." + def test_init__test_name__valid(self): + class Test(unittest.TestCase): + def runTest(self): raise MyException() + def test(self): pass + + self.assertEqual(Test('test').id()[-10:], '.Test.test') + + # "class TestCase([methodName])" + # ... + # "Each instance of TestCase will run a single test method: the + # method named methodName." + def test_init__test_name__invalid(self): + class Test(unittest.TestCase): + def runTest(self): raise MyException() + def test(self): pass + + try: + Test('testfoo') + except ValueError: + pass + else: + self.fail("Failed to raise ValueError") + + # "Return the number of tests represented by the this test object. For + # TestCase instances, this will always be 1" + def test_countTestCases(self): + class Foo(unittest.TestCase): + def test(self): pass + + self.assertEqual(Foo('test').countTestCases(), 1) + + # "Return the default type of test result object to be used to run this + # test. For TestCase instances, this will always be + # unittest.TestResult; subclasses of TestCase should + # override this as necessary." + def test_defaultTestResult(self): + class Foo(unittest.TestCase): + def runTest(self): + pass + + result = Foo().defaultTestResult() + self.assertEqual(type(result), unittest.TestResult) + + # "When a setUp() method is defined, the test runner will run that method + # prior to each test. Likewise, if a tearDown() method is defined, the + # test runner will invoke that method after each test. In the example, + # setUp() was used to create a fresh sequence for each test." + # + # Make sure the proper call order is maintained, even if setUp() raises + # an exception. + def test_run_call_order__error_in_setUp(self): + events = [] + result = LoggingResult(events) + + class Foo(Test.LoggingTestCase): + def setUp(self): + super(Foo, self).setUp() + raise RuntimeError('raised by Foo.setUp') + + Foo(events).run(result) + expected = ['startTest', 'setUp', 'addError', 'stopTest'] + self.assertEqual(events, expected) + + # "With a temporary result stopTestRun is called when setUp errors. + def test_run_call_order__error_in_setUp_default_result(self): + events = [] + + class Foo(Test.LoggingTestCase): + def defaultTestResult(self): + return LoggingResult(self.events) + + def setUp(self): + super(Foo, self).setUp() + raise RuntimeError('raised by Foo.setUp') + + Foo(events).run() + expected = ['startTestRun', 'startTest', 'setUp', 'addError', + 'stopTest', 'stopTestRun'] + self.assertEqual(events, expected) + + # "When a setUp() method is defined, the test runner will run that method + # prior to each test. Likewise, if a tearDown() method is defined, the + # test runner will invoke that method after each test. In the example, + # setUp() was used to create a fresh sequence for each test." + # + # Make sure the proper call order is maintained, even if the test raises + # an error (as opposed to a failure). + def test_run_call_order__error_in_test(self): + events = [] + result = LoggingResult(events) + + class Foo(Test.LoggingTestCase): + def test(self): + super(Foo, self).test() + raise RuntimeError('raised by Foo.test') + + expected = ['startTest', 'setUp', 'test', + 'addError', 'tearDown', 'stopTest'] + Foo(events).run(result) + self.assertEqual(events, expected) + + # "With a default result, an error in the test still results in stopTestRun + # being called." + def test_run_call_order__error_in_test_default_result(self): + events = [] + + class Foo(Test.LoggingTestCase): + def defaultTestResult(self): + return LoggingResult(self.events) + + def test(self): + super(Foo, self).test() + raise RuntimeError('raised by Foo.test') + + expected = ['startTestRun', 'startTest', 'setUp', 'test', + 'addError', 'tearDown', 'stopTest', 'stopTestRun'] + Foo(events).run() + self.assertEqual(events, expected) + + # "When a setUp() method is defined, the test runner will run that method + # prior to each test. Likewise, if a tearDown() method is defined, the + # test runner will invoke that method after each test. In the example, + # setUp() was used to create a fresh sequence for each test." + # + # Make sure the proper call order is maintained, even if the test signals + # a failure (as opposed to an error). + def test_run_call_order__failure_in_test(self): + events = [] + result = LoggingResult(events) + + class Foo(Test.LoggingTestCase): + def test(self): + super(Foo, self).test() + self.fail('raised by Foo.test') + + expected = ['startTest', 'setUp', 'test', + 'addFailure', 'tearDown', 'stopTest'] + Foo(events).run(result) + self.assertEqual(events, expected) + + # "When a test fails with a default result stopTestRun is still called." + def test_run_call_order__failure_in_test_default_result(self): + + class Foo(Test.LoggingTestCase): + def defaultTestResult(self): + return LoggingResult(self.events) + def test(self): + super(Foo, self).test() + self.fail('raised by Foo.test') + + expected = ['startTestRun', 'startTest', 'setUp', 'test', + 'addFailure', 'tearDown', 'stopTest', 'stopTestRun'] + events = [] + Foo(events).run() + self.assertEqual(events, expected) + + # "When a setUp() method is defined, the test runner will run that method + # prior to each test. Likewise, if a tearDown() method is defined, the + # test runner will invoke that method after each test. In the example, + # setUp() was used to create a fresh sequence for each test." + # + # Make sure the proper call order is maintained, even if tearDown() raises + # an exception. + def test_run_call_order__error_in_tearDown(self): + events = [] + result = LoggingResult(events) + + class Foo(Test.LoggingTestCase): + def tearDown(self): + super(Foo, self).tearDown() + raise RuntimeError('raised by Foo.tearDown') + + Foo(events).run(result) + expected = ['startTest', 'setUp', 'test', 'tearDown', 'addError', + 'stopTest'] + self.assertEqual(events, expected) + + # "When tearDown errors with a default result stopTestRun is still called." + def test_run_call_order__error_in_tearDown_default_result(self): + + class Foo(Test.LoggingTestCase): + def defaultTestResult(self): + return LoggingResult(self.events) + def tearDown(self): + super(Foo, self).tearDown() + raise RuntimeError('raised by Foo.tearDown') + + events = [] + Foo(events).run() + expected = ['startTestRun', 'startTest', 'setUp', 'test', 'tearDown', + 'addError', 'stopTest', 'stopTestRun'] + self.assertEqual(events, expected) + + # "TestCase.run() still works when the defaultTestResult is a TestResult + # that does not support startTestRun and stopTestRun. + def test_run_call_order_default_result(self): + + class Foo(unittest.TestCase): + def defaultTestResult(self): + return ResultWithNoStartTestRunStopTestRun() + def test(self): + pass + + Foo('test').run() + + def test_deprecation_of_return_val_from_test(self): + # Issue 41322 - deprecate return of value!=None from a test + class Foo(unittest.TestCase): + def test1(self): + return 1 + def test2(self): + yield 1 + + with self.assertWarns(DeprecationWarning) as w: + Foo('test1').run() + self.assertIn('It is deprecated to return a value!=None', str(w.warning)) + self.assertIn('test1', str(w.warning)) + self.assertEqual(w.filename, __file__) + + with self.assertWarns(DeprecationWarning) as w: + Foo('test2').run() + self.assertIn('It is deprecated to return a value!=None', str(w.warning)) + self.assertIn('test2', str(w.warning)) + self.assertEqual(w.filename, __file__) + + def _check_call_order__subtests(self, result, events, expected_events): + class Foo(Test.LoggingTestCase): + def test(self): + super(Foo, self).test() + for i in [1, 2, 3]: + with self.subTest(i=i): + if i == 1: + self.fail('failure') + for j in [2, 3]: + with self.subTest(j=j): + if i * j == 6: + raise RuntimeError('raised by Foo.test') + 1 / 0 + + # Order is the following: + # i=1 => subtest failure + # i=2, j=2 => subtest success + # i=2, j=3 => subtest error + # i=3, j=2 => subtest error + # i=3, j=3 => subtest success + # toplevel => error + Foo(events).run(result) + self.assertEqual(events, expected_events) + + def test_run_call_order__subtests(self): + events = [] + result = LoggingResult(events) + expected = ['startTest', 'setUp', 'test', + 'addSubTestFailure', 'addSubTestSuccess', + 'addSubTestFailure', 'addSubTestFailure', + 'addSubTestSuccess', 'addError', 'tearDown', 'stopTest'] + self._check_call_order__subtests(result, events, expected) + + def test_run_call_order__subtests_legacy(self): + # With a legacy result object (without an addSubTest method), + # text execution stops after the first subtest failure. + events = [] + result = LegacyLoggingResult(events) + expected = ['startTest', 'setUp', 'test', + 'addFailure', 'tearDown', 'stopTest'] + self._check_call_order__subtests(result, events, expected) + + def _check_call_order__subtests_success(self, result, events, expected_events): + class Foo(Test.LoggingTestCase): + def test(self): + super(Foo, self).test() + for i in [1, 2]: + with self.subTest(i=i): + for j in [2, 3]: + with self.subTest(j=j): + pass + + Foo(events).run(result) + self.assertEqual(events, expected_events) + + def test_run_call_order__subtests_success(self): + events = [] + result = LoggingResult(events) + # The 6 subtest successes are individually recorded, in addition + # to the whole test success. + expected = (['startTest', 'setUp', 'test'] + + 6 * ['addSubTestSuccess'] + + ['tearDown', 'addSuccess', 'stopTest']) + self._check_call_order__subtests_success(result, events, expected) + + def test_run_call_order__subtests_success_legacy(self): + # With a legacy result, only the whole test success is recorded. + events = [] + result = LegacyLoggingResult(events) + expected = ['startTest', 'setUp', 'test', 'tearDown', + 'addSuccess', 'stopTest'] + self._check_call_order__subtests_success(result, events, expected) + + def test_run_call_order__subtests_failfast(self): + events = [] + result = LoggingResult(events) + result.failfast = True + + class Foo(Test.LoggingTestCase): + def test(self): + super(Foo, self).test() + with self.subTest(i=1): + self.fail('failure') + with self.subTest(i=2): + self.fail('failure') + self.fail('failure') + + expected = ['startTest', 'setUp', 'test', + 'addSubTestFailure', 'tearDown', 'stopTest'] + Foo(events).run(result) + self.assertEqual(events, expected) + + def test_subtests_failfast(self): + # Ensure proper test flow with subtests and failfast (issue #22894) + events = [] + + class Foo(unittest.TestCase): + def test_a(self): + with self.subTest(): + events.append('a1') + events.append('a2') + + def test_b(self): + with self.subTest(): + events.append('b1') + with self.subTest(): + self.fail('failure') + events.append('b2') + + def test_c(self): + events.append('c') + + result = unittest.TestResult() + result.failfast = True + suite = unittest.TestLoader().loadTestsFromTestCase(Foo) + suite.run(result) + + expected = ['a1', 'a2', 'b1'] + self.assertEqual(events, expected) + + def test_subtests_debug(self): + # Test debug() with a test that uses subTest() (bpo-34900) + events = [] + + class Foo(unittest.TestCase): + def test_a(self): + events.append('test case') + with self.subTest(): + events.append('subtest 1') + + Foo('test_a').debug() + + self.assertEqual(events, ['test case', 'subtest 1']) + + # "This class attribute gives the exception raised by the test() method. + # If a test framework needs to use a specialized exception, possibly to + # carry additional information, it must subclass this exception in + # order to ``play fair'' with the framework. The initial value of this + # attribute is AssertionError" + def test_failureException__default(self): + class Foo(unittest.TestCase): + def test(self): + pass + + self.assertIs(Foo('test').failureException, AssertionError) + + # "This class attribute gives the exception raised by the test() method. + # If a test framework needs to use a specialized exception, possibly to + # carry additional information, it must subclass this exception in + # order to ``play fair'' with the framework." + # + # Make sure TestCase.run() respects the designated failureException + def test_failureException__subclassing__explicit_raise(self): + events = [] + result = LoggingResult(events) + + class Foo(unittest.TestCase): + def test(self): + raise RuntimeError() + + failureException = RuntimeError + + self.assertIs(Foo('test').failureException, RuntimeError) + + + Foo('test').run(result) + expected = ['startTest', 'addFailure', 'stopTest'] + self.assertEqual(events, expected) + + # "This class attribute gives the exception raised by the test() method. + # If a test framework needs to use a specialized exception, possibly to + # carry additional information, it must subclass this exception in + # order to ``play fair'' with the framework." + # + # Make sure TestCase.run() respects the designated failureException + def test_failureException__subclassing__implicit_raise(self): + events = [] + result = LoggingResult(events) + + class Foo(unittest.TestCase): + def test(self): + self.fail("foo") + + failureException = RuntimeError + + self.assertIs(Foo('test').failureException, RuntimeError) + + + Foo('test').run(result) + expected = ['startTest', 'addFailure', 'stopTest'] + self.assertEqual(events, expected) + + # "The default implementation does nothing." + def test_setUp(self): + class Foo(unittest.TestCase): + def runTest(self): + pass + + # ... and nothing should happen + Foo().setUp() + + # "The default implementation does nothing." + def test_tearDown(self): + class Foo(unittest.TestCase): + def runTest(self): + pass + + # ... and nothing should happen + Foo().tearDown() + + # "Return a string identifying the specific test case." + # + # Because of the vague nature of the docs, I'm not going to lock this + # test down too much. Really all that can be asserted is that the id() + # will be a string (either 8-byte or unicode -- again, because the docs + # just say "string") + def test_id(self): + class Foo(unittest.TestCase): + def runTest(self): + pass + + self.assertIsInstance(Foo().id(), str) + + + # "If result is omitted or None, a temporary result object is created, + # used, and is made available to the caller. As TestCase owns the + # temporary result startTestRun and stopTestRun are called. + + def test_run__uses_defaultTestResult(self): + events = [] + defaultResult = LoggingResult(events) + + class Foo(unittest.TestCase): + def test(self): + events.append('test') + + def defaultTestResult(self): + return defaultResult + + # Make run() find a result object on its own + result = Foo('test').run() + + self.assertIs(result, defaultResult) + expected = ['startTestRun', 'startTest', 'test', 'addSuccess', + 'stopTest', 'stopTestRun'] + self.assertEqual(events, expected) + + + # "The result object is returned to run's caller" + def test_run__returns_given_result(self): + + class Foo(unittest.TestCase): + def test(self): + pass + + result = unittest.TestResult() + + retval = Foo('test').run(result) + self.assertIs(retval, result) + + + # "The same effect [as method run] may be had by simply calling the + # TestCase instance." + def test_call__invoking_an_instance_delegates_to_run(self): + resultIn = unittest.TestResult() + resultOut = unittest.TestResult() + + class Foo(unittest.TestCase): + def test(self): + pass + + def run(self, result): + self.assertIs(result, resultIn) + return resultOut + + retval = Foo('test')(resultIn) + + self.assertIs(retval, resultOut) + + + def testShortDescriptionWithoutDocstring(self): + self.assertIsNone(self.shortDescription()) + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def testShortDescriptionWithOneLineDocstring(self): + """Tests shortDescription() for a method with a docstring.""" + self.assertEqual( + self.shortDescription(), + 'Tests shortDescription() for a method with a docstring.') + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def testShortDescriptionWithMultiLineDocstring(self): + """Tests shortDescription() for a method with a longer docstring. + + This method ensures that only the first line of a docstring is + returned used in the short description, no matter how long the + whole thing is. + """ + self.assertEqual( + self.shortDescription(), + 'Tests shortDescription() for a method with a longer ' + 'docstring.') + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def testShortDescriptionWhitespaceTrimming(self): + """ + Tests shortDescription() whitespace is trimmed, so that the first + line of nonwhite-space text becomes the docstring. + """ + self.assertEqual( + self.shortDescription(), + 'Tests shortDescription() whitespace is trimmed, so that the first') + + def testAddTypeEqualityFunc(self): + class SadSnake(object): + """Dummy class for test_addTypeEqualityFunc.""" + s1, s2 = SadSnake(), SadSnake() + self.assertFalse(s1 == s2) + def AllSnakesCreatedEqual(a, b, msg=None): + return type(a) == type(b) == SadSnake + self.addTypeEqualityFunc(SadSnake, AllSnakesCreatedEqual) + self.assertEqual(s1, s2) + # No this doesn't clean up and remove the SadSnake equality func + # from this TestCase instance but since it's local nothing else + # will ever notice that. + + def testAssertIs(self): + thing = object() + self.assertIs(thing, thing) + self.assertRaises(self.failureException, self.assertIs, thing, object()) + + def testAssertIsNot(self): + thing = object() + self.assertIsNot(thing, object()) + self.assertRaises(self.failureException, self.assertIsNot, thing, thing) + + def testAssertIsInstance(self): + thing = [] + self.assertIsInstance(thing, list) + self.assertRaises(self.failureException, self.assertIsInstance, + thing, dict) + + def testAssertNotIsInstance(self): + thing = [] + self.assertNotIsInstance(thing, dict) + self.assertRaises(self.failureException, self.assertNotIsInstance, + thing, list) + + def testAssertIn(self): + animals = {'monkey': 'banana', 'cow': 'grass', 'seal': 'fish'} + + self.assertIn('a', 'abc') + self.assertIn(2, [1, 2, 3]) + self.assertIn('monkey', animals) + + self.assertNotIn('d', 'abc') + self.assertNotIn(0, [1, 2, 3]) + self.assertNotIn('otter', animals) + + self.assertRaises(self.failureException, self.assertIn, 'x', 'abc') + self.assertRaises(self.failureException, self.assertIn, 4, [1, 2, 3]) + self.assertRaises(self.failureException, self.assertIn, 'elephant', + animals) + + self.assertRaises(self.failureException, self.assertNotIn, 'c', 'abc') + self.assertRaises(self.failureException, self.assertNotIn, 1, [1, 2, 3]) + self.assertRaises(self.failureException, self.assertNotIn, 'cow', + animals) + + def testAssertDictContainsSubset(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + self.assertDictContainsSubset({}, {}) + self.assertDictContainsSubset({}, {'a': 1}) + self.assertDictContainsSubset({'a': 1}, {'a': 1}) + self.assertDictContainsSubset({'a': 1}, {'a': 1, 'b': 2}) + self.assertDictContainsSubset({'a': 1, 'b': 2}, {'a': 1, 'b': 2}) + + with self.assertRaises(self.failureException): + self.assertDictContainsSubset({1: "one"}, {}) + + with self.assertRaises(self.failureException): + self.assertDictContainsSubset({'a': 2}, {'a': 1}) + + with self.assertRaises(self.failureException): + self.assertDictContainsSubset({'c': 1}, {'a': 1}) + + with self.assertRaises(self.failureException): + self.assertDictContainsSubset({'a': 1, 'c': 1}, {'a': 1}) + + with self.assertRaises(self.failureException): + self.assertDictContainsSubset({'a': 1, 'c': 1}, {'a': 1}) + + one = ''.join(chr(i) for i in range(255)) + # this used to cause a UnicodeDecodeError constructing the failure msg + with self.assertRaises(self.failureException): + self.assertDictContainsSubset({'foo': one}, {'foo': '\uFFFD'}) + + def testAssertEqual(self): + equal_pairs = [ + ((), ()), + ({}, {}), + ([], []), + (set(), set()), + (frozenset(), frozenset())] + for a, b in equal_pairs: + # This mess of try excepts is to test the assertEqual behavior + # itself. + try: + self.assertEqual(a, b) + except self.failureException: + self.fail('assertEqual(%r, %r) failed' % (a, b)) + try: + self.assertEqual(a, b, msg='foo') + except self.failureException: + self.fail('assertEqual(%r, %r) with msg= failed' % (a, b)) + try: + self.assertEqual(a, b, 'foo') + except self.failureException: + self.fail('assertEqual(%r, %r) with third parameter failed' % + (a, b)) + + unequal_pairs = [ + ((), []), + ({}, set()), + (set([4,1]), frozenset([4,2])), + (frozenset([4,5]), set([2,3])), + (set([3,4]), set([5,4]))] + for a, b in unequal_pairs: + self.assertRaises(self.failureException, self.assertEqual, a, b) + self.assertRaises(self.failureException, self.assertEqual, a, b, + 'foo') + self.assertRaises(self.failureException, self.assertEqual, a, b, + msg='foo') + + def testEquality(self): + self.assertListEqual([], []) + self.assertTupleEqual((), ()) + self.assertSequenceEqual([], ()) + + a = [0, 'a', []] + b = [] + self.assertRaises(unittest.TestCase.failureException, + self.assertListEqual, a, b) + self.assertRaises(unittest.TestCase.failureException, + self.assertListEqual, tuple(a), tuple(b)) + self.assertRaises(unittest.TestCase.failureException, + self.assertSequenceEqual, a, tuple(b)) + + b.extend(a) + self.assertListEqual(a, b) + self.assertTupleEqual(tuple(a), tuple(b)) + self.assertSequenceEqual(a, tuple(b)) + self.assertSequenceEqual(tuple(a), b) + + self.assertRaises(self.failureException, self.assertListEqual, + a, tuple(b)) + self.assertRaises(self.failureException, self.assertTupleEqual, + tuple(a), b) + self.assertRaises(self.failureException, self.assertListEqual, None, b) + self.assertRaises(self.failureException, self.assertTupleEqual, None, + tuple(b)) + self.assertRaises(self.failureException, self.assertSequenceEqual, + None, tuple(b)) + self.assertRaises(self.failureException, self.assertListEqual, 1, 1) + self.assertRaises(self.failureException, self.assertTupleEqual, 1, 1) + self.assertRaises(self.failureException, self.assertSequenceEqual, + 1, 1) + + self.assertDictEqual({}, {}) + + c = { 'x': 1 } + d = {} + self.assertRaises(unittest.TestCase.failureException, + self.assertDictEqual, c, d) + + d.update(c) + self.assertDictEqual(c, d) + + d['x'] = 0 + self.assertRaises(unittest.TestCase.failureException, + self.assertDictEqual, c, d, 'These are unequal') + + self.assertRaises(self.failureException, self.assertDictEqual, None, d) + self.assertRaises(self.failureException, self.assertDictEqual, [], d) + self.assertRaises(self.failureException, self.assertDictEqual, 1, 1) + + def testAssertSequenceEqualMaxDiff(self): + self.assertEqual(self.maxDiff, 80*8) + seq1 = 'a' + 'x' * 80**2 + seq2 = 'b' + 'x' * 80**2 + diff = '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(), + pprint.pformat(seq2).splitlines())) + # the +1 is the leading \n added by assertSequenceEqual + omitted = unittest.case.DIFF_OMITTED % (len(diff) + 1,) + + self.maxDiff = len(diff)//2 + try: + + self.assertSequenceEqual(seq1, seq2) + except self.failureException as e: + msg = e.args[0] + else: + self.fail('assertSequenceEqual did not fail.') + self.assertLess(len(msg), len(diff)) + self.assertIn(omitted, msg) + + self.maxDiff = len(diff) * 2 + try: + self.assertSequenceEqual(seq1, seq2) + except self.failureException as e: + msg = e.args[0] + else: + self.fail('assertSequenceEqual did not fail.') + self.assertGreater(len(msg), len(diff)) + self.assertNotIn(omitted, msg) + + self.maxDiff = None + try: + self.assertSequenceEqual(seq1, seq2) + except self.failureException as e: + msg = e.args[0] + else: + self.fail('assertSequenceEqual did not fail.') + self.assertGreater(len(msg), len(diff)) + self.assertNotIn(omitted, msg) + + def testTruncateMessage(self): + self.maxDiff = 1 + message = self._truncateMessage('foo', 'bar') + omitted = unittest.case.DIFF_OMITTED % len('bar') + self.assertEqual(message, 'foo' + omitted) + + self.maxDiff = None + message = self._truncateMessage('foo', 'bar') + self.assertEqual(message, 'foobar') + + self.maxDiff = 4 + message = self._truncateMessage('foo', 'bar') + self.assertEqual(message, 'foobar') + + def testAssertDictEqualTruncates(self): + test = unittest.TestCase('assertEqual') + def truncate(msg, diff): + return 'foo' + test._truncateMessage = truncate + try: + test.assertDictEqual({}, {1: 0}) + except self.failureException as e: + self.assertEqual(str(e), 'foo') + else: + self.fail('assertDictEqual did not fail') + + def testAssertMultiLineEqualTruncates(self): + test = unittest.TestCase('assertEqual') + def truncate(msg, diff): + return 'foo' + test._truncateMessage = truncate + try: + test.assertMultiLineEqual('foo', 'bar') + except self.failureException as e: + self.assertEqual(str(e), 'foo') + else: + self.fail('assertMultiLineEqual did not fail') + + def testAssertEqual_diffThreshold(self): + # check threshold value + self.assertEqual(self._diffThreshold, 2**16) + # disable madDiff to get diff markers + self.maxDiff = None + + # set a lower threshold value and add a cleanup to restore it + old_threshold = self._diffThreshold + self._diffThreshold = 2**5 + self.addCleanup(lambda: setattr(self, '_diffThreshold', old_threshold)) + + # under the threshold: diff marker (^) in error message + s = 'x' * (2**4) + with self.assertRaises(self.failureException) as cm: + self.assertEqual(s + 'a', s + 'b') + self.assertIn('^', str(cm.exception)) + self.assertEqual(s + 'a', s + 'a') + + # over the threshold: diff not used and marker (^) not in error message + s = 'x' * (2**6) + # if the path that uses difflib is taken, _truncateMessage will be + # called -- replace it with explodingTruncation to verify that this + # doesn't happen + def explodingTruncation(message, diff): + raise SystemError('this should not be raised') + old_truncate = self._truncateMessage + self._truncateMessage = explodingTruncation + self.addCleanup(lambda: setattr(self, '_truncateMessage', old_truncate)) + + s1, s2 = s + 'a', s + 'b' + with self.assertRaises(self.failureException) as cm: + self.assertEqual(s1, s2) + self.assertNotIn('^', str(cm.exception)) + self.assertEqual(str(cm.exception), '%r != %r' % (s1, s2)) + self.assertEqual(s + 'a', s + 'a') + + def testAssertEqual_shorten(self): + # set a lower threshold value and add a cleanup to restore it + old_threshold = self._diffThreshold + self._diffThreshold = 0 + self.addCleanup(lambda: setattr(self, '_diffThreshold', old_threshold)) + + s = 'x' * 100 + s1, s2 = s + 'a', s + 'b' + with self.assertRaises(self.failureException) as cm: + self.assertEqual(s1, s2) + c = 'xxxx[35 chars]' + 'x' * 61 + self.assertEqual(str(cm.exception), "'%sa' != '%sb'" % (c, c)) + self.assertEqual(s + 'a', s + 'a') + + p = 'y' * 50 + s1, s2 = s + 'a' + p, s + 'b' + p + with self.assertRaises(self.failureException) as cm: + self.assertEqual(s1, s2) + c = 'xxxx[85 chars]xxxxxxxxxxx' + self.assertEqual(str(cm.exception), "'%sa%s' != '%sb%s'" % (c, p, c, p)) + + p = 'y' * 100 + s1, s2 = s + 'a' + p, s + 'b' + p + with self.assertRaises(self.failureException) as cm: + self.assertEqual(s1, s2) + c = 'xxxx[91 chars]xxxxx' + d = 'y' * 40 + '[56 chars]yyyy' + self.assertEqual(str(cm.exception), "'%sa%s' != '%sb%s'" % (c, d, c, d)) + + def testAssertCountEqual(self): + a = object() + self.assertCountEqual([1, 2, 3], [3, 2, 1]) + self.assertCountEqual(['foo', 'bar', 'baz'], ['bar', 'baz', 'foo']) + self.assertCountEqual([a, a, 2, 2, 3], (a, 2, 3, a, 2)) + self.assertCountEqual([1, "2", "a", "a"], ["a", "2", True, "a"]) + self.assertRaises(self.failureException, self.assertCountEqual, + [1, 2] + [3] * 100, [1] * 100 + [2, 3]) + self.assertRaises(self.failureException, self.assertCountEqual, + [1, "2", "a", "a"], ["a", "2", True, 1]) + self.assertRaises(self.failureException, self.assertCountEqual, + [10], [10, 11]) + self.assertRaises(self.failureException, self.assertCountEqual, + [10, 11], [10]) + self.assertRaises(self.failureException, self.assertCountEqual, + [10, 11, 10], [10, 11]) + + # Test that sequences of unhashable objects can be tested for sameness: + self.assertCountEqual([[1, 2], [3, 4], 0], [False, [3, 4], [1, 2]]) + # Test that iterator of unhashable objects can be tested for sameness: + self.assertCountEqual(iter([1, 2, [], 3, 4]), + iter([1, 2, [], 3, 4])) + + # hashable types, but not orderable + self.assertRaises(self.failureException, self.assertCountEqual, + [], [divmod, 'x', 1, 5j, 2j, frozenset()]) + # comparing dicts + self.assertCountEqual([{'a': 1}, {'b': 2}], [{'b': 2}, {'a': 1}]) + # comparing heterogeneous non-hashable sequences + self.assertCountEqual([1, 'x', divmod, []], [divmod, [], 'x', 1]) + self.assertRaises(self.failureException, self.assertCountEqual, + [], [divmod, [], 'x', 1, 5j, 2j, set()]) + self.assertRaises(self.failureException, self.assertCountEqual, + [[1]], [[2]]) + + # Same elements, but not same sequence length + self.assertRaises(self.failureException, self.assertCountEqual, + [1, 1, 2], [2, 1]) + self.assertRaises(self.failureException, self.assertCountEqual, + [1, 1, "2", "a", "a"], ["2", "2", True, "a"]) + self.assertRaises(self.failureException, self.assertCountEqual, + [1, {'b': 2}, None, True], [{'b': 2}, True, None]) + + # Same elements which don't reliably compare, in + # different order, see issue 10242 + a = [{2,4}, {1,2}] + b = a[::-1] + self.assertCountEqual(a, b) + + # test utility functions supporting assertCountEqual() + + diffs = set(unittest.util._count_diff_all_purpose('aaabccd', 'abbbcce')) + expected = {(3,1,'a'), (1,3,'b'), (1,0,'d'), (0,1,'e')} + self.assertEqual(diffs, expected) + + diffs = unittest.util._count_diff_all_purpose([[]], []) + self.assertEqual(diffs, [(1, 0, [])]) + + diffs = set(unittest.util._count_diff_hashable('aaabccd', 'abbbcce')) + expected = {(3,1,'a'), (1,3,'b'), (1,0,'d'), (0,1,'e')} + self.assertEqual(diffs, expected) + + def testAssertSetEqual(self): + set1 = set() + set2 = set() + self.assertSetEqual(set1, set2) + + self.assertRaises(self.failureException, self.assertSetEqual, None, set2) + self.assertRaises(self.failureException, self.assertSetEqual, [], set2) + self.assertRaises(self.failureException, self.assertSetEqual, set1, None) + self.assertRaises(self.failureException, self.assertSetEqual, set1, []) + + set1 = set(['a']) + set2 = set() + self.assertRaises(self.failureException, self.assertSetEqual, set1, set2) + + set1 = set(['a']) + set2 = set(['a']) + self.assertSetEqual(set1, set2) + + set1 = set(['a']) + set2 = set(['a', 'b']) + self.assertRaises(self.failureException, self.assertSetEqual, set1, set2) + + set1 = set(['a']) + set2 = frozenset(['a', 'b']) + self.assertRaises(self.failureException, self.assertSetEqual, set1, set2) + + set1 = set(['a', 'b']) + set2 = frozenset(['a', 'b']) + self.assertSetEqual(set1, set2) + + set1 = set() + set2 = "foo" + self.assertRaises(self.failureException, self.assertSetEqual, set1, set2) + self.assertRaises(self.failureException, self.assertSetEqual, set2, set1) + + # make sure any string formatting is tuple-safe + set1 = set([(0, 1), (2, 3)]) + set2 = set([(4, 5)]) + self.assertRaises(self.failureException, self.assertSetEqual, set1, set2) + + def testInequality(self): + # Try ints + self.assertGreater(2, 1) + self.assertGreaterEqual(2, 1) + self.assertGreaterEqual(1, 1) + self.assertLess(1, 2) + self.assertLessEqual(1, 2) + self.assertLessEqual(1, 1) + self.assertRaises(self.failureException, self.assertGreater, 1, 2) + self.assertRaises(self.failureException, self.assertGreater, 1, 1) + self.assertRaises(self.failureException, self.assertGreaterEqual, 1, 2) + self.assertRaises(self.failureException, self.assertLess, 2, 1) + self.assertRaises(self.failureException, self.assertLess, 1, 1) + self.assertRaises(self.failureException, self.assertLessEqual, 2, 1) + + # Try Floats + self.assertGreater(1.1, 1.0) + self.assertGreaterEqual(1.1, 1.0) + self.assertGreaterEqual(1.0, 1.0) + self.assertLess(1.0, 1.1) + self.assertLessEqual(1.0, 1.1) + self.assertLessEqual(1.0, 1.0) + self.assertRaises(self.failureException, self.assertGreater, 1.0, 1.1) + self.assertRaises(self.failureException, self.assertGreater, 1.0, 1.0) + self.assertRaises(self.failureException, self.assertGreaterEqual, 1.0, 1.1) + self.assertRaises(self.failureException, self.assertLess, 1.1, 1.0) + self.assertRaises(self.failureException, self.assertLess, 1.0, 1.0) + self.assertRaises(self.failureException, self.assertLessEqual, 1.1, 1.0) + + # Try Strings + self.assertGreater('bug', 'ant') + self.assertGreaterEqual('bug', 'ant') + self.assertGreaterEqual('ant', 'ant') + self.assertLess('ant', 'bug') + self.assertLessEqual('ant', 'bug') + self.assertLessEqual('ant', 'ant') + self.assertRaises(self.failureException, self.assertGreater, 'ant', 'bug') + self.assertRaises(self.failureException, self.assertGreater, 'ant', 'ant') + self.assertRaises(self.failureException, self.assertGreaterEqual, 'ant', 'bug') + self.assertRaises(self.failureException, self.assertLess, 'bug', 'ant') + self.assertRaises(self.failureException, self.assertLess, 'ant', 'ant') + self.assertRaises(self.failureException, self.assertLessEqual, 'bug', 'ant') + + # Try bytes + self.assertGreater(b'bug', b'ant') + self.assertGreaterEqual(b'bug', b'ant') + self.assertGreaterEqual(b'ant', b'ant') + self.assertLess(b'ant', b'bug') + self.assertLessEqual(b'ant', b'bug') + self.assertLessEqual(b'ant', b'ant') + self.assertRaises(self.failureException, self.assertGreater, b'ant', b'bug') + self.assertRaises(self.failureException, self.assertGreater, b'ant', b'ant') + self.assertRaises(self.failureException, self.assertGreaterEqual, b'ant', + b'bug') + self.assertRaises(self.failureException, self.assertLess, b'bug', b'ant') + self.assertRaises(self.failureException, self.assertLess, b'ant', b'ant') + self.assertRaises(self.failureException, self.assertLessEqual, b'bug', b'ant') + + def testAssertMultiLineEqual(self): + sample_text = """\ +http://www.python.org/doc/2.3/lib/module-unittest.html +test case + A test case is the smallest unit of testing. [...] +""" + revised_sample_text = """\ +http://www.python.org/doc/2.4.1/lib/module-unittest.html +test case + A test case is the smallest unit of testing. [...] You may provide your + own implementation that does not subclass from TestCase, of course. +""" + sample_text_error = """\ +- http://www.python.org/doc/2.3/lib/module-unittest.html +? ^ ++ http://www.python.org/doc/2.4.1/lib/module-unittest.html +? ^^^ + test case +- A test case is the smallest unit of testing. [...] ++ A test case is the smallest unit of testing. [...] You may provide your +? +++++++++++++++++++++ ++ own implementation that does not subclass from TestCase, of course. +""" + self.maxDiff = None + try: + self.assertMultiLineEqual(sample_text, revised_sample_text) + except self.failureException as e: + # need to remove the first line of the error message + error = str(e).split('\n', 1)[1] + self.assertEqual(sample_text_error, error) + + def testAssertEqualSingleLine(self): + sample_text = "laden swallows fly slowly" + revised_sample_text = "unladen swallows fly quickly" + sample_text_error = """\ +- laden swallows fly slowly +? ^^^^ ++ unladen swallows fly quickly +? ++ ^^^^^ +""" + try: + self.assertEqual(sample_text, revised_sample_text) + except self.failureException as e: + # need to remove the first line of the error message + error = str(e).split('\n', 1)[1] + self.assertEqual(sample_text_error, error) + + def testEqualityBytesWarning(self): + if sys.flags.bytes_warning: + def bytes_warning(): + return self.assertWarnsRegex(BytesWarning, + 'Comparison between bytes and string') + else: + def bytes_warning(): + return contextlib.ExitStack() + + with bytes_warning(), self.assertRaises(self.failureException): + self.assertEqual('a', b'a') + with bytes_warning(): + self.assertNotEqual('a', b'a') + + a = [0, 'a'] + b = [0, b'a'] + with bytes_warning(), self.assertRaises(self.failureException): + self.assertListEqual(a, b) + with bytes_warning(), self.assertRaises(self.failureException): + self.assertTupleEqual(tuple(a), tuple(b)) + with bytes_warning(), self.assertRaises(self.failureException): + self.assertSequenceEqual(a, tuple(b)) + with bytes_warning(), self.assertRaises(self.failureException): + self.assertSequenceEqual(tuple(a), b) + with bytes_warning(), self.assertRaises(self.failureException): + self.assertSequenceEqual('a', b'a') + with bytes_warning(), self.assertRaises(self.failureException): + self.assertSetEqual(set(a), set(b)) + + with self.assertRaises(self.failureException): + self.assertListEqual(a, tuple(b)) + with self.assertRaises(self.failureException): + self.assertTupleEqual(tuple(a), b) + + a = [0, b'a'] + b = [0] + with self.assertRaises(self.failureException): + self.assertListEqual(a, b) + with self.assertRaises(self.failureException): + self.assertTupleEqual(tuple(a), tuple(b)) + with self.assertRaises(self.failureException): + self.assertSequenceEqual(a, tuple(b)) + with self.assertRaises(self.failureException): + self.assertSequenceEqual(tuple(a), b) + with self.assertRaises(self.failureException): + self.assertSetEqual(set(a), set(b)) + + a = [0] + b = [0, b'a'] + with self.assertRaises(self.failureException): + self.assertListEqual(a, b) + with self.assertRaises(self.failureException): + self.assertTupleEqual(tuple(a), tuple(b)) + with self.assertRaises(self.failureException): + self.assertSequenceEqual(a, tuple(b)) + with self.assertRaises(self.failureException): + self.assertSequenceEqual(tuple(a), b) + with self.assertRaises(self.failureException): + self.assertSetEqual(set(a), set(b)) + + with bytes_warning(), self.assertRaises(self.failureException): + self.assertDictEqual({'a': 0}, {b'a': 0}) + with self.assertRaises(self.failureException): + self.assertDictEqual({}, {b'a': 0}) + with self.assertRaises(self.failureException): + self.assertDictEqual({b'a': 0}, {}) + + with self.assertRaises(self.failureException): + self.assertCountEqual([b'a', b'a'], [b'a', b'a', b'a']) + with bytes_warning(): + self.assertCountEqual(['a', b'a'], ['a', b'a']) + with bytes_warning(), self.assertRaises(self.failureException): + self.assertCountEqual(['a', 'a'], [b'a', b'a']) + with bytes_warning(), self.assertRaises(self.failureException): + self.assertCountEqual(['a', 'a', []], [b'a', b'a', []]) + + def testAssertIsNone(self): + self.assertIsNone(None) + self.assertRaises(self.failureException, self.assertIsNone, False) + self.assertIsNotNone('DjZoPloGears on Rails') + self.assertRaises(self.failureException, self.assertIsNotNone, None) + + def testAssertRegex(self): + self.assertRegex('asdfabasdf', r'ab+') + self.assertRaises(self.failureException, self.assertRegex, + 'saaas', r'aaaa') + + def testAssertRaisesCallable(self): + class ExceptionMock(Exception): + pass + def Stub(): + raise ExceptionMock('We expect') + self.assertRaises(ExceptionMock, Stub) + # A tuple of exception classes is accepted + self.assertRaises((ValueError, ExceptionMock), Stub) + # *args and **kwargs also work + self.assertRaises(ValueError, int, '19', base=8) + # Failure when no exception is raised + with self.assertRaises(self.failureException): + self.assertRaises(ExceptionMock, lambda: 0) + # Failure when the function is None + with self.assertRaises(TypeError): + self.assertRaises(ExceptionMock, None) + # Failure when another exception is raised + with self.assertRaises(ExceptionMock): + self.assertRaises(ValueError, Stub) + + def testAssertRaisesContext(self): + class ExceptionMock(Exception): + pass + def Stub(): + raise ExceptionMock('We expect') + with self.assertRaises(ExceptionMock): + Stub() + # A tuple of exception classes is accepted + with self.assertRaises((ValueError, ExceptionMock)) as cm: + Stub() + # The context manager exposes caught exception + self.assertIsInstance(cm.exception, ExceptionMock) + self.assertEqual(cm.exception.args[0], 'We expect') + # *args and **kwargs also work + with self.assertRaises(ValueError): + int('19', base=8) + # Failure when no exception is raised + with self.assertRaises(self.failureException): + with self.assertRaises(ExceptionMock): + pass + # Custom message + with self.assertRaisesRegex(self.failureException, 'foobar'): + with self.assertRaises(ExceptionMock, msg='foobar'): + pass + # Invalid keyword argument + with self.assertRaisesRegex(TypeError, 'foobar'): + with self.assertRaises(ExceptionMock, foobar=42): + pass + # Failure when another exception is raised + with self.assertRaises(ExceptionMock): + self.assertRaises(ValueError, Stub) + + def testAssertRaisesNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertRaises() + with self.assertRaises(TypeError): + self.assertRaises(1) + with self.assertRaises(TypeError): + self.assertRaises(object) + with self.assertRaises(TypeError): + self.assertRaises((ValueError, 1)) + with self.assertRaises(TypeError): + self.assertRaises((ValueError, object)) + + def testAssertRaisesRefcount(self): + # bpo-23890: assertRaises() must not keep objects alive longer + # than expected + def func() : + try: + raise ValueError + except ValueError: + raise ValueError + + refcount = sys.getrefcount(func) + self.assertRaises(ValueError, func) + self.assertEqual(refcount, sys.getrefcount(func)) + + def testAssertRaisesRegex(self): + class ExceptionMock(Exception): + pass + + def Stub(): + raise ExceptionMock('We expect') + + self.assertRaisesRegex(ExceptionMock, re.compile('expect$'), Stub) + self.assertRaisesRegex(ExceptionMock, 'expect$', Stub) + with self.assertRaises(TypeError): + self.assertRaisesRegex(ExceptionMock, 'expect$', None) + + def testAssertNotRaisesRegex(self): + self.assertRaisesRegex( + self.failureException, '^Exception not raised by <lambda>$', + self.assertRaisesRegex, Exception, re.compile('x'), + lambda: None) + self.assertRaisesRegex( + self.failureException, '^Exception not raised by <lambda>$', + self.assertRaisesRegex, Exception, 'x', + lambda: None) + # Custom message + with self.assertRaisesRegex(self.failureException, 'foobar'): + with self.assertRaisesRegex(Exception, 'expect', msg='foobar'): + pass + # Invalid keyword argument + with self.assertRaisesRegex(TypeError, 'foobar'): + with self.assertRaisesRegex(Exception, 'expect', foobar=42): + pass + + def testAssertRaisesRegexInvalidRegex(self): + # Issue 20145. + class MyExc(Exception): + pass + self.assertRaises(TypeError, self.assertRaisesRegex, MyExc, lambda: True) + + def testAssertWarnsRegexInvalidRegex(self): + # Issue 20145. + class MyWarn(Warning): + pass + self.assertRaises(TypeError, self.assertWarnsRegex, MyWarn, lambda: True) + + def testAssertWarnsModifySysModules(self): + # bpo-29620: handle modified sys.modules during iteration + class Foo(types.ModuleType): + @property + def __warningregistry__(self): + sys.modules['@bar@'] = 'bar' + + sys.modules['@foo@'] = Foo('foo') + try: + self.assertWarns(UserWarning, warnings.warn, 'expected') + finally: + del sys.modules['@foo@'] + del sys.modules['@bar@'] + + def testAssertRaisesRegexMismatch(self): + def Stub(): + raise Exception('Unexpected') + + self.assertRaisesRegex( + self.failureException, + r'"\^Expected\$" does not match "Unexpected"', + self.assertRaisesRegex, Exception, '^Expected$', + Stub) + self.assertRaisesRegex( + self.failureException, + r'"\^Expected\$" does not match "Unexpected"', + self.assertRaisesRegex, Exception, + re.compile('^Expected$'), Stub) + + def testAssertRaisesExcValue(self): + class ExceptionMock(Exception): + pass + + def Stub(foo): + raise ExceptionMock(foo) + v = "particular value" + + ctx = self.assertRaises(ExceptionMock) + with ctx: + Stub(v) + e = ctx.exception + self.assertIsInstance(e, ExceptionMock) + self.assertEqual(e.args[0], v) + + def testAssertRaisesRegexNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertRaisesRegex() + with self.assertRaises(TypeError): + self.assertRaisesRegex(ValueError) + with self.assertRaises(TypeError): + self.assertRaisesRegex(1, 'expect') + with self.assertRaises(TypeError): + self.assertRaisesRegex(object, 'expect') + with self.assertRaises(TypeError): + self.assertRaisesRegex((ValueError, 1), 'expect') + with self.assertRaises(TypeError): + self.assertRaisesRegex((ValueError, object), 'expect') + + def testAssertWarnsCallable(self): + def _runtime_warn(): + warnings.warn("foo", RuntimeWarning) + # Success when the right warning is triggered, even several times + self.assertWarns(RuntimeWarning, _runtime_warn) + self.assertWarns(RuntimeWarning, _runtime_warn) + # A tuple of warning classes is accepted + self.assertWarns((DeprecationWarning, RuntimeWarning), _runtime_warn) + # *args and **kwargs also work + self.assertWarns(RuntimeWarning, + warnings.warn, "foo", category=RuntimeWarning) + # Failure when no warning is triggered + with self.assertRaises(self.failureException): + self.assertWarns(RuntimeWarning, lambda: 0) + # Failure when the function is None + with self.assertRaises(TypeError): + self.assertWarns(RuntimeWarning, None) + # Failure when another warning is triggered + with warnings.catch_warnings(): + # Force default filter (in case tests are run with -We) + warnings.simplefilter("default", RuntimeWarning) + with self.assertRaises(self.failureException): + self.assertWarns(DeprecationWarning, _runtime_warn) + # Filters for other warnings are not modified + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + with self.assertRaises(RuntimeWarning): + self.assertWarns(DeprecationWarning, _runtime_warn) + + def testAssertWarnsContext(self): + # Believe it or not, it is preferable to duplicate all tests above, + # to make sure the __warningregistry__ $@ is circumvented correctly. + def _runtime_warn(): + warnings.warn("foo", RuntimeWarning) + _runtime_warn_lineno = inspect.getsourcelines(_runtime_warn)[1] + with self.assertWarns(RuntimeWarning) as cm: + _runtime_warn() + # A tuple of warning classes is accepted + with self.assertWarns((DeprecationWarning, RuntimeWarning)) as cm: + _runtime_warn() + # The context manager exposes various useful attributes + self.assertIsInstance(cm.warning, RuntimeWarning) + self.assertEqual(cm.warning.args[0], "foo") + self.assertIn("test_case.py", cm.filename) + self.assertEqual(cm.lineno, _runtime_warn_lineno + 1) + # Same with several warnings + with self.assertWarns(RuntimeWarning): + _runtime_warn() + _runtime_warn() + with self.assertWarns(RuntimeWarning): + warnings.warn("foo", category=RuntimeWarning) + # Failure when no warning is triggered + with self.assertRaises(self.failureException): + with self.assertWarns(RuntimeWarning): + pass + # Custom message + with self.assertRaisesRegex(self.failureException, 'foobar'): + with self.assertWarns(RuntimeWarning, msg='foobar'): + pass + # Invalid keyword argument + with self.assertRaisesRegex(TypeError, 'foobar'): + with self.assertWarns(RuntimeWarning, foobar=42): + pass + # Failure when another warning is triggered + with warnings.catch_warnings(): + # Force default filter (in case tests are run with -We) + warnings.simplefilter("default", RuntimeWarning) + with self.assertRaises(self.failureException): + with self.assertWarns(DeprecationWarning): + _runtime_warn() + # Filters for other warnings are not modified + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + with self.assertRaises(RuntimeWarning): + with self.assertWarns(DeprecationWarning): + _runtime_warn() + + def testAssertWarnsNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertWarns() + with self.assertRaises(TypeError): + self.assertWarns(1) + with self.assertRaises(TypeError): + self.assertWarns(object) + with self.assertRaises(TypeError): + self.assertWarns((UserWarning, 1)) + with self.assertRaises(TypeError): + self.assertWarns((UserWarning, object)) + with self.assertRaises(TypeError): + self.assertWarns((UserWarning, Exception)) + + def testAssertWarnsRegexCallable(self): + def _runtime_warn(msg): + warnings.warn(msg, RuntimeWarning) + self.assertWarnsRegex(RuntimeWarning, "o+", + _runtime_warn, "foox") + # Failure when no warning is triggered + with self.assertRaises(self.failureException): + self.assertWarnsRegex(RuntimeWarning, "o+", + lambda: 0) + # Failure when the function is None + with self.assertRaises(TypeError): + self.assertWarnsRegex(RuntimeWarning, "o+", None) + # Failure when another warning is triggered + with warnings.catch_warnings(): + # Force default filter (in case tests are run with -We) + warnings.simplefilter("default", RuntimeWarning) + with self.assertRaises(self.failureException): + self.assertWarnsRegex(DeprecationWarning, "o+", + _runtime_warn, "foox") + # Failure when message doesn't match + with self.assertRaises(self.failureException): + self.assertWarnsRegex(RuntimeWarning, "o+", + _runtime_warn, "barz") + # A little trickier: we ask RuntimeWarnings to be raised, and then + # check for some of them. It is implementation-defined whether + # non-matching RuntimeWarnings are simply re-raised, or produce a + # failureException. + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + with self.assertRaises((RuntimeWarning, self.failureException)): + self.assertWarnsRegex(RuntimeWarning, "o+", + _runtime_warn, "barz") + + def testAssertWarnsRegexContext(self): + # Same as above, but with assertWarnsRegex as a context manager + def _runtime_warn(msg): + warnings.warn(msg, RuntimeWarning) + _runtime_warn_lineno = inspect.getsourcelines(_runtime_warn)[1] + with self.assertWarnsRegex(RuntimeWarning, "o+") as cm: + _runtime_warn("foox") + self.assertIsInstance(cm.warning, RuntimeWarning) + self.assertEqual(cm.warning.args[0], "foox") + self.assertIn("test_case.py", cm.filename) + self.assertEqual(cm.lineno, _runtime_warn_lineno + 1) + # Failure when no warning is triggered + with self.assertRaises(self.failureException): + with self.assertWarnsRegex(RuntimeWarning, "o+"): + pass + # Custom message + with self.assertRaisesRegex(self.failureException, 'foobar'): + with self.assertWarnsRegex(RuntimeWarning, 'o+', msg='foobar'): + pass + # Invalid keyword argument + with self.assertRaisesRegex(TypeError, 'foobar'): + with self.assertWarnsRegex(RuntimeWarning, 'o+', foobar=42): + pass + # Failure when another warning is triggered + with warnings.catch_warnings(): + # Force default filter (in case tests are run with -We) + warnings.simplefilter("default", RuntimeWarning) + with self.assertRaises(self.failureException): + with self.assertWarnsRegex(DeprecationWarning, "o+"): + _runtime_warn("foox") + # Failure when message doesn't match + with self.assertRaises(self.failureException): + with self.assertWarnsRegex(RuntimeWarning, "o+"): + _runtime_warn("barz") + # A little trickier: we ask RuntimeWarnings to be raised, and then + # check for some of them. It is implementation-defined whether + # non-matching RuntimeWarnings are simply re-raised, or produce a + # failureException. + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + with self.assertRaises((RuntimeWarning, self.failureException)): + with self.assertWarnsRegex(RuntimeWarning, "o+"): + _runtime_warn("barz") + + def testAssertWarnsRegexNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertWarnsRegex() + with self.assertRaises(TypeError): + self.assertWarnsRegex(UserWarning) + with self.assertRaises(TypeError): + self.assertWarnsRegex(1, 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex(object, 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex((UserWarning, 1), 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex((UserWarning, object), 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex((UserWarning, Exception), 'expect') + + @contextlib.contextmanager + def assertNoStderr(self): + with captured_stderr() as buf: + yield + self.assertEqual(buf.getvalue(), "") + + def assertLogRecords(self, records, matches): + self.assertEqual(len(records), len(matches)) + for rec, match in zip(records, matches): + self.assertIsInstance(rec, logging.LogRecord) + for k, v in match.items(): + self.assertEqual(getattr(rec, k), v) + + def testAssertLogsDefaults(self): + # defaults: root logger, level INFO + with self.assertNoStderr(): + with self.assertLogs() as cm: + log_foo.info("1") + log_foobar.debug("2") + self.assertEqual(cm.output, ["INFO:foo:1"]) + self.assertLogRecords(cm.records, [{'name': 'foo'}]) + + def testAssertLogsTwoMatchingMessages(self): + # Same, but with two matching log messages + with self.assertNoStderr(): + with self.assertLogs() as cm: + log_foo.info("1") + log_foobar.debug("2") + log_quux.warning("3") + self.assertEqual(cm.output, ["INFO:foo:1", "WARNING:quux:3"]) + self.assertLogRecords(cm.records, + [{'name': 'foo'}, {'name': 'quux'}]) + + def checkAssertLogsPerLevel(self, level): + # Check level filtering + with self.assertNoStderr(): + with self.assertLogs(level=level) as cm: + log_foo.warning("1") + log_foobar.error("2") + log_quux.critical("3") + self.assertEqual(cm.output, ["ERROR:foo.bar:2", "CRITICAL:quux:3"]) + self.assertLogRecords(cm.records, + [{'name': 'foo.bar'}, {'name': 'quux'}]) + + def testAssertLogsPerLevel(self): + self.checkAssertLogsPerLevel(logging.ERROR) + self.checkAssertLogsPerLevel('ERROR') + + def checkAssertLogsPerLogger(self, logger): + # Check per-logger filtering + with self.assertNoStderr(): + with self.assertLogs(level='DEBUG') as outer_cm: + with self.assertLogs(logger, level='DEBUG') as cm: + log_foo.info("1") + log_foobar.debug("2") + log_quux.warning("3") + self.assertEqual(cm.output, ["INFO:foo:1", "DEBUG:foo.bar:2"]) + self.assertLogRecords(cm.records, + [{'name': 'foo'}, {'name': 'foo.bar'}]) + # The outer catchall caught the quux log + self.assertEqual(outer_cm.output, ["WARNING:quux:3"]) + + def testAssertLogsPerLogger(self): + self.checkAssertLogsPerLogger(logging.getLogger('foo')) + self.checkAssertLogsPerLogger('foo') + + def testAssertLogsFailureNoLogs(self): + # Failure due to no logs + with self.assertNoStderr(): + with self.assertRaises(self.failureException): + with self.assertLogs(): + pass + + def testAssertLogsFailureLevelTooHigh(self): + # Failure due to level too high + with self.assertNoStderr(): + with self.assertRaises(self.failureException): + with self.assertLogs(level='WARNING'): + log_foo.info("1") + + def testAssertLogsFailureLevelTooHigh_FilterInRootLogger(self): + # Failure due to level too high - message propagated to root + with self.assertNoStderr(): + oldLevel = log_foo.level + log_foo.setLevel(logging.INFO) + try: + with self.assertRaises(self.failureException): + with self.assertLogs(level='WARNING'): + log_foo.info("1") + finally: + log_foo.setLevel(oldLevel) + + def testAssertLogsFailureMismatchingLogger(self): + # Failure due to mismatching logger (and the logged message is + # passed through) + with self.assertLogs('quux', level='ERROR'): + with self.assertRaises(self.failureException): + with self.assertLogs('foo'): + log_quux.error("1") + + def testAssertLogsUnexpectedException(self): + # Check unexpected exception will go through. + with self.assertRaises(ZeroDivisionError): + with self.assertLogs(): + raise ZeroDivisionError("Unexpected") + + def testAssertNoLogsDefault(self): + with self.assertRaises(self.failureException) as cm: + with self.assertNoLogs(): + log_foo.info("1") + log_foobar.debug("2") + self.assertEqual( + str(cm.exception), + "Unexpected logs found: ['INFO:foo:1']", + ) + + def testAssertNoLogsFailureFoundLogs(self): + with self.assertRaises(self.failureException) as cm: + with self.assertNoLogs(): + log_quux.error("1") + log_foo.error("foo") + + self.assertEqual( + str(cm.exception), + "Unexpected logs found: ['ERROR:quux:1', 'ERROR:foo:foo']", + ) + + def testAssertNoLogsPerLogger(self): + with self.assertNoStderr(): + with self.assertLogs(log_quux): + with self.assertNoLogs(logger=log_foo): + log_quux.error("1") + + def testAssertNoLogsFailurePerLogger(self): + # Failure due to unexpected logs for the given logger or its + # children. + with self.assertRaises(self.failureException) as cm: + with self.assertLogs(log_quux): + with self.assertNoLogs(logger=log_foo): + log_quux.error("1") + log_foobar.info("2") + self.assertEqual( + str(cm.exception), + "Unexpected logs found: ['INFO:foo.bar:2']", + ) + + def testAssertNoLogsPerLevel(self): + # Check per-level filtering + with self.assertNoStderr(): + with self.assertNoLogs(level="ERROR"): + log_foo.info("foo") + log_quux.debug("1") + + def testAssertNoLogsFailurePerLevel(self): + # Failure due to unexpected logs at the specified level. + with self.assertRaises(self.failureException) as cm: + with self.assertNoLogs(level="DEBUG"): + log_foo.debug("foo") + log_quux.debug("1") + self.assertEqual( + str(cm.exception), + "Unexpected logs found: ['DEBUG:foo:foo', 'DEBUG:quux:1']", + ) + + def testAssertNoLogsUnexpectedException(self): + # Check unexpected exception will go through. + with self.assertRaises(ZeroDivisionError): + with self.assertNoLogs(): + raise ZeroDivisionError("Unexpected") + + def testAssertNoLogsYieldsNone(self): + with self.assertNoLogs() as value: + pass + self.assertIsNone(value) + + def testDeprecatedMethodNames(self): + """ + Test that the deprecated methods raise a DeprecationWarning. See #9424. + """ + old = ( + (self.failIfEqual, (3, 5)), + (self.assertNotEquals, (3, 5)), + (self.failUnlessEqual, (3, 3)), + (self.assertEquals, (3, 3)), + (self.failUnlessAlmostEqual, (2.0, 2.0)), + (self.assertAlmostEquals, (2.0, 2.0)), + (self.failIfAlmostEqual, (3.0, 5.0)), + (self.assertNotAlmostEquals, (3.0, 5.0)), + (self.failUnless, (True,)), + (self.assert_, (True,)), + (self.failUnlessRaises, (TypeError, lambda _: 3.14 + 'spam')), + (self.failIf, (False,)), + (self.assertDictContainsSubset, (dict(a=1, b=2), dict(a=1, b=2, c=3))), + (self.assertRaisesRegexp, (KeyError, 'foo', lambda: {}['foo'])), + (self.assertRegexpMatches, ('bar', 'bar')), + ) + for meth, args in old: + with self.assertWarns(DeprecationWarning): + meth(*args) + + # disable this test for now. When the version where the fail* methods will + # be removed is decided, re-enable it and update the version + def _testDeprecatedFailMethods(self): + """Test that the deprecated fail* methods get removed in 3.x""" + if sys.version_info[:2] < (3, 3): + return + deprecated_names = [ + 'failIfEqual', 'failUnlessEqual', 'failUnlessAlmostEqual', + 'failIfAlmostEqual', 'failUnless', 'failUnlessRaises', 'failIf', + 'assertDictContainsSubset', + ] + for deprecated_name in deprecated_names: + with self.assertRaises(AttributeError): + getattr(self, deprecated_name) # remove these in 3.x + + def testDeepcopy(self): + # Issue: 5660 + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + test = TestableTest('testNothing') + + # This shouldn't blow up + deepcopy(test) + + def testPickle(self): + # Issue 10326 + + # Can't use TestCase classes defined in Test class as + # pickle does not work with inner classes + test = unittest.TestCase('run') + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + + # blew up prior to fix + pickled_test = pickle.dumps(test, protocol=protocol) + unpickled_test = pickle.loads(pickled_test) + self.assertEqual(test, unpickled_test) + + # exercise the TestCase instance in a way that will invoke + # the type equality lookup mechanism + unpickled_test.assertEqual(set(), set()) + + def testKeyboardInterrupt(self): + def _raise(self=None): + raise KeyboardInterrupt + def nothing(self): + pass + + class Test1(unittest.TestCase): + test_something = _raise + + class Test2(unittest.TestCase): + setUp = _raise + test_something = nothing + + class Test3(unittest.TestCase): + test_something = nothing + tearDown = _raise + + class Test4(unittest.TestCase): + def test_something(self): + self.addCleanup(_raise) + + for klass in (Test1, Test2, Test3, Test4): + with self.assertRaises(KeyboardInterrupt): + klass('test_something').run() + + def testSkippingEverywhere(self): + def _skip(self=None): + raise unittest.SkipTest('some reason') + def nothing(self): + pass + + class Test1(unittest.TestCase): + test_something = _skip + + class Test2(unittest.TestCase): + setUp = _skip + test_something = nothing + + class Test3(unittest.TestCase): + test_something = nothing + tearDown = _skip + + class Test4(unittest.TestCase): + def test_something(self): + self.addCleanup(_skip) + + for klass in (Test1, Test2, Test3, Test4): + result = unittest.TestResult() + klass('test_something').run(result) + self.assertEqual(len(result.skipped), 1) + self.assertEqual(result.testsRun, 1) + + def testSystemExit(self): + def _raise(self=None): + raise SystemExit + def nothing(self): + pass + + class Test1(unittest.TestCase): + test_something = _raise + + class Test2(unittest.TestCase): + setUp = _raise + test_something = nothing + + class Test3(unittest.TestCase): + test_something = nothing + tearDown = _raise + + class Test4(unittest.TestCase): + def test_something(self): + self.addCleanup(_raise) + + for klass in (Test1, Test2, Test3, Test4): + result = unittest.TestResult() + klass('test_something').run(result) + self.assertEqual(len(result.errors), 1) + self.assertEqual(result.testsRun, 1) + + @support.cpython_only + def testNoCycles(self): + case = unittest.TestCase() + wr = weakref.ref(case) + with support.disable_gc(): + del case + self.assertFalse(wr()) + + def test_no_exception_leak(self): + # Issue #19880: TestCase.run() should not keep a reference + # to the exception + class MyException(Exception): + ninstance = 0 + + def __init__(self): + MyException.ninstance += 1 + Exception.__init__(self) + + def __del__(self): + MyException.ninstance -= 1 + + class TestCase(unittest.TestCase): + def test1(self): + raise MyException() + + @unittest.expectedFailure + def test2(self): + raise MyException() + + for method_name in ('test1', 'test2'): + testcase = TestCase(method_name) + testcase.run() + gc_collect() # For PyPy or other GCs. + self.assertEqual(MyException.ninstance, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/test_discovery.py b/Lib/test/test_unittest/test_discovery.py new file mode 100644 index 0000000..946fa12 --- /dev/null +++ b/Lib/test/test_unittest/test_discovery.py @@ -0,0 +1,849 @@ +import os.path +from os.path import abspath +import re +import sys +import types +import pickle +from test import support +from test.support import import_helper +import test.test_importlib.util + +import unittest +import unittest.mock +import test.test_unittest + + +class TestableTestProgram(unittest.TestProgram): + module = None + exit = True + defaultTest = failfast = catchbreak = buffer = None + verbosity = 1 + progName = '' + testRunner = testLoader = None + + def __init__(self): + pass + + +class TestDiscovery(unittest.TestCase): + + # Heavily mocked tests so I can avoid hitting the filesystem + def test_get_name_from_path(self): + loader = unittest.TestLoader() + loader._top_level_dir = '/foo' + name = loader._get_name_from_path('/foo/bar/baz.py') + self.assertEqual(name, 'bar.baz') + + if not __debug__: + # asserts are off + return + + with self.assertRaises(AssertionError): + loader._get_name_from_path('/bar/baz.py') + + def test_find_tests(self): + loader = unittest.TestLoader() + + original_listdir = os.listdir + def restore_listdir(): + os.listdir = original_listdir + original_isfile = os.path.isfile + def restore_isfile(): + os.path.isfile = original_isfile + original_isdir = os.path.isdir + def restore_isdir(): + os.path.isdir = original_isdir + + path_lists = [['test2.py', 'test1.py', 'not_a_test.py', 'test_dir', + 'test.foo', 'test-not-a-module.py', 'another_dir'], + ['test4.py', 'test3.py', ]] + os.listdir = lambda path: path_lists.pop(0) + self.addCleanup(restore_listdir) + + def isdir(path): + return path.endswith('dir') + os.path.isdir = isdir + self.addCleanup(restore_isdir) + + def isfile(path): + # another_dir is not a package and so shouldn't be recursed into + return not path.endswith('dir') and not 'another_dir' in path + os.path.isfile = isfile + self.addCleanup(restore_isfile) + + loader._get_module_from_name = lambda path: path + ' module' + orig_load_tests = loader.loadTestsFromModule + def loadTestsFromModule(module, pattern=None): + # This is where load_tests is called. + base = orig_load_tests(module, pattern=pattern) + return base + [module + ' tests'] + loader.loadTestsFromModule = loadTestsFromModule + loader.suiteClass = lambda thing: thing + + top_level = os.path.abspath('/foo') + loader._top_level_dir = top_level + suite = list(loader._find_tests(top_level, 'test*.py')) + + # The test suites found should be sorted alphabetically for reliable + # execution order. + expected = [[name + ' module tests'] for name in + ('test1', 'test2', 'test_dir')] + expected.extend([[('test_dir.%s' % name) + ' module tests'] for name in + ('test3', 'test4')]) + self.assertEqual(suite, expected) + + def test_find_tests_socket(self): + # A socket is neither a directory nor a regular file. + # https://bugs.python.org/issue25320 + loader = unittest.TestLoader() + + original_listdir = os.listdir + def restore_listdir(): + os.listdir = original_listdir + original_isfile = os.path.isfile + def restore_isfile(): + os.path.isfile = original_isfile + original_isdir = os.path.isdir + def restore_isdir(): + os.path.isdir = original_isdir + + path_lists = [['socket']] + os.listdir = lambda path: path_lists.pop(0) + self.addCleanup(restore_listdir) + + os.path.isdir = lambda path: False + self.addCleanup(restore_isdir) + + os.path.isfile = lambda path: False + self.addCleanup(restore_isfile) + + loader._get_module_from_name = lambda path: path + ' module' + orig_load_tests = loader.loadTestsFromModule + def loadTestsFromModule(module, pattern=None): + # This is where load_tests is called. + base = orig_load_tests(module, pattern=pattern) + return base + [module + ' tests'] + loader.loadTestsFromModule = loadTestsFromModule + loader.suiteClass = lambda thing: thing + + top_level = os.path.abspath('/foo') + loader._top_level_dir = top_level + suite = list(loader._find_tests(top_level, 'test*.py')) + + self.assertEqual(suite, []) + + def test_find_tests_with_package(self): + loader = unittest.TestLoader() + + original_listdir = os.listdir + def restore_listdir(): + os.listdir = original_listdir + original_isfile = os.path.isfile + def restore_isfile(): + os.path.isfile = original_isfile + original_isdir = os.path.isdir + def restore_isdir(): + os.path.isdir = original_isdir + + directories = ['a_directory', 'test_directory', 'test_directory2'] + path_lists = [directories, [], [], []] + os.listdir = lambda path: path_lists.pop(0) + self.addCleanup(restore_listdir) + + os.path.isdir = lambda path: True + self.addCleanup(restore_isdir) + + os.path.isfile = lambda path: os.path.basename(path) not in directories + self.addCleanup(restore_isfile) + + class Module(object): + paths = [] + load_tests_args = [] + + def __init__(self, path): + self.path = path + self.paths.append(path) + if os.path.basename(path) == 'test_directory': + def load_tests(loader, tests, pattern): + self.load_tests_args.append((loader, tests, pattern)) + return [self.path + ' load_tests'] + self.load_tests = load_tests + + def __eq__(self, other): + return self.path == other.path + + loader._get_module_from_name = lambda name: Module(name) + orig_load_tests = loader.loadTestsFromModule + def loadTestsFromModule(module, pattern=None): + # This is where load_tests is called. + base = orig_load_tests(module, pattern=pattern) + return base + [module.path + ' module tests'] + loader.loadTestsFromModule = loadTestsFromModule + loader.suiteClass = lambda thing: thing + + loader._top_level_dir = '/foo' + # this time no '.py' on the pattern so that it can match + # a test package + suite = list(loader._find_tests('/foo', 'test*')) + + # We should have loaded tests from the a_directory and test_directory2 + # directly and via load_tests for the test_directory package, which + # still calls the baseline module loader. + self.assertEqual(suite, + [['a_directory module tests'], + ['test_directory load_tests', + 'test_directory module tests'], + ['test_directory2 module tests']]) + + + # The test module paths should be sorted for reliable execution order + self.assertEqual(Module.paths, + ['a_directory', 'test_directory', 'test_directory2']) + + # load_tests should have been called once with loader, tests and pattern + # (but there are no tests in our stub module itself, so that is [] at + # the time of call). + self.assertEqual(Module.load_tests_args, + [(loader, [], 'test*')]) + + def test_find_tests_default_calls_package_load_tests(self): + loader = unittest.TestLoader() + + original_listdir = os.listdir + def restore_listdir(): + os.listdir = original_listdir + original_isfile = os.path.isfile + def restore_isfile(): + os.path.isfile = original_isfile + original_isdir = os.path.isdir + def restore_isdir(): + os.path.isdir = original_isdir + + directories = ['a_directory', 'test_directory', 'test_directory2'] + path_lists = [directories, [], [], []] + os.listdir = lambda path: path_lists.pop(0) + self.addCleanup(restore_listdir) + + os.path.isdir = lambda path: True + self.addCleanup(restore_isdir) + + os.path.isfile = lambda path: os.path.basename(path) not in directories + self.addCleanup(restore_isfile) + + class Module(object): + paths = [] + load_tests_args = [] + + def __init__(self, path): + self.path = path + self.paths.append(path) + if os.path.basename(path) == 'test_directory': + def load_tests(loader, tests, pattern): + self.load_tests_args.append((loader, tests, pattern)) + return [self.path + ' load_tests'] + self.load_tests = load_tests + + def __eq__(self, other): + return self.path == other.path + + loader._get_module_from_name = lambda name: Module(name) + orig_load_tests = loader.loadTestsFromModule + def loadTestsFromModule(module, pattern=None): + # This is where load_tests is called. + base = orig_load_tests(module, pattern=pattern) + return base + [module.path + ' module tests'] + loader.loadTestsFromModule = loadTestsFromModule + loader.suiteClass = lambda thing: thing + + loader._top_level_dir = '/foo' + # this time no '.py' on the pattern so that it can match + # a test package + suite = list(loader._find_tests('/foo', 'test*.py')) + + # We should have loaded tests from the a_directory and test_directory2 + # directly and via load_tests for the test_directory package, which + # still calls the baseline module loader. + self.assertEqual(suite, + [['a_directory module tests'], + ['test_directory load_tests', + 'test_directory module tests'], + ['test_directory2 module tests']]) + # The test module paths should be sorted for reliable execution order + self.assertEqual(Module.paths, + ['a_directory', 'test_directory', 'test_directory2']) + + + # load_tests should have been called once with loader, tests and pattern + self.assertEqual(Module.load_tests_args, + [(loader, [], 'test*.py')]) + + def test_find_tests_customize_via_package_pattern(self): + # This test uses the example 'do-nothing' load_tests from + # https://docs.python.org/3/library/unittest.html#load-tests-protocol + # to make sure that that actually works. + # Housekeeping + original_listdir = os.listdir + def restore_listdir(): + os.listdir = original_listdir + self.addCleanup(restore_listdir) + original_isfile = os.path.isfile + def restore_isfile(): + os.path.isfile = original_isfile + self.addCleanup(restore_isfile) + original_isdir = os.path.isdir + def restore_isdir(): + os.path.isdir = original_isdir + self.addCleanup(restore_isdir) + self.addCleanup(sys.path.remove, abspath('/foo')) + + # Test data: we expect the following: + # a listdir to find our package, and isfile and isdir checks on it. + # a module-from-name call to turn that into a module + # followed by load_tests. + # then our load_tests will call discover() which is messy + # but that finally chains into find_tests again for the child dir - + # which is why we don't have an infinite loop. + # We expect to see: + # the module load tests for both package and plain module called, + # and the plain module result nested by the package module load_tests + # indicating that it was processed and could have been mutated. + vfs = {abspath('/foo'): ['my_package'], + abspath('/foo/my_package'): ['__init__.py', 'test_module.py']} + def list_dir(path): + return list(vfs[path]) + os.listdir = list_dir + os.path.isdir = lambda path: not path.endswith('.py') + os.path.isfile = lambda path: path.endswith('.py') + + class Module(object): + paths = [] + load_tests_args = [] + + def __init__(self, path): + self.path = path + self.paths.append(path) + if path.endswith('test_module'): + def load_tests(loader, tests, pattern): + self.load_tests_args.append((loader, tests, pattern)) + return [self.path + ' load_tests'] + else: + def load_tests(loader, tests, pattern): + self.load_tests_args.append((loader, tests, pattern)) + # top level directory cached on loader instance + __file__ = '/foo/my_package/__init__.py' + this_dir = os.path.dirname(__file__) + pkg_tests = loader.discover( + start_dir=this_dir, pattern=pattern) + return [self.path + ' load_tests', tests + ] + pkg_tests + self.load_tests = load_tests + + def __eq__(self, other): + return self.path == other.path + + loader = unittest.TestLoader() + loader._get_module_from_name = lambda name: Module(name) + loader.suiteClass = lambda thing: thing + + loader._top_level_dir = abspath('/foo') + # this time no '.py' on the pattern so that it can match + # a test package + suite = list(loader._find_tests(abspath('/foo'), 'test*.py')) + + # We should have loaded tests from both my_package and + # my_package.test_module, and also run the load_tests hook in both. + # (normally this would be nested TestSuites.) + self.assertEqual(suite, + [['my_package load_tests', [], + ['my_package.test_module load_tests']]]) + # Parents before children. + self.assertEqual(Module.paths, + ['my_package', 'my_package.test_module']) + + # load_tests should have been called twice with loader, tests and pattern + self.assertEqual(Module.load_tests_args, + [(loader, [], 'test*.py'), + (loader, [], 'test*.py')]) + + def test_discover(self): + loader = unittest.TestLoader() + + original_isfile = os.path.isfile + original_isdir = os.path.isdir + def restore_isfile(): + os.path.isfile = original_isfile + + os.path.isfile = lambda path: False + self.addCleanup(restore_isfile) + + orig_sys_path = sys.path[:] + def restore_path(): + sys.path[:] = orig_sys_path + self.addCleanup(restore_path) + + full_path = os.path.abspath(os.path.normpath('/foo')) + with self.assertRaises(ImportError): + loader.discover('/foo/bar', top_level_dir='/foo') + + self.assertEqual(loader._top_level_dir, full_path) + self.assertIn(full_path, sys.path) + + os.path.isfile = lambda path: True + os.path.isdir = lambda path: True + + def restore_isdir(): + os.path.isdir = original_isdir + self.addCleanup(restore_isdir) + + _find_tests_args = [] + def _find_tests(start_dir, pattern): + _find_tests_args.append((start_dir, pattern)) + return ['tests'] + loader._find_tests = _find_tests + loader.suiteClass = str + + suite = loader.discover('/foo/bar/baz', 'pattern', '/foo/bar') + + top_level_dir = os.path.abspath('/foo/bar') + start_dir = os.path.abspath('/foo/bar/baz') + self.assertEqual(suite, "['tests']") + self.assertEqual(loader._top_level_dir, top_level_dir) + self.assertEqual(_find_tests_args, [(start_dir, 'pattern')]) + self.assertIn(top_level_dir, sys.path) + + def test_discover_start_dir_is_package_calls_package_load_tests(self): + # This test verifies that the package load_tests in a package is indeed + # invoked when the start_dir is a package (and not the top level). + # http://bugs.python.org/issue22457 + + # Test data: we expect the following: + # an isfile to verify the package, then importing and scanning + # as per _find_tests' normal behaviour. + # We expect to see our load_tests hook called once. + vfs = {abspath('/toplevel'): ['startdir'], + abspath('/toplevel/startdir'): ['__init__.py']} + def list_dir(path): + return list(vfs[path]) + self.addCleanup(setattr, os, 'listdir', os.listdir) + os.listdir = list_dir + self.addCleanup(setattr, os.path, 'isfile', os.path.isfile) + os.path.isfile = lambda path: path.endswith('.py') + self.addCleanup(setattr, os.path, 'isdir', os.path.isdir) + os.path.isdir = lambda path: not path.endswith('.py') + self.addCleanup(sys.path.remove, abspath('/toplevel')) + + class Module(object): + paths = [] + load_tests_args = [] + + def __init__(self, path): + self.path = path + + def load_tests(self, loader, tests, pattern): + return ['load_tests called ' + self.path] + + def __eq__(self, other): + return self.path == other.path + + loader = unittest.TestLoader() + loader._get_module_from_name = lambda name: Module(name) + loader.suiteClass = lambda thing: thing + + suite = loader.discover('/toplevel/startdir', top_level_dir='/toplevel') + + # We should have loaded tests from the package __init__. + # (normally this would be nested TestSuites.) + self.assertEqual(suite, + [['load_tests called startdir']]) + + def setup_import_issue_tests(self, fakefile): + listdir = os.listdir + os.listdir = lambda _: [fakefile] + isfile = os.path.isfile + os.path.isfile = lambda _: True + orig_sys_path = sys.path[:] + def restore(): + os.path.isfile = isfile + os.listdir = listdir + sys.path[:] = orig_sys_path + self.addCleanup(restore) + + def setup_import_issue_package_tests(self, vfs): + self.addCleanup(setattr, os, 'listdir', os.listdir) + self.addCleanup(setattr, os.path, 'isfile', os.path.isfile) + self.addCleanup(setattr, os.path, 'isdir', os.path.isdir) + self.addCleanup(sys.path.__setitem__, slice(None), list(sys.path)) + def list_dir(path): + return list(vfs[path]) + os.listdir = list_dir + os.path.isdir = lambda path: not path.endswith('.py') + os.path.isfile = lambda path: path.endswith('.py') + + def test_discover_with_modules_that_fail_to_import(self): + loader = unittest.TestLoader() + + self.setup_import_issue_tests('test_this_does_not_exist.py') + + suite = loader.discover('.') + self.assertIn(os.getcwd(), sys.path) + self.assertEqual(suite.countTestCases(), 1) + # Errors loading the suite are also captured for introspection. + self.assertNotEqual([], loader.errors) + self.assertEqual(1, len(loader.errors)) + error = loader.errors[0] + self.assertTrue( + 'Failed to import test module: test_this_does_not_exist' in error, + 'missing error string in %r' % error) + test = list(list(suite)[0])[0] # extract test from suite + + with self.assertRaises(ImportError): + test.test_this_does_not_exist() + + def test_discover_with_init_modules_that_fail_to_import(self): + vfs = {abspath('/foo'): ['my_package'], + abspath('/foo/my_package'): ['__init__.py', 'test_module.py']} + self.setup_import_issue_package_tests(vfs) + import_calls = [] + def _get_module_from_name(name): + import_calls.append(name) + raise ImportError("Cannot import Name") + loader = unittest.TestLoader() + loader._get_module_from_name = _get_module_from_name + suite = loader.discover(abspath('/foo')) + + self.assertIn(abspath('/foo'), sys.path) + self.assertEqual(suite.countTestCases(), 1) + # Errors loading the suite are also captured for introspection. + self.assertNotEqual([], loader.errors) + self.assertEqual(1, len(loader.errors)) + error = loader.errors[0] + self.assertTrue( + 'Failed to import test module: my_package' in error, + 'missing error string in %r' % error) + test = list(list(suite)[0])[0] # extract test from suite + with self.assertRaises(ImportError): + test.my_package() + self.assertEqual(import_calls, ['my_package']) + + # Check picklability + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickle.loads(pickle.dumps(test, proto)) + + def test_discover_with_module_that_raises_SkipTest_on_import(self): + if not unittest.BaseTestSuite._cleanup: + raise unittest.SkipTest("Suite cleanup is disabled") + + loader = unittest.TestLoader() + + def _get_module_from_name(name): + raise unittest.SkipTest('skipperoo') + loader._get_module_from_name = _get_module_from_name + + self.setup_import_issue_tests('test_skip_dummy.py') + + suite = loader.discover('.') + self.assertEqual(suite.countTestCases(), 1) + + result = unittest.TestResult() + suite.run(result) + self.assertEqual(len(result.skipped), 1) + + # Check picklability + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickle.loads(pickle.dumps(suite, proto)) + + def test_discover_with_init_module_that_raises_SkipTest_on_import(self): + if not unittest.BaseTestSuite._cleanup: + raise unittest.SkipTest("Suite cleanup is disabled") + + vfs = {abspath('/foo'): ['my_package'], + abspath('/foo/my_package'): ['__init__.py', 'test_module.py']} + self.setup_import_issue_package_tests(vfs) + import_calls = [] + def _get_module_from_name(name): + import_calls.append(name) + raise unittest.SkipTest('skipperoo') + loader = unittest.TestLoader() + loader._get_module_from_name = _get_module_from_name + suite = loader.discover(abspath('/foo')) + + self.assertIn(abspath('/foo'), sys.path) + self.assertEqual(suite.countTestCases(), 1) + result = unittest.TestResult() + suite.run(result) + self.assertEqual(len(result.skipped), 1) + self.assertEqual(result.testsRun, 1) + self.assertEqual(import_calls, ['my_package']) + + # Check picklability + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + pickle.loads(pickle.dumps(suite, proto)) + + def test_command_line_handling_parseArgs(self): + program = TestableTestProgram() + + args = [] + program._do_discovery = args.append + program.parseArgs(['something', 'discover']) + self.assertEqual(args, [[]]) + + args[:] = [] + program.parseArgs(['something', 'discover', 'foo', 'bar']) + self.assertEqual(args, [['foo', 'bar']]) + + def test_command_line_handling_discover_by_default(self): + program = TestableTestProgram() + + args = [] + program._do_discovery = args.append + program.parseArgs(['something']) + self.assertEqual(args, [[]]) + self.assertEqual(program.verbosity, 1) + self.assertIs(program.buffer, False) + self.assertIs(program.catchbreak, False) + self.assertIs(program.failfast, False) + + def test_command_line_handling_discover_by_default_with_options(self): + program = TestableTestProgram() + + args = [] + program._do_discovery = args.append + program.parseArgs(['something', '-v', '-b', '-v', '-c', '-f']) + self.assertEqual(args, [[]]) + self.assertEqual(program.verbosity, 2) + self.assertIs(program.buffer, True) + self.assertIs(program.catchbreak, True) + self.assertIs(program.failfast, True) + + + def test_command_line_handling_do_discovery_too_many_arguments(self): + program = TestableTestProgram() + program.testLoader = None + + with support.captured_stderr() as stderr, \ + self.assertRaises(SystemExit) as cm: + # too many args + program._do_discovery(['one', 'two', 'three', 'four']) + self.assertEqual(cm.exception.args, (2,)) + self.assertIn('usage:', stderr.getvalue()) + + + def test_command_line_handling_do_discovery_uses_default_loader(self): + program = object.__new__(unittest.TestProgram) + program._initArgParsers() + + class Loader(object): + args = [] + def discover(self, start_dir, pattern, top_level_dir): + self.args.append((start_dir, pattern, top_level_dir)) + return 'tests' + + program.testLoader = Loader() + program._do_discovery(['-v']) + self.assertEqual(Loader.args, [('.', 'test*.py', None)]) + + def test_command_line_handling_do_discovery_calls_loader(self): + program = TestableTestProgram() + + class Loader(object): + args = [] + def discover(self, start_dir, pattern, top_level_dir): + self.args.append((start_dir, pattern, top_level_dir)) + return 'tests' + + program._do_discovery(['-v'], Loader=Loader) + self.assertEqual(program.verbosity, 2) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('.', 'test*.py', None)]) + + Loader.args = [] + program = TestableTestProgram() + program._do_discovery(['--verbose'], Loader=Loader) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('.', 'test*.py', None)]) + + Loader.args = [] + program = TestableTestProgram() + program._do_discovery([], Loader=Loader) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('.', 'test*.py', None)]) + + Loader.args = [] + program = TestableTestProgram() + program._do_discovery(['fish'], Loader=Loader) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('fish', 'test*.py', None)]) + + Loader.args = [] + program = TestableTestProgram() + program._do_discovery(['fish', 'eggs'], Loader=Loader) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('fish', 'eggs', None)]) + + Loader.args = [] + program = TestableTestProgram() + program._do_discovery(['fish', 'eggs', 'ham'], Loader=Loader) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('fish', 'eggs', 'ham')]) + + Loader.args = [] + program = TestableTestProgram() + program._do_discovery(['-s', 'fish'], Loader=Loader) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('fish', 'test*.py', None)]) + + Loader.args = [] + program = TestableTestProgram() + program._do_discovery(['-t', 'fish'], Loader=Loader) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('.', 'test*.py', 'fish')]) + + Loader.args = [] + program = TestableTestProgram() + program._do_discovery(['-p', 'fish'], Loader=Loader) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('.', 'fish', None)]) + self.assertFalse(program.failfast) + self.assertFalse(program.catchbreak) + + Loader.args = [] + program = TestableTestProgram() + program._do_discovery(['-p', 'eggs', '-s', 'fish', '-v', '-f', '-c'], + Loader=Loader) + self.assertEqual(program.test, 'tests') + self.assertEqual(Loader.args, [('fish', 'eggs', None)]) + self.assertEqual(program.verbosity, 2) + self.assertTrue(program.failfast) + self.assertTrue(program.catchbreak) + + def setup_module_clash(self): + class Module(object): + __file__ = 'bar/foo.py' + sys.modules['foo'] = Module + full_path = os.path.abspath('foo') + original_listdir = os.listdir + original_isfile = os.path.isfile + original_isdir = os.path.isdir + original_realpath = os.path.realpath + + def cleanup(): + os.listdir = original_listdir + os.path.isfile = original_isfile + os.path.isdir = original_isdir + os.path.realpath = original_realpath + del sys.modules['foo'] + if full_path in sys.path: + sys.path.remove(full_path) + self.addCleanup(cleanup) + + def listdir(_): + return ['foo.py'] + def isfile(_): + return True + def isdir(_): + return True + os.listdir = listdir + os.path.isfile = isfile + os.path.isdir = isdir + if os.name == 'nt': + # ntpath.realpath may inject path prefixes when failing to + # resolve real files, so we substitute abspath() here instead. + os.path.realpath = os.path.abspath + return full_path + + def test_detect_module_clash(self): + full_path = self.setup_module_clash() + loader = unittest.TestLoader() + + mod_dir = os.path.abspath('bar') + expected_dir = os.path.abspath('foo') + msg = re.escape(r"'foo' module incorrectly imported from %r. Expected %r. " + "Is this module globally installed?" % (mod_dir, expected_dir)) + self.assertRaisesRegex( + ImportError, '^%s$' % msg, loader.discover, + start_dir='foo', pattern='foo.py' + ) + self.assertEqual(sys.path[0], full_path) + + def test_module_symlink_ok(self): + full_path = self.setup_module_clash() + + original_realpath = os.path.realpath + + mod_dir = os.path.abspath('bar') + expected_dir = os.path.abspath('foo') + + def cleanup(): + os.path.realpath = original_realpath + self.addCleanup(cleanup) + + def realpath(path): + if path == os.path.join(mod_dir, 'foo.py'): + return os.path.join(expected_dir, 'foo.py') + return path + os.path.realpath = realpath + loader = unittest.TestLoader() + loader.discover(start_dir='foo', pattern='foo.py') + + def test_discovery_from_dotted_path(self): + loader = unittest.TestLoader() + + tests = [self] + expectedPath = os.path.abspath(os.path.dirname(test.test_unittest.__file__)) + + self.wasRun = False + def _find_tests(start_dir, pattern): + self.wasRun = True + self.assertEqual(start_dir, expectedPath) + return tests + loader._find_tests = _find_tests + suite = loader.discover('test.test_unittest') + self.assertTrue(self.wasRun) + self.assertEqual(suite._tests, tests) + + + def test_discovery_from_dotted_path_builtin_modules(self): + + loader = unittest.TestLoader() + + listdir = os.listdir + os.listdir = lambda _: ['test_this_does_not_exist.py'] + isfile = os.path.isfile + isdir = os.path.isdir + os.path.isdir = lambda _: False + orig_sys_path = sys.path[:] + def restore(): + os.path.isfile = isfile + os.path.isdir = isdir + os.listdir = listdir + sys.path[:] = orig_sys_path + self.addCleanup(restore) + + with self.assertRaises(TypeError) as cm: + loader.discover('sys') + self.assertEqual(str(cm.exception), + 'Can not use builtin modules ' + 'as dotted module names') + + def test_discovery_failed_discovery(self): + loader = unittest.TestLoader() + package = types.ModuleType('package') + + def _import(packagename, *args, **kwargs): + sys.modules[packagename] = package + return package + + with unittest.mock.patch('builtins.__import__', _import): + # Since loader.discover() can modify sys.path, restore it when done. + with import_helper.DirsOnSysPath(): + # Make sure to remove 'package' from sys.modules when done. + with test.test_importlib.util.uncache('package'): + with self.assertRaises(TypeError) as cm: + loader.discover('package') + self.assertEqual(str(cm.exception), + 'don\'t know how to discover from {!r}' + .format(package)) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/test_functiontestcase.py b/Lib/test/test_unittest/test_functiontestcase.py new file mode 100644 index 0000000..2ebed95 --- /dev/null +++ b/Lib/test/test_unittest/test_functiontestcase.py @@ -0,0 +1,148 @@ +import unittest + +from test.test_unittest.support import LoggingResult + + +class Test_FunctionTestCase(unittest.TestCase): + + # "Return the number of tests represented by the this test object. For + # TestCase instances, this will always be 1" + def test_countTestCases(self): + test = unittest.FunctionTestCase(lambda: None) + + self.assertEqual(test.countTestCases(), 1) + + # "When a setUp() method is defined, the test runner will run that method + # prior to each test. Likewise, if a tearDown() method is defined, the + # test runner will invoke that method after each test. In the example, + # setUp() was used to create a fresh sequence for each test." + # + # Make sure the proper call order is maintained, even if setUp() raises + # an exception. + def test_run_call_order__error_in_setUp(self): + events = [] + result = LoggingResult(events) + + def setUp(): + events.append('setUp') + raise RuntimeError('raised by setUp') + + def test(): + events.append('test') + + def tearDown(): + events.append('tearDown') + + expected = ['startTest', 'setUp', 'addError', 'stopTest'] + unittest.FunctionTestCase(test, setUp, tearDown).run(result) + self.assertEqual(events, expected) + + # "When a setUp() method is defined, the test runner will run that method + # prior to each test. Likewise, if a tearDown() method is defined, the + # test runner will invoke that method after each test. In the example, + # setUp() was used to create a fresh sequence for each test." + # + # Make sure the proper call order is maintained, even if the test raises + # an error (as opposed to a failure). + def test_run_call_order__error_in_test(self): + events = [] + result = LoggingResult(events) + + def setUp(): + events.append('setUp') + + def test(): + events.append('test') + raise RuntimeError('raised by test') + + def tearDown(): + events.append('tearDown') + + expected = ['startTest', 'setUp', 'test', + 'addError', 'tearDown', 'stopTest'] + unittest.FunctionTestCase(test, setUp, tearDown).run(result) + self.assertEqual(events, expected) + + # "When a setUp() method is defined, the test runner will run that method + # prior to each test. Likewise, if a tearDown() method is defined, the + # test runner will invoke that method after each test. In the example, + # setUp() was used to create a fresh sequence for each test." + # + # Make sure the proper call order is maintained, even if the test signals + # a failure (as opposed to an error). + def test_run_call_order__failure_in_test(self): + events = [] + result = LoggingResult(events) + + def setUp(): + events.append('setUp') + + def test(): + events.append('test') + self.fail('raised by test') + + def tearDown(): + events.append('tearDown') + + expected = ['startTest', 'setUp', 'test', + 'addFailure', 'tearDown', 'stopTest'] + unittest.FunctionTestCase(test, setUp, tearDown).run(result) + self.assertEqual(events, expected) + + # "When a setUp() method is defined, the test runner will run that method + # prior to each test. Likewise, if a tearDown() method is defined, the + # test runner will invoke that method after each test. In the example, + # setUp() was used to create a fresh sequence for each test." + # + # Make sure the proper call order is maintained, even if tearDown() raises + # an exception. + def test_run_call_order__error_in_tearDown(self): + events = [] + result = LoggingResult(events) + + def setUp(): + events.append('setUp') + + def test(): + events.append('test') + + def tearDown(): + events.append('tearDown') + raise RuntimeError('raised by tearDown') + + expected = ['startTest', 'setUp', 'test', 'tearDown', 'addError', + 'stopTest'] + unittest.FunctionTestCase(test, setUp, tearDown).run(result) + self.assertEqual(events, expected) + + # "Return a string identifying the specific test case." + # + # Because of the vague nature of the docs, I'm not going to lock this + # test down too much. Really all that can be asserted is that the id() + # will be a string (either 8-byte or unicode -- again, because the docs + # just say "string") + def test_id(self): + test = unittest.FunctionTestCase(lambda: None) + + self.assertIsInstance(test.id(), str) + + # "Returns a one-line description of the test, or None if no description + # has been provided. The default implementation of this method returns + # the first line of the test method's docstring, if available, or None." + def test_shortDescription__no_docstring(self): + test = unittest.FunctionTestCase(lambda: None) + + self.assertEqual(test.shortDescription(), None) + + # "Returns a one-line description of the test, or None if no description + # has been provided. The default implementation of this method returns + # the first line of the test method's docstring, if available, or None." + def test_shortDescription__singleline_docstring(self): + desc = "this tests foo" + test = unittest.FunctionTestCase(lambda: None, description=desc) + + self.assertEqual(test.shortDescription(), "this tests foo") + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/test_loader.py b/Lib/test/test_unittest/test_loader.py new file mode 100644 index 0000000..c06ebb6 --- /dev/null +++ b/Lib/test/test_unittest/test_loader.py @@ -0,0 +1,1642 @@ +import functools +import sys +import types +import warnings + +import unittest + +# Decorator used in the deprecation tests to reset the warning registry for +# test isolation and reproducibility. +def warningregistry(func): + def wrapper(*args, **kws): + missing = [] + saved = getattr(warnings, '__warningregistry__', missing).copy() + try: + return func(*args, **kws) + finally: + if saved is missing: + try: + del warnings.__warningregistry__ + except AttributeError: + pass + else: + warnings.__warningregistry__ = saved + return wrapper + + +class Test_TestLoader(unittest.TestCase): + + ### Basic object tests + ################################################################ + + def test___init__(self): + loader = unittest.TestLoader() + self.assertEqual([], loader.errors) + + ### Tests for TestLoader.loadTestsFromTestCase + ################################################################ + + # "Return a suite of all test cases contained in the TestCase-derived + # class testCaseClass" + def test_loadTestsFromTestCase(self): + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foo_bar(self): pass + + tests = unittest.TestSuite([Foo('test_1'), Foo('test_2')]) + + loader = unittest.TestLoader() + self.assertEqual(loader.loadTestsFromTestCase(Foo), tests) + + # "Return a suite of all test cases contained in the TestCase-derived + # class testCaseClass" + # + # Make sure it does the right thing even if no tests were found + def test_loadTestsFromTestCase__no_matches(self): + class Foo(unittest.TestCase): + def foo_bar(self): pass + + empty_suite = unittest.TestSuite() + + loader = unittest.TestLoader() + self.assertEqual(loader.loadTestsFromTestCase(Foo), empty_suite) + + # "Return a suite of all test cases contained in the TestCase-derived + # class testCaseClass" + # + # What happens if loadTestsFromTestCase() is given an object + # that isn't a subclass of TestCase? Specifically, what happens + # if testCaseClass is a subclass of TestSuite? + # + # This is checked for specifically in the code, so we better add a + # test for it. + def test_loadTestsFromTestCase__TestSuite_subclass(self): + class NotATestCase(unittest.TestSuite): + pass + + loader = unittest.TestLoader() + try: + loader.loadTestsFromTestCase(NotATestCase) + except TypeError: + pass + else: + self.fail('Should raise TypeError') + + # "Return a suite of all test cases contained in the TestCase-derived + # class testCaseClass" + # + # Make sure loadTestsFromTestCase() picks up the default test method + # name (as specified by TestCase), even though the method name does + # not match the default TestLoader.testMethodPrefix string + def test_loadTestsFromTestCase__default_method_name(self): + class Foo(unittest.TestCase): + def runTest(self): + pass + + loader = unittest.TestLoader() + # This has to be false for the test to succeed + self.assertFalse('runTest'.startswith(loader.testMethodPrefix)) + + suite = loader.loadTestsFromTestCase(Foo) + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), [Foo('runTest')]) + + ################################################################ + ### /Tests for TestLoader.loadTestsFromTestCase + + ### Tests for TestLoader.loadTestsFromModule + ################################################################ + + # "This method searches `module` for classes derived from TestCase" + def test_loadTestsFromModule__TestCase_subclass(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(m) + self.assertIsInstance(suite, loader.suiteClass) + + expected = [loader.suiteClass([MyTestCase('test')])] + self.assertEqual(list(suite), expected) + + # "This method searches `module` for classes derived from TestCase" + # + # What happens if no tests are found (no TestCase instances)? + def test_loadTestsFromModule__no_TestCase_instances(self): + m = types.ModuleType('m') + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(m) + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), []) + + # "This method searches `module` for classes derived from TestCase" + # + # What happens if no tests are found (TestCases instances, but no tests)? + def test_loadTestsFromModule__no_TestCase_tests(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + pass + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(m) + self.assertIsInstance(suite, loader.suiteClass) + + self.assertEqual(list(suite), [loader.suiteClass()]) + + # "This method searches `module` for classes derived from TestCase"s + # + # What happens if loadTestsFromModule() is given something other + # than a module? + # + # XXX Currently, it succeeds anyway. This flexibility + # should either be documented or loadTestsFromModule() should + # raise a TypeError + # + # XXX Certain people are using this behaviour. We'll add a test for it + def test_loadTestsFromModule__not_a_module(self): + class MyTestCase(unittest.TestCase): + def test(self): + pass + + class NotAModule(object): + test_2 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(NotAModule) + + reference = [unittest.TestSuite([MyTestCase('test')])] + self.assertEqual(list(suite), reference) + + + # Check that loadTestsFromModule honors (or not) a module + # with a load_tests function. + @warningregistry + def test_loadTestsFromModule__load_tests(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(m) + self.assertIsInstance(suite, unittest.TestSuite) + self.assertEqual(load_tests_args, [loader, suite, None]) + # With Python 3.5, the undocumented and unofficial use_load_tests is + # ignored (and deprecated). + load_tests_args = [] + with warnings.catch_warnings(record=False): + warnings.simplefilter('ignore') + suite = loader.loadTestsFromModule(m, use_load_tests=False) + self.assertEqual(load_tests_args, [loader, suite, None]) + + @warningregistry + def test_loadTestsFromModule__use_load_tests_deprecated_positional(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + # The method still works. + loader = unittest.TestLoader() + # use_load_tests=True as a positional argument. + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + suite = loader.loadTestsFromModule(m, False) + self.assertIsInstance(suite, unittest.TestSuite) + # load_tests was still called because use_load_tests is deprecated + # and ignored. + self.assertEqual(load_tests_args, [loader, suite, None]) + # We got a warning. + self.assertIs(w[-1].category, DeprecationWarning) + self.assertEqual(str(w[-1].message), + 'use_load_tests is deprecated and ignored') + + @warningregistry + def test_loadTestsFromModule__use_load_tests_deprecated_keyword(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + # The method still works. + loader = unittest.TestLoader() + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + suite = loader.loadTestsFromModule(m, use_load_tests=False) + self.assertIsInstance(suite, unittest.TestSuite) + # load_tests was still called because use_load_tests is deprecated + # and ignored. + self.assertEqual(load_tests_args, [loader, suite, None]) + # We got a warning. + self.assertIs(w[-1].category, DeprecationWarning) + self.assertEqual(str(w[-1].message), + 'use_load_tests is deprecated and ignored') + + @warningregistry + def test_loadTestsFromModule__too_many_positional_args(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + loader = unittest.TestLoader() + with self.assertRaises(TypeError) as cm, \ + warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + loader.loadTestsFromModule(m, False, 'testme.*') + # We still got the deprecation warning. + self.assertIs(w[-1].category, DeprecationWarning) + self.assertEqual(str(w[-1].message), + 'use_load_tests is deprecated and ignored') + # We also got a TypeError for too many positional arguments. + self.assertEqual(type(cm.exception), TypeError) + self.assertEqual( + str(cm.exception), + 'loadTestsFromModule() takes 1 positional argument but 3 were given') + + @warningregistry + def test_loadTestsFromModule__use_load_tests_other_bad_keyword(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + loader = unittest.TestLoader() + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + with self.assertRaises(TypeError) as cm: + loader.loadTestsFromModule( + m, use_load_tests=False, very_bad=True, worse=False) + self.assertEqual(type(cm.exception), TypeError) + # The error message names the first bad argument alphabetically, + # however use_load_tests (which sorts first) is ignored. + self.assertEqual( + str(cm.exception), + "loadTestsFromModule() got an unexpected keyword argument 'very_bad'") + + def test_loadTestsFromModule__pattern(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + load_tests_args = [] + def load_tests(loader, tests, pattern): + self.assertIsInstance(tests, unittest.TestSuite) + load_tests_args.extend((loader, tests, pattern)) + return tests + m.load_tests = load_tests + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(m, pattern='testme.*') + self.assertIsInstance(suite, unittest.TestSuite) + self.assertEqual(load_tests_args, [loader, suite, 'testme.*']) + + def test_loadTestsFromModule__faulty_load_tests(self): + m = types.ModuleType('m') + + def load_tests(loader, tests, pattern): + raise TypeError('some failure') + m.load_tests = load_tests + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(m) + self.assertIsInstance(suite, unittest.TestSuite) + self.assertEqual(suite.countTestCases(), 1) + # Errors loading the suite are also captured for introspection. + self.assertNotEqual([], loader.errors) + self.assertEqual(1, len(loader.errors)) + error = loader.errors[0] + self.assertTrue( + 'Failed to call load_tests:' in error, + 'missing error string in %r' % error) + test = list(suite)[0] + + self.assertRaisesRegex(TypeError, "some failure", test.m) + + ################################################################ + ### /Tests for TestLoader.loadTestsFromModule() + + ### Tests for TestLoader.loadTestsFromName() + ################################################################ + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # Is ValueError raised in response to an empty name? + def test_loadTestsFromName__empty_name(self): + loader = unittest.TestLoader() + + try: + loader.loadTestsFromName('') + except ValueError as e: + self.assertEqual(str(e), "Empty module name") + else: + self.fail("TestLoader.loadTestsFromName failed to raise ValueError") + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # What happens when the name contains invalid characters? + def test_loadTestsFromName__malformed_name(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromName('abc () //') + error, test = self.check_deferred_error(loader, suite) + expected = "Failed to import test module: abc () //" + expected_regex = r"Failed to import test module: abc \(\) //" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex( + ImportError, expected_regex, getattr(test, 'abc () //')) + + # "The specifier name is a ``dotted name'' that may resolve ... to a + # module" + # + # What happens when a module by that name can't be found? + def test_loadTestsFromName__unknown_module_name(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromName('sdasfasfasdf') + expected = "No module named 'sdasfasfasdf'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(ImportError, expected, test.sdasfasfasdf) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # What happens when the module is found, but the attribute isn't? + def test_loadTestsFromName__unknown_attr_name_on_module(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromName('unittest.loader.sdasfasfasdf') + expected = "module 'unittest.loader' has no attribute 'sdasfasfasdf'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # What happens when the module is found, but the attribute isn't? + def test_loadTestsFromName__unknown_attr_name_on_package(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromName('unittest.sdasfasfasdf') + expected = "No module named 'unittest.sdasfasfasdf'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(ImportError, expected, test.sdasfasfasdf) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # What happens when we provide the module, but the attribute can't be + # found? + def test_loadTestsFromName__relative_unknown_name(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromName('sdasfasfasdf', unittest) + expected = "module 'unittest' has no attribute 'sdasfasfasdf'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # ... + # "The method optionally resolves name relative to the given module" + # + # Does loadTestsFromName raise ValueError when passed an empty + # name relative to a provided module? + # + # XXX Should probably raise a ValueError instead of an AttributeError + def test_loadTestsFromName__relative_empty_name(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromName('', unittest) + error, test = self.check_deferred_error(loader, suite) + expected = "has no attribute ''" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, getattr(test, '')) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # ... + # "The method optionally resolves name relative to the given module" + # + # What happens when an impossible name is given, relative to the provided + # `module`? + def test_loadTestsFromName__relative_malformed_name(self): + loader = unittest.TestLoader() + + # XXX Should this raise AttributeError or ValueError? + suite = loader.loadTestsFromName('abc () //', unittest) + error, test = self.check_deferred_error(loader, suite) + expected = "module 'unittest' has no attribute 'abc () //'" + expected_regex = r"module 'unittest' has no attribute 'abc \(\) //'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex( + AttributeError, expected_regex, getattr(test, 'abc () //')) + + # "The method optionally resolves name relative to the given module" + # + # Does loadTestsFromName raise TypeError when the `module` argument + # isn't a module object? + # + # XXX Accepts the not-a-module object, ignoring the object's type + # This should raise an exception or the method name should be changed + # + # XXX Some people are relying on this, so keep it for now + def test_loadTestsFromName__relative_not_a_module(self): + class MyTestCase(unittest.TestCase): + def test(self): + pass + + class NotAModule(object): + test_2 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromName('test_2', NotAModule) + + reference = [MyTestCase('test')] + self.assertEqual(list(suite), reference) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # Does it raise an exception if the name resolves to an invalid + # object? + def test_loadTestsFromName__relative_bad_object(self): + m = types.ModuleType('m') + m.testcase_1 = object() + + loader = unittest.TestLoader() + try: + loader.loadTestsFromName('testcase_1', m) + except TypeError: + pass + else: + self.fail("Should have raised TypeError") + + # "The specifier name is a ``dotted name'' that may + # resolve either to ... a test case class" + def test_loadTestsFromName__relative_TestCase_subclass(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromName('testcase_1', m) + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), [MyTestCase('test')]) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + def test_loadTestsFromName__relative_TestSuite(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testsuite = unittest.TestSuite([MyTestCase('test')]) + + loader = unittest.TestLoader() + suite = loader.loadTestsFromName('testsuite', m) + self.assertIsInstance(suite, loader.suiteClass) + + self.assertEqual(list(suite), [MyTestCase('test')]) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a test method within a test case class" + def test_loadTestsFromName__relative_testmethod(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromName('testcase_1.test', m) + self.assertIsInstance(suite, loader.suiteClass) + + self.assertEqual(list(suite), [MyTestCase('test')]) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # Does loadTestsFromName() raise the proper exception when trying to + # resolve "a test method within a test case class" that doesn't exist + # for the given name (relative to a provided module)? + def test_loadTestsFromName__relative_invalid_testmethod(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromName('testcase_1.testfoo', m) + expected = "type object 'MyTestCase' has no attribute 'testfoo'" + error, test = self.check_deferred_error(loader, suite) + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.testfoo) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a callable object which returns a ... TestSuite instance" + def test_loadTestsFromName__callable__TestSuite(self): + m = types.ModuleType('m') + testcase_1 = unittest.FunctionTestCase(lambda: None) + testcase_2 = unittest.FunctionTestCase(lambda: None) + def return_TestSuite(): + return unittest.TestSuite([testcase_1, testcase_2]) + m.return_TestSuite = return_TestSuite + + loader = unittest.TestLoader() + suite = loader.loadTestsFromName('return_TestSuite', m) + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), [testcase_1, testcase_2]) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a callable object which returns a TestCase ... instance" + def test_loadTestsFromName__callable__TestCase_instance(self): + m = types.ModuleType('m') + testcase_1 = unittest.FunctionTestCase(lambda: None) + def return_TestCase(): + return testcase_1 + m.return_TestCase = return_TestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromName('return_TestCase', m) + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), [testcase_1]) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a callable object which returns a TestCase ... instance" + #***************************************************************** + #Override the suiteClass attribute to ensure that the suiteClass + #attribute is used + def test_loadTestsFromName__callable__TestCase_instance_ProperSuiteClass(self): + class SubTestSuite(unittest.TestSuite): + pass + m = types.ModuleType('m') + testcase_1 = unittest.FunctionTestCase(lambda: None) + def return_TestCase(): + return testcase_1 + m.return_TestCase = return_TestCase + + loader = unittest.TestLoader() + loader.suiteClass = SubTestSuite + suite = loader.loadTestsFromName('return_TestCase', m) + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), [testcase_1]) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a test method within a test case class" + #***************************************************************** + #Override the suiteClass attribute to ensure that the suiteClass + #attribute is used + def test_loadTestsFromName__relative_testmethod_ProperSuiteClass(self): + class SubTestSuite(unittest.TestSuite): + pass + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + loader.suiteClass=SubTestSuite + suite = loader.loadTestsFromName('testcase_1.test', m) + self.assertIsInstance(suite, loader.suiteClass) + + self.assertEqual(list(suite), [MyTestCase('test')]) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a callable object which returns a TestCase or TestSuite instance" + # + # What happens if the callable returns something else? + def test_loadTestsFromName__callable__wrong_type(self): + m = types.ModuleType('m') + def return_wrong(): + return 6 + m.return_wrong = return_wrong + + loader = unittest.TestLoader() + try: + suite = loader.loadTestsFromName('return_wrong', m) + except TypeError: + pass + else: + self.fail("TestLoader.loadTestsFromName failed to raise TypeError") + + # "The specifier can refer to modules and packages which have not been + # imported; they will be imported as a side-effect" + def test_loadTestsFromName__module_not_loaded(self): + # We're going to try to load this module as a side-effect, so it + # better not be loaded before we try. + # + module_name = 'test.test_unittest.dummy' + sys.modules.pop(module_name, None) + + loader = unittest.TestLoader() + try: + suite = loader.loadTestsFromName(module_name) + + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), []) + + # module should now be loaded, thanks to loadTestsFromName() + self.assertIn(module_name, sys.modules) + finally: + if module_name in sys.modules: + del sys.modules[module_name] + + ################################################################ + ### Tests for TestLoader.loadTestsFromName() + + ### Tests for TestLoader.loadTestsFromNames() + ################################################################ + + def check_deferred_error(self, loader, suite): + """Helper function for checking that errors in loading are reported. + + :param loader: A loader with some errors. + :param suite: A suite that should have a late bound error. + :return: The first error message from the loader and the test object + from the suite. + """ + self.assertIsInstance(suite, unittest.TestSuite) + self.assertEqual(suite.countTestCases(), 1) + # Errors loading the suite are also captured for introspection. + self.assertNotEqual([], loader.errors) + self.assertEqual(1, len(loader.errors)) + error = loader.errors[0] + test = list(suite)[0] + return error, test + + # "Similar to loadTestsFromName(), but takes a sequence of names rather + # than a single name." + # + # What happens if that sequence of names is empty? + def test_loadTestsFromNames__empty_name_list(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromNames([]) + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), []) + + # "Similar to loadTestsFromName(), but takes a sequence of names rather + # than a single name." + # ... + # "The method optionally resolves name relative to the given module" + # + # What happens if that sequence of names is empty? + # + # XXX Should this raise a ValueError or just return an empty TestSuite? + def test_loadTestsFromNames__relative_empty_name_list(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromNames([], unittest) + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), []) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # Is ValueError raised in response to an empty name? + def test_loadTestsFromNames__empty_name(self): + loader = unittest.TestLoader() + + try: + loader.loadTestsFromNames(['']) + except ValueError as e: + self.assertEqual(str(e), "Empty module name") + else: + self.fail("TestLoader.loadTestsFromNames failed to raise ValueError") + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # What happens when presented with an impossible module name? + def test_loadTestsFromNames__malformed_name(self): + loader = unittest.TestLoader() + + # XXX Should this raise ValueError or ImportError? + suite = loader.loadTestsFromNames(['abc () //']) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "Failed to import test module: abc () //" + expected_regex = r"Failed to import test module: abc \(\) //" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex( + ImportError, expected_regex, getattr(test, 'abc () //')) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # What happens when no module can be found for the given name? + def test_loadTestsFromNames__unknown_module_name(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromNames(['sdasfasfasdf']) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "Failed to import test module: sdasfasfasdf" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(ImportError, expected, test.sdasfasfasdf) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # What happens when the module can be found, but not the attribute? + def test_loadTestsFromNames__unknown_attr_name(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromNames( + ['unittest.loader.sdasfasfasdf', 'test.test_unittest.dummy']) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "module 'unittest.loader' has no attribute 'sdasfasfasdf'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # ... + # "The method optionally resolves name relative to the given module" + # + # What happens when given an unknown attribute on a specified `module` + # argument? + def test_loadTestsFromNames__unknown_name_relative_1(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromNames(['sdasfasfasdf'], unittest) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "module 'unittest' has no attribute 'sdasfasfasdf'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # ... + # "The method optionally resolves name relative to the given module" + # + # Do unknown attributes (relative to a provided module) still raise an + # exception even in the presence of valid attribute names? + def test_loadTestsFromNames__unknown_name_relative_2(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromNames(['TestCase', 'sdasfasfasdf'], unittest) + error, test = self.check_deferred_error(loader, list(suite)[1]) + expected = "module 'unittest' has no attribute 'sdasfasfasdf'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.sdasfasfasdf) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # ... + # "The method optionally resolves name relative to the given module" + # + # What happens when faced with the empty string? + # + # XXX This currently raises AttributeError, though ValueError is probably + # more appropriate + def test_loadTestsFromNames__relative_empty_name(self): + loader = unittest.TestLoader() + + suite = loader.loadTestsFromNames([''], unittest) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "has no attribute ''" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, getattr(test, '')) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # ... + # "The method optionally resolves name relative to the given module" + # + # What happens when presented with an impossible attribute name? + def test_loadTestsFromNames__relative_malformed_name(self): + loader = unittest.TestLoader() + + # XXX Should this raise AttributeError or ValueError? + suite = loader.loadTestsFromNames(['abc () //'], unittest) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "module 'unittest' has no attribute 'abc () //'" + expected_regex = r"module 'unittest' has no attribute 'abc \(\) //'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex( + AttributeError, expected_regex, getattr(test, 'abc () //')) + + # "The method optionally resolves name relative to the given module" + # + # Does loadTestsFromNames() make sure the provided `module` is in fact + # a module? + # + # XXX This validation is currently not done. This flexibility should + # either be documented or a TypeError should be raised. + def test_loadTestsFromNames__relative_not_a_module(self): + class MyTestCase(unittest.TestCase): + def test(self): + pass + + class NotAModule(object): + test_2 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromNames(['test_2'], NotAModule) + + reference = [unittest.TestSuite([MyTestCase('test')])] + self.assertEqual(list(suite), reference) + + # "The specifier name is a ``dotted name'' that may resolve either to + # a module, a test case class, a TestSuite instance, a test method + # within a test case class, or a callable object which returns a + # TestCase or TestSuite instance." + # + # Does it raise an exception if the name resolves to an invalid + # object? + def test_loadTestsFromNames__relative_bad_object(self): + m = types.ModuleType('m') + m.testcase_1 = object() + + loader = unittest.TestLoader() + try: + loader.loadTestsFromNames(['testcase_1'], m) + except TypeError: + pass + else: + self.fail("Should have raised TypeError") + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a test case class" + def test_loadTestsFromNames__relative_TestCase_subclass(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromNames(['testcase_1'], m) + self.assertIsInstance(suite, loader.suiteClass) + + expected = loader.suiteClass([MyTestCase('test')]) + self.assertEqual(list(suite), [expected]) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a TestSuite instance" + def test_loadTestsFromNames__relative_TestSuite(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testsuite = unittest.TestSuite([MyTestCase('test')]) + + loader = unittest.TestLoader() + suite = loader.loadTestsFromNames(['testsuite'], m) + self.assertIsInstance(suite, loader.suiteClass) + + self.assertEqual(list(suite), [m.testsuite]) + + # "The specifier name is a ``dotted name'' that may resolve ... to ... a + # test method within a test case class" + def test_loadTestsFromNames__relative_testmethod(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromNames(['testcase_1.test'], m) + self.assertIsInstance(suite, loader.suiteClass) + + ref_suite = unittest.TestSuite([MyTestCase('test')]) + self.assertEqual(list(suite), [ref_suite]) + + # #14971: Make sure the dotted name resolution works even if the actual + # function doesn't have the same name as is used to find it. + def test_loadTestsFromName__function_with_different_name_than_method(self): + # lambdas have the name '<lambda>'. + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + test = lambda: 1 + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromNames(['testcase_1.test'], m) + self.assertIsInstance(suite, loader.suiteClass) + + ref_suite = unittest.TestSuite([MyTestCase('test')]) + self.assertEqual(list(suite), [ref_suite]) + + # "The specifier name is a ``dotted name'' that may resolve ... to ... a + # test method within a test case class" + # + # Does the method gracefully handle names that initially look like they + # resolve to "a test method within a test case class" but don't? + def test_loadTestsFromNames__relative_invalid_testmethod(self): + m = types.ModuleType('m') + class MyTestCase(unittest.TestCase): + def test(self): + pass + m.testcase_1 = MyTestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromNames(['testcase_1.testfoo'], m) + error, test = self.check_deferred_error(loader, list(suite)[0]) + expected = "type object 'MyTestCase' has no attribute 'testfoo'" + self.assertIn( + expected, error, + 'missing error string in %r' % error) + self.assertRaisesRegex(AttributeError, expected, test.testfoo) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a callable object which returns a ... TestSuite instance" + def test_loadTestsFromNames__callable__TestSuite(self): + m = types.ModuleType('m') + testcase_1 = unittest.FunctionTestCase(lambda: None) + testcase_2 = unittest.FunctionTestCase(lambda: None) + def return_TestSuite(): + return unittest.TestSuite([testcase_1, testcase_2]) + m.return_TestSuite = return_TestSuite + + loader = unittest.TestLoader() + suite = loader.loadTestsFromNames(['return_TestSuite'], m) + self.assertIsInstance(suite, loader.suiteClass) + + expected = unittest.TestSuite([testcase_1, testcase_2]) + self.assertEqual(list(suite), [expected]) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a callable object which returns a TestCase ... instance" + def test_loadTestsFromNames__callable__TestCase_instance(self): + m = types.ModuleType('m') + testcase_1 = unittest.FunctionTestCase(lambda: None) + def return_TestCase(): + return testcase_1 + m.return_TestCase = return_TestCase + + loader = unittest.TestLoader() + suite = loader.loadTestsFromNames(['return_TestCase'], m) + self.assertIsInstance(suite, loader.suiteClass) + + ref_suite = unittest.TestSuite([testcase_1]) + self.assertEqual(list(suite), [ref_suite]) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a callable object which returns a TestCase or TestSuite instance" + # + # Are staticmethods handled correctly? + def test_loadTestsFromNames__callable__call_staticmethod(self): + m = types.ModuleType('m') + class Test1(unittest.TestCase): + def test(self): + pass + + testcase_1 = Test1('test') + class Foo(unittest.TestCase): + @staticmethod + def foo(): + return testcase_1 + m.Foo = Foo + + loader = unittest.TestLoader() + suite = loader.loadTestsFromNames(['Foo.foo'], m) + self.assertIsInstance(suite, loader.suiteClass) + + ref_suite = unittest.TestSuite([testcase_1]) + self.assertEqual(list(suite), [ref_suite]) + + # "The specifier name is a ``dotted name'' that may resolve ... to + # ... a callable object which returns a TestCase or TestSuite instance" + # + # What happens when the callable returns something else? + def test_loadTestsFromNames__callable__wrong_type(self): + m = types.ModuleType('m') + def return_wrong(): + return 6 + m.return_wrong = return_wrong + + loader = unittest.TestLoader() + try: + suite = loader.loadTestsFromNames(['return_wrong'], m) + except TypeError: + pass + else: + self.fail("TestLoader.loadTestsFromNames failed to raise TypeError") + + # "The specifier can refer to modules and packages which have not been + # imported; they will be imported as a side-effect" + def test_loadTestsFromNames__module_not_loaded(self): + # We're going to try to load this module as a side-effect, so it + # better not be loaded before we try. + # + module_name = 'test.test_unittest.dummy' + sys.modules.pop(module_name, None) + + loader = unittest.TestLoader() + try: + suite = loader.loadTestsFromNames([module_name]) + + self.assertIsInstance(suite, loader.suiteClass) + self.assertEqual(list(suite), [unittest.TestSuite()]) + + # module should now be loaded, thanks to loadTestsFromName() + self.assertIn(module_name, sys.modules) + finally: + if module_name in sys.modules: + del sys.modules[module_name] + + ################################################################ + ### /Tests for TestLoader.loadTestsFromNames() + + ### Tests for TestLoader.getTestCaseNames() + ################################################################ + + # "Return a sorted sequence of method names found within testCaseClass" + # + # Test.foobar is defined to make sure getTestCaseNames() respects + # loader.testMethodPrefix + def test_getTestCaseNames(self): + class Test(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foobar(self): pass + + loader = unittest.TestLoader() + + self.assertEqual(loader.getTestCaseNames(Test), ['test_1', 'test_2']) + + # "Return a sorted sequence of method names found within testCaseClass" + # + # Does getTestCaseNames() behave appropriately if no tests are found? + def test_getTestCaseNames__no_tests(self): + class Test(unittest.TestCase): + def foobar(self): pass + + loader = unittest.TestLoader() + + self.assertEqual(loader.getTestCaseNames(Test), []) + + # "Return a sorted sequence of method names found within testCaseClass" + # + # Are not-TestCases handled gracefully? + # + # XXX This should raise a TypeError, not return a list + # + # XXX It's too late in the 2.5 release cycle to fix this, but it should + # probably be revisited for 2.6 + def test_getTestCaseNames__not_a_TestCase(self): + class BadCase(int): + def test_foo(self): + pass + + loader = unittest.TestLoader() + names = loader.getTestCaseNames(BadCase) + + self.assertEqual(names, ['test_foo']) + + # "Return a sorted sequence of method names found within testCaseClass" + # + # Make sure inherited names are handled. + # + # TestP.foobar is defined to make sure getTestCaseNames() respects + # loader.testMethodPrefix + def test_getTestCaseNames__inheritance(self): + class TestP(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foobar(self): pass + + class TestC(TestP): + def test_1(self): pass + def test_3(self): pass + + loader = unittest.TestLoader() + + names = ['test_1', 'test_2', 'test_3'] + self.assertEqual(loader.getTestCaseNames(TestC), names) + + # "Return a sorted sequence of method names found within testCaseClass" + # + # If TestLoader.testNamePatterns is set, only tests that match one of these + # patterns should be included. + def test_getTestCaseNames__testNamePatterns(self): + class MyTest(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foobar(self): pass + + loader = unittest.TestLoader() + + loader.testNamePatterns = [] + self.assertEqual(loader.getTestCaseNames(MyTest), []) + + loader.testNamePatterns = ['*1'] + self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1']) + + loader.testNamePatterns = ['*1', '*2'] + self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1', 'test_2']) + + loader.testNamePatterns = ['*My*'] + self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1', 'test_2']) + + loader.testNamePatterns = ['*my*'] + self.assertEqual(loader.getTestCaseNames(MyTest), []) + + # "Return a sorted sequence of method names found within testCaseClass" + # + # If TestLoader.testNamePatterns is set, only tests that match one of these + # patterns should be included. + # + # For backwards compatibility reasons (see bpo-32071), the check may only + # touch a TestCase's attribute if it starts with the test method prefix. + def test_getTestCaseNames__testNamePatterns__attribute_access_regression(self): + class Trap: + def __get__(*ignored): + self.fail('Non-test attribute accessed') + + class MyTest(unittest.TestCase): + def test_1(self): pass + foobar = Trap() + + loader = unittest.TestLoader() + self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1']) + + loader = unittest.TestLoader() + loader.testNamePatterns = [] + self.assertEqual(loader.getTestCaseNames(MyTest), []) + + ################################################################ + ### /Tests for TestLoader.getTestCaseNames() + + ### Tests for TestLoader.testMethodPrefix + ################################################################ + + # "String giving the prefix of method names which will be interpreted as + # test methods" + # + # Implicit in the documentation is that testMethodPrefix is respected by + # all loadTestsFrom* methods. + def test_testMethodPrefix__loadTestsFromTestCase(self): + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foo_bar(self): pass + + tests_1 = unittest.TestSuite([Foo('foo_bar')]) + tests_2 = unittest.TestSuite([Foo('test_1'), Foo('test_2')]) + + loader = unittest.TestLoader() + loader.testMethodPrefix = 'foo' + self.assertEqual(loader.loadTestsFromTestCase(Foo), tests_1) + + loader.testMethodPrefix = 'test' + self.assertEqual(loader.loadTestsFromTestCase(Foo), tests_2) + + # "String giving the prefix of method names which will be interpreted as + # test methods" + # + # Implicit in the documentation is that testMethodPrefix is respected by + # all loadTestsFrom* methods. + def test_testMethodPrefix__loadTestsFromModule(self): + m = types.ModuleType('m') + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foo_bar(self): pass + m.Foo = Foo + + tests_1 = [unittest.TestSuite([Foo('foo_bar')])] + tests_2 = [unittest.TestSuite([Foo('test_1'), Foo('test_2')])] + + loader = unittest.TestLoader() + loader.testMethodPrefix = 'foo' + self.assertEqual(list(loader.loadTestsFromModule(m)), tests_1) + + loader.testMethodPrefix = 'test' + self.assertEqual(list(loader.loadTestsFromModule(m)), tests_2) + + # "String giving the prefix of method names which will be interpreted as + # test methods" + # + # Implicit in the documentation is that testMethodPrefix is respected by + # all loadTestsFrom* methods. + def test_testMethodPrefix__loadTestsFromName(self): + m = types.ModuleType('m') + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foo_bar(self): pass + m.Foo = Foo + + tests_1 = unittest.TestSuite([Foo('foo_bar')]) + tests_2 = unittest.TestSuite([Foo('test_1'), Foo('test_2')]) + + loader = unittest.TestLoader() + loader.testMethodPrefix = 'foo' + self.assertEqual(loader.loadTestsFromName('Foo', m), tests_1) + + loader.testMethodPrefix = 'test' + self.assertEqual(loader.loadTestsFromName('Foo', m), tests_2) + + # "String giving the prefix of method names which will be interpreted as + # test methods" + # + # Implicit in the documentation is that testMethodPrefix is respected by + # all loadTestsFrom* methods. + def test_testMethodPrefix__loadTestsFromNames(self): + m = types.ModuleType('m') + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foo_bar(self): pass + m.Foo = Foo + + tests_1 = unittest.TestSuite([unittest.TestSuite([Foo('foo_bar')])]) + tests_2 = unittest.TestSuite([Foo('test_1'), Foo('test_2')]) + tests_2 = unittest.TestSuite([tests_2]) + + loader = unittest.TestLoader() + loader.testMethodPrefix = 'foo' + self.assertEqual(loader.loadTestsFromNames(['Foo'], m), tests_1) + + loader.testMethodPrefix = 'test' + self.assertEqual(loader.loadTestsFromNames(['Foo'], m), tests_2) + + # "The default value is 'test'" + def test_testMethodPrefix__default_value(self): + loader = unittest.TestLoader() + self.assertEqual(loader.testMethodPrefix, 'test') + + ################################################################ + ### /Tests for TestLoader.testMethodPrefix + + ### Tests for TestLoader.sortTestMethodsUsing + ################################################################ + + # "Function to be used to compare method names when sorting them in + # getTestCaseNames() and all the loadTestsFromX() methods" + def test_sortTestMethodsUsing__loadTestsFromTestCase(self): + def reversed_cmp(x, y): + return -((x > y) - (x < y)) + + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + + loader = unittest.TestLoader() + loader.sortTestMethodsUsing = reversed_cmp + + tests = loader.suiteClass([Foo('test_2'), Foo('test_1')]) + self.assertEqual(loader.loadTestsFromTestCase(Foo), tests) + + # "Function to be used to compare method names when sorting them in + # getTestCaseNames() and all the loadTestsFromX() methods" + def test_sortTestMethodsUsing__loadTestsFromModule(self): + def reversed_cmp(x, y): + return -((x > y) - (x < y)) + + m = types.ModuleType('m') + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + m.Foo = Foo + + loader = unittest.TestLoader() + loader.sortTestMethodsUsing = reversed_cmp + + tests = [loader.suiteClass([Foo('test_2'), Foo('test_1')])] + self.assertEqual(list(loader.loadTestsFromModule(m)), tests) + + # "Function to be used to compare method names when sorting them in + # getTestCaseNames() and all the loadTestsFromX() methods" + def test_sortTestMethodsUsing__loadTestsFromName(self): + def reversed_cmp(x, y): + return -((x > y) - (x < y)) + + m = types.ModuleType('m') + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + m.Foo = Foo + + loader = unittest.TestLoader() + loader.sortTestMethodsUsing = reversed_cmp + + tests = loader.suiteClass([Foo('test_2'), Foo('test_1')]) + self.assertEqual(loader.loadTestsFromName('Foo', m), tests) + + # "Function to be used to compare method names when sorting them in + # getTestCaseNames() and all the loadTestsFromX() methods" + def test_sortTestMethodsUsing__loadTestsFromNames(self): + def reversed_cmp(x, y): + return -((x > y) - (x < y)) + + m = types.ModuleType('m') + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + m.Foo = Foo + + loader = unittest.TestLoader() + loader.sortTestMethodsUsing = reversed_cmp + + tests = [loader.suiteClass([Foo('test_2'), Foo('test_1')])] + self.assertEqual(list(loader.loadTestsFromNames(['Foo'], m)), tests) + + # "Function to be used to compare method names when sorting them in + # getTestCaseNames()" + # + # Does it actually affect getTestCaseNames()? + def test_sortTestMethodsUsing__getTestCaseNames(self): + def reversed_cmp(x, y): + return -((x > y) - (x < y)) + + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + + loader = unittest.TestLoader() + loader.sortTestMethodsUsing = reversed_cmp + + test_names = ['test_2', 'test_1'] + self.assertEqual(loader.getTestCaseNames(Foo), test_names) + + # "The default value is the built-in cmp() function" + # Since cmp is now defunct, we simply verify that the results + # occur in the same order as they would with the default sort. + def test_sortTestMethodsUsing__default_value(self): + loader = unittest.TestLoader() + + class Foo(unittest.TestCase): + def test_2(self): pass + def test_3(self): pass + def test_1(self): pass + + test_names = ['test_2', 'test_3', 'test_1'] + self.assertEqual(loader.getTestCaseNames(Foo), sorted(test_names)) + + + # "it can be set to None to disable the sort." + # + # XXX How is this different from reassigning cmp? Are the tests returned + # in a random order or something? This behaviour should die + def test_sortTestMethodsUsing__None(self): + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + + loader = unittest.TestLoader() + loader.sortTestMethodsUsing = None + + test_names = ['test_2', 'test_1'] + self.assertEqual(set(loader.getTestCaseNames(Foo)), set(test_names)) + + ################################################################ + ### /Tests for TestLoader.sortTestMethodsUsing + + ### Tests for TestLoader.suiteClass + ################################################################ + + # "Callable object that constructs a test suite from a list of tests." + def test_suiteClass__loadTestsFromTestCase(self): + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foo_bar(self): pass + + tests = [Foo('test_1'), Foo('test_2')] + + loader = unittest.TestLoader() + loader.suiteClass = list + self.assertEqual(loader.loadTestsFromTestCase(Foo), tests) + + # It is implicit in the documentation for TestLoader.suiteClass that + # all TestLoader.loadTestsFrom* methods respect it. Let's make sure + def test_suiteClass__loadTestsFromModule(self): + m = types.ModuleType('m') + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foo_bar(self): pass + m.Foo = Foo + + tests = [[Foo('test_1'), Foo('test_2')]] + + loader = unittest.TestLoader() + loader.suiteClass = list + self.assertEqual(loader.loadTestsFromModule(m), tests) + + # It is implicit in the documentation for TestLoader.suiteClass that + # all TestLoader.loadTestsFrom* methods respect it. Let's make sure + def test_suiteClass__loadTestsFromName(self): + m = types.ModuleType('m') + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foo_bar(self): pass + m.Foo = Foo + + tests = [Foo('test_1'), Foo('test_2')] + + loader = unittest.TestLoader() + loader.suiteClass = list + self.assertEqual(loader.loadTestsFromName('Foo', m), tests) + + # It is implicit in the documentation for TestLoader.suiteClass that + # all TestLoader.loadTestsFrom* methods respect it. Let's make sure + def test_suiteClass__loadTestsFromNames(self): + m = types.ModuleType('m') + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foo_bar(self): pass + m.Foo = Foo + + tests = [[Foo('test_1'), Foo('test_2')]] + + loader = unittest.TestLoader() + loader.suiteClass = list + self.assertEqual(loader.loadTestsFromNames(['Foo'], m), tests) + + # "The default value is the TestSuite class" + def test_suiteClass__default_value(self): + loader = unittest.TestLoader() + self.assertIs(loader.suiteClass, unittest.TestSuite) + + + def test_partial_functions(self): + def noop(arg): + pass + + class Foo(unittest.TestCase): + pass + + setattr(Foo, 'test_partial', functools.partial(noop, None)) + + loader = unittest.TestLoader() + + test_names = ['test_partial'] + self.assertEqual(loader.getTestCaseNames(Foo), test_names) + + +class TestObsoleteFunctions(unittest.TestCase): + class MyTestSuite(unittest.TestSuite): + pass + + class MyTestCase(unittest.TestCase): + def check_1(self): pass + def check_2(self): pass + def test(self): pass + + @staticmethod + def reverse_three_way_cmp(a, b): + return unittest.util.three_way_cmp(b, a) + + def test_getTestCaseNames(self): + with self.assertWarns(DeprecationWarning) as w: + tests = unittest.getTestCaseNames(self.MyTestCase, + prefix='check', sortUsing=self.reverse_three_way_cmp, + testNamePatterns=None) + self.assertEqual(w.filename, __file__) + self.assertEqual(tests, ['check_2', 'check_1']) + + def test_makeSuite(self): + with self.assertWarns(DeprecationWarning) as w: + suite = unittest.makeSuite(self.MyTestCase, + prefix='check', sortUsing=self.reverse_three_way_cmp, + suiteClass=self.MyTestSuite) + self.assertEqual(w.filename, __file__) + self.assertIsInstance(suite, self.MyTestSuite) + expected = self.MyTestSuite([self.MyTestCase('check_2'), + self.MyTestCase('check_1')]) + self.assertEqual(suite, expected) + + def test_findTestCases(self): + m = types.ModuleType('m') + m.testcase_1 = self.MyTestCase + + with self.assertWarns(DeprecationWarning) as w: + suite = unittest.findTestCases(m, + prefix='check', sortUsing=self.reverse_three_way_cmp, + suiteClass=self.MyTestSuite) + self.assertEqual(w.filename, __file__) + self.assertIsInstance(suite, self.MyTestSuite) + expected = [self.MyTestSuite([self.MyTestCase('check_2'), + self.MyTestCase('check_1')])] + self.assertEqual(list(suite), expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/test_program.py b/Lib/test/test_unittest/test_program.py new file mode 100644 index 0000000..169fc4e --- /dev/null +++ b/Lib/test/test_unittest/test_program.py @@ -0,0 +1,477 @@ +import io + +import os +import sys +import subprocess +from test import support +import unittest +import test.test_unittest +from test.test_unittest.test_result import BufferedWriter + + +class Test_TestProgram(unittest.TestCase): + + def test_discovery_from_dotted_path(self): + loader = unittest.TestLoader() + + tests = [self] + expectedPath = os.path.abspath(os.path.dirname(test.test_unittest.__file__)) + + self.wasRun = False + def _find_tests(start_dir, pattern): + self.wasRun = True + self.assertEqual(start_dir, expectedPath) + return tests + loader._find_tests = _find_tests + suite = loader.discover('test.test_unittest') + self.assertTrue(self.wasRun) + self.assertEqual(suite._tests, tests) + + # Horrible white box test + def testNoExit(self): + result = object() + test = object() + + class FakeRunner(object): + def run(self, test): + self.test = test + return result + + runner = FakeRunner() + + oldParseArgs = unittest.TestProgram.parseArgs + def restoreParseArgs(): + unittest.TestProgram.parseArgs = oldParseArgs + unittest.TestProgram.parseArgs = lambda *args: None + self.addCleanup(restoreParseArgs) + + def removeTest(): + del unittest.TestProgram.test + unittest.TestProgram.test = test + self.addCleanup(removeTest) + + program = unittest.TestProgram(testRunner=runner, exit=False, verbosity=2) + + self.assertEqual(program.result, result) + self.assertEqual(runner.test, test) + self.assertEqual(program.verbosity, 2) + + class FooBar(unittest.TestCase): + def testPass(self): + pass + def testFail(self): + raise AssertionError + def testError(self): + 1/0 + @unittest.skip('skipping') + def testSkipped(self): + raise AssertionError + @unittest.expectedFailure + def testExpectedFailure(self): + raise AssertionError + @unittest.expectedFailure + def testUnexpectedSuccess(self): + pass + + class FooBarLoader(unittest.TestLoader): + """Test loader that returns a suite containing FooBar.""" + def loadTestsFromModule(self, module): + return self.suiteClass( + [self.loadTestsFromTestCase(Test_TestProgram.FooBar)]) + + def loadTestsFromNames(self, names, module): + return self.suiteClass( + [self.loadTestsFromTestCase(Test_TestProgram.FooBar)]) + + def test_defaultTest_with_string(self): + class FakeRunner(object): + def run(self, test): + self.test = test + return True + + old_argv = sys.argv + sys.argv = ['faketest'] + runner = FakeRunner() + program = unittest.TestProgram(testRunner=runner, exit=False, + defaultTest='test.test_unittest', + testLoader=self.FooBarLoader()) + sys.argv = old_argv + self.assertEqual(('test.test_unittest',), program.testNames) + + def test_defaultTest_with_iterable(self): + class FakeRunner(object): + def run(self, test): + self.test = test + return True + + old_argv = sys.argv + sys.argv = ['faketest'] + runner = FakeRunner() + program = unittest.TestProgram( + testRunner=runner, exit=False, + defaultTest=['test.test_unittest', 'test.test_unittest2'], + testLoader=self.FooBarLoader()) + sys.argv = old_argv + self.assertEqual(['test.test_unittest', 'test.test_unittest2'], + program.testNames) + + def test_NonExit(self): + stream = BufferedWriter() + program = unittest.main(exit=False, + argv=["foobar"], + testRunner=unittest.TextTestRunner(stream=stream), + testLoader=self.FooBarLoader()) + self.assertTrue(hasattr(program, 'result')) + out = stream.getvalue() + self.assertIn('\nFAIL: testFail ', out) + self.assertIn('\nERROR: testError ', out) + self.assertIn('\nUNEXPECTED SUCCESS: testUnexpectedSuccess ', out) + expected = ('\n\nFAILED (failures=1, errors=1, skipped=1, ' + 'expected failures=1, unexpected successes=1)\n') + self.assertTrue(out.endswith(expected)) + + def test_Exit(self): + stream = BufferedWriter() + self.assertRaises( + SystemExit, + unittest.main, + argv=["foobar"], + testRunner=unittest.TextTestRunner(stream=stream), + exit=True, + testLoader=self.FooBarLoader()) + out = stream.getvalue() + self.assertIn('\nFAIL: testFail ', out) + self.assertIn('\nERROR: testError ', out) + self.assertIn('\nUNEXPECTED SUCCESS: testUnexpectedSuccess ', out) + expected = ('\n\nFAILED (failures=1, errors=1, skipped=1, ' + 'expected failures=1, unexpected successes=1)\n') + self.assertTrue(out.endswith(expected)) + + def test_ExitAsDefault(self): + stream = BufferedWriter() + self.assertRaises( + SystemExit, + unittest.main, + argv=["foobar"], + testRunner=unittest.TextTestRunner(stream=stream), + testLoader=self.FooBarLoader()) + out = stream.getvalue() + self.assertIn('\nFAIL: testFail ', out) + self.assertIn('\nERROR: testError ', out) + self.assertIn('\nUNEXPECTED SUCCESS: testUnexpectedSuccess ', out) + expected = ('\n\nFAILED (failures=1, errors=1, skipped=1, ' + 'expected failures=1, unexpected successes=1)\n') + self.assertTrue(out.endswith(expected)) + + +class InitialisableProgram(unittest.TestProgram): + exit = False + result = None + verbosity = 1 + defaultTest = None + tb_locals = False + testRunner = None + testLoader = unittest.defaultTestLoader + module = '__main__' + progName = 'test' + test = 'test' + def __init__(self, *args): + pass + +RESULT = object() + +class FakeRunner(object): + initArgs = None + test = None + raiseError = 0 + + def __init__(self, **kwargs): + FakeRunner.initArgs = kwargs + if FakeRunner.raiseError: + FakeRunner.raiseError -= 1 + raise TypeError + + def run(self, test): + FakeRunner.test = test + return RESULT + + +@support.requires_subprocess() +class TestCommandLineArgs(unittest.TestCase): + + def setUp(self): + self.program = InitialisableProgram() + self.program.createTests = lambda: None + FakeRunner.initArgs = None + FakeRunner.test = None + FakeRunner.raiseError = 0 + + def testVerbosity(self): + program = self.program + + for opt in '-q', '--quiet': + program.verbosity = 1 + program.parseArgs([None, opt]) + self.assertEqual(program.verbosity, 0) + + for opt in '-v', '--verbose': + program.verbosity = 1 + program.parseArgs([None, opt]) + self.assertEqual(program.verbosity, 2) + + def testBufferCatchFailfast(self): + program = self.program + for arg, attr in (('buffer', 'buffer'), ('failfast', 'failfast'), + ('catch', 'catchbreak')): + + setattr(program, attr, None) + program.parseArgs([None]) + self.assertIs(getattr(program, attr), False) + + false = [] + setattr(program, attr, false) + program.parseArgs([None]) + self.assertIs(getattr(program, attr), false) + + true = [42] + setattr(program, attr, true) + program.parseArgs([None]) + self.assertIs(getattr(program, attr), true) + + short_opt = '-%s' % arg[0] + long_opt = '--%s' % arg + for opt in short_opt, long_opt: + setattr(program, attr, None) + program.parseArgs([None, opt]) + self.assertIs(getattr(program, attr), True) + + setattr(program, attr, False) + with support.captured_stderr() as stderr, \ + self.assertRaises(SystemExit) as cm: + program.parseArgs([None, opt]) + self.assertEqual(cm.exception.args, (2,)) + + setattr(program, attr, True) + with support.captured_stderr() as stderr, \ + self.assertRaises(SystemExit) as cm: + program.parseArgs([None, opt]) + self.assertEqual(cm.exception.args, (2,)) + + def testWarning(self): + """Test the warnings argument""" + # see #10535 + class FakeTP(unittest.TestProgram): + def parseArgs(self, *args, **kw): pass + def runTests(self, *args, **kw): pass + warnoptions = sys.warnoptions[:] + try: + sys.warnoptions[:] = [] + # no warn options, no arg -> default + self.assertEqual(FakeTP().warnings, 'default') + # no warn options, w/ arg -> arg value + self.assertEqual(FakeTP(warnings='ignore').warnings, 'ignore') + sys.warnoptions[:] = ['somevalue'] + # warn options, no arg -> None + # warn options, w/ arg -> arg value + self.assertEqual(FakeTP().warnings, None) + self.assertEqual(FakeTP(warnings='ignore').warnings, 'ignore') + finally: + sys.warnoptions[:] = warnoptions + + def testRunTestsRunnerClass(self): + program = self.program + + program.testRunner = FakeRunner + program.verbosity = 'verbosity' + program.failfast = 'failfast' + program.buffer = 'buffer' + program.warnings = 'warnings' + + program.runTests() + + self.assertEqual(FakeRunner.initArgs, {'verbosity': 'verbosity', + 'failfast': 'failfast', + 'buffer': 'buffer', + 'tb_locals': False, + 'warnings': 'warnings'}) + self.assertEqual(FakeRunner.test, 'test') + self.assertIs(program.result, RESULT) + + def testRunTestsRunnerInstance(self): + program = self.program + + program.testRunner = FakeRunner() + FakeRunner.initArgs = None + + program.runTests() + + # A new FakeRunner should not have been instantiated + self.assertIsNone(FakeRunner.initArgs) + + self.assertEqual(FakeRunner.test, 'test') + self.assertIs(program.result, RESULT) + + def test_locals(self): + program = self.program + + program.testRunner = FakeRunner + program.parseArgs([None, '--locals']) + self.assertEqual(True, program.tb_locals) + program.runTests() + self.assertEqual(FakeRunner.initArgs, {'buffer': False, + 'failfast': False, + 'tb_locals': True, + 'verbosity': 1, + 'warnings': None}) + + def testRunTestsOldRunnerClass(self): + program = self.program + + # Two TypeErrors are needed to fall all the way back to old-style + # runners - one to fail tb_locals, one to fail buffer etc. + FakeRunner.raiseError = 2 + program.testRunner = FakeRunner + program.verbosity = 'verbosity' + program.failfast = 'failfast' + program.buffer = 'buffer' + program.test = 'test' + + program.runTests() + + # If initialising raises a type error it should be retried + # without the new keyword arguments + self.assertEqual(FakeRunner.initArgs, {}) + self.assertEqual(FakeRunner.test, 'test') + self.assertIs(program.result, RESULT) + + def testCatchBreakInstallsHandler(self): + module = sys.modules['unittest.main'] + original = module.installHandler + def restore(): + module.installHandler = original + self.addCleanup(restore) + + self.installed = False + def fakeInstallHandler(): + self.installed = True + module.installHandler = fakeInstallHandler + + program = self.program + program.catchbreak = True + + program.testRunner = FakeRunner + + program.runTests() + self.assertTrue(self.installed) + + def _patch_isfile(self, names, exists=True): + def isfile(path): + return path in names + original = os.path.isfile + os.path.isfile = isfile + def restore(): + os.path.isfile = original + self.addCleanup(restore) + + + def testParseArgsFileNames(self): + # running tests with filenames instead of module names + program = self.program + argv = ['progname', 'foo.py', 'bar.Py', 'baz.PY', 'wing.txt'] + self._patch_isfile(argv) + + program.createTests = lambda: None + program.parseArgs(argv) + + # note that 'wing.txt' is not a Python file so the name should + # *not* be converted to a module name + expected = ['foo', 'bar', 'baz', 'wing.txt'] + self.assertEqual(program.testNames, expected) + + + def testParseArgsFilePaths(self): + program = self.program + argv = ['progname', 'foo/bar/baz.py', 'green\\red.py'] + self._patch_isfile(argv) + + program.createTests = lambda: None + program.parseArgs(argv) + + expected = ['foo.bar.baz', 'green.red'] + self.assertEqual(program.testNames, expected) + + + def testParseArgsNonExistentFiles(self): + program = self.program + argv = ['progname', 'foo/bar/baz.py', 'green\\red.py'] + self._patch_isfile([]) + + program.createTests = lambda: None + program.parseArgs(argv) + + self.assertEqual(program.testNames, argv[1:]) + + def testParseArgsAbsolutePathsThatCanBeConverted(self): + cur_dir = os.getcwd() + program = self.program + def _join(name): + return os.path.join(cur_dir, name) + argv = ['progname', _join('foo/bar/baz.py'), _join('green\\red.py')] + self._patch_isfile(argv) + + program.createTests = lambda: None + program.parseArgs(argv) + + expected = ['foo.bar.baz', 'green.red'] + self.assertEqual(program.testNames, expected) + + def testParseArgsAbsolutePathsThatCannotBeConverted(self): + program = self.program + # even on Windows '/...' is considered absolute by os.path.abspath + argv = ['progname', '/foo/bar/baz.py', '/green/red.py'] + self._patch_isfile(argv) + + program.createTests = lambda: None + program.parseArgs(argv) + + self.assertEqual(program.testNames, argv[1:]) + + # it may be better to use platform specific functions to normalise paths + # rather than accepting '.PY' and '\' as file separator on Linux / Mac + # it would also be better to check that a filename is a valid module + # identifier (we have a regex for this in loader.py) + # for invalid filenames should we raise a useful error rather than + # leaving the current error message (import of filename fails) in place? + + def testParseArgsSelectedTestNames(self): + program = self.program + argv = ['progname', '-k', 'foo', '-k', 'bar', '-k', '*pat*'] + + program.createTests = lambda: None + program.parseArgs(argv) + + self.assertEqual(program.testNamePatterns, ['*foo*', '*bar*', '*pat*']) + + def testSelectedTestNamesFunctionalTest(self): + def run_unittest(args): + # Use -E to ignore PYTHONSAFEPATH env var + cmd = [sys.executable, '-E', '-m', 'unittest'] + args + p = subprocess.Popen(cmd, + stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, cwd=os.path.dirname(__file__)) + with p: + _, stderr = p.communicate() + return stderr.decode() + + t = '_test_warnings' + self.assertIn('Ran 7 tests', run_unittest([t])) + self.assertIn('Ran 7 tests', run_unittest(['-k', 'TestWarnings', t])) + self.assertIn('Ran 7 tests', run_unittest(['discover', '-p', '*_test*', '-k', 'TestWarnings'])) + self.assertIn('Ran 2 tests', run_unittest(['-k', 'f', t])) + self.assertIn('Ran 7 tests', run_unittest(['-k', 't', t])) + self.assertIn('Ran 3 tests', run_unittest(['-k', '*t', t])) + self.assertIn('Ran 7 tests', run_unittest(['-k', '*test_warnings.*Warning*', t])) + self.assertIn('Ran 1 test', run_unittest(['-k', '*test_warnings.*warning*', t])) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/test_result.py b/Lib/test/test_unittest/test_result.py new file mode 100644 index 0000000..b0cc3d8 --- /dev/null +++ b/Lib/test/test_unittest/test_result.py @@ -0,0 +1,1331 @@ +import io +import sys +import textwrap + +from test.support import warnings_helper, captured_stdout, captured_stderr + +import traceback +import unittest +from unittest.util import strclass + + +class MockTraceback(object): + class TracebackException: + def __init__(self, *args, **kwargs): + self.capture_locals = kwargs.get('capture_locals', False) + def format(self): + result = ['A traceback'] + if self.capture_locals: + result.append('locals') + return result + +def restore_traceback(): + unittest.result.traceback = traceback + + +def bad_cleanup1(): + print('do cleanup1') + raise TypeError('bad cleanup1') + + +def bad_cleanup2(): + print('do cleanup2') + raise ValueError('bad cleanup2') + + +class BufferedWriter: + def __init__(self): + self.result = '' + self.buffer = '' + + def write(self, arg): + self.buffer += arg + + def flush(self): + self.result += self.buffer + self.buffer = '' + + def getvalue(self): + return self.result + + +class Test_TestResult(unittest.TestCase): + # Note: there are not separate tests for TestResult.wasSuccessful(), + # TestResult.errors, TestResult.failures, TestResult.testsRun or + # TestResult.shouldStop because these only have meaning in terms of + # other TestResult methods. + # + # Accordingly, tests for the aforenamed attributes are incorporated + # in with the tests for the defining methods. + ################################################################ + + def test_init(self): + result = unittest.TestResult() + + self.assertTrue(result.wasSuccessful()) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.failures), 0) + self.assertEqual(result.testsRun, 0) + self.assertEqual(result.shouldStop, False) + self.assertIsNone(result._stdout_buffer) + self.assertIsNone(result._stderr_buffer) + + # "This method can be called to signal that the set of tests being + # run should be aborted by setting the TestResult's shouldStop + # attribute to True." + def test_stop(self): + result = unittest.TestResult() + + result.stop() + + self.assertEqual(result.shouldStop, True) + + # "Called when the test case test is about to be run. The default + # implementation simply increments the instance's testsRun counter." + def test_startTest(self): + class Foo(unittest.TestCase): + def test_1(self): + pass + + test = Foo('test_1') + + result = unittest.TestResult() + + result.startTest(test) + + self.assertTrue(result.wasSuccessful()) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.failures), 0) + self.assertEqual(result.testsRun, 1) + self.assertEqual(result.shouldStop, False) + + result.stopTest(test) + + # "Called after the test case test has been executed, regardless of + # the outcome. The default implementation does nothing." + def test_stopTest(self): + class Foo(unittest.TestCase): + def test_1(self): + pass + + test = Foo('test_1') + + result = unittest.TestResult() + + result.startTest(test) + + self.assertTrue(result.wasSuccessful()) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.failures), 0) + self.assertEqual(result.testsRun, 1) + self.assertEqual(result.shouldStop, False) + + result.stopTest(test) + + # Same tests as above; make sure nothing has changed + self.assertTrue(result.wasSuccessful()) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.failures), 0) + self.assertEqual(result.testsRun, 1) + self.assertEqual(result.shouldStop, False) + + # "Called before and after tests are run. The default implementation does nothing." + def test_startTestRun_stopTestRun(self): + result = unittest.TestResult() + result.startTestRun() + result.stopTestRun() + + # "addSuccess(test)" + # ... + # "Called when the test case test succeeds" + # ... + # "wasSuccessful() - Returns True if all tests run so far have passed, + # otherwise returns False" + # ... + # "testsRun - The total number of tests run so far." + # ... + # "errors - A list containing 2-tuples of TestCase instances and + # formatted tracebacks. Each tuple represents a test which raised an + # unexpected exception. Contains formatted + # tracebacks instead of sys.exc_info() results." + # ... + # "failures - A list containing 2-tuples of TestCase instances and + # formatted tracebacks. Each tuple represents a test where a failure was + # explicitly signalled using the TestCase.fail*() or TestCase.assert*() + # methods. Contains formatted tracebacks instead + # of sys.exc_info() results." + def test_addSuccess(self): + class Foo(unittest.TestCase): + def test_1(self): + pass + + test = Foo('test_1') + + result = unittest.TestResult() + + result.startTest(test) + result.addSuccess(test) + result.stopTest(test) + + self.assertTrue(result.wasSuccessful()) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.failures), 0) + self.assertEqual(result.testsRun, 1) + self.assertEqual(result.shouldStop, False) + + # "addFailure(test, err)" + # ... + # "Called when the test case test signals a failure. err is a tuple of + # the form returned by sys.exc_info(): (type, value, traceback)" + # ... + # "wasSuccessful() - Returns True if all tests run so far have passed, + # otherwise returns False" + # ... + # "testsRun - The total number of tests run so far." + # ... + # "errors - A list containing 2-tuples of TestCase instances and + # formatted tracebacks. Each tuple represents a test which raised an + # unexpected exception. Contains formatted + # tracebacks instead of sys.exc_info() results." + # ... + # "failures - A list containing 2-tuples of TestCase instances and + # formatted tracebacks. Each tuple represents a test where a failure was + # explicitly signalled using the TestCase.fail*() or TestCase.assert*() + # methods. Contains formatted tracebacks instead + # of sys.exc_info() results." + def test_addFailure(self): + class Foo(unittest.TestCase): + def test_1(self): + pass + + test = Foo('test_1') + try: + test.fail("foo") + except: + exc_info_tuple = sys.exc_info() + + result = unittest.TestResult() + + result.startTest(test) + result.addFailure(test, exc_info_tuple) + result.stopTest(test) + + self.assertFalse(result.wasSuccessful()) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.failures), 1) + self.assertEqual(result.testsRun, 1) + self.assertEqual(result.shouldStop, False) + + test_case, formatted_exc = result.failures[0] + self.assertIs(test_case, test) + self.assertIsInstance(formatted_exc, str) + + def test_addFailure_filter_traceback_frames(self): + class Foo(unittest.TestCase): + def test_1(self): + pass + + test = Foo('test_1') + def get_exc_info(): + try: + test.fail("foo") + except: + return sys.exc_info() + + exc_info_tuple = get_exc_info() + + full_exc = traceback.format_exception(*exc_info_tuple) + + result = unittest.TestResult() + result.startTest(test) + result.addFailure(test, exc_info_tuple) + result.stopTest(test) + + formatted_exc = result.failures[0][1] + dropped = [l for l in full_exc if l not in formatted_exc] + self.assertEqual(len(dropped), 1) + self.assertIn("raise self.failureException(msg)", dropped[0]) + + def test_addFailure_filter_traceback_frames_context(self): + class Foo(unittest.TestCase): + def test_1(self): + pass + + test = Foo('test_1') + def get_exc_info(): + try: + try: + test.fail("foo") + except: + raise ValueError(42) + except: + return sys.exc_info() + + exc_info_tuple = get_exc_info() + + full_exc = traceback.format_exception(*exc_info_tuple) + + result = unittest.TestResult() + result.startTest(test) + result.addFailure(test, exc_info_tuple) + result.stopTest(test) + + formatted_exc = result.failures[0][1] + dropped = [l for l in full_exc if l not in formatted_exc] + self.assertEqual(len(dropped), 1) + self.assertIn("raise self.failureException(msg)", dropped[0]) + + # "addError(test, err)" + # ... + # "Called when the test case test raises an unexpected exception err + # is a tuple of the form returned by sys.exc_info(): + # (type, value, traceback)" + # ... + # "wasSuccessful() - Returns True if all tests run so far have passed, + # otherwise returns False" + # ... + # "testsRun - The total number of tests run so far." + # ... + # "errors - A list containing 2-tuples of TestCase instances and + # formatted tracebacks. Each tuple represents a test which raised an + # unexpected exception. Contains formatted + # tracebacks instead of sys.exc_info() results." + # ... + # "failures - A list containing 2-tuples of TestCase instances and + # formatted tracebacks. Each tuple represents a test where a failure was + # explicitly signalled using the TestCase.fail*() or TestCase.assert*() + # methods. Contains formatted tracebacks instead + # of sys.exc_info() results." + def test_addError(self): + class Foo(unittest.TestCase): + def test_1(self): + pass + + test = Foo('test_1') + try: + raise TypeError() + except: + exc_info_tuple = sys.exc_info() + + result = unittest.TestResult() + + result.startTest(test) + result.addError(test, exc_info_tuple) + result.stopTest(test) + + self.assertFalse(result.wasSuccessful()) + self.assertEqual(len(result.errors), 1) + self.assertEqual(len(result.failures), 0) + self.assertEqual(result.testsRun, 1) + self.assertEqual(result.shouldStop, False) + + test_case, formatted_exc = result.errors[0] + self.assertIs(test_case, test) + self.assertIsInstance(formatted_exc, str) + + def test_addError_locals(self): + class Foo(unittest.TestCase): + def test_1(self): + 1/0 + + test = Foo('test_1') + result = unittest.TestResult() + result.tb_locals = True + + unittest.result.traceback = MockTraceback + self.addCleanup(restore_traceback) + result.startTestRun() + test.run(result) + result.stopTestRun() + + self.assertEqual(len(result.errors), 1) + test_case, formatted_exc = result.errors[0] + self.assertEqual('A tracebacklocals', formatted_exc) + + def test_addSubTest(self): + class Foo(unittest.TestCase): + def test_1(self): + nonlocal subtest + with self.subTest(foo=1): + subtest = self._subtest + try: + 1/0 + except ZeroDivisionError: + exc_info_tuple = sys.exc_info() + # Register an error by hand (to check the API) + result.addSubTest(test, subtest, exc_info_tuple) + # Now trigger a failure + self.fail("some recognizable failure") + + subtest = None + test = Foo('test_1') + result = unittest.TestResult() + + test.run(result) + + self.assertFalse(result.wasSuccessful()) + self.assertEqual(len(result.errors), 1) + self.assertEqual(len(result.failures), 1) + self.assertEqual(result.testsRun, 1) + self.assertEqual(result.shouldStop, False) + + test_case, formatted_exc = result.errors[0] + self.assertIs(test_case, subtest) + self.assertIn("ZeroDivisionError", formatted_exc) + test_case, formatted_exc = result.failures[0] + self.assertIs(test_case, subtest) + self.assertIn("some recognizable failure", formatted_exc) + + def testStackFrameTrimming(self): + class Frame(object): + class tb_frame(object): + f_globals = {} + result = unittest.TestResult() + self.assertFalse(result._is_relevant_tb_level(Frame)) + + Frame.tb_frame.f_globals['__unittest'] = True + self.assertTrue(result._is_relevant_tb_level(Frame)) + + def testFailFast(self): + result = unittest.TestResult() + result._exc_info_to_string = lambda *_: '' + result.failfast = True + result.addError(None, None) + self.assertTrue(result.shouldStop) + + result = unittest.TestResult() + result._exc_info_to_string = lambda *_: '' + result.failfast = True + result.addFailure(None, None) + self.assertTrue(result.shouldStop) + + result = unittest.TestResult() + result._exc_info_to_string = lambda *_: '' + result.failfast = True + result.addUnexpectedSuccess(None) + self.assertTrue(result.shouldStop) + + def testFailFastSetByRunner(self): + stream = BufferedWriter() + runner = unittest.TextTestRunner(stream=stream, failfast=True) + def test(result): + self.assertTrue(result.failfast) + result = runner.run(test) + stream.flush() + self.assertTrue(stream.getvalue().endswith('\n\nOK\n')) + + +class Test_TextTestResult(unittest.TestCase): + maxDiff = None + + def testGetDescriptionWithoutDocstring(self): + result = unittest.TextTestResult(None, True, 1) + self.assertEqual( + result.getDescription(self), + 'testGetDescriptionWithoutDocstring (' + __name__ + + '.Test_TextTestResult.testGetDescriptionWithoutDocstring)') + + def testGetSubTestDescriptionWithoutDocstring(self): + with self.subTest(foo=1, bar=2): + result = unittest.TextTestResult(None, True, 1) + self.assertEqual( + result.getDescription(self._subtest), + 'testGetSubTestDescriptionWithoutDocstring (' + __name__ + + '.Test_TextTestResult.testGetSubTestDescriptionWithoutDocstring) (foo=1, bar=2)') + + with self.subTest('some message'): + result = unittest.TextTestResult(None, True, 1) + self.assertEqual( + result.getDescription(self._subtest), + 'testGetSubTestDescriptionWithoutDocstring (' + __name__ + + '.Test_TextTestResult.testGetSubTestDescriptionWithoutDocstring) [some message]') + + def testGetSubTestDescriptionWithoutDocstringAndParams(self): + with self.subTest(): + result = unittest.TextTestResult(None, True, 1) + self.assertEqual( + result.getDescription(self._subtest), + 'testGetSubTestDescriptionWithoutDocstringAndParams ' + '(' + __name__ + '.Test_TextTestResult.testGetSubTestDescriptionWithoutDocstringAndParams) ' + '(<subtest>)') + + def testGetSubTestDescriptionForFalsyValues(self): + expected = 'testGetSubTestDescriptionForFalsyValues (%s.Test_TextTestResult.testGetSubTestDescriptionForFalsyValues) [%s]' + result = unittest.TextTestResult(None, True, 1) + for arg in [0, None, []]: + with self.subTest(arg): + self.assertEqual( + result.getDescription(self._subtest), + expected % (__name__, arg) + ) + + def testGetNestedSubTestDescriptionWithoutDocstring(self): + with self.subTest(foo=1): + with self.subTest(baz=2, bar=3): + result = unittest.TextTestResult(None, True, 1) + self.assertEqual( + result.getDescription(self._subtest), + 'testGetNestedSubTestDescriptionWithoutDocstring ' + '(' + __name__ + '.Test_TextTestResult.testGetNestedSubTestDescriptionWithoutDocstring) ' + '(baz=2, bar=3, foo=1)') + + def testGetDuplicatedNestedSubTestDescriptionWithoutDocstring(self): + with self.subTest(foo=1, bar=2): + with self.subTest(baz=3, bar=4): + result = unittest.TextTestResult(None, True, 1) + self.assertEqual( + result.getDescription(self._subtest), + 'testGetDuplicatedNestedSubTestDescriptionWithoutDocstring ' + '(' + __name__ + '.Test_TextTestResult.testGetDuplicatedNestedSubTestDescriptionWithoutDocstring) (baz=3, bar=4, foo=1)') + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def testGetDescriptionWithOneLineDocstring(self): + """Tests getDescription() for a method with a docstring.""" + result = unittest.TextTestResult(None, True, 1) + self.assertEqual( + result.getDescription(self), + ('testGetDescriptionWithOneLineDocstring ' + '(' + __name__ + '.Test_TextTestResult.testGetDescriptionWithOneLineDocstring)\n' + 'Tests getDescription() for a method with a docstring.')) + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def testGetSubTestDescriptionWithOneLineDocstring(self): + """Tests getDescription() for a method with a docstring.""" + result = unittest.TextTestResult(None, True, 1) + with self.subTest(foo=1, bar=2): + self.assertEqual( + result.getDescription(self._subtest), + ('testGetSubTestDescriptionWithOneLineDocstring ' + '(' + __name__ + '.Test_TextTestResult.testGetSubTestDescriptionWithOneLineDocstring) ' + '(foo=1, bar=2)\n' + + 'Tests getDescription() for a method with a docstring.')) + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def testGetDescriptionWithMultiLineDocstring(self): + """Tests getDescription() for a method with a longer docstring. + The second line of the docstring. + """ + result = unittest.TextTestResult(None, True, 1) + self.assertEqual( + result.getDescription(self), + ('testGetDescriptionWithMultiLineDocstring ' + '(' + __name__ + '.Test_TextTestResult.testGetDescriptionWithMultiLineDocstring)\n' + 'Tests getDescription() for a method with a longer ' + 'docstring.')) + + @unittest.skipIf(sys.flags.optimize >= 2, + "Docstrings are omitted with -O2 and above") + def testGetSubTestDescriptionWithMultiLineDocstring(self): + """Tests getDescription() for a method with a longer docstring. + The second line of the docstring. + """ + result = unittest.TextTestResult(None, True, 1) + with self.subTest(foo=1, bar=2): + self.assertEqual( + result.getDescription(self._subtest), + ('testGetSubTestDescriptionWithMultiLineDocstring ' + '(' + __name__ + '.Test_TextTestResult.testGetSubTestDescriptionWithMultiLineDocstring) ' + '(foo=1, bar=2)\n' + 'Tests getDescription() for a method with a longer ' + 'docstring.')) + + class Test(unittest.TestCase): + def testSuccess(self): + pass + def testSkip(self): + self.skipTest('skip') + def testFail(self): + self.fail('fail') + def testError(self): + raise Exception('error') + @unittest.expectedFailure + def testExpectedFailure(self): + self.fail('fail') + @unittest.expectedFailure + def testUnexpectedSuccess(self): + pass + def testSubTestSuccess(self): + with self.subTest('one', a=1): + pass + with self.subTest('two', b=2): + pass + def testSubTestMixed(self): + with self.subTest('success', a=1): + pass + with self.subTest('skip', b=2): + self.skipTest('skip') + with self.subTest('fail', c=3): + self.fail('fail') + with self.subTest('error', d=4): + raise Exception('error') + + tearDownError = None + def tearDown(self): + if self.tearDownError is not None: + raise self.tearDownError + + def _run_test(self, test_name, verbosity, tearDownError=None): + stream = BufferedWriter() + stream = unittest.runner._WritelnDecorator(stream) + result = unittest.TextTestResult(stream, True, verbosity) + test = self.Test(test_name) + test.tearDownError = tearDownError + test.run(result) + return stream.getvalue() + + def testDotsOutput(self): + self.assertEqual(self._run_test('testSuccess', 1), '.') + self.assertEqual(self._run_test('testSkip', 1), 's') + self.assertEqual(self._run_test('testFail', 1), 'F') + self.assertEqual(self._run_test('testError', 1), 'E') + self.assertEqual(self._run_test('testExpectedFailure', 1), 'x') + self.assertEqual(self._run_test('testUnexpectedSuccess', 1), 'u') + + def testLongOutput(self): + classname = f'{__name__}.{self.Test.__qualname__}' + self.assertEqual(self._run_test('testSuccess', 2), + f'testSuccess ({classname}.testSuccess) ... ok\n') + self.assertEqual(self._run_test('testSkip', 2), + f"testSkip ({classname}.testSkip) ... skipped 'skip'\n") + self.assertEqual(self._run_test('testFail', 2), + f'testFail ({classname}.testFail) ... FAIL\n') + self.assertEqual(self._run_test('testError', 2), + f'testError ({classname}.testError) ... ERROR\n') + self.assertEqual(self._run_test('testExpectedFailure', 2), + f'testExpectedFailure ({classname}.testExpectedFailure) ... expected failure\n') + self.assertEqual(self._run_test('testUnexpectedSuccess', 2), + f'testUnexpectedSuccess ({classname}.testUnexpectedSuccess) ... unexpected success\n') + + def testDotsOutputSubTestSuccess(self): + self.assertEqual(self._run_test('testSubTestSuccess', 1), '.') + + def testLongOutputSubTestSuccess(self): + classname = f'{__name__}.{self.Test.__qualname__}' + self.assertEqual(self._run_test('testSubTestSuccess', 2), + f'testSubTestSuccess ({classname}.testSubTestSuccess) ... ok\n') + + def testDotsOutputSubTestMixed(self): + self.assertEqual(self._run_test('testSubTestMixed', 1), 'sFE') + + def testLongOutputSubTestMixed(self): + classname = f'{__name__}.{self.Test.__qualname__}' + self.assertEqual(self._run_test('testSubTestMixed', 2), + f'testSubTestMixed ({classname}.testSubTestMixed) ... \n' + f" testSubTestMixed ({classname}.testSubTestMixed) [skip] (b=2) ... skipped 'skip'\n" + f' testSubTestMixed ({classname}.testSubTestMixed) [fail] (c=3) ... FAIL\n' + f' testSubTestMixed ({classname}.testSubTestMixed) [error] (d=4) ... ERROR\n') + + def testDotsOutputTearDownFail(self): + out = self._run_test('testSuccess', 1, AssertionError('fail')) + self.assertEqual(out, 'F') + out = self._run_test('testError', 1, AssertionError('fail')) + self.assertEqual(out, 'EF') + out = self._run_test('testFail', 1, Exception('error')) + self.assertEqual(out, 'FE') + out = self._run_test('testSkip', 1, AssertionError('fail')) + self.assertEqual(out, 'sF') + + def testLongOutputTearDownFail(self): + classname = f'{__name__}.{self.Test.__qualname__}' + out = self._run_test('testSuccess', 2, AssertionError('fail')) + self.assertEqual(out, + f'testSuccess ({classname}.testSuccess) ... FAIL\n') + out = self._run_test('testError', 2, AssertionError('fail')) + self.assertEqual(out, + f'testError ({classname}.testError) ... ERROR\n' + f'testError ({classname}.testError) ... FAIL\n') + out = self._run_test('testFail', 2, Exception('error')) + self.assertEqual(out, + f'testFail ({classname}.testFail) ... FAIL\n' + f'testFail ({classname}.testFail) ... ERROR\n') + out = self._run_test('testSkip', 2, AssertionError('fail')) + self.assertEqual(out, + f"testSkip ({classname}.testSkip) ... skipped 'skip'\n" + f'testSkip ({classname}.testSkip) ... FAIL\n') + + +classDict = dict(unittest.TestResult.__dict__) +for m in ('addSkip', 'addExpectedFailure', 'addUnexpectedSuccess', + '__init__'): + del classDict[m] + +def __init__(self, stream=None, descriptions=None, verbosity=None): + self.failures = [] + self.errors = [] + self.testsRun = 0 + self.shouldStop = False + self.buffer = False + self.tb_locals = False + +classDict['__init__'] = __init__ +OldResult = type('OldResult', (object,), classDict) + +class Test_OldTestResult(unittest.TestCase): + + def assertOldResultWarning(self, test, failures): + with warnings_helper.check_warnings( + ("TestResult has no add.+ method,", RuntimeWarning)): + result = OldResult() + test.run(result) + self.assertEqual(len(result.failures), failures) + + def testOldTestResult(self): + class Test(unittest.TestCase): + def testSkip(self): + self.skipTest('foobar') + @unittest.expectedFailure + def testExpectedFail(self): + raise TypeError + @unittest.expectedFailure + def testUnexpectedSuccess(self): + pass + + for test_name, should_pass in (('testSkip', True), + ('testExpectedFail', True), + ('testUnexpectedSuccess', False)): + test = Test(test_name) + self.assertOldResultWarning(test, int(not should_pass)) + + def testOldTestTesultSetup(self): + class Test(unittest.TestCase): + def setUp(self): + self.skipTest('no reason') + def testFoo(self): + pass + self.assertOldResultWarning(Test('testFoo'), 0) + + def testOldTestResultClass(self): + @unittest.skip('no reason') + class Test(unittest.TestCase): + def testFoo(self): + pass + self.assertOldResultWarning(Test('testFoo'), 0) + + def testOldResultWithRunner(self): + class Test(unittest.TestCase): + def testFoo(self): + pass + runner = unittest.TextTestRunner(resultclass=OldResult, + stream=io.StringIO()) + # This will raise an exception if TextTestRunner can't handle old + # test result objects + runner.run(Test('testFoo')) + + +class TestOutputBuffering(unittest.TestCase): + + def setUp(self): + self._real_out = sys.stdout + self._real_err = sys.stderr + + def tearDown(self): + sys.stdout = self._real_out + sys.stderr = self._real_err + + def testBufferOutputOff(self): + real_out = self._real_out + real_err = self._real_err + + result = unittest.TestResult() + self.assertFalse(result.buffer) + + self.assertIs(real_out, sys.stdout) + self.assertIs(real_err, sys.stderr) + + result.startTest(self) + + self.assertIs(real_out, sys.stdout) + self.assertIs(real_err, sys.stderr) + + def testBufferOutputStartTestAddSuccess(self): + real_out = self._real_out + real_err = self._real_err + + result = unittest.TestResult() + self.assertFalse(result.buffer) + + result.buffer = True + + self.assertIs(real_out, sys.stdout) + self.assertIs(real_err, sys.stderr) + + result.startTest(self) + + self.assertIsNot(real_out, sys.stdout) + self.assertIsNot(real_err, sys.stderr) + self.assertIsInstance(sys.stdout, io.StringIO) + self.assertIsInstance(sys.stderr, io.StringIO) + self.assertIsNot(sys.stdout, sys.stderr) + + out_stream = sys.stdout + err_stream = sys.stderr + + result._original_stdout = io.StringIO() + result._original_stderr = io.StringIO() + + print('foo') + print('bar', file=sys.stderr) + + self.assertEqual(out_stream.getvalue(), 'foo\n') + self.assertEqual(err_stream.getvalue(), 'bar\n') + + self.assertEqual(result._original_stdout.getvalue(), '') + self.assertEqual(result._original_stderr.getvalue(), '') + + result.addSuccess(self) + result.stopTest(self) + + self.assertIs(sys.stdout, result._original_stdout) + self.assertIs(sys.stderr, result._original_stderr) + + self.assertEqual(result._original_stdout.getvalue(), '') + self.assertEqual(result._original_stderr.getvalue(), '') + + self.assertEqual(out_stream.getvalue(), '') + self.assertEqual(err_stream.getvalue(), '') + + + def getStartedResult(self): + result = unittest.TestResult() + result.buffer = True + result.startTest(self) + return result + + def testBufferOutputAddErrorOrFailure(self): + unittest.result.traceback = MockTraceback + self.addCleanup(restore_traceback) + + for message_attr, add_attr, include_error in [ + ('errors', 'addError', True), + ('failures', 'addFailure', False), + ('errors', 'addError', True), + ('failures', 'addFailure', False) + ]: + result = self.getStartedResult() + buffered_out = sys.stdout + buffered_err = sys.stderr + result._original_stdout = io.StringIO() + result._original_stderr = io.StringIO() + + print('foo', file=sys.stdout) + if include_error: + print('bar', file=sys.stderr) + + + addFunction = getattr(result, add_attr) + addFunction(self, (None, None, None)) + result.stopTest(self) + + result_list = getattr(result, message_attr) + self.assertEqual(len(result_list), 1) + + test, message = result_list[0] + expectedOutMessage = textwrap.dedent(""" + Stdout: + foo + """) + expectedErrMessage = '' + if include_error: + expectedErrMessage = textwrap.dedent(""" + Stderr: + bar + """) + + expectedFullMessage = 'A traceback%s%s' % (expectedOutMessage, expectedErrMessage) + + self.assertIs(test, self) + self.assertEqual(result._original_stdout.getvalue(), expectedOutMessage) + self.assertEqual(result._original_stderr.getvalue(), expectedErrMessage) + self.assertMultiLineEqual(message, expectedFullMessage) + + def testBufferSetUp(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def setUp(self): + print('set up') + 1/0 + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\nset up\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 1) + description = f'test_foo ({strclass(Foo)}.test_foo)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(str(test_case), description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferTearDown(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def tearDown(self): + print('tear down') + 1/0 + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\ntear down\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 1) + description = f'test_foo ({strclass(Foo)}.test_foo)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(str(test_case), description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferDoCleanups(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def setUp(self): + print('set up') + self.addCleanup(bad_cleanup1) + self.addCleanup(bad_cleanup2) + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\nset up\ndo cleanup2\ndo cleanup1\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 2) + description = f'test_foo ({strclass(Foo)}.test_foo)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(str(test_case), description) + self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn('\nStdout:\nset up\ndo cleanup2\n', formatted_exc) + self.assertNotIn('\ndo cleanup1\n', formatted_exc) + test_case, formatted_exc = result.errors[1] + self.assertEqual(str(test_case), description) + self.assertIn('TypeError: bad cleanup1', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferSetUp_DoCleanups(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def setUp(self): + print('set up') + self.addCleanup(bad_cleanup1) + self.addCleanup(bad_cleanup2) + 1/0 + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\nset up\ndo cleanup2\ndo cleanup1\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 3) + description = f'test_foo ({strclass(Foo)}.test_foo)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(str(test_case), description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn('\nStdout:\nset up\n', formatted_exc) + self.assertNotIn('\ndo cleanup2\n', formatted_exc) + self.assertNotIn('\ndo cleanup1\n', formatted_exc) + test_case, formatted_exc = result.errors[1] + self.assertEqual(str(test_case), description) + self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn('\nStdout:\nset up\ndo cleanup2\n', formatted_exc) + self.assertNotIn('\ndo cleanup1\n', formatted_exc) + test_case, formatted_exc = result.errors[2] + self.assertEqual(str(test_case), description) + self.assertIn('TypeError: bad cleanup1', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferTearDown_DoCleanups(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def setUp(self): + print('set up') + self.addCleanup(bad_cleanup1) + self.addCleanup(bad_cleanup2) + def tearDown(self): + print('tear down') + 1/0 + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\nset up\ntear down\ndo cleanup2\ndo cleanup1\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 3) + description = f'test_foo ({strclass(Foo)}.test_foo)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(str(test_case), description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn('\nStdout:\nset up\ntear down\n', formatted_exc) + self.assertNotIn('\ndo cleanup2\n', formatted_exc) + self.assertNotIn('\ndo cleanup1\n', formatted_exc) + test_case, formatted_exc = result.errors[1] + self.assertEqual(str(test_case), description) + self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn('\nStdout:\nset up\ntear down\ndo cleanup2\n', formatted_exc) + self.assertNotIn('\ndo cleanup1\n', formatted_exc) + test_case, formatted_exc = result.errors[2] + self.assertEqual(str(test_case), description) + self.assertIn('TypeError: bad cleanup1', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferSetupClass(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + @classmethod + def setUpClass(cls): + print('set up class') + 1/0 + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\nset up class\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 1) + description = f'setUpClass ({strclass(Foo)})' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferTearDownClass(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + @classmethod + def tearDownClass(cls): + print('tear down class') + 1/0 + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\ntear down class\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 1) + description = f'tearDownClass ({strclass(Foo)})' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferDoClassCleanups(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + @classmethod + def setUpClass(cls): + print('set up class') + cls.addClassCleanup(bad_cleanup1) + cls.addClassCleanup(bad_cleanup2) + @classmethod + def tearDownClass(cls): + print('tear down class') + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\ntear down class\ndo cleanup2\ndo cleanup1\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 2) + description = f'tearDownClass ({strclass(Foo)})' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + test_case, formatted_exc = result.errors[1] + self.assertEqual(test_case.description, description) + self.assertIn('TypeError: bad cleanup1', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferSetupClass_DoClassCleanups(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + @classmethod + def setUpClass(cls): + print('set up class') + cls.addClassCleanup(bad_cleanup1) + cls.addClassCleanup(bad_cleanup2) + 1/0 + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\nset up class\ndo cleanup2\ndo cleanup1\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 3) + description = f'setUpClass ({strclass(Foo)})' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn('\nStdout:\nset up class\n', formatted_exc) + test_case, formatted_exc = result.errors[1] + self.assertEqual(test_case.description, description) + self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + test_case, formatted_exc = result.errors[2] + self.assertEqual(test_case.description, description) + self.assertIn('TypeError: bad cleanup1', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferTearDownClass_DoClassCleanups(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + @classmethod + def setUpClass(cls): + print('set up class') + cls.addClassCleanup(bad_cleanup1) + cls.addClassCleanup(bad_cleanup2) + @classmethod + def tearDownClass(cls): + print('tear down class') + 1/0 + def test_foo(self): + pass + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\ntear down class\ndo cleanup2\ndo cleanup1\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 3) + description = f'tearDownClass ({strclass(Foo)})' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn('\nStdout:\ntear down class\n', formatted_exc) + test_case, formatted_exc = result.errors[1] + self.assertEqual(test_case.description, description) + self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + test_case, formatted_exc = result.errors[2] + self.assertEqual(test_case.description, description) + self.assertIn('TypeError: bad cleanup1', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferSetUpModule(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def test_foo(self): + pass + class Module(object): + @staticmethod + def setUpModule(): + print('set up module') + 1/0 + + Foo.__module__ = 'Module' + sys.modules['Module'] = Module + self.addCleanup(sys.modules.pop, 'Module') + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\nset up module\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 1) + description = 'setUpModule (Module)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferTearDownModule(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def test_foo(self): + pass + class Module(object): + @staticmethod + def tearDownModule(): + print('tear down module') + 1/0 + + Foo.__module__ = 'Module' + sys.modules['Module'] = Module + self.addCleanup(sys.modules.pop, 'Module') + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\ntear down module\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 1) + description = 'tearDownModule (Module)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferDoModuleCleanups(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def test_foo(self): + pass + class Module(object): + @staticmethod + def setUpModule(): + print('set up module') + unittest.addModuleCleanup(bad_cleanup1) + unittest.addModuleCleanup(bad_cleanup2) + + Foo.__module__ = 'Module' + sys.modules['Module'] = Module + self.addCleanup(sys.modules.pop, 'Module') + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\ndo cleanup2\ndo cleanup1\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 1) + description = 'tearDownModule (Module)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferSetUpModule_DoModuleCleanups(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def test_foo(self): + pass + class Module(object): + @staticmethod + def setUpModule(): + print('set up module') + unittest.addModuleCleanup(bad_cleanup1) + unittest.addModuleCleanup(bad_cleanup2) + 1/0 + + Foo.__module__ = 'Module' + sys.modules['Module'] = Module + self.addCleanup(sys.modules.pop, 'Module') + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\nset up module\ndo cleanup2\ndo cleanup1\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 2) + description = 'setUpModule (Module)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn('\nStdout:\nset up module\n', formatted_exc) + test_case, formatted_exc = result.errors[1] + self.assertIn(expected_out, formatted_exc) + self.assertEqual(test_case.description, description) + self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + def testBufferTearDownModule_DoModuleCleanups(self): + with captured_stdout() as stdout: + result = unittest.TestResult() + result.buffer = True + + class Foo(unittest.TestCase): + def test_foo(self): + pass + class Module(object): + @staticmethod + def setUpModule(): + print('set up module') + unittest.addModuleCleanup(bad_cleanup1) + unittest.addModuleCleanup(bad_cleanup2) + @staticmethod + def tearDownModule(): + print('tear down module') + 1/0 + + Foo.__module__ = 'Module' + sys.modules['Module'] = Module + self.addCleanup(sys.modules.pop, 'Module') + suite = unittest.TestSuite([Foo('test_foo')]) + suite(result) + expected_out = '\nStdout:\ntear down module\ndo cleanup2\ndo cleanup1\n' + self.assertEqual(stdout.getvalue(), expected_out) + self.assertEqual(len(result.errors), 2) + description = 'tearDownModule (Module)' + test_case, formatted_exc = result.errors[0] + self.assertEqual(test_case.description, description) + self.assertIn('ZeroDivisionError: division by zero', formatted_exc) + self.assertNotIn('ValueError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn('\nStdout:\ntear down module\n', formatted_exc) + test_case, formatted_exc = result.errors[1] + self.assertEqual(test_case.description, description) + self.assertIn('ValueError: bad cleanup2', formatted_exc) + self.assertNotIn('ZeroDivisionError', formatted_exc) + self.assertNotIn('TypeError', formatted_exc) + self.assertIn(expected_out, formatted_exc) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/test_runner.py b/Lib/test/test_unittest/test_runner.py new file mode 100644 index 0000000..9e3a0a9 --- /dev/null +++ b/Lib/test/test_unittest/test_runner.py @@ -0,0 +1,1330 @@ +import io +import os +import sys +import pickle +import subprocess +from test import support + +import unittest +from unittest.case import _Outcome + +from test.test_unittest.support import (LoggingResult, + ResultWithNoStartTestRunStopTestRun) + + +def resultFactory(*_): + return unittest.TestResult() + + +def getRunner(): + return unittest.TextTestRunner(resultclass=resultFactory, + stream=io.StringIO()) + + +def runTests(*cases): + suite = unittest.TestSuite() + for case in cases: + tests = unittest.defaultTestLoader.loadTestsFromTestCase(case) + suite.addTests(tests) + + runner = getRunner() + + # creating a nested suite exposes some potential bugs + realSuite = unittest.TestSuite() + realSuite.addTest(suite) + # adding empty suites to the end exposes potential bugs + suite.addTest(unittest.TestSuite()) + realSuite.addTest(unittest.TestSuite()) + return runner.run(realSuite) + + +def cleanup(ordering, blowUp=False): + if not blowUp: + ordering.append('cleanup_good') + else: + ordering.append('cleanup_exc') + raise Exception('CleanUpExc') + + +class TestCM: + def __init__(self, ordering, enter_result=None): + self.ordering = ordering + self.enter_result = enter_result + + def __enter__(self): + self.ordering.append('enter') + return self.enter_result + + def __exit__(self, *exc_info): + self.ordering.append('exit') + + +class LacksEnterAndExit: + pass +class LacksEnter: + def __exit__(self, *exc_info): + pass +class LacksExit: + def __enter__(self): + pass + + +class TestCleanUp(unittest.TestCase): + def testCleanUp(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + test = TestableTest('testNothing') + self.assertEqual(test._cleanups, []) + + cleanups = [] + + def cleanup1(*args, **kwargs): + cleanups.append((1, args, kwargs)) + + def cleanup2(*args, **kwargs): + cleanups.append((2, args, kwargs)) + + test.addCleanup(cleanup1, 1, 2, 3, four='hello', five='goodbye') + test.addCleanup(cleanup2) + + self.assertEqual(test._cleanups, + [(cleanup1, (1, 2, 3), dict(four='hello', five='goodbye')), + (cleanup2, (), {})]) + + self.assertTrue(test.doCleanups()) + self.assertEqual(cleanups, [(2, (), {}), (1, (1, 2, 3), dict(four='hello', five='goodbye'))]) + + def testCleanUpWithErrors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + test = TestableTest('testNothing') + result = unittest.TestResult() + outcome = test._outcome = _Outcome(result=result) + + CleanUpExc = Exception('foo') + exc2 = Exception('bar') + def cleanup1(): + raise CleanUpExc + + def cleanup2(): + raise exc2 + + test.addCleanup(cleanup1) + test.addCleanup(cleanup2) + + self.assertFalse(test.doCleanups()) + self.assertFalse(outcome.success) + + (_, msg2), (_, msg1) = result.errors + self.assertIn('in cleanup1', msg1) + self.assertIn('raise CleanUpExc', msg1) + self.assertIn('Exception: foo', msg1) + self.assertIn('in cleanup2', msg2) + self.assertIn('raise exc2', msg2) + self.assertIn('Exception: bar', msg2) + + def testCleanupInRun(self): + blowUp = False + ordering = [] + + class TestableTest(unittest.TestCase): + def setUp(self): + ordering.append('setUp') + if blowUp: + raise Exception('foo') + + def testNothing(self): + ordering.append('test') + + def tearDown(self): + ordering.append('tearDown') + + test = TestableTest('testNothing') + + def cleanup1(): + ordering.append('cleanup1') + def cleanup2(): + ordering.append('cleanup2') + test.addCleanup(cleanup1) + test.addCleanup(cleanup2) + + def success(some_test): + self.assertEqual(some_test, test) + ordering.append('success') + + result = unittest.TestResult() + result.addSuccess = success + + test.run(result) + self.assertEqual(ordering, ['setUp', 'test', 'tearDown', + 'cleanup2', 'cleanup1', 'success']) + + blowUp = True + ordering = [] + test = TestableTest('testNothing') + test.addCleanup(cleanup1) + test.run(result) + self.assertEqual(ordering, ['setUp', 'cleanup1']) + + def testTestCaseDebugExecutesCleanups(self): + ordering = [] + + class TestableTest(unittest.TestCase): + def setUp(self): + ordering.append('setUp') + self.addCleanup(cleanup1) + + def testNothing(self): + ordering.append('test') + + def tearDown(self): + ordering.append('tearDown') + + test = TestableTest('testNothing') + + def cleanup1(): + ordering.append('cleanup1') + test.addCleanup(cleanup2) + def cleanup2(): + ordering.append('cleanup2') + + test.debug() + self.assertEqual(ordering, ['setUp', 'test', 'tearDown', 'cleanup1', 'cleanup2']) + + + def test_enterContext(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + test = TestableTest('testNothing') + cleanups = [] + + test.addCleanup(cleanups.append, 'cleanup1') + cm = TestCM(cleanups, 42) + self.assertEqual(test.enterContext(cm), 42) + test.addCleanup(cleanups.append, 'cleanup2') + + self.assertTrue(test.doCleanups()) + self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterContext_arg_errors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + test = TestableTest('testNothing') + + with self.assertRaisesRegex(TypeError, 'the context manager'): + test.enterContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + test.enterContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + test.enterContext(LacksExit()) + + self.assertEqual(test._cleanups, []) + + +class TestClassCleanup(unittest.TestCase): + def test_addClassCleanUp(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + test = TestableTest('testNothing') + self.assertEqual(test._class_cleanups, []) + class_cleanups = [] + + def class_cleanup1(*args, **kwargs): + class_cleanups.append((3, args, kwargs)) + + def class_cleanup2(*args, **kwargs): + class_cleanups.append((4, args, kwargs)) + + TestableTest.addClassCleanup(class_cleanup1, 1, 2, 3, + four='hello', five='goodbye') + TestableTest.addClassCleanup(class_cleanup2) + + self.assertEqual(test._class_cleanups, + [(class_cleanup1, (1, 2, 3), + dict(four='hello', five='goodbye')), + (class_cleanup2, (), {})]) + + TestableTest.doClassCleanups() + self.assertEqual(class_cleanups, [(4, (), {}), (3, (1, 2, 3), + dict(four='hello', five='goodbye'))]) + + def test_run_class_cleanUp(self): + ordering = [] + blowUp = True + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering) + if blowUp: + raise Exception() + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + runTests(TestableTest) + self.assertEqual(ordering, ['setUpClass', 'cleanup_good']) + + ordering = [] + blowUp = False + runTests(TestableTest) + self.assertEqual(ordering, + ['setUpClass', 'test', 'tearDownClass', 'cleanup_good']) + + def test_run_class_cleanUp_without_tearDownClass(self): + ordering = [] + blowUp = True + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering) + if blowUp: + raise Exception() + def testNothing(self): + ordering.append('test') + @classmethod + @property + def tearDownClass(cls): + raise AttributeError + + runTests(TestableTest) + self.assertEqual(ordering, ['setUpClass', 'cleanup_good']) + + ordering = [] + blowUp = False + runTests(TestableTest) + self.assertEqual(ordering, + ['setUpClass', 'test', 'cleanup_good']) + + def test_debug_executes_classCleanUp(self): + ordering = [] + blowUp = False + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering, blowUp=blowUp) + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + suite.debug() + self.assertEqual(ordering, + ['setUpClass', 'test', 'tearDownClass', 'cleanup_good']) + + ordering = [] + blowUp = True + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + with self.assertRaises(Exception) as cm: + suite.debug() + self.assertEqual(str(cm.exception), 'CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'test', 'tearDownClass', 'cleanup_exc']) + + def test_debug_executes_classCleanUp_when_teardown_exception(self): + ordering = [] + blowUp = False + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering, blowUp=blowUp) + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + raise Exception('TearDownClassExc') + + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + with self.assertRaises(Exception) as cm: + suite.debug() + self.assertEqual(str(cm.exception), 'TearDownClassExc') + self.assertEqual(ordering, ['setUpClass', 'test']) + self.assertTrue(TestableTest._class_cleanups) + TestableTest._class_cleanups.clear() + + ordering = [] + blowUp = True + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + with self.assertRaises(Exception) as cm: + suite.debug() + self.assertEqual(str(cm.exception), 'TearDownClassExc') + self.assertEqual(ordering, ['setUpClass', 'test']) + self.assertTrue(TestableTest._class_cleanups) + TestableTest._class_cleanups.clear() + + def test_doClassCleanups_with_errors_addClassCleanUp(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + def cleanup1(): + raise Exception('cleanup1') + + def cleanup2(): + raise Exception('cleanup2') + + TestableTest.addClassCleanup(cleanup1) + TestableTest.addClassCleanup(cleanup2) + with self.assertRaises(Exception) as e: + TestableTest.doClassCleanups() + self.assertEqual(e, 'cleanup1') + + def test_with_errors_addCleanUp(self): + ordering = [] + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering) + def setUp(self): + ordering.append('setUp') + self.addCleanup(cleanup, ordering, blowUp=True) + def testNothing(self): + pass + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'setUp', 'cleanup_exc', + 'tearDownClass', 'cleanup_good']) + + def test_run_with_errors_addClassCleanUp(self): + ordering = [] + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering, blowUp=True) + def setUp(self): + ordering.append('setUp') + self.addCleanup(cleanup, ordering) + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'setUp', 'test', 'cleanup_good', + 'tearDownClass', 'cleanup_exc']) + + def test_with_errors_in_addClassCleanup_and_setUps(self): + ordering = [] + class_blow_up = False + method_blow_up = False + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering, blowUp=True) + if class_blow_up: + raise Exception('ClassExc') + def setUp(self): + ordering.append('setUp') + if method_blow_up: + raise Exception('MethodExc') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'setUp', 'test', + 'tearDownClass', 'cleanup_exc']) + + ordering = [] + class_blow_up = True + method_blow_up = False + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: ClassExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'cleanup_exc']) + + ordering = [] + class_blow_up = False + method_blow_up = True + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: MethodExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpClass', 'setUp', 'tearDownClass', + 'cleanup_exc']) + + def test_with_errors_in_tearDownClass(self): + ordering = [] + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering) + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + raise Exception('TearDownExc') + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: TearDownExc') + self.assertEqual(ordering, + ['setUpClass', 'test', 'tearDownClass', 'cleanup_good']) + + def test_enterClassContext(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + cleanups = [] + + TestableTest.addClassCleanup(cleanups.append, 'cleanup1') + cm = TestCM(cleanups, 42) + self.assertEqual(TestableTest.enterClassContext(cm), 42) + TestableTest.addClassCleanup(cleanups.append, 'cleanup2') + + TestableTest.doClassCleanups() + self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterClassContext_arg_errors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + with self.assertRaisesRegex(TypeError, 'the context manager'): + TestableTest.enterClassContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + TestableTest.enterClassContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + TestableTest.enterClassContext(LacksExit()) + + self.assertEqual(TestableTest._class_cleanups, []) + + +class TestModuleCleanUp(unittest.TestCase): + def test_add_and_do_ModuleCleanup(self): + module_cleanups = [] + + def module_cleanup1(*args, **kwargs): + module_cleanups.append((3, args, kwargs)) + + def module_cleanup2(*args, **kwargs): + module_cleanups.append((4, args, kwargs)) + + class Module(object): + unittest.addModuleCleanup(module_cleanup1, 1, 2, 3, + four='hello', five='goodbye') + unittest.addModuleCleanup(module_cleanup2) + + self.assertEqual(unittest.case._module_cleanups, + [(module_cleanup1, (1, 2, 3), + dict(four='hello', five='goodbye')), + (module_cleanup2, (), {})]) + + unittest.case.doModuleCleanups() + self.assertEqual(module_cleanups, [(4, (), {}), (3, (1, 2, 3), + dict(four='hello', five='goodbye'))]) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_doModuleCleanup_with_errors_in_addModuleCleanup(self): + module_cleanups = [] + + def module_cleanup_good(*args, **kwargs): + module_cleanups.append((3, args, kwargs)) + + def module_cleanup_bad(*args, **kwargs): + raise Exception('CleanUpExc') + + class Module(object): + unittest.addModuleCleanup(module_cleanup_good, 1, 2, 3, + four='hello', five='goodbye') + unittest.addModuleCleanup(module_cleanup_bad) + self.assertEqual(unittest.case._module_cleanups, + [(module_cleanup_good, (1, 2, 3), + dict(four='hello', five='goodbye')), + (module_cleanup_bad, (), {})]) + with self.assertRaises(Exception) as e: + unittest.case.doModuleCleanups() + self.assertEqual(str(e.exception), 'CleanUpExc') + self.assertEqual(unittest.case._module_cleanups, []) + + def test_addModuleCleanup_arg_errors(self): + cleanups = [] + def cleanup(*args, **kwargs): + cleanups.append((args, kwargs)) + + class Module(object): + unittest.addModuleCleanup(cleanup, 1, 2, function='hello') + with self.assertRaises(TypeError): + unittest.addModuleCleanup(function=cleanup, arg='hello') + with self.assertRaises(TypeError): + unittest.addModuleCleanup() + unittest.case.doModuleCleanups() + self.assertEqual(cleanups, + [((1, 2), {'function': 'hello'})]) + + def test_run_module_cleanUp(self): + blowUp = True + ordering = [] + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + if blowUp: + raise Exception('setUpModule Exc') + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + result = runTests(TestableTest) + self.assertEqual(ordering, ['setUpModule', 'cleanup_good']) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: setUpModule Exc') + + ordering = [] + blowUp = False + runTests(TestableTest) + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'test', 'tearDownClass', + 'tearDownModule', 'cleanup_good']) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_run_multiple_module_cleanUp(self): + blowUp = True + blowUp2 = False + ordering = [] + class Module1(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + if blowUp: + raise Exception() + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class Module2(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule2') + unittest.addModuleCleanup(cleanup, ordering) + if blowUp2: + raise Exception() + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule2') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + class TestableTest2(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass2') + def testNothing(self): + ordering.append('test2') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass2') + + TestableTest.__module__ = 'Module1' + sys.modules['Module1'] = Module1 + TestableTest2.__module__ = 'Module2' + sys.modules['Module2'] = Module2 + runTests(TestableTest, TestableTest2) + self.assertEqual(ordering, ['setUpModule', 'cleanup_good', + 'setUpModule2', 'setUpClass2', 'test2', + 'tearDownClass2', 'tearDownModule2', + 'cleanup_good']) + ordering = [] + blowUp = False + blowUp2 = True + runTests(TestableTest, TestableTest2) + self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test', + 'tearDownClass', 'tearDownModule', + 'cleanup_good', 'setUpModule2', + 'cleanup_good']) + + ordering = [] + blowUp = False + blowUp2 = False + runTests(TestableTest, TestableTest2) + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'test', 'tearDownClass', + 'tearDownModule', 'cleanup_good', 'setUpModule2', + 'setUpClass2', 'test2', 'tearDownClass2', + 'tearDownModule2', 'cleanup_good']) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_run_module_cleanUp_without_teardown(self): + ordering = [] + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + runTests(TestableTest) + self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test', + 'tearDownClass', 'cleanup_good']) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_run_module_cleanUp_when_teardown_exception(self): + ordering = [] + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + @staticmethod + def tearDownModule(): + raise Exception('CleanUpExc') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test', + 'tearDownClass', 'cleanup_good']) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_debug_module_executes_cleanUp(self): + ordering = [] + blowUp = False + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering, blowUp=blowUp) + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + suite.debug() + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'test', 'tearDownClass', + 'tearDownModule', 'cleanup_good']) + self.assertEqual(unittest.case._module_cleanups, []) + + ordering = [] + blowUp = True + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + with self.assertRaises(Exception) as cm: + suite.debug() + self.assertEqual(str(cm.exception), 'CleanUpExc') + self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test', + 'tearDownClass', 'tearDownModule', 'cleanup_exc']) + self.assertEqual(unittest.case._module_cleanups, []) + + def test_debug_module_cleanUp_when_teardown_exception(self): + ordering = [] + blowUp = False + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering, blowUp=blowUp) + @staticmethod + def tearDownModule(): + raise Exception('TearDownModuleExc') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + with self.assertRaises(Exception) as cm: + suite.debug() + self.assertEqual(str(cm.exception), 'TearDownModuleExc') + self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test', + 'tearDownClass']) + self.assertTrue(unittest.case._module_cleanups) + unittest.case._module_cleanups.clear() + + ordering = [] + blowUp = True + suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestableTest) + with self.assertRaises(Exception) as cm: + suite.debug() + self.assertEqual(str(cm.exception), 'TearDownModuleExc') + self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'test', + 'tearDownClass']) + self.assertTrue(unittest.case._module_cleanups) + unittest.case._module_cleanups.clear() + + def test_addClassCleanup_arg_errors(self): + cleanups = [] + def cleanup(*args, **kwargs): + cleanups.append((args, kwargs)) + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.addClassCleanup(cleanup, 1, 2, function=3, cls=4) + with self.assertRaises(TypeError): + cls.addClassCleanup(function=cleanup, arg='hello') + def testNothing(self): + pass + + with self.assertRaises(TypeError): + TestableTest.addClassCleanup() + with self.assertRaises(TypeError): + unittest.TestCase.addCleanup(cls=TestableTest(), function=cleanup) + runTests(TestableTest) + self.assertEqual(cleanups, + [((1, 2), {'function': 3, 'cls': 4})]) + + def test_addCleanup_arg_errors(self): + cleanups = [] + def cleanup(*args, **kwargs): + cleanups.append((args, kwargs)) + + class TestableTest(unittest.TestCase): + def setUp(self2): + self2.addCleanup(cleanup, 1, 2, function=3, self=4) + with self.assertRaises(TypeError): + self2.addCleanup(function=cleanup, arg='hello') + def testNothing(self): + pass + + with self.assertRaises(TypeError): + TestableTest().addCleanup() + with self.assertRaises(TypeError): + unittest.TestCase.addCleanup(self=TestableTest(), function=cleanup) + runTests(TestableTest) + self.assertEqual(cleanups, + [((1, 2), {'function': 3, 'self': 4})]) + + def test_with_errors_in_addClassCleanup(self): + ordering = [] + + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + cls.addClassCleanup(cleanup, ordering, blowUp=True) + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'test', 'tearDownClass', + 'cleanup_exc', 'tearDownModule', 'cleanup_good']) + + def test_with_errors_in_addCleanup(self): + ordering = [] + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering) + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + def setUp(self): + ordering.append('setUp') + self.addCleanup(cleanup, ordering, blowUp=True) + def testNothing(self): + ordering.append('test') + def tearDown(self): + ordering.append('tearDown') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpModule', 'setUp', 'test', 'tearDown', + 'cleanup_exc', 'tearDownModule', 'cleanup_good']) + + def test_with_errors_in_addModuleCleanup_and_setUps(self): + ordering = [] + module_blow_up = False + class_blow_up = False + method_blow_up = False + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup, ordering, blowUp=True) + if module_blow_up: + raise Exception('ModuleExc') + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + if class_blow_up: + raise Exception('ClassExc') + def setUp(self): + ordering.append('setUp') + if method_blow_up: + raise Exception('MethodExc') + def testNothing(self): + ordering.append('test') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + + TestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, + ['setUpModule', 'setUpClass', 'setUp', 'test', + 'tearDownClass', 'tearDownModule', + 'cleanup_exc']) + + ordering = [] + module_blow_up = True + class_blow_up = False + method_blow_up = False + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: ModuleExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, ['setUpModule', 'cleanup_exc']) + + ordering = [] + module_blow_up = False + class_blow_up = True + method_blow_up = False + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: ClassExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, ['setUpModule', 'setUpClass', + 'tearDownModule', 'cleanup_exc']) + + ordering = [] + module_blow_up = False + class_blow_up = False + method_blow_up = True + result = runTests(TestableTest) + self.assertEqual(result.errors[0][1].splitlines()[-1], + 'Exception: MethodExc') + self.assertEqual(result.errors[1][1].splitlines()[-1], + 'Exception: CleanUpExc') + self.assertEqual(ordering, ['setUpModule', 'setUpClass', 'setUp', + 'tearDownClass', 'tearDownModule', + 'cleanup_exc']) + + def test_module_cleanUp_with_multiple_classes(self): + ordering =[] + def cleanup1(): + ordering.append('cleanup1') + + def cleanup2(): + ordering.append('cleanup2') + + def cleanup3(): + ordering.append('cleanup3') + + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + unittest.addModuleCleanup(cleanup1) + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class TestableTest(unittest.TestCase): + def setUp(self): + ordering.append('setUp') + self.addCleanup(cleanup2) + def testNothing(self): + ordering.append('test') + def tearDown(self): + ordering.append('tearDown') + + class OtherTestableTest(unittest.TestCase): + def setUp(self): + ordering.append('setUp2') + self.addCleanup(cleanup3) + def testNothing(self): + ordering.append('test2') + def tearDown(self): + ordering.append('tearDown2') + + TestableTest.__module__ = 'Module' + OtherTestableTest.__module__ = 'Module' + sys.modules['Module'] = Module + runTests(TestableTest, OtherTestableTest) + self.assertEqual(ordering, + ['setUpModule', 'setUp', 'test', 'tearDown', + 'cleanup2', 'setUp2', 'test2', 'tearDown2', + 'cleanup3', 'tearDownModule', 'cleanup1']) + + def test_enterModuleContext(self): + cleanups = [] + + unittest.addModuleCleanup(cleanups.append, 'cleanup1') + cm = TestCM(cleanups, 42) + self.assertEqual(unittest.enterModuleContext(cm), 42) + unittest.addModuleCleanup(cleanups.append, 'cleanup2') + + unittest.case.doModuleCleanups() + self.assertEqual(cleanups, ['enter', 'cleanup2', 'exit', 'cleanup1']) + + def test_enterModuleContext_arg_errors(self): + class TestableTest(unittest.TestCase): + def testNothing(self): + pass + + with self.assertRaisesRegex(TypeError, 'the context manager'): + unittest.enterModuleContext(LacksEnterAndExit()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + unittest.enterModuleContext(LacksEnter()) + with self.assertRaisesRegex(TypeError, 'the context manager'): + unittest.enterModuleContext(LacksExit()) + + self.assertEqual(unittest.case._module_cleanups, []) + + +class Test_TextTestRunner(unittest.TestCase): + """Tests for TextTestRunner.""" + + def setUp(self): + # clean the environment from pre-existing PYTHONWARNINGS to make + # test_warnings results consistent + self.pythonwarnings = os.environ.get('PYTHONWARNINGS') + if self.pythonwarnings: + del os.environ['PYTHONWARNINGS'] + + def tearDown(self): + # bring back pre-existing PYTHONWARNINGS if present + if self.pythonwarnings: + os.environ['PYTHONWARNINGS'] = self.pythonwarnings + + def test_init(self): + runner = unittest.TextTestRunner() + self.assertFalse(runner.failfast) + self.assertFalse(runner.buffer) + self.assertEqual(runner.verbosity, 1) + self.assertEqual(runner.warnings, None) + self.assertTrue(runner.descriptions) + self.assertEqual(runner.resultclass, unittest.TextTestResult) + self.assertFalse(runner.tb_locals) + + def test_multiple_inheritance(self): + class AResult(unittest.TestResult): + def __init__(self, stream, descriptions, verbosity): + super(AResult, self).__init__(stream, descriptions, verbosity) + + class ATextResult(unittest.TextTestResult, AResult): + pass + + # This used to raise an exception due to TextTestResult not passing + # on arguments in its __init__ super call + ATextResult(None, None, 1) + + def testBufferAndFailfast(self): + class Test(unittest.TestCase): + def testFoo(self): + pass + result = unittest.TestResult() + runner = unittest.TextTestRunner(stream=io.StringIO(), failfast=True, + buffer=True) + # Use our result object + runner._makeResult = lambda: result + runner.run(Test('testFoo')) + + self.assertTrue(result.failfast) + self.assertTrue(result.buffer) + + def test_locals(self): + runner = unittest.TextTestRunner(stream=io.StringIO(), tb_locals=True) + result = runner.run(unittest.TestSuite()) + self.assertEqual(True, result.tb_locals) + + def testRunnerRegistersResult(self): + class Test(unittest.TestCase): + def testFoo(self): + pass + originalRegisterResult = unittest.runner.registerResult + def cleanup(): + unittest.runner.registerResult = originalRegisterResult + self.addCleanup(cleanup) + + result = unittest.TestResult() + runner = unittest.TextTestRunner(stream=io.StringIO()) + # Use our result object + runner._makeResult = lambda: result + + self.wasRegistered = 0 + def fakeRegisterResult(thisResult): + self.wasRegistered += 1 + self.assertEqual(thisResult, result) + unittest.runner.registerResult = fakeRegisterResult + + runner.run(unittest.TestSuite()) + self.assertEqual(self.wasRegistered, 1) + + def test_works_with_result_without_startTestRun_stopTestRun(self): + class OldTextResult(ResultWithNoStartTestRunStopTestRun): + separator2 = '' + def printErrors(self): + pass + + class Runner(unittest.TextTestRunner): + def __init__(self): + super(Runner, self).__init__(io.StringIO()) + + def _makeResult(self): + return OldTextResult() + + runner = Runner() + runner.run(unittest.TestSuite()) + + def test_startTestRun_stopTestRun_called(self): + class LoggingTextResult(LoggingResult): + separator2 = '' + def printErrors(self): + pass + + class LoggingRunner(unittest.TextTestRunner): + def __init__(self, events): + super(LoggingRunner, self).__init__(io.StringIO()) + self._events = events + + def _makeResult(self): + return LoggingTextResult(self._events) + + events = [] + runner = LoggingRunner(events) + runner.run(unittest.TestSuite()) + expected = ['startTestRun', 'stopTestRun'] + self.assertEqual(events, expected) + + def test_pickle_unpickle(self): + # Issue #7197: a TextTestRunner should be (un)pickleable. This is + # required by test_multiprocessing under Windows (in verbose mode). + stream = io.StringIO("foo") + runner = unittest.TextTestRunner(stream) + for protocol in range(2, pickle.HIGHEST_PROTOCOL + 1): + s = pickle.dumps(runner, protocol) + obj = pickle.loads(s) + # StringIO objects never compare equal, a cheap test instead. + self.assertEqual(obj.stream.getvalue(), stream.getvalue()) + + def test_resultclass(self): + def MockResultClass(*args): + return args + STREAM = object() + DESCRIPTIONS = object() + VERBOSITY = object() + runner = unittest.TextTestRunner(STREAM, DESCRIPTIONS, VERBOSITY, + resultclass=MockResultClass) + self.assertEqual(runner.resultclass, MockResultClass) + + expectedresult = (runner.stream, DESCRIPTIONS, VERBOSITY) + self.assertEqual(runner._makeResult(), expectedresult) + + @support.requires_subprocess() + def test_warnings(self): + """ + Check that warnings argument of TextTestRunner correctly affects the + behavior of the warnings. + """ + # see #10535 and the _test_warnings file for more information + + def get_parse_out_err(p): + return [b.splitlines() for b in p.communicate()] + opts = dict(stdout=subprocess.PIPE, stderr=subprocess.PIPE, + cwd=os.path.dirname(__file__)) + ae_msg = b'Please use assertEqual instead.' + at_msg = b'Please use assertTrue instead.' + + # no args -> all the warnings are printed, unittest warnings only once + p = subprocess.Popen([sys.executable, '-E', '_test_warnings.py'], **opts) + with p: + out, err = get_parse_out_err(p) + self.assertIn(b'OK', err) + # check that the total number of warnings in the output is correct + self.assertEqual(len(out), 12) + # check that the numbers of the different kind of warnings is correct + for msg in [b'dw', b'iw', b'uw']: + self.assertEqual(out.count(msg), 3) + for msg in [ae_msg, at_msg, b'rw']: + self.assertEqual(out.count(msg), 1) + + args_list = ( + # passing 'ignore' as warnings arg -> no warnings + [sys.executable, '_test_warnings.py', 'ignore'], + # -W doesn't affect the result if the arg is passed + [sys.executable, '-Wa', '_test_warnings.py', 'ignore'], + # -W affects the result if the arg is not passed + [sys.executable, '-Wi', '_test_warnings.py'] + ) + # in all these cases no warnings are printed + for args in args_list: + p = subprocess.Popen(args, **opts) + with p: + out, err = get_parse_out_err(p) + self.assertIn(b'OK', err) + self.assertEqual(len(out), 0) + + + # passing 'always' as warnings arg -> all the warnings printed, + # unittest warnings only once + p = subprocess.Popen([sys.executable, '_test_warnings.py', 'always'], + **opts) + with p: + out, err = get_parse_out_err(p) + self.assertIn(b'OK', err) + self.assertEqual(len(out), 14) + for msg in [b'dw', b'iw', b'uw', b'rw']: + self.assertEqual(out.count(msg), 3) + for msg in [ae_msg, at_msg]: + self.assertEqual(out.count(msg), 1) + + def testStdErrLookedUpAtInstantiationTime(self): + # see issue 10786 + old_stderr = sys.stderr + f = io.StringIO() + sys.stderr = f + try: + runner = unittest.TextTestRunner() + self.assertTrue(runner.stream.stream is f) + finally: + sys.stderr = old_stderr + + def testSpecifiedStreamUsed(self): + # see issue 10786 + f = io.StringIO() + runner = unittest.TextTestRunner(f) + self.assertTrue(runner.stream.stream is f) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/test_setups.py b/Lib/test/test_unittest/test_setups.py new file mode 100644 index 0000000..2df703e --- /dev/null +++ b/Lib/test/test_unittest/test_setups.py @@ -0,0 +1,507 @@ +import io +import sys + +import unittest + + +def resultFactory(*_): + return unittest.TestResult() + + +class TestSetups(unittest.TestCase): + + def getRunner(self): + return unittest.TextTestRunner(resultclass=resultFactory, + stream=io.StringIO()) + def runTests(self, *cases): + suite = unittest.TestSuite() + for case in cases: + tests = unittest.defaultTestLoader.loadTestsFromTestCase(case) + suite.addTests(tests) + + runner = self.getRunner() + + # creating a nested suite exposes some potential bugs + realSuite = unittest.TestSuite() + realSuite.addTest(suite) + # adding empty suites to the end exposes potential bugs + suite.addTest(unittest.TestSuite()) + realSuite.addTest(unittest.TestSuite()) + return runner.run(realSuite) + + def test_setup_class(self): + class Test(unittest.TestCase): + setUpCalled = 0 + @classmethod + def setUpClass(cls): + Test.setUpCalled += 1 + unittest.TestCase.setUpClass() + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(Test) + + self.assertEqual(Test.setUpCalled, 1) + self.assertEqual(result.testsRun, 2) + self.assertEqual(len(result.errors), 0) + + def test_teardown_class(self): + class Test(unittest.TestCase): + tearDownCalled = 0 + @classmethod + def tearDownClass(cls): + Test.tearDownCalled += 1 + unittest.TestCase.tearDownClass() + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(Test) + + self.assertEqual(Test.tearDownCalled, 1) + self.assertEqual(result.testsRun, 2) + self.assertEqual(len(result.errors), 0) + + def test_teardown_class_two_classes(self): + class Test(unittest.TestCase): + tearDownCalled = 0 + @classmethod + def tearDownClass(cls): + Test.tearDownCalled += 1 + unittest.TestCase.tearDownClass() + def test_one(self): + pass + def test_two(self): + pass + + class Test2(unittest.TestCase): + tearDownCalled = 0 + @classmethod + def tearDownClass(cls): + Test2.tearDownCalled += 1 + unittest.TestCase.tearDownClass() + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(Test, Test2) + + self.assertEqual(Test.tearDownCalled, 1) + self.assertEqual(Test2.tearDownCalled, 1) + self.assertEqual(result.testsRun, 4) + self.assertEqual(len(result.errors), 0) + + def test_error_in_setupclass(self): + class BrokenTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + raise TypeError('foo') + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(BrokenTest) + + self.assertEqual(result.testsRun, 0) + self.assertEqual(len(result.errors), 1) + error, _ = result.errors[0] + self.assertEqual(str(error), + 'setUpClass (%s.%s)' % (__name__, BrokenTest.__qualname__)) + + def test_error_in_teardown_class(self): + class Test(unittest.TestCase): + tornDown = 0 + @classmethod + def tearDownClass(cls): + Test.tornDown += 1 + raise TypeError('foo') + def test_one(self): + pass + def test_two(self): + pass + + class Test2(unittest.TestCase): + tornDown = 0 + @classmethod + def tearDownClass(cls): + Test2.tornDown += 1 + raise TypeError('foo') + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(Test, Test2) + self.assertEqual(result.testsRun, 4) + self.assertEqual(len(result.errors), 2) + self.assertEqual(Test.tornDown, 1) + self.assertEqual(Test2.tornDown, 1) + + error, _ = result.errors[0] + self.assertEqual(str(error), + 'tearDownClass (%s.%s)' % (__name__, Test.__qualname__)) + + def test_class_not_torndown_when_setup_fails(self): + class Test(unittest.TestCase): + tornDown = False + @classmethod + def setUpClass(cls): + raise TypeError + @classmethod + def tearDownClass(cls): + Test.tornDown = True + raise TypeError('foo') + def test_one(self): + pass + + self.runTests(Test) + self.assertFalse(Test.tornDown) + + def test_class_not_setup_or_torndown_when_skipped(self): + class Test(unittest.TestCase): + classSetUp = False + tornDown = False + @classmethod + def setUpClass(cls): + Test.classSetUp = True + @classmethod + def tearDownClass(cls): + Test.tornDown = True + def test_one(self): + pass + + Test = unittest.skip("hop")(Test) + self.runTests(Test) + self.assertFalse(Test.classSetUp) + self.assertFalse(Test.tornDown) + + def test_setup_teardown_order_with_pathological_suite(self): + results = [] + + class Module1(object): + @staticmethod + def setUpModule(): + results.append('Module1.setUpModule') + @staticmethod + def tearDownModule(): + results.append('Module1.tearDownModule') + + class Module2(object): + @staticmethod + def setUpModule(): + results.append('Module2.setUpModule') + @staticmethod + def tearDownModule(): + results.append('Module2.tearDownModule') + + class Test1(unittest.TestCase): + @classmethod + def setUpClass(cls): + results.append('setup 1') + @classmethod + def tearDownClass(cls): + results.append('teardown 1') + def testOne(self): + results.append('Test1.testOne') + def testTwo(self): + results.append('Test1.testTwo') + + class Test2(unittest.TestCase): + @classmethod + def setUpClass(cls): + results.append('setup 2') + @classmethod + def tearDownClass(cls): + results.append('teardown 2') + def testOne(self): + results.append('Test2.testOne') + def testTwo(self): + results.append('Test2.testTwo') + + class Test3(unittest.TestCase): + @classmethod + def setUpClass(cls): + results.append('setup 3') + @classmethod + def tearDownClass(cls): + results.append('teardown 3') + def testOne(self): + results.append('Test3.testOne') + def testTwo(self): + results.append('Test3.testTwo') + + Test1.__module__ = Test2.__module__ = 'Module' + Test3.__module__ = 'Module2' + sys.modules['Module'] = Module1 + sys.modules['Module2'] = Module2 + + first = unittest.TestSuite((Test1('testOne'),)) + second = unittest.TestSuite((Test1('testTwo'),)) + third = unittest.TestSuite((Test2('testOne'),)) + fourth = unittest.TestSuite((Test2('testTwo'),)) + fifth = unittest.TestSuite((Test3('testOne'),)) + sixth = unittest.TestSuite((Test3('testTwo'),)) + suite = unittest.TestSuite((first, second, third, fourth, fifth, sixth)) + + runner = self.getRunner() + result = runner.run(suite) + self.assertEqual(result.testsRun, 6) + self.assertEqual(len(result.errors), 0) + + self.assertEqual(results, + ['Module1.setUpModule', 'setup 1', + 'Test1.testOne', 'Test1.testTwo', 'teardown 1', + 'setup 2', 'Test2.testOne', 'Test2.testTwo', + 'teardown 2', 'Module1.tearDownModule', + 'Module2.setUpModule', 'setup 3', + 'Test3.testOne', 'Test3.testTwo', + 'teardown 3', 'Module2.tearDownModule']) + + def test_setup_module(self): + class Module(object): + moduleSetup = 0 + @staticmethod + def setUpModule(): + Module.moduleSetup += 1 + + class Test(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + sys.modules['Module'] = Module + + result = self.runTests(Test) + self.assertEqual(Module.moduleSetup, 1) + self.assertEqual(result.testsRun, 2) + self.assertEqual(len(result.errors), 0) + + def test_error_in_setup_module(self): + class Module(object): + moduleSetup = 0 + moduleTornDown = 0 + @staticmethod + def setUpModule(): + Module.moduleSetup += 1 + raise TypeError('foo') + @staticmethod + def tearDownModule(): + Module.moduleTornDown += 1 + + class Test(unittest.TestCase): + classSetUp = False + classTornDown = False + @classmethod + def setUpClass(cls): + Test.classSetUp = True + @classmethod + def tearDownClass(cls): + Test.classTornDown = True + def test_one(self): + pass + def test_two(self): + pass + + class Test2(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + Test2.__module__ = 'Module' + sys.modules['Module'] = Module + + result = self.runTests(Test, Test2) + self.assertEqual(Module.moduleSetup, 1) + self.assertEqual(Module.moduleTornDown, 0) + self.assertEqual(result.testsRun, 0) + self.assertFalse(Test.classSetUp) + self.assertFalse(Test.classTornDown) + self.assertEqual(len(result.errors), 1) + error, _ = result.errors[0] + self.assertEqual(str(error), 'setUpModule (Module)') + + def test_testcase_with_missing_module(self): + class Test(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + sys.modules.pop('Module', None) + + result = self.runTests(Test) + self.assertEqual(result.testsRun, 2) + + def test_teardown_module(self): + class Module(object): + moduleTornDown = 0 + @staticmethod + def tearDownModule(): + Module.moduleTornDown += 1 + + class Test(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + sys.modules['Module'] = Module + + result = self.runTests(Test) + self.assertEqual(Module.moduleTornDown, 1) + self.assertEqual(result.testsRun, 2) + self.assertEqual(len(result.errors), 0) + + def test_error_in_teardown_module(self): + class Module(object): + moduleTornDown = 0 + @staticmethod + def tearDownModule(): + Module.moduleTornDown += 1 + raise TypeError('foo') + + class Test(unittest.TestCase): + classSetUp = False + classTornDown = False + @classmethod + def setUpClass(cls): + Test.classSetUp = True + @classmethod + def tearDownClass(cls): + Test.classTornDown = True + def test_one(self): + pass + def test_two(self): + pass + + class Test2(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + Test.__module__ = 'Module' + Test2.__module__ = 'Module' + sys.modules['Module'] = Module + + result = self.runTests(Test, Test2) + self.assertEqual(Module.moduleTornDown, 1) + self.assertEqual(result.testsRun, 4) + self.assertTrue(Test.classSetUp) + self.assertTrue(Test.classTornDown) + self.assertEqual(len(result.errors), 1) + error, _ = result.errors[0] + self.assertEqual(str(error), 'tearDownModule (Module)') + + def test_skiptest_in_setupclass(self): + class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + raise unittest.SkipTest('foo') + def test_one(self): + pass + def test_two(self): + pass + + result = self.runTests(Test) + self.assertEqual(result.testsRun, 0) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.skipped), 1) + skipped = result.skipped[0][0] + self.assertEqual(str(skipped), + 'setUpClass (%s.%s)' % (__name__, Test.__qualname__)) + + def test_skiptest_in_setupmodule(self): + class Test(unittest.TestCase): + def test_one(self): + pass + def test_two(self): + pass + + class Module(object): + @staticmethod + def setUpModule(): + raise unittest.SkipTest('foo') + + Test.__module__ = 'Module' + sys.modules['Module'] = Module + + result = self.runTests(Test) + self.assertEqual(result.testsRun, 0) + self.assertEqual(len(result.errors), 0) + self.assertEqual(len(result.skipped), 1) + skipped = result.skipped[0][0] + self.assertEqual(str(skipped), 'setUpModule (Module)') + + def test_suite_debug_executes_setups_and_teardowns(self): + ordering = [] + + class Module(object): + @staticmethod + def setUpModule(): + ordering.append('setUpModule') + @staticmethod + def tearDownModule(): + ordering.append('tearDownModule') + + class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + ordering.append('setUpClass') + @classmethod + def tearDownClass(cls): + ordering.append('tearDownClass') + def test_something(self): + ordering.append('test_something') + + Test.__module__ = 'Module' + sys.modules['Module'] = Module + + suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test) + suite.debug() + expectedOrder = ['setUpModule', 'setUpClass', 'test_something', 'tearDownClass', 'tearDownModule'] + self.assertEqual(ordering, expectedOrder) + + def test_suite_debug_propagates_exceptions(self): + class Module(object): + @staticmethod + def setUpModule(): + if phase == 0: + raise Exception('setUpModule') + @staticmethod + def tearDownModule(): + if phase == 1: + raise Exception('tearDownModule') + + class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + if phase == 2: + raise Exception('setUpClass') + @classmethod + def tearDownClass(cls): + if phase == 3: + raise Exception('tearDownClass') + def test_something(self): + if phase == 4: + raise Exception('test_something') + + Test.__module__ = 'Module' + sys.modules['Module'] = Module + + messages = ('setUpModule', 'tearDownModule', 'setUpClass', 'tearDownClass', 'test_something') + for phase, msg in enumerate(messages): + _suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test) + suite = unittest.TestSuite([_suite]) + with self.assertRaisesRegex(Exception, msg): + suite.debug() + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/test_skipping.py b/Lib/test/test_unittest/test_skipping.py new file mode 100644 index 0000000..f146dca --- /dev/null +++ b/Lib/test/test_unittest/test_skipping.py @@ -0,0 +1,530 @@ +import unittest + +from test.test_unittest.support import LoggingResult + + +class Test_TestSkipping(unittest.TestCase): + + def test_skipping(self): + class Foo(unittest.TestCase): + def defaultTestResult(self): + return LoggingResult(events) + def test_skip_me(self): + self.skipTest("skip") + events = [] + result = LoggingResult(events) + test = Foo("test_skip_me") + self.assertIs(test.run(result), result) + self.assertEqual(events, ['startTest', 'addSkip', 'stopTest']) + self.assertEqual(result.skipped, [(test, "skip")]) + + events = [] + result = test.run() + self.assertEqual(events, ['startTestRun', 'startTest', 'addSkip', + 'stopTest', 'stopTestRun']) + self.assertEqual(result.skipped, [(test, "skip")]) + self.assertEqual(result.testsRun, 1) + + # Try letting setUp skip the test now. + class Foo(unittest.TestCase): + def defaultTestResult(self): + return LoggingResult(events) + def setUp(self): + self.skipTest("testing") + def test_nothing(self): pass + events = [] + result = LoggingResult(events) + test = Foo("test_nothing") + self.assertIs(test.run(result), result) + self.assertEqual(events, ['startTest', 'addSkip', 'stopTest']) + self.assertEqual(result.skipped, [(test, "testing")]) + self.assertEqual(result.testsRun, 1) + + events = [] + result = test.run() + self.assertEqual(events, ['startTestRun', 'startTest', 'addSkip', + 'stopTest', 'stopTestRun']) + self.assertEqual(result.skipped, [(test, "testing")]) + self.assertEqual(result.testsRun, 1) + + def test_skipping_subtests(self): + class Foo(unittest.TestCase): + def defaultTestResult(self): + return LoggingResult(events) + def test_skip_me(self): + with self.subTest(a=1): + with self.subTest(b=2): + self.skipTest("skip 1") + self.skipTest("skip 2") + self.skipTest("skip 3") + events = [] + result = LoggingResult(events) + test = Foo("test_skip_me") + self.assertIs(test.run(result), result) + self.assertEqual(events, ['startTest', 'addSkip', 'addSkip', + 'addSkip', 'stopTest']) + self.assertEqual(len(result.skipped), 3) + subtest, msg = result.skipped[0] + self.assertEqual(msg, "skip 1") + self.assertIsInstance(subtest, unittest.TestCase) + self.assertIsNot(subtest, test) + subtest, msg = result.skipped[1] + self.assertEqual(msg, "skip 2") + self.assertIsInstance(subtest, unittest.TestCase) + self.assertIsNot(subtest, test) + self.assertEqual(result.skipped[2], (test, "skip 3")) + + events = [] + result = test.run() + self.assertEqual(events, + ['startTestRun', 'startTest', 'addSkip', 'addSkip', + 'addSkip', 'stopTest', 'stopTestRun']) + self.assertEqual([msg for subtest, msg in result.skipped], + ['skip 1', 'skip 2', 'skip 3']) + + def test_skipping_decorators(self): + op_table = ((unittest.skipUnless, False, True), + (unittest.skipIf, True, False)) + for deco, do_skip, dont_skip in op_table: + class Foo(unittest.TestCase): + def defaultTestResult(self): + return LoggingResult(events) + + @deco(do_skip, "testing") + def test_skip(self): pass + + @deco(dont_skip, "testing") + def test_dont_skip(self): pass + test_do_skip = Foo("test_skip") + test_dont_skip = Foo("test_dont_skip") + + suite = unittest.TestSuite([test_do_skip, test_dont_skip]) + events = [] + result = LoggingResult(events) + self.assertIs(suite.run(result), result) + self.assertEqual(len(result.skipped), 1) + expected = ['startTest', 'addSkip', 'stopTest', + 'startTest', 'addSuccess', 'stopTest'] + self.assertEqual(events, expected) + self.assertEqual(result.testsRun, 2) + self.assertEqual(result.skipped, [(test_do_skip, "testing")]) + self.assertTrue(result.wasSuccessful()) + + events = [] + result = test_do_skip.run() + self.assertEqual(events, ['startTestRun', 'startTest', 'addSkip', + 'stopTest', 'stopTestRun']) + self.assertEqual(result.skipped, [(test_do_skip, "testing")]) + + events = [] + result = test_dont_skip.run() + self.assertEqual(events, ['startTestRun', 'startTest', 'addSuccess', + 'stopTest', 'stopTestRun']) + self.assertEqual(result.skipped, []) + + def test_skip_class(self): + @unittest.skip("testing") + class Foo(unittest.TestCase): + def defaultTestResult(self): + return LoggingResult(events) + def test_1(self): + record.append(1) + events = [] + record = [] + result = LoggingResult(events) + test = Foo("test_1") + suite = unittest.TestSuite([test]) + self.assertIs(suite.run(result), result) + self.assertEqual(events, ['startTest', 'addSkip', 'stopTest']) + self.assertEqual(result.skipped, [(test, "testing")]) + self.assertEqual(record, []) + + events = [] + result = test.run() + self.assertEqual(events, ['startTestRun', 'startTest', 'addSkip', + 'stopTest', 'stopTestRun']) + self.assertEqual(result.skipped, [(test, "testing")]) + self.assertEqual(record, []) + + def test_skip_non_unittest_class(self): + @unittest.skip("testing") + class Mixin: + def test_1(self): + record.append(1) + class Foo(Mixin, unittest.TestCase): + pass + record = [] + result = unittest.TestResult() + test = Foo("test_1") + suite = unittest.TestSuite([test]) + self.assertIs(suite.run(result), result) + self.assertEqual(result.skipped, [(test, "testing")]) + self.assertEqual(record, []) + + def test_skip_in_setup(self): + class Foo(unittest.TestCase): + def setUp(self): + self.skipTest("skip") + def test_skip_me(self): + self.fail("shouldn't come here") + events = [] + result = LoggingResult(events) + test = Foo("test_skip_me") + self.assertIs(test.run(result), result) + self.assertEqual(events, ['startTest', 'addSkip', 'stopTest']) + self.assertEqual(result.skipped, [(test, "skip")]) + + def test_skip_in_cleanup(self): + class Foo(unittest.TestCase): + def test_skip_me(self): + pass + def tearDown(self): + self.skipTest("skip") + events = [] + result = LoggingResult(events) + test = Foo("test_skip_me") + self.assertIs(test.run(result), result) + self.assertEqual(events, ['startTest', 'addSkip', 'stopTest']) + self.assertEqual(result.skipped, [(test, "skip")]) + + def test_failure_and_skip_in_cleanup(self): + class Foo(unittest.TestCase): + def test_skip_me(self): + self.fail("fail") + def tearDown(self): + self.skipTest("skip") + events = [] + result = LoggingResult(events) + test = Foo("test_skip_me") + self.assertIs(test.run(result), result) + self.assertEqual(events, ['startTest', 'addFailure', 'addSkip', 'stopTest']) + self.assertEqual(result.skipped, [(test, "skip")]) + + def test_skipping_and_fail_in_cleanup(self): + class Foo(unittest.TestCase): + def test_skip_me(self): + self.skipTest("skip") + def tearDown(self): + self.fail("fail") + events = [] + result = LoggingResult(events) + test = Foo("test_skip_me") + self.assertIs(test.run(result), result) + self.assertEqual(events, ['startTest', 'addSkip', 'addFailure', 'stopTest']) + self.assertEqual(result.skipped, [(test, "skip")]) + + def test_expected_failure(self): + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + self.fail("help me!") + events = [] + result = LoggingResult(events) + test = Foo("test_die") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', 'addExpectedFailure', 'stopTest']) + self.assertFalse(result.failures) + self.assertEqual(result.expectedFailures[0][0], test) + self.assertFalse(result.unexpectedSuccesses) + self.assertTrue(result.wasSuccessful()) + + def test_expected_failure_with_wrapped_class(self): + @unittest.expectedFailure + class Foo(unittest.TestCase): + def test_1(self): + self.assertTrue(False) + + events = [] + result = LoggingResult(events) + test = Foo("test_1") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', 'addExpectedFailure', 'stopTest']) + self.assertFalse(result.failures) + self.assertEqual(result.expectedFailures[0][0], test) + self.assertFalse(result.unexpectedSuccesses) + self.assertTrue(result.wasSuccessful()) + + def test_expected_failure_with_wrapped_subclass(self): + class Foo(unittest.TestCase): + def test_1(self): + self.assertTrue(False) + + @unittest.expectedFailure + class Bar(Foo): + pass + + events = [] + result = LoggingResult(events) + test = Bar("test_1") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', 'addExpectedFailure', 'stopTest']) + self.assertFalse(result.failures) + self.assertEqual(result.expectedFailures[0][0], test) + self.assertFalse(result.unexpectedSuccesses) + self.assertTrue(result.wasSuccessful()) + + def test_expected_failure_subtests(self): + # A failure in any subtest counts as the expected failure of the + # whole test. + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + with self.subTest(): + # This one succeeds + pass + with self.subTest(): + self.fail("help me!") + with self.subTest(): + # This one doesn't get executed + self.fail("shouldn't come here") + events = [] + result = LoggingResult(events) + test = Foo("test_die") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', 'addSubTestSuccess', + 'addExpectedFailure', 'stopTest']) + self.assertFalse(result.failures) + self.assertEqual(len(result.expectedFailures), 1) + self.assertIs(result.expectedFailures[0][0], test) + self.assertFalse(result.unexpectedSuccesses) + self.assertTrue(result.wasSuccessful()) + + def test_expected_failure_and_fail_in_cleanup(self): + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + self.fail("help me!") + def tearDown(self): + self.fail("bad tearDown") + events = [] + result = LoggingResult(events) + test = Foo("test_die") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', 'addFailure', 'stopTest']) + self.assertEqual(len(result.failures), 1) + self.assertIn('AssertionError: bad tearDown', result.failures[0][1]) + self.assertFalse(result.expectedFailures) + self.assertFalse(result.unexpectedSuccesses) + self.assertFalse(result.wasSuccessful()) + + def test_expected_failure_and_skip_in_cleanup(self): + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + self.fail("help me!") + def tearDown(self): + self.skipTest("skip") + events = [] + result = LoggingResult(events) + test = Foo("test_die") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', 'addSkip', 'stopTest']) + self.assertFalse(result.failures) + self.assertFalse(result.expectedFailures) + self.assertFalse(result.unexpectedSuccesses) + self.assertEqual(result.skipped, [(test, "skip")]) + self.assertTrue(result.wasSuccessful()) + + def test_unexpected_success(self): + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + pass + events = [] + result = LoggingResult(events) + test = Foo("test_die") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', 'addUnexpectedSuccess', 'stopTest']) + self.assertFalse(result.failures) + self.assertFalse(result.expectedFailures) + self.assertEqual(result.unexpectedSuccesses, [test]) + self.assertFalse(result.wasSuccessful()) + + def test_unexpected_success_subtests(self): + # Success in all subtests counts as the unexpected success of + # the whole test. + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + with self.subTest(): + # This one succeeds + pass + with self.subTest(): + # So does this one + pass + events = [] + result = LoggingResult(events) + test = Foo("test_die") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', + 'addSubTestSuccess', 'addSubTestSuccess', + 'addUnexpectedSuccess', 'stopTest']) + self.assertFalse(result.failures) + self.assertFalse(result.expectedFailures) + self.assertEqual(result.unexpectedSuccesses, [test]) + self.assertFalse(result.wasSuccessful()) + + def test_unexpected_success_and_fail_in_cleanup(self): + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + pass + def tearDown(self): + self.fail("bad tearDown") + events = [] + result = LoggingResult(events) + test = Foo("test_die") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', 'addFailure', 'stopTest']) + self.assertEqual(len(result.failures), 1) + self.assertIn('AssertionError: bad tearDown', result.failures[0][1]) + self.assertFalse(result.expectedFailures) + self.assertFalse(result.unexpectedSuccesses) + self.assertFalse(result.wasSuccessful()) + + def test_unexpected_success_and_skip_in_cleanup(self): + class Foo(unittest.TestCase): + @unittest.expectedFailure + def test_die(self): + pass + def tearDown(self): + self.skipTest("skip") + events = [] + result = LoggingResult(events) + test = Foo("test_die") + self.assertIs(test.run(result), result) + self.assertEqual(events, + ['startTest', 'addSkip', 'stopTest']) + self.assertFalse(result.failures) + self.assertFalse(result.expectedFailures) + self.assertFalse(result.unexpectedSuccesses) + self.assertEqual(result.skipped, [(test, "skip")]) + self.assertTrue(result.wasSuccessful()) + + def test_skip_doesnt_run_setup(self): + class Foo(unittest.TestCase): + wasSetUp = False + wasTornDown = False + def setUp(self): + Foo.wasSetUp = True + def tornDown(self): + Foo.wasTornDown = True + @unittest.skip('testing') + def test_1(self): + pass + + result = unittest.TestResult() + test = Foo("test_1") + suite = unittest.TestSuite([test]) + self.assertIs(suite.run(result), result) + self.assertEqual(result.skipped, [(test, "testing")]) + self.assertFalse(Foo.wasSetUp) + self.assertFalse(Foo.wasTornDown) + + def test_decorated_skip(self): + def decorator(func): + def inner(*a): + return func(*a) + return inner + + class Foo(unittest.TestCase): + @decorator + @unittest.skip('testing') + def test_1(self): + pass + + result = unittest.TestResult() + test = Foo("test_1") + suite = unittest.TestSuite([test]) + self.assertIs(suite.run(result), result) + self.assertEqual(result.skipped, [(test, "testing")]) + + def test_skip_without_reason(self): + class Foo(unittest.TestCase): + @unittest.skip + def test_1(self): + pass + + result = unittest.TestResult() + test = Foo("test_1") + suite = unittest.TestSuite([test]) + self.assertIs(suite.run(result), result) + self.assertEqual(result.skipped, [(test, "")]) + + def test_debug_skipping(self): + class Foo(unittest.TestCase): + def setUp(self): + events.append("setUp") + def tearDown(self): + events.append("tearDown") + def test1(self): + self.skipTest('skipping exception') + events.append("test1") + @unittest.skip("skipping decorator") + def test2(self): + events.append("test2") + + events = [] + test = Foo("test1") + with self.assertRaises(unittest.SkipTest) as cm: + test.debug() + self.assertIn("skipping exception", str(cm.exception)) + self.assertEqual(events, ["setUp"]) + + events = [] + test = Foo("test2") + with self.assertRaises(unittest.SkipTest) as cm: + test.debug() + self.assertIn("skipping decorator", str(cm.exception)) + self.assertEqual(events, []) + + def test_debug_skipping_class(self): + @unittest.skip("testing") + class Foo(unittest.TestCase): + def setUp(self): + events.append("setUp") + def tearDown(self): + events.append("tearDown") + def test(self): + events.append("test") + + events = [] + test = Foo("test") + with self.assertRaises(unittest.SkipTest) as cm: + test.debug() + self.assertIn("testing", str(cm.exception)) + self.assertEqual(events, []) + + def test_debug_skipping_subtests(self): + class Foo(unittest.TestCase): + def setUp(self): + events.append("setUp") + def tearDown(self): + events.append("tearDown") + def test(self): + with self.subTest(a=1): + events.append('subtest') + self.skipTest("skip subtest") + events.append('end subtest') + events.append('end test') + + events = [] + result = LoggingResult(events) + test = Foo("test") + with self.assertRaises(unittest.SkipTest) as cm: + test.debug() + self.assertIn("skip subtest", str(cm.exception)) + self.assertEqual(events, ['setUp', 'subtest']) + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/test_suite.py b/Lib/test/test_unittest/test_suite.py new file mode 100644 index 0000000..ca52ee9 --- /dev/null +++ b/Lib/test/test_unittest/test_suite.py @@ -0,0 +1,447 @@ +import unittest + +import gc +import sys +import weakref +from test.test_unittest.support import LoggingResult, TestEquality + + +### Support code for Test_TestSuite +################################################################ + +class Test(object): + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def test_3(self): pass + def runTest(self): pass + +def _mk_TestSuite(*names): + return unittest.TestSuite(Test.Foo(n) for n in names) + +################################################################ + + +class Test_TestSuite(unittest.TestCase, TestEquality): + + ### Set up attributes needed by inherited tests + ################################################################ + + # Used by TestEquality.test_eq + eq_pairs = [(unittest.TestSuite(), unittest.TestSuite()) + ,(unittest.TestSuite(), unittest.TestSuite([])) + ,(_mk_TestSuite('test_1'), _mk_TestSuite('test_1'))] + + # Used by TestEquality.test_ne + ne_pairs = [(unittest.TestSuite(), _mk_TestSuite('test_1')) + ,(unittest.TestSuite([]), _mk_TestSuite('test_1')) + ,(_mk_TestSuite('test_1', 'test_2'), _mk_TestSuite('test_1', 'test_3')) + ,(_mk_TestSuite('test_1'), _mk_TestSuite('test_2'))] + + ################################################################ + ### /Set up attributes needed by inherited tests + + ### Tests for TestSuite.__init__ + ################################################################ + + # "class TestSuite([tests])" + # + # The tests iterable should be optional + def test_init__tests_optional(self): + suite = unittest.TestSuite() + + self.assertEqual(suite.countTestCases(), 0) + # countTestCases() still works after tests are run + suite.run(unittest.TestResult()) + self.assertEqual(suite.countTestCases(), 0) + + # "class TestSuite([tests])" + # ... + # "If tests is given, it must be an iterable of individual test cases + # or other test suites that will be used to build the suite initially" + # + # TestSuite should deal with empty tests iterables by allowing the + # creation of an empty suite + def test_init__empty_tests(self): + suite = unittest.TestSuite([]) + + self.assertEqual(suite.countTestCases(), 0) + # countTestCases() still works after tests are run + suite.run(unittest.TestResult()) + self.assertEqual(suite.countTestCases(), 0) + + # "class TestSuite([tests])" + # ... + # "If tests is given, it must be an iterable of individual test cases + # or other test suites that will be used to build the suite initially" + # + # TestSuite should allow any iterable to provide tests + def test_init__tests_from_any_iterable(self): + def tests(): + yield unittest.FunctionTestCase(lambda: None) + yield unittest.FunctionTestCase(lambda: None) + + suite_1 = unittest.TestSuite(tests()) + self.assertEqual(suite_1.countTestCases(), 2) + + suite_2 = unittest.TestSuite(suite_1) + self.assertEqual(suite_2.countTestCases(), 2) + + suite_3 = unittest.TestSuite(set(suite_1)) + self.assertEqual(suite_3.countTestCases(), 2) + + # countTestCases() still works after tests are run + suite_1.run(unittest.TestResult()) + self.assertEqual(suite_1.countTestCases(), 2) + suite_2.run(unittest.TestResult()) + self.assertEqual(suite_2.countTestCases(), 2) + suite_3.run(unittest.TestResult()) + self.assertEqual(suite_3.countTestCases(), 2) + + # "class TestSuite([tests])" + # ... + # "If tests is given, it must be an iterable of individual test cases + # or other test suites that will be used to build the suite initially" + # + # Does TestSuite() also allow other TestSuite() instances to be present + # in the tests iterable? + def test_init__TestSuite_instances_in_tests(self): + def tests(): + ftc = unittest.FunctionTestCase(lambda: None) + yield unittest.TestSuite([ftc]) + yield unittest.FunctionTestCase(lambda: None) + + suite = unittest.TestSuite(tests()) + self.assertEqual(suite.countTestCases(), 2) + # countTestCases() still works after tests are run + suite.run(unittest.TestResult()) + self.assertEqual(suite.countTestCases(), 2) + + ################################################################ + ### /Tests for TestSuite.__init__ + + # Container types should support the iter protocol + def test_iter(self): + test1 = unittest.FunctionTestCase(lambda: None) + test2 = unittest.FunctionTestCase(lambda: None) + suite = unittest.TestSuite((test1, test2)) + + self.assertEqual(list(suite), [test1, test2]) + + # "Return the number of tests represented by the this test object. + # ...this method is also implemented by the TestSuite class, which can + # return larger [greater than 1] values" + # + # Presumably an empty TestSuite returns 0? + def test_countTestCases_zero_simple(self): + suite = unittest.TestSuite() + + self.assertEqual(suite.countTestCases(), 0) + + # "Return the number of tests represented by the this test object. + # ...this method is also implemented by the TestSuite class, which can + # return larger [greater than 1] values" + # + # Presumably an empty TestSuite (even if it contains other empty + # TestSuite instances) returns 0? + def test_countTestCases_zero_nested(self): + class Test1(unittest.TestCase): + def test(self): + pass + + suite = unittest.TestSuite([unittest.TestSuite()]) + + self.assertEqual(suite.countTestCases(), 0) + + # "Return the number of tests represented by the this test object. + # ...this method is also implemented by the TestSuite class, which can + # return larger [greater than 1] values" + def test_countTestCases_simple(self): + test1 = unittest.FunctionTestCase(lambda: None) + test2 = unittest.FunctionTestCase(lambda: None) + suite = unittest.TestSuite((test1, test2)) + + self.assertEqual(suite.countTestCases(), 2) + # countTestCases() still works after tests are run + suite.run(unittest.TestResult()) + self.assertEqual(suite.countTestCases(), 2) + + # "Return the number of tests represented by the this test object. + # ...this method is also implemented by the TestSuite class, which can + # return larger [greater than 1] values" + # + # Make sure this holds for nested TestSuite instances, too + def test_countTestCases_nested(self): + class Test1(unittest.TestCase): + def test1(self): pass + def test2(self): pass + + test2 = unittest.FunctionTestCase(lambda: None) + test3 = unittest.FunctionTestCase(lambda: None) + child = unittest.TestSuite((Test1('test2'), test2)) + parent = unittest.TestSuite((test3, child, Test1('test1'))) + + self.assertEqual(parent.countTestCases(), 4) + # countTestCases() still works after tests are run + parent.run(unittest.TestResult()) + self.assertEqual(parent.countTestCases(), 4) + self.assertEqual(child.countTestCases(), 2) + + # "Run the tests associated with this suite, collecting the result into + # the test result object passed as result." + # + # And if there are no tests? What then? + def test_run__empty_suite(self): + events = [] + result = LoggingResult(events) + + suite = unittest.TestSuite() + + suite.run(result) + + self.assertEqual(events, []) + + # "Note that unlike TestCase.run(), TestSuite.run() requires the + # "result object to be passed in." + def test_run__requires_result(self): + suite = unittest.TestSuite() + + try: + suite.run() + except TypeError: + pass + else: + self.fail("Failed to raise TypeError") + + # "Run the tests associated with this suite, collecting the result into + # the test result object passed as result." + def test_run(self): + events = [] + result = LoggingResult(events) + + class LoggingCase(unittest.TestCase): + def run(self, result): + events.append('run %s' % self._testMethodName) + + def test1(self): pass + def test2(self): pass + + tests = [LoggingCase('test1'), LoggingCase('test2')] + + unittest.TestSuite(tests).run(result) + + self.assertEqual(events, ['run test1', 'run test2']) + + # "Add a TestCase ... to the suite" + def test_addTest__TestCase(self): + class Foo(unittest.TestCase): + def test(self): pass + + test = Foo('test') + suite = unittest.TestSuite() + + suite.addTest(test) + + self.assertEqual(suite.countTestCases(), 1) + self.assertEqual(list(suite), [test]) + # countTestCases() still works after tests are run + suite.run(unittest.TestResult()) + self.assertEqual(suite.countTestCases(), 1) + + # "Add a ... TestSuite to the suite" + def test_addTest__TestSuite(self): + class Foo(unittest.TestCase): + def test(self): pass + + suite_2 = unittest.TestSuite([Foo('test')]) + + suite = unittest.TestSuite() + suite.addTest(suite_2) + + self.assertEqual(suite.countTestCases(), 1) + self.assertEqual(list(suite), [suite_2]) + # countTestCases() still works after tests are run + suite.run(unittest.TestResult()) + self.assertEqual(suite.countTestCases(), 1) + + # "Add all the tests from an iterable of TestCase and TestSuite + # instances to this test suite." + # + # "This is equivalent to iterating over tests, calling addTest() for + # each element" + def test_addTests(self): + class Foo(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + + test_1 = Foo('test_1') + test_2 = Foo('test_2') + inner_suite = unittest.TestSuite([test_2]) + + def gen(): + yield test_1 + yield test_2 + yield inner_suite + + suite_1 = unittest.TestSuite() + suite_1.addTests(gen()) + + self.assertEqual(list(suite_1), list(gen())) + + # "This is equivalent to iterating over tests, calling addTest() for + # each element" + suite_2 = unittest.TestSuite() + for t in gen(): + suite_2.addTest(t) + + self.assertEqual(suite_1, suite_2) + + # "Add all the tests from an iterable of TestCase and TestSuite + # instances to this test suite." + # + # What happens if it doesn't get an iterable? + def test_addTest__noniterable(self): + suite = unittest.TestSuite() + + try: + suite.addTests(5) + except TypeError: + pass + else: + self.fail("Failed to raise TypeError") + + def test_addTest__noncallable(self): + suite = unittest.TestSuite() + self.assertRaises(TypeError, suite.addTest, 5) + + def test_addTest__casesuiteclass(self): + suite = unittest.TestSuite() + self.assertRaises(TypeError, suite.addTest, Test_TestSuite) + self.assertRaises(TypeError, suite.addTest, unittest.TestSuite) + + def test_addTests__string(self): + suite = unittest.TestSuite() + self.assertRaises(TypeError, suite.addTests, "foo") + + def test_function_in_suite(self): + def f(_): + pass + suite = unittest.TestSuite() + suite.addTest(f) + + # when the bug is fixed this line will not crash + suite.run(unittest.TestResult()) + + def test_remove_test_at_index(self): + if not unittest.BaseTestSuite._cleanup: + raise unittest.SkipTest("Suite cleanup is disabled") + + suite = unittest.TestSuite() + + suite._tests = [1, 2, 3] + suite._removeTestAtIndex(1) + + self.assertEqual([1, None, 3], suite._tests) + + def test_remove_test_at_index_not_indexable(self): + if not unittest.BaseTestSuite._cleanup: + raise unittest.SkipTest("Suite cleanup is disabled") + + suite = unittest.TestSuite() + suite._tests = None + + # if _removeAtIndex raises for noniterables this next line will break + suite._removeTestAtIndex(2) + + def assert_garbage_collect_test_after_run(self, TestSuiteClass): + if not unittest.BaseTestSuite._cleanup: + raise unittest.SkipTest("Suite cleanup is disabled") + + class Foo(unittest.TestCase): + def test_nothing(self): + pass + + test = Foo('test_nothing') + wref = weakref.ref(test) + + suite = TestSuiteClass([wref()]) + suite.run(unittest.TestResult()) + + del test + + # for the benefit of non-reference counting implementations + gc.collect() + + self.assertEqual(suite._tests, [None]) + self.assertIsNone(wref()) + + def test_garbage_collect_test_after_run_BaseTestSuite(self): + self.assert_garbage_collect_test_after_run(unittest.BaseTestSuite) + + def test_garbage_collect_test_after_run_TestSuite(self): + self.assert_garbage_collect_test_after_run(unittest.TestSuite) + + def test_basetestsuite(self): + class Test(unittest.TestCase): + wasSetUp = False + wasTornDown = False + @classmethod + def setUpClass(cls): + cls.wasSetUp = True + @classmethod + def tearDownClass(cls): + cls.wasTornDown = True + def testPass(self): + pass + def testFail(self): + fail + class Module(object): + wasSetUp = False + wasTornDown = False + @staticmethod + def setUpModule(): + Module.wasSetUp = True + @staticmethod + def tearDownModule(): + Module.wasTornDown = True + + Test.__module__ = 'Module' + sys.modules['Module'] = Module + self.addCleanup(sys.modules.pop, 'Module') + + suite = unittest.BaseTestSuite() + suite.addTests([Test('testPass'), Test('testFail')]) + self.assertEqual(suite.countTestCases(), 2) + + result = unittest.TestResult() + suite.run(result) + self.assertFalse(Module.wasSetUp) + self.assertFalse(Module.wasTornDown) + self.assertFalse(Test.wasSetUp) + self.assertFalse(Test.wasTornDown) + self.assertEqual(len(result.errors), 1) + self.assertEqual(len(result.failures), 0) + self.assertEqual(result.testsRun, 2) + self.assertEqual(suite.countTestCases(), 2) + + + def test_overriding_call(self): + class MySuite(unittest.TestSuite): + called = False + def __call__(self, *args, **kw): + self.called = True + unittest.TestSuite.__call__(self, *args, **kw) + + suite = MySuite() + result = unittest.TestResult() + wrapper = unittest.TestSuite() + wrapper.addTest(suite) + wrapper(result) + self.assertTrue(suite.called) + + # reusing results should be permitted even if abominable + self.assertFalse(result._testRunEntered) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/testmock/__init__.py b/Lib/test/test_unittest/testmock/__init__.py new file mode 100644 index 0000000..6ee60b2 --- /dev/null +++ b/Lib/test/test_unittest/testmock/__init__.py @@ -0,0 +1,17 @@ +import os +import sys +import unittest + + +here = os.path.dirname(__file__) +loader = unittest.defaultTestLoader + +def load_tests(*args): + suite = unittest.TestSuite() + for fn in os.listdir(here): + if fn.startswith("test") and fn.endswith(".py"): + modname = "test.test_unittest.testmock." + fn[:-3] + __import__(modname) + module = sys.modules[modname] + suite.addTest(loader.loadTestsFromModule(module)) + return suite diff --git a/Lib/test/test_unittest/testmock/__main__.py b/Lib/test/test_unittest/testmock/__main__.py new file mode 100644 index 0000000..1e3068b --- /dev/null +++ b/Lib/test/test_unittest/testmock/__main__.py @@ -0,0 +1,18 @@ +import os +import unittest + + +def load_tests(loader, standard_tests, pattern): + # top level directory cached on loader instance + this_dir = os.path.dirname(__file__) + pattern = pattern or "test*.py" + # We are inside test.test_unittest.testmock, so the top-level is three notches up + top_level_dir = os.path.dirname(os.path.dirname(os.path.dirname(this_dir))) + package_tests = loader.discover(start_dir=this_dir, pattern=pattern, + top_level_dir=top_level_dir) + standard_tests.addTests(package_tests) + return standard_tests + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/testmock/support.py b/Lib/test/test_unittest/testmock/support.py new file mode 100644 index 0000000..49986d6 --- /dev/null +++ b/Lib/test/test_unittest/testmock/support.py @@ -0,0 +1,16 @@ +target = {'foo': 'FOO'} + + +def is_instance(obj, klass): + """Version of is_instance that doesn't access __class__""" + return issubclass(type(obj), klass) + + +class SomeClass(object): + class_attribute = None + + def wibble(self): pass + + +class X(object): + pass diff --git a/Lib/test/test_unittest/testmock/testasync.py b/Lib/test/test_unittest/testmock/testasync.py new file mode 100644 index 0000000..1bab671 --- /dev/null +++ b/Lib/test/test_unittest/testmock/testasync.py @@ -0,0 +1,1061 @@ +import asyncio +import gc +import inspect +import re +import unittest +from contextlib import contextmanager +from test import support + +support.requires_working_socket(module=True) + +from asyncio import run, iscoroutinefunction +from unittest import IsolatedAsyncioTestCase +from unittest.mock import (ANY, call, AsyncMock, patch, MagicMock, Mock, + create_autospec, sentinel, _CallList) + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class AsyncClass: + def __init__(self): pass + async def async_method(self): pass + def normal_method(self): pass + + @classmethod + async def async_class_method(cls): pass + + @staticmethod + async def async_static_method(): pass + + +class AwaitableClass: + def __await__(self): yield + +async def async_func(): pass + +async def async_func_args(a, b, *, c): pass + +def normal_func(): pass + +class NormalClass(object): + def a(self): pass + + +async_foo_name = f'{__name__}.AsyncClass' +normal_foo_name = f'{__name__}.NormalClass' + + +@contextmanager +def assertNeverAwaited(test): + with test.assertWarnsRegex(RuntimeWarning, "was never awaited$"): + yield + # In non-CPython implementations of Python, this is needed because timely + # deallocation is not guaranteed by the garbage collector. + gc.collect() + + +class AsyncPatchDecoratorTest(unittest.TestCase): + def test_is_coroutine_function_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + self.assertTrue(iscoroutinefunction(mock_method)) + test_async() + + def test_is_async_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + run(m) + + @patch(f'{async_foo_name}.async_method') + def test_no_parent_attribute(mock_method): + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + run(m) + + test_async() + test_no_parent_attribute() + + def test_is_AsyncMock_patch(self): + @patch.object(AsyncClass, 'async_method') + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_is_AsyncMock_patch_staticmethod(self): + @patch.object(AsyncClass, 'async_static_method') + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_is_AsyncMock_patch_classmethod(self): + @patch.object(AsyncClass, 'async_class_method') + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_async_def_patch(self): + @patch(f"{__name__}.async_func", return_value=1) + @patch(f"{__name__}.async_func_args", return_value=2) + async def test_async(func_args_mock, func_mock): + self.assertEqual(func_args_mock._mock_name, "async_func_args") + self.assertEqual(func_mock._mock_name, "async_func") + + self.assertIsInstance(async_func, AsyncMock) + self.assertIsInstance(async_func_args, AsyncMock) + + self.assertEqual(await async_func(), 1) + self.assertEqual(await async_func_args(1, 2, c=3), 2) + + run(test_async()) + self.assertTrue(inspect.iscoroutinefunction(async_func)) + + +class AsyncPatchCMTest(unittest.TestCase): + def test_is_async_function_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + self.assertTrue(iscoroutinefunction(mock_method)) + + test_async() + + def test_is_async_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + m = mock_method() + self.assertTrue(inspect.isawaitable(m)) + run(m) + + test_async() + + def test_is_AsyncMock_cm(self): + def test_async(): + with patch.object(AsyncClass, 'async_method') as mock_method: + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_async_def_cm(self): + async def test_async(): + with patch(f"{__name__}.async_func", AsyncMock()): + self.assertIsInstance(async_func, AsyncMock) + self.assertTrue(inspect.iscoroutinefunction(async_func)) + + run(test_async()) + + +class AsyncMockTest(unittest.TestCase): + def test_iscoroutinefunction_default(self): + mock = AsyncMock() + self.assertTrue(iscoroutinefunction(mock)) + + def test_iscoroutinefunction_function(self): + async def foo(): pass + mock = AsyncMock(foo) + self.assertTrue(iscoroutinefunction(mock)) + self.assertTrue(inspect.iscoroutinefunction(mock)) + + def test_isawaitable(self): + mock = AsyncMock() + m = mock() + self.assertTrue(inspect.isawaitable(m)) + run(m) + self.assertIn('assert_awaited', dir(mock)) + + def test_iscoroutinefunction_normal_function(self): + def foo(): pass + mock = AsyncMock(foo) + self.assertTrue(iscoroutinefunction(mock)) + self.assertTrue(inspect.iscoroutinefunction(mock)) + + def test_future_isfuture(self): + loop = asyncio.new_event_loop() + fut = loop.create_future() + loop.stop() + loop.close() + mock = AsyncMock(fut) + self.assertIsInstance(mock, asyncio.Future) + + +class AsyncAutospecTest(unittest.TestCase): + def test_is_AsyncMock_patch(self): + @patch(async_foo_name, autospec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method.async_method, AsyncMock) + self.assertIsInstance(mock_method, MagicMock) + + @patch(async_foo_name, autospec=True) + def test_normal_method(mock_method): + self.assertIsInstance(mock_method.normal_method, MagicMock) + + test_async() + test_normal_method() + + def test_create_autospec_instance(self): + with self.assertRaises(RuntimeError): + create_autospec(async_func, instance=True) + + @unittest.skip('Broken test from https://bugs.python.org/issue37251') + def test_create_autospec_awaitable_class(self): + self.assertIsInstance(create_autospec(AwaitableClass), AsyncMock) + + def test_create_autospec(self): + spec = create_autospec(async_func_args) + awaitable = spec(1, 2, c=3) + async def main(): + await awaitable + + self.assertEqual(spec.await_count, 0) + self.assertIsNone(spec.await_args) + self.assertEqual(spec.await_args_list, []) + spec.assert_not_awaited() + + run(main()) + + self.assertTrue(iscoroutinefunction(spec)) + self.assertTrue(asyncio.iscoroutine(awaitable)) + self.assertEqual(spec.await_count, 1) + self.assertEqual(spec.await_args, call(1, 2, c=3)) + self.assertEqual(spec.await_args_list, [call(1, 2, c=3)]) + spec.assert_awaited_once() + spec.assert_awaited_once_with(1, 2, c=3) + spec.assert_awaited_with(1, 2, c=3) + spec.assert_awaited() + + with self.assertRaises(AssertionError): + spec.assert_any_await(e=1) + + + def test_patch_with_autospec(self): + + async def test_async(): + with patch(f"{__name__}.async_func_args", autospec=True) as mock_method: + awaitable = mock_method(1, 2, c=3) + self.assertIsInstance(mock_method.mock, AsyncMock) + + self.assertTrue(iscoroutinefunction(mock_method)) + self.assertTrue(asyncio.iscoroutine(awaitable)) + self.assertTrue(inspect.isawaitable(awaitable)) + + # Verify the default values during mock setup + self.assertEqual(mock_method.await_count, 0) + self.assertEqual(mock_method.await_args_list, []) + self.assertIsNone(mock_method.await_args) + mock_method.assert_not_awaited() + + await awaitable + + self.assertEqual(mock_method.await_count, 1) + self.assertEqual(mock_method.await_args, call(1, 2, c=3)) + self.assertEqual(mock_method.await_args_list, [call(1, 2, c=3)]) + mock_method.assert_awaited_once() + mock_method.assert_awaited_once_with(1, 2, c=3) + mock_method.assert_awaited_with(1, 2, c=3) + mock_method.assert_awaited() + + mock_method.reset_mock() + self.assertEqual(mock_method.await_count, 0) + self.assertIsNone(mock_method.await_args) + self.assertEqual(mock_method.await_args_list, []) + + run(test_async()) + + +class AsyncSpecTest(unittest.TestCase): + def test_spec_normal_methods_on_class(self): + def inner_test(mock_type): + mock = mock_type(AsyncClass) + self.assertIsInstance(mock.async_method, AsyncMock) + self.assertIsInstance(mock.normal_method, MagicMock) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"test method types with {mock_type}"): + inner_test(mock_type) + + def test_spec_normal_methods_on_class_with_mock(self): + mock = Mock(AsyncClass) + self.assertIsInstance(mock.async_method, AsyncMock) + self.assertIsInstance(mock.normal_method, Mock) + + def test_spec_mock_type_kw(self): + def inner_test(mock_type): + async_mock = mock_type(spec=async_func) + self.assertIsInstance(async_mock, mock_type) + with assertNeverAwaited(self): + self.assertTrue(inspect.isawaitable(async_mock())) + + sync_mock = mock_type(spec=normal_func) + self.assertIsInstance(sync_mock, mock_type) + + for mock_type in [AsyncMock, MagicMock, Mock]: + with self.subTest(f"test spec kwarg with {mock_type}"): + inner_test(mock_type) + + def test_spec_mock_type_positional(self): + def inner_test(mock_type): + async_mock = mock_type(async_func) + self.assertIsInstance(async_mock, mock_type) + with assertNeverAwaited(self): + self.assertTrue(inspect.isawaitable(async_mock())) + + sync_mock = mock_type(normal_func) + self.assertIsInstance(sync_mock, mock_type) + + for mock_type in [AsyncMock, MagicMock, Mock]: + with self.subTest(f"test spec positional with {mock_type}"): + inner_test(mock_type) + + def test_spec_as_normal_kw_AsyncMock(self): + mock = AsyncMock(spec=normal_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + run(m) + + def test_spec_as_normal_positional_AsyncMock(self): + mock = AsyncMock(normal_func) + self.assertIsInstance(mock, AsyncMock) + m = mock() + self.assertTrue(inspect.isawaitable(m)) + run(m) + + def test_spec_async_mock(self): + @patch.object(AsyncClass, 'async_method', spec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method, AsyncMock) + + test_async() + + def test_spec_parent_not_async_attribute_is(self): + @patch(async_foo_name, spec=True) + def test_async(mock_method): + self.assertIsInstance(mock_method, MagicMock) + self.assertIsInstance(mock_method.async_method, AsyncMock) + + test_async() + + def test_target_async_spec_not(self): + @patch.object(AsyncClass, 'async_method', spec=NormalClass.a) + def test_async_attribute(mock_method): + self.assertIsInstance(mock_method, MagicMock) + self.assertFalse(inspect.iscoroutine(mock_method)) + self.assertFalse(inspect.isawaitable(mock_method)) + + test_async_attribute() + + def test_target_not_async_spec_is(self): + @patch.object(NormalClass, 'a', spec=async_func) + def test_attribute_not_async_spec_is(mock_async_func): + self.assertIsInstance(mock_async_func, AsyncMock) + test_attribute_not_async_spec_is() + + def test_spec_async_attributes(self): + @patch(normal_foo_name, spec=AsyncClass) + def test_async_attributes_coroutines(MockNormalClass): + self.assertIsInstance(MockNormalClass.async_method, AsyncMock) + self.assertIsInstance(MockNormalClass, MagicMock) + + test_async_attributes_coroutines() + + +class AsyncSpecSetTest(unittest.TestCase): + def test_is_AsyncMock_patch(self): + @patch.object(AsyncClass, 'async_method', spec_set=True) + def test_async(async_method): + self.assertIsInstance(async_method, AsyncMock) + test_async() + + def test_is_async_AsyncMock(self): + mock = AsyncMock(spec_set=AsyncClass.async_method) + self.assertTrue(iscoroutinefunction(mock)) + self.assertIsInstance(mock, AsyncMock) + + def test_is_child_AsyncMock(self): + mock = MagicMock(spec_set=AsyncClass) + self.assertTrue(iscoroutinefunction(mock.async_method)) + self.assertFalse(iscoroutinefunction(mock.normal_method)) + self.assertIsInstance(mock.async_method, AsyncMock) + self.assertIsInstance(mock.normal_method, MagicMock) + self.assertIsInstance(mock, MagicMock) + + def test_magicmock_lambda_spec(self): + mock_obj = MagicMock() + mock_obj.mock_func = MagicMock(spec=lambda x: x) + + with patch.object(mock_obj, "mock_func") as cm: + self.assertIsInstance(cm, MagicMock) + + +class AsyncArguments(IsolatedAsyncioTestCase): + async def test_add_return_value(self): + async def addition(self, var): pass + + mock = AsyncMock(addition, return_value=10) + output = await mock(5) + + self.assertEqual(output, 10) + + async def test_add_side_effect_exception(self): + async def addition(var): pass + mock = AsyncMock(addition, side_effect=Exception('err')) + with self.assertRaises(Exception): + await mock(5) + + async def test_add_side_effect_coroutine(self): + async def addition(var): + return var + 1 + mock = AsyncMock(side_effect=addition) + result = await mock(5) + self.assertEqual(result, 6) + + async def test_add_side_effect_normal_function(self): + def addition(var): + return var + 1 + mock = AsyncMock(side_effect=addition) + result = await mock(5) + self.assertEqual(result, 6) + + async def test_add_side_effect_iterable(self): + vals = [1, 2, 3] + mock = AsyncMock(side_effect=vals) + for item in vals: + self.assertEqual(await mock(), item) + + with self.assertRaises(StopAsyncIteration) as e: + await mock() + + async def test_add_side_effect_exception_iterable(self): + class SampleException(Exception): + pass + + vals = [1, SampleException("foo")] + mock = AsyncMock(side_effect=vals) + self.assertEqual(await mock(), 1) + + with self.assertRaises(SampleException) as e: + await mock() + + async def test_return_value_AsyncMock(self): + value = AsyncMock(return_value=10) + mock = AsyncMock(return_value=value) + result = await mock() + self.assertIs(result, value) + + async def test_return_value_awaitable(self): + fut = asyncio.Future() + fut.set_result(None) + mock = AsyncMock(return_value=fut) + result = await mock() + self.assertIsInstance(result, asyncio.Future) + + async def test_side_effect_awaitable_values(self): + fut = asyncio.Future() + fut.set_result(None) + + mock = AsyncMock(side_effect=[fut]) + result = await mock() + self.assertIsInstance(result, asyncio.Future) + + with self.assertRaises(StopAsyncIteration): + await mock() + + async def test_side_effect_is_AsyncMock(self): + effect = AsyncMock(return_value=10) + mock = AsyncMock(side_effect=effect) + + result = await mock() + self.assertEqual(result, 10) + + async def test_wraps_coroutine(self): + value = asyncio.Future() + + ran = False + async def inner(): + nonlocal ran + ran = True + return value + + mock = AsyncMock(wraps=inner) + result = await mock() + self.assertEqual(result, value) + mock.assert_awaited() + self.assertTrue(ran) + + async def test_wraps_normal_function(self): + value = 1 + + ran = False + def inner(): + nonlocal ran + ran = True + return value + + mock = AsyncMock(wraps=inner) + result = await mock() + self.assertEqual(result, value) + mock.assert_awaited() + self.assertTrue(ran) + + async def test_await_args_list_order(self): + async_mock = AsyncMock() + mock2 = async_mock(2) + mock1 = async_mock(1) + await mock1 + await mock2 + async_mock.assert_has_awaits([call(1), call(2)]) + self.assertEqual(async_mock.await_args_list, [call(1), call(2)]) + self.assertEqual(async_mock.call_args_list, [call(2), call(1)]) + + +class AsyncMagicMethods(unittest.TestCase): + def test_async_magic_methods_return_async_mocks(self): + m_mock = MagicMock() + self.assertIsInstance(m_mock.__aenter__, AsyncMock) + self.assertIsInstance(m_mock.__aexit__, AsyncMock) + self.assertIsInstance(m_mock.__anext__, AsyncMock) + # __aiter__ is actually a synchronous object + # so should return a MagicMock + self.assertIsInstance(m_mock.__aiter__, MagicMock) + + def test_sync_magic_methods_return_magic_mocks(self): + a_mock = AsyncMock() + self.assertIsInstance(a_mock.__enter__, MagicMock) + self.assertIsInstance(a_mock.__exit__, MagicMock) + self.assertIsInstance(a_mock.__next__, MagicMock) + self.assertIsInstance(a_mock.__len__, MagicMock) + + def test_magicmock_has_async_magic_methods(self): + m_mock = MagicMock() + self.assertTrue(hasattr(m_mock, "__aenter__")) + self.assertTrue(hasattr(m_mock, "__aexit__")) + self.assertTrue(hasattr(m_mock, "__anext__")) + + def test_asyncmock_has_sync_magic_methods(self): + a_mock = AsyncMock() + self.assertTrue(hasattr(a_mock, "__enter__")) + self.assertTrue(hasattr(a_mock, "__exit__")) + self.assertTrue(hasattr(a_mock, "__next__")) + self.assertTrue(hasattr(a_mock, "__len__")) + + def test_magic_methods_are_async_functions(self): + m_mock = MagicMock() + self.assertIsInstance(m_mock.__aenter__, AsyncMock) + self.assertIsInstance(m_mock.__aexit__, AsyncMock) + # AsyncMocks are also coroutine functions + self.assertTrue(iscoroutinefunction(m_mock.__aenter__)) + self.assertTrue(iscoroutinefunction(m_mock.__aexit__)) + +class AsyncContextManagerTest(unittest.TestCase): + + class WithAsyncContextManager: + async def __aenter__(self, *args, **kwargs): pass + + async def __aexit__(self, *args, **kwargs): pass + + class WithSyncContextManager: + def __enter__(self, *args, **kwargs): pass + + def __exit__(self, *args, **kwargs): pass + + class ProductionCode: + # Example real-world(ish) code + def __init__(self): + self.session = None + + async def main(self): + async with self.session.post('https://python.org') as response: + val = await response.json() + return val + + def test_set_return_value_of_aenter(self): + def inner_test(mock_type): + pc = self.ProductionCode() + pc.session = MagicMock(name='sessionmock') + cm = mock_type(name='magic_cm') + response = AsyncMock(name='response') + response.json = AsyncMock(return_value={'json': 123}) + cm.__aenter__.return_value = response + pc.session.post.return_value = cm + result = run(pc.main()) + self.assertEqual(result, {'json': 123}) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"test set return value of aenter with {mock_type}"): + inner_test(mock_type) + + def test_mock_supports_async_context_manager(self): + def inner_test(mock_type): + called = False + cm = self.WithAsyncContextManager() + cm_mock = mock_type(cm) + + async def use_context_manager(): + nonlocal called + async with cm_mock as result: + called = True + return result + + cm_result = run(use_context_manager()) + self.assertTrue(called) + self.assertTrue(cm_mock.__aenter__.called) + self.assertTrue(cm_mock.__aexit__.called) + cm_mock.__aenter__.assert_awaited() + cm_mock.__aexit__.assert_awaited() + # We mock __aenter__ so it does not return self + self.assertIsNot(cm_mock, cm_result) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"test context manager magics with {mock_type}"): + inner_test(mock_type) + + + def test_mock_customize_async_context_manager(self): + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + expected_result = object() + mock_instance.__aenter__.return_value = expected_result + + async def use_context_manager(): + async with mock_instance as result: + return result + + self.assertIs(run(use_context_manager()), expected_result) + + def test_mock_customize_async_context_manager_with_coroutine(self): + enter_called = False + exit_called = False + + async def enter_coroutine(*args): + nonlocal enter_called + enter_called = True + + async def exit_coroutine(*args): + nonlocal exit_called + exit_called = True + + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + + mock_instance.__aenter__ = enter_coroutine + mock_instance.__aexit__ = exit_coroutine + + async def use_context_manager(): + async with mock_instance: + pass + + run(use_context_manager()) + self.assertTrue(enter_called) + self.assertTrue(exit_called) + + def test_context_manager_raise_exception_by_default(self): + async def raise_in(context_manager): + async with context_manager: + raise TypeError() + + instance = self.WithAsyncContextManager() + mock_instance = MagicMock(instance) + with self.assertRaises(TypeError): + run(raise_in(mock_instance)) + + +class AsyncIteratorTest(unittest.TestCase): + class WithAsyncIterator(object): + def __init__(self): + self.items = ["foo", "NormalFoo", "baz"] + + def __aiter__(self): pass + + async def __anext__(self): pass + + def test_aiter_set_return_value(self): + mock_iter = AsyncMock(name="tester") + mock_iter.__aiter__.return_value = [1, 2, 3] + async def main(): + return [i async for i in mock_iter] + result = run(main()) + self.assertEqual(result, [1, 2, 3]) + + def test_mock_aiter_and_anext_asyncmock(self): + def inner_test(mock_type): + instance = self.WithAsyncIterator() + mock_instance = mock_type(instance) + # Check that the mock and the real thing bahave the same + # __aiter__ is not actually async, so not a coroutinefunction + self.assertFalse(iscoroutinefunction(instance.__aiter__)) + self.assertFalse(iscoroutinefunction(mock_instance.__aiter__)) + # __anext__ is async + self.assertTrue(iscoroutinefunction(instance.__anext__)) + self.assertTrue(iscoroutinefunction(mock_instance.__anext__)) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"test aiter and anext corourtine with {mock_type}"): + inner_test(mock_type) + + + def test_mock_async_for(self): + async def iterate(iterator): + accumulator = [] + async for item in iterator: + accumulator.append(item) + + return accumulator + + expected = ["FOO", "BAR", "BAZ"] + def test_default(mock_type): + mock_instance = mock_type(self.WithAsyncIterator()) + self.assertEqual(run(iterate(mock_instance)), []) + + + def test_set_return_value(mock_type): + mock_instance = mock_type(self.WithAsyncIterator()) + mock_instance.__aiter__.return_value = expected[:] + self.assertEqual(run(iterate(mock_instance)), expected) + + def test_set_return_value_iter(mock_type): + mock_instance = mock_type(self.WithAsyncIterator()) + mock_instance.__aiter__.return_value = iter(expected[:]) + self.assertEqual(run(iterate(mock_instance)), expected) + + for mock_type in [AsyncMock, MagicMock]: + with self.subTest(f"default value with {mock_type}"): + test_default(mock_type) + + with self.subTest(f"set return_value with {mock_type}"): + test_set_return_value(mock_type) + + with self.subTest(f"set return_value iterator with {mock_type}"): + test_set_return_value_iter(mock_type) + + +class AsyncMockAssert(unittest.TestCase): + def setUp(self): + self.mock = AsyncMock() + + async def _runnable_test(self, *args, **kwargs): + await self.mock(*args, **kwargs) + + async def _await_coroutine(self, coroutine): + return await coroutine + + def test_assert_called_but_not_awaited(self): + mock = AsyncMock(AsyncClass) + with assertNeverAwaited(self): + mock.async_method() + self.assertTrue(iscoroutinefunction(mock.async_method)) + mock.async_method.assert_called() + mock.async_method.assert_called_once() + mock.async_method.assert_called_once_with() + with self.assertRaises(AssertionError): + mock.assert_awaited() + with self.assertRaises(AssertionError): + mock.async_method.assert_awaited() + + def test_assert_called_then_awaited(self): + mock = AsyncMock(AsyncClass) + mock_coroutine = mock.async_method() + mock.async_method.assert_called() + mock.async_method.assert_called_once() + mock.async_method.assert_called_once_with() + with self.assertRaises(AssertionError): + mock.async_method.assert_awaited() + + run(self._await_coroutine(mock_coroutine)) + # Assert we haven't re-called the function + mock.async_method.assert_called_once() + mock.async_method.assert_awaited() + mock.async_method.assert_awaited_once() + mock.async_method.assert_awaited_once_with() + + def test_assert_called_and_awaited_at_same_time(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + + with self.assertRaises(AssertionError): + self.mock.assert_called() + + run(self._runnable_test()) + self.mock.assert_called_once() + self.mock.assert_awaited_once() + + def test_assert_called_twice_and_awaited_once(self): + mock = AsyncMock(AsyncClass) + coroutine = mock.async_method() + # The first call will be awaited so no warning there + # But this call will never get awaited, so it will warn here + with assertNeverAwaited(self): + mock.async_method() + with self.assertRaises(AssertionError): + mock.async_method.assert_awaited() + mock.async_method.assert_called() + run(self._await_coroutine(coroutine)) + mock.async_method.assert_awaited() + mock.async_method.assert_awaited_once() + + def test_assert_called_once_and_awaited_twice(self): + mock = AsyncMock(AsyncClass) + coroutine = mock.async_method() + mock.async_method.assert_called_once() + run(self._await_coroutine(coroutine)) + with self.assertRaises(RuntimeError): + # Cannot reuse already awaited coroutine + run(self._await_coroutine(coroutine)) + mock.async_method.assert_awaited() + + def test_assert_awaited_but_not_called(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + with self.assertRaises(AssertionError): + self.mock.assert_called() + with self.assertRaises(TypeError): + # You cannot await an AsyncMock, it must be a coroutine + run(self._await_coroutine(self.mock)) + + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + with self.assertRaises(AssertionError): + self.mock.assert_called() + + def test_assert_has_calls_not_awaits(self): + kalls = [call('foo')] + with assertNeverAwaited(self): + self.mock('foo') + self.mock.assert_has_calls(kalls) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(kalls) + + def test_assert_has_mock_calls_on_async_mock_no_spec(self): + with assertNeverAwaited(self): + self.mock() + kalls_empty = [('', (), {})] + self.assertEqual(self.mock.mock_calls, kalls_empty) + + with assertNeverAwaited(self): + self.mock('foo') + with assertNeverAwaited(self): + self.mock('baz') + mock_kalls = ([call(), call('foo'), call('baz')]) + self.assertEqual(self.mock.mock_calls, mock_kalls) + + def test_assert_has_mock_calls_on_async_mock_with_spec(self): + a_class_mock = AsyncMock(AsyncClass) + with assertNeverAwaited(self): + a_class_mock.async_method() + kalls_empty = [('', (), {})] + self.assertEqual(a_class_mock.async_method.mock_calls, kalls_empty) + self.assertEqual(a_class_mock.mock_calls, [call.async_method()]) + + with assertNeverAwaited(self): + a_class_mock.async_method(1, 2, 3, a=4, b=5) + method_kalls = [call(), call(1, 2, 3, a=4, b=5)] + mock_kalls = [call.async_method(), call.async_method(1, 2, 3, a=4, b=5)] + self.assertEqual(a_class_mock.async_method.mock_calls, method_kalls) + self.assertEqual(a_class_mock.mock_calls, mock_kalls) + + def test_async_method_calls_recorded(self): + with assertNeverAwaited(self): + self.mock.something(3, fish=None) + with assertNeverAwaited(self): + self.mock.something_else.something(6, cake=sentinel.Cake) + + self.assertEqual(self.mock.method_calls, [ + ("something", (3,), {'fish': None}), + ("something_else.something", (6,), {'cake': sentinel.Cake}) + ], + "method calls not recorded correctly") + self.assertEqual(self.mock.something_else.method_calls, + [("something", (6,), {'cake': sentinel.Cake})], + "method calls not recorded correctly") + + def test_async_arg_lists(self): + def assert_attrs(mock): + names = ('call_args_list', 'method_calls', 'mock_calls') + for name in names: + attr = getattr(mock, name) + self.assertIsInstance(attr, _CallList) + self.assertIsInstance(attr, list) + self.assertEqual(attr, []) + + assert_attrs(self.mock) + with assertNeverAwaited(self): + self.mock() + with assertNeverAwaited(self): + self.mock(1, 2) + with assertNeverAwaited(self): + self.mock(a=3) + + self.mock.reset_mock() + assert_attrs(self.mock) + + a_mock = AsyncMock(AsyncClass) + with assertNeverAwaited(self): + a_mock.async_method() + with assertNeverAwaited(self): + a_mock.async_method(1, a=3) + + a_mock.reset_mock() + assert_attrs(a_mock) + + def test_assert_awaited(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited() + + run(self._runnable_test()) + self.mock.assert_awaited() + + def test_assert_awaited_once(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once() + + run(self._runnable_test()) + self.mock.assert_awaited_once() + + run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once() + + def test_assert_awaited_with(self): + msg = 'Not awaited' + with self.assertRaisesRegex(AssertionError, msg): + self.mock.assert_awaited_with('foo') + + run(self._runnable_test()) + msg = 'expected await not found' + with self.assertRaisesRegex(AssertionError, msg): + self.mock.assert_awaited_with('foo') + + run(self._runnable_test('foo')) + self.mock.assert_awaited_with('foo') + + run(self._runnable_test('SomethingElse')) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_with('foo') + + def test_assert_awaited_once_with(self): + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once_with('foo') + + run(self._runnable_test('foo')) + self.mock.assert_awaited_once_with('foo') + + run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_awaited_once_with('foo') + + def test_assert_any_wait(self): + with self.assertRaises(AssertionError): + self.mock.assert_any_await('foo') + + run(self._runnable_test('baz')) + with self.assertRaises(AssertionError): + self.mock.assert_any_await('foo') + + run(self._runnable_test('foo')) + self.mock.assert_any_await('foo') + + run(self._runnable_test('SomethingElse')) + self.mock.assert_any_await('foo') + + def test_assert_has_awaits_no_order(self): + calls = [call('foo'), call('baz')] + + with self.assertRaises(AssertionError) as cm: + self.mock.assert_has_awaits(calls) + self.assertEqual(len(cm.exception.args), 1) + + run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + run(self._runnable_test('foo')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls) + + run(self._runnable_test('baz')) + self.mock.assert_has_awaits(calls) + + run(self._runnable_test('SomethingElse')) + self.mock.assert_has_awaits(calls) + + def test_awaits_asserts_with_any(self): + class Foo: + def __eq__(self, other): pass + + run(self._runnable_test(Foo(), 1)) + + self.mock.assert_has_awaits([call(ANY, 1)]) + self.mock.assert_awaited_with(ANY, 1) + self.mock.assert_any_await(ANY, 1) + + def test_awaits_asserts_with_spec_and_any(self): + class Foo: + def __eq__(self, other): pass + + mock_with_spec = AsyncMock(spec=Foo) + + async def _custom_mock_runnable_test(*args): + await mock_with_spec(*args) + + run(_custom_mock_runnable_test(Foo(), 1)) + mock_with_spec.assert_has_awaits([call(ANY, 1)]) + mock_with_spec.assert_awaited_with(ANY, 1) + mock_with_spec.assert_any_await(ANY, 1) + + def test_assert_has_awaits_ordered(self): + calls = [call('foo'), call('baz')] + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + run(self._runnable_test('baz')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + run(self._runnable_test('bamf')) + with self.assertRaises(AssertionError): + self.mock.assert_has_awaits(calls, any_order=True) + + run(self._runnable_test('foo')) + self.mock.assert_has_awaits(calls, any_order=True) + + run(self._runnable_test('qux')) + self.mock.assert_has_awaits(calls, any_order=True) + + def test_assert_not_awaited(self): + self.mock.assert_not_awaited() + + run(self._runnable_test()) + with self.assertRaises(AssertionError): + self.mock.assert_not_awaited() + + def test_assert_has_awaits_not_matching_spec_error(self): + async def f(x=None): pass + + self.mock = AsyncMock(spec=f) + run(self._runnable_test(1)) + + with self.assertRaisesRegex( + AssertionError, + '^{}$'.format( + re.escape('Awaits not found.\n' + 'Expected: [call()]\n' + 'Actual: [call(1)]'))) as cm: + self.mock.assert_has_awaits([call()]) + self.assertIsNone(cm.exception.__cause__) + + with self.assertRaisesRegex( + AssertionError, + '^{}$'.format( + re.escape( + 'Error processing expected awaits.\n' + "Errors: [None, TypeError('too many positional " + "arguments')]\n" + 'Expected: [call(), call(1, 2)]\n' + 'Actual: [call(1)]'))) as cm: + self.mock.assert_has_awaits([call(), call(1, 2)]) + self.assertIsInstance(cm.exception.__cause__, TypeError) diff --git a/Lib/test/test_unittest/testmock/testcallable.py b/Lib/test/test_unittest/testmock/testcallable.py new file mode 100644 index 0000000..ca88511 --- /dev/null +++ b/Lib/test/test_unittest/testmock/testcallable.py @@ -0,0 +1,150 @@ +# Copyright (C) 2007-2012 Michael Foord & the mock team +# E-mail: fuzzyman AT voidspace DOT org DOT uk +# http://www.voidspace.org.uk/python/mock/ + +import unittest +from test.test_unittest.testmock.support import is_instance, X, SomeClass + +from unittest.mock import ( + Mock, MagicMock, NonCallableMagicMock, + NonCallableMock, patch, create_autospec, + CallableMixin +) + + + +class TestCallable(unittest.TestCase): + + def assertNotCallable(self, mock): + self.assertTrue(is_instance(mock, NonCallableMagicMock)) + self.assertFalse(is_instance(mock, CallableMixin)) + + + def test_non_callable(self): + for mock in NonCallableMagicMock(), NonCallableMock(): + self.assertRaises(TypeError, mock) + self.assertFalse(hasattr(mock, '__call__')) + self.assertIn(mock.__class__.__name__, repr(mock)) + + + def test_hierarchy(self): + self.assertTrue(issubclass(MagicMock, Mock)) + self.assertTrue(issubclass(NonCallableMagicMock, NonCallableMock)) + + + def test_attributes(self): + one = NonCallableMock() + self.assertTrue(issubclass(type(one.one), Mock)) + + two = NonCallableMagicMock() + self.assertTrue(issubclass(type(two.two), MagicMock)) + + + def test_subclasses(self): + class MockSub(Mock): + pass + + one = MockSub() + self.assertTrue(issubclass(type(one.one), MockSub)) + + class MagicSub(MagicMock): + pass + + two = MagicSub() + self.assertTrue(issubclass(type(two.two), MagicSub)) + + + def test_patch_spec(self): + patcher = patch('%s.X' % __name__, spec=True) + mock = patcher.start() + self.addCleanup(patcher.stop) + + instance = mock() + mock.assert_called_once_with() + + self.assertNotCallable(instance) + self.assertRaises(TypeError, instance) + + + def test_patch_spec_set(self): + patcher = patch('%s.X' % __name__, spec_set=True) + mock = patcher.start() + self.addCleanup(patcher.stop) + + instance = mock() + mock.assert_called_once_with() + + self.assertNotCallable(instance) + self.assertRaises(TypeError, instance) + + + def test_patch_spec_instance(self): + patcher = patch('%s.X' % __name__, spec=X()) + mock = patcher.start() + self.addCleanup(patcher.stop) + + self.assertNotCallable(mock) + self.assertRaises(TypeError, mock) + + + def test_patch_spec_set_instance(self): + patcher = patch('%s.X' % __name__, spec_set=X()) + mock = patcher.start() + self.addCleanup(patcher.stop) + + self.assertNotCallable(mock) + self.assertRaises(TypeError, mock) + + + def test_patch_spec_callable_class(self): + class CallableX(X): + def __call__(self): pass + + class Sub(CallableX): + pass + + class Multi(SomeClass, Sub): + pass + + for arg in 'spec', 'spec_set': + for Klass in CallableX, Sub, Multi: + with patch('%s.X' % __name__, **{arg: Klass}) as mock: + instance = mock() + mock.assert_called_once_with() + + self.assertTrue(is_instance(instance, MagicMock)) + # inherited spec + self.assertRaises(AttributeError, getattr, instance, + 'foobarbaz') + + result = instance() + # instance is callable, result has no spec + instance.assert_called_once_with() + + result(3, 2, 1) + result.assert_called_once_with(3, 2, 1) + result.foo(3, 2, 1) + result.foo.assert_called_once_with(3, 2, 1) + + + def test_create_autospec(self): + mock = create_autospec(X) + instance = mock() + self.assertRaises(TypeError, instance) + + mock = create_autospec(X()) + self.assertRaises(TypeError, mock) + + + def test_create_autospec_instance(self): + mock = create_autospec(SomeClass, instance=True) + + self.assertRaises(TypeError, mock) + mock.wibble() + mock.wibble.assert_called_once_with() + + self.assertRaises(TypeError, mock.wibble, 'some', 'args') + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/testmock/testhelpers.py b/Lib/test/test_unittest/testmock/testhelpers.py new file mode 100644 index 0000000..9e7ec5d --- /dev/null +++ b/Lib/test/test_unittest/testmock/testhelpers.py @@ -0,0 +1,1127 @@ +import inspect +import time +import types +import unittest + +from unittest.mock import ( + call, _Call, create_autospec, MagicMock, + Mock, ANY, _CallList, patch, PropertyMock, _callable +) + +from datetime import datetime +from functools import partial + +class SomeClass(object): + def one(self, a, b): pass + def two(self): pass + def three(self, a=None): pass + + + +class AnyTest(unittest.TestCase): + + def test_any(self): + self.assertEqual(ANY, object()) + + mock = Mock() + mock(ANY) + mock.assert_called_with(ANY) + + mock = Mock() + mock(foo=ANY) + mock.assert_called_with(foo=ANY) + + def test_repr(self): + self.assertEqual(repr(ANY), '<ANY>') + self.assertEqual(str(ANY), '<ANY>') + + + def test_any_and_datetime(self): + mock = Mock() + mock(datetime.now(), foo=datetime.now()) + + mock.assert_called_with(ANY, foo=ANY) + + + def test_any_mock_calls_comparison_order(self): + mock = Mock() + class Foo(object): + def __eq__(self, other): pass + def __ne__(self, other): pass + + for d in datetime.now(), Foo(): + mock.reset_mock() + + mock(d, foo=d, bar=d) + mock.method(d, zinga=d, alpha=d) + mock().method(a1=d, z99=d) + + expected = [ + call(ANY, foo=ANY, bar=ANY), + call.method(ANY, zinga=ANY, alpha=ANY), + call(), call().method(a1=ANY, z99=ANY) + ] + self.assertEqual(expected, mock.mock_calls) + self.assertEqual(mock.mock_calls, expected) + + def test_any_no_spec(self): + # This is a regression test for bpo-37555 + class Foo: + def __eq__(self, other): pass + + mock = Mock() + mock(Foo(), 1) + mock.assert_has_calls([call(ANY, 1)]) + mock.assert_called_with(ANY, 1) + mock.assert_any_call(ANY, 1) + + def test_any_and_spec_set(self): + # This is a regression test for bpo-37555 + class Foo: + def __eq__(self, other): pass + + mock = Mock(spec=Foo) + + mock(Foo(), 1) + mock.assert_has_calls([call(ANY, 1)]) + mock.assert_called_with(ANY, 1) + mock.assert_any_call(ANY, 1) + +class CallTest(unittest.TestCase): + + def test_call_with_call(self): + kall = _Call() + self.assertEqual(kall, _Call()) + self.assertEqual(kall, _Call(('',))) + self.assertEqual(kall, _Call(((),))) + self.assertEqual(kall, _Call(({},))) + self.assertEqual(kall, _Call(('', ()))) + self.assertEqual(kall, _Call(('', {}))) + self.assertEqual(kall, _Call(('', (), {}))) + self.assertEqual(kall, _Call(('foo',))) + self.assertEqual(kall, _Call(('bar', ()))) + self.assertEqual(kall, _Call(('baz', {}))) + self.assertEqual(kall, _Call(('spam', (), {}))) + + kall = _Call(((1, 2, 3),)) + self.assertEqual(kall, _Call(((1, 2, 3),))) + self.assertEqual(kall, _Call(('', (1, 2, 3)))) + self.assertEqual(kall, _Call(((1, 2, 3), {}))) + self.assertEqual(kall, _Call(('', (1, 2, 3), {}))) + + kall = _Call(((1, 2, 4),)) + self.assertNotEqual(kall, _Call(('', (1, 2, 3)))) + self.assertNotEqual(kall, _Call(('', (1, 2, 3), {}))) + + kall = _Call(('foo', (1, 2, 4),)) + self.assertNotEqual(kall, _Call(('', (1, 2, 4)))) + self.assertNotEqual(kall, _Call(('', (1, 2, 4), {}))) + self.assertNotEqual(kall, _Call(('bar', (1, 2, 4)))) + self.assertNotEqual(kall, _Call(('bar', (1, 2, 4), {}))) + + kall = _Call(({'a': 3},)) + self.assertEqual(kall, _Call(('', (), {'a': 3}))) + self.assertEqual(kall, _Call(('', {'a': 3}))) + self.assertEqual(kall, _Call(((), {'a': 3}))) + self.assertEqual(kall, _Call(({'a': 3},))) + + + def test_empty__Call(self): + args = _Call() + + self.assertEqual(args, ()) + self.assertEqual(args, ('foo',)) + self.assertEqual(args, ((),)) + self.assertEqual(args, ('foo', ())) + self.assertEqual(args, ('foo',(), {})) + self.assertEqual(args, ('foo', {})) + self.assertEqual(args, ({},)) + + + def test_named_empty_call(self): + args = _Call(('foo', (), {})) + + self.assertEqual(args, ('foo',)) + self.assertEqual(args, ('foo', ())) + self.assertEqual(args, ('foo',(), {})) + self.assertEqual(args, ('foo', {})) + + self.assertNotEqual(args, ((),)) + self.assertNotEqual(args, ()) + self.assertNotEqual(args, ({},)) + self.assertNotEqual(args, ('bar',)) + self.assertNotEqual(args, ('bar', ())) + self.assertNotEqual(args, ('bar', {})) + + + def test_call_with_args(self): + args = _Call(((1, 2, 3), {})) + + self.assertEqual(args, ((1, 2, 3),)) + self.assertEqual(args, ('foo', (1, 2, 3))) + self.assertEqual(args, ('foo', (1, 2, 3), {})) + self.assertEqual(args, ((1, 2, 3), {})) + self.assertEqual(args.args, (1, 2, 3)) + self.assertEqual(args.kwargs, {}) + + + def test_named_call_with_args(self): + args = _Call(('foo', (1, 2, 3), {})) + + self.assertEqual(args, ('foo', (1, 2, 3))) + self.assertEqual(args, ('foo', (1, 2, 3), {})) + self.assertEqual(args.args, (1, 2, 3)) + self.assertEqual(args.kwargs, {}) + + self.assertNotEqual(args, ((1, 2, 3),)) + self.assertNotEqual(args, ((1, 2, 3), {})) + + + def test_call_with_kwargs(self): + args = _Call(((), dict(a=3, b=4))) + + self.assertEqual(args, (dict(a=3, b=4),)) + self.assertEqual(args, ('foo', dict(a=3, b=4))) + self.assertEqual(args, ('foo', (), dict(a=3, b=4))) + self.assertEqual(args, ((), dict(a=3, b=4))) + self.assertEqual(args.args, ()) + self.assertEqual(args.kwargs, dict(a=3, b=4)) + + + def test_named_call_with_kwargs(self): + args = _Call(('foo', (), dict(a=3, b=4))) + + self.assertEqual(args, ('foo', dict(a=3, b=4))) + self.assertEqual(args, ('foo', (), dict(a=3, b=4))) + self.assertEqual(args.args, ()) + self.assertEqual(args.kwargs, dict(a=3, b=4)) + + self.assertNotEqual(args, (dict(a=3, b=4),)) + self.assertNotEqual(args, ((), dict(a=3, b=4))) + + + def test_call_with_args_call_empty_name(self): + args = _Call(((1, 2, 3), {})) + + self.assertEqual(args, call(1, 2, 3)) + self.assertEqual(call(1, 2, 3), args) + self.assertIn(call(1, 2, 3), [args]) + + + def test_call_ne(self): + self.assertNotEqual(_Call(((1, 2, 3),)), call(1, 2)) + self.assertFalse(_Call(((1, 2, 3),)) != call(1, 2, 3)) + self.assertTrue(_Call(((1, 2), {})) != call(1, 2, 3)) + + + def test_call_non_tuples(self): + kall = _Call(((1, 2, 3),)) + for value in 1, None, self, int: + self.assertNotEqual(kall, value) + self.assertFalse(kall == value) + + + def test_repr(self): + self.assertEqual(repr(_Call()), 'call()') + self.assertEqual(repr(_Call(('foo',))), 'call.foo()') + + self.assertEqual(repr(_Call(((1, 2, 3), {'a': 'b'}))), + "call(1, 2, 3, a='b')") + self.assertEqual(repr(_Call(('bar', (1, 2, 3), {'a': 'b'}))), + "call.bar(1, 2, 3, a='b')") + + self.assertEqual(repr(call), 'call') + self.assertEqual(str(call), 'call') + + self.assertEqual(repr(call()), 'call()') + self.assertEqual(repr(call(1)), 'call(1)') + self.assertEqual(repr(call(zz='thing')), "call(zz='thing')") + + self.assertEqual(repr(call().foo), 'call().foo') + self.assertEqual(repr(call(1).foo.bar(a=3).bing), + 'call().foo.bar().bing') + self.assertEqual( + repr(call().foo(1, 2, a=3)), + "call().foo(1, 2, a=3)" + ) + self.assertEqual(repr(call()()), "call()()") + self.assertEqual(repr(call(1)(2)), "call()(2)") + self.assertEqual( + repr(call()().bar().baz.beep(1)), + "call()().bar().baz.beep(1)" + ) + + + def test_call(self): + self.assertEqual(call(), ('', (), {})) + self.assertEqual(call('foo', 'bar', one=3, two=4), + ('', ('foo', 'bar'), {'one': 3, 'two': 4})) + + mock = Mock() + mock(1, 2, 3) + mock(a=3, b=6) + self.assertEqual(mock.call_args_list, + [call(1, 2, 3), call(a=3, b=6)]) + + def test_attribute_call(self): + self.assertEqual(call.foo(1), ('foo', (1,), {})) + self.assertEqual(call.bar.baz(fish='eggs'), + ('bar.baz', (), {'fish': 'eggs'})) + + mock = Mock() + mock.foo(1, 2 ,3) + mock.bar.baz(a=3, b=6) + self.assertEqual(mock.method_calls, + [call.foo(1, 2, 3), call.bar.baz(a=3, b=6)]) + + + def test_extended_call(self): + result = call(1).foo(2).bar(3, a=4) + self.assertEqual(result, ('().foo().bar', (3,), dict(a=4))) + + mock = MagicMock() + mock(1, 2, a=3, b=4) + self.assertEqual(mock.call_args, call(1, 2, a=3, b=4)) + self.assertNotEqual(mock.call_args, call(1, 2, 3)) + + self.assertEqual(mock.call_args_list, [call(1, 2, a=3, b=4)]) + self.assertEqual(mock.mock_calls, [call(1, 2, a=3, b=4)]) + + mock = MagicMock() + mock.foo(1).bar()().baz.beep(a=6) + + last_call = call.foo(1).bar()().baz.beep(a=6) + self.assertEqual(mock.mock_calls[-1], last_call) + self.assertEqual(mock.mock_calls, last_call.call_list()) + + + def test_extended_not_equal(self): + a = call(x=1).foo + b = call(x=2).foo + self.assertEqual(a, a) + self.assertEqual(b, b) + self.assertNotEqual(a, b) + + + def test_nested_calls_not_equal(self): + a = call(x=1).foo().bar + b = call(x=2).foo().bar + self.assertEqual(a, a) + self.assertEqual(b, b) + self.assertNotEqual(a, b) + + + def test_call_list(self): + mock = MagicMock() + mock(1) + self.assertEqual(call(1).call_list(), mock.mock_calls) + + mock = MagicMock() + mock(1).method(2) + self.assertEqual(call(1).method(2).call_list(), + mock.mock_calls) + + mock = MagicMock() + mock(1).method(2)(3) + self.assertEqual(call(1).method(2)(3).call_list(), + mock.mock_calls) + + mock = MagicMock() + int(mock(1).method(2)(3).foo.bar.baz(4)(5)) + kall = call(1).method(2)(3).foo.bar.baz(4)(5).__int__() + self.assertEqual(kall.call_list(), mock.mock_calls) + + + def test_call_any(self): + self.assertEqual(call, ANY) + + m = MagicMock() + int(m) + self.assertEqual(m.mock_calls, [ANY]) + self.assertEqual([ANY], m.mock_calls) + + + def test_two_args_call(self): + args = _Call(((1, 2), {'a': 3}), two=True) + self.assertEqual(len(args), 2) + self.assertEqual(args[0], (1, 2)) + self.assertEqual(args[1], {'a': 3}) + + other_args = _Call(((1, 2), {'a': 3})) + self.assertEqual(args, other_args) + + def test_call_with_name(self): + self.assertEqual(_Call((), 'foo')[0], 'foo') + self.assertEqual(_Call((('bar', 'barz'),),)[0], '') + self.assertEqual(_Call((('bar', 'barz'), {'hello': 'world'}),)[0], '') + + def test_dunder_call(self): + m = MagicMock() + m().foo()['bar']() + self.assertEqual( + m.mock_calls, + [call(), call().foo(), call().foo().__getitem__('bar'), call().foo().__getitem__()()] + ) + m = MagicMock() + m().foo()['bar'] = 1 + self.assertEqual( + m.mock_calls, + [call(), call().foo(), call().foo().__setitem__('bar', 1)] + ) + m = MagicMock() + iter(m().foo()) + self.assertEqual( + m.mock_calls, + [call(), call().foo(), call().foo().__iter__()] + ) + + +class SpecSignatureTest(unittest.TestCase): + + def _check_someclass_mock(self, mock): + self.assertRaises(AttributeError, getattr, mock, 'foo') + mock.one(1, 2) + mock.one.assert_called_with(1, 2) + self.assertRaises(AssertionError, + mock.one.assert_called_with, 3, 4) + self.assertRaises(TypeError, mock.one, 1) + + mock.two() + mock.two.assert_called_with() + self.assertRaises(AssertionError, + mock.two.assert_called_with, 3) + self.assertRaises(TypeError, mock.two, 1) + + mock.three() + mock.three.assert_called_with() + self.assertRaises(AssertionError, + mock.three.assert_called_with, 3) + self.assertRaises(TypeError, mock.three, 3, 2) + + mock.three(1) + mock.three.assert_called_with(1) + + mock.three(a=1) + mock.three.assert_called_with(a=1) + + + def test_basic(self): + mock = create_autospec(SomeClass) + self._check_someclass_mock(mock) + mock = create_autospec(SomeClass()) + self._check_someclass_mock(mock) + + + def test_create_autospec_return_value(self): + def f(): pass + mock = create_autospec(f, return_value='foo') + self.assertEqual(mock(), 'foo') + + class Foo(object): + pass + + mock = create_autospec(Foo, return_value='foo') + self.assertEqual(mock(), 'foo') + + + def test_autospec_reset_mock(self): + m = create_autospec(int) + int(m) + m.reset_mock() + self.assertEqual(m.__int__.call_count, 0) + + + def test_mocking_unbound_methods(self): + class Foo(object): + def foo(self, foo): pass + p = patch.object(Foo, 'foo') + mock_foo = p.start() + Foo().foo(1) + + mock_foo.assert_called_with(1) + + + def test_create_autospec_keyword_arguments(self): + class Foo(object): + a = 3 + m = create_autospec(Foo, a='3') + self.assertEqual(m.a, '3') + + + def test_create_autospec_keyword_only_arguments(self): + def foo(a, *, b=None): pass + + m = create_autospec(foo) + m(1) + m.assert_called_with(1) + self.assertRaises(TypeError, m, 1, 2) + + m(2, b=3) + m.assert_called_with(2, b=3) + + + def test_function_as_instance_attribute(self): + obj = SomeClass() + def f(a): pass + obj.f = f + + mock = create_autospec(obj) + mock.f('bing') + mock.f.assert_called_with('bing') + + + def test_spec_as_list(self): + # because spec as a list of strings in the mock constructor means + # something very different we treat a list instance as the type. + mock = create_autospec([]) + mock.append('foo') + mock.append.assert_called_with('foo') + + self.assertRaises(AttributeError, getattr, mock, 'foo') + + class Foo(object): + foo = [] + + mock = create_autospec(Foo) + mock.foo.append(3) + mock.foo.append.assert_called_with(3) + self.assertRaises(AttributeError, getattr, mock.foo, 'foo') + + + def test_attributes(self): + class Sub(SomeClass): + attr = SomeClass() + + sub_mock = create_autospec(Sub) + + for mock in (sub_mock, sub_mock.attr): + self._check_someclass_mock(mock) + + + def test_spec_has_descriptor_returning_function(self): + + class CrazyDescriptor(object): + + def __get__(self, obj, type_): + if obj is None: + return lambda x: None + + class MyClass(object): + + some_attr = CrazyDescriptor() + + mock = create_autospec(MyClass) + mock.some_attr(1) + with self.assertRaises(TypeError): + mock.some_attr() + with self.assertRaises(TypeError): + mock.some_attr(1, 2) + + + def test_spec_has_function_not_in_bases(self): + + class CrazyClass(object): + + def __dir__(self): + return super(CrazyClass, self).__dir__()+['crazy'] + + def __getattr__(self, item): + if item == 'crazy': + return lambda x: x + raise AttributeError(item) + + inst = CrazyClass() + with self.assertRaises(AttributeError): + inst.other + self.assertEqual(inst.crazy(42), 42) + + mock = create_autospec(inst) + mock.crazy(42) + with self.assertRaises(TypeError): + mock.crazy() + with self.assertRaises(TypeError): + mock.crazy(1, 2) + + + def test_builtin_functions_types(self): + # we could replace builtin functions / methods with a function + # with *args / **kwargs signature. Using the builtin method type + # as a spec seems to work fairly well though. + class BuiltinSubclass(list): + def bar(self, arg): pass + sorted = sorted + attr = {} + + mock = create_autospec(BuiltinSubclass) + mock.append(3) + mock.append.assert_called_with(3) + self.assertRaises(AttributeError, getattr, mock.append, 'foo') + + mock.bar('foo') + mock.bar.assert_called_with('foo') + self.assertRaises(TypeError, mock.bar, 'foo', 'bar') + self.assertRaises(AttributeError, getattr, mock.bar, 'foo') + + mock.sorted([1, 2]) + mock.sorted.assert_called_with([1, 2]) + self.assertRaises(AttributeError, getattr, mock.sorted, 'foo') + + mock.attr.pop(3) + mock.attr.pop.assert_called_with(3) + self.assertRaises(AttributeError, getattr, mock.attr, 'foo') + + + def test_method_calls(self): + class Sub(SomeClass): + attr = SomeClass() + + mock = create_autospec(Sub) + mock.one(1, 2) + mock.two() + mock.three(3) + + expected = [call.one(1, 2), call.two(), call.three(3)] + self.assertEqual(mock.method_calls, expected) + + mock.attr.one(1, 2) + mock.attr.two() + mock.attr.three(3) + + expected.extend( + [call.attr.one(1, 2), call.attr.two(), call.attr.three(3)] + ) + self.assertEqual(mock.method_calls, expected) + + + def test_magic_methods(self): + class BuiltinSubclass(list): + attr = {} + + mock = create_autospec(BuiltinSubclass) + self.assertEqual(list(mock), []) + self.assertRaises(TypeError, int, mock) + self.assertRaises(TypeError, int, mock.attr) + self.assertEqual(list(mock), []) + + self.assertIsInstance(mock['foo'], MagicMock) + self.assertIsInstance(mock.attr['foo'], MagicMock) + + + def test_spec_set(self): + class Sub(SomeClass): + attr = SomeClass() + + for spec in (Sub, Sub()): + mock = create_autospec(spec, spec_set=True) + self._check_someclass_mock(mock) + + self.assertRaises(AttributeError, setattr, mock, 'foo', 'bar') + self.assertRaises(AttributeError, setattr, mock.attr, 'foo', 'bar') + + + def test_descriptors(self): + class Foo(object): + @classmethod + def f(cls, a, b): pass + @staticmethod + def g(a, b): pass + + class Bar(Foo): pass + + class Baz(SomeClass, Bar): pass + + for spec in (Foo, Foo(), Bar, Bar(), Baz, Baz()): + mock = create_autospec(spec) + mock.f(1, 2) + mock.f.assert_called_once_with(1, 2) + + mock.g(3, 4) + mock.g.assert_called_once_with(3, 4) + + + def test_recursive(self): + class A(object): + def a(self): pass + foo = 'foo bar baz' + bar = foo + + A.B = A + mock = create_autospec(A) + + mock() + self.assertFalse(mock.B.called) + + mock.a() + mock.B.a() + self.assertEqual(mock.method_calls, [call.a(), call.B.a()]) + + self.assertIs(A.foo, A.bar) + self.assertIsNot(mock.foo, mock.bar) + mock.foo.lower() + self.assertRaises(AssertionError, mock.bar.lower.assert_called_with) + + + def test_spec_inheritance_for_classes(self): + class Foo(object): + def a(self, x): pass + class Bar(object): + def f(self, y): pass + + class_mock = create_autospec(Foo) + + self.assertIsNot(class_mock, class_mock()) + + for this_mock in class_mock, class_mock(): + this_mock.a(x=5) + this_mock.a.assert_called_with(x=5) + this_mock.a.assert_called_with(5) + self.assertRaises(TypeError, this_mock.a, 'foo', 'bar') + self.assertRaises(AttributeError, getattr, this_mock, 'b') + + instance_mock = create_autospec(Foo()) + instance_mock.a(5) + instance_mock.a.assert_called_with(5) + instance_mock.a.assert_called_with(x=5) + self.assertRaises(TypeError, instance_mock.a, 'foo', 'bar') + self.assertRaises(AttributeError, getattr, instance_mock, 'b') + + # The return value isn't isn't callable + self.assertRaises(TypeError, instance_mock) + + instance_mock.Bar.f(6) + instance_mock.Bar.f.assert_called_with(6) + instance_mock.Bar.f.assert_called_with(y=6) + self.assertRaises(AttributeError, getattr, instance_mock.Bar, 'g') + + instance_mock.Bar().f(6) + instance_mock.Bar().f.assert_called_with(6) + instance_mock.Bar().f.assert_called_with(y=6) + self.assertRaises(AttributeError, getattr, instance_mock.Bar(), 'g') + + + def test_inherit(self): + class Foo(object): + a = 3 + + Foo.Foo = Foo + + # class + mock = create_autospec(Foo) + instance = mock() + self.assertRaises(AttributeError, getattr, instance, 'b') + + attr_instance = mock.Foo() + self.assertRaises(AttributeError, getattr, attr_instance, 'b') + + # instance + mock = create_autospec(Foo()) + self.assertRaises(AttributeError, getattr, mock, 'b') + self.assertRaises(TypeError, mock) + + # attribute instance + call_result = mock.Foo() + self.assertRaises(AttributeError, getattr, call_result, 'b') + + + def test_builtins(self): + # used to fail with infinite recursion + create_autospec(1) + + create_autospec(int) + create_autospec('foo') + create_autospec(str) + create_autospec({}) + create_autospec(dict) + create_autospec([]) + create_autospec(list) + create_autospec(set()) + create_autospec(set) + create_autospec(1.0) + create_autospec(float) + create_autospec(1j) + create_autospec(complex) + create_autospec(False) + create_autospec(True) + + + def test_function(self): + def f(a, b): pass + + mock = create_autospec(f) + self.assertRaises(TypeError, mock) + mock(1, 2) + mock.assert_called_with(1, 2) + mock.assert_called_with(1, b=2) + mock.assert_called_with(a=1, b=2) + + f.f = f + mock = create_autospec(f) + self.assertRaises(TypeError, mock.f) + mock.f(3, 4) + mock.f.assert_called_with(3, 4) + mock.f.assert_called_with(a=3, b=4) + + + def test_skip_attributeerrors(self): + class Raiser(object): + def __get__(self, obj, type=None): + if obj is None: + raise AttributeError('Can only be accessed via an instance') + + class RaiserClass(object): + raiser = Raiser() + + @staticmethod + def existing(a, b): + return a + b + + self.assertEqual(RaiserClass.existing(1, 2), 3) + s = create_autospec(RaiserClass) + self.assertRaises(TypeError, lambda x: s.existing(1, 2, 3)) + self.assertEqual(s.existing(1, 2), s.existing.return_value) + self.assertRaises(AttributeError, lambda: s.nonexisting) + + # check we can fetch the raiser attribute and it has no spec + obj = s.raiser + obj.foo, obj.bar + + + def test_signature_class(self): + class Foo(object): + def __init__(self, a, b=3): pass + + mock = create_autospec(Foo) + + self.assertRaises(TypeError, mock) + mock(1) + mock.assert_called_once_with(1) + mock.assert_called_once_with(a=1) + self.assertRaises(AssertionError, mock.assert_called_once_with, 2) + + mock(4, 5) + mock.assert_called_with(4, 5) + mock.assert_called_with(a=4, b=5) + self.assertRaises(AssertionError, mock.assert_called_with, a=5, b=4) + + + def test_class_with_no_init(self): + # this used to raise an exception + # due to trying to get a signature from object.__init__ + class Foo(object): + pass + create_autospec(Foo) + + + def test_signature_callable(self): + class Callable(object): + def __init__(self, x, y): pass + def __call__(self, a): pass + + mock = create_autospec(Callable) + mock(1, 2) + mock.assert_called_once_with(1, 2) + mock.assert_called_once_with(x=1, y=2) + self.assertRaises(TypeError, mock, 'a') + + instance = mock(1, 2) + self.assertRaises(TypeError, instance) + instance(a='a') + instance.assert_called_once_with('a') + instance.assert_called_once_with(a='a') + instance('a') + instance.assert_called_with('a') + instance.assert_called_with(a='a') + + mock = create_autospec(Callable(1, 2)) + mock(a='a') + mock.assert_called_once_with(a='a') + self.assertRaises(TypeError, mock) + mock('a') + mock.assert_called_with('a') + + + def test_signature_noncallable(self): + class NonCallable(object): + def __init__(self): + pass + + mock = create_autospec(NonCallable) + instance = mock() + mock.assert_called_once_with() + self.assertRaises(TypeError, mock, 'a') + self.assertRaises(TypeError, instance) + self.assertRaises(TypeError, instance, 'a') + + mock = create_autospec(NonCallable()) + self.assertRaises(TypeError, mock) + self.assertRaises(TypeError, mock, 'a') + + + def test_create_autospec_none(self): + class Foo(object): + bar = None + + mock = create_autospec(Foo) + none = mock.bar + self.assertNotIsInstance(none, type(None)) + + none.foo() + none.foo.assert_called_once_with() + + + def test_autospec_functions_with_self_in_odd_place(self): + class Foo(object): + def f(a, self): pass + + a = create_autospec(Foo) + a.f(10) + a.f.assert_called_with(10) + a.f.assert_called_with(self=10) + a.f(self=10) + a.f.assert_called_with(10) + a.f.assert_called_with(self=10) + + + def test_autospec_data_descriptor(self): + class Descriptor(object): + def __init__(self, value): + self.value = value + + def __get__(self, obj, cls=None): + return self + + def __set__(self, obj, value): pass + + class MyProperty(property): + pass + + class Foo(object): + __slots__ = ['slot'] + + @property + def prop(self): pass + + @MyProperty + def subprop(self): pass + + desc = Descriptor(42) + + foo = create_autospec(Foo) + + def check_data_descriptor(mock_attr): + # Data descriptors don't have a spec. + self.assertIsInstance(mock_attr, MagicMock) + mock_attr(1, 2, 3) + mock_attr.abc(4, 5, 6) + mock_attr.assert_called_once_with(1, 2, 3) + mock_attr.abc.assert_called_once_with(4, 5, 6) + + # property + check_data_descriptor(foo.prop) + # property subclass + check_data_descriptor(foo.subprop) + # class __slot__ + check_data_descriptor(foo.slot) + # plain data descriptor + check_data_descriptor(foo.desc) + + + def test_autospec_on_bound_builtin_function(self): + meth = types.MethodType(time.ctime, time.time()) + self.assertIsInstance(meth(), str) + mocked = create_autospec(meth) + + # no signature, so no spec to check against + mocked() + mocked.assert_called_once_with() + mocked.reset_mock() + mocked(4, 5, 6) + mocked.assert_called_once_with(4, 5, 6) + + + def test_autospec_getattr_partial_function(self): + # bpo-32153 : getattr returning partial functions without + # __name__ should not create AttributeError in create_autospec + class Foo: + + def __getattr__(self, attribute): + return partial(lambda name: name, attribute) + + proxy = Foo() + autospec = create_autospec(proxy) + self.assertFalse(hasattr(autospec, '__name__')) + + + def test_spec_inspect_signature(self): + + def myfunc(x, y): pass + + mock = create_autospec(myfunc) + mock(1, 2) + mock(x=1, y=2) + + self.assertEqual(inspect.signature(mock), inspect.signature(myfunc)) + self.assertEqual(mock.mock_calls, [call(1, 2), call(x=1, y=2)]) + self.assertRaises(TypeError, mock, 1) + + + def test_spec_inspect_signature_annotations(self): + + def foo(a: int, b: int=10, *, c:int) -> int: + return a + b + c + + self.assertEqual(foo(1, 2 , c=3), 6) + mock = create_autospec(foo) + mock(1, 2, c=3) + mock(1, c=3) + + self.assertEqual(inspect.signature(mock), inspect.signature(foo)) + self.assertEqual(mock.mock_calls, [call(1, 2, c=3), call(1, c=3)]) + self.assertRaises(TypeError, mock, 1) + self.assertRaises(TypeError, mock, 1, 2, 3, c=4) + + + def test_spec_function_no_name(self): + func = lambda: 'nope' + mock = create_autospec(func) + self.assertEqual(mock.__name__, 'funcopy') + + + def test_spec_function_assert_has_calls(self): + def f(a): pass + mock = create_autospec(f) + mock(1) + mock.assert_has_calls([call(1)]) + with self.assertRaises(AssertionError): + mock.assert_has_calls([call(2)]) + + + def test_spec_function_assert_any_call(self): + def f(a): pass + mock = create_autospec(f) + mock(1) + mock.assert_any_call(1) + with self.assertRaises(AssertionError): + mock.assert_any_call(2) + + + def test_spec_function_reset_mock(self): + def f(a): pass + rv = Mock() + mock = create_autospec(f, return_value=rv) + mock(1)(2) + self.assertEqual(mock.mock_calls, [call(1)]) + self.assertEqual(rv.mock_calls, [call(2)]) + mock.reset_mock() + self.assertEqual(mock.mock_calls, []) + self.assertEqual(rv.mock_calls, []) + + +class TestCallList(unittest.TestCase): + + def test_args_list_contains_call_list(self): + mock = Mock() + self.assertIsInstance(mock.call_args_list, _CallList) + + mock(1, 2) + mock(a=3) + mock(3, 4) + mock(b=6) + + for kall in call(1, 2), call(a=3), call(3, 4), call(b=6): + self.assertIn(kall, mock.call_args_list) + + calls = [call(a=3), call(3, 4)] + self.assertIn(calls, mock.call_args_list) + calls = [call(1, 2), call(a=3)] + self.assertIn(calls, mock.call_args_list) + calls = [call(3, 4), call(b=6)] + self.assertIn(calls, mock.call_args_list) + calls = [call(3, 4)] + self.assertIn(calls, mock.call_args_list) + + self.assertNotIn(call('fish'), mock.call_args_list) + self.assertNotIn([call('fish')], mock.call_args_list) + + + def test_call_list_str(self): + mock = Mock() + mock(1, 2) + mock.foo(a=3) + mock.foo.bar().baz('fish', cat='dog') + + expected = ( + "[call(1, 2),\n" + " call.foo(a=3),\n" + " call.foo.bar(),\n" + " call.foo.bar().baz('fish', cat='dog')]" + ) + self.assertEqual(str(mock.mock_calls), expected) + + + def test_propertymock(self): + p = patch('%s.SomeClass.one' % __name__, new_callable=PropertyMock) + mock = p.start() + try: + SomeClass.one + mock.assert_called_once_with() + + s = SomeClass() + s.one + mock.assert_called_with() + self.assertEqual(mock.mock_calls, [call(), call()]) + + s.one = 3 + self.assertEqual(mock.mock_calls, [call(), call(), call(3)]) + finally: + p.stop() + + + def test_propertymock_returnvalue(self): + m = MagicMock() + p = PropertyMock() + type(m).foo = p + + returned = m.foo + p.assert_called_once_with() + self.assertIsInstance(returned, MagicMock) + self.assertNotIsInstance(returned, PropertyMock) + + +class TestCallablePredicate(unittest.TestCase): + + def test_type(self): + for obj in [str, bytes, int, list, tuple, SomeClass]: + self.assertTrue(_callable(obj)) + + def test_call_magic_method(self): + class Callable: + def __call__(self): pass + instance = Callable() + self.assertTrue(_callable(instance)) + + def test_staticmethod(self): + class WithStaticMethod: + @staticmethod + def staticfunc(): pass + self.assertTrue(_callable(WithStaticMethod.staticfunc)) + + def test_non_callable_staticmethod(self): + class BadStaticMethod: + not_callable = staticmethod(None) + self.assertFalse(_callable(BadStaticMethod.not_callable)) + + def test_classmethod(self): + class WithClassMethod: + @classmethod + def classfunc(cls): pass + self.assertTrue(_callable(WithClassMethod.classfunc)) + + def test_non_callable_classmethod(self): + class BadClassMethod: + not_callable = classmethod(None) + self.assertFalse(_callable(BadClassMethod.not_callable)) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/testmock/testmagicmethods.py b/Lib/test/test_unittest/testmock/testmagicmethods.py new file mode 100644 index 0000000..a4feae7 --- /dev/null +++ b/Lib/test/test_unittest/testmock/testmagicmethods.py @@ -0,0 +1,509 @@ +import math +import unittest +import os +from asyncio import iscoroutinefunction +from unittest.mock import AsyncMock, Mock, MagicMock, _magics + + + +class TestMockingMagicMethods(unittest.TestCase): + + def test_deleting_magic_methods(self): + mock = Mock() + self.assertFalse(hasattr(mock, '__getitem__')) + + mock.__getitem__ = Mock() + self.assertTrue(hasattr(mock, '__getitem__')) + + del mock.__getitem__ + self.assertFalse(hasattr(mock, '__getitem__')) + + + def test_magicmock_del(self): + mock = MagicMock() + # before using getitem + del mock.__getitem__ + self.assertRaises(TypeError, lambda: mock['foo']) + + mock = MagicMock() + # this time use it first + mock['foo'] + del mock.__getitem__ + self.assertRaises(TypeError, lambda: mock['foo']) + + + def test_magic_method_wrapping(self): + mock = Mock() + def f(self, name): + return self, 'fish' + + mock.__getitem__ = f + self.assertIsNot(mock.__getitem__, f) + self.assertEqual(mock['foo'], (mock, 'fish')) + self.assertEqual(mock.__getitem__('foo'), (mock, 'fish')) + + mock.__getitem__ = mock + self.assertIs(mock.__getitem__, mock) + + + def test_magic_methods_isolated_between_mocks(self): + mock1 = Mock() + mock2 = Mock() + + mock1.__iter__ = Mock(return_value=iter([])) + self.assertEqual(list(mock1), []) + self.assertRaises(TypeError, lambda: list(mock2)) + + + def test_repr(self): + mock = Mock() + self.assertEqual(repr(mock), "<Mock id='%s'>" % id(mock)) + mock.__repr__ = lambda s: 'foo' + self.assertEqual(repr(mock), 'foo') + + + def test_str(self): + mock = Mock() + self.assertEqual(str(mock), object.__str__(mock)) + mock.__str__ = lambda s: 'foo' + self.assertEqual(str(mock), 'foo') + + + def test_dict_methods(self): + mock = Mock() + + self.assertRaises(TypeError, lambda: mock['foo']) + def _del(): + del mock['foo'] + def _set(): + mock['foo'] = 3 + self.assertRaises(TypeError, _del) + self.assertRaises(TypeError, _set) + + _dict = {} + def getitem(s, name): + return _dict[name] + def setitem(s, name, value): + _dict[name] = value + def delitem(s, name): + del _dict[name] + + mock.__setitem__ = setitem + mock.__getitem__ = getitem + mock.__delitem__ = delitem + + self.assertRaises(KeyError, lambda: mock['foo']) + mock['foo'] = 'bar' + self.assertEqual(_dict, {'foo': 'bar'}) + self.assertEqual(mock['foo'], 'bar') + del mock['foo'] + self.assertEqual(_dict, {}) + + + def test_numeric(self): + original = mock = Mock() + mock.value = 0 + + self.assertRaises(TypeError, lambda: mock + 3) + + def add(self, other): + mock.value += other + return self + mock.__add__ = add + self.assertEqual(mock + 3, mock) + self.assertEqual(mock.value, 3) + + del mock.__add__ + def iadd(mock): + mock += 3 + self.assertRaises(TypeError, iadd, mock) + mock.__iadd__ = add + mock += 6 + self.assertEqual(mock, original) + self.assertEqual(mock.value, 9) + + self.assertRaises(TypeError, lambda: 3 + mock) + mock.__radd__ = add + self.assertEqual(7 + mock, mock) + self.assertEqual(mock.value, 16) + + def test_division(self): + original = mock = Mock() + mock.value = 32 + self.assertRaises(TypeError, lambda: mock / 2) + + def truediv(self, other): + mock.value /= other + return self + mock.__truediv__ = truediv + self.assertEqual(mock / 2, mock) + self.assertEqual(mock.value, 16) + + del mock.__truediv__ + def itruediv(mock): + mock /= 4 + self.assertRaises(TypeError, itruediv, mock) + mock.__itruediv__ = truediv + mock /= 8 + self.assertEqual(mock, original) + self.assertEqual(mock.value, 2) + + self.assertRaises(TypeError, lambda: 8 / mock) + mock.__rtruediv__ = truediv + self.assertEqual(0.5 / mock, mock) + self.assertEqual(mock.value, 4) + + def test_hash(self): + mock = Mock() + # test delegation + self.assertEqual(hash(mock), Mock.__hash__(mock)) + + def _hash(s): + return 3 + mock.__hash__ = _hash + self.assertEqual(hash(mock), 3) + + + def test_nonzero(self): + m = Mock() + self.assertTrue(bool(m)) + + m.__bool__ = lambda s: False + self.assertFalse(bool(m)) + + + def test_comparison(self): + mock = Mock() + def comp(s, o): + return True + mock.__lt__ = mock.__gt__ = mock.__le__ = mock.__ge__ = comp + self. assertTrue(mock < 3) + self. assertTrue(mock > 3) + self. assertTrue(mock <= 3) + self. assertTrue(mock >= 3) + + self.assertRaises(TypeError, lambda: MagicMock() < object()) + self.assertRaises(TypeError, lambda: object() < MagicMock()) + self.assertRaises(TypeError, lambda: MagicMock() < MagicMock()) + self.assertRaises(TypeError, lambda: MagicMock() > object()) + self.assertRaises(TypeError, lambda: object() > MagicMock()) + self.assertRaises(TypeError, lambda: MagicMock() > MagicMock()) + self.assertRaises(TypeError, lambda: MagicMock() <= object()) + self.assertRaises(TypeError, lambda: object() <= MagicMock()) + self.assertRaises(TypeError, lambda: MagicMock() <= MagicMock()) + self.assertRaises(TypeError, lambda: MagicMock() >= object()) + self.assertRaises(TypeError, lambda: object() >= MagicMock()) + self.assertRaises(TypeError, lambda: MagicMock() >= MagicMock()) + + + def test_equality(self): + for mock in Mock(), MagicMock(): + self.assertEqual(mock == mock, True) + self.assertIsInstance(mock == mock, bool) + self.assertEqual(mock != mock, False) + self.assertIsInstance(mock != mock, bool) + self.assertEqual(mock == object(), False) + self.assertEqual(mock != object(), True) + + def eq(self, other): + return other == 3 + mock.__eq__ = eq + self.assertTrue(mock == 3) + self.assertFalse(mock == 4) + + def ne(self, other): + return other == 3 + mock.__ne__ = ne + self.assertTrue(mock != 3) + self.assertFalse(mock != 4) + + mock = MagicMock() + mock.__eq__.return_value = True + self.assertIsInstance(mock == 3, bool) + self.assertEqual(mock == 3, True) + + mock.__ne__.return_value = False + self.assertIsInstance(mock != 3, bool) + self.assertEqual(mock != 3, False) + + + def test_len_contains_iter(self): + mock = Mock() + + self.assertRaises(TypeError, len, mock) + self.assertRaises(TypeError, iter, mock) + self.assertRaises(TypeError, lambda: 'foo' in mock) + + mock.__len__ = lambda s: 6 + self.assertEqual(len(mock), 6) + + mock.__contains__ = lambda s, o: o == 3 + self.assertIn(3, mock) + self.assertNotIn(6, mock) + + mock.__iter__ = lambda s: iter('foobarbaz') + self.assertEqual(list(mock), list('foobarbaz')) + + + def test_magicmock(self): + mock = MagicMock() + + mock.__iter__.return_value = iter([1, 2, 3]) + self.assertEqual(list(mock), [1, 2, 3]) + + getattr(mock, '__bool__').return_value = False + self.assertFalse(hasattr(mock, '__nonzero__')) + self.assertFalse(bool(mock)) + + for entry in _magics: + self.assertTrue(hasattr(mock, entry)) + self.assertFalse(hasattr(mock, '__imaginary__')) + + + def test_magic_mock_equality(self): + mock = MagicMock() + self.assertIsInstance(mock == object(), bool) + self.assertIsInstance(mock != object(), bool) + + self.assertEqual(mock == object(), False) + self.assertEqual(mock != object(), True) + self.assertEqual(mock == mock, True) + self.assertEqual(mock != mock, False) + + def test_asyncmock_defaults(self): + mock = AsyncMock() + self.assertEqual(int(mock), 1) + self.assertEqual(complex(mock), 1j) + self.assertEqual(float(mock), 1.0) + self.assertNotIn(object(), mock) + self.assertEqual(len(mock), 0) + self.assertEqual(list(mock), []) + self.assertEqual(hash(mock), object.__hash__(mock)) + self.assertEqual(str(mock), object.__str__(mock)) + self.assertTrue(bool(mock)) + self.assertEqual(round(mock), mock.__round__()) + self.assertEqual(math.trunc(mock), mock.__trunc__()) + self.assertEqual(math.floor(mock), mock.__floor__()) + self.assertEqual(math.ceil(mock), mock.__ceil__()) + self.assertTrue(iscoroutinefunction(mock.__aexit__)) + self.assertTrue(iscoroutinefunction(mock.__aenter__)) + self.assertIsInstance(mock.__aenter__, AsyncMock) + self.assertIsInstance(mock.__aexit__, AsyncMock) + + # in Python 3 oct and hex use __index__ + # so these tests are for __index__ in py3k + self.assertEqual(oct(mock), '0o1') + self.assertEqual(hex(mock), '0x1') + # how to test __sizeof__ ? + + def test_magicmock_defaults(self): + mock = MagicMock() + self.assertEqual(int(mock), 1) + self.assertEqual(complex(mock), 1j) + self.assertEqual(float(mock), 1.0) + self.assertNotIn(object(), mock) + self.assertEqual(len(mock), 0) + self.assertEqual(list(mock), []) + self.assertEqual(hash(mock), object.__hash__(mock)) + self.assertEqual(str(mock), object.__str__(mock)) + self.assertTrue(bool(mock)) + self.assertEqual(round(mock), mock.__round__()) + self.assertEqual(math.trunc(mock), mock.__trunc__()) + self.assertEqual(math.floor(mock), mock.__floor__()) + self.assertEqual(math.ceil(mock), mock.__ceil__()) + self.assertTrue(iscoroutinefunction(mock.__aexit__)) + self.assertTrue(iscoroutinefunction(mock.__aenter__)) + self.assertIsInstance(mock.__aenter__, AsyncMock) + self.assertIsInstance(mock.__aexit__, AsyncMock) + + # in Python 3 oct and hex use __index__ + # so these tests are for __index__ in py3k + self.assertEqual(oct(mock), '0o1') + self.assertEqual(hex(mock), '0x1') + # how to test __sizeof__ ? + + + def test_magic_methods_fspath(self): + mock = MagicMock() + expected_path = mock.__fspath__() + mock.reset_mock() + + self.assertEqual(os.fspath(mock), expected_path) + mock.__fspath__.assert_called_once() + + + def test_magic_methods_and_spec(self): + class Iterable(object): + def __iter__(self): pass + + mock = Mock(spec=Iterable) + self.assertRaises(AttributeError, lambda: mock.__iter__) + + mock.__iter__ = Mock(return_value=iter([])) + self.assertEqual(list(mock), []) + + class NonIterable(object): + pass + mock = Mock(spec=NonIterable) + self.assertRaises(AttributeError, lambda: mock.__iter__) + + def set_int(): + mock.__int__ = Mock(return_value=iter([])) + self.assertRaises(AttributeError, set_int) + + mock = MagicMock(spec=Iterable) + self.assertEqual(list(mock), []) + self.assertRaises(AttributeError, set_int) + + + def test_magic_methods_and_spec_set(self): + class Iterable(object): + def __iter__(self): pass + + mock = Mock(spec_set=Iterable) + self.assertRaises(AttributeError, lambda: mock.__iter__) + + mock.__iter__ = Mock(return_value=iter([])) + self.assertEqual(list(mock), []) + + class NonIterable(object): + pass + mock = Mock(spec_set=NonIterable) + self.assertRaises(AttributeError, lambda: mock.__iter__) + + def set_int(): + mock.__int__ = Mock(return_value=iter([])) + self.assertRaises(AttributeError, set_int) + + mock = MagicMock(spec_set=Iterable) + self.assertEqual(list(mock), []) + self.assertRaises(AttributeError, set_int) + + + def test_setting_unsupported_magic_method(self): + mock = MagicMock() + def set_setattr(): + mock.__setattr__ = lambda self, name: None + self.assertRaisesRegex(AttributeError, + "Attempting to set unsupported magic method '__setattr__'.", + set_setattr + ) + + + def test_attributes_and_return_value(self): + mock = MagicMock() + attr = mock.foo + def _get_type(obj): + # the type of every mock (or magicmock) is a custom subclass + # so the real type is the second in the mro + return type(obj).__mro__[1] + self.assertEqual(_get_type(attr), MagicMock) + + returned = mock() + self.assertEqual(_get_type(returned), MagicMock) + + + def test_magic_methods_are_magic_mocks(self): + mock = MagicMock() + self.assertIsInstance(mock.__getitem__, MagicMock) + + mock[1][2].__getitem__.return_value = 3 + self.assertEqual(mock[1][2][3], 3) + + + def test_magic_method_reset_mock(self): + mock = MagicMock() + str(mock) + self.assertTrue(mock.__str__.called) + mock.reset_mock() + self.assertFalse(mock.__str__.called) + + + def test_dir(self): + # overriding the default implementation + for mock in Mock(), MagicMock(): + def _dir(self): + return ['foo'] + mock.__dir__ = _dir + self.assertEqual(dir(mock), ['foo']) + + + def test_bound_methods(self): + m = Mock() + + # XXXX should this be an expected failure instead? + + # this seems like it should work, but is hard to do without introducing + # other api inconsistencies. Failure message could be better though. + m.__iter__ = [3].__iter__ + self.assertRaises(TypeError, iter, m) + + + def test_magic_method_type(self): + class Foo(MagicMock): + pass + + foo = Foo() + self.assertIsInstance(foo.__int__, Foo) + + + def test_descriptor_from_class(self): + m = MagicMock() + type(m).__str__.return_value = 'foo' + self.assertEqual(str(m), 'foo') + + + def test_iterable_as_iter_return_value(self): + m = MagicMock() + m.__iter__.return_value = [1, 2, 3] + self.assertEqual(list(m), [1, 2, 3]) + self.assertEqual(list(m), [1, 2, 3]) + + m.__iter__.return_value = iter([4, 5, 6]) + self.assertEqual(list(m), [4, 5, 6]) + self.assertEqual(list(m), []) + + + def test_matmul(self): + m = MagicMock() + self.assertIsInstance(m @ 1, MagicMock) + m.__matmul__.return_value = 42 + m.__rmatmul__.return_value = 666 + m.__imatmul__.return_value = 24 + self.assertEqual(m @ 1, 42) + self.assertEqual(1 @ m, 666) + m @= 24 + self.assertEqual(m, 24) + + def test_divmod_and_rdivmod(self): + m = MagicMock() + self.assertIsInstance(divmod(5, m), MagicMock) + m.__divmod__.return_value = (2, 1) + self.assertEqual(divmod(m, 2), (2, 1)) + m = MagicMock() + foo = divmod(2, m) + self.assertIsInstance(foo, MagicMock) + foo_direct = m.__divmod__(2) + self.assertIsInstance(foo_direct, MagicMock) + bar = divmod(m, 2) + self.assertIsInstance(bar, MagicMock) + bar_direct = m.__rdivmod__(2) + self.assertIsInstance(bar_direct, MagicMock) + + # http://bugs.python.org/issue23310 + # Check if you can change behaviour of magic methods in MagicMock init + def test_magic_in_initialization(self): + m = MagicMock(**{'__str__.return_value': "12"}) + self.assertEqual(str(m), "12") + + def test_changing_magic_set_in_initialization(self): + m = MagicMock(**{'__str__.return_value': "12"}) + m.__str__.return_value = "13" + self.assertEqual(str(m), "13") + m = MagicMock(**{'__str__.return_value': "12"}) + m.configure_mock(**{'__str__.return_value': "14"}) + self.assertEqual(str(m), "14") + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/testmock/testmock.py b/Lib/test/test_unittest/testmock/testmock.py new file mode 100644 index 0000000..8a92490 --- /dev/null +++ b/Lib/test/test_unittest/testmock/testmock.py @@ -0,0 +1,2262 @@ +import copy +import re +import sys +import tempfile + +from test.support import ALWAYS_EQ +import unittest +from test.test_unittest.testmock.support import is_instance +from unittest import mock +from unittest.mock import ( + call, DEFAULT, patch, sentinel, + MagicMock, Mock, NonCallableMock, + NonCallableMagicMock, AsyncMock, _Call, _CallList, + create_autospec, InvalidSpecError +) + + +class Iter(object): + def __init__(self): + self.thing = iter(['this', 'is', 'an', 'iter']) + + def __iter__(self): + return self + + def next(self): + return next(self.thing) + + __next__ = next + + +class Something(object): + def meth(self, a, b, c, d=None): pass + + @classmethod + def cmeth(cls, a, b, c, d=None): pass + + @staticmethod + def smeth(a, b, c, d=None): pass + + +class Typos(): + autospect = None + auto_spec = None + set_spec = None + + +def something(a): pass + + +class MockTest(unittest.TestCase): + + def test_all(self): + # if __all__ is badly defined then import * will raise an error + # We have to exec it because you can't import * inside a method + # in Python 3 + exec("from unittest.mock import *") + + + def test_constructor(self): + mock = Mock() + + self.assertFalse(mock.called, "called not initialised correctly") + self.assertEqual(mock.call_count, 0, + "call_count not initialised correctly") + self.assertTrue(is_instance(mock.return_value, Mock), + "return_value not initialised correctly") + + self.assertEqual(mock.call_args, None, + "call_args not initialised correctly") + self.assertEqual(mock.call_args_list, [], + "call_args_list not initialised correctly") + self.assertEqual(mock.method_calls, [], + "method_calls not initialised correctly") + + # Can't use hasattr for this test as it always returns True on a mock + self.assertNotIn('_items', mock.__dict__, + "default mock should not have '_items' attribute") + + self.assertIsNone(mock._mock_parent, + "parent not initialised correctly") + self.assertIsNone(mock._mock_methods, + "methods not initialised correctly") + self.assertEqual(mock._mock_children, {}, + "children not initialised incorrectly") + + + def test_return_value_in_constructor(self): + mock = Mock(return_value=None) + self.assertIsNone(mock.return_value, + "return value in constructor not honoured") + + + def test_change_return_value_via_delegate(self): + def f(): pass + mock = create_autospec(f) + mock.mock.return_value = 1 + self.assertEqual(mock(), 1) + + + def test_change_side_effect_via_delegate(self): + def f(): pass + mock = create_autospec(f) + mock.mock.side_effect = TypeError() + with self.assertRaises(TypeError): + mock() + + + def test_repr(self): + mock = Mock(name='foo') + self.assertIn('foo', repr(mock)) + self.assertIn("'%s'" % id(mock), repr(mock)) + + mocks = [(Mock(), 'mock'), (Mock(name='bar'), 'bar')] + for mock, name in mocks: + self.assertIn('%s.bar' % name, repr(mock.bar)) + self.assertIn('%s.foo()' % name, repr(mock.foo())) + self.assertIn('%s.foo().bing' % name, repr(mock.foo().bing)) + self.assertIn('%s()' % name, repr(mock())) + self.assertIn('%s()()' % name, repr(mock()())) + self.assertIn('%s()().foo.bar.baz().bing' % name, + repr(mock()().foo.bar.baz().bing)) + + + def test_repr_with_spec(self): + class X(object): + pass + + mock = Mock(spec=X) + self.assertIn(" spec='X' ", repr(mock)) + + mock = Mock(spec=X()) + self.assertIn(" spec='X' ", repr(mock)) + + mock = Mock(spec_set=X) + self.assertIn(" spec_set='X' ", repr(mock)) + + mock = Mock(spec_set=X()) + self.assertIn(" spec_set='X' ", repr(mock)) + + mock = Mock(spec=X, name='foo') + self.assertIn(" spec='X' ", repr(mock)) + self.assertIn(" name='foo' ", repr(mock)) + + mock = Mock(name='foo') + self.assertNotIn("spec", repr(mock)) + + mock = Mock() + self.assertNotIn("spec", repr(mock)) + + mock = Mock(spec=['foo']) + self.assertNotIn("spec", repr(mock)) + + + def test_side_effect(self): + mock = Mock() + + def effect(*args, **kwargs): + raise SystemError('kablooie') + + mock.side_effect = effect + self.assertRaises(SystemError, mock, 1, 2, fish=3) + mock.assert_called_with(1, 2, fish=3) + + results = [1, 2, 3] + def effect(): + return results.pop() + mock.side_effect = effect + + self.assertEqual([mock(), mock(), mock()], [3, 2, 1], + "side effect not used correctly") + + mock = Mock(side_effect=sentinel.SideEffect) + self.assertEqual(mock.side_effect, sentinel.SideEffect, + "side effect in constructor not used") + + def side_effect(): + return DEFAULT + mock = Mock(side_effect=side_effect, return_value=sentinel.RETURN) + self.assertEqual(mock(), sentinel.RETURN) + + def test_autospec_side_effect(self): + # Test for issue17826 + results = [1, 2, 3] + def effect(): + return results.pop() + def f(): pass + + mock = create_autospec(f) + mock.side_effect = [1, 2, 3] + self.assertEqual([mock(), mock(), mock()], [1, 2, 3], + "side effect not used correctly in create_autospec") + # Test where side effect is a callable + results = [1, 2, 3] + mock = create_autospec(f) + mock.side_effect = effect + self.assertEqual([mock(), mock(), mock()], [3, 2, 1], + "callable side effect not used correctly") + + def test_autospec_side_effect_exception(self): + # Test for issue 23661 + def f(): pass + + mock = create_autospec(f) + mock.side_effect = ValueError('Bazinga!') + self.assertRaisesRegex(ValueError, 'Bazinga!', mock) + + + def test_autospec_mock(self): + class A(object): + class B(object): + C = None + + with mock.patch.object(A, 'B'): + with self.assertRaisesRegex(InvalidSpecError, + "Cannot autospec attr 'B' from target <MagicMock spec='A'"): + create_autospec(A).B + with self.assertRaisesRegex(InvalidSpecError, + "Cannot autospec attr 'B' from target 'A'"): + mock.patch.object(A, 'B', autospec=True).start() + with self.assertRaisesRegex(InvalidSpecError, + "Cannot autospec attr 'C' as the patch target "): + mock.patch.object(A.B, 'C', autospec=True).start() + with self.assertRaisesRegex(InvalidSpecError, + "Cannot spec attr 'B' as the spec "): + mock.patch.object(A, 'B', spec=A.B).start() + with self.assertRaisesRegex(InvalidSpecError, + "Cannot spec attr 'B' as the spec_set "): + mock.patch.object(A, 'B', spec_set=A.B).start() + with self.assertRaisesRegex(InvalidSpecError, + "Cannot spec attr 'B' as the spec_set "): + mock.patch.object(A, 'B', spec_set=A.B).start() + with self.assertRaisesRegex(InvalidSpecError, "Cannot spec a Mock object."): + mock.Mock(A.B) + with mock.patch('builtins.open', mock.mock_open()): + mock.mock_open() # should still be valid with open() mocked + + + def test_reset_mock(self): + parent = Mock() + spec = ["something"] + mock = Mock(name="child", parent=parent, spec=spec) + mock(sentinel.Something, something=sentinel.SomethingElse) + something = mock.something + mock.something() + mock.side_effect = sentinel.SideEffect + return_value = mock.return_value + return_value() + + mock.reset_mock() + + self.assertEqual(mock._mock_name, "child", + "name incorrectly reset") + self.assertEqual(mock._mock_parent, parent, + "parent incorrectly reset") + self.assertEqual(mock._mock_methods, spec, + "methods incorrectly reset") + + self.assertFalse(mock.called, "called not reset") + self.assertEqual(mock.call_count, 0, "call_count not reset") + self.assertEqual(mock.call_args, None, "call_args not reset") + self.assertEqual(mock.call_args_list, [], "call_args_list not reset") + self.assertEqual(mock.method_calls, [], + "method_calls not initialised correctly: %r != %r" % + (mock.method_calls, [])) + self.assertEqual(mock.mock_calls, []) + + self.assertEqual(mock.side_effect, sentinel.SideEffect, + "side_effect incorrectly reset") + self.assertEqual(mock.return_value, return_value, + "return_value incorrectly reset") + self.assertFalse(return_value.called, "return value mock not reset") + self.assertEqual(mock._mock_children, {'something': something}, + "children reset incorrectly") + self.assertEqual(mock.something, something, + "children incorrectly cleared") + self.assertFalse(mock.something.called, "child not reset") + + + def test_reset_mock_recursion(self): + mock = Mock() + mock.return_value = mock + + # used to cause recursion + mock.reset_mock() + + def test_reset_mock_on_mock_open_issue_18622(self): + a = mock.mock_open() + a.reset_mock() + + def test_call(self): + mock = Mock() + self.assertTrue(is_instance(mock.return_value, Mock), + "Default return_value should be a Mock") + + result = mock() + self.assertEqual(mock(), result, + "different result from consecutive calls") + mock.reset_mock() + + ret_val = mock(sentinel.Arg) + self.assertTrue(mock.called, "called not set") + self.assertEqual(mock.call_count, 1, "call_count incorrect") + self.assertEqual(mock.call_args, ((sentinel.Arg,), {}), + "call_args not set") + self.assertEqual(mock.call_args.args, (sentinel.Arg,), + "call_args not set") + self.assertEqual(mock.call_args.kwargs, {}, + "call_args not set") + self.assertEqual(mock.call_args_list, [((sentinel.Arg,), {})], + "call_args_list not initialised correctly") + + mock.return_value = sentinel.ReturnValue + ret_val = mock(sentinel.Arg, key=sentinel.KeyArg) + self.assertEqual(ret_val, sentinel.ReturnValue, + "incorrect return value") + + self.assertEqual(mock.call_count, 2, "call_count incorrect") + self.assertEqual(mock.call_args, + ((sentinel.Arg,), {'key': sentinel.KeyArg}), + "call_args not set") + self.assertEqual(mock.call_args_list, [ + ((sentinel.Arg,), {}), + ((sentinel.Arg,), {'key': sentinel.KeyArg}) + ], + "call_args_list not set") + + + def test_call_args_comparison(self): + mock = Mock() + mock() + mock(sentinel.Arg) + mock(kw=sentinel.Kwarg) + mock(sentinel.Arg, kw=sentinel.Kwarg) + self.assertEqual(mock.call_args_list, [ + (), + ((sentinel.Arg,),), + ({"kw": sentinel.Kwarg},), + ((sentinel.Arg,), {"kw": sentinel.Kwarg}) + ]) + self.assertEqual(mock.call_args, + ((sentinel.Arg,), {"kw": sentinel.Kwarg})) + self.assertEqual(mock.call_args.args, (sentinel.Arg,)) + self.assertEqual(mock.call_args.kwargs, {"kw": sentinel.Kwarg}) + + # Comparing call_args to a long sequence should not raise + # an exception. See issue 24857. + self.assertFalse(mock.call_args == "a long sequence") + + + def test_calls_equal_with_any(self): + # Check that equality and non-equality is consistent even when + # comparing with mock.ANY + mm = mock.MagicMock() + self.assertTrue(mm == mm) + self.assertFalse(mm != mm) + self.assertFalse(mm == mock.MagicMock()) + self.assertTrue(mm != mock.MagicMock()) + self.assertTrue(mm == mock.ANY) + self.assertFalse(mm != mock.ANY) + self.assertTrue(mock.ANY == mm) + self.assertFalse(mock.ANY != mm) + self.assertTrue(mm == ALWAYS_EQ) + self.assertFalse(mm != ALWAYS_EQ) + + call1 = mock.call(mock.MagicMock()) + call2 = mock.call(mock.ANY) + self.assertTrue(call1 == call2) + self.assertFalse(call1 != call2) + self.assertTrue(call2 == call1) + self.assertFalse(call2 != call1) + + self.assertTrue(call1 == ALWAYS_EQ) + self.assertFalse(call1 != ALWAYS_EQ) + self.assertFalse(call1 == 1) + self.assertTrue(call1 != 1) + + + def test_assert_called_with(self): + mock = Mock() + mock() + + # Will raise an exception if it fails + mock.assert_called_with() + self.assertRaises(AssertionError, mock.assert_called_with, 1) + + mock.reset_mock() + self.assertRaises(AssertionError, mock.assert_called_with) + + mock(1, 2, 3, a='fish', b='nothing') + mock.assert_called_with(1, 2, 3, a='fish', b='nothing') + + + def test_assert_called_with_any(self): + m = MagicMock() + m(MagicMock()) + m.assert_called_with(mock.ANY) + + + def test_assert_called_with_function_spec(self): + def f(a, b, c, d=None): pass + + mock = Mock(spec=f) + + mock(1, b=2, c=3) + mock.assert_called_with(1, 2, 3) + mock.assert_called_with(a=1, b=2, c=3) + self.assertRaises(AssertionError, mock.assert_called_with, + 1, b=3, c=2) + # Expected call doesn't match the spec's signature + with self.assertRaises(AssertionError) as cm: + mock.assert_called_with(e=8) + self.assertIsInstance(cm.exception.__cause__, TypeError) + + + def test_assert_called_with_method_spec(self): + def _check(mock): + mock(1, b=2, c=3) + mock.assert_called_with(1, 2, 3) + mock.assert_called_with(a=1, b=2, c=3) + self.assertRaises(AssertionError, mock.assert_called_with, + 1, b=3, c=2) + + mock = Mock(spec=Something().meth) + _check(mock) + mock = Mock(spec=Something.cmeth) + _check(mock) + mock = Mock(spec=Something().cmeth) + _check(mock) + mock = Mock(spec=Something.smeth) + _check(mock) + mock = Mock(spec=Something().smeth) + _check(mock) + + + def test_assert_called_exception_message(self): + msg = "Expected '{0}' to have been called" + with self.assertRaisesRegex(AssertionError, msg.format('mock')): + Mock().assert_called() + with self.assertRaisesRegex(AssertionError, msg.format('test_name')): + Mock(name="test_name").assert_called() + + + def test_assert_called_once_with(self): + mock = Mock() + mock() + + # Will raise an exception if it fails + mock.assert_called_once_with() + + mock() + self.assertRaises(AssertionError, mock.assert_called_once_with) + + mock.reset_mock() + self.assertRaises(AssertionError, mock.assert_called_once_with) + + mock('foo', 'bar', baz=2) + mock.assert_called_once_with('foo', 'bar', baz=2) + + mock.reset_mock() + mock('foo', 'bar', baz=2) + self.assertRaises( + AssertionError, + lambda: mock.assert_called_once_with('bob', 'bar', baz=2) + ) + + def test_assert_called_once_with_call_list(self): + m = Mock() + m(1) + m(2) + self.assertRaisesRegex(AssertionError, + re.escape("Calls: [call(1), call(2)]"), + lambda: m.assert_called_once_with(2)) + + + def test_assert_called_once_with_function_spec(self): + def f(a, b, c, d=None): pass + + mock = Mock(spec=f) + + mock(1, b=2, c=3) + mock.assert_called_once_with(1, 2, 3) + mock.assert_called_once_with(a=1, b=2, c=3) + self.assertRaises(AssertionError, mock.assert_called_once_with, + 1, b=3, c=2) + # Expected call doesn't match the spec's signature + with self.assertRaises(AssertionError) as cm: + mock.assert_called_once_with(e=8) + self.assertIsInstance(cm.exception.__cause__, TypeError) + # Mock called more than once => always fails + mock(4, 5, 6) + self.assertRaises(AssertionError, mock.assert_called_once_with, + 1, 2, 3) + self.assertRaises(AssertionError, mock.assert_called_once_with, + 4, 5, 6) + + + def test_attribute_access_returns_mocks(self): + mock = Mock() + something = mock.something + self.assertTrue(is_instance(something, Mock), "attribute isn't a mock") + self.assertEqual(mock.something, something, + "different attributes returned for same name") + + # Usage example + mock = Mock() + mock.something.return_value = 3 + + self.assertEqual(mock.something(), 3, "method returned wrong value") + self.assertTrue(mock.something.called, + "method didn't record being called") + + + def test_attributes_have_name_and_parent_set(self): + mock = Mock() + something = mock.something + + self.assertEqual(something._mock_name, "something", + "attribute name not set correctly") + self.assertEqual(something._mock_parent, mock, + "attribute parent not set correctly") + + + def test_method_calls_recorded(self): + mock = Mock() + mock.something(3, fish=None) + mock.something_else.something(6, cake=sentinel.Cake) + + self.assertEqual(mock.something_else.method_calls, + [("something", (6,), {'cake': sentinel.Cake})], + "method calls not recorded correctly") + self.assertEqual(mock.method_calls, [ + ("something", (3,), {'fish': None}), + ("something_else.something", (6,), {'cake': sentinel.Cake}) + ], + "method calls not recorded correctly") + + + def test_method_calls_compare_easily(self): + mock = Mock() + mock.something() + self.assertEqual(mock.method_calls, [('something',)]) + self.assertEqual(mock.method_calls, [('something', (), {})]) + + mock = Mock() + mock.something('different') + self.assertEqual(mock.method_calls, [('something', ('different',))]) + self.assertEqual(mock.method_calls, + [('something', ('different',), {})]) + + mock = Mock() + mock.something(x=1) + self.assertEqual(mock.method_calls, [('something', {'x': 1})]) + self.assertEqual(mock.method_calls, [('something', (), {'x': 1})]) + + mock = Mock() + mock.something('different', some='more') + self.assertEqual(mock.method_calls, [ + ('something', ('different',), {'some': 'more'}) + ]) + + + def test_only_allowed_methods_exist(self): + for spec in ['something'], ('something',): + for arg in 'spec', 'spec_set': + mock = Mock(**{arg: spec}) + + # this should be allowed + mock.something + self.assertRaisesRegex( + AttributeError, + "Mock object has no attribute 'something_else'", + getattr, mock, 'something_else' + ) + + + def test_from_spec(self): + class Something(object): + x = 3 + __something__ = None + def y(self): pass + + def test_attributes(mock): + # should work + mock.x + mock.y + mock.__something__ + self.assertRaisesRegex( + AttributeError, + "Mock object has no attribute 'z'", + getattr, mock, 'z' + ) + self.assertRaisesRegex( + AttributeError, + "Mock object has no attribute '__foobar__'", + getattr, mock, '__foobar__' + ) + + test_attributes(Mock(spec=Something)) + test_attributes(Mock(spec=Something())) + + + def test_wraps_calls(self): + real = Mock() + + mock = Mock(wraps=real) + self.assertEqual(mock(), real()) + + real.reset_mock() + + mock(1, 2, fish=3) + real.assert_called_with(1, 2, fish=3) + + + def test_wraps_prevents_automatic_creation_of_mocks(self): + class Real(object): + pass + + real = Real() + mock = Mock(wraps=real) + + self.assertRaises(AttributeError, lambda: mock.new_attr()) + + + def test_wraps_call_with_nondefault_return_value(self): + real = Mock() + + mock = Mock(wraps=real) + mock.return_value = 3 + + self.assertEqual(mock(), 3) + self.assertFalse(real.called) + + + def test_wraps_attributes(self): + class Real(object): + attribute = Mock() + + real = Real() + + mock = Mock(wraps=real) + self.assertEqual(mock.attribute(), real.attribute()) + self.assertRaises(AttributeError, lambda: mock.fish) + + self.assertNotEqual(mock.attribute, real.attribute) + result = mock.attribute.frog(1, 2, fish=3) + Real.attribute.frog.assert_called_with(1, 2, fish=3) + self.assertEqual(result, Real.attribute.frog()) + + + def test_customize_wrapped_object_with_side_effect_iterable_with_default(self): + class Real(object): + def method(self): + return sentinel.ORIGINAL_VALUE + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = [sentinel.VALUE1, DEFAULT] + + self.assertEqual(mock.method(), sentinel.VALUE1) + self.assertEqual(mock.method(), sentinel.ORIGINAL_VALUE) + self.assertRaises(StopIteration, mock.method) + + + def test_customize_wrapped_object_with_side_effect_iterable(self): + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = [sentinel.VALUE1, sentinel.VALUE2] + + self.assertEqual(mock.method(), sentinel.VALUE1) + self.assertEqual(mock.method(), sentinel.VALUE2) + self.assertRaises(StopIteration, mock.method) + + + def test_customize_wrapped_object_with_side_effect_exception(self): + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = RuntimeError + + self.assertRaises(RuntimeError, mock.method) + + + def test_customize_wrapped_object_with_side_effect_function(self): + class Real(object): + def method(self): pass + def side_effect(): + return sentinel.VALUE + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = side_effect + + self.assertEqual(mock.method(), sentinel.VALUE) + + + def test_customize_wrapped_object_with_return_value(self): + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.return_value = sentinel.VALUE + + self.assertEqual(mock.method(), sentinel.VALUE) + + + def test_customize_wrapped_object_with_return_value_and_side_effect(self): + # side_effect should always take precedence over return_value. + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = [sentinel.VALUE1, sentinel.VALUE2] + mock.method.return_value = sentinel.WRONG_VALUE + + self.assertEqual(mock.method(), sentinel.VALUE1) + self.assertEqual(mock.method(), sentinel.VALUE2) + self.assertRaises(StopIteration, mock.method) + + + def test_customize_wrapped_object_with_return_value_and_side_effect2(self): + # side_effect can return DEFAULT to default to return_value + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = lambda: DEFAULT + mock.method.return_value = sentinel.VALUE + + self.assertEqual(mock.method(), sentinel.VALUE) + + + def test_customize_wrapped_object_with_return_value_and_side_effect_default(self): + class Real(object): + def method(self): pass + + real = Real() + mock = Mock(wraps=real) + mock.method.side_effect = [sentinel.VALUE1, DEFAULT] + mock.method.return_value = sentinel.RETURN + + self.assertEqual(mock.method(), sentinel.VALUE1) + self.assertEqual(mock.method(), sentinel.RETURN) + self.assertRaises(StopIteration, mock.method) + + + def test_magic_method_wraps_dict(self): + # bpo-25597: MagicMock with wrap doesn't call wrapped object's + # method for magic methods with default values. + data = {'foo': 'bar'} + + wrapped_dict = MagicMock(wraps=data) + self.assertEqual(wrapped_dict.get('foo'), 'bar') + # Accessing key gives a MagicMock + self.assertIsInstance(wrapped_dict['foo'], MagicMock) + # __contains__ method has a default value of False + self.assertFalse('foo' in wrapped_dict) + + # return_value is non-sentinel and takes precedence over wrapped value. + wrapped_dict.get.return_value = 'return_value' + self.assertEqual(wrapped_dict.get('foo'), 'return_value') + + # return_value is sentinel and hence wrapped value is returned. + wrapped_dict.get.return_value = sentinel.DEFAULT + self.assertEqual(wrapped_dict.get('foo'), 'bar') + + self.assertEqual(wrapped_dict.get('baz'), None) + self.assertIsInstance(wrapped_dict['baz'], MagicMock) + self.assertFalse('bar' in wrapped_dict) + + data['baz'] = 'spam' + self.assertEqual(wrapped_dict.get('baz'), 'spam') + self.assertIsInstance(wrapped_dict['baz'], MagicMock) + self.assertFalse('bar' in wrapped_dict) + + del data['baz'] + self.assertEqual(wrapped_dict.get('baz'), None) + + + def test_magic_method_wraps_class(self): + + class Foo: + + def __getitem__(self, index): + return index + + def __custom_method__(self): + return "foo" + + + klass = MagicMock(wraps=Foo) + obj = klass() + self.assertEqual(obj.__getitem__(2), 2) + self.assertEqual(obj[2], 2) + self.assertEqual(obj.__custom_method__(), "foo") + + + def test_exceptional_side_effect(self): + mock = Mock(side_effect=AttributeError) + self.assertRaises(AttributeError, mock) + + mock = Mock(side_effect=AttributeError('foo')) + self.assertRaises(AttributeError, mock) + + + def test_baseexceptional_side_effect(self): + mock = Mock(side_effect=KeyboardInterrupt) + self.assertRaises(KeyboardInterrupt, mock) + + mock = Mock(side_effect=KeyboardInterrupt('foo')) + self.assertRaises(KeyboardInterrupt, mock) + + + def test_assert_called_with_message(self): + mock = Mock() + self.assertRaisesRegex(AssertionError, 'not called', + mock.assert_called_with) + + + def test_assert_called_once_with_message(self): + mock = Mock(name='geoffrey') + self.assertRaisesRegex(AssertionError, + r"Expected 'geoffrey' to be called once\.", + mock.assert_called_once_with) + + + def test__name__(self): + mock = Mock() + self.assertRaises(AttributeError, lambda: mock.__name__) + + mock.__name__ = 'foo' + self.assertEqual(mock.__name__, 'foo') + + + def test_spec_list_subclass(self): + class Sub(list): + pass + mock = Mock(spec=Sub(['foo'])) + + mock.append(3) + mock.append.assert_called_with(3) + self.assertRaises(AttributeError, getattr, mock, 'foo') + + + def test_spec_class(self): + class X(object): + pass + + mock = Mock(spec=X) + self.assertIsInstance(mock, X) + + mock = Mock(spec=X()) + self.assertIsInstance(mock, X) + + self.assertIs(mock.__class__, X) + self.assertEqual(Mock().__class__.__name__, 'Mock') + + mock = Mock(spec_set=X) + self.assertIsInstance(mock, X) + + mock = Mock(spec_set=X()) + self.assertIsInstance(mock, X) + + + def test_spec_class_no_object_base(self): + class X: + pass + + mock = Mock(spec=X) + self.assertIsInstance(mock, X) + + mock = Mock(spec=X()) + self.assertIsInstance(mock, X) + + self.assertIs(mock.__class__, X) + self.assertEqual(Mock().__class__.__name__, 'Mock') + + mock = Mock(spec_set=X) + self.assertIsInstance(mock, X) + + mock = Mock(spec_set=X()) + self.assertIsInstance(mock, X) + + + def test_setting_attribute_with_spec_set(self): + class X(object): + y = 3 + + mock = Mock(spec=X) + mock.x = 'foo' + + mock = Mock(spec_set=X) + def set_attr(): + mock.x = 'foo' + + mock.y = 'foo' + self.assertRaises(AttributeError, set_attr) + + + def test_copy(self): + current = sys.getrecursionlimit() + self.addCleanup(sys.setrecursionlimit, current) + + # can't use sys.maxint as this doesn't exist in Python 3 + sys.setrecursionlimit(int(10e8)) + # this segfaults without the fix in place + copy.copy(Mock()) + + + def test_subclass_with_properties(self): + class SubClass(Mock): + def _get(self): + return 3 + def _set(self, value): + raise NameError('strange error') + some_attribute = property(_get, _set) + + s = SubClass(spec_set=SubClass) + self.assertEqual(s.some_attribute, 3) + + def test(): + s.some_attribute = 3 + self.assertRaises(NameError, test) + + def test(): + s.foo = 'bar' + self.assertRaises(AttributeError, test) + + + def test_setting_call(self): + mock = Mock() + def __call__(self, a): + self._increment_mock_call(a) + return self._mock_call(a) + + type(mock).__call__ = __call__ + mock('one') + mock.assert_called_with('one') + + self.assertRaises(TypeError, mock, 'one', 'two') + + + def test_dir(self): + mock = Mock() + attrs = set(dir(mock)) + type_attrs = set([m for m in dir(Mock) if not m.startswith('_')]) + + # all public attributes from the type are included + self.assertEqual(set(), type_attrs - attrs) + + # creates these attributes + mock.a, mock.b + self.assertIn('a', dir(mock)) + self.assertIn('b', dir(mock)) + + # instance attributes + mock.c = mock.d = None + self.assertIn('c', dir(mock)) + self.assertIn('d', dir(mock)) + + # magic methods + mock.__iter__ = lambda s: iter([]) + self.assertIn('__iter__', dir(mock)) + + + def test_dir_from_spec(self): + mock = Mock(spec=unittest.TestCase) + testcase_attrs = set(dir(unittest.TestCase)) + attrs = set(dir(mock)) + + # all attributes from the spec are included + self.assertEqual(set(), testcase_attrs - attrs) + + # shadow a sys attribute + mock.version = 3 + self.assertEqual(dir(mock).count('version'), 1) + + + def test_filter_dir(self): + patcher = patch.object(mock, 'FILTER_DIR', False) + patcher.start() + try: + attrs = set(dir(Mock())) + type_attrs = set(dir(Mock)) + + # ALL attributes from the type are included + self.assertEqual(set(), type_attrs - attrs) + finally: + patcher.stop() + + + def test_dir_does_not_include_deleted_attributes(self): + mock = Mock() + mock.child.return_value = 1 + + self.assertIn('child', dir(mock)) + del mock.child + self.assertNotIn('child', dir(mock)) + + + def test_configure_mock(self): + mock = Mock(foo='bar') + self.assertEqual(mock.foo, 'bar') + + mock = MagicMock(foo='bar') + self.assertEqual(mock.foo, 'bar') + + kwargs = {'side_effect': KeyError, 'foo.bar.return_value': 33, + 'foo': MagicMock()} + mock = Mock(**kwargs) + self.assertRaises(KeyError, mock) + self.assertEqual(mock.foo.bar(), 33) + self.assertIsInstance(mock.foo, MagicMock) + + mock = Mock() + mock.configure_mock(**kwargs) + self.assertRaises(KeyError, mock) + self.assertEqual(mock.foo.bar(), 33) + self.assertIsInstance(mock.foo, MagicMock) + + + def assertRaisesWithMsg(self, exception, message, func, *args, **kwargs): + # needed because assertRaisesRegex doesn't work easily with newlines + with self.assertRaises(exception) as context: + func(*args, **kwargs) + msg = str(context.exception) + self.assertEqual(msg, message) + + + def test_assert_called_with_failure_message(self): + mock = NonCallableMock() + + actual = 'not called.' + expected = "mock(1, '2', 3, bar='foo')" + message = 'expected call not found.\nExpected: %s\nActual: %s' + self.assertRaisesWithMsg( + AssertionError, message % (expected, actual), + mock.assert_called_with, 1, '2', 3, bar='foo' + ) + + mock.foo(1, '2', 3, foo='foo') + + + asserters = [ + mock.foo.assert_called_with, mock.foo.assert_called_once_with + ] + for meth in asserters: + actual = "foo(1, '2', 3, foo='foo')" + expected = "foo(1, '2', 3, bar='foo')" + message = 'expected call not found.\nExpected: %s\nActual: %s' + self.assertRaisesWithMsg( + AssertionError, message % (expected, actual), + meth, 1, '2', 3, bar='foo' + ) + + # just kwargs + for meth in asserters: + actual = "foo(1, '2', 3, foo='foo')" + expected = "foo(bar='foo')" + message = 'expected call not found.\nExpected: %s\nActual: %s' + self.assertRaisesWithMsg( + AssertionError, message % (expected, actual), + meth, bar='foo' + ) + + # just args + for meth in asserters: + actual = "foo(1, '2', 3, foo='foo')" + expected = "foo(1, 2, 3)" + message = 'expected call not found.\nExpected: %s\nActual: %s' + self.assertRaisesWithMsg( + AssertionError, message % (expected, actual), + meth, 1, 2, 3 + ) + + # empty + for meth in asserters: + actual = "foo(1, '2', 3, foo='foo')" + expected = "foo()" + message = 'expected call not found.\nExpected: %s\nActual: %s' + self.assertRaisesWithMsg( + AssertionError, message % (expected, actual), meth + ) + + + def test_mock_calls(self): + mock = MagicMock() + + # need to do this because MagicMock.mock_calls used to just return + # a MagicMock which also returned a MagicMock when __eq__ was called + self.assertIs(mock.mock_calls == [], True) + + mock = MagicMock() + mock() + expected = [('', (), {})] + self.assertEqual(mock.mock_calls, expected) + + mock.foo() + expected.append(call.foo()) + self.assertEqual(mock.mock_calls, expected) + # intermediate mock_calls work too + self.assertEqual(mock.foo.mock_calls, [('', (), {})]) + + mock = MagicMock() + mock().foo(1, 2, 3, a=4, b=5) + expected = [ + ('', (), {}), ('().foo', (1, 2, 3), dict(a=4, b=5)) + ] + self.assertEqual(mock.mock_calls, expected) + self.assertEqual(mock.return_value.foo.mock_calls, + [('', (1, 2, 3), dict(a=4, b=5))]) + self.assertEqual(mock.return_value.mock_calls, + [('foo', (1, 2, 3), dict(a=4, b=5))]) + + mock = MagicMock() + mock().foo.bar().baz() + expected = [ + ('', (), {}), ('().foo.bar', (), {}), + ('().foo.bar().baz', (), {}) + ] + self.assertEqual(mock.mock_calls, expected) + self.assertEqual(mock().mock_calls, + call.foo.bar().baz().call_list()) + + for kwargs in dict(), dict(name='bar'): + mock = MagicMock(**kwargs) + int(mock.foo) + expected = [('foo.__int__', (), {})] + self.assertEqual(mock.mock_calls, expected) + + mock = MagicMock(**kwargs) + mock.a()() + expected = [('a', (), {}), ('a()', (), {})] + self.assertEqual(mock.mock_calls, expected) + self.assertEqual(mock.a().mock_calls, [call()]) + + mock = MagicMock(**kwargs) + mock(1)(2)(3) + self.assertEqual(mock.mock_calls, call(1)(2)(3).call_list()) + self.assertEqual(mock().mock_calls, call(2)(3).call_list()) + self.assertEqual(mock()().mock_calls, call(3).call_list()) + + mock = MagicMock(**kwargs) + mock(1)(2)(3).a.b.c(4) + self.assertEqual(mock.mock_calls, + call(1)(2)(3).a.b.c(4).call_list()) + self.assertEqual(mock().mock_calls, + call(2)(3).a.b.c(4).call_list()) + self.assertEqual(mock()().mock_calls, + call(3).a.b.c(4).call_list()) + + mock = MagicMock(**kwargs) + int(mock().foo.bar().baz()) + last_call = ('().foo.bar().baz().__int__', (), {}) + self.assertEqual(mock.mock_calls[-1], last_call) + self.assertEqual(mock().mock_calls, + call.foo.bar().baz().__int__().call_list()) + self.assertEqual(mock().foo.bar().mock_calls, + call.baz().__int__().call_list()) + self.assertEqual(mock().foo.bar().baz.mock_calls, + call().__int__().call_list()) + + + def test_child_mock_call_equal(self): + m = Mock() + result = m() + result.wibble() + # parent looks like this: + self.assertEqual(m.mock_calls, [call(), call().wibble()]) + # but child should look like this: + self.assertEqual(result.mock_calls, [call.wibble()]) + + + def test_mock_call_not_equal_leaf(self): + m = Mock() + m.foo().something() + self.assertNotEqual(m.mock_calls[1], call.foo().different()) + self.assertEqual(m.mock_calls[0], call.foo()) + + + def test_mock_call_not_equal_non_leaf(self): + m = Mock() + m.foo().bar() + self.assertNotEqual(m.mock_calls[1], call.baz().bar()) + self.assertNotEqual(m.mock_calls[0], call.baz()) + + + def test_mock_call_not_equal_non_leaf_params_different(self): + m = Mock() + m.foo(x=1).bar() + # This isn't ideal, but there's no way to fix it without breaking backwards compatibility: + self.assertEqual(m.mock_calls[1], call.foo(x=2).bar()) + + + def test_mock_call_not_equal_non_leaf_attr(self): + m = Mock() + m.foo.bar() + self.assertNotEqual(m.mock_calls[0], call.baz.bar()) + + + def test_mock_call_not_equal_non_leaf_call_versus_attr(self): + m = Mock() + m.foo.bar() + self.assertNotEqual(m.mock_calls[0], call.foo().bar()) + + + def test_mock_call_repr(self): + m = Mock() + m.foo().bar().baz.bob() + self.assertEqual(repr(m.mock_calls[0]), 'call.foo()') + self.assertEqual(repr(m.mock_calls[1]), 'call.foo().bar()') + self.assertEqual(repr(m.mock_calls[2]), 'call.foo().bar().baz.bob()') + + + def test_mock_call_repr_loop(self): + m = Mock() + m.foo = m + repr(m.foo()) + self.assertRegex(repr(m.foo()), r"<Mock name='mock\(\)' id='\d+'>") + + + def test_mock_calls_contains(self): + m = Mock() + self.assertFalse([call()] in m.mock_calls) + + + def test_subclassing(self): + class Subclass(Mock): + pass + + mock = Subclass() + self.assertIsInstance(mock.foo, Subclass) + self.assertIsInstance(mock(), Subclass) + + class Subclass(Mock): + def _get_child_mock(self, **kwargs): + return Mock(**kwargs) + + mock = Subclass() + self.assertNotIsInstance(mock.foo, Subclass) + self.assertNotIsInstance(mock(), Subclass) + + + def test_arg_lists(self): + mocks = [ + Mock(), + MagicMock(), + NonCallableMock(), + NonCallableMagicMock() + ] + + def assert_attrs(mock): + names = 'call_args_list', 'method_calls', 'mock_calls' + for name in names: + attr = getattr(mock, name) + self.assertIsInstance(attr, _CallList) + self.assertIsInstance(attr, list) + self.assertEqual(attr, []) + + for mock in mocks: + assert_attrs(mock) + + if callable(mock): + mock() + mock(1, 2) + mock(a=3) + + mock.reset_mock() + assert_attrs(mock) + + mock.foo() + mock.foo.bar(1, a=3) + mock.foo(1).bar().baz(3) + + mock.reset_mock() + assert_attrs(mock) + + + def test_call_args_two_tuple(self): + mock = Mock() + mock(1, a=3) + mock(2, b=4) + + self.assertEqual(len(mock.call_args), 2) + self.assertEqual(mock.call_args.args, (2,)) + self.assertEqual(mock.call_args.kwargs, dict(b=4)) + + expected_list = [((1,), dict(a=3)), ((2,), dict(b=4))] + for expected, call_args in zip(expected_list, mock.call_args_list): + self.assertEqual(len(call_args), 2) + self.assertEqual(expected[0], call_args[0]) + self.assertEqual(expected[1], call_args[1]) + + + def test_side_effect_iterator(self): + mock = Mock(side_effect=iter([1, 2, 3])) + self.assertEqual([mock(), mock(), mock()], [1, 2, 3]) + self.assertRaises(StopIteration, mock) + + mock = MagicMock(side_effect=['a', 'b', 'c']) + self.assertEqual([mock(), mock(), mock()], ['a', 'b', 'c']) + self.assertRaises(StopIteration, mock) + + mock = Mock(side_effect='ghi') + self.assertEqual([mock(), mock(), mock()], ['g', 'h', 'i']) + self.assertRaises(StopIteration, mock) + + class Foo(object): + pass + mock = MagicMock(side_effect=Foo) + self.assertIsInstance(mock(), Foo) + + mock = Mock(side_effect=Iter()) + self.assertEqual([mock(), mock(), mock(), mock()], + ['this', 'is', 'an', 'iter']) + self.assertRaises(StopIteration, mock) + + + def test_side_effect_iterator_exceptions(self): + for Klass in Mock, MagicMock: + iterable = (ValueError, 3, KeyError, 6) + m = Klass(side_effect=iterable) + self.assertRaises(ValueError, m) + self.assertEqual(m(), 3) + self.assertRaises(KeyError, m) + self.assertEqual(m(), 6) + + + def test_side_effect_setting_iterator(self): + mock = Mock() + mock.side_effect = iter([1, 2, 3]) + self.assertEqual([mock(), mock(), mock()], [1, 2, 3]) + self.assertRaises(StopIteration, mock) + side_effect = mock.side_effect + self.assertIsInstance(side_effect, type(iter([]))) + + mock.side_effect = ['a', 'b', 'c'] + self.assertEqual([mock(), mock(), mock()], ['a', 'b', 'c']) + self.assertRaises(StopIteration, mock) + side_effect = mock.side_effect + self.assertIsInstance(side_effect, type(iter([]))) + + this_iter = Iter() + mock.side_effect = this_iter + self.assertEqual([mock(), mock(), mock(), mock()], + ['this', 'is', 'an', 'iter']) + self.assertRaises(StopIteration, mock) + self.assertIs(mock.side_effect, this_iter) + + def test_side_effect_iterator_default(self): + mock = Mock(return_value=2) + mock.side_effect = iter([1, DEFAULT]) + self.assertEqual([mock(), mock()], [1, 2]) + + def test_assert_has_calls_any_order(self): + mock = Mock() + mock(1, 2) + mock(a=3) + mock(3, 4) + mock(b=6) + mock(b=6) + + kalls = [ + call(1, 2), ({'a': 3},), + ((3, 4),), ((), {'a': 3}), + ('', (1, 2)), ('', {'a': 3}), + ('', (1, 2), {}), ('', (), {'a': 3}) + ] + for kall in kalls: + mock.assert_has_calls([kall], any_order=True) + + for kall in call(1, '2'), call(b=3), call(), 3, None, 'foo': + self.assertRaises( + AssertionError, mock.assert_has_calls, + [kall], any_order=True + ) + + kall_lists = [ + [call(1, 2), call(b=6)], + [call(3, 4), call(1, 2)], + [call(b=6), call(b=6)], + ] + + for kall_list in kall_lists: + mock.assert_has_calls(kall_list, any_order=True) + + kall_lists = [ + [call(b=6), call(b=6), call(b=6)], + [call(1, 2), call(1, 2)], + [call(3, 4), call(1, 2), call(5, 7)], + [call(b=6), call(3, 4), call(b=6), call(1, 2), call(b=6)], + ] + for kall_list in kall_lists: + self.assertRaises( + AssertionError, mock.assert_has_calls, + kall_list, any_order=True + ) + + def test_assert_has_calls(self): + kalls1 = [ + call(1, 2), ({'a': 3},), + ((3, 4),), call(b=6), + ('', (1,), {'b': 6}), + ] + kalls2 = [call.foo(), call.bar(1)] + kalls2.extend(call.spam().baz(a=3).call_list()) + kalls2.extend(call.bam(set(), foo={}).fish([1]).call_list()) + + mocks = [] + for mock in Mock(), MagicMock(): + mock(1, 2) + mock(a=3) + mock(3, 4) + mock(b=6) + mock(1, b=6) + mocks.append((mock, kalls1)) + + mock = Mock() + mock.foo() + mock.bar(1) + mock.spam().baz(a=3) + mock.bam(set(), foo={}).fish([1]) + mocks.append((mock, kalls2)) + + for mock, kalls in mocks: + for i in range(len(kalls)): + for step in 1, 2, 3: + these = kalls[i:i+step] + mock.assert_has_calls(these) + + if len(these) > 1: + self.assertRaises( + AssertionError, + mock.assert_has_calls, + list(reversed(these)) + ) + + + def test_assert_has_calls_nested_spec(self): + class Something: + + def __init__(self): pass + def meth(self, a, b, c, d=None): pass + + class Foo: + + def __init__(self, a): pass + def meth1(self, a, b): pass + + mock_class = create_autospec(Something) + + for m in [mock_class, mock_class()]: + m.meth(1, 2, 3, d=1) + m.assert_has_calls([call.meth(1, 2, 3, d=1)]) + m.assert_has_calls([call.meth(1, 2, 3, 1)]) + + mock_class.reset_mock() + + for m in [mock_class, mock_class()]: + self.assertRaises(AssertionError, m.assert_has_calls, [call.Foo()]) + m.Foo(1).meth1(1, 2) + m.assert_has_calls([call.Foo(1), call.Foo(1).meth1(1, 2)]) + m.Foo.assert_has_calls([call(1), call().meth1(1, 2)]) + + mock_class.reset_mock() + + invalid_calls = [call.meth(1), + call.non_existent(1), + call.Foo().non_existent(1), + call.Foo().meth(1, 2, 3, 4)] + + for kall in invalid_calls: + self.assertRaises(AssertionError, + mock_class.assert_has_calls, + [kall] + ) + + + def test_assert_has_calls_nested_without_spec(self): + m = MagicMock() + m().foo().bar().baz() + m.one().two().three() + calls = call.one().two().three().call_list() + m.assert_has_calls(calls) + + + def test_assert_has_calls_with_function_spec(self): + def f(a, b, c, d=None): pass + + mock = Mock(spec=f) + + mock(1, b=2, c=3) + mock(4, 5, c=6, d=7) + mock(10, 11, c=12) + calls = [ + ('', (1, 2, 3), {}), + ('', (4, 5, 6), {'d': 7}), + ((10, 11, 12), {}), + ] + mock.assert_has_calls(calls) + mock.assert_has_calls(calls, any_order=True) + mock.assert_has_calls(calls[1:]) + mock.assert_has_calls(calls[1:], any_order=True) + mock.assert_has_calls(calls[:-1]) + mock.assert_has_calls(calls[:-1], any_order=True) + # Reversed order + calls = list(reversed(calls)) + with self.assertRaises(AssertionError): + mock.assert_has_calls(calls) + mock.assert_has_calls(calls, any_order=True) + with self.assertRaises(AssertionError): + mock.assert_has_calls(calls[1:]) + mock.assert_has_calls(calls[1:], any_order=True) + with self.assertRaises(AssertionError): + mock.assert_has_calls(calls[:-1]) + mock.assert_has_calls(calls[:-1], any_order=True) + + def test_assert_has_calls_not_matching_spec_error(self): + def f(x=None): pass + + mock = Mock(spec=f) + mock(1) + + with self.assertRaisesRegex( + AssertionError, + '^{}$'.format( + re.escape('Calls not found.\n' + 'Expected: [call()]\n' + 'Actual: [call(1)]'))) as cm: + mock.assert_has_calls([call()]) + self.assertIsNone(cm.exception.__cause__) + + + with self.assertRaisesRegex( + AssertionError, + '^{}$'.format( + re.escape( + 'Error processing expected calls.\n' + "Errors: [None, TypeError('too many positional arguments')]\n" + "Expected: [call(), call(1, 2)]\n" + 'Actual: [call(1)]'))) as cm: + mock.assert_has_calls([call(), call(1, 2)]) + self.assertIsInstance(cm.exception.__cause__, TypeError) + + def test_assert_any_call(self): + mock = Mock() + mock(1, 2) + mock(a=3) + mock(1, b=6) + + mock.assert_any_call(1, 2) + mock.assert_any_call(a=3) + mock.assert_any_call(1, b=6) + + self.assertRaises( + AssertionError, + mock.assert_any_call + ) + self.assertRaises( + AssertionError, + mock.assert_any_call, + 1, 3 + ) + self.assertRaises( + AssertionError, + mock.assert_any_call, + a=4 + ) + + + def test_assert_any_call_with_function_spec(self): + def f(a, b, c, d=None): pass + + mock = Mock(spec=f) + + mock(1, b=2, c=3) + mock(4, 5, c=6, d=7) + mock.assert_any_call(1, 2, 3) + mock.assert_any_call(a=1, b=2, c=3) + mock.assert_any_call(4, 5, 6, 7) + mock.assert_any_call(a=4, b=5, c=6, d=7) + self.assertRaises(AssertionError, mock.assert_any_call, + 1, b=3, c=2) + # Expected call doesn't match the spec's signature + with self.assertRaises(AssertionError) as cm: + mock.assert_any_call(e=8) + self.assertIsInstance(cm.exception.__cause__, TypeError) + + + def test_mock_calls_create_autospec(self): + def f(a, b): pass + obj = Iter() + obj.f = f + + funcs = [ + create_autospec(f), + create_autospec(obj).f + ] + for func in funcs: + func(1, 2) + func(3, 4) + + self.assertEqual( + func.mock_calls, [call(1, 2), call(3, 4)] + ) + + #Issue21222 + def test_create_autospec_with_name(self): + m = mock.create_autospec(object(), name='sweet_func') + self.assertIn('sweet_func', repr(m)) + + #Issue23078 + def test_create_autospec_classmethod_and_staticmethod(self): + class TestClass: + @classmethod + def class_method(cls): pass + + @staticmethod + def static_method(): pass + for method in ('class_method', 'static_method'): + with self.subTest(method=method): + mock_method = mock.create_autospec(getattr(TestClass, method)) + mock_method() + mock_method.assert_called_once_with() + self.assertRaises(TypeError, mock_method, 'extra_arg') + + #Issue21238 + def test_mock_unsafe(self): + m = Mock() + msg = "is not a valid assertion. Use a spec for the mock" + with self.assertRaisesRegex(AttributeError, msg): + m.assert_foo_call() + with self.assertRaisesRegex(AttributeError, msg): + m.assret_foo_call() + with self.assertRaisesRegex(AttributeError, msg): + m.asert_foo_call() + with self.assertRaisesRegex(AttributeError, msg): + m.aseert_foo_call() + with self.assertRaisesRegex(AttributeError, msg): + m.assrt_foo_call() + m = Mock(unsafe=True) + m.assert_foo_call() + m.assret_foo_call() + m.asert_foo_call() + m.aseert_foo_call() + m.assrt_foo_call() + + #Issue21262 + def test_assert_not_called(self): + m = Mock() + m.hello.assert_not_called() + m.hello() + with self.assertRaises(AssertionError): + m.hello.assert_not_called() + + def test_assert_not_called_message(self): + m = Mock() + m(1, 2) + self.assertRaisesRegex(AssertionError, + re.escape("Calls: [call(1, 2)]"), + m.assert_not_called) + + def test_assert_called(self): + m = Mock() + with self.assertRaises(AssertionError): + m.hello.assert_called() + m.hello() + m.hello.assert_called() + + m.hello() + m.hello.assert_called() + + def test_assert_called_once(self): + m = Mock() + with self.assertRaises(AssertionError): + m.hello.assert_called_once() + m.hello() + m.hello.assert_called_once() + + m.hello() + with self.assertRaises(AssertionError): + m.hello.assert_called_once() + + def test_assert_called_once_message(self): + m = Mock() + m(1, 2) + m(3) + self.assertRaisesRegex(AssertionError, + re.escape("Calls: [call(1, 2), call(3)]"), + m.assert_called_once) + + def test_assert_called_once_message_not_called(self): + m = Mock() + with self.assertRaises(AssertionError) as e: + m.assert_called_once() + self.assertNotIn("Calls:", str(e.exception)) + + #Issue37212 printout of keyword args now preserves the original order + def test_ordered_call_signature(self): + m = Mock() + m.hello(name='hello', daddy='hero') + text = "call(name='hello', daddy='hero')" + self.assertEqual(repr(m.hello.call_args), text) + + #Issue21270 overrides tuple methods for mock.call objects + def test_override_tuple_methods(self): + c = call.count() + i = call.index(132,'hello') + m = Mock() + m.count() + m.index(132,"hello") + self.assertEqual(m.method_calls[0], c) + self.assertEqual(m.method_calls[1], i) + + def test_reset_return_sideeffect(self): + m = Mock(return_value=10, side_effect=[2,3]) + m.reset_mock(return_value=True, side_effect=True) + self.assertIsInstance(m.return_value, Mock) + self.assertEqual(m.side_effect, None) + + def test_reset_return(self): + m = Mock(return_value=10, side_effect=[2,3]) + m.reset_mock(return_value=True) + self.assertIsInstance(m.return_value, Mock) + self.assertNotEqual(m.side_effect, None) + + def test_reset_sideeffect(self): + m = Mock(return_value=10, side_effect=[2, 3]) + m.reset_mock(side_effect=True) + self.assertEqual(m.return_value, 10) + self.assertEqual(m.side_effect, None) + + def test_reset_return_with_children(self): + m = MagicMock(f=MagicMock(return_value=1)) + self.assertEqual(m.f(), 1) + m.reset_mock(return_value=True) + self.assertNotEqual(m.f(), 1) + + def test_reset_return_with_children_side_effect(self): + m = MagicMock(f=MagicMock(side_effect=[2, 3])) + self.assertNotEqual(m.f.side_effect, None) + m.reset_mock(side_effect=True) + self.assertEqual(m.f.side_effect, None) + + def test_mock_add_spec(self): + class _One(object): + one = 1 + class _Two(object): + two = 2 + class Anything(object): + one = two = three = 'four' + + klasses = [ + Mock, MagicMock, NonCallableMock, NonCallableMagicMock + ] + for Klass in list(klasses): + klasses.append(lambda K=Klass: K(spec=Anything)) + klasses.append(lambda K=Klass: K(spec_set=Anything)) + + for Klass in klasses: + for kwargs in dict(), dict(spec_set=True): + mock = Klass() + #no error + mock.one, mock.two, mock.three + + for One, Two in [(_One, _Two), (['one'], ['two'])]: + for kwargs in dict(), dict(spec_set=True): + mock.mock_add_spec(One, **kwargs) + + mock.one + self.assertRaises( + AttributeError, getattr, mock, 'two' + ) + self.assertRaises( + AttributeError, getattr, mock, 'three' + ) + if 'spec_set' in kwargs: + self.assertRaises( + AttributeError, setattr, mock, 'three', None + ) + + mock.mock_add_spec(Two, **kwargs) + self.assertRaises( + AttributeError, getattr, mock, 'one' + ) + mock.two + self.assertRaises( + AttributeError, getattr, mock, 'three' + ) + if 'spec_set' in kwargs: + self.assertRaises( + AttributeError, setattr, mock, 'three', None + ) + # note that creating a mock, setting an instance attribute, and + # *then* setting a spec doesn't work. Not the intended use case + + + def test_mock_add_spec_magic_methods(self): + for Klass in MagicMock, NonCallableMagicMock: + mock = Klass() + int(mock) + + mock.mock_add_spec(object) + self.assertRaises(TypeError, int, mock) + + mock = Klass() + mock['foo'] + mock.__int__.return_value =4 + + mock.mock_add_spec(int) + self.assertEqual(int(mock), 4) + self.assertRaises(TypeError, lambda: mock['foo']) + + + def test_adding_child_mock(self): + for Klass in (NonCallableMock, Mock, MagicMock, NonCallableMagicMock, + AsyncMock): + mock = Klass() + + mock.foo = Mock() + mock.foo() + + self.assertEqual(mock.method_calls, [call.foo()]) + self.assertEqual(mock.mock_calls, [call.foo()]) + + mock = Klass() + mock.bar = Mock(name='name') + mock.bar() + self.assertEqual(mock.method_calls, []) + self.assertEqual(mock.mock_calls, []) + + # mock with an existing _new_parent but no name + mock = Klass() + mock.baz = MagicMock()() + mock.baz() + self.assertEqual(mock.method_calls, []) + self.assertEqual(mock.mock_calls, []) + + + def test_adding_return_value_mock(self): + for Klass in Mock, MagicMock: + mock = Klass() + mock.return_value = MagicMock() + + mock()() + self.assertEqual(mock.mock_calls, [call(), call()()]) + + + def test_manager_mock(self): + class Foo(object): + one = 'one' + two = 'two' + manager = Mock() + p1 = patch.object(Foo, 'one') + p2 = patch.object(Foo, 'two') + + mock_one = p1.start() + self.addCleanup(p1.stop) + mock_two = p2.start() + self.addCleanup(p2.stop) + + manager.attach_mock(mock_one, 'one') + manager.attach_mock(mock_two, 'two') + + Foo.two() + Foo.one() + + self.assertEqual(manager.mock_calls, [call.two(), call.one()]) + + + def test_magic_methods_mock_calls(self): + for Klass in Mock, MagicMock: + m = Klass() + m.__int__ = Mock(return_value=3) + m.__float__ = MagicMock(return_value=3.0) + int(m) + float(m) + + self.assertEqual(m.mock_calls, [call.__int__(), call.__float__()]) + self.assertEqual(m.method_calls, []) + + def test_mock_open_reuse_issue_21750(self): + mocked_open = mock.mock_open(read_data='data') + f1 = mocked_open('a-name') + f1_data = f1.read() + f2 = mocked_open('another-name') + f2_data = f2.read() + self.assertEqual(f1_data, f2_data) + + def test_mock_open_dunder_iter_issue(self): + # Test dunder_iter method generates the expected result and + # consumes the iterator. + mocked_open = mock.mock_open(read_data='Remarkable\nNorwegian Blue') + f1 = mocked_open('a-name') + lines = [line for line in f1] + self.assertEqual(lines[0], 'Remarkable\n') + self.assertEqual(lines[1], 'Norwegian Blue') + self.assertEqual(list(f1), []) + + def test_mock_open_using_next(self): + mocked_open = mock.mock_open(read_data='1st line\n2nd line\n3rd line') + f1 = mocked_open('a-name') + line1 = next(f1) + line2 = f1.__next__() + lines = [line for line in f1] + self.assertEqual(line1, '1st line\n') + self.assertEqual(line2, '2nd line\n') + self.assertEqual(lines[0], '3rd line') + self.assertEqual(list(f1), []) + with self.assertRaises(StopIteration): + next(f1) + + def test_mock_open_next_with_readline_with_return_value(self): + mopen = mock.mock_open(read_data='foo\nbarn') + mopen.return_value.readline.return_value = 'abc' + self.assertEqual('abc', next(mopen())) + + def test_mock_open_write(self): + # Test exception in file writing write() + mock_namedtemp = mock.mock_open(mock.MagicMock(name='JLV')) + with mock.patch('tempfile.NamedTemporaryFile', mock_namedtemp): + mock_filehandle = mock_namedtemp.return_value + mock_write = mock_filehandle.write + mock_write.side_effect = OSError('Test 2 Error') + def attempt(): + tempfile.NamedTemporaryFile().write('asd') + self.assertRaises(OSError, attempt) + + def test_mock_open_alter_readline(self): + mopen = mock.mock_open(read_data='foo\nbarn') + mopen.return_value.readline.side_effect = lambda *args:'abc' + first = mopen().readline() + second = mopen().readline() + self.assertEqual('abc', first) + self.assertEqual('abc', second) + + def test_mock_open_after_eof(self): + # read, readline and readlines should work after end of file. + _open = mock.mock_open(read_data='foo') + h = _open('bar') + h.read() + self.assertEqual('', h.read()) + self.assertEqual('', h.read()) + self.assertEqual('', h.readline()) + self.assertEqual('', h.readline()) + self.assertEqual([], h.readlines()) + self.assertEqual([], h.readlines()) + + def test_mock_parents(self): + for Klass in Mock, MagicMock: + m = Klass() + original_repr = repr(m) + m.return_value = m + self.assertIs(m(), m) + self.assertEqual(repr(m), original_repr) + + m.reset_mock() + self.assertIs(m(), m) + self.assertEqual(repr(m), original_repr) + + m = Klass() + m.b = m.a + self.assertIn("name='mock.a'", repr(m.b)) + self.assertIn("name='mock.a'", repr(m.a)) + m.reset_mock() + self.assertIn("name='mock.a'", repr(m.b)) + self.assertIn("name='mock.a'", repr(m.a)) + + m = Klass() + original_repr = repr(m) + m.a = m() + m.a.return_value = m + + self.assertEqual(repr(m), original_repr) + self.assertEqual(repr(m.a()), original_repr) + + + def test_attach_mock(self): + classes = Mock, MagicMock, NonCallableMagicMock, NonCallableMock + for Klass in classes: + for Klass2 in classes: + m = Klass() + + m2 = Klass2(name='foo') + m.attach_mock(m2, 'bar') + + self.assertIs(m.bar, m2) + self.assertIn("name='mock.bar'", repr(m2)) + + m.bar.baz(1) + self.assertEqual(m.mock_calls, [call.bar.baz(1)]) + self.assertEqual(m.method_calls, [call.bar.baz(1)]) + + + def test_attach_mock_return_value(self): + classes = Mock, MagicMock, NonCallableMagicMock, NonCallableMock + for Klass in Mock, MagicMock: + for Klass2 in classes: + m = Klass() + + m2 = Klass2(name='foo') + m.attach_mock(m2, 'return_value') + + self.assertIs(m(), m2) + self.assertIn("name='mock()'", repr(m2)) + + m2.foo() + self.assertEqual(m.mock_calls, call().foo().call_list()) + + + def test_attach_mock_patch_autospec(self): + parent = Mock() + + with mock.patch(f'{__name__}.something', autospec=True) as mock_func: + self.assertEqual(mock_func.mock._extract_mock_name(), 'something') + parent.attach_mock(mock_func, 'child') + parent.child(1) + something(2) + mock_func(3) + + parent_calls = [call.child(1), call.child(2), call.child(3)] + child_calls = [call(1), call(2), call(3)] + self.assertEqual(parent.mock_calls, parent_calls) + self.assertEqual(parent.child.mock_calls, child_calls) + self.assertEqual(something.mock_calls, child_calls) + self.assertEqual(mock_func.mock_calls, child_calls) + self.assertIn('mock.child', repr(parent.child.mock)) + self.assertEqual(mock_func.mock._extract_mock_name(), 'mock.child') + + + def test_attach_mock_patch_autospec_signature(self): + with mock.patch(f'{__name__}.Something.meth', autospec=True) as mocked: + manager = Mock() + manager.attach_mock(mocked, 'attach_meth') + obj = Something() + obj.meth(1, 2, 3, d=4) + manager.assert_has_calls([call.attach_meth(mock.ANY, 1, 2, 3, d=4)]) + obj.meth.assert_has_calls([call(mock.ANY, 1, 2, 3, d=4)]) + mocked.assert_has_calls([call(mock.ANY, 1, 2, 3, d=4)]) + + with mock.patch(f'{__name__}.something', autospec=True) as mocked: + manager = Mock() + manager.attach_mock(mocked, 'attach_func') + something(1) + manager.assert_has_calls([call.attach_func(1)]) + something.assert_has_calls([call(1)]) + mocked.assert_has_calls([call(1)]) + + with mock.patch(f'{__name__}.Something', autospec=True) as mocked: + manager = Mock() + manager.attach_mock(mocked, 'attach_obj') + obj = Something() + obj.meth(1, 2, 3, d=4) + manager.assert_has_calls([call.attach_obj(), + call.attach_obj().meth(1, 2, 3, d=4)]) + obj.meth.assert_has_calls([call(1, 2, 3, d=4)]) + mocked.assert_has_calls([call(), call().meth(1, 2, 3, d=4)]) + + + def test_attribute_deletion(self): + for mock in (Mock(), MagicMock(), NonCallableMagicMock(), + NonCallableMock()): + self.assertTrue(hasattr(mock, 'm')) + + del mock.m + self.assertFalse(hasattr(mock, 'm')) + + del mock.f + self.assertFalse(hasattr(mock, 'f')) + self.assertRaises(AttributeError, getattr, mock, 'f') + + + def test_mock_does_not_raise_on_repeated_attribute_deletion(self): + # bpo-20239: Assigning and deleting twice an attribute raises. + for mock in (Mock(), MagicMock(), NonCallableMagicMock(), + NonCallableMock()): + mock.foo = 3 + self.assertTrue(hasattr(mock, 'foo')) + self.assertEqual(mock.foo, 3) + + del mock.foo + self.assertFalse(hasattr(mock, 'foo')) + + mock.foo = 4 + self.assertTrue(hasattr(mock, 'foo')) + self.assertEqual(mock.foo, 4) + + del mock.foo + self.assertFalse(hasattr(mock, 'foo')) + + + def test_mock_raises_when_deleting_nonexistent_attribute(self): + for mock in (Mock(), MagicMock(), NonCallableMagicMock(), + NonCallableMock()): + del mock.foo + with self.assertRaises(AttributeError): + del mock.foo + + + def test_reset_mock_does_not_raise_on_attr_deletion(self): + # bpo-31177: reset_mock should not raise AttributeError when attributes + # were deleted in a mock instance + mock = Mock() + mock.child = True + del mock.child + mock.reset_mock() + self.assertFalse(hasattr(mock, 'child')) + + + def test_class_assignable(self): + for mock in Mock(), MagicMock(): + self.assertNotIsInstance(mock, int) + + mock.__class__ = int + self.assertIsInstance(mock, int) + mock.foo + + def test_name_attribute_of_call(self): + # bpo-35357: _Call should not disclose any attributes whose names + # may clash with popular ones (such as ".name") + self.assertIsNotNone(call.name) + self.assertEqual(type(call.name), _Call) + self.assertEqual(type(call.name().name), _Call) + + def test_parent_attribute_of_call(self): + # bpo-35357: _Call should not disclose any attributes whose names + # may clash with popular ones (such as ".parent") + self.assertIsNotNone(call.parent) + self.assertEqual(type(call.parent), _Call) + self.assertEqual(type(call.parent().parent), _Call) + + + def test_parent_propagation_with_create_autospec(self): + + def foo(a, b): pass + + mock = Mock() + mock.child = create_autospec(foo) + mock.child(1, 2) + + self.assertRaises(TypeError, mock.child, 1) + self.assertEqual(mock.mock_calls, [call.child(1, 2)]) + self.assertIn('mock.child', repr(mock.child.mock)) + + def test_parent_propagation_with_autospec_attach_mock(self): + + def foo(a, b): pass + + parent = Mock() + parent.attach_mock(create_autospec(foo, name='bar'), 'child') + parent.child(1, 2) + + self.assertRaises(TypeError, parent.child, 1) + self.assertEqual(parent.child.mock_calls, [call.child(1, 2)]) + self.assertIn('mock.child', repr(parent.child.mock)) + + + def test_isinstance_under_settrace(self): + # bpo-36593 : __class__ is not set for a class that has __class__ + # property defined when it's used with sys.settrace(trace) set. + # Delete the module to force reimport with tracing function set + # restore the old reference later since there are other tests that are + # dependent on unittest.mock.patch. In testpatch.PatchTest + # test_patch_dict_test_prefix and test_patch_test_prefix not restoring + # causes the objects patched to go out of sync + + old_patch = unittest.mock.patch + + # Directly using __setattr__ on unittest.mock causes current imported + # reference to be updated. Use a lambda so that during cleanup the + # re-imported new reference is updated. + self.addCleanup(lambda patch: setattr(unittest.mock, 'patch', patch), + old_patch) + + with patch.dict('sys.modules'): + del sys.modules['unittest.mock'] + + # This trace will stop coverage being measured ;-) + def trace(frame, event, arg): # pragma: no cover + return trace + + self.addCleanup(sys.settrace, sys.gettrace()) + sys.settrace(trace) + + from unittest.mock import ( + Mock, MagicMock, NonCallableMock, NonCallableMagicMock + ) + + mocks = [ + Mock, MagicMock, NonCallableMock, NonCallableMagicMock, AsyncMock + ] + + for mock in mocks: + obj = mock(spec=Something) + self.assertIsInstance(obj, Something) + + def test_bool_not_called_when_passing_spec_arg(self): + class Something: + def __init__(self): + self.obj_with_bool_func = unittest.mock.MagicMock() + + obj = Something() + with unittest.mock.patch.object(obj, 'obj_with_bool_func', spec=object): pass + + self.assertEqual(obj.obj_with_bool_func.__bool__.call_count, 0) + + def test_misspelled_arguments(self): + class Foo(): + one = 'one' + # patch, patch.object and create_autospec need to check for misspelled + # arguments explicitly and throw a RuntimError if found. + with self.assertRaises(RuntimeError): + with patch(f'{__name__}.Something.meth', autospect=True): pass + with self.assertRaises(RuntimeError): + with patch.object(Foo, 'one', autospect=True): pass + with self.assertRaises(RuntimeError): + with patch(f'{__name__}.Something.meth', auto_spec=True): pass + with self.assertRaises(RuntimeError): + with patch.object(Foo, 'one', auto_spec=True): pass + with self.assertRaises(RuntimeError): + with patch(f'{__name__}.Something.meth', set_spec=True): pass + with self.assertRaises(RuntimeError): + with patch.object(Foo, 'one', set_spec=True): pass + with self.assertRaises(RuntimeError): + m = create_autospec(Foo, set_spec=True) + # patch.multiple, on the other hand, should flag misspelled arguments + # through an AttributeError, when trying to find the keys from kwargs + # as attributes on the target. + with self.assertRaises(AttributeError): + with patch.multiple( + f'{__name__}.Something', meth=DEFAULT, autospect=True): pass + with self.assertRaises(AttributeError): + with patch.multiple( + f'{__name__}.Something', meth=DEFAULT, auto_spec=True): pass + with self.assertRaises(AttributeError): + with patch.multiple( + f'{__name__}.Something', meth=DEFAULT, set_spec=True): pass + + with patch(f'{__name__}.Something.meth', unsafe=True, autospect=True): + pass + with patch.object(Foo, 'one', unsafe=True, autospect=True): pass + with patch(f'{__name__}.Something.meth', unsafe=True, auto_spec=True): + pass + with patch.object(Foo, 'one', unsafe=True, auto_spec=True): pass + with patch(f'{__name__}.Something.meth', unsafe=True, set_spec=True): + pass + with patch.object(Foo, 'one', unsafe=True, set_spec=True): pass + m = create_autospec(Foo, set_spec=True, unsafe=True) + with patch.multiple( + f'{__name__}.Typos', autospect=True, set_spec=True, auto_spec=True): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/testmock/testpatch.py b/Lib/test/test_unittest/testmock/testpatch.py new file mode 100644 index 0000000..93ec0ca --- /dev/null +++ b/Lib/test/test_unittest/testmock/testpatch.py @@ -0,0 +1,1953 @@ +# Copyright (C) 2007-2012 Michael Foord & the mock team +# E-mail: fuzzyman AT voidspace DOT org DOT uk +# http://www.voidspace.org.uk/python/mock/ + +import os +import sys +from collections import OrderedDict + +import unittest +from test.test_unittest.testmock import support +from test.test_unittest.testmock.support import SomeClass, is_instance + +from test.test_importlib.util import uncache +from unittest.mock import ( + NonCallableMock, CallableMixin, sentinel, + MagicMock, Mock, NonCallableMagicMock, patch, _patch, + DEFAULT, call, _get_target +) + + +builtin_string = 'builtins' + +PTModule = sys.modules[__name__] +MODNAME = '%s.PTModule' % __name__ + + +def _get_proxy(obj, get_only=True): + class Proxy(object): + def __getattr__(self, name): + return getattr(obj, name) + if not get_only: + def __setattr__(self, name, value): + setattr(obj, name, value) + def __delattr__(self, name): + delattr(obj, name) + Proxy.__setattr__ = __setattr__ + Proxy.__delattr__ = __delattr__ + return Proxy() + + +# for use in the test +something = sentinel.Something +something_else = sentinel.SomethingElse + + +class Foo(object): + def __init__(self, a): pass + def f(self, a): pass + def g(self): pass + foo = 'bar' + + @staticmethod + def static_method(): pass + + @classmethod + def class_method(cls): pass + + class Bar(object): + def a(self): pass + +foo_name = '%s.Foo' % __name__ + + +def function(a, b=Foo): pass + + +class Container(object): + def __init__(self): + self.values = {} + + def __getitem__(self, name): + return self.values[name] + + def __setitem__(self, name, value): + self.values[name] = value + + def __delitem__(self, name): + del self.values[name] + + def __iter__(self): + return iter(self.values) + + + +class PatchTest(unittest.TestCase): + + def assertNotCallable(self, obj, magic=True): + MockClass = NonCallableMagicMock + if not magic: + MockClass = NonCallableMock + + self.assertRaises(TypeError, obj) + self.assertTrue(is_instance(obj, MockClass)) + self.assertFalse(is_instance(obj, CallableMixin)) + + + def test_single_patchobject(self): + class Something(object): + attribute = sentinel.Original + + @patch.object(Something, 'attribute', sentinel.Patched) + def test(): + self.assertEqual(Something.attribute, sentinel.Patched, "unpatched") + + test() + self.assertEqual(Something.attribute, sentinel.Original, + "patch not restored") + + def test_patchobject_with_string_as_target(self): + msg = "'Something' must be the actual object to be patched, not a str" + with self.assertRaisesRegex(TypeError, msg): + patch.object('Something', 'do_something') + + def test_patchobject_with_none(self): + class Something(object): + attribute = sentinel.Original + + @patch.object(Something, 'attribute', None) + def test(): + self.assertIsNone(Something.attribute, "unpatched") + + test() + self.assertEqual(Something.attribute, sentinel.Original, + "patch not restored") + + + def test_multiple_patchobject(self): + class Something(object): + attribute = sentinel.Original + next_attribute = sentinel.Original2 + + @patch.object(Something, 'attribute', sentinel.Patched) + @patch.object(Something, 'next_attribute', sentinel.Patched2) + def test(): + self.assertEqual(Something.attribute, sentinel.Patched, + "unpatched") + self.assertEqual(Something.next_attribute, sentinel.Patched2, + "unpatched") + + test() + self.assertEqual(Something.attribute, sentinel.Original, + "patch not restored") + self.assertEqual(Something.next_attribute, sentinel.Original2, + "patch not restored") + + + def test_object_lookup_is_quite_lazy(self): + global something + original = something + @patch('%s.something' % __name__, sentinel.Something2) + def test(): + pass + + try: + something = sentinel.replacement_value + test() + self.assertEqual(something, sentinel.replacement_value) + finally: + something = original + + + def test_patch(self): + @patch('%s.something' % __name__, sentinel.Something2) + def test(): + self.assertEqual(PTModule.something, sentinel.Something2, + "unpatched") + + test() + self.assertEqual(PTModule.something, sentinel.Something, + "patch not restored") + + @patch('%s.something' % __name__, sentinel.Something2) + @patch('%s.something_else' % __name__, sentinel.SomethingElse) + def test(): + self.assertEqual(PTModule.something, sentinel.Something2, + "unpatched") + self.assertEqual(PTModule.something_else, sentinel.SomethingElse, + "unpatched") + + self.assertEqual(PTModule.something, sentinel.Something, + "patch not restored") + self.assertEqual(PTModule.something_else, sentinel.SomethingElse, + "patch not restored") + + # Test the patching and restoring works a second time + test() + + self.assertEqual(PTModule.something, sentinel.Something, + "patch not restored") + self.assertEqual(PTModule.something_else, sentinel.SomethingElse, + "patch not restored") + + mock = Mock() + mock.return_value = sentinel.Handle + @patch('%s.open' % builtin_string, mock) + def test(): + self.assertEqual(open('filename', 'r'), sentinel.Handle, + "open not patched") + test() + test() + + self.assertNotEqual(open, mock, "patch not restored") + + + def test_patch_class_attribute(self): + @patch('%s.SomeClass.class_attribute' % __name__, + sentinel.ClassAttribute) + def test(): + self.assertEqual(PTModule.SomeClass.class_attribute, + sentinel.ClassAttribute, "unpatched") + test() + + self.assertIsNone(PTModule.SomeClass.class_attribute, + "patch not restored") + + + def test_patchobject_with_default_mock(self): + class Test(object): + something = sentinel.Original + something2 = sentinel.Original2 + + @patch.object(Test, 'something') + def test(mock): + self.assertEqual(mock, Test.something, + "Mock not passed into test function") + self.assertIsInstance(mock, MagicMock, + "patch with two arguments did not create a mock") + + test() + + @patch.object(Test, 'something') + @patch.object(Test, 'something2') + def test(this1, this2, mock1, mock2): + self.assertEqual(this1, sentinel.this1, + "Patched function didn't receive initial argument") + self.assertEqual(this2, sentinel.this2, + "Patched function didn't receive second argument") + self.assertEqual(mock1, Test.something2, + "Mock not passed into test function") + self.assertEqual(mock2, Test.something, + "Second Mock not passed into test function") + self.assertIsInstance(mock2, MagicMock, + "patch with two arguments did not create a mock") + self.assertIsInstance(mock2, MagicMock, + "patch with two arguments did not create a mock") + + # A hack to test that new mocks are passed the second time + self.assertNotEqual(outerMock1, mock1, "unexpected value for mock1") + self.assertNotEqual(outerMock2, mock2, "unexpected value for mock1") + return mock1, mock2 + + outerMock1 = outerMock2 = None + outerMock1, outerMock2 = test(sentinel.this1, sentinel.this2) + + # Test that executing a second time creates new mocks + test(sentinel.this1, sentinel.this2) + + + def test_patch_with_spec(self): + @patch('%s.SomeClass' % __name__, spec=SomeClass) + def test(MockSomeClass): + self.assertEqual(SomeClass, MockSomeClass) + self.assertTrue(is_instance(SomeClass.wibble, MagicMock)) + self.assertRaises(AttributeError, lambda: SomeClass.not_wibble) + + test() + + + def test_patchobject_with_spec(self): + @patch.object(SomeClass, 'class_attribute', spec=SomeClass) + def test(MockAttribute): + self.assertEqual(SomeClass.class_attribute, MockAttribute) + self.assertTrue(is_instance(SomeClass.class_attribute.wibble, + MagicMock)) + self.assertRaises(AttributeError, + lambda: SomeClass.class_attribute.not_wibble) + + test() + + + def test_patch_with_spec_as_list(self): + @patch('%s.SomeClass' % __name__, spec=['wibble']) + def test(MockSomeClass): + self.assertEqual(SomeClass, MockSomeClass) + self.assertTrue(is_instance(SomeClass.wibble, MagicMock)) + self.assertRaises(AttributeError, lambda: SomeClass.not_wibble) + + test() + + + def test_patchobject_with_spec_as_list(self): + @patch.object(SomeClass, 'class_attribute', spec=['wibble']) + def test(MockAttribute): + self.assertEqual(SomeClass.class_attribute, MockAttribute) + self.assertTrue(is_instance(SomeClass.class_attribute.wibble, + MagicMock)) + self.assertRaises(AttributeError, + lambda: SomeClass.class_attribute.not_wibble) + + test() + + + def test_nested_patch_with_spec_as_list(self): + # regression test for nested decorators + @patch('%s.open' % builtin_string) + @patch('%s.SomeClass' % __name__, spec=['wibble']) + def test(MockSomeClass, MockOpen): + self.assertEqual(SomeClass, MockSomeClass) + self.assertTrue(is_instance(SomeClass.wibble, MagicMock)) + self.assertRaises(AttributeError, lambda: SomeClass.not_wibble) + test() + + + def test_patch_with_spec_as_boolean(self): + @patch('%s.SomeClass' % __name__, spec=True) + def test(MockSomeClass): + self.assertEqual(SomeClass, MockSomeClass) + # Should not raise attribute error + MockSomeClass.wibble + + self.assertRaises(AttributeError, lambda: MockSomeClass.not_wibble) + + test() + + + def test_patch_object_with_spec_as_boolean(self): + @patch.object(PTModule, 'SomeClass', spec=True) + def test(MockSomeClass): + self.assertEqual(SomeClass, MockSomeClass) + # Should not raise attribute error + MockSomeClass.wibble + + self.assertRaises(AttributeError, lambda: MockSomeClass.not_wibble) + + test() + + + def test_patch_class_acts_with_spec_is_inherited(self): + @patch('%s.SomeClass' % __name__, spec=True) + def test(MockSomeClass): + self.assertTrue(is_instance(MockSomeClass, MagicMock)) + instance = MockSomeClass() + self.assertNotCallable(instance) + # Should not raise attribute error + instance.wibble + + self.assertRaises(AttributeError, lambda: instance.not_wibble) + + test() + + + def test_patch_with_create_mocks_non_existent_attributes(self): + @patch('%s.frooble' % builtin_string, sentinel.Frooble, create=True) + def test(): + self.assertEqual(frooble, sentinel.Frooble) + + test() + self.assertRaises(NameError, lambda: frooble) + + + def test_patchobject_with_create_mocks_non_existent_attributes(self): + @patch.object(SomeClass, 'frooble', sentinel.Frooble, create=True) + def test(): + self.assertEqual(SomeClass.frooble, sentinel.Frooble) + + test() + self.assertFalse(hasattr(SomeClass, 'frooble')) + + + def test_patch_wont_create_by_default(self): + with self.assertRaises(AttributeError): + @patch('%s.frooble' % builtin_string, sentinel.Frooble) + def test(): pass + + test() + self.assertRaises(NameError, lambda: frooble) + + + def test_patchobject_wont_create_by_default(self): + with self.assertRaises(AttributeError): + @patch.object(SomeClass, 'ord', sentinel.Frooble) + def test(): pass + test() + self.assertFalse(hasattr(SomeClass, 'ord')) + + + def test_patch_builtins_without_create(self): + @patch(__name__+'.ord') + def test_ord(mock_ord): + mock_ord.return_value = 101 + return ord('c') + + @patch(__name__+'.open') + def test_open(mock_open): + m = mock_open.return_value + m.read.return_value = 'abcd' + + fobj = open('doesnotexists.txt') + data = fobj.read() + fobj.close() + return data + + self.assertEqual(test_ord(), 101) + self.assertEqual(test_open(), 'abcd') + + + def test_patch_with_static_methods(self): + class Foo(object): + @staticmethod + def woot(): + return sentinel.Static + + @patch.object(Foo, 'woot', staticmethod(lambda: sentinel.Patched)) + def anonymous(): + self.assertEqual(Foo.woot(), sentinel.Patched) + anonymous() + + self.assertEqual(Foo.woot(), sentinel.Static) + + + def test_patch_local(self): + foo = sentinel.Foo + @patch.object(sentinel, 'Foo', 'Foo') + def anonymous(): + self.assertEqual(sentinel.Foo, 'Foo') + anonymous() + + self.assertEqual(sentinel.Foo, foo) + + + def test_patch_slots(self): + class Foo(object): + __slots__ = ('Foo',) + + foo = Foo() + foo.Foo = sentinel.Foo + + @patch.object(foo, 'Foo', 'Foo') + def anonymous(): + self.assertEqual(foo.Foo, 'Foo') + anonymous() + + self.assertEqual(foo.Foo, sentinel.Foo) + + + def test_patchobject_class_decorator(self): + class Something(object): + attribute = sentinel.Original + + class Foo(object): + def test_method(other_self): + self.assertEqual(Something.attribute, sentinel.Patched, + "unpatched") + def not_test_method(other_self): + self.assertEqual(Something.attribute, sentinel.Original, + "non-test method patched") + + Foo = patch.object(Something, 'attribute', sentinel.Patched)(Foo) + + f = Foo() + f.test_method() + f.not_test_method() + + self.assertEqual(Something.attribute, sentinel.Original, + "patch not restored") + + + def test_patch_class_decorator(self): + class Something(object): + attribute = sentinel.Original + + class Foo(object): + + test_class_attr = 'whatever' + + def test_method(other_self, mock_something): + self.assertEqual(PTModule.something, mock_something, + "unpatched") + def not_test_method(other_self): + self.assertEqual(PTModule.something, sentinel.Something, + "non-test method patched") + Foo = patch('%s.something' % __name__)(Foo) + + f = Foo() + f.test_method() + f.not_test_method() + + self.assertEqual(Something.attribute, sentinel.Original, + "patch not restored") + self.assertEqual(PTModule.something, sentinel.Something, + "patch not restored") + + + def test_patchobject_twice(self): + class Something(object): + attribute = sentinel.Original + next_attribute = sentinel.Original2 + + @patch.object(Something, 'attribute', sentinel.Patched) + @patch.object(Something, 'attribute', sentinel.Patched) + def test(): + self.assertEqual(Something.attribute, sentinel.Patched, "unpatched") + + test() + + self.assertEqual(Something.attribute, sentinel.Original, + "patch not restored") + + + def test_patch_dict(self): + foo = {'initial': object(), 'other': 'something'} + original = foo.copy() + + @patch.dict(foo) + def test(): + foo['a'] = 3 + del foo['initial'] + foo['other'] = 'something else' + + test() + + self.assertEqual(foo, original) + + @patch.dict(foo, {'a': 'b'}) + def test(): + self.assertEqual(len(foo), 3) + self.assertEqual(foo['a'], 'b') + + test() + + self.assertEqual(foo, original) + + @patch.dict(foo, [('a', 'b')]) + def test(): + self.assertEqual(len(foo), 3) + self.assertEqual(foo['a'], 'b') + + test() + + self.assertEqual(foo, original) + + + def test_patch_dict_with_container_object(self): + foo = Container() + foo['initial'] = object() + foo['other'] = 'something' + + original = foo.values.copy() + + @patch.dict(foo) + def test(): + foo['a'] = 3 + del foo['initial'] + foo['other'] = 'something else' + + test() + + self.assertEqual(foo.values, original) + + @patch.dict(foo, {'a': 'b'}) + def test(): + self.assertEqual(len(foo.values), 3) + self.assertEqual(foo['a'], 'b') + + test() + + self.assertEqual(foo.values, original) + + + def test_patch_dict_with_clear(self): + foo = {'initial': object(), 'other': 'something'} + original = foo.copy() + + @patch.dict(foo, clear=True) + def test(): + self.assertEqual(foo, {}) + foo['a'] = 3 + foo['other'] = 'something else' + + test() + + self.assertEqual(foo, original) + + @patch.dict(foo, {'a': 'b'}, clear=True) + def test(): + self.assertEqual(foo, {'a': 'b'}) + + test() + + self.assertEqual(foo, original) + + @patch.dict(foo, [('a', 'b')], clear=True) + def test(): + self.assertEqual(foo, {'a': 'b'}) + + test() + + self.assertEqual(foo, original) + + + def test_patch_dict_with_container_object_and_clear(self): + foo = Container() + foo['initial'] = object() + foo['other'] = 'something' + + original = foo.values.copy() + + @patch.dict(foo, clear=True) + def test(): + self.assertEqual(foo.values, {}) + foo['a'] = 3 + foo['other'] = 'something else' + + test() + + self.assertEqual(foo.values, original) + + @patch.dict(foo, {'a': 'b'}, clear=True) + def test(): + self.assertEqual(foo.values, {'a': 'b'}) + + test() + + self.assertEqual(foo.values, original) + + + def test_patch_dict_as_context_manager(self): + foo = {'a': 'b'} + with patch.dict(foo, a='c') as patched: + self.assertEqual(patched, {'a': 'c'}) + self.assertEqual(foo, {'a': 'b'}) + + + def test_name_preserved(self): + foo = {} + + @patch('%s.SomeClass' % __name__, object()) + @patch('%s.SomeClass' % __name__, object(), autospec=True) + @patch.object(SomeClass, object()) + @patch.dict(foo) + def some_name(): pass + + self.assertEqual(some_name.__name__, 'some_name') + + + def test_patch_with_exception(self): + foo = {} + + @patch.dict(foo, {'a': 'b'}) + def test(): + raise NameError('Konrad') + + with self.assertRaises(NameError): + test() + + self.assertEqual(foo, {}) + + + def test_patch_dict_with_string(self): + @patch.dict('os.environ', {'konrad_delong': 'some value'}) + def test(): + self.assertIn('konrad_delong', os.environ) + + test() + + + def test_patch_dict_decorator_resolution(self): + # bpo-35512: Ensure that patch with a string target resolves to + # the new dictionary during function call + original = support.target.copy() + + @patch.dict('test.test_unittest.testmock.support.target', {'bar': 'BAR'}) + def test(): + self.assertEqual(support.target, {'foo': 'BAZ', 'bar': 'BAR'}) + + try: + support.target = {'foo': 'BAZ'} + test() + self.assertEqual(support.target, {'foo': 'BAZ'}) + finally: + support.target = original + + + def test_patch_spec_set(self): + @patch('%s.SomeClass' % __name__, spec=SomeClass, spec_set=True) + def test(MockClass): + MockClass.z = 'foo' + + self.assertRaises(AttributeError, test) + + @patch.object(support, 'SomeClass', spec=SomeClass, spec_set=True) + def test(MockClass): + MockClass.z = 'foo' + + self.assertRaises(AttributeError, test) + @patch('%s.SomeClass' % __name__, spec_set=True) + def test(MockClass): + MockClass.z = 'foo' + + self.assertRaises(AttributeError, test) + + @patch.object(support, 'SomeClass', spec_set=True) + def test(MockClass): + MockClass.z = 'foo' + + self.assertRaises(AttributeError, test) + + + def test_spec_set_inherit(self): + @patch('%s.SomeClass' % __name__, spec_set=True) + def test(MockClass): + instance = MockClass() + instance.z = 'foo' + + self.assertRaises(AttributeError, test) + + + def test_patch_start_stop(self): + original = something + patcher = patch('%s.something' % __name__) + self.assertIs(something, original) + mock = patcher.start() + try: + self.assertIsNot(mock, original) + self.assertIs(something, mock) + finally: + patcher.stop() + self.assertIs(something, original) + + + def test_stop_without_start(self): + # bpo-36366: calling stop without start will return None. + patcher = patch(foo_name, 'bar', 3) + self.assertIsNone(patcher.stop()) + + + def test_stop_idempotent(self): + # bpo-36366: calling stop on an already stopped patch will return None. + patcher = patch(foo_name, 'bar', 3) + + patcher.start() + patcher.stop() + self.assertIsNone(patcher.stop()) + + + def test_patchobject_start_stop(self): + original = something + patcher = patch.object(PTModule, 'something', 'foo') + self.assertIs(something, original) + replaced = patcher.start() + try: + self.assertEqual(replaced, 'foo') + self.assertIs(something, replaced) + finally: + patcher.stop() + self.assertIs(something, original) + + + def test_patch_dict_start_stop(self): + d = {'foo': 'bar'} + original = d.copy() + patcher = patch.dict(d, [('spam', 'eggs')], clear=True) + self.assertEqual(d, original) + + patcher.start() + try: + self.assertEqual(d, {'spam': 'eggs'}) + finally: + patcher.stop() + self.assertEqual(d, original) + + + def test_patch_dict_stop_without_start(self): + d = {'foo': 'bar'} + original = d.copy() + patcher = patch.dict(d, [('spam', 'eggs')], clear=True) + self.assertFalse(patcher.stop()) + self.assertEqual(d, original) + + + def test_patch_dict_class_decorator(self): + this = self + d = {'spam': 'eggs'} + original = d.copy() + + class Test(object): + def test_first(self): + this.assertEqual(d, {'foo': 'bar'}) + def test_second(self): + this.assertEqual(d, {'foo': 'bar'}) + + Test = patch.dict(d, {'foo': 'bar'}, clear=True)(Test) + self.assertEqual(d, original) + + test = Test() + + test.test_first() + self.assertEqual(d, original) + + test.test_second() + self.assertEqual(d, original) + + test = Test() + + test.test_first() + self.assertEqual(d, original) + + test.test_second() + self.assertEqual(d, original) + + + def test_get_only_proxy(self): + class Something(object): + foo = 'foo' + class SomethingElse: + foo = 'foo' + + for thing in Something, SomethingElse, Something(), SomethingElse: + proxy = _get_proxy(thing) + + @patch.object(proxy, 'foo', 'bar') + def test(): + self.assertEqual(proxy.foo, 'bar') + test() + self.assertEqual(proxy.foo, 'foo') + self.assertEqual(thing.foo, 'foo') + self.assertNotIn('foo', proxy.__dict__) + + + def test_get_set_delete_proxy(self): + class Something(object): + foo = 'foo' + class SomethingElse: + foo = 'foo' + + for thing in Something, SomethingElse, Something(), SomethingElse: + proxy = _get_proxy(Something, get_only=False) + + @patch.object(proxy, 'foo', 'bar') + def test(): + self.assertEqual(proxy.foo, 'bar') + test() + self.assertEqual(proxy.foo, 'foo') + self.assertEqual(thing.foo, 'foo') + self.assertNotIn('foo', proxy.__dict__) + + + def test_patch_keyword_args(self): + kwargs = {'side_effect': KeyError, 'foo.bar.return_value': 33, + 'foo': MagicMock()} + + patcher = patch(foo_name, **kwargs) + mock = patcher.start() + patcher.stop() + + self.assertRaises(KeyError, mock) + self.assertEqual(mock.foo.bar(), 33) + self.assertIsInstance(mock.foo, MagicMock) + + + def test_patch_object_keyword_args(self): + kwargs = {'side_effect': KeyError, 'foo.bar.return_value': 33, + 'foo': MagicMock()} + + patcher = patch.object(Foo, 'f', **kwargs) + mock = patcher.start() + patcher.stop() + + self.assertRaises(KeyError, mock) + self.assertEqual(mock.foo.bar(), 33) + self.assertIsInstance(mock.foo, MagicMock) + + + def test_patch_dict_keyword_args(self): + original = {'foo': 'bar'} + copy = original.copy() + + patcher = patch.dict(original, foo=3, bar=4, baz=5) + patcher.start() + + try: + self.assertEqual(original, dict(foo=3, bar=4, baz=5)) + finally: + patcher.stop() + + self.assertEqual(original, copy) + + + def test_autospec(self): + class Boo(object): + def __init__(self, a): pass + def f(self, a): pass + def g(self): pass + foo = 'bar' + + class Bar(object): + def a(self): pass + + def _test(mock): + mock(1) + mock.assert_called_with(1) + self.assertRaises(TypeError, mock) + + def _test2(mock): + mock.f(1) + mock.f.assert_called_with(1) + self.assertRaises(TypeError, mock.f) + + mock.g() + mock.g.assert_called_with() + self.assertRaises(TypeError, mock.g, 1) + + self.assertRaises(AttributeError, getattr, mock, 'h') + + mock.foo.lower() + mock.foo.lower.assert_called_with() + self.assertRaises(AttributeError, getattr, mock.foo, 'bar') + + mock.Bar() + mock.Bar.assert_called_with() + + mock.Bar.a() + mock.Bar.a.assert_called_with() + self.assertRaises(TypeError, mock.Bar.a, 1) + + mock.Bar().a() + mock.Bar().a.assert_called_with() + self.assertRaises(TypeError, mock.Bar().a, 1) + + self.assertRaises(AttributeError, getattr, mock.Bar, 'b') + self.assertRaises(AttributeError, getattr, mock.Bar(), 'b') + + def function(mock): + _test(mock) + _test2(mock) + _test2(mock(1)) + self.assertIs(mock, Foo) + return mock + + test = patch(foo_name, autospec=True)(function) + + mock = test() + self.assertIsNot(Foo, mock) + # test patching a second time works + test() + + module = sys.modules[__name__] + test = patch.object(module, 'Foo', autospec=True)(function) + + mock = test() + self.assertIsNot(Foo, mock) + # test patching a second time works + test() + + + def test_autospec_function(self): + @patch('%s.function' % __name__, autospec=True) + def test(mock): + function.assert_not_called() + self.assertRaises(AssertionError, function.assert_called) + self.assertRaises(AssertionError, function.assert_called_once) + function(1) + self.assertRaises(AssertionError, function.assert_not_called) + function.assert_called_with(1) + function.assert_called() + function.assert_called_once() + function(2, 3) + function.assert_called_with(2, 3) + + self.assertRaises(TypeError, function) + self.assertRaises(AttributeError, getattr, function, 'foo') + + test() + + + def test_autospec_keywords(self): + @patch('%s.function' % __name__, autospec=True, + return_value=3) + def test(mock_function): + #self.assertEqual(function.abc, 'foo') + return function(1, 2) + + result = test() + self.assertEqual(result, 3) + + + def test_autospec_staticmethod(self): + with patch('%s.Foo.static_method' % __name__, autospec=True) as method: + Foo.static_method() + method.assert_called_once_with() + + + def test_autospec_classmethod(self): + with patch('%s.Foo.class_method' % __name__, autospec=True) as method: + Foo.class_method() + method.assert_called_once_with() + + + def test_autospec_with_new(self): + patcher = patch('%s.function' % __name__, new=3, autospec=True) + self.assertRaises(TypeError, patcher.start) + + module = sys.modules[__name__] + patcher = patch.object(module, 'function', new=3, autospec=True) + self.assertRaises(TypeError, patcher.start) + + + def test_autospec_with_object(self): + class Bar(Foo): + extra = [] + + patcher = patch(foo_name, autospec=Bar) + mock = patcher.start() + try: + self.assertIsInstance(mock, Bar) + self.assertIsInstance(mock.extra, list) + finally: + patcher.stop() + + + def test_autospec_inherits(self): + FooClass = Foo + patcher = patch(foo_name, autospec=True) + mock = patcher.start() + try: + self.assertIsInstance(mock, FooClass) + self.assertIsInstance(mock(3), FooClass) + finally: + patcher.stop() + + + def test_autospec_name(self): + patcher = patch(foo_name, autospec=True) + mock = patcher.start() + + try: + self.assertIn(" name='Foo'", repr(mock)) + self.assertIn(" name='Foo.f'", repr(mock.f)) + self.assertIn(" name='Foo()'", repr(mock(None))) + self.assertIn(" name='Foo().f'", repr(mock(None).f)) + finally: + patcher.stop() + + + def test_tracebacks(self): + @patch.object(Foo, 'f', object()) + def test(): + raise AssertionError + try: + test() + except: + err = sys.exc_info() + + result = unittest.TextTestResult(None, None, 0) + traceback = result._exc_info_to_string(err, self) + self.assertIn('raise AssertionError', traceback) + + + def test_new_callable_patch(self): + patcher = patch(foo_name, new_callable=NonCallableMagicMock) + + m1 = patcher.start() + patcher.stop() + m2 = patcher.start() + patcher.stop() + + self.assertIsNot(m1, m2) + for mock in m1, m2: + self.assertNotCallable(m1) + + + def test_new_callable_patch_object(self): + patcher = patch.object(Foo, 'f', new_callable=NonCallableMagicMock) + + m1 = patcher.start() + patcher.stop() + m2 = patcher.start() + patcher.stop() + + self.assertIsNot(m1, m2) + for mock in m1, m2: + self.assertNotCallable(m1) + + + def test_new_callable_keyword_arguments(self): + class Bar(object): + kwargs = None + def __init__(self, **kwargs): + Bar.kwargs = kwargs + + patcher = patch(foo_name, new_callable=Bar, arg1=1, arg2=2) + m = patcher.start() + try: + self.assertIs(type(m), Bar) + self.assertEqual(Bar.kwargs, dict(arg1=1, arg2=2)) + finally: + patcher.stop() + + + def test_new_callable_spec(self): + class Bar(object): + kwargs = None + def __init__(self, **kwargs): + Bar.kwargs = kwargs + + patcher = patch(foo_name, new_callable=Bar, spec=Bar) + patcher.start() + try: + self.assertEqual(Bar.kwargs, dict(spec=Bar)) + finally: + patcher.stop() + + patcher = patch(foo_name, new_callable=Bar, spec_set=Bar) + patcher.start() + try: + self.assertEqual(Bar.kwargs, dict(spec_set=Bar)) + finally: + patcher.stop() + + + def test_new_callable_create(self): + non_existent_attr = '%s.weeeee' % foo_name + p = patch(non_existent_attr, new_callable=NonCallableMock) + self.assertRaises(AttributeError, p.start) + + p = patch(non_existent_attr, new_callable=NonCallableMock, + create=True) + m = p.start() + try: + self.assertNotCallable(m, magic=False) + finally: + p.stop() + + + def test_new_callable_incompatible_with_new(self): + self.assertRaises( + ValueError, patch, foo_name, new=object(), new_callable=MagicMock + ) + self.assertRaises( + ValueError, patch.object, Foo, 'f', new=object(), + new_callable=MagicMock + ) + + + def test_new_callable_incompatible_with_autospec(self): + self.assertRaises( + ValueError, patch, foo_name, new_callable=MagicMock, + autospec=True + ) + self.assertRaises( + ValueError, patch.object, Foo, 'f', new_callable=MagicMock, + autospec=True + ) + + + def test_new_callable_inherit_for_mocks(self): + class MockSub(Mock): + pass + + MockClasses = ( + NonCallableMock, NonCallableMagicMock, MagicMock, Mock, MockSub + ) + for Klass in MockClasses: + for arg in 'spec', 'spec_set': + kwargs = {arg: True} + p = patch(foo_name, new_callable=Klass, **kwargs) + m = p.start() + try: + instance = m.return_value + self.assertRaises(AttributeError, getattr, instance, 'x') + finally: + p.stop() + + + def test_new_callable_inherit_non_mock(self): + class NotAMock(object): + def __init__(self, spec): + self.spec = spec + + p = patch(foo_name, new_callable=NotAMock, spec=True) + m = p.start() + try: + self.assertTrue(is_instance(m, NotAMock)) + self.assertRaises(AttributeError, getattr, m, 'return_value') + finally: + p.stop() + + self.assertEqual(m.spec, Foo) + + + def test_new_callable_class_decorating(self): + test = self + original = Foo + class SomeTest(object): + + def _test(self, mock_foo): + test.assertIsNot(Foo, original) + test.assertIs(Foo, mock_foo) + test.assertIsInstance(Foo, SomeClass) + + def test_two(self, mock_foo): + self._test(mock_foo) + def test_one(self, mock_foo): + self._test(mock_foo) + + SomeTest = patch(foo_name, new_callable=SomeClass)(SomeTest) + SomeTest().test_one() + SomeTest().test_two() + self.assertIs(Foo, original) + + + def test_patch_multiple(self): + original_foo = Foo + original_f = Foo.f + original_g = Foo.g + + patcher1 = patch.multiple(foo_name, f=1, g=2) + patcher2 = patch.multiple(Foo, f=1, g=2) + + for patcher in patcher1, patcher2: + patcher.start() + try: + self.assertIs(Foo, original_foo) + self.assertEqual(Foo.f, 1) + self.assertEqual(Foo.g, 2) + finally: + patcher.stop() + + self.assertIs(Foo, original_foo) + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + + + @patch.multiple(foo_name, f=3, g=4) + def test(): + self.assertIs(Foo, original_foo) + self.assertEqual(Foo.f, 3) + self.assertEqual(Foo.g, 4) + + test() + + + def test_patch_multiple_no_kwargs(self): + self.assertRaises(ValueError, patch.multiple, foo_name) + self.assertRaises(ValueError, patch.multiple, Foo) + + + def test_patch_multiple_create_mocks(self): + original_foo = Foo + original_f = Foo.f + original_g = Foo.g + + @patch.multiple(foo_name, f=DEFAULT, g=3, foo=DEFAULT) + def test(f, foo): + self.assertIs(Foo, original_foo) + self.assertIs(Foo.f, f) + self.assertEqual(Foo.g, 3) + self.assertIs(Foo.foo, foo) + self.assertTrue(is_instance(f, MagicMock)) + self.assertTrue(is_instance(foo, MagicMock)) + + test() + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + + + def test_patch_multiple_create_mocks_different_order(self): + original_f = Foo.f + original_g = Foo.g + + patcher = patch.object(Foo, 'f', 3) + patcher.attribute_name = 'f' + + other = patch.object(Foo, 'g', DEFAULT) + other.attribute_name = 'g' + patcher.additional_patchers = [other] + + @patcher + def test(g): + self.assertIs(Foo.g, g) + self.assertEqual(Foo.f, 3) + + test() + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + + + def test_patch_multiple_stacked_decorators(self): + original_foo = Foo + original_f = Foo.f + original_g = Foo.g + + @patch.multiple(foo_name, f=DEFAULT) + @patch.multiple(foo_name, foo=DEFAULT) + @patch(foo_name + '.g') + def test1(g, **kwargs): + _test(g, **kwargs) + + @patch.multiple(foo_name, f=DEFAULT) + @patch(foo_name + '.g') + @patch.multiple(foo_name, foo=DEFAULT) + def test2(g, **kwargs): + _test(g, **kwargs) + + @patch(foo_name + '.g') + @patch.multiple(foo_name, f=DEFAULT) + @patch.multiple(foo_name, foo=DEFAULT) + def test3(g, **kwargs): + _test(g, **kwargs) + + def _test(g, **kwargs): + f = kwargs.pop('f') + foo = kwargs.pop('foo') + self.assertFalse(kwargs) + + self.assertIs(Foo, original_foo) + self.assertIs(Foo.f, f) + self.assertIs(Foo.g, g) + self.assertIs(Foo.foo, foo) + self.assertTrue(is_instance(f, MagicMock)) + self.assertTrue(is_instance(g, MagicMock)) + self.assertTrue(is_instance(foo, MagicMock)) + + test1() + test2() + test3() + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + + + def test_patch_multiple_create_mocks_patcher(self): + original_foo = Foo + original_f = Foo.f + original_g = Foo.g + + patcher = patch.multiple(foo_name, f=DEFAULT, g=3, foo=DEFAULT) + + result = patcher.start() + try: + f = result['f'] + foo = result['foo'] + self.assertEqual(set(result), set(['f', 'foo'])) + + self.assertIs(Foo, original_foo) + self.assertIs(Foo.f, f) + self.assertIs(Foo.foo, foo) + self.assertTrue(is_instance(f, MagicMock)) + self.assertTrue(is_instance(foo, MagicMock)) + finally: + patcher.stop() + + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + + + def test_patch_multiple_decorating_class(self): + test = self + original_foo = Foo + original_f = Foo.f + original_g = Foo.g + + class SomeTest(object): + + def _test(self, f, foo): + test.assertIs(Foo, original_foo) + test.assertIs(Foo.f, f) + test.assertEqual(Foo.g, 3) + test.assertIs(Foo.foo, foo) + test.assertTrue(is_instance(f, MagicMock)) + test.assertTrue(is_instance(foo, MagicMock)) + + def test_two(self, f, foo): + self._test(f, foo) + def test_one(self, f, foo): + self._test(f, foo) + + SomeTest = patch.multiple( + foo_name, f=DEFAULT, g=3, foo=DEFAULT + )(SomeTest) + + thing = SomeTest() + thing.test_one() + thing.test_two() + + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + + + def test_patch_multiple_create(self): + patcher = patch.multiple(Foo, blam='blam') + self.assertRaises(AttributeError, patcher.start) + + patcher = patch.multiple(Foo, blam='blam', create=True) + patcher.start() + try: + self.assertEqual(Foo.blam, 'blam') + finally: + patcher.stop() + + self.assertFalse(hasattr(Foo, 'blam')) + + + def test_patch_multiple_spec_set(self): + # if spec_set works then we can assume that spec and autospec also + # work as the underlying machinery is the same + patcher = patch.multiple(Foo, foo=DEFAULT, spec_set=['a', 'b']) + result = patcher.start() + try: + self.assertEqual(Foo.foo, result['foo']) + Foo.foo.a(1) + Foo.foo.b(2) + Foo.foo.a.assert_called_with(1) + Foo.foo.b.assert_called_with(2) + self.assertRaises(AttributeError, setattr, Foo.foo, 'c', None) + finally: + patcher.stop() + + + def test_patch_multiple_new_callable(self): + class Thing(object): + pass + + patcher = patch.multiple( + Foo, f=DEFAULT, g=DEFAULT, new_callable=Thing + ) + result = patcher.start() + try: + self.assertIs(Foo.f, result['f']) + self.assertIs(Foo.g, result['g']) + self.assertIsInstance(Foo.f, Thing) + self.assertIsInstance(Foo.g, Thing) + self.assertIsNot(Foo.f, Foo.g) + finally: + patcher.stop() + + + def test_nested_patch_failure(self): + original_f = Foo.f + original_g = Foo.g + + @patch.object(Foo, 'g', 1) + @patch.object(Foo, 'missing', 1) + @patch.object(Foo, 'f', 1) + def thing1(): pass + + @patch.object(Foo, 'missing', 1) + @patch.object(Foo, 'g', 1) + @patch.object(Foo, 'f', 1) + def thing2(): pass + + @patch.object(Foo, 'g', 1) + @patch.object(Foo, 'f', 1) + @patch.object(Foo, 'missing', 1) + def thing3(): pass + + for func in thing1, thing2, thing3: + self.assertRaises(AttributeError, func) + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + + + def test_new_callable_failure(self): + original_f = Foo.f + original_g = Foo.g + original_foo = Foo.foo + + def crasher(): + raise NameError('crasher') + + @patch.object(Foo, 'g', 1) + @patch.object(Foo, 'foo', new_callable=crasher) + @patch.object(Foo, 'f', 1) + def thing1(): pass + + @patch.object(Foo, 'foo', new_callable=crasher) + @patch.object(Foo, 'g', 1) + @patch.object(Foo, 'f', 1) + def thing2(): pass + + @patch.object(Foo, 'g', 1) + @patch.object(Foo, 'f', 1) + @patch.object(Foo, 'foo', new_callable=crasher) + def thing3(): pass + + for func in thing1, thing2, thing3: + self.assertRaises(NameError, func) + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + self.assertEqual(Foo.foo, original_foo) + + + def test_patch_multiple_failure(self): + original_f = Foo.f + original_g = Foo.g + + patcher = patch.object(Foo, 'f', 1) + patcher.attribute_name = 'f' + + good = patch.object(Foo, 'g', 1) + good.attribute_name = 'g' + + bad = patch.object(Foo, 'missing', 1) + bad.attribute_name = 'missing' + + for additionals in [good, bad], [bad, good]: + patcher.additional_patchers = additionals + + @patcher + def func(): pass + + self.assertRaises(AttributeError, func) + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + + + def test_patch_multiple_new_callable_failure(self): + original_f = Foo.f + original_g = Foo.g + original_foo = Foo.foo + + def crasher(): + raise NameError('crasher') + + patcher = patch.object(Foo, 'f', 1) + patcher.attribute_name = 'f' + + good = patch.object(Foo, 'g', 1) + good.attribute_name = 'g' + + bad = patch.object(Foo, 'foo', new_callable=crasher) + bad.attribute_name = 'foo' + + for additionals in [good, bad], [bad, good]: + patcher.additional_patchers = additionals + + @patcher + def func(): pass + + self.assertRaises(NameError, func) + self.assertEqual(Foo.f, original_f) + self.assertEqual(Foo.g, original_g) + self.assertEqual(Foo.foo, original_foo) + + + def test_patch_multiple_string_subclasses(self): + Foo = type('Foo', (str,), {'fish': 'tasty'}) + foo = Foo() + @patch.multiple(foo, fish='nearly gone') + def test(): + self.assertEqual(foo.fish, 'nearly gone') + + test() + self.assertEqual(foo.fish, 'tasty') + + + @patch('unittest.mock.patch.TEST_PREFIX', 'foo') + def test_patch_test_prefix(self): + class Foo(object): + thing = 'original' + + def foo_one(self): + return self.thing + def foo_two(self): + return self.thing + def test_one(self): + return self.thing + def test_two(self): + return self.thing + + Foo = patch.object(Foo, 'thing', 'changed')(Foo) + + foo = Foo() + self.assertEqual(foo.foo_one(), 'changed') + self.assertEqual(foo.foo_two(), 'changed') + self.assertEqual(foo.test_one(), 'original') + self.assertEqual(foo.test_two(), 'original') + + + @patch('unittest.mock.patch.TEST_PREFIX', 'bar') + def test_patch_dict_test_prefix(self): + class Foo(object): + def bar_one(self): + return dict(the_dict) + def bar_two(self): + return dict(the_dict) + def test_one(self): + return dict(the_dict) + def test_two(self): + return dict(the_dict) + + the_dict = {'key': 'original'} + Foo = patch.dict(the_dict, key='changed')(Foo) + + foo =Foo() + self.assertEqual(foo.bar_one(), {'key': 'changed'}) + self.assertEqual(foo.bar_two(), {'key': 'changed'}) + self.assertEqual(foo.test_one(), {'key': 'original'}) + self.assertEqual(foo.test_two(), {'key': 'original'}) + + + def test_patch_with_spec_mock_repr(self): + for arg in ('spec', 'autospec', 'spec_set'): + p = patch('%s.SomeClass' % __name__, **{arg: True}) + m = p.start() + try: + self.assertIn(" name='SomeClass'", repr(m)) + self.assertIn(" name='SomeClass.class_attribute'", + repr(m.class_attribute)) + self.assertIn(" name='SomeClass()'", repr(m())) + self.assertIn(" name='SomeClass().class_attribute'", + repr(m().class_attribute)) + finally: + p.stop() + + + def test_patch_nested_autospec_repr(self): + with patch('test.test_unittest.testmock.support', autospec=True) as m: + self.assertIn(" name='support.SomeClass.wibble()'", + repr(m.SomeClass.wibble())) + self.assertIn(" name='support.SomeClass().wibble()'", + repr(m.SomeClass().wibble())) + + + + def test_mock_calls_with_patch(self): + for arg in ('spec', 'autospec', 'spec_set'): + p = patch('%s.SomeClass' % __name__, **{arg: True}) + m = p.start() + try: + m.wibble() + + kalls = [call.wibble()] + self.assertEqual(m.mock_calls, kalls) + self.assertEqual(m.method_calls, kalls) + self.assertEqual(m.wibble.mock_calls, [call()]) + + result = m() + kalls.append(call()) + self.assertEqual(m.mock_calls, kalls) + + result.wibble() + kalls.append(call().wibble()) + self.assertEqual(m.mock_calls, kalls) + + self.assertEqual(result.mock_calls, [call.wibble()]) + self.assertEqual(result.wibble.mock_calls, [call()]) + self.assertEqual(result.method_calls, [call.wibble()]) + finally: + p.stop() + + + def test_patch_imports_lazily(self): + p1 = patch('squizz.squozz') + self.assertRaises(ImportError, p1.start) + + with uncache('squizz'): + squizz = Mock() + sys.modules['squizz'] = squizz + + squizz.squozz = 6 + p1 = patch('squizz.squozz') + squizz.squozz = 3 + p1.start() + p1.stop() + self.assertEqual(squizz.squozz, 3) + + def test_patch_propagates_exc_on_exit(self): + class holder: + exc_info = None, None, None + + class custom_patch(_patch): + def __exit__(self, etype=None, val=None, tb=None): + _patch.__exit__(self, etype, val, tb) + holder.exc_info = etype, val, tb + stop = __exit__ + + def with_custom_patch(target): + getter, attribute = _get_target(target) + return custom_patch( + getter, attribute, DEFAULT, None, False, None, + None, None, {} + ) + + @with_custom_patch('squizz.squozz') + def test(mock): + raise RuntimeError + + with uncache('squizz'): + squizz = Mock() + sys.modules['squizz'] = squizz + + self.assertRaises(RuntimeError, test) + + self.assertIs(holder.exc_info[0], RuntimeError) + self.assertIsNotNone(holder.exc_info[1], + 'exception value not propagated') + self.assertIsNotNone(holder.exc_info[2], + 'exception traceback not propagated') + + + def test_create_and_specs(self): + for kwarg in ('spec', 'spec_set', 'autospec'): + p = patch('%s.doesnotexist' % __name__, create=True, + **{kwarg: True}) + self.assertRaises(TypeError, p.start) + self.assertRaises(NameError, lambda: doesnotexist) + + # check that spec with create is innocuous if the original exists + p = patch(MODNAME, create=True, **{kwarg: True}) + p.start() + p.stop() + + + def test_multiple_specs(self): + original = PTModule + for kwarg in ('spec', 'spec_set'): + p = patch(MODNAME, autospec=0, **{kwarg: 0}) + self.assertRaises(TypeError, p.start) + self.assertIs(PTModule, original) + + for kwarg in ('spec', 'autospec'): + p = patch(MODNAME, spec_set=0, **{kwarg: 0}) + self.assertRaises(TypeError, p.start) + self.assertIs(PTModule, original) + + for kwarg in ('spec_set', 'autospec'): + p = patch(MODNAME, spec=0, **{kwarg: 0}) + self.assertRaises(TypeError, p.start) + self.assertIs(PTModule, original) + + + def test_specs_false_instead_of_none(self): + p = patch(MODNAME, spec=False, spec_set=False, autospec=False) + mock = p.start() + try: + # no spec should have been set, so attribute access should not fail + mock.does_not_exist + mock.does_not_exist = 3 + finally: + p.stop() + + + def test_falsey_spec(self): + for kwarg in ('spec', 'autospec', 'spec_set'): + p = patch(MODNAME, **{kwarg: 0}) + m = p.start() + try: + self.assertRaises(AttributeError, getattr, m, 'doesnotexit') + finally: + p.stop() + + + def test_spec_set_true(self): + for kwarg in ('spec', 'autospec'): + p = patch(MODNAME, spec_set=True, **{kwarg: True}) + m = p.start() + try: + self.assertRaises(AttributeError, setattr, m, + 'doesnotexist', 'something') + self.assertRaises(AttributeError, getattr, m, 'doesnotexist') + finally: + p.stop() + + + def test_callable_spec_as_list(self): + spec = ('__call__',) + p = patch(MODNAME, spec=spec) + m = p.start() + try: + self.assertTrue(callable(m)) + finally: + p.stop() + + + def test_not_callable_spec_as_list(self): + spec = ('foo', 'bar') + p = patch(MODNAME, spec=spec) + m = p.start() + try: + self.assertFalse(callable(m)) + finally: + p.stop() + + + def test_patch_stopall(self): + unlink = os.unlink + chdir = os.chdir + path = os.path + patch('os.unlink', something).start() + patch('os.chdir', something_else).start() + + @patch('os.path') + def patched(mock_path): + patch.stopall() + self.assertIs(os.path, mock_path) + self.assertIs(os.unlink, unlink) + self.assertIs(os.chdir, chdir) + + patched() + self.assertIs(os.path, path) + + def test_stopall_lifo(self): + stopped = [] + class thing(object): + one = two = three = None + + def get_patch(attribute): + class mypatch(_patch): + def stop(self): + stopped.append(attribute) + return super(mypatch, self).stop() + return mypatch(lambda: thing, attribute, None, None, + False, None, None, None, {}) + [get_patch(val).start() for val in ("one", "two", "three")] + patch.stopall() + + self.assertEqual(stopped, ["three", "two", "one"]) + + def test_patch_dict_stopall(self): + dic1 = {} + dic2 = {1: 'a'} + dic3 = {1: 'A', 2: 'B'} + origdic1 = dic1.copy() + origdic2 = dic2.copy() + origdic3 = dic3.copy() + patch.dict(dic1, {1: 'I', 2: 'II'}).start() + patch.dict(dic2, {2: 'b'}).start() + + @patch.dict(dic3) + def patched(): + del dic3[1] + + patched() + self.assertNotEqual(dic1, origdic1) + self.assertNotEqual(dic2, origdic2) + self.assertEqual(dic3, origdic3) + + patch.stopall() + + self.assertEqual(dic1, origdic1) + self.assertEqual(dic2, origdic2) + self.assertEqual(dic3, origdic3) + + + def test_patch_and_patch_dict_stopall(self): + original_unlink = os.unlink + original_chdir = os.chdir + dic1 = {} + dic2 = {1: 'A', 2: 'B'} + origdic1 = dic1.copy() + origdic2 = dic2.copy() + + patch('os.unlink', something).start() + patch('os.chdir', something_else).start() + patch.dict(dic1, {1: 'I', 2: 'II'}).start() + patch.dict(dic2).start() + del dic2[1] + + self.assertIsNot(os.unlink, original_unlink) + self.assertIsNot(os.chdir, original_chdir) + self.assertNotEqual(dic1, origdic1) + self.assertNotEqual(dic2, origdic2) + patch.stopall() + self.assertIs(os.unlink, original_unlink) + self.assertIs(os.chdir, original_chdir) + self.assertEqual(dic1, origdic1) + self.assertEqual(dic2, origdic2) + + + def test_special_attrs(self): + def foo(x=0): + """TEST""" + return x + with patch.object(foo, '__defaults__', (1, )): + self.assertEqual(foo(), 1) + self.assertEqual(foo(), 0) + + orig_doc = foo.__doc__ + with patch.object(foo, '__doc__', "FUN"): + self.assertEqual(foo.__doc__, "FUN") + self.assertEqual(foo.__doc__, orig_doc) + + with patch.object(foo, '__module__', "testpatch2"): + self.assertEqual(foo.__module__, "testpatch2") + self.assertEqual(foo.__module__, 'test.test_unittest.testmock.testpatch') + + with patch.object(foo, '__annotations__', dict([('s', 1, )])): + self.assertEqual(foo.__annotations__, dict([('s', 1, )])) + self.assertEqual(foo.__annotations__, dict()) + + def foo(*a, x=0): + return x + with patch.object(foo, '__kwdefaults__', dict([('x', 1, )])): + self.assertEqual(foo(), 1) + self.assertEqual(foo(), 0) + + def test_patch_orderdict(self): + foo = OrderedDict() + foo['a'] = object() + foo['b'] = 'python' + + original = foo.copy() + update_values = list(zip('cdefghijklmnopqrstuvwxyz', range(26))) + patched_values = list(foo.items()) + update_values + + with patch.dict(foo, OrderedDict(update_values)): + self.assertEqual(list(foo.items()), patched_values) + + self.assertEqual(foo, original) + + with patch.dict(foo, update_values): + self.assertEqual(list(foo.items()), patched_values) + + self.assertEqual(foo, original) + + def test_dotted_but_module_not_loaded(self): + # This exercises the AttributeError branch of _dot_lookup. + + # make sure it's there + import test.test_unittest.testmock.support + # now make sure it's not: + with patch.dict('sys.modules'): + del sys.modules['test.test_unittest.testmock.support'] + del sys.modules['test.test_unittest.testmock'] + del sys.modules['test.test_unittest'] + del sys.modules['unittest'] + + # now make sure we can patch based on a dotted path: + @patch('test.test_unittest.testmock.support.X') + def test(mock): + pass + test() + + + def test_invalid_target(self): + class Foo: + pass + + for target in ['', 12, Foo()]: + with self.subTest(target=target): + with self.assertRaises(TypeError): + patch(target) + + + def test_cant_set_kwargs_when_passing_a_mock(self): + @patch('test.test_unittest.testmock.support.X', new=object(), x=1) + def test(): pass + with self.assertRaises(TypeError): + test() + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/testmock/testsealable.py b/Lib/test/test_unittest/testmock/testsealable.py new file mode 100644 index 0000000..daba2b4 --- /dev/null +++ b/Lib/test/test_unittest/testmock/testsealable.py @@ -0,0 +1,237 @@ +import unittest +from unittest import mock + + +class SampleObject: + + def method_sample1(self): pass + + def method_sample2(self): pass + + +class TestSealable(unittest.TestCase): + + def test_attributes_return_more_mocks_by_default(self): + m = mock.Mock() + + self.assertIsInstance(m.test, mock.Mock) + self.assertIsInstance(m.test(), mock.Mock) + self.assertIsInstance(m.test().test2(), mock.Mock) + + def test_new_attributes_cannot_be_accessed_on_seal(self): + m = mock.Mock() + + mock.seal(m) + with self.assertRaises(AttributeError): + m.test + with self.assertRaises(AttributeError): + m() + + def test_new_attributes_cannot_be_set_on_seal(self): + m = mock.Mock() + + mock.seal(m) + with self.assertRaises(AttributeError): + m.test = 1 + + def test_existing_attributes_can_be_set_on_seal(self): + m = mock.Mock() + m.test.test2 = 1 + + mock.seal(m) + m.test.test2 = 2 + self.assertEqual(m.test.test2, 2) + + def test_new_attributes_cannot_be_set_on_child_of_seal(self): + m = mock.Mock() + m.test.test2 = 1 + + mock.seal(m) + with self.assertRaises(AttributeError): + m.test.test3 = 1 + + def test_existing_attributes_allowed_after_seal(self): + m = mock.Mock() + + m.test.return_value = 3 + + mock.seal(m) + self.assertEqual(m.test(), 3) + + def test_initialized_attributes_allowed_after_seal(self): + m = mock.Mock(test_value=1) + + mock.seal(m) + self.assertEqual(m.test_value, 1) + + def test_call_on_sealed_mock_fails(self): + m = mock.Mock() + + mock.seal(m) + with self.assertRaises(AttributeError): + m() + + def test_call_on_defined_sealed_mock_succeeds(self): + m = mock.Mock(return_value=5) + + mock.seal(m) + self.assertEqual(m(), 5) + + def test_seals_recurse_on_added_attributes(self): + m = mock.Mock() + + m.test1.test2().test3 = 4 + + mock.seal(m) + self.assertEqual(m.test1.test2().test3, 4) + with self.assertRaises(AttributeError): + m.test1.test2().test4 + with self.assertRaises(AttributeError): + m.test1.test3 + + def test_seals_recurse_on_magic_methods(self): + m = mock.MagicMock() + + m.test1.test2["a"].test3 = 4 + m.test1.test3[2:5].test3 = 4 + + mock.seal(m) + self.assertEqual(m.test1.test2["a"].test3, 4) + self.assertEqual(m.test1.test2[2:5].test3, 4) + with self.assertRaises(AttributeError): + m.test1.test2["a"].test4 + with self.assertRaises(AttributeError): + m.test1.test3[2:5].test4 + + def test_seals_dont_recurse_on_manual_attributes(self): + m = mock.Mock(name="root_mock") + + m.test1.test2 = mock.Mock(name="not_sealed") + m.test1.test2.test3 = 4 + + mock.seal(m) + self.assertEqual(m.test1.test2.test3, 4) + m.test1.test2.test4 # Does not raise + m.test1.test2.test4 = 1 # Does not raise + + def test_integration_with_spec_att_definition(self): + """You are not restricted when using mock with spec""" + m = mock.Mock(SampleObject) + + m.attr_sample1 = 1 + m.attr_sample3 = 3 + + mock.seal(m) + self.assertEqual(m.attr_sample1, 1) + self.assertEqual(m.attr_sample3, 3) + with self.assertRaises(AttributeError): + m.attr_sample2 + + def test_integration_with_spec_method_definition(self): + """You need to define the methods, even if they are in the spec""" + m = mock.Mock(SampleObject) + + m.method_sample1.return_value = 1 + + mock.seal(m) + self.assertEqual(m.method_sample1(), 1) + with self.assertRaises(AttributeError): + m.method_sample2() + + def test_integration_with_spec_method_definition_respects_spec(self): + """You cannot define methods out of the spec""" + m = mock.Mock(SampleObject) + + with self.assertRaises(AttributeError): + m.method_sample3.return_value = 3 + + def test_sealed_exception_has_attribute_name(self): + m = mock.Mock() + + mock.seal(m) + with self.assertRaises(AttributeError) as cm: + m.SECRETE_name + self.assertIn("SECRETE_name", str(cm.exception)) + + def test_attribute_chain_is_maintained(self): + m = mock.Mock(name="mock_name") + m.test1.test2.test3.test4 + + mock.seal(m) + with self.assertRaises(AttributeError) as cm: + m.test1.test2.test3.test4.boom + self.assertIn("mock_name.test1.test2.test3.test4.boom", str(cm.exception)) + + def test_call_chain_is_maintained(self): + m = mock.Mock() + m.test1().test2.test3().test4 + + mock.seal(m) + with self.assertRaises(AttributeError) as cm: + m.test1().test2.test3().test4() + self.assertIn("mock.test1().test2.test3().test4", str(cm.exception)) + + def test_seal_with_autospec(self): + # https://bugs.python.org/issue45156 + class Foo: + foo = 0 + def bar1(self): + return 1 + def bar2(self): + return 2 + + class Baz: + baz = 3 + def ban(self): + return 4 + + for spec_set in (True, False): + with self.subTest(spec_set=spec_set): + foo = mock.create_autospec(Foo, spec_set=spec_set) + foo.bar1.return_value = 'a' + foo.Baz.ban.return_value = 'b' + + mock.seal(foo) + + self.assertIsInstance(foo.foo, mock.NonCallableMagicMock) + self.assertIsInstance(foo.bar1, mock.MagicMock) + self.assertIsInstance(foo.bar2, mock.MagicMock) + self.assertIsInstance(foo.Baz, mock.MagicMock) + self.assertIsInstance(foo.Baz.baz, mock.NonCallableMagicMock) + self.assertIsInstance(foo.Baz.ban, mock.MagicMock) + + self.assertEqual(foo.bar1(), 'a') + foo.bar1.return_value = 'new_a' + self.assertEqual(foo.bar1(), 'new_a') + self.assertEqual(foo.Baz.ban(), 'b') + foo.Baz.ban.return_value = 'new_b' + self.assertEqual(foo.Baz.ban(), 'new_b') + + with self.assertRaises(TypeError): + foo.foo() + with self.assertRaises(AttributeError): + foo.bar = 1 + with self.assertRaises(AttributeError): + foo.bar2() + + foo.bar2.return_value = 'bar2' + self.assertEqual(foo.bar2(), 'bar2') + + with self.assertRaises(AttributeError): + foo.missing_attr + with self.assertRaises(AttributeError): + foo.missing_attr = 1 + with self.assertRaises(AttributeError): + foo.missing_method() + with self.assertRaises(TypeError): + foo.Baz.baz() + with self.assertRaises(AttributeError): + foo.Baz.missing_attr + with self.assertRaises(AttributeError): + foo.Baz.missing_attr = 1 + with self.assertRaises(AttributeError): + foo.Baz.missing_method() + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_unittest/testmock/testsentinel.py b/Lib/test/test_unittest/testmock/testsentinel.py new file mode 100644 index 0000000..de53509 --- /dev/null +++ b/Lib/test/test_unittest/testmock/testsentinel.py @@ -0,0 +1,41 @@ +import unittest +import copy +import pickle +from unittest.mock import sentinel, DEFAULT + + +class SentinelTest(unittest.TestCase): + + def testSentinels(self): + self.assertEqual(sentinel.whatever, sentinel.whatever, + 'sentinel not stored') + self.assertNotEqual(sentinel.whatever, sentinel.whateverelse, + 'sentinel should be unique') + + + def testSentinelName(self): + self.assertEqual(str(sentinel.whatever), 'sentinel.whatever', + 'sentinel name incorrect') + + + def testDEFAULT(self): + self.assertIs(DEFAULT, sentinel.DEFAULT) + + def testBases(self): + # If this doesn't raise an AttributeError then help(mock) is broken + self.assertRaises(AttributeError, lambda: sentinel.__bases__) + + def testPickle(self): + for proto in range(pickle.HIGHEST_PROTOCOL+1): + with self.subTest(protocol=proto): + pickled = pickle.dumps(sentinel.whatever, proto) + unpickled = pickle.loads(pickled) + self.assertIs(unpickled, sentinel.whatever) + + def testCopy(self): + self.assertIs(copy.copy(sentinel.whatever), sentinel.whatever) + self.assertIs(copy.deepcopy(sentinel.whatever), sentinel.whatever) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_unittest/testmock/testwith.py b/Lib/test/test_unittest/testmock/testwith.py new file mode 100644 index 0000000..8dc8eb1 --- /dev/null +++ b/Lib/test/test_unittest/testmock/testwith.py @@ -0,0 +1,347 @@ +import unittest +from warnings import catch_warnings + +from test.test_unittest.testmock.support import is_instance +from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call + + + +something = sentinel.Something +something_else = sentinel.SomethingElse + + +class SampleException(Exception): pass + + +class WithTest(unittest.TestCase): + + def test_with_statement(self): + with patch('%s.something' % __name__, sentinel.Something2): + self.assertEqual(something, sentinel.Something2, "unpatched") + self.assertEqual(something, sentinel.Something) + + + def test_with_statement_exception(self): + with self.assertRaises(SampleException): + with patch('%s.something' % __name__, sentinel.Something2): + self.assertEqual(something, sentinel.Something2, "unpatched") + raise SampleException() + self.assertEqual(something, sentinel.Something) + + + def test_with_statement_as(self): + with patch('%s.something' % __name__) as mock_something: + self.assertEqual(something, mock_something, "unpatched") + self.assertTrue(is_instance(mock_something, MagicMock), + "patching wrong type") + self.assertEqual(something, sentinel.Something) + + + def test_patch_object_with_statement(self): + class Foo(object): + something = 'foo' + original = Foo.something + with patch.object(Foo, 'something'): + self.assertNotEqual(Foo.something, original, "unpatched") + self.assertEqual(Foo.something, original) + + + def test_with_statement_nested(self): + with catch_warnings(record=True): + with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else: + self.assertEqual(something, mock_something, "unpatched") + self.assertEqual(something_else, mock_something_else, + "unpatched") + + self.assertEqual(something, sentinel.Something) + self.assertEqual(something_else, sentinel.SomethingElse) + + + def test_with_statement_specified(self): + with patch('%s.something' % __name__, sentinel.Patched) as mock_something: + self.assertEqual(something, mock_something, "unpatched") + self.assertEqual(mock_something, sentinel.Patched, "wrong patch") + self.assertEqual(something, sentinel.Something) + + + def testContextManagerMocking(self): + mock = Mock() + mock.__enter__ = Mock() + mock.__exit__ = Mock() + mock.__exit__.return_value = False + + with mock as m: + self.assertEqual(m, mock.__enter__.return_value) + mock.__enter__.assert_called_with() + mock.__exit__.assert_called_with(None, None, None) + + + def test_context_manager_with_magic_mock(self): + mock = MagicMock() + + with self.assertRaises(TypeError): + with mock: + 'foo' + 3 + mock.__enter__.assert_called_with() + self.assertTrue(mock.__exit__.called) + + + def test_with_statement_same_attribute(self): + with patch('%s.something' % __name__, sentinel.Patched) as mock_something: + self.assertEqual(something, mock_something, "unpatched") + + with patch('%s.something' % __name__) as mock_again: + self.assertEqual(something, mock_again, "unpatched") + + self.assertEqual(something, mock_something, + "restored with wrong instance") + + self.assertEqual(something, sentinel.Something, "not restored") + + + def test_with_statement_imbricated(self): + with patch('%s.something' % __name__) as mock_something: + self.assertEqual(something, mock_something, "unpatched") + + with patch('%s.something_else' % __name__) as mock_something_else: + self.assertEqual(something_else, mock_something_else, + "unpatched") + + self.assertEqual(something, sentinel.Something) + self.assertEqual(something_else, sentinel.SomethingElse) + + + def test_dict_context_manager(self): + foo = {} + with patch.dict(foo, {'a': 'b'}): + self.assertEqual(foo, {'a': 'b'}) + self.assertEqual(foo, {}) + + with self.assertRaises(NameError): + with patch.dict(foo, {'a': 'b'}): + self.assertEqual(foo, {'a': 'b'}) + raise NameError('Konrad') + + self.assertEqual(foo, {}) + + def test_double_patch_instance_method(self): + class C: + def f(self): pass + + c = C() + + with patch.object(c, 'f') as patch1: + with patch.object(c, 'f') as patch2: + c.f() + self.assertEqual(patch2.call_count, 1) + self.assertEqual(patch1.call_count, 0) + c.f() + self.assertEqual(patch1.call_count, 1) + + +class TestMockOpen(unittest.TestCase): + + def test_mock_open(self): + mock = mock_open() + with patch('%s.open' % __name__, mock, create=True) as patched: + self.assertIs(patched, mock) + open('foo') + + mock.assert_called_once_with('foo') + + + def test_mock_open_context_manager(self): + mock = mock_open() + handle = mock.return_value + with patch('%s.open' % __name__, mock, create=True): + with open('foo') as f: + f.read() + + expected_calls = [call('foo'), call().__enter__(), call().read(), + call().__exit__(None, None, None)] + self.assertEqual(mock.mock_calls, expected_calls) + self.assertIs(f, handle) + + def test_mock_open_context_manager_multiple_times(self): + mock = mock_open() + with patch('%s.open' % __name__, mock, create=True): + with open('foo') as f: + f.read() + with open('bar') as f: + f.read() + + expected_calls = [ + call('foo'), call().__enter__(), call().read(), + call().__exit__(None, None, None), + call('bar'), call().__enter__(), call().read(), + call().__exit__(None, None, None)] + self.assertEqual(mock.mock_calls, expected_calls) + + def test_explicit_mock(self): + mock = MagicMock() + mock_open(mock) + + with patch('%s.open' % __name__, mock, create=True) as patched: + self.assertIs(patched, mock) + open('foo') + + mock.assert_called_once_with('foo') + + + def test_read_data(self): + mock = mock_open(read_data='foo') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + result = h.read() + + self.assertEqual(result, 'foo') + + + def test_readline_data(self): + # Check that readline will return all the lines from the fake file + # And that once fully consumed, readline will return an empty string. + mock = mock_open(read_data='foo\nbar\nbaz\n') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + line1 = h.readline() + line2 = h.readline() + line3 = h.readline() + self.assertEqual(line1, 'foo\n') + self.assertEqual(line2, 'bar\n') + self.assertEqual(line3, 'baz\n') + self.assertEqual(h.readline(), '') + + # Check that we properly emulate a file that doesn't end in a newline + mock = mock_open(read_data='foo') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + result = h.readline() + self.assertEqual(result, 'foo') + self.assertEqual(h.readline(), '') + + + def test_dunder_iter_data(self): + # Check that dunder_iter will return all the lines from the fake file. + mock = mock_open(read_data='foo\nbar\nbaz\n') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + lines = [l for l in h] + self.assertEqual(lines[0], 'foo\n') + self.assertEqual(lines[1], 'bar\n') + self.assertEqual(lines[2], 'baz\n') + self.assertEqual(h.readline(), '') + with self.assertRaises(StopIteration): + next(h) + + def test_next_data(self): + # Check that next will correctly return the next available + # line and plays well with the dunder_iter part. + mock = mock_open(read_data='foo\nbar\nbaz\n') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + line1 = next(h) + line2 = next(h) + lines = [l for l in h] + self.assertEqual(line1, 'foo\n') + self.assertEqual(line2, 'bar\n') + self.assertEqual(lines[0], 'baz\n') + self.assertEqual(h.readline(), '') + + def test_readlines_data(self): + # Test that emulating a file that ends in a newline character works + mock = mock_open(read_data='foo\nbar\nbaz\n') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + result = h.readlines() + self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n']) + + # Test that files without a final newline will also be correctly + # emulated + mock = mock_open(read_data='foo\nbar\nbaz') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + result = h.readlines() + + self.assertEqual(result, ['foo\n', 'bar\n', 'baz']) + + + def test_read_bytes(self): + mock = mock_open(read_data=b'\xc6') + with patch('%s.open' % __name__, mock, create=True): + with open('abc', 'rb') as f: + result = f.read() + self.assertEqual(result, b'\xc6') + + + def test_readline_bytes(self): + m = mock_open(read_data=b'abc\ndef\nghi\n') + with patch('%s.open' % __name__, m, create=True): + with open('abc', 'rb') as f: + line1 = f.readline() + line2 = f.readline() + line3 = f.readline() + self.assertEqual(line1, b'abc\n') + self.assertEqual(line2, b'def\n') + self.assertEqual(line3, b'ghi\n') + + + def test_readlines_bytes(self): + m = mock_open(read_data=b'abc\ndef\nghi\n') + with patch('%s.open' % __name__, m, create=True): + with open('abc', 'rb') as f: + result = f.readlines() + self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n']) + + + def test_mock_open_read_with_argument(self): + # At one point calling read with an argument was broken + # for mocks returned by mock_open + some_data = 'foo\nbar\nbaz' + mock = mock_open(read_data=some_data) + self.assertEqual(mock().read(10), some_data[:10]) + self.assertEqual(mock().read(10), some_data[:10]) + + f = mock() + self.assertEqual(f.read(10), some_data[:10]) + self.assertEqual(f.read(10), some_data[10:]) + + + def test_interleaved_reads(self): + # Test that calling read, readline, and readlines pulls data + # sequentially from the data we preload with + mock = mock_open(read_data='foo\nbar\nbaz\n') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + line1 = h.readline() + rest = h.readlines() + self.assertEqual(line1, 'foo\n') + self.assertEqual(rest, ['bar\n', 'baz\n']) + + mock = mock_open(read_data='foo\nbar\nbaz\n') + with patch('%s.open' % __name__, mock, create=True): + h = open('bar') + line1 = h.readline() + rest = h.read() + self.assertEqual(line1, 'foo\n') + self.assertEqual(rest, 'bar\nbaz\n') + + + def test_overriding_return_values(self): + mock = mock_open(read_data='foo') + handle = mock() + + handle.read.return_value = 'bar' + handle.readline.return_value = 'bar' + handle.readlines.return_value = ['bar'] + + self.assertEqual(handle.read(), 'bar') + self.assertEqual(handle.readline(), 'bar') + self.assertEqual(handle.readlines(), ['bar']) + + # call repeatedly to check that a StopIteration is not propagated + self.assertEqual(handle.readline(), 'bar') + self.assertEqual(handle.readline(), 'bar') + + +if __name__ == '__main__': + unittest.main() |