summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorKarthikeyan Singaravelan <tir.karthi@gmail.com>2020-04-18 16:19:32 (GMT)
committerGitHub <noreply@github.com>2020-04-18 16:19:32 (GMT)
commit696136b993e11b37c4f34d729a0375e5ad544ade (patch)
tree19b1085d34db9d665b8f131e3f99bd0c2a40fcb6 /Lib
parentce578831a4e573eac422a488930100bc5380f227 (diff)
downloadcpython-696136b993e11b37c4f34d729a0375e5ad544ade.zip
cpython-696136b993e11b37c4f34d729a0375e5ad544ade.tar.gz
cpython-696136b993e11b37c4f34d729a0375e5ad544ade.tar.bz2
bpo-35113: Fix inspect.getsource to return correct source for inner classes (#10307)
* Use ast module to find class definition * Add NEWS entry * Fix class with multiple children and move decorator code to the method * Fix PR comments 1. Use node.decorator_list to select decorators 2. Remove unwanted variables in ClassVisitor 3. Simplify stack management as per review * Add test for nested functions and async calls * Fix pydoc test since comments are returned now correctly * Set event loop policy as None to fix environment related change * Refactor visit_AsyncFunctionDef and tests * Refactor to use local variables and fix tests * Add patch attribution * Use self.addCleanup for asyncio * Rename ClassVisitor to ClassFinder and fix asyncio cleanup * Return first class inside conditional in case of multiple definitions. Remove decorator for class source. * Add docstring to make the test correct * Modify NEWS entry regarding decorators * Return decorators too for bpo-15856 * Move ast and the class source code to top. Use proper Exception.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/inspect.py65
-rw-r--r--Lib/test/inspect_fodder2.py114
-rw-r--r--Lib/test/test_inspect.py40
-rw-r--r--Lib/test/test_pydoc.py1
4 files changed, 197 insertions, 23 deletions
diff --git a/Lib/inspect.py b/Lib/inspect.py
index 6f7d5cd..ad7e8cb 100644
--- a/Lib/inspect.py
+++ b/Lib/inspect.py
@@ -32,6 +32,7 @@ __author__ = ('Ka-Ping Yee <ping@lfw.org>',
'Yury Selivanov <yselivanov@sprymix.com>')
import abc
+import ast
import dis
import collections.abc
import enum
@@ -770,6 +771,42 @@ def getmodule(object, _filename=None):
if builtinobject is object:
return builtin
+
+class ClassFoundException(Exception):
+ pass
+
+
+class _ClassFinder(ast.NodeVisitor):
+
+ def __init__(self, qualname):
+ self.stack = []
+ self.qualname = qualname
+
+ def visit_FunctionDef(self, node):
+ self.stack.append(node.name)
+ self.stack.append('<locals>')
+ self.generic_visit(node)
+ self.stack.pop()
+ self.stack.pop()
+
+ visit_AsyncFunctionDef = visit_FunctionDef
+
+ def visit_ClassDef(self, node):
+ self.stack.append(node.name)
+ if self.qualname == '.'.join(self.stack):
+ # Return the decorator for the class if present
+ if node.decorator_list:
+ line_number = node.decorator_list[0].lineno
+ else:
+ line_number = node.lineno
+
+ # decrement by one since lines starts with indexing by zero
+ line_number -= 1
+ raise ClassFoundException(line_number)
+ self.generic_visit(node)
+ self.stack.pop()
+
+
def findsource(object):
"""Return the entire source file and starting line number for an object.
@@ -802,25 +839,15 @@ def findsource(object):
return lines, 0
if isclass(object):
- name = object.__name__
- pat = re.compile(r'^(\s*)class\s*' + name + r'\b')
- # make some effort to find the best matching class definition:
- # use the one with the least indentation, which is the one
- # that's most probably not inside a function definition.
- candidates = []
- for i in range(len(lines)):
- match = pat.match(lines[i])
- if match:
- # if it's at toplevel, it's already the best one
- if lines[i][0] == 'c':
- return lines, i
- # else add whitespace to candidate list
- candidates.append((match.group(1), i))
- if candidates:
- # this will sort by whitespace, and by line number,
- # less whitespace first
- candidates.sort()
- return lines, candidates[0][1]
+ qualname = object.__qualname__
+ source = ''.join(lines)
+ tree = ast.parse(source)
+ class_finder = _ClassFinder(qualname)
+ try:
+ class_finder.visit(tree)
+ except ClassFoundException as e:
+ line_number = e.args[0]
+ return lines, line_number
else:
raise OSError('could not find class definition')
diff --git a/Lib/test/inspect_fodder2.py b/Lib/test/inspect_fodder2.py
index 5a7b559..e7d4b53 100644
--- a/Lib/test/inspect_fodder2.py
+++ b/Lib/test/inspect_fodder2.py
@@ -138,18 +138,124 @@ class cls135:
never_reached1
never_reached2
-#line 141
+# line 141
+class cls142:
+ a = """
+class cls149:
+ ...
+"""
+
+# line 148
+class cls149:
+
+ def func151(self):
+ pass
+
+'''
+class cls160:
+ pass
+'''
+
+# line 159
+class cls160:
+
+ def func162(self):
+ pass
+
+# line 165
+class cls166:
+ a = '''
+ class cls175:
+ ...
+ '''
+
+# line 172
+class cls173:
+
+ class cls175:
+ pass
+
+# line 178
+class cls179:
+ pass
+
+# line 182
+class cls183:
+
+ class cls185:
+
+ def func186(self):
+ pass
+
+def class_decorator(cls):
+ return cls
+
+# line 193
+@class_decorator
+@class_decorator
+class cls196:
+
+ @class_decorator
+ @class_decorator
+ class cls200:
+ pass
+
+class cls203:
+ class cls204:
+ class cls205:
+ pass
+ class cls207:
+ class cls205:
+ pass
+
+# line 211
+def func212():
+ class cls213:
+ pass
+ return cls213
+
+# line 217
+class cls213:
+ def func219(self):
+ class cls220:
+ pass
+ return cls220
+
+# line 224
+async def func225():
+ class cls226:
+ pass
+ return cls226
+
+# line 230
+class cls226:
+ async def func232(self):
+ class cls233:
+ pass
+ return cls233
+
+if True:
+ class cls238:
+ class cls239:
+ '''if clause cls239'''
+else:
+ class cls238:
+ class cls239:
+ '''else clause 239'''
+ pass
+
+#line 247
def positional_only_arg(a, /):
pass
-#line 145
+#line 251
def all_markers(a, b, /, c, d, *, e, f):
pass
-# line 149
+# line 255
def all_markers_with_args_and_kwargs(a, b, /, c, d, *args, e, f, **kwargs):
pass
-#line 153
+#line 259
def all_markers_with_defaults(a, b=1, /, c=2, d=3, *, e=4, f=5):
pass
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py
index 2dc8454..98a9c0a 100644
--- a/Lib/test/test_inspect.py
+++ b/Lib/test/test_inspect.py
@@ -473,6 +473,7 @@ class TestRetrievingSourceCode(GetSourceBase):
def test_getcomments(self):
self.assertEqual(inspect.getcomments(mod), '# line 1\n')
self.assertEqual(inspect.getcomments(mod.StupidGit), '# line 20\n')
+ self.assertEqual(inspect.getcomments(mod2.cls160), '# line 159\n')
# If the object source file is not available, return None.
co = compile('x=1', '_non_existing_filename.py', 'exec')
self.assertIsNone(inspect.getcomments(co))
@@ -709,6 +710,45 @@ class TestBuggyCases(GetSourceBase):
def test_nested_func(self):
self.assertSourceEqual(mod2.cls135.func136, 136, 139)
+ def test_class_definition_in_multiline_string_definition(self):
+ self.assertSourceEqual(mod2.cls149, 149, 152)
+
+ def test_class_definition_in_multiline_comment(self):
+ self.assertSourceEqual(mod2.cls160, 160, 163)
+
+ def test_nested_class_definition_indented_string(self):
+ self.assertSourceEqual(mod2.cls173.cls175, 175, 176)
+
+ def test_nested_class_definition(self):
+ self.assertSourceEqual(mod2.cls183, 183, 188)
+ self.assertSourceEqual(mod2.cls183.cls185, 185, 188)
+
+ def test_class_decorator(self):
+ self.assertSourceEqual(mod2.cls196, 194, 201)
+ self.assertSourceEqual(mod2.cls196.cls200, 198, 201)
+
+ def test_class_inside_conditional(self):
+ self.assertSourceEqual(mod2.cls238, 238, 240)
+ self.assertSourceEqual(mod2.cls238.cls239, 239, 240)
+
+ def test_multiple_children_classes(self):
+ self.assertSourceEqual(mod2.cls203, 203, 209)
+ self.assertSourceEqual(mod2.cls203.cls204, 204, 206)
+ self.assertSourceEqual(mod2.cls203.cls204.cls205, 205, 206)
+ self.assertSourceEqual(mod2.cls203.cls207, 207, 209)
+ self.assertSourceEqual(mod2.cls203.cls207.cls205, 208, 209)
+
+ def test_nested_class_definition_inside_function(self):
+ self.assertSourceEqual(mod2.func212(), 213, 214)
+ self.assertSourceEqual(mod2.cls213, 218, 222)
+ self.assertSourceEqual(mod2.cls213().func219(), 220, 221)
+
+ def test_nested_class_definition_inside_async_function(self):
+ import asyncio
+ self.addCleanup(asyncio.set_event_loop_policy, None)
+ self.assertSourceEqual(asyncio.run(mod2.func225()), 226, 227)
+ self.assertSourceEqual(mod2.cls226, 231, 235)
+ self.assertSourceEqual(asyncio.run(mod2.cls226().func232()), 233, 234)
class TestNoEOL(GetSourceBase):
def setUp(self):
diff --git a/Lib/test/test_pydoc.py b/Lib/test/test_pydoc.py
index 6d358f4..ffabb7f 100644
--- a/Lib/test/test_pydoc.py
+++ b/Lib/test/test_pydoc.py
@@ -476,6 +476,7 @@ class PydocDocTest(unittest.TestCase):
def test_non_str_name(self):
# issue14638
# Treat illegal (non-str) name like no name
+
class A:
__name__ = 42
class B: