diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/test/test_unittest.py | 17 | ||||
-rw-r--r-- | Lib/unittest.py | 44 |
2 files changed, 18 insertions, 43 deletions
diff --git a/Lib/test/test_unittest.py b/Lib/test/test_unittest.py index e7097cc..ea33180 100644 --- a/Lib/test/test_unittest.py +++ b/Lib/test/test_unittest.py @@ -106,7 +106,7 @@ class TestHashing(object): # List subclass we can add attributes to. class MyClassSuite(list): - def __init__(self, tests, klass): + def __init__(self, tests): super(MyClassSuite, self).__init__(tests) @@ -1271,7 +1271,7 @@ class Test_TestLoader(TestCase): tests = [Foo('test_1'), Foo('test_2')] loader = unittest.TestLoader() - loader.classSuiteClass = MyClassSuite + loader.suiteClass = list self.assertEqual(loader.loadTestsFromTestCase(Foo), tests) # It is implicit in the documentation for TestLoader.suiteClass that @@ -1284,7 +1284,7 @@ class Test_TestLoader(TestCase): def foo_bar(self): pass m.Foo = Foo - tests = [unittest.ClassTestSuite([Foo('test_1'), Foo('test_2')], Foo)] + tests = [[Foo('test_1'), Foo('test_2')]] loader = unittest.TestLoader() loader.suiteClass = list @@ -1303,7 +1303,7 @@ class Test_TestLoader(TestCase): tests = [Foo('test_1'), Foo('test_2')] loader = unittest.TestLoader() - loader.classSuiteClass = MyClassSuite + loader.suiteClass = list self.assertEqual(loader.loadTestsFromName('Foo', m), tests) # It is implicit in the documentation for TestLoader.suiteClass that @@ -1316,7 +1316,7 @@ class Test_TestLoader(TestCase): def foo_bar(self): pass m.Foo = Foo - tests = [unittest.ClassTestSuite([Foo('test_1'), Foo('test_2')], Foo)] + tests = [[Foo('test_1'), Foo('test_2')]] loader = unittest.TestLoader() loader.suiteClass = list @@ -2842,7 +2842,7 @@ class Test_TestSkipping(TestCase): def test_dont_skip(self): pass test_do_skip = Foo("test_skip") test_dont_skip = Foo("test_dont_skip") - suite = unittest.ClassTestSuite([test_do_skip, test_dont_skip], Foo) + suite = unittest.TestSuite([test_do_skip, test_dont_skip]) events = [] result = LoggingResult(events) suite.run(result) @@ -2861,9 +2861,10 @@ class Test_TestSkipping(TestCase): record.append(1) record = [] result = unittest.TestResult() - suite = unittest.ClassTestSuite([Foo("test_1")], Foo) + test = Foo("test_1") + suite = unittest.TestSuite([test]) suite.run(result) - self.assertEqual(result.skipped, [(suite, "testing")]) + self.assertEqual(result.skipped, [(test, "testing")]) self.assertEqual(record, []) def test_expected_failure(self): diff --git a/Lib/unittest.py b/Lib/unittest.py index c6d893e..cdccd8c 100644 --- a/Lib/unittest.py +++ b/Lib/unittest.py @@ -59,7 +59,7 @@ import warnings ############################################################################## # Exported classes and functions ############################################################################## -__all__ = ['TestResult', 'TestCase', 'TestSuite', 'ClassTestSuite', +__all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless', 'expectedFailure'] @@ -459,6 +459,13 @@ class TestCase(object): self._result = result result.startTest(self) + if getattr(self.__class__, "__unittest_skip__", False): + # If the whole class was skipped. + try: + result.addSkip(self, self.__class__.__unittest_skip_why__) + finally: + result.stopTest(self) + return testMethod = getattr(self, self._testMethodName) try: success = False @@ -1129,37 +1136,6 @@ class TestSuite(object): test.debug() -class ClassTestSuite(TestSuite): - """ - Suite of tests derived from a single TestCase class. - """ - - def __init__(self, tests, class_collected_from): - super(ClassTestSuite, self).__init__(tests) - self.collected_from = class_collected_from - - def id(self): - module = getattr(self.collected_from, "__module__", None) - if module is not None: - return "{0}.{1}".format(module, self.collected_from.__name__) - return self.collected_from.__name__ - - def run(self, result): - if getattr(self.collected_from, "__unittest_skip__", False): - # ClassTestSuite result pretends to be a TestCase enough to be - # reported. - result.startTest(self) - try: - result.addSkip(self, self.collected_from.__unittest_skip_why__) - finally: - result.stopTest(self) - else: - result = super(ClassTestSuite, self).run(result) - return result - - shortDescription = id - - class FunctionTestCase(TestCase): """A test case that wraps a test function. @@ -1245,7 +1221,6 @@ class TestLoader(object): testMethodPrefix = 'test' sortTestMethodsUsing = staticmethod(three_way_cmp) suiteClass = TestSuite - classSuiteClass = ClassTestSuite def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" @@ -1255,8 +1230,7 @@ class TestLoader(object): testCaseNames = self.getTestCaseNames(testCaseClass) if not testCaseNames and hasattr(testCaseClass, 'runTest'): testCaseNames = ['runTest'] - suite = self.classSuiteClass(map(testCaseClass, testCaseNames), - testCaseClass) + suite = self.suiteClass(map(testCaseClass, testCaseNames)) return suite def loadTestsFromModule(self, module): |