diff options
Diffstat (limited to 'Python/ast.c')
-rw-r--r-- | Python/ast.c | 1460 |
1 files changed, 1268 insertions, 192 deletions
diff --git a/Python/ast.c b/Python/ast.c index d36a9b7..217ea14 100644 --- a/Python/ast.c +++ b/Python/ast.c @@ -132,6 +132,52 @@ validate_arguments(arguments_ty args) } static int +validate_constant(PyObject *value) +{ + if (value == Py_None || value == Py_Ellipsis) + return 1; + + if (PyLong_CheckExact(value) + || PyFloat_CheckExact(value) + || PyComplex_CheckExact(value) + || PyBool_Check(value) + || PyUnicode_CheckExact(value) + || PyBytes_CheckExact(value)) + return 1; + + if (PyTuple_CheckExact(value) || PyFrozenSet_CheckExact(value)) { + PyObject *it; + + it = PyObject_GetIter(value); + if (it == NULL) + return 0; + + while (1) { + PyObject *item = PyIter_Next(it); + if (item == NULL) { + if (PyErr_Occurred()) { + Py_DECREF(it); + return 0; + } + break; + } + + if (!validate_constant(item)) { + Py_DECREF(it); + Py_DECREF(item); + return 0; + } + Py_DECREF(item); + } + + Py_DECREF(it); + return 1; + } + + return 0; +} + +static int validate_expr(expr_ty exp, expr_context_ty ctx) { int check_ctx = 1; @@ -240,6 +286,14 @@ validate_expr(expr_ty exp, expr_context_ty ctx) return validate_expr(exp->v.Call.func, Load) && validate_exprs(exp->v.Call.args, Load, 0) && validate_keywords(exp->v.Call.keywords); + case Constant_kind: + if (!validate_constant(exp->v.Constant.value)) { + PyErr_Format(PyExc_TypeError, + "got an invalid type in Constant: %s", + Py_TYPE(exp->v.Constant.value)->tp_name); + return 0; + } + return 1; case Num_kind: { PyObject *n = exp->v.Num.n; if (!PyLong_CheckExact(n) && !PyFloat_CheckExact(n) && @@ -257,6 +311,14 @@ validate_expr(expr_ty exp, expr_context_ty ctx) } return 1; } + case JoinedStr_kind: + return validate_exprs(exp->v.JoinedStr.values, Load, 0); + case FormattedValue_kind: + if (validate_expr(exp->v.FormattedValue.value, Load) == 0) + return 0; + if (exp->v.FormattedValue.format_spec) + return validate_expr(exp->v.FormattedValue.format_spec, Load); + return 1; case Bytes_kind: { PyObject *b = exp->v.Bytes.s; if (!PyBytes_CheckExact(b)) { @@ -335,6 +397,17 @@ validate_stmt(stmt_ty stmt) case AugAssign_kind: return validate_expr(stmt->v.AugAssign.target, Store) && validate_expr(stmt->v.AugAssign.value, Load); + case AnnAssign_kind: + if (stmt->v.AnnAssign.target->kind != Name_kind && + stmt->v.AnnAssign.simple) { + PyErr_SetString(PyExc_TypeError, + "AnnAssign with simple non-Name target"); + return 0; + } + return validate_expr(stmt->v.AnnAssign.target, Store) && + (!stmt->v.AnnAssign.value || + validate_expr(stmt->v.AnnAssign.value, Load)) && + validate_expr(stmt->v.AnnAssign.annotation, Load); case For_kind: return validate_expr(stmt->v.For.target, Store) && validate_expr(stmt->v.For.iter, Load) && @@ -413,8 +486,8 @@ validate_stmt(stmt_ty stmt) case Import_kind: return validate_nonempty_seq(stmt->v.Import.names, "names", "Import"); case ImportFrom_kind: - if (stmt->v.ImportFrom.level < -1) { - PyErr_SetString(PyExc_ValueError, "ImportFrom level less than -1"); + if (stmt->v.ImportFrom.level < 0) { + PyErr_SetString(PyExc_ValueError, "Negative ImportFrom level"); return 0; } return validate_nonempty_seq(stmt->v.ImportFrom.names, "names", "ImportFrom"); @@ -512,8 +585,7 @@ PyAST_Validate(mod_ty mod) /* Data structure used internally */ struct compiling { - char *c_encoding; /* source encoding */ - PyArena *c_arena; /* arena for allocating memeory */ + PyArena *c_arena; /* Arena for allocating memory. */ PyObject *c_filename; /* filename */ PyObject *c_normalize; /* Normalization function from unicodedata. */ PyObject *c_normalize_args; /* Normalization argument tuple. */ @@ -535,9 +607,7 @@ static stmt_ty ast_for_for_stmt(struct compiling *, const node *, int); static expr_ty ast_for_call(struct compiling *, const node *, expr_ty); static PyObject *parsenumber(struct compiling *, const char *); -static PyObject *parsestr(struct compiling *, const node *n, int *bytesmode); -static PyObject *parsestrplus(struct compiling *, const node *n, - int *bytesmode); +static expr_ty parsestrplus(struct compiling *, const node *n); #define COMP_GENEXP 0 #define COMP_LISTCOMP 1 @@ -701,23 +771,11 @@ PyAST_FromNodeObject(const node *n, PyCompilerFlags *flags, c.c_arena = arena; /* borrowed reference */ c.c_filename = filename; - c.c_normalize = c.c_normalize_args = NULL; - if (flags && flags->cf_flags & PyCF_SOURCE_IS_UTF8) { - c.c_encoding = "utf-8"; - if (TYPE(n) == encoding_decl) { -#if 0 - ast_error(c, n, "encoding declaration in Unicode string"); - goto out; -#endif - n = CHILD(n, 0); - } - } else if (TYPE(n) == encoding_decl) { - c.c_encoding = STR(n); + c.c_normalize = NULL; + c.c_normalize_args = NULL; + + if (TYPE(n) == encoding_decl) n = CHILD(n, 0); - } else { - /* PEP 3120 */ - c.c_encoding = "utf-8"; - } k = 0; switch (TYPE(n)) { @@ -864,7 +922,7 @@ get_operator(const node *n) } } -static const char* FORBIDDEN[] = { +static const char * const FORBIDDEN[] = { "None", "True", "False", @@ -880,8 +938,30 @@ forbidden_name(struct compiling *c, identifier name, const node *n, ast_error(c, n, "assignment to keyword"); return 1; } + if (_PyUnicode_EqualToASCIIString(name, "async") || + _PyUnicode_EqualToASCIIString(name, "await")) + { + PyObject *message = PyUnicode_FromString( + "'async' and 'await' will become reserved keywords" + " in Python 3.7"); + int ret; + if (message == NULL) { + return 1; + } + ret = PyErr_WarnExplicitObject( + PyExc_DeprecationWarning, + message, + c->c_filename, + LINENO(n), + NULL, + NULL); + Py_DECREF(message); + if (ret < 0) { + return 1; + } + } if (full_checks) { - const char **p; + const char * const *p; for (p = FORBIDDEN; *p; p++) { if (_PyUnicode_EqualToASCIIString(name, *p)) { ast_error(c, n, "assignment to keyword"); @@ -943,13 +1023,8 @@ set_context(struct compiling *c, expr_ty e, expr_context_ty ctx, const node *n) s = e->v.List.elts; break; case Tuple_kind: - if (asdl_seq_LEN(e->v.Tuple.elts)) { - e->v.Tuple.ctx = ctx; - s = e->v.Tuple.elts; - } - else { - expr_name = "()"; - } + e->v.Tuple.ctx = ctx; + s = e->v.Tuple.elts; break; case Lambda_kind: expr_name = "lambda"; @@ -986,6 +1061,8 @@ set_context(struct compiling *c, expr_ty e, expr_context_ty ctx, const node *n) case Num_kind: case Str_kind: case Bytes_kind: + case JoinedStr_kind: + case FormattedValue_kind: expr_name = "literal"; break; case NameConstant_kind: @@ -1259,16 +1336,20 @@ ast_for_arguments(struct compiling *c, const node *n) and varargslist (lambda definition). parameters: '(' [typedargslist] ')' - typedargslist: ((tfpdef ['=' test] ',')* - ('*' [tfpdef] (',' tfpdef ['=' test])* [',' '**' tfpdef] - | '**' tfpdef) - | tfpdef ['=' test] (',' tfpdef ['=' test])* [',']) + typedargslist: (tfpdef ['=' test] (',' tfpdef ['=' test])* [',' [ + '*' [tfpdef] (',' tfpdef ['=' test])* [',' ['**' tfpdef [',']]] + | '**' tfpdef [',']]] + | '*' [tfpdef] (',' tfpdef ['=' test])* [',' ['**' tfpdef [',']]] + | '**' tfpdef [',']) tfpdef: NAME [':' test] - varargslist: ((vfpdef ['=' test] ',')* - ('*' [vfpdef] (',' vfpdef ['=' test])* [',' '**' vfpdef] - | '**' vfpdef) - | vfpdef ['=' test] (',' vfpdef ['=' test])* [',']) + varargslist: (vfpdef ['=' test] (',' vfpdef ['=' test])* [',' [ + '*' [vfpdef] (',' vfpdef ['=' test])* [',' ['**' vfpdef [',']]] + | '**' vfpdef [',']]] + | '*' [vfpdef] (',' vfpdef ['=' test])* [',' ['**' vfpdef [',']]] + | '**' vfpdef [','] + ) vfpdef: NAME + */ int i, j, k, nposargs = 0, nkwonlyargs = 0; int nposdefaults = 0, found_default = 0; @@ -1370,7 +1451,8 @@ ast_for_arguments(struct compiling *c, const node *n) i += 2; /* the name and the comma */ break; case STAR: - if (i+1 >= NCH(n)) { + if (i+1 >= NCH(n) || + (i+2 == NCH(n) && TYPE(CHILD(n, i+1)) == COMMA)) { ast_error(c, CHILD(n, i), "named arguments must follow bare *"); return NULL; @@ -1687,14 +1769,21 @@ static int count_comp_fors(struct compiling *c, const node *n) { int n_fors = 0; + int is_async; count_comp_for: + is_async = 0; n_fors++; REQ(n, comp_for); - if (NCH(n) == 5) - n = CHILD(n, 4); - else + if (TYPE(CHILD(n, 0)) == ASYNC) { + is_async = 1; + } + if (NCH(n) == (5 + is_async)) { + n = CHILD(n, 4 + is_async); + } + else { return n_fors; + } count_comp_iter: REQ(n, comp_iter); n = CHILD(n, 0); @@ -1757,14 +1846,19 @@ ast_for_comprehension(struct compiling *c, const node *n) asdl_seq *t; expr_ty expression, first; node *for_ch; + int is_async = 0; REQ(n, comp_for); - for_ch = CHILD(n, 1); + if (TYPE(CHILD(n, 0)) == ASYNC) { + is_async = 1; + } + + for_ch = CHILD(n, 1 + is_async); t = ast_for_exprlist(c, for_ch, Store); if (!t) return NULL; - expression = ast_for_expr(c, CHILD(n, 3)); + expression = ast_for_expr(c, CHILD(n, 3 + is_async)); if (!expression) return NULL; @@ -1772,19 +1866,20 @@ ast_for_comprehension(struct compiling *c, const node *n) (x for x, in ...) has 1 element in t, but still requires a Tuple. */ first = (expr_ty)asdl_seq_GET(t, 0); if (NCH(for_ch) == 1) - comp = comprehension(first, expression, NULL, c->c_arena); + comp = comprehension(first, expression, NULL, + is_async, c->c_arena); else - comp = comprehension(Tuple(t, Store, first->lineno, first->col_offset, - c->c_arena), - expression, NULL, c->c_arena); + comp = comprehension(Tuple(t, Store, first->lineno, + first->col_offset, c->c_arena), + expression, NULL, is_async, c->c_arena); if (!comp) return NULL; - if (NCH(n) == 5) { + if (NCH(n) == (5 + is_async)) { int j, n_ifs; asdl_seq *ifs; - n = CHILD(n, 4); + n = CHILD(n, 4 + is_async); n_ifs = count_comp_ifs(c, n); if (n_ifs == -1) return NULL; @@ -1993,7 +2088,6 @@ ast_for_atom(struct compiling *c, const node *n) | '...' | 'None' | 'True' | 'False' */ node *ch = CHILD(n, 0); - int bytesmode = 0; switch (TYPE(ch)) { case NAME: { @@ -2015,7 +2109,7 @@ ast_for_atom(struct compiling *c, const node *n) return Name(name, Load, LINENO(n), n->n_col_offset, c->c_arena); } case STRING: { - PyObject *str = parsestrplus(c, n, &bytesmode); + expr_ty str = parsestrplus(c, n); if (!str) { const char *errtype = NULL; if (PyErr_ExceptionMatches(PyExc_UnicodeError)) @@ -2033,6 +2127,7 @@ ast_for_atom(struct compiling *c, const node *n) if (s) { PyOS_snprintf(buf, sizeof(buf), "(%s) %s", errtype, s); } else { + PyErr_Clear(); PyOS_snprintf(buf, sizeof(buf), "(%s) unknown error", errtype); } Py_XDECREF(errstr); @@ -2043,14 +2138,7 @@ ast_for_atom(struct compiling *c, const node *n) } return NULL; } - if (PyArena_AddPyObject(c->c_arena, str) < 0) { - Py_DECREF(str); - return NULL; - } - if (bytesmode) - return Bytes(str, LINENO(n), n->n_col_offset, c->c_arena); - else - return Str(str, LINENO(n), n->n_col_offset, c->c_arena); + return str; } case NUMBER: { PyObject *pynum = parsenumber(c, STR(ch)); @@ -2807,8 +2895,9 @@ static stmt_ty ast_for_expr_stmt(struct compiling *c, const node *n) { REQ(n, expr_stmt); - /* expr_stmt: testlist_star_expr (augassign (yield_expr|testlist) - | ('=' (yield_expr|testlist))*) + /* expr_stmt: testlist_star_expr (annassign | augassign (yield_expr|testlist) | + ('=' (yield_expr|testlist_star_expr))*) + annassign: ':' test ['=' test] testlist_star_expr: (test|star_expr) (',' test|star_expr)* [','] augassign: '+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^=' | '<<=' | '>>=' | '**=' | '//=' @@ -2860,6 +2949,76 @@ ast_for_expr_stmt(struct compiling *c, const node *n) return AugAssign(expr1, newoperator, expr2, LINENO(n), n->n_col_offset, c->c_arena); } + else if (TYPE(CHILD(n, 1)) == annassign) { + expr_ty expr1, expr2, expr3; + node *ch = CHILD(n, 0); + node *deep, *ann = CHILD(n, 1); + int simple = 1; + + /* we keep track of parens to qualify (x) as expression not name */ + deep = ch; + while (NCH(deep) == 1) { + deep = CHILD(deep, 0); + } + if (NCH(deep) > 0 && TYPE(CHILD(deep, 0)) == LPAR) { + simple = 0; + } + expr1 = ast_for_testlist(c, ch); + if (!expr1) { + return NULL; + } + switch (expr1->kind) { + case Name_kind: + if (forbidden_name(c, expr1->v.Name.id, n, 0)) { + return NULL; + } + expr1->v.Name.ctx = Store; + break; + case Attribute_kind: + if (forbidden_name(c, expr1->v.Attribute.attr, n, 1)) { + return NULL; + } + expr1->v.Attribute.ctx = Store; + break; + case Subscript_kind: + expr1->v.Subscript.ctx = Store; + break; + case List_kind: + ast_error(c, ch, + "only single target (not list) can be annotated"); + return NULL; + case Tuple_kind: + ast_error(c, ch, + "only single target (not tuple) can be annotated"); + return NULL; + default: + ast_error(c, ch, + "illegal target for annotation"); + return NULL; + } + + if (expr1->kind != Name_kind) { + simple = 0; + } + ch = CHILD(ann, 1); + expr2 = ast_for_expr(c, ch); + if (!expr2) { + return NULL; + } + if (NCH(ann) == 2) { + return AnnAssign(expr1, expr2, NULL, simple, + LINENO(n), n->n_col_offset, c->c_arena); + } + else { + ch = CHILD(ann, 3); + expr3 = ast_for_expr(c, ch); + if (!expr3) { + return NULL; + } + return AnnAssign(expr1, expr2, expr3, simple, + LINENO(n), n->n_col_offset, c->c_arena); + } + } else { int i; asdl_seq *targets; @@ -2994,9 +3153,6 @@ ast_for_flow_stmt(struct compiling *c, const node *n) "unexpected flow_stmt: %d", TYPE(ch)); return NULL; } - - PyErr_SetString(PyExc_SystemError, "unhandled flow statement"); - return NULL; } static alias_ty @@ -3210,14 +3366,14 @@ ast_for_import_stmt(struct compiling *c, const node *n) alias_ty import_alias = alias_for_import_name(c, n, 1); if (!import_alias) return NULL; - asdl_seq_SET(aliases, 0, import_alias); + asdl_seq_SET(aliases, 0, import_alias); } else { for (i = 0; i < NCH(n); i += 2) { alias_ty import_alias = alias_for_import_name(c, CHILD(n, i), 1); if (!import_alias) return NULL; - asdl_seq_SET(aliases, i / 2, import_alias); + asdl_seq_SET(aliases, i / 2, import_alias); } } if (mod != NULL) @@ -3883,7 +4039,7 @@ ast_for_stmt(struct compiling *c, const node *n) } static PyObject * -parsenumber(struct compiling *c, const char *s) +parsenumber_raw(struct compiling *c, const char *s) { const char *end; long x; @@ -3926,6 +4082,31 @@ parsenumber(struct compiling *c, const char *s) } static PyObject * +parsenumber(struct compiling *c, const char *s) +{ + char *dup, *end; + PyObject *res = NULL; + + assert(s != NULL); + + if (strchr(s, '_') == NULL) { + return parsenumber_raw(c, s); + } + /* Create a duplicate without underscores. */ + dup = PyMem_Malloc(strlen(s) + 1); + end = dup; + for (; *s; s++) { + if (*s != '_') { + *end++ = *s; + } + } + *end = '\0'; + res = parsenumber_raw(c, dup); + PyMem_Free(dup); + return res; +} + +static PyObject * decode_utf8(struct compiling *c, const char **sPtr, const char *end) { const char *s, *t; @@ -3936,85 +4117,924 @@ decode_utf8(struct compiling *c, const char **sPtr, const char *end) return PyUnicode_DecodeUTF8(t, s - t, NULL); } +static int +warn_invalid_escape_sequence(struct compiling *c, const node *n, + char first_invalid_escape_char) +{ + PyObject *msg = PyUnicode_FromFormat("invalid escape sequence \\%c", + first_invalid_escape_char); + if (msg == NULL) { + return -1; + } + if (PyErr_WarnExplicitObject(PyExc_DeprecationWarning, msg, + c->c_filename, LINENO(n), + NULL, NULL) < 0 && + PyErr_ExceptionMatches(PyExc_DeprecationWarning)) + { + const char *s; + + /* Replace the DeprecationWarning exception with a SyntaxError + to get a more accurate error report */ + PyErr_Clear(); + + s = PyUnicode_AsUTF8(msg); + if (s != NULL) { + ast_error(c, n, s); + } + Py_DECREF(msg); + return -1; + } + Py_DECREF(msg); + return 0; +} + static PyObject * -decode_unicode(struct compiling *c, const char *s, size_t len, int rawmode, const char *encoding) +decode_unicode_with_escapes(struct compiling *c, const node *n, const char *s, + size_t len) { PyObject *v, *u; char *buf; char *p; const char *end; - if (encoding == NULL) { - u = NULL; - } else { - /* check for integer overflow */ - if (len > PY_SIZE_MAX / 6) + /* check for integer overflow */ + if (len > SIZE_MAX / 6) + return NULL; + /* "ä" (2 bytes) may become "\U000000E4" (10 bytes), or 1:5 + "\ä" (3 bytes) may become "\u005c\U000000E4" (16 bytes), or ~1:6 */ + u = PyBytes_FromStringAndSize((char *)NULL, len * 6); + if (u == NULL) + return NULL; + p = buf = PyBytes_AsString(u); + end = s + len; + while (s < end) { + if (*s == '\\') { + *p++ = *s++; + if (*s & 0x80) { + strcpy(p, "u005c"); + p += 5; + } + } + if (*s & 0x80) { /* XXX inefficient */ + PyObject *w; + int kind; + void *data; + Py_ssize_t len, i; + w = decode_utf8(c, &s, end); + if (w == NULL) { + Py_DECREF(u); + return NULL; + } + kind = PyUnicode_KIND(w); + data = PyUnicode_DATA(w); + len = PyUnicode_GET_LENGTH(w); + for (i = 0; i < len; i++) { + Py_UCS4 chr = PyUnicode_READ(kind, data, i); + sprintf(p, "\\U%08x", chr); + p += 10; + } + /* Should be impossible to overflow */ + assert(p - buf <= Py_SIZE(u)); + Py_DECREF(w); + } else { + *p++ = *s++; + } + } + len = p - buf; + s = buf; + + const char *first_invalid_escape; + v = _PyUnicode_DecodeUnicodeEscape(s, len, NULL, &first_invalid_escape); + + if (v != NULL && first_invalid_escape != NULL) { + if (warn_invalid_escape_sequence(c, n, *first_invalid_escape) < 0) { + /* We have not decref u before because first_invalid_escape points + inside u. */ + Py_XDECREF(u); + Py_DECREF(v); return NULL; - /* "ä" (2 bytes) may become "\U000000E4" (10 bytes), or 1:5 - "\ä" (3 bytes) may become "\u005c\U000000E4" (16 bytes), or ~1:6 */ - u = PyBytes_FromStringAndSize((char *)NULL, len * 6); - if (u == NULL) + } + } + Py_XDECREF(u); + return v; +} + +static PyObject * +decode_bytes_with_escapes(struct compiling *c, const node *n, const char *s, + size_t len) +{ + const char *first_invalid_escape; + PyObject *result = _PyBytes_DecodeEscape(s, len, NULL, 0, NULL, + &first_invalid_escape); + if (result == NULL) + return NULL; + + if (first_invalid_escape != NULL) { + if (warn_invalid_escape_sequence(c, n, *first_invalid_escape) < 0) { + Py_DECREF(result); return NULL; - p = buf = PyBytes_AsString(u); - end = s + len; - while (s < end) { - if (*s == '\\') { - *p++ = *s++; - if (*s & 0x80) { - strcpy(p, "u005c"); - p += 5; + } + } + return result; +} + +/* Compile this expression in to an expr_ty. Add parens around the + expression, in order to allow leading spaces in the expression. */ +static expr_ty +fstring_compile_expr(const char *expr_start, const char *expr_end, + struct compiling *c, const node *n) + +{ + int all_whitespace = 1; + int kind; + void *data; + PyCompilerFlags cf; + mod_ty mod; + char *str; + PyObject *o; + Py_ssize_t len; + Py_ssize_t i; + + assert(expr_end >= expr_start); + assert(*(expr_start-1) == '{'); + assert(*expr_end == '}' || *expr_end == '!' || *expr_end == ':'); + + /* We know there are no escapes here, because backslashes are not allowed, + and we know it's utf-8 encoded (per PEP 263). But, in order to check + that each char is not whitespace, we need to decode it to unicode. + Which is unfortunate, but such is life. */ + + /* If the substring is all whitespace, it's an error. We need to catch + this here, and not when we call PyParser_ASTFromString, because turning + the expression '' in to '()' would go from being invalid to valid. */ + /* Note that this code says an empty string is all whitespace. That's + important. There's a test for it: f'{}'. */ + o = PyUnicode_DecodeUTF8(expr_start, expr_end-expr_start, NULL); + if (o == NULL) + return NULL; + len = PyUnicode_GET_LENGTH(o); + kind = PyUnicode_KIND(o); + data = PyUnicode_DATA(o); + for (i = 0; i < len; i++) { + if (!Py_UNICODE_ISSPACE(PyUnicode_READ(kind, data, i))) { + all_whitespace = 0; + break; + } + } + Py_DECREF(o); + if (all_whitespace) { + ast_error(c, n, "f-string: empty expression not allowed"); + return NULL; + } + + /* Reuse len to be the length of the utf-8 input string. */ + len = expr_end - expr_start; + /* Allocate 3 extra bytes: open paren, close paren, null byte. */ + str = PyMem_RawMalloc(len + 3); + if (str == NULL) + return NULL; + + str[0] = '('; + memcpy(str+1, expr_start, len); + str[len+1] = ')'; + str[len+2] = 0; + + cf.cf_flags = PyCF_ONLY_AST; + mod = PyParser_ASTFromString(str, "<fstring>", + Py_eval_input, &cf, c->c_arena); + PyMem_RawFree(str); + if (!mod) + return NULL; + return mod->v.Expression.body; +} + +/* Return -1 on error. + + Return 0 if we reached the end of the literal. + + Return 1 if we haven't reached the end of the literal, but we want + the caller to process the literal up to this point. Used for + doubled braces. +*/ +static int +fstring_find_literal(const char **str, const char *end, int raw, + PyObject **literal, int recurse_lvl, + struct compiling *c, const node *n) +{ + /* Get any literal string. It ends when we hit an un-doubled left + brace (which isn't part of a unicode name escape such as + "\N{EULER CONSTANT}"), or the end of the string. */ + + const char *literal_start = *str; + const char *literal_end; + int in_named_escape = 0; + int result = 0; + + assert(*literal == NULL); + for (; *str < end; (*str)++) { + char ch = **str; + if (!in_named_escape && ch == '{' && (*str)-literal_start >= 2 && + *(*str-2) == '\\' && *(*str-1) == 'N') { + in_named_escape = 1; + } else if (in_named_escape && ch == '}') { + in_named_escape = 0; + } else if (ch == '{' || ch == '}') { + /* Check for doubled braces, but only at the top level. If + we checked at every level, then f'{0:{3}}' would fail + with the two closing braces. */ + if (recurse_lvl == 0) { + if (*str+1 < end && *(*str+1) == ch) { + /* We're going to tell the caller that the literal ends + here, but that they should continue scanning. But also + skip over the second brace when we resume scanning. */ + literal_end = *str+1; + *str += 2; + result = 1; + goto done; } - } - if (*s & 0x80) { /* XXX inefficient */ - PyObject *w; - int kind; - void *data; - Py_ssize_t len, i; - w = decode_utf8(c, &s, end); - if (w == NULL) { - Py_DECREF(u); - return NULL; + + /* Where a single '{' is the start of a new expression, a + single '}' is not allowed. */ + if (ch == '}') { + ast_error(c, n, "f-string: single '}' is not allowed"); + return -1; } - kind = PyUnicode_KIND(w); - data = PyUnicode_DATA(w); - len = PyUnicode_GET_LENGTH(w); - for (i = 0; i < len; i++) { - Py_UCS4 chr = PyUnicode_READ(kind, data, i); - sprintf(p, "\\U%08x", chr); - p += 10; + } + /* We're either at a '{', which means we're starting another + expression; or a '}', which means we're at the end of this + f-string (for a nested format_spec). */ + break; + } + } + literal_end = *str; + assert(*str <= end); + assert(*str == end || **str == '{' || **str == '}'); +done: + if (literal_start != literal_end) { + if (raw) + *literal = PyUnicode_DecodeUTF8Stateful(literal_start, + literal_end-literal_start, + NULL, NULL); + else + *literal = decode_unicode_with_escapes(c, n, literal_start, + literal_end-literal_start); + if (!*literal) + return -1; + } + return result; +} + +/* Forward declaration because parsing is recursive. */ +static expr_ty +fstring_parse(const char **str, const char *end, int raw, int recurse_lvl, + struct compiling *c, const node *n); + +/* Parse the f-string at *str, ending at end. We know *str starts an + expression (so it must be a '{'). Returns the FormattedValue node, + which includes the expression, conversion character, and + format_spec expression. + + Note that I don't do a perfect job here: I don't make sure that a + closing brace doesn't match an opening paren, for example. It + doesn't need to error on all invalid expressions, just correctly + find the end of all valid ones. Any errors inside the expression + will be caught when we parse it later. */ +static int +fstring_find_expr(const char **str, const char *end, int raw, int recurse_lvl, + expr_ty *expression, struct compiling *c, const node *n) +{ + /* Return -1 on error, else 0. */ + + const char *expr_start; + const char *expr_end; + expr_ty simple_expression; + expr_ty format_spec = NULL; /* Optional format specifier. */ + int conversion = -1; /* The conversion char. -1 if not specified. */ + + /* 0 if we're not in a string, else the quote char we're trying to + match (single or double quote). */ + char quote_char = 0; + + /* If we're inside a string, 1=normal, 3=triple-quoted. */ + int string_type = 0; + + /* Keep track of nesting level for braces/parens/brackets in + expressions. */ + Py_ssize_t nested_depth = 0; + + /* Can only nest one level deep. */ + if (recurse_lvl >= 2) { + ast_error(c, n, "f-string: expressions nested too deeply"); + return -1; + } + + /* The first char must be a left brace, or we wouldn't have gotten + here. Skip over it. */ + assert(**str == '{'); + *str += 1; + + expr_start = *str; + for (; *str < end; (*str)++) { + char ch; + + /* Loop invariants. */ + assert(nested_depth >= 0); + assert(*str >= expr_start && *str < end); + if (quote_char) + assert(string_type == 1 || string_type == 3); + else + assert(string_type == 0); + + ch = **str; + /* Nowhere inside an expression is a backslash allowed. */ + if (ch == '\\') { + /* Error: can't include a backslash character, inside + parens or strings or not. */ + ast_error(c, n, "f-string expression part " + "cannot include a backslash"); + return -1; + } + if (quote_char) { + /* We're inside a string. See if we're at the end. */ + /* This code needs to implement the same non-error logic + as tok_get from tokenizer.c, at the letter_quote + label. To actually share that code would be a + nightmare. But, it's unlikely to change and is small, + so duplicate it here. Note we don't need to catch all + of the errors, since they'll be caught when parsing the + expression. We just need to match the non-error + cases. Thus we can ignore \n in single-quoted strings, + for example. Or non-terminated strings. */ + if (ch == quote_char) { + /* Does this match the string_type (single or triple + quoted)? */ + if (string_type == 3) { + if (*str+2 < end && *(*str+1) == ch && *(*str+2) == ch) { + /* We're at the end of a triple quoted string. */ + *str += 2; + string_type = 0; + quote_char = 0; + continue; + } + } else { + /* We're at the end of a normal string. */ + quote_char = 0; + string_type = 0; + continue; } - /* Should be impossible to overflow */ - assert(p - buf <= Py_SIZE(u)); - Py_DECREF(w); + } + } else if (ch == '\'' || ch == '"') { + /* Is this a triple quoted string? */ + if (*str+2 < end && *(*str+1) == ch && *(*str+2) == ch) { + string_type = 3; + *str += 2; } else { - *p++ = *s++; + /* Start of a normal string. */ + string_type = 1; } + /* Start looking for the end of the string. */ + quote_char = ch; + } else if (ch == '[' || ch == '{' || ch == '(') { + nested_depth++; + } else if (nested_depth != 0 && + (ch == ']' || ch == '}' || ch == ')')) { + nested_depth--; + } else if (ch == '#') { + /* Error: can't include a comment character, inside parens + or not. */ + ast_error(c, n, "f-string expression part cannot include '#'"); + return -1; + } else if (nested_depth == 0 && + (ch == '!' || ch == ':' || ch == '}')) { + /* First, test for the special case of "!=". Since '=' is + not an allowed conversion character, nothing is lost in + this test. */ + if (ch == '!' && *str+1 < end && *(*str+1) == '=') { + /* This isn't a conversion character, just continue. */ + continue; + } + /* Normal way out of this loop. */ + break; + } else { + /* Just consume this char and loop around. */ } - len = p - buf; - s = buf; } - if (rawmode) - v = PyUnicode_DecodeRawUnicodeEscape(s, len, NULL); - else - v = PyUnicode_DecodeUnicodeEscape(s, len, NULL); - Py_XDECREF(u); - return v; + expr_end = *str; + /* If we leave this loop in a string or with mismatched parens, we + don't care. We'll get a syntax error when compiling the + expression. But, we can produce a better error message, so + let's just do that.*/ + if (quote_char) { + ast_error(c, n, "f-string: unterminated string"); + return -1; + } + if (nested_depth) { + ast_error(c, n, "f-string: mismatched '(', '{', or '['"); + return -1; + } + + if (*str >= end) + goto unexpected_end_of_string; + + /* Compile the expression as soon as possible, so we show errors + related to the expression before errors related to the + conversion or format_spec. */ + simple_expression = fstring_compile_expr(expr_start, expr_end, c, n); + if (!simple_expression) + return -1; + + /* Check for a conversion char, if present. */ + if (**str == '!') { + *str += 1; + if (*str >= end) + goto unexpected_end_of_string; + + conversion = **str; + *str += 1; + + /* Validate the conversion. */ + if (!(conversion == 's' || conversion == 'r' + || conversion == 'a')) { + ast_error(c, n, "f-string: invalid conversion character: " + "expected 's', 'r', or 'a'"); + return -1; + } + } + + /* Check for the format spec, if present. */ + if (*str >= end) + goto unexpected_end_of_string; + if (**str == ':') { + *str += 1; + if (*str >= end) + goto unexpected_end_of_string; + + /* Parse the format spec. */ + format_spec = fstring_parse(str, end, raw, recurse_lvl+1, c, n); + if (!format_spec) + return -1; + } + + if (*str >= end || **str != '}') + goto unexpected_end_of_string; + + /* We're at a right brace. Consume it. */ + assert(*str < end); + assert(**str == '}'); + *str += 1; + + /* And now create the FormattedValue node that represents this + entire expression with the conversion and format spec. */ + *expression = FormattedValue(simple_expression, conversion, + format_spec, LINENO(n), n->n_col_offset, + c->c_arena); + if (!*expression) + return -1; + + return 0; + +unexpected_end_of_string: + ast_error(c, n, "f-string: expecting '}'"); + return -1; } -/* s is a Python string literal, including the bracketing quote characters, - * and r &/or b prefixes (if any), and embedded escape sequences (if any). - * parsestr parses it, and returns the decoded Python string object. - */ -static PyObject * -parsestr(struct compiling *c, const node *n, int *bytesmode) +/* Return -1 on error. + + Return 0 if we have a literal (possible zero length) and an + expression (zero length if at the end of the string. + + Return 1 if we have a literal, but no expression, and we want the + caller to call us again. This is used to deal with doubled + braces. + + When called multiple times on the string 'a{{b{0}c', this function + will return: + + 1. the literal 'a{' with no expression, and a return value + of 1. Despite the fact that there's no expression, the return + value of 1 means we're not finished yet. + + 2. the literal 'b' and the expression '0', with a return value of + 0. The fact that there's an expression means we're not finished. + + 3. literal 'c' with no expression and a return value of 0. The + combination of the return value of 0 with no expression means + we're finished. +*/ +static int +fstring_find_literal_and_expr(const char **str, const char *end, int raw, + int recurse_lvl, PyObject **literal, + expr_ty *expression, + struct compiling *c, const node *n) +{ + int result; + + assert(*literal == NULL && *expression == NULL); + + /* Get any literal string. */ + result = fstring_find_literal(str, end, raw, literal, recurse_lvl, c, n); + if (result < 0) + goto error; + + assert(result == 0 || result == 1); + + if (result == 1) + /* We have a literal, but don't look at the expression. */ + return 1; + + if (*str >= end || **str == '}') + /* We're at the end of the string or the end of a nested + f-string: no expression. The top-level error case where we + expect to be at the end of the string but we're at a '}' is + handled later. */ + return 0; + + /* We must now be the start of an expression, on a '{'. */ + assert(**str == '{'); + + if (fstring_find_expr(str, end, raw, recurse_lvl, expression, c, n) < 0) + goto error; + + return 0; + +error: + Py_CLEAR(*literal); + return -1; +} + +#define EXPRLIST_N_CACHED 64 + +typedef struct { + /* Incrementally build an array of expr_ty, so be used in an + asdl_seq. Cache some small but reasonably sized number of + expr_ty's, and then after that start dynamically allocating, + doubling the number allocated each time. Note that the f-string + f'{0}a{1}' contains 3 expr_ty's: 2 FormattedValue's, and one + Str for the literal 'a'. So you add expr_ty's about twice as + fast as you add exressions in an f-string. */ + + Py_ssize_t allocated; /* Number we've allocated. */ + Py_ssize_t size; /* Number we've used. */ + expr_ty *p; /* Pointer to the memory we're actually + using. Will point to 'data' until we + start dynamically allocating. */ + expr_ty data[EXPRLIST_N_CACHED]; +} ExprList; + +#ifdef NDEBUG +#define ExprList_check_invariants(l) +#else +static void +ExprList_check_invariants(ExprList *l) +{ + /* Check our invariants. Make sure this object is "live", and + hasn't been deallocated. */ + assert(l->size >= 0); + assert(l->p != NULL); + if (l->size <= EXPRLIST_N_CACHED) + assert(l->data == l->p); +} +#endif + +static void +ExprList_Init(ExprList *l) +{ + l->allocated = EXPRLIST_N_CACHED; + l->size = 0; + + /* Until we start allocating dynamically, p points to data. */ + l->p = l->data; + + ExprList_check_invariants(l); +} + +static int +ExprList_Append(ExprList *l, expr_ty exp) +{ + ExprList_check_invariants(l); + if (l->size >= l->allocated) { + /* We need to alloc (or realloc) the memory. */ + Py_ssize_t new_size = l->allocated * 2; + + /* See if we've ever allocated anything dynamically. */ + if (l->p == l->data) { + Py_ssize_t i; + /* We're still using the cached data. Switch to + alloc-ing. */ + l->p = PyMem_RawMalloc(sizeof(expr_ty) * new_size); + if (!l->p) + return -1; + /* Copy the cached data into the new buffer. */ + for (i = 0; i < l->size; i++) + l->p[i] = l->data[i]; + } else { + /* Just realloc. */ + expr_ty *tmp = PyMem_RawRealloc(l->p, sizeof(expr_ty) * new_size); + if (!tmp) { + PyMem_RawFree(l->p); + l->p = NULL; + return -1; + } + l->p = tmp; + } + + l->allocated = new_size; + assert(l->allocated == 2 * l->size); + } + + l->p[l->size++] = exp; + + ExprList_check_invariants(l); + return 0; +} + +static void +ExprList_Dealloc(ExprList *l) +{ + ExprList_check_invariants(l); + + /* If there's been an error, or we've never dynamically allocated, + do nothing. */ + if (!l->p || l->p == l->data) { + /* Do nothing. */ + } else { + /* We have dynamically allocated. Free the memory. */ + PyMem_RawFree(l->p); + } + l->p = NULL; + l->size = -1; +} + +static asdl_seq * +ExprList_Finish(ExprList *l, PyArena *arena) +{ + asdl_seq *seq; + + ExprList_check_invariants(l); + + /* Allocate the asdl_seq and copy the expressions in to it. */ + seq = _Py_asdl_seq_new(l->size, arena); + if (seq) { + Py_ssize_t i; + for (i = 0; i < l->size; i++) + asdl_seq_SET(seq, i, l->p[i]); + } + ExprList_Dealloc(l); + return seq; +} + +/* The FstringParser is designed to add a mix of strings and + f-strings, and concat them together as needed. Ultimately, it + generates an expr_ty. */ +typedef struct { + PyObject *last_str; + ExprList expr_list; + int fmode; +} FstringParser; + +#ifdef NDEBUG +#define FstringParser_check_invariants(state) +#else +static void +FstringParser_check_invariants(FstringParser *state) +{ + if (state->last_str) + assert(PyUnicode_CheckExact(state->last_str)); + ExprList_check_invariants(&state->expr_list); +} +#endif + +static void +FstringParser_Init(FstringParser *state) +{ + state->last_str = NULL; + state->fmode = 0; + ExprList_Init(&state->expr_list); + FstringParser_check_invariants(state); +} + +static void +FstringParser_Dealloc(FstringParser *state) +{ + FstringParser_check_invariants(state); + + Py_XDECREF(state->last_str); + ExprList_Dealloc(&state->expr_list); +} + +/* Make a Str node, but decref the PyUnicode object being added. */ +static expr_ty +make_str_node_and_del(PyObject **str, struct compiling *c, const node* n) +{ + PyObject *s = *str; + *str = NULL; + assert(PyUnicode_CheckExact(s)); + if (PyArena_AddPyObject(c->c_arena, s) < 0) { + Py_DECREF(s); + return NULL; + } + return Str(s, LINENO(n), n->n_col_offset, c->c_arena); +} + +/* Add a non-f-string (that is, a regular literal string). str is + decref'd. */ +static int +FstringParser_ConcatAndDel(FstringParser *state, PyObject *str) +{ + FstringParser_check_invariants(state); + + assert(PyUnicode_CheckExact(str)); + + if (PyUnicode_GET_LENGTH(str) == 0) { + Py_DECREF(str); + return 0; + } + + if (!state->last_str) { + /* We didn't have a string before, so just remember this one. */ + state->last_str = str; + } else { + /* Concatenate this with the previous string. */ + PyUnicode_AppendAndDel(&state->last_str, str); + if (!state->last_str) + return -1; + } + FstringParser_check_invariants(state); + return 0; +} + +/* Parse an f-string. The f-string is in *str to end, with no + 'f' or quotes. */ +static int +FstringParser_ConcatFstring(FstringParser *state, const char **str, + const char *end, int raw, int recurse_lvl, + struct compiling *c, const node *n) +{ + FstringParser_check_invariants(state); + state->fmode = 1; + + /* Parse the f-string. */ + while (1) { + PyObject *literal = NULL; + expr_ty expression = NULL; + + /* If there's a zero length literal in front of the + expression, literal will be NULL. If we're at the end of + the f-string, expression will be NULL (unless result == 1, + see below). */ + int result = fstring_find_literal_and_expr(str, end, raw, recurse_lvl, + &literal, &expression, + c, n); + if (result < 0) + return -1; + + /* Add the literal, if any. */ + if (!literal) { + /* Do nothing. Just leave last_str alone (and possibly + NULL). */ + } else if (!state->last_str) { + state->last_str = literal; + literal = NULL; + } else { + /* We have a literal, concatenate it. */ + assert(PyUnicode_GET_LENGTH(literal) != 0); + if (FstringParser_ConcatAndDel(state, literal) < 0) + return -1; + literal = NULL; + } + assert(!state->last_str || + PyUnicode_GET_LENGTH(state->last_str) != 0); + + /* We've dealt with the literal now. It can't be leaked on further + errors. */ + assert(literal == NULL); + + /* See if we should just loop around to get the next literal + and expression, while ignoring the expression this + time. This is used for un-doubling braces, as an + optimization. */ + if (result == 1) + continue; + + if (!expression) + /* We're done with this f-string. */ + break; + + /* We know we have an expression. Convert any existing string + to a Str node. */ + if (!state->last_str) { + /* Do nothing. No previous literal. */ + } else { + /* Convert the existing last_str literal to a Str node. */ + expr_ty str = make_str_node_and_del(&state->last_str, c, n); + if (!str || ExprList_Append(&state->expr_list, str) < 0) + return -1; + } + + if (ExprList_Append(&state->expr_list, expression) < 0) + return -1; + } + + /* If recurse_lvl is zero, then we must be at the end of the + string. Otherwise, we must be at a right brace. */ + + if (recurse_lvl == 0 && *str < end-1) { + ast_error(c, n, "f-string: unexpected end of string"); + return -1; + } + if (recurse_lvl != 0 && **str != '}') { + ast_error(c, n, "f-string: expecting '}'"); + return -1; + } + + FstringParser_check_invariants(state); + return 0; +} + +/* Convert the partial state reflected in last_str and expr_list to an + expr_ty. The expr_ty can be a Str, or a JoinedStr. */ +static expr_ty +FstringParser_Finish(FstringParser *state, struct compiling *c, + const node *n) +{ + asdl_seq *seq; + + FstringParser_check_invariants(state); + + /* If we're just a constant string with no expressions, return + that. */ + if (!state->fmode) { + assert(!state->expr_list.size); + if (!state->last_str) { + /* Create a zero length string. */ + state->last_str = PyUnicode_FromStringAndSize(NULL, 0); + if (!state->last_str) + goto error; + } + return make_str_node_and_del(&state->last_str, c, n); + } + + /* Create a Str node out of last_str, if needed. It will be the + last node in our expression list. */ + if (state->last_str) { + expr_ty str = make_str_node_and_del(&state->last_str, c, n); + if (!str || ExprList_Append(&state->expr_list, str) < 0) + goto error; + } + /* This has already been freed. */ + assert(state->last_str == NULL); + + seq = ExprList_Finish(&state->expr_list, c->c_arena); + if (!seq) + goto error; + + return JoinedStr(seq, LINENO(n), n->n_col_offset, c->c_arena); + +error: + FstringParser_Dealloc(state); + return NULL; +} + +/* Given an f-string (with no 'f' or quotes) that's in *str and ends + at end, parse it into an expr_ty. Return NULL on error. Adjust + str to point past the parsed portion. */ +static expr_ty +fstring_parse(const char **str, const char *end, int raw, int recurse_lvl, + struct compiling *c, const node *n) +{ + FstringParser state; + + FstringParser_Init(&state); + if (FstringParser_ConcatFstring(&state, str, end, raw, recurse_lvl, + c, n) < 0) { + FstringParser_Dealloc(&state); + return NULL; + } + + return FstringParser_Finish(&state, c, n); +} + +/* n is a Python string literal, including the bracketing quote + characters, and r, b, u, &/or f prefixes (if any), and embedded + escape sequences (if any). parsestr parses it, and sets *result to + decoded Python string object. If the string is an f-string, set + *fstr and *fstrlen to the unparsed string object. Return 0 if no + errors occurred. +*/ +static int +parsestr(struct compiling *c, const node *n, int *bytesmode, int *rawmode, + PyObject **result, const char **fstr, Py_ssize_t *fstrlen) { size_t len; const char *s = STR(n); int quote = Py_CHARMASK(*s); - int rawmode = 0; - int need_encoding; + int fmode = 0; + *bytesmode = 0; + *rawmode = 0; + *result = NULL; + *fstr = NULL; if (Py_ISALPHA(quote)) { - while (!*bytesmode || !rawmode) { + while (!*bytesmode || !*rawmode) { if (quote == 'b' || quote == 'B') { quote = *++s; *bytesmode = 1; @@ -4024,114 +5044,170 @@ parsestr(struct compiling *c, const node *n, int *bytesmode) } else if (quote == 'r' || quote == 'R') { quote = *++s; - rawmode = 1; + *rawmode = 1; + } + else if (quote == 'f' || quote == 'F') { + quote = *++s; + fmode = 1; } else { break; } } } + if (fmode && *bytesmode) { + PyErr_BadInternalCall(); + return -1; + } if (quote != '\'' && quote != '\"') { PyErr_BadInternalCall(); - return NULL; + return -1; } + /* Skip the leading quote char. */ s++; len = strlen(s); if (len > INT_MAX) { PyErr_SetString(PyExc_OverflowError, "string to parse is too long"); - return NULL; + return -1; } if (s[--len] != quote) { + /* Last quote char must match the first. */ PyErr_BadInternalCall(); - return NULL; + return -1; } if (len >= 4 && s[0] == quote && s[1] == quote) { + /* A triple quoted string. We've already skipped one quote at + the start and one at the end of the string. Now skip the + two at the start. */ s += 2; len -= 2; + /* And check that the last two match. */ if (s[--len] != quote || s[--len] != quote) { PyErr_BadInternalCall(); - return NULL; + return -1; } } - if (!*bytesmode && !rawmode) { - return decode_unicode(c, s, len, rawmode, c->c_encoding); + + if (fmode) { + /* Just return the bytes. The caller will parse the resulting + string. */ + *fstr = s; + *fstrlen = len; + return 0; } + + /* Not an f-string. */ + /* Avoid invoking escape decoding routines if possible. */ + *rawmode = *rawmode || strchr(s, '\\') == NULL; if (*bytesmode) { - /* Disallow non-ascii characters (but not escapes) */ + /* Disallow non-ASCII characters. */ const char *ch; for (ch = s; *ch; ch++) { if (Py_CHARMASK(*ch) >= 0x80) { ast_error(c, n, "bytes can only contain ASCII " "literal characters."); - return NULL; + return -1; } } + if (*rawmode) + *result = PyBytes_FromStringAndSize(s, len); + else + *result = decode_bytes_with_escapes(c, n, s, len); + } else { + if (*rawmode) + *result = PyUnicode_DecodeUTF8Stateful(s, len, NULL, NULL); + else + *result = decode_unicode_with_escapes(c, n, s, len); } - need_encoding = (!*bytesmode && c->c_encoding != NULL && - strcmp(c->c_encoding, "utf-8") != 0); - if (rawmode || strchr(s, '\\') == NULL) { - if (need_encoding) { - PyObject *v, *u = PyUnicode_DecodeUTF8(s, len, NULL); - if (u == NULL || !*bytesmode) - return u; - v = PyUnicode_AsEncodedString(u, c->c_encoding, NULL); - Py_DECREF(u); - return v; - } else if (*bytesmode) { - return PyBytes_FromStringAndSize(s, len); - } else if (strcmp(c->c_encoding, "utf-8") == 0) { - return PyUnicode_FromStringAndSize(s, len); - } else { - return PyUnicode_DecodeLatin1(s, len, NULL); - } - } - return PyBytes_DecodeEscape(s, len, NULL, 1, - need_encoding ? c->c_encoding : NULL); + return *result == NULL ? -1 : 0; } -/* Build a Python string object out of a STRING+ atom. This takes care of - * compile-time literal catenation, calling parsestr() on each piece, and - * pasting the intermediate results together. - */ -static PyObject * -parsestrplus(struct compiling *c, const node *n, int *bytesmode) +/* Accepts a STRING+ atom, and produces an expr_ty node. Run through + each STRING atom, and process it as needed. For bytes, just + concatenate them together, and the result will be a Bytes node. For + normal strings and f-strings, concatenate them together. The result + will be a Str node if there were no f-strings; a FormattedValue + node if there's just an f-string (with no leading or trailing + literals), or a JoinedStr node if there are multiple f-strings or + any literals involved. */ +static expr_ty +parsestrplus(struct compiling *c, const node *n) { - PyObject *v; + int bytesmode = 0; + PyObject *bytes_str = NULL; int i; - REQ(CHILD(n, 0), STRING); - v = parsestr(c, CHILD(n, 0), bytesmode); - if (v != NULL) { - /* String literal concatenation */ - for (i = 1; i < NCH(n); i++) { - PyObject *s; - int subbm = 0; - s = parsestr(c, CHILD(n, i), &subbm); - if (s == NULL) - goto onError; - if (*bytesmode != subbm) { - ast_error(c, n, "cannot mix bytes and nonbytes literals"); - Py_DECREF(s); - goto onError; - } - if (PyBytes_Check(v) && PyBytes_Check(s)) { - PyBytes_ConcatAndDel(&v, s); - if (v == NULL) - goto onError; - } - else { - PyObject *temp = PyUnicode_Concat(v, s); - Py_DECREF(s); - Py_DECREF(v); - v = temp; - if (v == NULL) - goto onError; + + FstringParser state; + FstringParser_Init(&state); + + for (i = 0; i < NCH(n); i++) { + int this_bytesmode; + int this_rawmode; + PyObject *s; + const char *fstr; + Py_ssize_t fstrlen = -1; /* Silence a compiler warning. */ + + REQ(CHILD(n, i), STRING); + if (parsestr(c, CHILD(n, i), &this_bytesmode, &this_rawmode, &s, + &fstr, &fstrlen) != 0) + goto error; + + /* Check that we're not mixing bytes with unicode. */ + if (i != 0 && bytesmode != this_bytesmode) { + ast_error(c, n, "cannot mix bytes and nonbytes literals"); + /* s is NULL if the current string part is an f-string. */ + Py_XDECREF(s); + goto error; + } + bytesmode = this_bytesmode; + + if (fstr != NULL) { + int result; + assert(s == NULL && !bytesmode); + /* This is an f-string. Parse and concatenate it. */ + result = FstringParser_ConcatFstring(&state, &fstr, fstr+fstrlen, + this_rawmode, 0, c, n); + if (result < 0) + goto error; + } else { + /* A string or byte string. */ + assert(s != NULL && fstr == NULL); + + assert(bytesmode ? PyBytes_CheckExact(s) : + PyUnicode_CheckExact(s)); + + if (bytesmode) { + /* For bytes, concat as we go. */ + if (i == 0) { + /* First time, just remember this value. */ + bytes_str = s; + } else { + PyBytes_ConcatAndDel(&bytes_str, s); + if (!bytes_str) + goto error; + } + } else { + /* This is a regular string. Concatenate it. */ + if (FstringParser_ConcatAndDel(&state, s) < 0) + goto error; } } } - return v; + if (bytesmode) { + /* Just return the bytes object and we're done. */ + if (PyArena_AddPyObject(c->c_arena, bytes_str) < 0) + goto error; + return Bytes(bytes_str, LINENO(n), n->n_col_offset, c->c_arena); + } + + /* We're not a bytes string, bytes_str should never have been set. */ + assert(bytes_str == NULL); + + return FstringParser_Finish(&state, c, n); - onError: - Py_XDECREF(v); +error: + Py_XDECREF(bytes_str); + FstringParser_Dealloc(&state); return NULL; } |