diff options
Diffstat (limited to 'Lib/lib2to3/fixes')
-rw-r--r-- | Lib/lib2to3/fixes/fix_dict.py | 6 | ||||
-rw-r--r-- | Lib/lib2to3/fixes/fix_xrange.py | 45 | ||||
-rw-r--r-- | Lib/lib2to3/fixes/util.py | 4 |
3 files changed, 50 insertions, 5 deletions
diff --git a/Lib/lib2to3/fixes/fix_dict.py b/Lib/lib2to3/fixes/fix_dict.py index f76ceb4..c14a819 100644 --- a/Lib/lib2to3/fixes/fix_dict.py +++ b/Lib/lib2to3/fixes/fix_dict.py @@ -29,10 +29,10 @@ from .. import patcomp from ..pgen2 import token from . import basefix from .util import Name, Call, LParen, RParen, ArgList, Dot, set +from . import util -exempt = set(["sorted", "list", "set", "any", "all", "tuple", "sum"]) -iter_exempt = exempt | set(["iter"]) +iter_exempt = util.consuming_calls | set(["iter"]) class FixDict(basefix.BaseFix): @@ -92,7 +92,7 @@ class FixDict(basefix.BaseFix): return results["func"].value in iter_exempt else: # list(d.keys()) -> list(d.keys()), etc. - return results["func"].value in exempt + return results["func"].value in util.consuming_calls if not isiter: return False # for ... in d.iterkeys() -> for ... in d.keys(), etc. diff --git a/Lib/lib2to3/fixes/fix_xrange.py b/Lib/lib2to3/fixes/fix_xrange.py index 410e601..2e4040e 100644 --- a/Lib/lib2to3/fixes/fix_xrange.py +++ b/Lib/lib2to3/fixes/fix_xrange.py @@ -5,14 +5,55 @@ # Local imports from .import basefix -from .util import Name +from .util import Name, Call, consuming_calls +from .. import patcomp + class FixXrange(basefix.BaseFix): PATTERN = """ - power< name='xrange' trailer< '(' [any] ')' > > + power< (name='range'|name='xrange') trailer< '(' [any] ')' > any* > """ def transform(self, node, results): name = results["name"] + if name.value == "xrange": + return self.transform_xrange(node, results) + elif name.value == "range": + return self.transform_range(node, results) + else: + raise ValueError(repr(name)) + + def transform_xrange(self, node, results): + name = results["name"] name.replace(Name("range", prefix=name.get_prefix())) + + def transform_range(self, node, results): + if not self.in_special_context(node): + arg = node.clone() + arg.set_prefix("") + call = Call(Name("list"), [arg]) + call.set_prefix(node.get_prefix()) + return call + return node + + P1 = "power< func=NAME trailer< '(' node=any ')' > any* >" + p1 = patcomp.compile_pattern(P1) + + P2 = """for_stmt< 'for' any 'in' node=any ':' any* > + | comp_for< 'for' any 'in' node=any any* > + | comparison< any 'in' node=any any*> + """ + p2 = patcomp.compile_pattern(P2) + + def in_special_context(self, node): + if node.parent is None: + return False + results = {} + if (node.parent.parent is not None and + self.p1.match(node.parent.parent, results) and + results["node"] is node): + # list(d.keys()) -> list(d.keys()), etc. + return results["func"].value in consuming_calls + # for ... in d.iterkeys() -> for ... in d.keys(), etc. + return self.p2.match(node.parent, results) and results["node"] is node diff --git a/Lib/lib2to3/fixes/util.py b/Lib/lib2to3/fixes/util.py index b48aeb3..c977237 100644 --- a/Lib/lib2to3/fixes/util.py +++ b/Lib/lib2to3/fixes/util.py @@ -182,6 +182,10 @@ except NameError: ### Misc ########################################################### + +consuming_calls = set(["sorted", "list", "set", "any", "all", "tuple", "sum", + "min", "max"]) + def attr_chain(obj, attr): """Follow an attribute chain. |