diff options
Diffstat (limited to 'Lib/lib2to3')
-rw-r--r-- | Lib/lib2to3/fixes/fix_exitfunc.py | 70 | ||||
-rw-r--r-- | Lib/lib2to3/pgen2/tokenize.py | 19 | ||||
-rw-r--r-- | Lib/lib2to3/tests/test_fixers.py | 88 | ||||
-rw-r--r-- | Lib/lib2to3/tests/test_parser.py | 3 |
4 files changed, 170 insertions, 10 deletions
diff --git a/Lib/lib2to3/fixes/fix_exitfunc.py b/Lib/lib2to3/fixes/fix_exitfunc.py new file mode 100644 index 0000000..5203821 --- /dev/null +++ b/Lib/lib2to3/fixes/fix_exitfunc.py @@ -0,0 +1,70 @@ +""" +Convert use of sys.exitfunc to use the atexit module. +""" + +# Author: Benjamin Peterson + +from lib2to3 import pytree, fixer_base +from lib2to3.fixer_util import Name, Attr, Call, Comma, Newline, syms + + +class FixExitfunc(fixer_base.BaseFix): + + PATTERN = """ + ( + sys_import=import_name<'import' + ('sys' + | + dotted_as_names< (any ',')* 'sys' (',' any)* > + ) + > + | + expr_stmt< + power< 'sys' trailer< '.' 'exitfunc' > > + '=' func=any > + ) + """ + + def __init__(self, *args): + super(FixExitfunc, self).__init__(*args) + + def start_tree(self, tree, filename): + super(FixExitfunc, self).start_tree(tree, filename) + self.sys_import = None + + def transform(self, node, results): + # First, find a the sys import. We'll just hope it's global scope. + if "sys_import" in results: + if self.sys_import is None: + self.sys_import = results["sys_import"] + return + + func = results["func"].clone() + func.prefix = "" + register = pytree.Node(syms.power, + Attr(Name("atexit"), Name("register")) + ) + call = Call(register, [func], node.prefix) + node.replace(call) + + if self.sys_import is None: + # That's interesting. + self.warning(node, "Can't find sys import; Please add an atexit " + "import at the top of your file.") + return + + # Now add an atexit import after the sys import. + names = self.sys_import.children[1] + if names.type == syms.dotted_as_names: + names.append_child(Comma()) + names.append_child(Name("atexit", " ")) + else: + containing_stmt = self.sys_import.parent + position = containing_stmt.children.index(self.sys_import) + stmt_container = containing_stmt.parent + new_import = pytree.Node(syms.import_name, + [Name("import"), Name("atexit", " ")] + ) + new = pytree.Node(syms.simple_stmt, [new_import]) + containing_stmt.insert_child(position + 1, Newline()) + containing_stmt.insert_child(position + 2, new) diff --git a/Lib/lib2to3/pgen2/tokenize.py b/Lib/lib2to3/pgen2/tokenize.py index 7ae0280..701daf8 100644 --- a/Lib/lib2to3/pgen2/tokenize.py +++ b/Lib/lib2to3/pgen2/tokenize.py @@ -253,14 +253,16 @@ def detect_encoding(readline): in. It detects the encoding from the presence of a utf-8 bom or an encoding - cookie as specified in pep-0263. If both a bom and a cookie are present, - but disagree, a SyntaxError will be raised. If the encoding cookie is an - invalid charset, raise a SyntaxError. + cookie as specified in pep-0263. If both a bom and a cookie are present, but + disagree, a SyntaxError will be raised. If the encoding cookie is an invalid + charset, raise a SyntaxError. Note that if a utf-8 bom is found, + 'utf-8-sig' is returned. If no encoding is specified, then the default of 'utf-8' will be returned. """ bom_found = False encoding = None + default = 'utf-8' def read_or_stop(): try: return readline() @@ -287,17 +289,16 @@ def detect_encoding(readline): if codec.name != 'utf-8': # This behaviour mimics the Python interpreter raise SyntaxError('encoding problem: utf-8') - else: - # Allow it to be properly encoded and decoded. - encoding = 'utf-8-sig' + encoding += '-sig' return encoding first = read_or_stop() if first.startswith(BOM_UTF8): bom_found = True first = first[3:] + default = 'utf-8-sig' if not first: - return 'utf-8', [] + return default, [] encoding = find_cookie(first) if encoding: @@ -305,13 +306,13 @@ def detect_encoding(readline): second = read_or_stop() if not second: - return 'utf-8', [first] + return default, [first] encoding = find_cookie(second) if encoding: return encoding, [first, second] - return 'utf-8', [first, second] + return default, [first, second] def untokenize(iterable): """Transform tokens back into Python source code. diff --git a/Lib/lib2to3/tests/test_fixers.py b/Lib/lib2to3/tests/test_fixers.py index a92f14a..b28c35f 100644 --- a/Lib/lib2to3/tests/test_fixers.py +++ b/Lib/lib2to3/tests/test_fixers.py @@ -4285,3 +4285,91 @@ class Test_operator(FixerTestCase): def test_bare_sequenceIncludes(self): s = "sequenceIncludes(x, y)" self.warns_unchanged(s, "You should use operator.contains here.") + + +class Test_exitfunc(FixerTestCase): + + fixer = "exitfunc" + + def test_simple(self): + b = """ + import sys + sys.exitfunc = my_atexit + """ + a = """ + import sys + import atexit + atexit.register(my_atexit) + """ + self.check(b, a) + + def test_names_import(self): + b = """ + import sys, crumbs + sys.exitfunc = my_func + """ + a = """ + import sys, crumbs, atexit + atexit.register(my_func) + """ + self.check(b, a) + + def test_complex_expression(self): + b = """ + import sys + sys.exitfunc = do(d)/a()+complex(f=23, g=23)*expression + """ + a = """ + import sys + import atexit + atexit.register(do(d)/a()+complex(f=23, g=23)*expression) + """ + self.check(b, a) + + def test_comments(self): + b = """ + import sys # Foo + sys.exitfunc = f # Blah + """ + a = """ + import sys + import atexit # Foo + atexit.register(f) # Blah + """ + self.check(b, a) + + b = """ + import apples, sys, crumbs, larry # Pleasant comments + sys.exitfunc = func + """ + a = """ + import apples, sys, crumbs, larry, atexit # Pleasant comments + atexit.register(func) + """ + self.check(b, a) + + def test_in_a_function(self): + b = """ + import sys + def f(): + sys.exitfunc = func + """ + a = """ + import sys + import atexit + def f(): + atexit.register(func) + """ + self.check(b, a) + + def test_no_sys_import(self): + b = """sys.exitfunc = f""" + a = """atexit.register(f)""" + msg = ("Can't find sys import; Please add an atexit import at the " + "top of your file.") + self.warns(b, a, msg) + + + def test_unchanged(self): + s = """f(sys.exitfunc)""" + self.unchanged(s) diff --git a/Lib/lib2to3/tests/test_parser.py b/Lib/lib2to3/tests/test_parser.py index 15b109e9..06f3227 100644 --- a/Lib/lib2to3/tests/test_parser.py +++ b/Lib/lib2to3/tests/test_parser.py @@ -206,6 +206,7 @@ def diff(fn, result): finally: f.close() try: - return os.system("diff -u %r @" % fn) + fn = fn.replace('"', '\\"') + return os.system('diff -u "%s" @' % fn) finally: os.remove("@") |