diff options
-rw-r--r-- | Lib/test/test_unittest.py | 15 | ||||
-rw-r--r-- | Lib/unittest/loader.py | 28 |
2 files changed, 36 insertions, 7 deletions
diff --git a/Lib/test/test_unittest.py b/Lib/test/test_unittest.py index 5228e76..e094923 100644 --- a/Lib/test/test_unittest.py +++ b/Lib/test/test_unittest.py @@ -289,6 +289,21 @@ class Test_TestLoader(TestCase): suite = loader.loadTestsFromModule(m, use_load_tests=False) self.assertEquals(load_tests_args, []) + def test_loadTestsFromModule__faulty_load_tests(self): + m = types.ModuleType('m') + + def load_tests(loader, tests, pattern): + raise TypeError('some failure') + m.load_tests = load_tests + + loader = unittest.TestLoader() + suite = loader.loadTestsFromModule(m) + self.assertIsInstance(suite, unittest.TestSuite) + self.assertEqual(suite.countTestCases(), 1) + test = list(suite)[0] + + self.assertRaisesRegexp(TypeError, "some failure", test.m) + ################################################################ ### /Tests for TestLoader.loadTestsFromModule() diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index c04de06..4036454 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -33,12 +33,18 @@ def _make_failed_import_test(name, suiteClass): # Python 2.3 compatibility # format_exc returns two frames of discover.py as well message += '\n%s' % traceback.format_exc() + return _make_failed_test('ModuleImportFailure', name, suiteClass, + ImportError(message)) - def testImportFailure(self): - raise ImportError(message) - attrs = {name: testImportFailure} - ModuleImportFailure = type('ModuleImportFailure', (case.TestCase,), attrs) - return suiteClass((ModuleImportFailure(name),)) +def _make_failed_load_tests(name, exception, suiteClass): + return _make_failed_test('LoadTestsFailure', name, suiteClass, exception) + +def _make_failed_test(classname, methodname, suiteClass, exception): + def testFailure(self): + raise exception + attrs = {methodname: testFailure} + TestClass = type(classname, (case.TestCase,), attrs) + return suiteClass((TestClass(methodname),)) class TestLoader(object): @@ -73,7 +79,11 @@ class TestLoader(object): load_tests = getattr(module, 'load_tests', None) tests = self.suiteClass(tests) if use_load_tests and load_tests is not None: - return load_tests(self, tests, None) + try: + return load_tests(self, tests, None) + except Exception, e: + return _make_failed_load_tests(module.__name__, e, + self.suiteClass) return tests def loadTestsFromName(self, name, module=None): @@ -239,7 +249,11 @@ class TestLoader(object): for test in self._find_tests(full_path, pattern): yield test else: - yield load_tests(self, tests, pattern) + try: + yield load_tests(self, tests, pattern) + except Exception, e: + yield _make_failed_load_tests(package.__name__, e, + self.suiteClass) defaultTestLoader = TestLoader() |