diff options
Diffstat (limited to 'Python/ast.c')
| -rw-r--r-- | Python/ast.c | 1400 | 
1 files changed, 1206 insertions, 194 deletions
| diff --git a/Python/ast.c b/Python/ast.c index 6da33f7..76daf6f 100644 --- a/Python/ast.c +++ b/Python/ast.c @@ -132,6 +132,52 @@ validate_arguments(arguments_ty args)  }  static int +validate_constant(PyObject *value) +{ +    if (value == Py_None || value == Py_Ellipsis) +        return 1; + +    if (PyLong_CheckExact(value) +            || PyFloat_CheckExact(value) +            || PyComplex_CheckExact(value) +            || PyBool_Check(value) +            || PyUnicode_CheckExact(value) +            || PyBytes_CheckExact(value)) +        return 1; + +    if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) { +        PyObject *it; + +        it = PyObject_GetIter(value); +        if (it == NULL) +            return 0; + +        while (1) { +            PyObject *item = PyIter_Next(it); +            if (item == NULL) { +                if (PyErr_Occurred()) { +                    Py_DECREF(it); +                    return 0; +                } +                break; +            } + +            if (!validate_constant(item)) { +                Py_DECREF(it); +                Py_DECREF(item); +                return 0; +            } +            Py_DECREF(item); +        } + +        Py_DECREF(it); +        return 1; +    } + +    return 0; +} + +static int  validate_expr(expr_ty exp, expr_context_ty ctx)  {      int check_ctx = 1; @@ -240,6 +286,14 @@ validate_expr(expr_ty exp, expr_context_ty ctx)          return validate_expr(exp->v.Call.func, Load) &&              validate_exprs(exp->v.Call.args, Load, 0) &&              validate_keywords(exp->v.Call.keywords); +    case Constant_kind: +        if (!validate_constant(exp->v.Constant.value)) { +            PyErr_Format(PyExc_TypeError, +                         "got an invalid type in Constant: %s", +                         Py_TYPE(exp->v.Constant.value)->tp_name); +            return 0; +        } +        return 1;      case Num_kind: {          PyObject *n = exp->v.Num.n;          if (!PyLong_CheckExact(n) && !PyFloat_CheckExact(n) && @@ -257,6 +311,14 @@ validate_expr(expr_ty exp, expr_context_ty ctx)          }          return 1;      } +    case JoinedStr_kind: +        return validate_exprs(exp->v.JoinedStr.values, Load, 0); +    case FormattedValue_kind: +        if (validate_expr(exp->v.FormattedValue.value, Load) == 0) +            return 0; +        if (exp->v.FormattedValue.format_spec) +            return validate_expr(exp->v.FormattedValue.format_spec, Load); +        return 1;      case Bytes_kind: {          PyObject *b = exp->v.Bytes.s;          if (!PyBytes_CheckExact(b)) { @@ -335,6 +397,17 @@ validate_stmt(stmt_ty stmt)      case AugAssign_kind:          return validate_expr(stmt->v.AugAssign.target, Store) &&              validate_expr(stmt->v.AugAssign.value, Load); +    case AnnAssign_kind: +        if (stmt->v.AnnAssign.target->kind != Name_kind && +            stmt->v.AnnAssign.simple) { +            PyErr_SetString(PyExc_TypeError, +                            "AnnAssign with simple non-Name target"); +            return 0; +        } +        return validate_expr(stmt->v.AnnAssign.target, Store) && +               (!stmt->v.AnnAssign.value || +                validate_expr(stmt->v.AnnAssign.value, Load)) && +               validate_expr(stmt->v.AnnAssign.annotation, Load);      case For_kind:          return validate_expr(stmt->v.For.target, Store) &&              validate_expr(stmt->v.For.iter, Load) && @@ -413,8 +486,8 @@ validate_stmt(stmt_ty stmt)      case Import_kind:          return validate_nonempty_seq(stmt->v.Import.names, "names", "Import");      case ImportFrom_kind: -        if (stmt->v.ImportFrom.level < -1) { -            PyErr_SetString(PyExc_ValueError, "ImportFrom level less than -1"); +        if (stmt->v.ImportFrom.level < 0) { +            PyErr_SetString(PyExc_ValueError, "Negative ImportFrom level");              return 0;          }          return validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom"); @@ -512,8 +585,7 @@ PyAST_Validate(mod_ty mod)  /* Data structure used internally */  struct compiling { -    char *c_encoding; /* source encoding */ -    PyArena *c_arena; /* arena for allocating memeory */ +    PyArena *c_arena; /* Arena for allocating memory. */      PyObject *c_filename; /* filename */      PyObject *c_normalize; /* Normalization function from unicodedata. */      PyObject *c_normalize_args; /* Normalization argument tuple. */ @@ -535,9 +607,7 @@ static stmt_ty ast_for_for_stmt(struct compiling *, const node *, int);  static expr_ty ast_for_call(struct compiling *, const node *, expr_ty);  static PyObject *parsenumber(struct compiling *, const char *); -static PyObject *parsestr(struct compiling *, const node *n, int *bytesmode); -static PyObject *parsestrplus(struct compiling *, const node *n, -                              int *bytesmode); +static expr_ty parsestrplus(struct compiling *, const node *n);  #define COMP_GENEXP   0  #define COMP_LISTCOMP 1 @@ -701,23 +771,11 @@ PyAST_FromNodeObject(const node *n, PyCompilerFlags *flags,      c.c_arena = arena;      /* borrowed reference */      c.c_filename = filename; -    c.c_normalize = c.c_normalize_args = NULL; -    if (flags && flags->cf_flags & PyCF_SOURCE_IS_UTF8) { -        c.c_encoding = "utf-8"; -        if (TYPE(n) == encoding_decl) { -#if 0 -            ast_error(c, n, "encoding declaration in Unicode string"); -            goto out; -#endif -            n = CHILD(n, 0); -        } -    } else if (TYPE(n) == encoding_decl) { -        c.c_encoding = STR(n); +    c.c_normalize = NULL; +    c.c_normalize_args = NULL; + +    if (TYPE(n) == encoding_decl)          n = CHILD(n, 0); -    } else { -        /* PEP 3120 */ -        c.c_encoding = "utf-8"; -    }      k = 0;      switch (TYPE(n)) { @@ -864,7 +922,7 @@ get_operator(const node *n)      }  } -static const char* FORBIDDEN[] = { +static const char * const FORBIDDEN[] = {      "None",      "True",      "False", @@ -880,8 +938,28 @@ forbidden_name(struct compiling *c, identifier name, const node *n,          ast_error(c, n, "assignment to keyword");          return 1;      } +    if (PyUnicode_CompareWithASCIIString(name, "async") == 0 || +        PyUnicode_CompareWithASCIIString(name, "await") == 0) +    { +        PyObject *message = PyUnicode_FromString( +            "'async' and 'await' will become reserved keywords" +            " in Python 3.7"); +        if (message == NULL) { +            return 1; +        } +        if (PyErr_WarnExplicitObject( +                PyExc_DeprecationWarning, +                message, +                c->c_filename, +                LINENO(n), +                NULL, +                NULL) < 0) +        { +            return 1; +        } +    }      if (full_checks) { -        const char **p; +        const char * const *p;          for (p = FORBIDDEN; *p; p++) {              if (PyUnicode_CompareWithASCIIString(name, *p) == 0) {                  ast_error(c, n, "assignment to keyword"); @@ -943,13 +1021,8 @@ set_context(struct compiling *c, expr_ty e, expr_context_ty ctx, const node *n)              s = e->v.List.elts;              break;          case Tuple_kind: -            if (asdl_seq_LEN(e->v.Tuple.elts))  { -                e->v.Tuple.ctx = ctx; -                s = e->v.Tuple.elts; -            } -            else { -                expr_name = "()"; -            } +            e->v.Tuple.ctx = ctx; +            s = e->v.Tuple.elts;              break;          case Lambda_kind:              expr_name = "lambda"; @@ -986,6 +1059,8 @@ set_context(struct compiling *c, expr_ty e, expr_context_ty ctx, const node *n)          case Num_kind:          case Str_kind:          case Bytes_kind: +        case JoinedStr_kind: +        case FormattedValue_kind:              expr_name = "literal";              break;          case NameConstant_kind: @@ -1259,16 +1334,20 @@ ast_for_arguments(struct compiling *c, const node *n)         and varargslist (lambda definition).         parameters: '(' [typedargslist] ')' -       typedargslist: ((tfpdef ['=' test] ',')* -           ('*' [tfpdef] (',' tfpdef ['=' test])* [',' '**' tfpdef] -           | '**' tfpdef) -           | tfpdef ['=' test] (',' tfpdef ['=' test])* [',']) +       typedargslist: (tfpdef ['=' test] (',' tfpdef ['=' test])* [',' [ +               '*' [tfpdef] (',' tfpdef ['=' test])* [',' ['**' tfpdef [',']]] +             | '**' tfpdef [',']]] +         | '*' [tfpdef] (',' tfpdef ['=' test])* [',' ['**' tfpdef [',']]] +         | '**' tfpdef [','])         tfpdef: NAME [':' test] -       varargslist: ((vfpdef ['=' test] ',')* -           ('*' [vfpdef] (',' vfpdef ['=' test])*  [',' '**' vfpdef] -           | '**' vfpdef) -           | vfpdef ['=' test] (',' vfpdef ['=' test])* [',']) +       varargslist: (vfpdef ['=' test] (',' vfpdef ['=' test])* [',' [ +               '*' [vfpdef] (',' vfpdef ['=' test])* [',' ['**' vfpdef [',']]] +             | '**' vfpdef [',']]] +         | '*' [vfpdef] (',' vfpdef ['=' test])* [',' ['**' vfpdef [',']]] +         | '**' vfpdef [','] +       )         vfpdef: NAME +      */      int i, j, k, nposargs = 0, nkwonlyargs = 0;      int nposdefaults = 0, found_default = 0; @@ -1370,7 +1449,8 @@ ast_for_arguments(struct compiling *c, const node *n)                  i += 2; /* the name and the comma */                  break;              case STAR: -                if (i+1 >= NCH(n)) { +                if (i+1 >= NCH(n) || +                    (i+2 == NCH(n) && TYPE(CHILD(n, i+1)) == COMMA)) {                      ast_error(c, CHILD(n, i),                          "named arguments must follow bare *");                      return NULL; @@ -1687,14 +1767,21 @@ static int  count_comp_fors(struct compiling *c, const node *n)  {      int n_fors = 0; +    int is_async;    count_comp_for: +    is_async = 0;      n_fors++;      REQ(n, comp_for); -    if (NCH(n) == 5) -        n = CHILD(n, 4); -    else +    if (TYPE(CHILD(n, 0)) == ASYNC) { +        is_async = 1; +    } +    if (NCH(n) == (5 + is_async)) { +        n = CHILD(n, 4 + is_async); +    } +    else {          return n_fors; +    }    count_comp_iter:      REQ(n, comp_iter);      n = CHILD(n, 0); @@ -1757,14 +1844,19 @@ ast_for_comprehension(struct compiling *c, const node *n)          asdl_seq *t;          expr_ty expression, first;          node *for_ch; +        int is_async = 0;          REQ(n, comp_for); -        for_ch = CHILD(n, 1); +        if (TYPE(CHILD(n, 0)) == ASYNC) { +            is_async = 1; +        } + +        for_ch = CHILD(n, 1 + is_async);          t = ast_for_exprlist(c, for_ch, Store);          if (!t)              return NULL; -        expression = ast_for_expr(c, CHILD(n, 3)); +        expression = ast_for_expr(c, CHILD(n, 3 + is_async));          if (!expression)              return NULL; @@ -1772,19 +1864,20 @@ ast_for_comprehension(struct compiling *c, const node *n)             (x for x, in ...) has 1 element in t, but still requires a Tuple. */          first = (expr_ty)asdl_seq_GET(t, 0);          if (NCH(for_ch) == 1) -            comp = comprehension(first, expression, NULL, c->c_arena); +            comp = comprehension(first, expression, NULL, +                                 is_async, c->c_arena);          else -            comp = comprehension(Tuple(t, Store, first->lineno, first->col_offset, -                                     c->c_arena), -                               expression, NULL, c->c_arena); +            comp = comprehension(Tuple(t, Store, first->lineno, +                                       first->col_offset, c->c_arena), +                                 expression, NULL, is_async, c->c_arena);          if (!comp)              return NULL; -        if (NCH(n) == 5) { +        if (NCH(n) == (5 + is_async)) {              int j, n_ifs;              asdl_seq *ifs; -            n = CHILD(n, 4); +            n = CHILD(n, 4 + is_async);              n_ifs = count_comp_ifs(c, n);              if (n_ifs == -1)                  return NULL; @@ -1993,7 +2086,6 @@ ast_for_atom(struct compiling *c, const node *n)         | '...' | 'None' | 'True' | 'False'      */      node *ch = CHILD(n, 0); -    int bytesmode = 0;      switch (TYPE(ch)) {      case NAME: { @@ -2015,7 +2107,7 @@ ast_for_atom(struct compiling *c, const node *n)          return Name(name, Load, LINENO(n), n->n_col_offset, c->c_arena);      }      case STRING: { -        PyObject *str = parsestrplus(c, n, &bytesmode); +        expr_ty str = parsestrplus(c, n);          if (!str) {              const char *errtype = NULL;              if (PyErr_ExceptionMatches(PyExc_UnicodeError)) @@ -2032,6 +2124,7 @@ ast_for_atom(struct compiling *c, const node *n)                      PyOS_snprintf(buf, sizeof(buf), "(%s) %s", errtype, s);                      Py_DECREF(errstr);                  } else { +                    PyErr_Clear();                      PyOS_snprintf(buf, sizeof(buf), "(%s) unknown error", errtype);                  }                  ast_error(c, n, buf); @@ -2041,14 +2134,7 @@ ast_for_atom(struct compiling *c, const node *n)              }              return NULL;          } -        if (PyArena_AddPyObject(c->c_arena, str) < 0) { -            Py_DECREF(str); -            return NULL; -        } -        if (bytesmode) -            return Bytes(str, LINENO(n), n->n_col_offset, c->c_arena); -        else -            return Str(str, LINENO(n), n->n_col_offset, c->c_arena); +        return str;      }      case NUMBER: {          PyObject *pynum = parsenumber(c, STR(ch)); @@ -2805,8 +2891,9 @@ static stmt_ty  ast_for_expr_stmt(struct compiling *c, const node *n)  {      REQ(n, expr_stmt); -    /* expr_stmt: testlist_star_expr (augassign (yield_expr|testlist) -                | ('=' (yield_expr|testlist))*) +    /* expr_stmt: testlist_star_expr (annassign | augassign (yield_expr|testlist) | +                            ('=' (yield_expr|testlist_star_expr))*) +       annassign: ':' test ['=' test]         testlist_star_expr: (test|star_expr) (',' test|star_expr)* [',']         augassign: '+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^='                  | '<<=' | '>>=' | '**=' | '//=' @@ -2858,6 +2945,76 @@ ast_for_expr_stmt(struct compiling *c, const node *n)          return AugAssign(expr1, newoperator, expr2, LINENO(n), n->n_col_offset, c->c_arena);      } +    else if (TYPE(CHILD(n, 1)) == annassign) { +        expr_ty expr1, expr2, expr3; +        node *ch = CHILD(n, 0); +        node *deep, *ann = CHILD(n, 1); +        int simple = 1; + +        /* we keep track of parens to qualify (x) as expression not name */ +        deep = ch; +        while (NCH(deep) == 1) { +            deep = CHILD(deep, 0); +        } +        if (NCH(deep) > 0 && TYPE(CHILD(deep, 0)) == LPAR) { +            simple = 0; +        } +        expr1 = ast_for_testlist(c, ch); +        if (!expr1) { +            return NULL; +        } +        switch (expr1->kind) { +            case Name_kind: +                if (forbidden_name(c, expr1->v.Name.id, n, 0)) { +                    return NULL; +                } +                expr1->v.Name.ctx = Store; +                break; +            case Attribute_kind: +                if (forbidden_name(c, expr1->v.Attribute.attr, n, 1)) { +                    return NULL; +                } +                expr1->v.Attribute.ctx = Store; +                break; +            case Subscript_kind: +                expr1->v.Subscript.ctx = Store; +                break; +            case List_kind: +                ast_error(c, ch, +                          "only single target (not list) can be annotated"); +                return NULL; +            case Tuple_kind: +                ast_error(c, ch, +                          "only single target (not tuple) can be annotated"); +                return NULL; +            default: +                ast_error(c, ch, +                          "illegal target for annotation"); +                return NULL; +        } + +        if (expr1->kind != Name_kind) { +            simple = 0; +        } +        ch = CHILD(ann, 1); +        expr2 = ast_for_expr(c, ch); +        if (!expr2) { +            return NULL; +        } +        if (NCH(ann) == 2) { +            return AnnAssign(expr1, expr2, NULL, simple, +                             LINENO(n), n->n_col_offset, c->c_arena); +        } +        else { +            ch = CHILD(ann, 3); +            expr3 = ast_for_expr(c, ch); +            if (!expr3) { +                return NULL; +            } +            return AnnAssign(expr1, expr2, expr3, simple, +                             LINENO(n), n->n_col_offset, c->c_arena); +        } +    }      else {          int i;          asdl_seq *targets; @@ -2992,9 +3149,6 @@ ast_for_flow_stmt(struct compiling *c, const node *n)                           "unexpected flow_stmt: %d", TYPE(ch));              return NULL;      } - -    PyErr_SetString(PyExc_SystemError, "unhandled flow statement"); -    return NULL;  }  static alias_ty @@ -3208,14 +3362,14 @@ ast_for_import_stmt(struct compiling *c, const node *n)              alias_ty import_alias = alias_for_import_name(c, n, 1);              if (!import_alias)                  return NULL; -                asdl_seq_SET(aliases, 0, import_alias); +            asdl_seq_SET(aliases, 0, import_alias);          }          else {              for (i = 0; i < NCH(n); i += 2) {                  alias_ty import_alias = alias_for_import_name(c, CHILD(n, i), 1);                  if (!import_alias)                      return NULL; -                    asdl_seq_SET(aliases, i / 2, import_alias); +                asdl_seq_SET(aliases, i / 2, import_alias);              }          }          if (mod != NULL) @@ -3881,7 +4035,7 @@ ast_for_stmt(struct compiling *c, const node *n)  }  static PyObject * -parsenumber(struct compiling *c, const char *s) +parsenumber_raw(struct compiling *c, const char *s)  {      const char *end;      long x; @@ -3924,6 +4078,31 @@ parsenumber(struct compiling *c, const char *s)  }  static PyObject * +parsenumber(struct compiling *c, const char *s) +{ +    char *dup, *end; +    PyObject *res = NULL; + +    assert(s != NULL); + +    if (strchr(s, '_') == NULL) { +        return parsenumber_raw(c, s); +    } +    /* Create a duplicate without underscores. */ +    dup = PyMem_Malloc(strlen(s) + 1); +    end = dup; +    for (; *s; s++) { +        if (*s != '_') { +            *end++ = *s; +        } +    } +    *end = '\0'; +    res = parsenumber_raw(c, dup); +    PyMem_Free(dup); +    return res; +} + +static PyObject *  decode_utf8(struct compiling *c, const char **sPtr, const char *end)  {      const char *s, *t; @@ -3935,84 +4114,862 @@ decode_utf8(struct compiling *c, const char **sPtr, const char *end)  }  static PyObject * -decode_unicode(struct compiling *c, const char *s, size_t len, int rawmode, const char *encoding) +decode_unicode_with_escapes(struct compiling *c, const char *s, size_t len)  {      PyObject *v, *u;      char *buf;      char *p;      const char *end; -    if (encoding == NULL) { -        u = NULL; -    } else { -        /* check for integer overflow */ -        if (len > PY_SIZE_MAX / 6) -            return NULL; -        /* "ä" (2 bytes) may become "\U000000E4" (10 bytes), or 1:5 -           "\ä" (3 bytes) may become "\u005c\U000000E4" (16 bytes), or ~1:6 */ -        u = PyBytes_FromStringAndSize((char *)NULL, len * 6); -        if (u == NULL) -            return NULL; -        p = buf = PyBytes_AsString(u); -        end = s + len; -        while (s < end) { -            if (*s == '\\') { -                *p++ = *s++; -                if (*s & 0x80) { -                    strcpy(p, "u005c"); -                    p += 5; -                } +    /* check for integer overflow */ +    if (len > SIZE_MAX / 6) +        return NULL; +    /* "ä" (2 bytes) may become "\U000000E4" (10 bytes), or 1:5 +       "\ä" (3 bytes) may become "\u005c\U000000E4" (16 bytes), or ~1:6 */ +    u = PyBytes_FromStringAndSize((char *)NULL, len * 6); +    if (u == NULL) +        return NULL; +    p = buf = PyBytes_AsString(u); +    end = s + len; +    while (s < end) { +        if (*s == '\\') { +            *p++ = *s++; +            if (*s & 0x80) { +                strcpy(p, "u005c"); +                p += 5;              } -            if (*s & 0x80) { /* XXX inefficient */ -                PyObject *w; -                int kind; -                void *data; -                Py_ssize_t len, i; -                w = decode_utf8(c, &s, end); -                if (w == NULL) { -                    Py_DECREF(u); -                    return NULL; +        } +        if (*s & 0x80) { /* XXX inefficient */ +            PyObject *w; +            int kind; +            void *data; +            Py_ssize_t len, i; +            w = decode_utf8(c, &s, end); +            if (w == NULL) { +                Py_DECREF(u); +                return NULL; +            } +            kind = PyUnicode_KIND(w); +            data = PyUnicode_DATA(w); +            len = PyUnicode_GET_LENGTH(w); +            for (i = 0; i < len; i++) { +                Py_UCS4 chr = PyUnicode_READ(kind, data, i); +                sprintf(p, "\\U%08x", chr); +                p += 10; +            } +            /* Should be impossible to overflow */ +            assert(p - buf <= Py_SIZE(u)); +            Py_DECREF(w); +        } else { +            *p++ = *s++; +        } +    } +    len = p - buf; +    s = buf; + +    v = PyUnicode_DecodeUnicodeEscape(s, len, NULL); +    Py_XDECREF(u); +    return v; +} + +/* Compile this expression in to an expr_ty.  Add parens around the +   expression, in order to allow leading spaces in the expression. */ +static expr_ty +fstring_compile_expr(const char *expr_start, const char *expr_end, +                     struct compiling *c, const node *n) + +{ +    int all_whitespace = 1; +    int kind; +    void *data; +    PyCompilerFlags cf; +    mod_ty mod; +    char *str; +    PyObject *o; +    Py_ssize_t len; +    Py_ssize_t i; + +    assert(expr_end >= expr_start); +    assert(*(expr_start-1) == '{'); +    assert(*expr_end == '}' || *expr_end == '!' || *expr_end == ':'); + +    /* We know there are no escapes here, because backslashes are not allowed, +       and we know it's utf-8 encoded (per PEP 263).  But, in order to check +       that each char is not whitespace, we need to decode it to unicode. +       Which is unfortunate, but such is life. */ + +    /* If the substring is all whitespace, it's an error.  We need to catch +       this here, and not when we call PyParser_ASTFromString, because turning +       the expression '' in to '()' would go from being invalid to valid. */ +    /* Note that this code says an empty string is all whitespace.  That's +       important.  There's a test for it: f'{}'. */ +    o = PyUnicode_DecodeUTF8(expr_start, expr_end-expr_start, NULL); +    if (o == NULL) +        return NULL; +    len = PyUnicode_GET_LENGTH(o); +    kind = PyUnicode_KIND(o); +    data = PyUnicode_DATA(o); +    for (i = 0; i < len; i++) { +        if (!Py_UNICODE_ISSPACE(PyUnicode_READ(kind, data, i))) { +            all_whitespace = 0; +            break; +        } +    } +    Py_DECREF(o); +    if (all_whitespace) { +        ast_error(c, n, "f-string: empty expression not allowed"); +        return NULL; +    } + +    /* Reuse len to be the length of the utf-8 input string. */ +    len = expr_end - expr_start; +    /* Allocate 3 extra bytes: open paren, close paren, null byte. */ +    str = PyMem_RawMalloc(len + 3); +    if (str == NULL) +        return NULL; + +    str[0] = '('; +    memcpy(str+1, expr_start, len); +    str[len+1] = ')'; +    str[len+2] = 0; + +    cf.cf_flags = PyCF_ONLY_AST; +    mod = PyParser_ASTFromString(str, "<fstring>", +                                 Py_eval_input, &cf, c->c_arena); +    PyMem_RawFree(str); +    if (!mod) +        return NULL; +    return mod->v.Expression.body; +} + +/* Return -1 on error. + +   Return 0 if we reached the end of the literal. + +   Return 1 if we haven't reached the end of the literal, but we want +   the caller to process the literal up to this point. Used for +   doubled braces. +*/ +static int +fstring_find_literal(const char **str, const char *end, int raw, +                     PyObject **literal, int recurse_lvl, +                     struct compiling *c, const node *n) +{ +    /* Get any literal string. It ends when we hit an un-doubled left +       brace (which isn't part of a unicode name escape such as +       "\N{EULER CONSTANT}"), or the end of the string. */ + +    const char *literal_start = *str; +    const char *literal_end; +    int in_named_escape = 0; +    int result = 0; + +    assert(*literal == NULL); +    for (; *str < end; (*str)++) { +        char ch = **str; +        if (!in_named_escape && ch == '{' && (*str)-literal_start >= 2 && +            *(*str-2) == '\\' && *(*str-1) == 'N') { +            in_named_escape = 1; +        } else if (in_named_escape && ch == '}') { +            in_named_escape = 0; +        } else if (ch == '{' || ch == '}') { +            /* Check for doubled braces, but only at the top level. If +               we checked at every level, then f'{0:{3}}' would fail +               with the two closing braces. */ +            if (recurse_lvl == 0) { +                if (*str+1 < end && *(*str+1) == ch) { +                    /* We're going to tell the caller that the literal ends +                       here, but that they should continue scanning. But also +                       skip over the second brace when we resume scanning. */ +                    literal_end = *str+1; +                    *str += 2; +                    result = 1; +                    goto done; +                } + +                /* Where a single '{' is the start of a new expression, a +                   single '}' is not allowed. */ +                if (ch == '}') { +                    ast_error(c, n, "f-string: single '}' is not allowed"); +                    return -1;                  } -                kind = PyUnicode_KIND(w); -                data = PyUnicode_DATA(w); -                len = PyUnicode_GET_LENGTH(w); -                for (i = 0; i < len; i++) { -                    Py_UCS4 chr = PyUnicode_READ(kind, data, i); -                    sprintf(p, "\\U%08x", chr); -                    p += 10; +            } +            /* We're either at a '{', which means we're starting another +               expression; or a '}', which means we're at the end of this +               f-string (for a nested format_spec). */ +            break; +        } +    } +    literal_end = *str; +    assert(*str <= end); +    assert(*str == end || **str == '{' || **str == '}'); +done: +    if (literal_start != literal_end) { +        if (raw) +            *literal = PyUnicode_DecodeUTF8Stateful(literal_start, +                                                    literal_end-literal_start, +                                                    NULL, NULL); +        else +            *literal = decode_unicode_with_escapes(c, literal_start, +                                                   literal_end-literal_start); +        if (!*literal) +            return -1; +    } +    return result; +} + +/* Forward declaration because parsing is recursive. */ +static expr_ty +fstring_parse(const char **str, const char *end, int raw, int recurse_lvl, +              struct compiling *c, const node *n); + +/* Parse the f-string at *str, ending at end.  We know *str starts an +   expression (so it must be a '{'). Returns the FormattedValue node, +   which includes the expression, conversion character, and +   format_spec expression. + +   Note that I don't do a perfect job here: I don't make sure that a +   closing brace doesn't match an opening paren, for example. It +   doesn't need to error on all invalid expressions, just correctly +   find the end of all valid ones. Any errors inside the expression +   will be caught when we parse it later. */ +static int +fstring_find_expr(const char **str, const char *end, int raw, int recurse_lvl, +                  expr_ty *expression, struct compiling *c, const node *n) +{ +    /* Return -1 on error, else 0. */ + +    const char *expr_start; +    const char *expr_end; +    expr_ty simple_expression; +    expr_ty format_spec = NULL; /* Optional format specifier. */ +    int conversion = -1; /* The conversion char. -1 if not specified. */ + +    /* 0 if we're not in a string, else the quote char we're trying to +       match (single or double quote). */ +    char quote_char = 0; + +    /* If we're inside a string, 1=normal, 3=triple-quoted. */ +    int string_type = 0; + +    /* Keep track of nesting level for braces/parens/brackets in +       expressions. */ +    Py_ssize_t nested_depth = 0; + +    /* Can only nest one level deep. */ +    if (recurse_lvl >= 2) { +        ast_error(c, n, "f-string: expressions nested too deeply"); +        return -1; +    } + +    /* The first char must be a left brace, or we wouldn't have gotten +       here. Skip over it. */ +    assert(**str == '{'); +    *str += 1; + +    expr_start = *str; +    for (; *str < end; (*str)++) { +        char ch; + +        /* Loop invariants. */ +        assert(nested_depth >= 0); +        assert(*str >= expr_start && *str < end); +        if (quote_char) +            assert(string_type == 1 || string_type == 3); +        else +            assert(string_type == 0); + +        ch = **str; +        /* Nowhere inside an expression is a backslash allowed. */ +        if (ch == '\\') { +            /* Error: can't include a backslash character, inside +               parens or strings or not. */ +            ast_error(c, n, "f-string expression part " +                            "cannot include a backslash"); +            return -1; +        } +        if (quote_char) { +            /* We're inside a string. See if we're at the end. */ +            /* This code needs to implement the same non-error logic +               as tok_get from tokenizer.c, at the letter_quote +               label. To actually share that code would be a +               nightmare. But, it's unlikely to change and is small, +               so duplicate it here. Note we don't need to catch all +               of the errors, since they'll be caught when parsing the +               expression. We just need to match the non-error +               cases. Thus we can ignore \n in single-quoted strings, +               for example. Or non-terminated strings. */ +            if (ch == quote_char) { +                /* Does this match the string_type (single or triple +                   quoted)? */ +                if (string_type == 3) { +                    if (*str+2 < end && *(*str+1) == ch && *(*str+2) == ch) { +                        /* We're at the end of a triple quoted string. */ +                        *str += 2; +                        string_type = 0; +                        quote_char = 0; +                        continue; +                    } +                } else { +                    /* We're at the end of a normal string. */ +                    quote_char = 0; +                    string_type = 0; +                    continue;                  } -                /* Should be impossible to overflow */ -                assert(p - buf <= Py_SIZE(u)); -                Py_DECREF(w); +            } +        } else if (ch == '\'' || ch == '"') { +            /* Is this a triple quoted string? */ +            if (*str+2 < end && *(*str+1) == ch && *(*str+2) == ch) { +                string_type = 3; +                *str += 2;              } else { -                *p++ = *s++; +                /* Start of a normal string. */ +                string_type = 1; +            } +            /* Start looking for the end of the string. */ +            quote_char = ch; +        } else if (ch == '[' || ch == '{' || ch == '(') { +            nested_depth++; +        } else if (nested_depth != 0 && +                   (ch == ']' || ch == '}' || ch == ')')) { +            nested_depth--; +        } else if (ch == '#') { +            /* Error: can't include a comment character, inside parens +               or not. */ +            ast_error(c, n, "f-string expression part cannot include '#'"); +            return -1; +        } else if (nested_depth == 0 && +                   (ch == '!' || ch == ':' || ch == '}')) { +            /* First, test for the special case of "!=". Since '=' is +               not an allowed conversion character, nothing is lost in +               this test. */ +            if (ch == '!' && *str+1 < end && *(*str+1) == '=') { +                /* This isn't a conversion character, just continue. */ +                continue;              } +            /* Normal way out of this loop. */ +            break; +        } else { +            /* Just consume this char and loop around. */          } -        len = p - buf; -        s = buf;      } -    if (rawmode) -        v = PyUnicode_DecodeRawUnicodeEscape(s, len, NULL); -    else -        v = PyUnicode_DecodeUnicodeEscape(s, len, NULL); -    Py_XDECREF(u); -    return v; +    expr_end = *str; +    /* If we leave this loop in a string or with mismatched parens, we +       don't care. We'll get a syntax error when compiling the +       expression. But, we can produce a better error message, so +       let's just do that.*/ +    if (quote_char) { +        ast_error(c, n, "f-string: unterminated string"); +        return -1; +    } +    if (nested_depth) { +        ast_error(c, n, "f-string: mismatched '(', '{', or '['"); +        return -1; +    } + +    if (*str >= end) +        goto unexpected_end_of_string; + +    /* Compile the expression as soon as possible, so we show errors +       related to the expression before errors related to the +       conversion or format_spec. */ +    simple_expression = fstring_compile_expr(expr_start, expr_end, c, n); +    if (!simple_expression) +        return -1; + +    /* Check for a conversion char, if present. */ +    if (**str == '!') { +        *str += 1; +        if (*str >= end) +            goto unexpected_end_of_string; + +        conversion = **str; +        *str += 1; + +        /* Validate the conversion. */ +        if (!(conversion == 's' || conversion == 'r' +              || conversion == 'a')) { +            ast_error(c, n, "f-string: invalid conversion character: " +                            "expected 's', 'r', or 'a'"); +            return -1; +        } +    } + +    /* Check for the format spec, if present. */ +    if (*str >= end) +        goto unexpected_end_of_string; +    if (**str == ':') { +        *str += 1; +        if (*str >= end) +            goto unexpected_end_of_string; + +        /* Parse the format spec. */ +        format_spec = fstring_parse(str, end, raw, recurse_lvl+1, c, n); +        if (!format_spec) +            return -1; +    } + +    if (*str >= end || **str != '}') +        goto unexpected_end_of_string; + +    /* We're at a right brace. Consume it. */ +    assert(*str < end); +    assert(**str == '}'); +    *str += 1; + +    /* And now create the FormattedValue node that represents this +       entire expression with the conversion and format spec. */ +    *expression = FormattedValue(simple_expression, conversion, +                                 format_spec, LINENO(n), n->n_col_offset, +                                 c->c_arena); +    if (!*expression) +        return -1; + +    return 0; + +unexpected_end_of_string: +    ast_error(c, n, "f-string: expecting '}'"); +    return -1;  } -/* s is a Python string literal, including the bracketing quote characters, - * and r &/or b prefixes (if any), and embedded escape sequences (if any). - * parsestr parses it, and returns the decoded Python string object. - */ -static PyObject * -parsestr(struct compiling *c, const node *n, int *bytesmode) +/* Return -1 on error. + +   Return 0 if we have a literal (possible zero length) and an +   expression (zero length if at the end of the string. + +   Return 1 if we have a literal, but no expression, and we want the +   caller to call us again. This is used to deal with doubled +   braces. + +   When called multiple times on the string 'a{{b{0}c', this function +   will return: + +   1. the literal 'a{' with no expression, and a return value +      of 1. Despite the fact that there's no expression, the return +      value of 1 means we're not finished yet. + +   2. the literal 'b' and the expression '0', with a return value of +      0. The fact that there's an expression means we're not finished. + +   3. literal 'c' with no expression and a return value of 0. The +      combination of the return value of 0 with no expression means +      we're finished. +*/ +static int +fstring_find_literal_and_expr(const char **str, const char *end, int raw, +                              int recurse_lvl, PyObject **literal, +                              expr_ty *expression, +                              struct compiling *c, const node *n) +{ +    int result; + +    assert(*literal == NULL && *expression == NULL); + +    /* Get any literal string. */ +    result = fstring_find_literal(str, end, raw, literal, recurse_lvl, c, n); +    if (result < 0) +        goto error; + +    assert(result == 0 || result == 1); + +    if (result == 1) +        /* We have a literal, but don't look at the expression. */ +        return 1; + +    if (*str >= end || **str == '}') +        /* We're at the end of the string or the end of a nested +           f-string: no expression. The top-level error case where we +           expect to be at the end of the string but we're at a '}' is +           handled later. */ +        return 0; + +    /* We must now be the start of an expression, on a '{'. */ +    assert(**str == '{'); + +    if (fstring_find_expr(str, end, raw, recurse_lvl, expression, c, n) < 0) +        goto error; + +    return 0; + +error: +    Py_CLEAR(*literal); +    return -1; +} + +#define EXPRLIST_N_CACHED  64 + +typedef struct { +    /* Incrementally build an array of expr_ty, so be used in an +       asdl_seq. Cache some small but reasonably sized number of +       expr_ty's, and then after that start dynamically allocating, +       doubling the number allocated each time. Note that the f-string +       f'{0}a{1}' contains 3 expr_ty's: 2 FormattedValue's, and one +       Str for the literal 'a'. So you add expr_ty's about twice as +       fast as you add exressions in an f-string. */ + +    Py_ssize_t allocated;  /* Number we've allocated. */ +    Py_ssize_t size;       /* Number we've used. */ +    expr_ty    *p;         /* Pointer to the memory we're actually +                              using. Will point to 'data' until we +                              start dynamically allocating. */ +    expr_ty    data[EXPRLIST_N_CACHED]; +} ExprList; + +#ifdef NDEBUG +#define ExprList_check_invariants(l) +#else +static void +ExprList_check_invariants(ExprList *l) +{ +    /* Check our invariants. Make sure this object is "live", and +       hasn't been deallocated. */ +    assert(l->size >= 0); +    assert(l->p != NULL); +    if (l->size <= EXPRLIST_N_CACHED) +        assert(l->data == l->p); +} +#endif + +static void +ExprList_Init(ExprList *l) +{ +    l->allocated = EXPRLIST_N_CACHED; +    l->size = 0; + +    /* Until we start allocating dynamically, p points to data. */ +    l->p = l->data; + +    ExprList_check_invariants(l); +} + +static int +ExprList_Append(ExprList *l, expr_ty exp) +{ +    ExprList_check_invariants(l); +    if (l->size >= l->allocated) { +        /* We need to alloc (or realloc) the memory. */ +        Py_ssize_t new_size = l->allocated * 2; + +        /* See if we've ever allocated anything dynamically. */ +        if (l->p == l->data) { +            Py_ssize_t i; +            /* We're still using the cached data. Switch to +               alloc-ing. */ +            l->p = PyMem_RawMalloc(sizeof(expr_ty) * new_size); +            if (!l->p) +                return -1; +            /* Copy the cached data into the new buffer. */ +            for (i = 0; i < l->size; i++) +                l->p[i] = l->data[i]; +        } else { +            /* Just realloc. */ +            expr_ty *tmp = PyMem_RawRealloc(l->p, sizeof(expr_ty) * new_size); +            if (!tmp) { +                PyMem_RawFree(l->p); +                l->p = NULL; +                return -1; +            } +            l->p = tmp; +        } + +        l->allocated = new_size; +        assert(l->allocated == 2 * l->size); +    } + +    l->p[l->size++] = exp; + +    ExprList_check_invariants(l); +    return 0; +} + +static void +ExprList_Dealloc(ExprList *l) +{ +    ExprList_check_invariants(l); + +    /* If there's been an error, or we've never dynamically allocated, +       do nothing. */ +    if (!l->p || l->p == l->data) { +        /* Do nothing. */ +    } else { +        /* We have dynamically allocated. Free the memory. */ +        PyMem_RawFree(l->p); +    } +    l->p = NULL; +    l->size = -1; +} + +static asdl_seq * +ExprList_Finish(ExprList *l, PyArena *arena) +{ +    asdl_seq *seq; + +    ExprList_check_invariants(l); + +    /* Allocate the asdl_seq and copy the expressions in to it. */ +    seq = _Py_asdl_seq_new(l->size, arena); +    if (seq) { +        Py_ssize_t i; +        for (i = 0; i < l->size; i++) +            asdl_seq_SET(seq, i, l->p[i]); +    } +    ExprList_Dealloc(l); +    return seq; +} + +/* The FstringParser is designed to add a mix of strings and +   f-strings, and concat them together as needed. Ultimately, it +   generates an expr_ty. */ +typedef struct { +    PyObject *last_str; +    ExprList expr_list; +} FstringParser; + +#ifdef NDEBUG +#define FstringParser_check_invariants(state) +#else +static void +FstringParser_check_invariants(FstringParser *state) +{ +    if (state->last_str) +        assert(PyUnicode_CheckExact(state->last_str)); +    ExprList_check_invariants(&state->expr_list); +} +#endif + +static void +FstringParser_Init(FstringParser *state) +{ +    state->last_str = NULL; +    ExprList_Init(&state->expr_list); +    FstringParser_check_invariants(state); +} + +static void +FstringParser_Dealloc(FstringParser *state) +{ +    FstringParser_check_invariants(state); + +    Py_XDECREF(state->last_str); +    ExprList_Dealloc(&state->expr_list); +} + +/* Make a Str node, but decref the PyUnicode object being added. */ +static expr_ty +make_str_node_and_del(PyObject **str, struct compiling *c, const node* n) +{ +    PyObject *s = *str; +    *str = NULL; +    assert(PyUnicode_CheckExact(s)); +    if (PyArena_AddPyObject(c->c_arena, s) < 0) { +        Py_DECREF(s); +        return NULL; +    } +    return Str(s, LINENO(n), n->n_col_offset, c->c_arena); +} + +/* Add a non-f-string (that is, a regular literal string). str is +   decref'd. */ +static int +FstringParser_ConcatAndDel(FstringParser *state, PyObject *str) +{ +    FstringParser_check_invariants(state); + +    assert(PyUnicode_CheckExact(str)); + +    if (PyUnicode_GET_LENGTH(str) == 0) { +        Py_DECREF(str); +        return 0; +    } + +    if (!state->last_str) { +        /* We didn't have a string before, so just remember this one. */ +        state->last_str = str; +    } else { +        /* Concatenate this with the previous string. */ +        PyUnicode_AppendAndDel(&state->last_str, str); +        if (!state->last_str) +            return -1; +    } +    FstringParser_check_invariants(state); +    return 0; +} + +/* Parse an f-string. The f-string is in *str to end, with no +   'f' or quotes. */ +static int +FstringParser_ConcatFstring(FstringParser *state, const char **str, +                            const char *end, int raw, int recurse_lvl, +                            struct compiling *c, const node *n) +{ +    FstringParser_check_invariants(state); + +    /* Parse the f-string. */ +    while (1) { +        PyObject *literal = NULL; +        expr_ty expression = NULL; + +        /* If there's a zero length literal in front of the +           expression, literal will be NULL. If we're at the end of +           the f-string, expression will be NULL (unless result == 1, +           see below). */ +        int result = fstring_find_literal_and_expr(str, end, raw, recurse_lvl, +                                                   &literal, &expression, +                                                   c, n); +        if (result < 0) +            return -1; + +        /* Add the literal, if any. */ +        if (!literal) { +            /* Do nothing. Just leave last_str alone (and possibly +               NULL). */ +        } else if (!state->last_str) { +            state->last_str = literal; +            literal = NULL; +        } else { +            /* We have a literal, concatenate it. */ +            assert(PyUnicode_GET_LENGTH(literal) != 0); +            if (FstringParser_ConcatAndDel(state, literal) < 0) +                return -1; +            literal = NULL; +        } +        assert(!state->last_str || +               PyUnicode_GET_LENGTH(state->last_str) != 0); + +        /* We've dealt with the literal now. It can't be leaked on further +           errors. */ +        assert(literal == NULL); + +        /* See if we should just loop around to get the next literal +           and expression, while ignoring the expression this +           time. This is used for un-doubling braces, as an +           optimization. */ +        if (result == 1) +            continue; + +        if (!expression) +            /* We're done with this f-string. */ +            break; + +        /* We know we have an expression. Convert any existing string +           to a Str node. */ +        if (!state->last_str) { +            /* Do nothing. No previous literal. */ +        } else { +            /* Convert the existing last_str literal to a Str node. */ +            expr_ty str = make_str_node_and_del(&state->last_str, c, n); +            if (!str || ExprList_Append(&state->expr_list, str) < 0) +                return -1; +        } + +        if (ExprList_Append(&state->expr_list, expression) < 0) +            return -1; +    } + +    /* If recurse_lvl is zero, then we must be at the end of the +       string. Otherwise, we must be at a right brace. */ + +    if (recurse_lvl == 0 && *str < end-1) { +        ast_error(c, n, "f-string: unexpected end of string"); +        return -1; +    } +    if (recurse_lvl != 0 && **str != '}') { +        ast_error(c, n, "f-string: expecting '}'"); +        return -1; +    } + +    FstringParser_check_invariants(state); +    return 0; +} + +/* Convert the partial state reflected in last_str and expr_list to an +   expr_ty. The expr_ty can be a Str, or a JoinedStr. */ +static expr_ty +FstringParser_Finish(FstringParser *state, struct compiling *c, +                     const node *n) +{ +    asdl_seq *seq; + +    FstringParser_check_invariants(state); + +    /* If we're just a constant string with no expressions, return +       that. */ +    if(state->expr_list.size == 0) { +        if (!state->last_str) { +            /* Create a zero length string. */ +            state->last_str = PyUnicode_FromStringAndSize(NULL, 0); +            if (!state->last_str) +                goto error; +        } +        return make_str_node_and_del(&state->last_str, c, n); +    } + +    /* Create a Str node out of last_str, if needed. It will be the +       last node in our expression list. */ +    if (state->last_str) { +        expr_ty str = make_str_node_and_del(&state->last_str, c, n); +        if (!str || ExprList_Append(&state->expr_list, str) < 0) +            goto error; +    } +    /* This has already been freed. */ +    assert(state->last_str == NULL); + +    seq = ExprList_Finish(&state->expr_list, c->c_arena); +    if (!seq) +        goto error; + +    /* If there's only one expression, return it. Otherwise, we need +       to join them together. */ +    if (seq->size == 1) +        return seq->elements[0]; + +    return JoinedStr(seq, LINENO(n), n->n_col_offset, c->c_arena); + +error: +    FstringParser_Dealloc(state); +    return NULL; +} + +/* Given an f-string (with no 'f' or quotes) that's in *str and ends +   at end, parse it into an expr_ty.  Return NULL on error.  Adjust +   str to point past the parsed portion. */ +static expr_ty +fstring_parse(const char **str, const char *end, int raw, int recurse_lvl, +              struct compiling *c, const node *n) +{ +    FstringParser state; + +    FstringParser_Init(&state); +    if (FstringParser_ConcatFstring(&state, str, end, raw, recurse_lvl, +                                    c, n) < 0) { +        FstringParser_Dealloc(&state); +        return NULL; +    } + +    return FstringParser_Finish(&state, c, n); +} + +/* n is a Python string literal, including the bracketing quote +   characters, and r, b, u, &/or f prefixes (if any), and embedded +   escape sequences (if any). parsestr parses it, and sets *result to +   decoded Python string object.  If the string is an f-string, set +   *fstr and *fstrlen to the unparsed string object.  Return 0 if no +   errors occurred. +*/ +static int +parsestr(struct compiling *c, const node *n, int *bytesmode, int *rawmode, +         PyObject **result, const char **fstr, Py_ssize_t *fstrlen)  {      size_t len;      const char *s = STR(n);      int quote = Py_CHARMASK(*s); -    int rawmode = 0; -    int need_encoding; +    int fmode = 0; +    *bytesmode = 0; +    *rawmode = 0; +    *result = NULL; +    *fstr = NULL;      if (Py_ISALPHA(quote)) { -        while (!*bytesmode || !rawmode) { +        while (!*bytesmode || !*rawmode) {              if (quote == 'b' || quote == 'B') {                  quote = *++s;                  *bytesmode = 1; @@ -4022,114 +4979,169 @@ parsestr(struct compiling *c, const node *n, int *bytesmode)              }              else if (quote == 'r' || quote == 'R') {                  quote = *++s; -                rawmode = 1; +                *rawmode = 1; +            } +            else if (quote == 'f' || quote == 'F') { +                quote = *++s; +                fmode = 1;              }              else {                  break;              }          }      } +    if (fmode && *bytesmode) { +        PyErr_BadInternalCall(); +        return -1; +    }      if (quote != '\'' && quote != '\"') {          PyErr_BadInternalCall(); -        return NULL; +        return -1;      } +    /* Skip the leading quote char. */      s++;      len = strlen(s);      if (len > INT_MAX) {          PyErr_SetString(PyExc_OverflowError,                          "string to parse is too long"); -        return NULL; +        return -1;      }      if (s[--len] != quote) { +        /* Last quote char must match the first. */          PyErr_BadInternalCall(); -        return NULL; +        return -1;      }      if (len >= 4 && s[0] == quote && s[1] == quote) { +        /* A triple quoted string. We've already skipped one quote at +           the start and one at the end of the string. Now skip the +           two at the start. */          s += 2;          len -= 2; +        /* And check that the last two match. */          if (s[--len] != quote || s[--len] != quote) {              PyErr_BadInternalCall(); -            return NULL; +            return -1;          }      } -    if (!*bytesmode && !rawmode) { -        return decode_unicode(c, s, len, rawmode, c->c_encoding); + +    if (fmode) { +        /* Just return the bytes. The caller will parse the resulting +           string. */ +        *fstr = s; +        *fstrlen = len; +        return 0;      } + +    /* Not an f-string. */ +    /* Avoid invoking escape decoding routines if possible. */ +    *rawmode = *rawmode || strchr(s, '\\') == NULL;      if (*bytesmode) { -        /* Disallow non-ascii characters (but not escapes) */ +        /* Disallow non-ASCII characters. */          const char *ch;          for (ch = s; *ch; ch++) {              if (Py_CHARMASK(*ch) >= 0x80) {                  ast_error(c, n, "bytes can only contain ASCII "                            "literal characters."); -                return NULL; +                return -1;              }          } +        if (*rawmode) +            *result = PyBytes_FromStringAndSize(s, len); +        else +            *result = PyBytes_DecodeEscape(s, len, NULL, /* ignored */ 0, NULL); +    } else { +        if (*rawmode) +            *result = PyUnicode_DecodeUTF8Stateful(s, len, NULL, NULL); +        else +            *result = decode_unicode_with_escapes(c, s, len);      } -    need_encoding = (!*bytesmode && c->c_encoding != NULL && -                     strcmp(c->c_encoding, "utf-8") != 0); -    if (rawmode || strchr(s, '\\') == NULL) { -        if (need_encoding) { -            PyObject *v, *u = PyUnicode_DecodeUTF8(s, len, NULL); -            if (u == NULL || !*bytesmode) -                return u; -            v = PyUnicode_AsEncodedString(u, c->c_encoding, NULL); -            Py_DECREF(u); -            return v; -        } else if (*bytesmode) { -            return PyBytes_FromStringAndSize(s, len); -        } else if (strcmp(c->c_encoding, "utf-8") == 0) { -            return PyUnicode_FromStringAndSize(s, len); -        } else { -            return PyUnicode_DecodeLatin1(s, len, NULL); -        } -    } -    return PyBytes_DecodeEscape(s, len, NULL, 1, -                                 need_encoding ? c->c_encoding : NULL); +    return *result == NULL ? -1 : 0;  } -/* Build a Python string object out of a STRING+ atom.  This takes care of - * compile-time literal catenation, calling parsestr() on each piece, and - * pasting the intermediate results together. - */ -static PyObject * -parsestrplus(struct compiling *c, const node *n, int *bytesmode) +/* Accepts a STRING+ atom, and produces an expr_ty node. Run through +   each STRING atom, and process it as needed. For bytes, just +   concatenate them together, and the result will be a Bytes node. For +   normal strings and f-strings, concatenate them together. The result +   will be a Str node if there were no f-strings; a FormattedValue +   node if there's just an f-string (with no leading or trailing +   literals), or a JoinedStr node if there are multiple f-strings or +   any literals involved. */ +static expr_ty +parsestrplus(struct compiling *c, const node *n)  { -    PyObject *v; +    int bytesmode = 0; +    PyObject *bytes_str = NULL;      int i; -    REQ(CHILD(n, 0), STRING); -    v = parsestr(c, CHILD(n, 0), bytesmode); -    if (v != NULL) { -        /* String literal concatenation */ -        for (i = 1; i < NCH(n); i++) { -            PyObject *s; -            int subbm = 0; -            s = parsestr(c, CHILD(n, i), &subbm); -            if (s == NULL) -                goto onError; -            if (*bytesmode != subbm) { -                ast_error(c, n, "cannot mix bytes and nonbytes literals"); -                Py_DECREF(s); -                goto onError; -            } -            if (PyBytes_Check(v) && PyBytes_Check(s)) { -                PyBytes_ConcatAndDel(&v, s); -                if (v == NULL) -                    goto onError; -            } -            else { -                PyObject *temp = PyUnicode_Concat(v, s); -                Py_DECREF(s); -                Py_DECREF(v); -                v = temp; -                if (v == NULL) -                    goto onError; + +    FstringParser state; +    FstringParser_Init(&state); + +    for (i = 0; i < NCH(n); i++) { +        int this_bytesmode; +        int this_rawmode; +        PyObject *s; +        const char *fstr; +        Py_ssize_t fstrlen = -1;  /* Silence a compiler warning. */ + +        REQ(CHILD(n, i), STRING); +        if (parsestr(c, CHILD(n, i), &this_bytesmode, &this_rawmode, &s, +                     &fstr, &fstrlen) != 0) +            goto error; + +        /* Check that we're not mixing bytes with unicode. */ +        if (i != 0 && bytesmode != this_bytesmode) { +            ast_error(c, n, "cannot mix bytes and nonbytes literals"); +            Py_DECREF(s); +            goto error; +        } +        bytesmode = this_bytesmode; + +        if (fstr != NULL) { +            int result; +            assert(s == NULL && !bytesmode); +            /* This is an f-string. Parse and concatenate it. */ +            result = FstringParser_ConcatFstring(&state, &fstr, fstr+fstrlen, +                                                 this_rawmode, 0, c, n); +            if (result < 0) +                goto error; +        } else { +            assert(bytesmode ? PyBytes_CheckExact(s) : +                   PyUnicode_CheckExact(s)); + +            /* A string or byte string. */ +            assert(s != NULL && fstr == NULL); +            if (bytesmode) { +                /* For bytes, concat as we go. */ +                if (i == 0) { +                    /* First time, just remember this value. */ +                    bytes_str = s; +                } else { +                    PyBytes_ConcatAndDel(&bytes_str, s); +                    if (!bytes_str) +                        goto error; +                } +            } else { +                assert(s != NULL && fstr == NULL); +                /* This is a regular string. Concatenate it. */ +                if (FstringParser_ConcatAndDel(&state, s) < 0) +                    goto error;              }          }      } -    return v; +    if (bytesmode) { +        /* Just return the bytes object and we're done. */ +        if (PyArena_AddPyObject(c->c_arena, bytes_str) < 0) +            goto error; +        return Bytes(bytes_str, LINENO(n), n->n_col_offset, c->c_arena); +    } + +    /* We're not a bytes string, bytes_str should never have been set. */ +    assert(bytes_str == NULL); + +    return FstringParser_Finish(&state, c, n); -  onError: -    Py_XDECREF(v); +error: +    Py_XDECREF(bytes_str); +    FstringParser_Dealloc(&state);      return NULL;  } | 
