summaryrefslogtreecommitdiffstats
path: root/Lib/unittest
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/unittest')
-rw-r--r--Lib/unittest/case.py10
-rw-r--r--Lib/unittest/loader.py63
-rw-r--r--Lib/unittest/main.py13
-rw-r--r--Lib/unittest/suite.py3
4 files changed, 65 insertions, 24 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 48f3ef1..77ca278 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -468,7 +468,13 @@ class TestCase(object):
Note that decimal places (from zero) are usually not the same
as significant digits (measured from the most signficant digit).
+
+ If the two objects compare equal then they will automatically
+ compare almost equal.
"""
+ if first == second:
+ # shortcut for ite
+ return
if round(abs(second-first), places) != 0:
standardMsg = '%r != %r within %r places' % (first, second, places)
msg = self._formatMessage(msg, standardMsg)
@@ -481,8 +487,10 @@ class TestCase(object):
Note that decimal places (from zero) are usually not the same
as significant digits (measured from the most signficant digit).
+
+ Objects that are equal automatically fail.
"""
- if round(abs(second-first), places) == 0:
+ if (first == second) or round(abs(second-first), places) == 0:
standardMsg = '%r == %r within %r places' % (first, second, places)
msg = self._formatMessage(msg, standardMsg)
raise self.failureException(msg)
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index c687b1b..68f954c 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -1,7 +1,9 @@
"""Loading unittests."""
import os
+import re
import sys
+import traceback
import types
from fnmatch import fnmatch
@@ -9,6 +11,26 @@ from fnmatch import fnmatch
from . import case, suite, util
+# what about .pyc or .pyo (etc)
+# we would need to avoid loading the same tests multiple times
+# from '.py', '.pyc' *and* '.pyo'
+VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
+
+
+def _make_failed_import_test(name, suiteClass):
+ message = 'Failed to import test module: %s' % name
+ if hasattr(traceback, 'format_exc'):
+ # Python 2.3 compatibility
+ # format_exc returns two frames of discover.py as well
+ message += '\n%s' % traceback.format_exc()
+
+ def testImportFailure(self):
+ raise ImportError(message)
+ attrs = {name: testImportFailure}
+ ModuleImportFailure = type('ModuleImportFailure', (case.TestCase,), attrs)
+ return suiteClass((ModuleImportFailure(name),))
+
+
class TestLoader(object):
"""
This class is responsible for loading tests according to various criteria
@@ -79,7 +101,7 @@ class TestLoader(object):
inst = parent(name)
# static methods follow a different path
if not isinstance(getattr(inst, name), types.FunctionType):
- return suite.TestSuite([inst])
+ return self.suiteClass([inst])
elif isinstance(obj, suite.TestSuite):
return obj
if hasattr(obj, '__call__'):
@@ -87,7 +109,7 @@ class TestLoader(object):
if isinstance(test, suite.TestSuite):
return test
elif isinstance(test, case.TestCase):
- return suite.TestSuite([test])
+ return self.suiteClass([test])
else:
raise TypeError("calling %s returned %s, not a test" %
(obj, test))
@@ -156,17 +178,17 @@ class TestLoader(object):
tests = list(self._find_tests(start_dir, pattern))
return self.suiteClass(tests)
-
- def _get_module_from_path(self, path):
- """Load a module from a path relative to the top-level directory
- of a project. Used by discovery."""
+ def _get_name_from_path(self, path):
path = os.path.splitext(os.path.normpath(path))[0]
- relpath = os.path.relpath(path, self._top_level_dir)
- assert not os.path.isabs(relpath), "Path must be within the project"
- assert not relpath.startswith('..'), "Path must be within the project"
+ _relpath = os.path.relpath(path, self._top_level_dir)
+ assert not os.path.isabs(_relpath), "Path must be within the project"
+ assert not _relpath.startswith('..'), "Path must be within the project"
+
+ name = _relpath.replace(os.path.sep, '.')
+ return name
- name = relpath.replace(os.path.sep, '.')
+ def _get_module_from_name(self, name):
__import__(name)
return sys.modules[name]
@@ -176,14 +198,20 @@ class TestLoader(object):
for path in paths:
full_path = os.path.join(start_dir, path)
- # what about __init__.pyc or pyo (etc)
- # we would need to avoid loading the same tests multiple times
- # from '.py', '.pyc' *and* '.pyo'
- if os.path.isfile(full_path) and path.lower().endswith('.py'):
+ if os.path.isfile(full_path):
+ if not VALID_MODULE_NAME.match(path):
+ # valid Python identifiers only
+ continue
+
if fnmatch(path, pattern):
# if the test file matches, load it
- module = self._get_module_from_path(full_path)
- yield self.loadTestsFromModule(module)
+ name = self._get_name_from_path(full_path)
+ try:
+ module = self._get_module_from_name(name)
+ except:
+ yield _make_failed_import_test(name, self.suiteClass)
+ else:
+ yield self.loadTestsFromModule(module)
elif os.path.isdir(full_path):
if not os.path.isfile(os.path.join(full_path, '__init__.py')):
continue
@@ -192,7 +220,8 @@ class TestLoader(object):
tests = None
if fnmatch(path, pattern):
# only check load_tests if the package directory itself matches the filter
- package = self._get_module_from_path(full_path)
+ name = self._get_name_from_path(full_path)
+ package = self._get_module_from_name(name)
load_tests = getattr(package, 'load_tests', None)
tests = self.loadTestsFromModule(package, use_load_tests=False)
diff --git a/Lib/unittest/main.py b/Lib/unittest/main.py
index 4a5c22b..e6237b0 100644
--- a/Lib/unittest/main.py
+++ b/Lib/unittest/main.py
@@ -109,9 +109,9 @@ class TestProgram(object):
if opt in ('-v','--verbose'):
self.verbosity = 2
if len(args) == 0 and self.defaultTest is None:
- self.test = self.testLoader.loadTestsFromModule(self.module)
- return
- if len(args) > 0:
+ # createTests will load tests from self.module
+ self.testNames = None
+ elif len(args) > 0:
self.testNames = args
if __name__ == '__main__':
# to support python -m unittest ...
@@ -123,8 +123,11 @@ class TestProgram(object):
self.usageExit(msg)
def createTests(self):
- self.test = self.testLoader.loadTestsFromNames(self.testNames,
- self.module)
+ if self.testNames is None:
+ self.test = self.testLoader.loadTestsFromModule(self.module)
+ else:
+ self.test = self.testLoader.loadTestsFromNames(self.testNames,
+ self.module)
def _do_discovery(self, argv, Loader=loader.TestLoader):
# handle command line args for test discovery
diff --git a/Lib/unittest/suite.py b/Lib/unittest/suite.py
index baf8414..8672aab 100644
--- a/Lib/unittest/suite.py
+++ b/Lib/unittest/suite.py
@@ -1,6 +1,7 @@
"""TestSuite"""
from . import case
+from . import util
class TestSuite(object):
@@ -17,7 +18,7 @@ class TestSuite(object):
self.addTests(tests)
def __repr__(self):
- return "<%s tests=%s>" % (_strclass(self.__class__), list(self))
+ return "<%s tests=%s>" % (util.strclass(self.__class__), list(self))
def __eq__(self, other):
if not isinstance(other, self.__class__):