diff options
Diffstat (limited to 'Lib/unittest')
-rw-r--r-- | Lib/unittest/case.py | 10 | ||||
-rw-r--r-- | Lib/unittest/loader.py | 63 | ||||
-rw-r--r-- | Lib/unittest/main.py | 13 | ||||
-rw-r--r-- | Lib/unittest/suite.py | 3 |
4 files changed, 65 insertions, 24 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 48f3ef1..77ca278 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -468,7 +468,13 @@ class TestCase(object): Note that decimal places (from zero) are usually not the same as significant digits (measured from the most signficant digit). + + If the two objects compare equal then they will automatically + compare almost equal. """ + if first == second: + # shortcut for ite + return if round(abs(second-first), places) != 0: standardMsg = '%r != %r within %r places' % (first, second, places) msg = self._formatMessage(msg, standardMsg) @@ -481,8 +487,10 @@ class TestCase(object): Note that decimal places (from zero) are usually not the same as significant digits (measured from the most signficant digit). + + Objects that are equal automatically fail. """ - if round(abs(second-first), places) == 0: + if (first == second) or round(abs(second-first), places) == 0: standardMsg = '%r == %r within %r places' % (first, second, places) msg = self._formatMessage(msg, standardMsg) raise self.failureException(msg) diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index c687b1b..68f954c 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -1,7 +1,9 @@ """Loading unittests.""" import os +import re import sys +import traceback import types from fnmatch import fnmatch @@ -9,6 +11,26 @@ from fnmatch import fnmatch from . import case, suite, util +# what about .pyc or .pyo (etc) +# we would need to avoid loading the same tests multiple times +# from '.py', '.pyc' *and* '.pyo' +VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE) + + +def _make_failed_import_test(name, suiteClass): + message = 'Failed to import test module: %s' % name + if hasattr(traceback, 'format_exc'): + # Python 2.3 compatibility + # format_exc returns two frames of discover.py as well + message += '\n%s' % traceback.format_exc() + + def testImportFailure(self): + raise ImportError(message) + attrs = {name: testImportFailure} + ModuleImportFailure = type('ModuleImportFailure', (case.TestCase,), attrs) + return suiteClass((ModuleImportFailure(name),)) + + class TestLoader(object): """ This class is responsible for loading tests according to various criteria @@ -79,7 +101,7 @@ class TestLoader(object): inst = parent(name) # static methods follow a different path if not isinstance(getattr(inst, name), types.FunctionType): - return suite.TestSuite([inst]) + return self.suiteClass([inst]) elif isinstance(obj, suite.TestSuite): return obj if hasattr(obj, '__call__'): @@ -87,7 +109,7 @@ class TestLoader(object): if isinstance(test, suite.TestSuite): return test elif isinstance(test, case.TestCase): - return suite.TestSuite([test]) + return self.suiteClass([test]) else: raise TypeError("calling %s returned %s, not a test" % (obj, test)) @@ -156,17 +178,17 @@ class TestLoader(object): tests = list(self._find_tests(start_dir, pattern)) return self.suiteClass(tests) - - def _get_module_from_path(self, path): - """Load a module from a path relative to the top-level directory - of a project. Used by discovery.""" + def _get_name_from_path(self, path): path = os.path.splitext(os.path.normpath(path))[0] - relpath = os.path.relpath(path, self._top_level_dir) - assert not os.path.isabs(relpath), "Path must be within the project" - assert not relpath.startswith('..'), "Path must be within the project" + _relpath = os.path.relpath(path, self._top_level_dir) + assert not os.path.isabs(_relpath), "Path must be within the project" + assert not _relpath.startswith('..'), "Path must be within the project" + + name = _relpath.replace(os.path.sep, '.') + return name - name = relpath.replace(os.path.sep, '.') + def _get_module_from_name(self, name): __import__(name) return sys.modules[name] @@ -176,14 +198,20 @@ class TestLoader(object): for path in paths: full_path = os.path.join(start_dir, path) - # what about __init__.pyc or pyo (etc) - # we would need to avoid loading the same tests multiple times - # from '.py', '.pyc' *and* '.pyo' - if os.path.isfile(full_path) and path.lower().endswith('.py'): + if os.path.isfile(full_path): + if not VALID_MODULE_NAME.match(path): + # valid Python identifiers only + continue + if fnmatch(path, pattern): # if the test file matches, load it - module = self._get_module_from_path(full_path) - yield self.loadTestsFromModule(module) + name = self._get_name_from_path(full_path) + try: + module = self._get_module_from_name(name) + except: + yield _make_failed_import_test(name, self.suiteClass) + else: + yield self.loadTestsFromModule(module) elif os.path.isdir(full_path): if not os.path.isfile(os.path.join(full_path, '__init__.py')): continue @@ -192,7 +220,8 @@ class TestLoader(object): tests = None if fnmatch(path, pattern): # only check load_tests if the package directory itself matches the filter - package = self._get_module_from_path(full_path) + name = self._get_name_from_path(full_path) + package = self._get_module_from_name(name) load_tests = getattr(package, 'load_tests', None) tests = self.loadTestsFromModule(package, use_load_tests=False) diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py index 4a5c22b..e6237b0 100644 --- a/Lib/unittest/main.py +++ b/Lib/unittest/main.py @@ -109,9 +109,9 @@ class TestProgram(object): if opt in ('-v','--verbose'): self.verbosity = 2 if len(args) == 0 and self.defaultTest is None: - self.test = self.testLoader.loadTestsFromModule(self.module) - return - if len(args) > 0: + # createTests will load tests from self.module + self.testNames = None + elif len(args) > 0: self.testNames = args if __name__ == '__main__': # to support python -m unittest ... @@ -123,8 +123,11 @@ class TestProgram(object): self.usageExit(msg) def createTests(self): - self.test = self.testLoader.loadTestsFromNames(self.testNames, - self.module) + if self.testNames is None: + self.test = self.testLoader.loadTestsFromModule(self.module) + else: + self.test = self.testLoader.loadTestsFromNames(self.testNames, + self.module) def _do_discovery(self, argv, Loader=loader.TestLoader): # handle command line args for test discovery diff --git a/Lib/unittest/suite.py b/Lib/unittest/suite.py index baf8414..8672aab 100644 --- a/Lib/unittest/suite.py +++ b/Lib/unittest/suite.py @@ -1,6 +1,7 @@ """TestSuite""" from . import case +from . import util class TestSuite(object): @@ -17,7 +18,7 @@ class TestSuite(object): self.addTests(tests) def __repr__(self): - return "<%s tests=%s>" % (_strclass(self.__class__), list(self)) + return "<%s tests=%s>" % (util.strclass(self.__class__), list(self)) def __eq__(self, other): if not isinstance(other, self.__class__): |