diff options
Diffstat (limited to 'Lib/unittest')
-rw-r--r-- | Lib/unittest/case.py | 13 | ||||
-rw-r--r-- | Lib/unittest/test/test_case.py | 56 |
2 files changed, 69 insertions, 0 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index befad61..7701ad3 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -119,6 +119,10 @@ def expectedFailure(test_item): test_item.__unittest_expecting_failure__ = True return test_item +def _is_subtype(expected, basetype): + if isinstance(expected, tuple): + return all(_is_subtype(e, basetype) for e in expected) + return isinstance(expected, type) and issubclass(expected, basetype) class _BaseTestCaseContext: @@ -148,6 +152,9 @@ class _AssertRaisesBaseContext(_BaseTestCaseContext): If args is not empty, call a callable passing positional and keyword arguments. """ + if not _is_subtype(self.expected, self._base_type): + raise TypeError('%s() arg 1 must be %s' % + (name, self._base_type_str)) if args and args[0] is None: warnings.warn("callable is None", DeprecationWarning, 3) @@ -172,6 +179,9 @@ class _AssertRaisesBaseContext(_BaseTestCaseContext): class _AssertRaisesContext(_AssertRaisesBaseContext): """A context manager used to implement TestCase.assertRaises* methods.""" + _base_type = BaseException + _base_type_str = 'an exception type or tuple of exception types' + def __enter__(self): return self @@ -206,6 +216,9 @@ class _AssertRaisesContext(_AssertRaisesBaseContext): class _AssertWarnsContext(_AssertRaisesBaseContext): """A context manager used to implement TestCase.assertWarns* methods.""" + _base_type = Warning + _base_type_str = 'a warning type or tuple of warning types' + def __enter__(self): # The __warningregistry__'s need to be in a pristine state for tests # to work properly. diff --git a/Lib/unittest/test/test_case.py b/Lib/unittest/test/test_case.py index a05cc64..ada733b 100644 --- a/Lib/unittest/test/test_case.py +++ b/Lib/unittest/test/test_case.py @@ -1185,6 +1185,18 @@ test case with self.assertRaises(ExceptionMock): self.assertRaises(ValueError, Stub) + def testAssertRaisesNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertRaises() + with self.assertRaises(TypeError): + self.assertRaises(1) + with self.assertRaises(TypeError): + self.assertRaises(object) + with self.assertRaises(TypeError): + self.assertRaises((ValueError, 1)) + with self.assertRaises(TypeError): + self.assertRaises((ValueError, object)) + def testAssertRaisesRegex(self): class ExceptionMock(Exception): pass @@ -1258,6 +1270,20 @@ test case self.assertIsInstance(e, ExceptionMock) self.assertEqual(e.args[0], v) + def testAssertRaisesRegexNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertRaisesRegex() + with self.assertRaises(TypeError): + self.assertRaisesRegex(ValueError) + with self.assertRaises(TypeError): + self.assertRaisesRegex(1, 'expect') + with self.assertRaises(TypeError): + self.assertRaisesRegex(object, 'expect') + with self.assertRaises(TypeError): + self.assertRaisesRegex((ValueError, 1), 'expect') + with self.assertRaises(TypeError): + self.assertRaisesRegex((ValueError, object), 'expect') + def testAssertWarnsCallable(self): def _runtime_warn(): warnings.warn("foo", RuntimeWarning) @@ -1336,6 +1362,20 @@ test case with self.assertWarns(DeprecationWarning): _runtime_warn() + def testAssertWarnsNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertWarns() + with self.assertRaises(TypeError): + self.assertWarns(1) + with self.assertRaises(TypeError): + self.assertWarns(object) + with self.assertRaises(TypeError): + self.assertWarns((UserWarning, 1)) + with self.assertRaises(TypeError): + self.assertWarns((UserWarning, object)) + with self.assertRaises(TypeError): + self.assertWarns((UserWarning, Exception)) + def testAssertWarnsRegexCallable(self): def _runtime_warn(msg): warnings.warn(msg, RuntimeWarning) @@ -1414,6 +1454,22 @@ test case with self.assertWarnsRegex(RuntimeWarning, "o+"): _runtime_warn("barz") + def testAssertWarnsRegexNoExceptionType(self): + with self.assertRaises(TypeError): + self.assertWarnsRegex() + with self.assertRaises(TypeError): + self.assertWarnsRegex(UserWarning) + with self.assertRaises(TypeError): + self.assertWarnsRegex(1, 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex(object, 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex((UserWarning, 1), 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex((UserWarning, object), 'expect') + with self.assertRaises(TypeError): + self.assertWarnsRegex((UserWarning, Exception), 'expect') + @contextlib.contextmanager def assertNoStderr(self): with captured_stderr() as buf: |