diff options
Diffstat (limited to 'Lib/lib2to3/tests/test_refactor.py')
-rw-r--r-- | Lib/lib2to3/tests/test_refactor.py | 80 |
1 files changed, 45 insertions, 35 deletions
diff --git a/Lib/lib2to3/tests/test_refactor.py b/Lib/lib2to3/tests/test_refactor.py index 49fb0c0..35efe25 100644 --- a/Lib/lib2to3/tests/test_refactor.py +++ b/Lib/lib2to3/tests/test_refactor.py @@ -2,6 +2,8 @@ Unit tests for refactor.py. """ +from __future__ import with_statement + import sys import os import codecs @@ -61,42 +63,50 @@ class TestRefactoringTool(unittest.TestCase): self.assertEqual(full_names, ["myfixes.fix_" + name for name in contents]) - def test_detect_future_print(self): - run = refactor._detect_future_print - self.assertFalse(run("")) - self.assertTrue(run("from __future__ import print_function")) - self.assertFalse(run("from __future__ import generators")) - self.assertFalse(run("from __future__ import generators, feature")) - input = "from __future__ import generators, print_function" - self.assertTrue(run(input)) - input ="from __future__ import print_function, generators" - self.assertTrue(run(input)) - input = "from __future__ import (print_function,)" - self.assertTrue(run(input)) - input = "from __future__ import (generators, print_function)" - self.assertTrue(run(input)) - input = "from __future__ import (generators, nested_scopes)" - self.assertFalse(run(input)) - input = """from __future__ import generators + def test_detect_future_features(self): + run = refactor._detect_future_features + fs = frozenset + empty = fs() + self.assertEqual(run(""), empty) + self.assertEqual(run("from __future__ import print_function"), + fs(("print_function",))) + self.assertEqual(run("from __future__ import generators"), + fs(("generators",))) + self.assertEqual(run("from __future__ import generators, feature"), + fs(("generators", "feature"))) + inp = "from __future__ import generators, print_function" + self.assertEqual(run(inp), fs(("generators", "print_function"))) + inp ="from __future__ import print_function, generators" + self.assertEqual(run(inp), fs(("print_function", "generators"))) + inp = "from __future__ import (print_function,)" + self.assertEqual(run(inp), fs(("print_function",))) + inp = "from __future__ import (generators, print_function)" + self.assertEqual(run(inp), fs(("generators", "print_function"))) + inp = "from __future__ import (generators, nested_scopes)" + self.assertEqual(run(inp), fs(("generators", "nested_scopes"))) + inp = """from __future__ import generators from __future__ import print_function""" - self.assertTrue(run(input)) - self.assertFalse(run("from")) - self.assertFalse(run("from 4")) - self.assertFalse(run("from x")) - self.assertFalse(run("from x 5")) - self.assertFalse(run("from x im")) - self.assertFalse(run("from x import")) - self.assertFalse(run("from x import 4")) - input = "'docstring'\nfrom __future__ import print_function" - self.assertTrue(run(input)) - input = "'docstring'\n'somng'\nfrom __future__ import print_function" - self.assertFalse(run(input)) - input = "# comment\nfrom __future__ import print_function" - self.assertTrue(run(input)) - input = "# comment\n'doc'\nfrom __future__ import print_function" - self.assertTrue(run(input)) - input = "class x: pass\nfrom __future__ import print_function" - self.assertFalse(run(input)) + self.assertEqual(run(inp), fs(("generators", "print_function"))) + invalid = ("from", + "from 4", + "from x", + "from x 5", + "from x im", + "from x import", + "from x import 4", + ) + for inp in invalid: + self.assertEqual(run(inp), empty) + inp = "'docstring'\nfrom __future__ import print_function" + self.assertEqual(run(inp), fs(("print_function",))) + inp = "'docstring'\n'somng'\nfrom __future__ import print_function" + self.assertEqual(run(inp), empty) + inp = "# comment\nfrom __future__ import print_function" + self.assertEqual(run(inp), fs(("print_function",))) + inp = "# comment\n'doc'\nfrom __future__ import print_function" + self.assertEqual(run(inp), fs(("print_function",))) + inp = "class x: pass\nfrom __future__ import print_function" + self.assertEqual(run(inp), empty) def test_get_headnode_dict(self): class NoneFix(fixer_base.BaseFix): |