diff options
-rw-r--r-- | Lib/unittest/loader.py | 15 | ||||
-rw-r--r-- | Lib/unittest/test/test_discovery.py | 39 |
2 files changed, 53 insertions, 1 deletions
diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index a45dffa..76c4e11 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -178,7 +178,10 @@ class TestLoader(object): if not top_level_dir in sys.path: # all test modules must be importable from the top level directory - sys.path.append(top_level_dir) + # should we *unconditionally* put the start directory in first + # in sys.path to minimise likelihood of conflicts between installed + # modules and development versions? + sys.path.insert(0, top_level_dir) self._top_level_dir = top_level_dir is_not_importable = False @@ -251,6 +254,16 @@ class TestLoader(object): except: yield _make_failed_import_test(name, self.suiteClass) else: + mod_file = os.path.abspath(getattr(module, '__file__', full_path)) + realpath = os.path.splitext(mod_file)[0] + fullpath_noext = os.path.splitext(full_path)[0] + if realpath.lower() != fullpath_noext.lower(): + module_dir = os.path.dirname(realpath) + mod_name = os.path.splitext(os.path.basename(full_path))[0] + expected_dir = os.path.dirname(full_path) + msg = ("%r module incorrectly imported from %r. Expected %r. " + "Is this module globally installed?") + raise ImportError(msg % (mod_name, module_dir, expected_dir)) yield self.loadTestsFromModule(module) elif os.path.isdir(full_path): if not os.path.isfile(os.path.join(full_path, '__init__.py')): diff --git a/Lib/unittest/test/test_discovery.py b/Lib/unittest/test/test_discovery.py index 1b0f08b..8253580 100644 --- a/Lib/unittest/test/test_discovery.py +++ b/Lib/unittest/test/test_discovery.py @@ -294,6 +294,45 @@ class TestDiscovery(unittest.TestCase): self.assertTrue(program.failfast) self.assertTrue(program.catchbreak) + def test_detect_module_clash(self): + class Module(object): + __file__ = 'bar/foo.py' + sys.modules['foo'] = Module + full_path = os.path.abspath('foo') + original_listdir = os.listdir + original_isfile = os.path.isfile + original_isdir = os.path.isdir + + def cleanup(): + os.listdir = original_listdir + os.path.isfile = original_isfile + os.path.isdir = original_isdir + del sys.modules['foo'] + if full_path in sys.path: + sys.path.remove(full_path) + self.addCleanup(cleanup) + + def listdir(_): + return ['foo.py'] + def isfile(_): + return True + def isdir(_): + return True + os.listdir = listdir + os.path.isfile = isfile + os.path.isdir = isdir + + loader = unittest.TestLoader() + + mod_dir = os.path.abspath('bar') + expected_dir = os.path.abspath('foo') + msg = (r"^'foo' module incorrectly imported from %r\. Expected %r\. " + "Is this module globally installed\?$") % (mod_dir, expected_dir) + self.assertRaisesRegexp( + ImportError, msg, loader.discover, + start_dir='foo', pattern='foo.py' + ) + self.assertEqual(sys.path[0], full_path) if __name__ == '__main__': unittest.main() |