diff options
author | Jonas Haag <jonas@lophus.org> | 2017-11-25 15:23:52 (GMT) |
---|---|---|
committer | Antoine Pitrou <pitrou@free.fr> | 2017-11-25 15:23:52 (GMT) |
commit | 5b48dc638b7405fd9bde4d854bf477dfeaaddf44 (patch) | |
tree | 2e3b44b9193cc1a0e08a6e1d65dd324e76fb3ee6 /Lib/unittest | |
parent | 8d9bb11d8fcbf10cc9b1eb0a647bcf3658a4e3dd (diff) | |
download | cpython-5b48dc638b7405fd9bde4d854bf477dfeaaddf44.zip cpython-5b48dc638b7405fd9bde4d854bf477dfeaaddf44.tar.gz cpython-5b48dc638b7405fd9bde4d854bf477dfeaaddf44.tar.bz2 |
bpo-32071: Add unittest -k option (#4496)
* bpo-32071: Add unittest -k option
Diffstat (limited to 'Lib/unittest')
-rw-r--r-- | Lib/unittest/loader.py | 24 | ||||
-rw-r--r-- | Lib/unittest/main.py | 25 | ||||
-rw-r--r-- | Lib/unittest/test/test_loader.py | 27 | ||||
-rw-r--r-- | Lib/unittest/test/test_program.py | 28 |
4 files changed, 90 insertions, 14 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index e860deb..eb03b4a 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -8,7 +8,7 @@ import types import functools import warnings -from fnmatch import fnmatch +from fnmatch import fnmatch, fnmatchcase from . import case, suite, util @@ -70,6 +70,7 @@ class TestLoader(object): """ testMethodPrefix = 'test' sortTestMethodsUsing = staticmethod(util.three_way_cmp) + testNamePatterns = None suiteClass = suite.TestSuite _top_level_dir = None @@ -222,11 +223,15 @@ class TestLoader(object): def getTestCaseNames(self, testCaseClass): """Return a sorted sequence of method names found within testCaseClass """ - def isTestMethod(attrname, testCaseClass=testCaseClass, - prefix=self.testMethodPrefix): - return attrname.startswith(prefix) and \ - callable(getattr(testCaseClass, attrname)) - testFnNames = list(filter(isTestMethod, dir(testCaseClass))) + def shouldIncludeMethod(attrname): + testFunc = getattr(testCaseClass, attrname) + isTestMethod = attrname.startswith(self.testMethodPrefix) and callable(testFunc) + if not isTestMethod: + return False + fullName = '%s.%s' % (testCaseClass.__module__, testFunc.__qualname__) + return self.testNamePatterns is None or \ + any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns) + testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass))) if self.sortTestMethodsUsing: testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing)) return testFnNames @@ -486,16 +491,17 @@ class TestLoader(object): defaultTestLoader = TestLoader() -def _makeLoader(prefix, sortUsing, suiteClass=None): +def _makeLoader(prefix, sortUsing, suiteClass=None, testNamePatterns=None): loader = TestLoader() loader.sortTestMethodsUsing = sortUsing loader.testMethodPrefix = prefix + loader.testNamePatterns = testNamePatterns if suiteClass: loader.suiteClass = suiteClass return loader -def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp): - return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass) +def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNamePatterns=None): + return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass) def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp, suiteClass=suite.TestSuite): diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py index 807604f..e62469a 100644 --- a/Lib/unittest/main.py +++ b/Lib/unittest/main.py @@ -46,6 +46,12 @@ def _convert_names(names): return [_convert_name(name) for name in names] +def _convert_select_pattern(pattern): + if not '*' in pattern: + pattern = '*%s*' % pattern + return pattern + + class TestProgram(object): """A command-line program that runs a set of tests; this is primarily for making test modules conveniently executable. @@ -53,7 +59,7 @@ class TestProgram(object): # defaults for testing module=None verbosity = 1 - failfast = catchbreak = buffer = progName = warnings = None + failfast = catchbreak = buffer = progName = warnings = testNamePatterns = None _discovery_parser = None def __init__(self, module='__main__', defaultTest=None, argv=None, @@ -140,8 +146,13 @@ class TestProgram(object): self.testNames = list(self.defaultTest) self.createTests() - def createTests(self): - if self.testNames is None: + def createTests(self, from_discovery=False, Loader=None): + if self.testNamePatterns: + self.testLoader.testNamePatterns = self.testNamePatterns + if from_discovery: + loader = self.testLoader if Loader is None else Loader() + self.test = loader.discover(self.start, self.pattern, self.top) + elif self.testNames is None: self.test = self.testLoader.loadTestsFromModule(self.module) else: self.test = self.testLoader.loadTestsFromNames(self.testNames, @@ -179,6 +190,11 @@ class TestProgram(object): action='store_true', help='Buffer stdout and stderr during tests') self.buffer = False + if self.testNamePatterns is None: + parser.add_argument('-k', dest='testNamePatterns', + action='append', type=_convert_select_pattern, + help='Only run tests which match the given substring') + self.testNamePatterns = [] return parser @@ -225,8 +241,7 @@ class TestProgram(object): self._initArgParsers() self._discovery_parser.parse_args(argv, self) - loader = self.testLoader if Loader is None else Loader() - self.test = loader.discover(self.start, self.pattern, self.top) + self.createTests(from_discovery=True, Loader=Loader) def runTests(self): if self.catchbreak: diff --git a/Lib/unittest/test/test_loader.py b/Lib/unittest/test/test_loader.py index 1131a75..15b0186 100644 --- a/Lib/unittest/test/test_loader.py +++ b/Lib/unittest/test/test_loader.py @@ -1226,6 +1226,33 @@ class Test_TestLoader(unittest.TestCase): names = ['test_1', 'test_2', 'test_3'] self.assertEqual(loader.getTestCaseNames(TestC), names) + # "Return a sorted sequence of method names found within testCaseClass" + # + # If TestLoader.testNamePatterns is set, only tests that match one of these + # patterns should be included. + def test_getTestCaseNames__testNamePatterns(self): + class MyTest(unittest.TestCase): + def test_1(self): pass + def test_2(self): pass + def foobar(self): pass + + loader = unittest.TestLoader() + + loader.testNamePatterns = [] + self.assertEqual(loader.getTestCaseNames(MyTest), []) + + loader.testNamePatterns = ['*1'] + self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1']) + + loader.testNamePatterns = ['*1', '*2'] + self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1', 'test_2']) + + loader.testNamePatterns = ['*My*'] + self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1', 'test_2']) + + loader.testNamePatterns = ['*my*'] + self.assertEqual(loader.getTestCaseNames(MyTest), []) + ################################################################ ### /Tests for TestLoader.getTestCaseNames() diff --git a/Lib/unittest/test/test_program.py b/Lib/unittest/test/test_program.py index 1cfc179..4a62ae1 100644 --- a/Lib/unittest/test/test_program.py +++ b/Lib/unittest/test/test_program.py @@ -2,6 +2,7 @@ import io import os import sys +import subprocess from test import support import unittest import unittest.test @@ -409,6 +410,33 @@ class TestCommandLineArgs(unittest.TestCase): # for invalid filenames should we raise a useful error rather than # leaving the current error message (import of filename fails) in place? + def testParseArgsSelectedTestNames(self): + program = self.program + argv = ['progname', '-k', 'foo', '-k', 'bar', '-k', '*pat*'] + + program.createTests = lambda: None + program.parseArgs(argv) + + self.assertEqual(program.testNamePatterns, ['*foo*', '*bar*', '*pat*']) + + def testSelectedTestNamesFunctionalTest(self): + def run_unittest(args): + p = subprocess.Popen([sys.executable, '-m', 'unittest'] + args, + stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, cwd=os.path.dirname(__file__)) + with p: + _, stderr = p.communicate() + return stderr.decode() + + t = '_test_warnings' + self.assertIn('Ran 7 tests', run_unittest([t])) + self.assertIn('Ran 7 tests', run_unittest(['-k', 'TestWarnings', t])) + self.assertIn('Ran 7 tests', run_unittest(['discover', '-p', '*_test*', '-k', 'TestWarnings'])) + self.assertIn('Ran 2 tests', run_unittest(['-k', 'f', t])) + self.assertIn('Ran 7 tests', run_unittest(['-k', 't', t])) + self.assertIn('Ran 3 tests', run_unittest(['-k', '*t', t])) + self.assertIn('Ran 7 tests', run_unittest(['-k', '*test_warnings.*Warning*', t])) + self.assertIn('Ran 1 test', run_unittest(['-k', '*test_warnings.*warning*', t])) + if __name__ == '__main__': unittest.main() |