summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorIrit Katriel <1055913+iritkatriel@users.noreply.github.com>2023-01-03 22:24:19 (GMT)
committerGitHub <noreply@github.com>2023-01-03 22:24:19 (GMT)
commitff9ac5807172499d29757fc27dee4378164db3be (patch)
tree8d6b4a6c15b1940a9bc241060552762b3a5d01cb
parent242836c3f28706e4c940f19f7583fc85936fdbbb (diff)
downloadcpython-ff9ac5807172499d29757fc27dee4378164db3be.zip
cpython-ff9ac5807172499d29757fc27dee4378164db3be.tar.gz
cpython-ff9ac5807172499d29757fc27dee4378164db3be.tar.bz2
[3.10] gh-95882: Add tests for traceback from contextlib context managers (GH-95883) (#100715)
-rw-r--r--Lib/test/test_contextlib.py50
-rw-r--r--Lib/test/test_contextlib_async.py58
2 files changed, 108 insertions, 0 deletions
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 68bd45d..bcedf17 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -87,6 +87,56 @@ class ContextManagerTestCase(unittest.TestCase):
raise ZeroDivisionError()
self.assertEqual(state, [1, 42, 999])
+ def test_contextmanager_traceback(self):
+ @contextmanager
+ def f():
+ yield
+
+ try:
+ with f():
+ 1/0
+ except ZeroDivisionError as e:
+ frames = traceback.extract_tb(e.__traceback__)
+
+ self.assertEqual(len(frames), 1)
+ self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+ self.assertEqual(frames[0].line, '1/0')
+
+ # Repeat with RuntimeError (which goes through a different code path)
+ class RuntimeErrorSubclass(RuntimeError):
+ pass
+
+ try:
+ with f():
+ raise RuntimeErrorSubclass(42)
+ except RuntimeErrorSubclass as e:
+ frames = traceback.extract_tb(e.__traceback__)
+
+ self.assertEqual(len(frames), 1)
+ self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+ self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
+
+ class StopIterationSubclass(StopIteration):
+ pass
+
+ for stop_exc in (
+ StopIteration('spam'),
+ StopIterationSubclass('spam'),
+ ):
+ with self.subTest(type=type(stop_exc)):
+ try:
+ with f():
+ raise stop_exc
+ except type(stop_exc) as e:
+ self.assertIs(e, stop_exc)
+ frames = traceback.extract_tb(e.__traceback__)
+ else:
+ self.fail(f'{stop_exc} was suppressed')
+
+ self.assertEqual(len(frames), 1)
+ self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+ self.assertEqual(frames[0].line, 'raise stop_exc')
+
def test_contextmanager_no_reraise(self):
@contextmanager
def whee():
diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py
index d44d362..3d1079c 100644
--- a/Lib/test/test_contextlib_async.py
+++ b/Lib/test/test_contextlib_async.py
@@ -4,6 +4,7 @@ from contextlib import (
AsyncExitStack, nullcontext, aclosing, contextmanager)
import functools
from test import support
+import traceback
import unittest
from test.test_contextlib import TestBaseExitStack
@@ -125,6 +126,63 @@ class AsyncContextManagerTestCase(unittest.TestCase):
self.assertEqual(state, [1, 42, 999])
@_async_test
+ async def test_contextmanager_traceback(self):
+ @asynccontextmanager
+ async def f():
+ yield
+
+ try:
+ async with f():
+ 1/0
+ except ZeroDivisionError as e:
+ frames = traceback.extract_tb(e.__traceback__)
+
+ self.assertEqual(len(frames), 1)
+ self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+ self.assertEqual(frames[0].line, '1/0')
+
+ # Repeat with RuntimeError (which goes through a different code path)
+ class RuntimeErrorSubclass(RuntimeError):
+ pass
+
+ try:
+ async with f():
+ raise RuntimeErrorSubclass(42)
+ except RuntimeErrorSubclass as e:
+ frames = traceback.extract_tb(e.__traceback__)
+
+ self.assertEqual(len(frames), 1)
+ self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+ self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
+
+ class StopIterationSubclass(StopIteration):
+ pass
+
+ class StopAsyncIterationSubclass(StopAsyncIteration):
+ pass
+
+ for stop_exc in (
+ StopIteration('spam'),
+ StopAsyncIteration('ham'),
+ StopIterationSubclass('spam'),
+ StopAsyncIterationSubclass('spam')
+ ):
+ with self.subTest(type=type(stop_exc)):
+ try:
+ async with f():
+ raise stop_exc
+ except type(stop_exc) as e:
+ self.assertIs(e, stop_exc)
+ frames = traceback.extract_tb(e.__traceback__)
+ else:
+ self.fail(f'{stop_exc} was suppressed')
+
+ self.assertEqual(len(frames), 1)
+ self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+ self.assertEqual(frames[0].line, 'raise stop_exc')
+
+
+ @_async_test
async def test_contextmanager_no_reraise(self):
@asynccontextmanager
async def whee():