summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/unittest/loader.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py
index f0cc157..360a41e 100644
--- a/Lib/unittest/loader.py
+++ b/Lib/unittest/loader.py
@@ -192,7 +192,7 @@ class TestLoader(object):
top_part = start_dir.split('.')[0]
start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
if set_implicit_top:
- self._top_level_dir = os.path.abspath(os.path.dirname(os.path.dirname(sys.modules[top_part].__file__)))
+ self._top_level_dir = self._get_directory_containing_module(top_part)
sys.path.remove(top_level_dir)
if is_not_importable:
@@ -201,6 +201,18 @@ class TestLoader(object):
tests = list(self._find_tests(start_dir, pattern))
return self.suiteClass(tests)
+ def _get_directory_containing_module(self, module_name):
+ module = sys.modules[module_name]
+ full_path = os.path.abspath(module.__file__)
+
+ if os.path.basename(full_path).lower().startswith('__init__.py'):
+ return os.path.dirname(os.path.dirname(full_path))
+ else:
+ # here we have been given a module rather than a package - so
+ # all we can do is search the *same* directory the module is in
+ # should an exception be raised instead
+ return os.path.dirname(full_path)
+
def _get_name_from_path(self, path):
path = os.path.splitext(os.path.normpath(path))[0]