diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/unittest/__init__.py | 4 | ||||
-rw-r--r-- | Lib/unittest/signals.py | 19 | ||||
-rw-r--r-- | Lib/unittest/test/test_break.py | 21 |
3 files changed, 42 insertions, 2 deletions
diff --git a/Lib/unittest/__init__.py b/Lib/unittest/__init__.py index e84299e..201a3f0 100644 --- a/Lib/unittest/__init__.py +++ b/Lib/unittest/__init__.py @@ -48,7 +48,7 @@ __all__ = ['TestResult', 'TestCase', 'TestSuite', 'TextTestRunner', 'TestLoader', 'FunctionTestCase', 'main', 'defaultTestLoader', 'SkipTest', 'skip', 'skipIf', 'skipUnless', 'expectedFailure', 'TextTestResult', 'installHandler', - 'registerResult', 'removeResult'] + 'registerResult', 'removeResult', 'removeHandler'] # Expose obsolete functions for backwards compatibility __all__.extend(['getTestCaseNames', 'makeSuite', 'findTestCases']) @@ -63,7 +63,7 @@ from .loader import (TestLoader, defaultTestLoader, makeSuite, getTestCaseNames, findTestCases) from .main import TestProgram, main from .runner import TextTestRunner, TextTestResult -from .signals import installHandler, registerResult, removeResult +from .signals import installHandler, registerResult, removeResult, removeHandler # deprecated _TextTestResult = TextTestResult diff --git a/Lib/unittest/signals.py b/Lib/unittest/signals.py index 0651cf2..fc31043 100644 --- a/Lib/unittest/signals.py +++ b/Lib/unittest/signals.py @@ -1,6 +1,8 @@ import signal import weakref +from functools import wraps + __unittest = True @@ -36,3 +38,20 @@ def installHandler(): default_handler = signal.getsignal(signal.SIGINT) _interrupt_handler = _InterruptHandler(default_handler) signal.signal(signal.SIGINT, _interrupt_handler) + + +def removeHandler(method=None): + if method is not None: + @wraps(method) + def inner(*args, **kwargs): + initial = signal.getsignal(signal.SIGINT) + removeHandler() + try: + return method(*args, **kwargs) + finally: + signal.signal(signal.SIGINT, initial) + return inner + + global _interrupt_handler + if _interrupt_handler is not None: + signal.signal(signal.SIGINT, _interrupt_handler.default_handler) diff --git a/Lib/unittest/test/test_break.py b/Lib/unittest/test/test_break.py index 4f89e87..0e09dfb 100644 --- a/Lib/unittest/test/test_break.py +++ b/Lib/unittest/test/test_break.py @@ -227,3 +227,24 @@ class TestBreak(unittest.TestCase): self.assertEqual(p.result, result) self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler) + + def testRemoveHandler(self): + default_handler = signal.getsignal(signal.SIGINT) + unittest.installHandler() + unittest.removeHandler() + self.assertEqual(signal.getsignal(signal.SIGINT), default_handler) + + # check that calling removeHandler multiple times has no ill-effect + unittest.removeHandler() + self.assertEqual(signal.getsignal(signal.SIGINT), default_handler) + + def testRemoveHandlerAsDecorator(self): + default_handler = signal.getsignal(signal.SIGINT) + unittest.installHandler() + + @unittest.removeHandler + def test(): + self.assertEqual(signal.getsignal(signal.SIGINT), default_handler) + + test() + self.assertNotEqual(signal.getsignal(signal.SIGINT), default_handler) |