summaryrefslogtreecommitdiffstats
path: root/Lib/unittest
diff options
context:
space:
mode:
authorJonas Haag <jonas@lophus.org>2017-11-25 15:23:52 (GMT)
committerAntoine Pitrou <pitrou@free.fr>2017-11-25 15:23:52 (GMT)
commit5b48dc638b7405fd9bde4d854bf477dfeaaddf44 (patch)
tree2e3b44b9193cc1a0e08a6e1d65dd324e76fb3ee6 /Lib/unittest
parent8d9bb11d8fcbf10cc9b1eb0a647bcf3658a4e3dd (diff)
downloadcpython-5b48dc638b7405fd9bde4d854bf477dfeaaddf44.zip
cpython-5b48dc638b7405fd9bde4d854bf477dfeaaddf44.tar.gz
cpython-5b48dc638b7405fd9bde4d854bf477dfeaaddf44.tar.bz2
bpo-32071: Add unittest -k option (#4496)
* bpo-32071: Add unittest -k option
Diffstat (limited to 'Lib/unittest')
-rw-r--r--Lib/unittest/loader.py24
-rw-r--r--Lib/unittest/main.py25
-rw-r--r--Lib/unittest/test/test_loader.py27
-rw-r--r--Lib/unittest/test/test_program.py28
4 files changed, 90 insertions, 14 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index e860deb..eb03b4a 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -8,7 +8,7 @@ import types
import functools
import warnings
-from fnmatch import fnmatch
+from fnmatch import fnmatch, fnmatchcase
from . import case, suite, util
@@ -70,6 +70,7 @@ class TestLoader(object):
"""
testMethodPrefix = 'test'
sortTestMethodsUsing = staticmethod(util.three_way_cmp)
+ testNamePatterns = None
suiteClass = suite.TestSuite
_top_level_dir = None
@@ -222,11 +223,15 @@ class TestLoader(object):
def getTestCaseNames(self, testCaseClass):
"""Return a sorted sequence of method names found within testCaseClass
"""
- def isTestMethod(attrname, testCaseClass=testCaseClass,
- prefix=self.testMethodPrefix):
- return attrname.startswith(prefix) and \
- callable(getattr(testCaseClass, attrname))
- testFnNames = list(filter(isTestMethod, dir(testCaseClass)))
+ def shouldIncludeMethod(attrname):
+ testFunc = getattr(testCaseClass, attrname)
+ isTestMethod = attrname.startswith(self.testMethodPrefix) and callable(testFunc)
+ if not isTestMethod:
+ return False
+ fullName = '%s.%s' % (testCaseClass.__module__, testFunc.__qualname__)
+ return self.testNamePatterns is None or \
+ any(fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns)
+ testFnNames = list(filter(shouldIncludeMethod, dir(testCaseClass)))
if self.sortTestMethodsUsing:
testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing))
return testFnNames
@@ -486,16 +491,17 @@ class TestLoader(object):
defaultTestLoader = TestLoader()
-def _makeLoader(prefix, sortUsing, suiteClass=None):
+def _makeLoader(prefix, sortUsing, suiteClass=None, testNamePatterns=None):
loader = TestLoader()
loader.sortTestMethodsUsing = sortUsing
loader.testMethodPrefix = prefix
+ loader.testNamePatterns = testNamePatterns
if suiteClass:
loader.suiteClass = suiteClass
return loader
-def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp):
- return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
+def getTestCaseNames(testCaseClass, prefix, sortUsing=util.three_way_cmp, testNamePatterns=None):
+ return _makeLoader(prefix, sortUsing, testNamePatterns=testNamePatterns).getTestCaseNames(testCaseClass)
def makeSuite(testCaseClass, prefix='test', sortUsing=util.three_way_cmp,
suiteClass=suite.TestSuite):
diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py
index 807604f..e62469a 100644
--- a/Lib/unittest/main.py
+++ b/Lib/unittest/main.py
@@ -46,6 +46,12 @@ def _convert_names(names):
return [_convert_name(name) for name in names]
+def _convert_select_pattern(pattern):
+ if not '*' in pattern:
+ pattern = '*%s*' % pattern
+ return pattern
+
+
class TestProgram(object):
"""A command-line program that runs a set of tests; this is primarily
for making test modules conveniently executable.
@@ -53,7 +59,7 @@ class TestProgram(object):
# defaults for testing
module=None
verbosity = 1
- failfast = catchbreak = buffer = progName = warnings = None
+ failfast = catchbreak = buffer = progName = warnings = testNamePatterns = None
_discovery_parser = None
def __init__(self, module='__main__', defaultTest=None, argv=None,
@@ -140,8 +146,13 @@ class TestProgram(object):
self.testNames = list(self.defaultTest)
self.createTests()
- def createTests(self):
- if self.testNames is None:
+ def createTests(self, from_discovery=False, Loader=None):
+ if self.testNamePatterns:
+ self.testLoader.testNamePatterns = self.testNamePatterns
+ if from_discovery:
+ loader = self.testLoader if Loader is None else Loader()
+ self.test = loader.discover(self.start, self.pattern, self.top)
+ elif self.testNames is None:
self.test = self.testLoader.loadTestsFromModule(self.module)
else:
self.test = self.testLoader.loadTestsFromNames(self.testNames,
@@ -179,6 +190,11 @@ class TestProgram(object):
action='store_true',
help='Buffer stdout and stderr during tests')
self.buffer = False
+ if self.testNamePatterns is None:
+ parser.add_argument('-k', dest='testNamePatterns',
+ action='append', type=_convert_select_pattern,
+ help='Only run tests which match the given substring')
+ self.testNamePatterns = []
return parser
@@ -225,8 +241,7 @@ class TestProgram(object):
self._initArgParsers()
self._discovery_parser.parse_args(argv, self)
- loader = self.testLoader if Loader is None else Loader()
- self.test = loader.discover(self.start, self.pattern, self.top)
+ self.createTests(from_discovery=True, Loader=Loader)
def runTests(self):
if self.catchbreak:
diff --git a/Lib/unittest/test/test_loader.py b/Lib/unittest/test/test_loader.py
index 1131a75..15b0186 100644
--- a/Lib/unittest/test/test_loader.py
+++ b/Lib/unittest/test/test_loader.py
@@ -1226,6 +1226,33 @@ class Test_TestLoader(unittest.TestCase):
names = ['test_1', 'test_2', 'test_3']
self.assertEqual(loader.getTestCaseNames(TestC), names)
+ # "Return a sorted sequence of method names found within testCaseClass"
+ #
+ # If TestLoader.testNamePatterns is set, only tests that match one of these
+ # patterns should be included.
+ def test_getTestCaseNames__testNamePatterns(self):
+ class MyTest(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foobar(self): pass
+
+ loader = unittest.TestLoader()
+
+ loader.testNamePatterns = []
+ self.assertEqual(loader.getTestCaseNames(MyTest), [])
+
+ loader.testNamePatterns = ['*1']
+ self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1'])
+
+ loader.testNamePatterns = ['*1', '*2']
+ self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1', 'test_2'])
+
+ loader.testNamePatterns = ['*My*']
+ self.assertEqual(loader.getTestCaseNames(MyTest), ['test_1', 'test_2'])
+
+ loader.testNamePatterns = ['*my*']
+ self.assertEqual(loader.getTestCaseNames(MyTest), [])
+
################################################################
### /Tests for TestLoader.getTestCaseNames()
diff --git a/Lib/unittest/test/test_program.py b/Lib/unittest/test/test_program.py
index 1cfc179..4a62ae1 100644
--- a/Lib/unittest/test/test_program.py
+++ b/Lib/unittest/test/test_program.py
@@ -2,6 +2,7 @@ import io
import os
import sys
+import subprocess
from test import support
import unittest
import unittest.test
@@ -409,6 +410,33 @@ class TestCommandLineArgs(unittest.TestCase):
# for invalid filenames should we raise a useful error rather than
# leaving the current error message (import of filename fails) in place?
+ def testParseArgsSelectedTestNames(self):
+ program = self.program
+ argv = ['progname', '-k', 'foo', '-k', 'bar', '-k', '*pat*']
+
+ program.createTests = lambda: None
+ program.parseArgs(argv)
+
+ self.assertEqual(program.testNamePatterns, ['*foo*', '*bar*', '*pat*'])
+
+ def testSelectedTestNamesFunctionalTest(self):
+ def run_unittest(args):
+ p = subprocess.Popen([sys.executable, '-m', 'unittest'] + args,
+ stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, cwd=os.path.dirname(__file__))
+ with p:
+ _, stderr = p.communicate()
+ return stderr.decode()
+
+ t = '_test_warnings'
+ self.assertIn('Ran 7 tests', run_unittest([t]))
+ self.assertIn('Ran 7 tests', run_unittest(['-k', 'TestWarnings', t]))
+ self.assertIn('Ran 7 tests', run_unittest(['discover', '-p', '*_test*', '-k', 'TestWarnings']))
+ self.assertIn('Ran 2 tests', run_unittest(['-k', 'f', t]))
+ self.assertIn('Ran 7 tests', run_unittest(['-k', 't', t]))
+ self.assertIn('Ran 3 tests', run_unittest(['-k', '*t', t]))
+ self.assertIn('Ran 7 tests', run_unittest(['-k', '*test_warnings.*Warning*', t]))
+ self.assertIn('Ran 1 test', run_unittest(['-k', '*test_warnings.*warning*', t]))
+
if __name__ == '__main__':
unittest.main()