diff options
author | Nick Coghlan <ncoghlan@gmail.com> | 2012-06-23 09:39:55 (GMT) |
---|---|---|
committer | Nick Coghlan <ncoghlan@gmail.com> | 2012-06-23 09:39:55 (GMT) |
commit | 2f92e54507cd4a76978133c10bf32cbdef90cd37 (patch) | |
tree | 6a027fec78404e9cce88c8ef892f6a7b1b1d423b /Lib/test | |
parent | 970da4549ab78ad6f89dd0cce0d2defc771e1c72 (diff) | |
download | cpython-2f92e54507cd4a76978133c10bf32cbdef90cd37.zip cpython-2f92e54507cd4a76978133c10bf32cbdef90cd37.tar.gz cpython-2f92e54507cd4a76978133c10bf32cbdef90cd37.tar.bz2 |
Close #13062: Add inspect.getclosurevars to simplify testing stateful closures
Diffstat (limited to 'Lib/test')
-rw-r--r-- | Lib/test/test_inspect.py | 101 |
1 files changed, 100 insertions, 1 deletions
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py index 53c947f..8327721 100644 --- a/Lib/test/test_inspect.py +++ b/Lib/test/test_inspect.py @@ -665,6 +665,105 @@ class TestClassesAndFunctions(unittest.TestCase): self.assertIn(('f', b.f), inspect.getmembers(b, inspect.ismethod)) +_global_ref = object() +class TestGetClosureVars(unittest.TestCase): + + def test_name_resolution(self): + # Basic test of the 4 different resolution mechanisms + def f(nonlocal_ref): + def g(local_ref): + print(local_ref, nonlocal_ref, _global_ref, unbound_ref) + return g + _arg = object() + nonlocal_vars = {"nonlocal_ref": _arg} + global_vars = {"_global_ref": _global_ref} + builtin_vars = {"print": print} + unbound_names = {"unbound_ref"} + expected = inspect.ClosureVars(nonlocal_vars, global_vars, + builtin_vars, unbound_names) + self.assertEqual(inspect.getclosurevars(f(_arg)), expected) + + def test_generator_closure(self): + def f(nonlocal_ref): + def g(local_ref): + print(local_ref, nonlocal_ref, _global_ref, unbound_ref) + yield + return g + _arg = object() + nonlocal_vars = {"nonlocal_ref": _arg} + global_vars = {"_global_ref": _global_ref} + builtin_vars = {"print": print} + unbound_names = {"unbound_ref"} + expected = inspect.ClosureVars(nonlocal_vars, global_vars, + builtin_vars, unbound_names) + self.assertEqual(inspect.getclosurevars(f(_arg)), expected) + + def test_method_closure(self): + class C: + def f(self, nonlocal_ref): + def g(local_ref): + print(local_ref, nonlocal_ref, _global_ref, unbound_ref) + return g + _arg = object() + nonlocal_vars = {"nonlocal_ref": _arg} + global_vars = {"_global_ref": _global_ref} + builtin_vars = {"print": print} + unbound_names = {"unbound_ref"} + expected = inspect.ClosureVars(nonlocal_vars, global_vars, + builtin_vars, unbound_names) + self.assertEqual(inspect.getclosurevars(C().f(_arg)), expected) + + def test_nonlocal_vars(self): + # More complex tests of nonlocal resolution + def _nonlocal_vars(f): + return inspect.getclosurevars(f).nonlocals + + def make_adder(x): + def add(y): + return x + y + return add + + def curry(func, arg1): + return lambda arg2: func(arg1, arg2) + + def less_than(a, b): + return a < b + + # The infamous Y combinator. + def Y(le): + def g(f): + return le(lambda x: f(f)(x)) + Y.g_ref = g + return g(g) + + def check_y_combinator(func): + self.assertEqual(_nonlocal_vars(func), {'f': Y.g_ref}) + + inc = make_adder(1) + add_two = make_adder(2) + greater_than_five = curry(less_than, 5) + + self.assertEqual(_nonlocal_vars(inc), {'x': 1}) + self.assertEqual(_nonlocal_vars(add_two), {'x': 2}) + self.assertEqual(_nonlocal_vars(greater_than_five), + {'arg1': 5, 'func': less_than}) + self.assertEqual(_nonlocal_vars((lambda x: lambda y: x + y)(3)), + {'x': 3}) + Y(check_y_combinator) + + def test_getclosurevars_empty(self): + def foo(): pass + _empty = inspect.ClosureVars({}, {}, {}, set()) + self.assertEqual(inspect.getclosurevars(lambda: True), _empty) + self.assertEqual(inspect.getclosurevars(foo), _empty) + + def test_getclosurevars_error(self): + class T: pass + self.assertRaises(TypeError, inspect.getclosurevars, 1) + self.assertRaises(TypeError, inspect.getclosurevars, list) + self.assertRaises(TypeError, inspect.getclosurevars, {}) + + class TestGetcallargsFunctions(unittest.TestCase): def assertEqualCallArgs(self, func, call_params_string, locs=None): @@ -2100,7 +2199,7 @@ def test_main(): TestGetcallargsFunctions, TestGetcallargsMethods, TestGetcallargsUnboundMethods, TestGetattrStatic, TestGetGeneratorState, TestNoEOL, TestSignatureObject, TestSignatureBind, TestParameterObject, - TestBoundArguments + TestBoundArguments, TestGetClosureVars ) if __name__ == "__main__": |