summaryrefslogtreecommitdiffstats
path: root/Lib/lib2to3
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/lib2to3')
-rw-r--r--Lib/lib2to3/fixes/fix_exitfunc.py70
-rw-r--r--Lib/lib2to3/pgen2/tokenize.py19
-rw-r--r--Lib/lib2to3/tests/test_fixers.py88
-rw-r--r--Lib/lib2to3/tests/test_parser.py3
4 files changed, 170 insertions, 10 deletions
diff --git a/Lib/lib2to3/fixes/fix_exitfunc.py b/Lib/lib2to3/fixes/fix_exitfunc.py
new file mode 100644
index 0000000..5203821
--- /dev/null
+++ b/Lib/lib2to3/fixes/fix_exitfunc.py
@@ -0,0 +1,70 @@
+"""
+Convert use of sys.exitfunc to use the atexit module.
+"""
+
+# Author: Benjamin Peterson
+
+from lib2to3 import pytree, fixer_base
+from lib2to3.fixer_util import Name, Attr, Call, Comma, Newline, syms
+
+
+class FixExitfunc(fixer_base.BaseFix):
+
+ PATTERN = """
+ (
+ sys_import=import_name<'import'
+ ('sys'
+ |
+ dotted_as_names< (any ',')* 'sys' (',' any)* >
+ )
+ >
+ |
+ expr_stmt<
+ power< 'sys' trailer< '.' 'exitfunc' > >
+ '=' func=any >
+ )
+ """
+
+ def __init__(self, *args):
+ super(FixExitfunc, self).__init__(*args)
+
+ def start_tree(self, tree, filename):
+ super(FixExitfunc, self).start_tree(tree, filename)
+ self.sys_import = None
+
+ def transform(self, node, results):
+ # First, find a the sys import. We'll just hope it's global scope.
+ if "sys_import" in results:
+ if self.sys_import is None:
+ self.sys_import = results["sys_import"]
+ return
+
+ func = results["func"].clone()
+ func.prefix = ""
+ register = pytree.Node(syms.power,
+ Attr(Name("atexit"), Name("register"))
+ )
+ call = Call(register, [func], node.prefix)
+ node.replace(call)
+
+ if self.sys_import is None:
+ # That's interesting.
+ self.warning(node, "Can't find sys import; Please add an atexit "
+ "import at the top of your file.")
+ return
+
+ # Now add an atexit import after the sys import.
+ names = self.sys_import.children[1]
+ if names.type == syms.dotted_as_names:
+ names.append_child(Comma())
+ names.append_child(Name("atexit", " "))
+ else:
+ containing_stmt = self.sys_import.parent
+ position = containing_stmt.children.index(self.sys_import)
+ stmt_container = containing_stmt.parent
+ new_import = pytree.Node(syms.import_name,
+ [Name("import"), Name("atexit", " ")]
+ )
+ new = pytree.Node(syms.simple_stmt, [new_import])
+ containing_stmt.insert_child(position + 1, Newline())
+ containing_stmt.insert_child(position + 2, new)
diff --git a/Lib/lib2to3/pgen2/tokenize.py b/Lib/lib2to3/pgen2/tokenize.py
index 7ae0280..701daf8 100644
--- a/Lib/lib2to3/pgen2/tokenize.py
+++ b/Lib/lib2to3/pgen2/tokenize.py
@@ -253,14 +253,16 @@ def detect_encoding(readline):
in.
It detects the encoding from the presence of a utf-8 bom or an encoding
- cookie as specified in pep-0263. If both a bom and a cookie are present,
- but disagree, a SyntaxError will be raised. If the encoding cookie is an
- invalid charset, raise a SyntaxError.
+ cookie as specified in pep-0263. If both a bom and a cookie are present, but
+ disagree, a SyntaxError will be raised. If the encoding cookie is an invalid
+ charset, raise a SyntaxError. Note that if a utf-8 bom is found,
+ 'utf-8-sig' is returned.
If no encoding is specified, then the default of 'utf-8' will be returned.
"""
bom_found = False
encoding = None
+ default = 'utf-8'
def read_or_stop():
try:
return readline()
@@ -287,17 +289,16 @@ def detect_encoding(readline):
if codec.name != 'utf-8':
# This behaviour mimics the Python interpreter
raise SyntaxError('encoding problem: utf-8')
- else:
- # Allow it to be properly encoded and decoded.
- encoding = 'utf-8-sig'
+ encoding += '-sig'
return encoding
first = read_or_stop()
if first.startswith(BOM_UTF8):
bom_found = True
first = first[3:]
+ default = 'utf-8-sig'
if not first:
- return 'utf-8', []
+ return default, []
encoding = find_cookie(first)
if encoding:
@@ -305,13 +306,13 @@ def detect_encoding(readline):
second = read_or_stop()
if not second:
- return 'utf-8', [first]
+ return default, [first]
encoding = find_cookie(second)
if encoding:
return encoding, [first, second]
- return 'utf-8', [first, second]
+ return default, [first, second]
def untokenize(iterable):
"""Transform tokens back into Python source code.
diff --git a/Lib/lib2to3/tests/test_fixers.py b/Lib/lib2to3/tests/test_fixers.py
index a92f14a..b28c35f 100644
--- a/Lib/lib2to3/tests/test_fixers.py
+++ b/Lib/lib2to3/tests/test_fixers.py
@@ -4285,3 +4285,91 @@ class Test_operator(FixerTestCase):
def test_bare_sequenceIncludes(self):
s = "sequenceIncludes(x, y)"
self.warns_unchanged(s, "You should use operator.contains here.")
+
+
+class Test_exitfunc(FixerTestCase):
+
+ fixer = "exitfunc"
+
+ def test_simple(self):
+ b = """
+ import sys
+ sys.exitfunc = my_atexit
+ """
+ a = """
+ import sys
+ import atexit
+ atexit.register(my_atexit)
+ """
+ self.check(b, a)
+
+ def test_names_import(self):
+ b = """
+ import sys, crumbs
+ sys.exitfunc = my_func
+ """
+ a = """
+ import sys, crumbs, atexit
+ atexit.register(my_func)
+ """
+ self.check(b, a)
+
+ def test_complex_expression(self):
+ b = """
+ import sys
+ sys.exitfunc = do(d)/a()+complex(f=23, g=23)*expression
+ """
+ a = """
+ import sys
+ import atexit
+ atexit.register(do(d)/a()+complex(f=23, g=23)*expression)
+ """
+ self.check(b, a)
+
+ def test_comments(self):
+ b = """
+ import sys # Foo
+ sys.exitfunc = f # Blah
+ """
+ a = """
+ import sys
+ import atexit # Foo
+ atexit.register(f) # Blah
+ """
+ self.check(b, a)
+
+ b = """
+ import apples, sys, crumbs, larry # Pleasant comments
+ sys.exitfunc = func
+ """
+ a = """
+ import apples, sys, crumbs, larry, atexit # Pleasant comments
+ atexit.register(func)
+ """
+ self.check(b, a)
+
+ def test_in_a_function(self):
+ b = """
+ import sys
+ def f():
+ sys.exitfunc = func
+ """
+ a = """
+ import sys
+ import atexit
+ def f():
+ atexit.register(func)
+ """
+ self.check(b, a)
+
+ def test_no_sys_import(self):
+ b = """sys.exitfunc = f"""
+ a = """atexit.register(f)"""
+ msg = ("Can't find sys import; Please add an atexit import at the "
+ "top of your file.")
+ self.warns(b, a, msg)
+
+
+ def test_unchanged(self):
+ s = """f(sys.exitfunc)"""
+ self.unchanged(s)
diff --git a/Lib/lib2to3/tests/test_parser.py b/Lib/lib2to3/tests/test_parser.py
index 15b109e9..06f3227 100644
--- a/Lib/lib2to3/tests/test_parser.py
+++ b/Lib/lib2to3/tests/test_parser.py
@@ -206,6 +206,7 @@ def diff(fn, result):
finally:
f.close()
try:
- return os.system("diff -u %r @" % fn)
+ fn = fn.replace('"', '\\"')
+ return os.system('diff -u "%s" @' % fn)
finally:
os.remove("@")