summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/importlib/abc.py30
-rw-r--r--Lib/test/test_importlib/test_abc.py96
2 files changed, 113 insertions, 13 deletions
diff --git a/Lib/importlib/abc.py b/Lib/importlib/abc.py
index cdcf244..b45e7d6 100644
--- a/Lib/importlib/abc.py
+++ b/Lib/importlib/abc.py
@@ -147,14 +147,18 @@ class InspectLoader(Loader):
"""
raise ImportError
- @abc.abstractmethod
def get_code(self, fullname):
- """Abstract method which when implemented should return the code object
- for the module. The fullname is a str. Returns a types.CodeType.
+ """Method which returns the code object for the module.
- Raises ImportError if the module cannot be found.
+ The fullname is a str. Returns a types.CodeType if possible, else
+ returns None if a code object does not make sense
+ (e.g. built-in module). Raises ImportError if the module cannot be
+ found.
"""
- raise ImportError
+ source = self.get_source(fullname)
+ if source is None:
+ return None
+ return self.source_to_code(source)
@abc.abstractmethod
def get_source(self, fullname):
@@ -194,6 +198,22 @@ class ExecutionLoader(InspectLoader):
"""
raise ImportError
+ def get_code(self, fullname):
+ """Method to return the code object for fullname.
+
+ Should return None if not applicable (e.g. built-in module).
+ Raise ImportError if the module cannot be found.
+ """
+ source = self.get_source(fullname)
+ if source is None:
+ return None
+ try:
+ path = self.get_filename(fullname)
+ except ImportError:
+ return self.source_to_code(source)
+ else:
+ return self.source_to_code(source, path)
+
class FileLoader(_bootstrap.FileLoader, ResourceLoader, ExecutionLoader):
diff --git a/Lib/test/test_importlib/test_abc.py b/Lib/test/test_importlib/test_abc.py
index 5d1661c..b443337 100644
--- a/Lib/test/test_importlib/test_abc.py
+++ b/Lib/test/test_importlib/test_abc.py
@@ -9,6 +9,7 @@ import marshal
import os
import sys
import unittest
+from unittest import mock
from . import util
@@ -166,9 +167,6 @@ class InspectLoaderSubclass(LoaderSubclass, abc.InspectLoader):
def is_package(self, fullname):
return super().is_package(fullname)
- def get_code(self, fullname):
- return super().get_code(fullname)
-
def get_source(self, fullname):
return super().get_source(fullname)
@@ -181,10 +179,6 @@ class InspectLoaderDefaultsTests(unittest.TestCase):
with self.assertRaises(ImportError):
self.ins.is_package('blah')
- def test_get_code(self):
- with self.assertRaises(ImportError):
- self.ins.get_code('blah')
-
def test_get_source(self):
with self.assertRaises(ImportError):
self.ins.get_source('blah')
@@ -206,7 +200,7 @@ class ExecutionLoaderDefaultsTests(unittest.TestCase):
##### InspectLoader concrete methods ###########################################
-class InspectLoaderConcreteMethodTests(unittest.TestCase):
+class InspectLoaderSourceToCodeTests(unittest.TestCase):
def source_to_module(self, data, path=None):
"""Help with source_to_code() tests."""
@@ -248,6 +242,92 @@ class InspectLoaderConcreteMethodTests(unittest.TestCase):
self.assertEqual(code.co_filename, '<string>')
+class InspectLoaderGetCodeTests(unittest.TestCase):
+
+ def test_get_code(self):
+ # Test success.
+ module = imp.new_module('blah')
+ with mock.patch.object(InspectLoaderSubclass, 'get_source') as mocked:
+ mocked.return_value = 'attr = 42'
+ loader = InspectLoaderSubclass()
+ code = loader.get_code('blah')
+ exec(code, module.__dict__)
+ self.assertEqual(module.attr, 42)
+
+ def test_get_code_source_is_None(self):
+ # If get_source() is None then this should be None.
+ with mock.patch.object(InspectLoaderSubclass, 'get_source') as mocked:
+ mocked.return_value = None
+ loader = InspectLoaderSubclass()
+ code = loader.get_code('blah')
+ self.assertIsNone(code)
+
+ def test_get_code_source_not_found(self):
+ # If there is no source then there is no code object.
+ loader = InspectLoaderSubclass()
+ with self.assertRaises(ImportError):
+ loader.get_code('blah')
+
+
+##### ExecutionLoader concrete methods #########################################
+class ExecutionLoaderGetCodeTests(unittest.TestCase):
+
+ def mock_methods(self, *, get_source=False, get_filename=False):
+ source_mock_context, filename_mock_context = None, None
+ if get_source:
+ source_mock_context = mock.patch.object(ExecutionLoaderSubclass,
+ 'get_source')
+ if get_filename:
+ filename_mock_context = mock.patch.object(ExecutionLoaderSubclass,
+ 'get_filename')
+ return source_mock_context, filename_mock_context
+
+ def test_get_code(self):
+ path = 'blah.py'
+ source_mock_context, filename_mock_context = self.mock_methods(
+ get_source=True, get_filename=True)
+ with source_mock_context as source_mock, filename_mock_context as name_mock:
+ source_mock.return_value = 'attr = 42'
+ name_mock.return_value = path
+ loader = ExecutionLoaderSubclass()
+ code = loader.get_code('blah')
+ self.assertEqual(code.co_filename, path)
+ module = imp.new_module('blah')
+ exec(code, module.__dict__)
+ self.assertEqual(module.attr, 42)
+
+ def test_get_code_source_is_None(self):
+ # If get_source() is None then this should be None.
+ source_mock_context, _ = self.mock_methods(get_source=True)
+ with source_mock_context as mocked:
+ mocked.return_value = None
+ loader = ExecutionLoaderSubclass()
+ code = loader.get_code('blah')
+ self.assertIsNone(code)
+
+ def test_get_code_source_not_found(self):
+ # If there is no source then there is no code object.
+ loader = ExecutionLoaderSubclass()
+ with self.assertRaises(ImportError):
+ loader.get_code('blah')
+
+ def test_get_code_no_path(self):
+ # If get_filename() raises ImportError then simply skip setting the path
+ # on the code object.
+ source_mock_context, filename_mock_context = self.mock_methods(
+ get_source=True, get_filename=True)
+ with source_mock_context as source_mock, filename_mock_context as name_mock:
+ source_mock.return_value = 'attr = 42'
+ name_mock.side_effect = ImportError
+ loader = ExecutionLoaderSubclass()
+ code = loader.get_code('blah')
+ self.assertEqual(code.co_filename, '<string>')
+ module = imp.new_module('blah')
+ exec(code, module.__dict__)
+ self.assertEqual(module.attr, 42)
+
+
+
##### SourceLoader concrete methods ############################################
class SourceOnlyLoaderMock(abc.SourceLoader):