summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorTian Gao <gaogaotiantian@hotmail.com>2023-07-18 23:20:31 (GMT)
committerGitHub <noreply@github.com>2023-07-18 23:20:31 (GMT)
commit663854d73b35feeb004ae0970e45b53ca27774a1 (patch)
tree24f61e97f387c4e1168672b75633f7ca8c8d29c2 /Lib
parent505eede38d141d43e40e246319b157e3c77211d3 (diff)
downloadcpython-663854d73b35feeb004ae0970e45b53ca27774a1.zip
cpython-663854d73b35feeb004ae0970e45b53ca27774a1.tar.gz
cpython-663854d73b35feeb004ae0970e45b53ca27774a1.tar.bz2
gh-106727: Make `inspect.getsource` smarter for class for same name definitions (#106815)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/inspect.py57
-rw-r--r--Lib/test/inspect_fodder2.py20
-rw-r--r--Lib/test/test_inspect.py5
3 files changed, 70 insertions, 12 deletions
diff --git a/Lib/inspect.py b/Lib/inspect.py
index 15f94a1..675714d 100644
--- a/Lib/inspect.py
+++ b/Lib/inspect.py
@@ -1034,9 +1034,13 @@ class ClassFoundException(Exception):
class _ClassFinder(ast.NodeVisitor):
- def __init__(self, qualname):
+ def __init__(self, cls, tree, lines, qualname):
self.stack = []
+ self.cls = cls
+ self.tree = tree
+ self.lines = lines
self.qualname = qualname
+ self.lineno_found = []
def visit_FunctionDef(self, node):
self.stack.append(node.name)
@@ -1057,11 +1061,48 @@ class _ClassFinder(ast.NodeVisitor):
line_number = node.lineno
# decrement by one since lines starts with indexing by zero
- line_number -= 1
- raise ClassFoundException(line_number)
+ self.lineno_found.append((line_number - 1, node.end_lineno))
self.generic_visit(node)
self.stack.pop()
+ def get_lineno(self):
+ self.visit(self.tree)
+ lineno_found_number = len(self.lineno_found)
+ if lineno_found_number == 0:
+ raise OSError('could not find class definition')
+ elif lineno_found_number == 1:
+ return self.lineno_found[0][0]
+ else:
+ # We have multiple candidates for the class definition.
+ # Now we have to guess.
+
+ # First, let's see if there are any method definitions
+ for member in self.cls.__dict__.values():
+ if isinstance(member, types.FunctionType):
+ for lineno, end_lineno in self.lineno_found:
+ if lineno <= member.__code__.co_firstlineno <= end_lineno:
+ return lineno
+
+ class_strings = [(''.join(self.lines[lineno: end_lineno]), lineno)
+ for lineno, end_lineno in self.lineno_found]
+
+ # Maybe the class has a docstring and it's unique?
+ if self.cls.__doc__:
+ ret = None
+ for candidate, lineno in class_strings:
+ if self.cls.__doc__.strip() in candidate:
+ if ret is None:
+ ret = lineno
+ else:
+ break
+ else:
+ if ret is not None:
+ return ret
+
+ # We are out of ideas, just return the last one found, which is
+ # slightly better than previous ones
+ return self.lineno_found[-1][0]
+
def findsource(object):
"""Return the entire source file and starting line number for an object.
@@ -1098,14 +1139,8 @@ def findsource(object):
qualname = object.__qualname__
source = ''.join(lines)
tree = ast.parse(source)
- class_finder = _ClassFinder(qualname)
- try:
- class_finder.visit(tree)
- except ClassFoundException as e:
- line_number = e.args[0]
- return lines, line_number
- else:
- raise OSError('could not find class definition')
+ class_finder = _ClassFinder(object, tree, lines, qualname)
+ return lines, class_finder.get_lineno()
if ismethod(object):
object = object.__func__
diff --git a/Lib/test/inspect_fodder2.py b/Lib/test/inspect_fodder2.py
index 0346461..8639cf2 100644
--- a/Lib/test/inspect_fodder2.py
+++ b/Lib/test/inspect_fodder2.py
@@ -290,3 +290,23 @@ post_line_parenthesized_lambda1 = (lambda: ()
nested_lambda = (
lambda right: [].map(
lambda length: ()))
+
+# line 294
+if True:
+ class cls296:
+ def f():
+ pass
+else:
+ class cls296:
+ def g():
+ pass
+
+# line 304
+if False:
+ class cls310:
+ def f():
+ pass
+else:
+ class cls310:
+ def g():
+ pass
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py
index 64afeec..33a593f 100644
--- a/Lib/test/test_inspect.py
+++ b/Lib/test/test_inspect.py
@@ -949,7 +949,6 @@ class TestBuggyCases(GetSourceBase):
self.assertSourceEqual(mod2.cls196.cls200, 198, 201)
def test_class_inside_conditional(self):
- self.assertSourceEqual(mod2.cls238, 238, 240)
self.assertSourceEqual(mod2.cls238.cls239, 239, 240)
def test_multiple_children_classes(self):
@@ -975,6 +974,10 @@ class TestBuggyCases(GetSourceBase):
self.assertSourceEqual(mod2.cls226, 231, 235)
self.assertSourceEqual(asyncio.run(mod2.cls226().func232()), 233, 234)
+ def test_class_definition_same_name_diff_methods(self):
+ self.assertSourceEqual(mod2.cls296, 296, 298)
+ self.assertSourceEqual(mod2.cls310, 310, 312)
+
class TestNoEOL(GetSourceBase):
def setUp(self):
self.tempdir = TESTFN + '_dir'