summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_generators.py33
-rw-r--r--Python/compile.c63
2 files changed, 94 insertions, 2 deletions
diff --git a/Lib/test/test_generators.py b/Lib/test/test_generators.py
index 3dd468b..7b78b2b 100644
--- a/Lib/test/test_generators.py
+++ b/Lib/test/test_generators.py
@@ -652,6 +652,17 @@ But this is fine:
[12, 666]
>>> def f():
+... yield
+Traceback (most recent call last):
+SyntaxError: invalid syntax
+
+>>> def f():
+... if 0:
+... yield
+Traceback (most recent call last):
+SyntaxError: invalid syntax
+
+>>> def f():
... if 0:
... yield 1
>>> type(f())
@@ -704,6 +715,28 @@ But this is fine:
... yield 2
>>> type(f())
<type 'None'>
+
+>>> def f():
+... if 0:
+... return
+... if 0:
+... yield 2
+>>> type(f())
+<type 'generator'>
+
+
+>>> def f():
+... if 0:
+... lambda x: x # shouldn't trigger here
+... return # or here
+... def f(i):
+... return 2*i # or here
+... if 0:
+... return 3 # but *this* sucks (line 8)
+... if 0:
+... yield 2 # because it's a generator
+Traceback (most recent call last):
+SyntaxError: 'return' with argument inside generator (<string>, line 8)
"""
__test__ = {"tut": tutorial_tests,
diff --git a/Python/compile.c b/Python/compile.c
index e82c34c..92322fc 100644
--- a/Python/compile.c
+++ b/Python/compile.c
@@ -2876,6 +2876,45 @@ is_constant_false(struct compiling *c, node *n)
return 0;
}
+
+/* Look under n for a return stmt with an expression.
+ * This hack is used to find illegal returns under "if 0:" blocks in
+ * functions already known to be generators (as determined by the symtable
+ * pass).
+ * Return the offending return node if found, else NULL.
+ */
+static node *
+look_for_offending_return(node *n)
+{
+ int i;
+
+ for (i = 0; i < NCH(n); ++i) {
+ node *kid = CHILD(n, i);
+
+ switch (TYPE(kid)) {
+ case classdef:
+ case funcdef:
+ case lambdef:
+ /* Stuff in nested functions & classes doesn't
+ affect the code block we started in. */
+ return NULL;
+
+ case return_stmt:
+ if (NCH(kid) > 1)
+ return kid;
+ break;
+
+ default: {
+ node *bad = look_for_offending_return(kid);
+ if (bad != NULL)
+ return bad;
+ }
+ }
+ }
+
+ return NULL;
+}
+
static void
com_if_stmt(struct compiling *c, node *n)
{
@@ -2886,8 +2925,24 @@ com_if_stmt(struct compiling *c, node *n)
for (i = 0; i+3 < NCH(n); i+=4) {
int a = 0;
node *ch = CHILD(n, i+1);
- if (is_constant_false(c, ch))
+ if (is_constant_false(c, ch)) {
+ /* We're going to skip this block. However, if this
+ is a generator, we have to check the dead code
+ anyway to make sure there aren't any return stmts
+ with expressions, in the same scope. */
+ if (c->c_flags & CO_GENERATOR) {
+ node *p = look_for_offending_return(n);
+ if (p != NULL) {
+ int savelineno = c->c_lineno;
+ c->c_lineno = p->n_lineno;
+ com_error(c, PyExc_SyntaxError,
+ "'return' with argument "
+ "inside generator");
+ c->c_lineno = savelineno;
+ }
+ }
continue;
+ }
if (i > 0)
com_addoparg(c, SET_LINENO, ch->n_lineno);
com_node(c, ch);
@@ -4840,7 +4895,10 @@ symtable_add_def_o(struct symtable *st, PyObject *dict,
#define symtable_add_use(ST, NAME) symtable_add_def((ST), (NAME), USE)
-/* Look for a yield stmt under n. Return 1 if found, else 0. */
+/* Look for a yield stmt under n. Return 1 if found, else 0.
+ This hack is used to look inside "if 0:" blocks (which are normally
+ ignored) in case those are the only places a yield occurs (so that this
+ function is a generator). */
static int
look_for_yield(node *n)
{
@@ -4853,6 +4911,7 @@ look_for_yield(node *n)
case classdef:
case funcdef:
+ case lambdef:
/* Stuff in nested functions and classes can't make
the parent a generator. */
return 0;