diff options
Diffstat (limited to 'Lib/unittest/loader.py')
-rw-r--r-- | Lib/unittest/loader.py | 39 |
1 files changed, 35 insertions, 4 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index f00f38d..a45dffa 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -166,27 +166,58 @@ class TestLoader(object): packages can continue discovery themselves. top_level_dir is stored so load_tests does not need to pass this argument in to loader.discover(). """ + set_implicit_top = False if top_level_dir is None and self._top_level_dir is not None: # make top_level_dir optional if called from load_tests in a package top_level_dir = self._top_level_dir elif top_level_dir is None: + set_implicit_top = True top_level_dir = start_dir - top_level_dir = os.path.abspath(os.path.normpath(top_level_dir)) - start_dir = os.path.abspath(os.path.normpath(start_dir)) + top_level_dir = os.path.abspath(top_level_dir) if not top_level_dir in sys.path: # all test modules must be importable from the top level directory sys.path.append(top_level_dir) self._top_level_dir = top_level_dir - if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')): - # what about __init__.pyc or pyo (etc) + is_not_importable = False + if os.path.isdir(os.path.abspath(start_dir)): + start_dir = os.path.abspath(start_dir) + if start_dir != top_level_dir: + is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py')) + else: + # support for discovery from dotted module names + try: + __import__(start_dir) + except ImportError: + is_not_importable = True + else: + the_module = sys.modules[start_dir] + top_part = start_dir.split('.')[0] + start_dir = os.path.abspath(os.path.dirname((the_module.__file__))) + if set_implicit_top: + self._top_level_dir = self._get_directory_containing_module(top_part) + sys.path.remove(top_level_dir) + + if is_not_importable: raise ImportError('Start directory is not importable: %r' % start_dir) tests = list(self._find_tests(start_dir, pattern)) return self.suiteClass(tests) + def _get_directory_containing_module(self, module_name): + module = sys.modules[module_name] + full_path = os.path.abspath(module.__file__) + + if os.path.basename(full_path).lower().startswith('__init__.py'): + return os.path.dirname(os.path.dirname(full_path)) + else: + # here we have been given a module rather than a package - so + # all we can do is search the *same* directory the module is in + # should an exception be raised instead + return os.path.dirname(full_path) + def _get_name_from_path(self, path): path = os.path.splitext(os.path.normpath(path))[0] |