diff options
author | Tian Gao <gaogaotiantian@hotmail.com> | 2023-07-18 23:20:31 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-18 23:20:31 (GMT) |
commit | 663854d73b35feeb004ae0970e45b53ca27774a1 (patch) | |
tree | 24f61e97f387c4e1168672b75633f7ca8c8d29c2 /Lib | |
parent | 505eede38d141d43e40e246319b157e3c77211d3 (diff) | |
download | cpython-663854d73b35feeb004ae0970e45b53ca27774a1.zip cpython-663854d73b35feeb004ae0970e45b53ca27774a1.tar.gz cpython-663854d73b35feeb004ae0970e45b53ca27774a1.tar.bz2 |
gh-106727: Make `inspect.getsource` smarter for class for same name definitions (#106815)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/inspect.py | 57 | ||||
-rw-r--r-- | Lib/test/inspect_fodder2.py | 20 | ||||
-rw-r--r-- | Lib/test/test_inspect.py | 5 |
3 files changed, 70 insertions, 12 deletions
diff --git a/Lib/inspect.py b/Lib/inspect.py index 15f94a1..675714d 100644 --- a/Lib/inspect.py +++ b/Lib/inspect.py @@ -1034,9 +1034,13 @@ class ClassFoundException(Exception): class _ClassFinder(ast.NodeVisitor): - def __init__(self, qualname): + def __init__(self, cls, tree, lines, qualname): self.stack = [] + self.cls = cls + self.tree = tree + self.lines = lines self.qualname = qualname + self.lineno_found = [] def visit_FunctionDef(self, node): self.stack.append(node.name) @@ -1057,11 +1061,48 @@ class _ClassFinder(ast.NodeVisitor): line_number = node.lineno # decrement by one since lines starts with indexing by zero - line_number -= 1 - raise ClassFoundException(line_number) + self.lineno_found.append((line_number - 1, node.end_lineno)) self.generic_visit(node) self.stack.pop() + def get_lineno(self): + self.visit(self.tree) + lineno_found_number = len(self.lineno_found) + if lineno_found_number == 0: + raise OSError('could not find class definition') + elif lineno_found_number == 1: + return self.lineno_found[0][0] + else: + # We have multiple candidates for the class definition. + # Now we have to guess. + + # First, let's see if there are any method definitions + for member in self.cls.__dict__.values(): + if isinstance(member, types.FunctionType): + for lineno, end_lineno in self.lineno_found: + if lineno <= member.__code__.co_firstlineno <= end_lineno: + return lineno + + class_strings = [(''.join(self.lines[lineno: end_lineno]), lineno) + for lineno, end_lineno in self.lineno_found] + + # Maybe the class has a docstring and it's unique? + if self.cls.__doc__: + ret = None + for candidate, lineno in class_strings: + if self.cls.__doc__.strip() in candidate: + if ret is None: + ret = lineno + else: + break + else: + if ret is not None: + return ret + + # We are out of ideas, just return the last one found, which is + # slightly better than previous ones + return self.lineno_found[-1][0] + def findsource(object): """Return the entire source file and starting line number for an object. @@ -1098,14 +1139,8 @@ def findsource(object): qualname = object.__qualname__ source = ''.join(lines) tree = ast.parse(source) - class_finder = _ClassFinder(qualname) - try: - class_finder.visit(tree) - except ClassFoundException as e: - line_number = e.args[0] - return lines, line_number - else: - raise OSError('could not find class definition') + class_finder = _ClassFinder(object, tree, lines, qualname) + return lines, class_finder.get_lineno() if ismethod(object): object = object.__func__ diff --git a/Lib/test/inspect_fodder2.py b/Lib/test/inspect_fodder2.py index 0346461..8639cf2 100644 --- a/Lib/test/inspect_fodder2.py +++ b/Lib/test/inspect_fodder2.py @@ -290,3 +290,23 @@ post_line_parenthesized_lambda1 = (lambda: () nested_lambda = ( lambda right: [].map( lambda length: ())) + +# line 294 +if True: + class cls296: + def f(): + pass +else: + class cls296: + def g(): + pass + +# line 304 +if False: + class cls310: + def f(): + pass +else: + class cls310: + def g(): + pass diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py index 64afeec..33a593f 100644 --- a/Lib/test/test_inspect.py +++ b/Lib/test/test_inspect.py @@ -949,7 +949,6 @@ class TestBuggyCases(GetSourceBase): self.assertSourceEqual(mod2.cls196.cls200, 198, 201) def test_class_inside_conditional(self): - self.assertSourceEqual(mod2.cls238, 238, 240) self.assertSourceEqual(mod2.cls238.cls239, 239, 240) def test_multiple_children_classes(self): @@ -975,6 +974,10 @@ class TestBuggyCases(GetSourceBase): self.assertSourceEqual(mod2.cls226, 231, 235) self.assertSourceEqual(asyncio.run(mod2.cls226().func232()), 233, 234) + def test_class_definition_same_name_diff_methods(self): + self.assertSourceEqual(mod2.cls296, 296, 298) + self.assertSourceEqual(mod2.cls310, 310, 312) + class TestNoEOL(GetSourceBase): def setUp(self): self.tempdir = TESTFN + '_dir' |