summaryrefslogtreecommitdiffstats
path: root/Lib/unittest/loader.py
diff options
context:
space:
mode:
authorChristian Heimes <christian@cheimes.de>2013-11-23 14:59:07 (GMT)
committerChristian Heimes <christian@cheimes.de>2013-11-23 14:59:07 (GMT)
commit5de397e158d538ffa065974006e58547891bd955 (patch)
treedfa6715108e51cb33635bdf23210443f42898b03 /Lib/unittest/loader.py
parent4c05b472ddd4634138b6abfa857ee37761d33185 (diff)
parent2cf3917954dab65c025ea39f8b6a0298c598f9f7 (diff)
downloadcpython-5de397e158d538ffa065974006e58547891bd955.zip
cpython-5de397e158d538ffa065974006e58547891bd955.tar.gz
cpython-5de397e158d538ffa065974006e58547891bd955.tar.bz2
merge
Diffstat (limited to 'Lib/unittest/loader.py')
-rw-r--r--Lib/unittest/loader.py60
1 files changed, 51 insertions, 9 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index e872fcc..808c50e 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -61,8 +61,9 @@ class TestLoader(object):
def loadTestsFromTestCase(self, testCaseClass):
"""Return a suite of all tests cases contained in testCaseClass"""
if issubclass(testCaseClass, suite.TestSuite):
- raise TypeError("Test cases should not be derived from TestSuite." \
- " Maybe you meant to derive from TestCase?")
+ raise TypeError("Test cases should not be derived from "
+ "TestSuite. Maybe you meant to derive from "
+ "TestCase?")
testCaseNames = self.getTestCaseNames(testCaseClass)
if not testCaseNames and hasattr(testCaseClass, 'runTest'):
testCaseNames = ['runTest']
@@ -200,6 +201,8 @@ class TestLoader(object):
self._top_level_dir = top_level_dir
is_not_importable = False
+ is_namespace = False
+ tests = []
if os.path.isdir(os.path.abspath(start_dir)):
start_dir = os.path.abspath(start_dir)
if start_dir != top_level_dir:
@@ -213,15 +216,52 @@ class TestLoader(object):
else:
the_module = sys.modules[start_dir]
top_part = start_dir.split('.')[0]
- start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
+ try:
+ start_dir = os.path.abspath(
+ os.path.dirname((the_module.__file__)))
+ except AttributeError:
+ # look for namespace packages
+ try:
+ spec = the_module.__spec__
+ except AttributeError:
+ spec = None
+
+ if spec and spec.loader is None:
+ if spec.submodule_search_locations is not None:
+ is_namespace = True
+
+ for path in the_module.__path__:
+ if (not set_implicit_top and
+ not path.startswith(top_level_dir)):
+ continue
+ self._top_level_dir = \
+ (path.split(the_module.__name__
+ .replace(".", os.path.sep))[0])
+ tests.extend(self._find_tests(path,
+ pattern,
+ namespace=True))
+ elif the_module.__name__ in sys.builtin_module_names:
+ # builtin module
+ raise TypeError('Can not use builtin modules '
+ 'as dotted module names') from None
+ else:
+ raise TypeError(
+ 'don\'t know how to discover from {!r}'
+ .format(the_module)) from None
+
if set_implicit_top:
- self._top_level_dir = self._get_directory_containing_module(top_part)
- sys.path.remove(top_level_dir)
+ if not is_namespace:
+ self._top_level_dir = \
+ self._get_directory_containing_module(top_part)
+ sys.path.remove(top_level_dir)
+ else:
+ sys.path.remove(top_level_dir)
if is_not_importable:
raise ImportError('Start directory is not importable: %r' % start_dir)
- tests = list(self._find_tests(start_dir, pattern))
+ if not is_namespace:
+ tests = list(self._find_tests(start_dir, pattern))
return self.suiteClass(tests)
def _get_directory_containing_module(self, module_name):
@@ -254,7 +294,7 @@ class TestLoader(object):
# override this method to use alternative matching strategy
return fnmatch(path, pattern)
- def _find_tests(self, start_dir, pattern):
+ def _find_tests(self, start_dir, pattern, namespace=False):
"""Used by discovery. Yields test suites it loads."""
paths = sorted(os.listdir(start_dir))
@@ -287,7 +327,8 @@ class TestLoader(object):
raise ImportError(msg % (mod_name, module_dir, expected_dir))
yield self.loadTestsFromModule(module)
elif os.path.isdir(full_path):
- if not os.path.isfile(os.path.join(full_path, '__init__.py')):
+ if (not namespace and
+ not os.path.isfile(os.path.join(full_path, '__init__.py'))):
continue
load_tests = None
@@ -304,7 +345,8 @@ class TestLoader(object):
# tests loaded from package file
yield tests
# recurse into the package
- yield from self._find_tests(full_path, pattern)
+ yield from self._find_tests(full_path, pattern,
+ namespace=namespace)
else:
try:
yield load_tests(self, tests, pattern)