summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2022-11-16 03:59:19 (GMT)
committerGitHub <noreply@github.com>2022-11-16 03:59:19 (GMT)
commite37744f289af00c6f6eba83f7abfb932b63de9e0 (patch)
tree442f19bd21cf3c6f5abcf733282baa710625ee7c
parent4636df9febc6b8ae977ee8515749f189dfff7aab (diff)
downloadcpython-e37744f289af00c6f6eba83f7abfb932b63de9e0.zip
cpython-e37744f289af00c6f6eba83f7abfb932b63de9e0.tar.gz
cpython-e37744f289af00c6f6eba83f7abfb932b63de9e0.tar.bz2
GH-98831: Implement basic cache effects (#99313)
-rw-r--r--Python/bytecodes.c69
-rw-r--r--Python/generated_cases.c.h35
-rw-r--r--Tools/cases_generator/generate_cases.py102
-rw-r--r--Tools/cases_generator/parser.py124
4 files changed, 202 insertions, 128 deletions
diff --git a/Python/bytecodes.c b/Python/bytecodes.c
index 69ee741..1575b53 100644
--- a/Python/bytecodes.c
+++ b/Python/bytecodes.c
@@ -76,13 +76,9 @@ do { \
#define NAME_ERROR_MSG \
"name '%.200s' is not defined"
-typedef struct {
- PyObject *kwnames;
-} CallShape;
-
// Dummy variables for stack effects.
static PyObject *value, *value1, *value2, *left, *right, *res, *sum, *prod, *sub;
-static PyObject *container, *start, *stop, *v;
+static PyObject *container, *start, *stop, *v, *lhs, *rhs;
static PyObject *
dummy_func(
@@ -101,6 +97,8 @@ dummy_func(
binaryfunc binary_ops[]
)
{
+ _PyInterpreterFrame entry_frame;
+
switch (opcode) {
// BEGIN BYTECODES //
@@ -193,7 +191,21 @@ dummy_func(
ERROR_IF(res == NULL, error);
}
- inst(BINARY_OP_MULTIPLY_INT, (left, right -- prod)) {
+ family(binary_op, INLINE_CACHE_ENTRIES_BINARY_OP) = {
+ BINARY_OP,
+ BINARY_OP_ADD_FLOAT,
+ BINARY_OP_ADD_INT,
+ BINARY_OP_ADD_UNICODE,
+ BINARY_OP_GENERIC,
+ // BINARY_OP_INPLACE_ADD_UNICODE, // This is an odd duck.
+ BINARY_OP_MULTIPLY_FLOAT,
+ BINARY_OP_MULTIPLY_INT,
+ BINARY_OP_SUBTRACT_FLOAT,
+ BINARY_OP_SUBTRACT_INT,
+ };
+
+
+ inst(BINARY_OP_MULTIPLY_INT, (left, right, unused/1 -- prod)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP);
@@ -202,10 +214,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
ERROR_IF(prod == NULL, error);
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
- inst(BINARY_OP_MULTIPLY_FLOAT, (left, right -- prod)) {
+ inst(BINARY_OP_MULTIPLY_FLOAT, (left, right, unused/1 -- prod)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP);
@@ -216,10 +227,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
ERROR_IF(prod == NULL, error);
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
- inst(BINARY_OP_SUBTRACT_INT, (left, right -- sub)) {
+ inst(BINARY_OP_SUBTRACT_INT, (left, right, unused/1 -- sub)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyLong_CheckExact(right), BINARY_OP);
@@ -228,10 +238,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
ERROR_IF(sub == NULL, error);
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
- inst(BINARY_OP_SUBTRACT_FLOAT, (left, right -- sub)) {
+ inst(BINARY_OP_SUBTRACT_FLOAT, (left, right, unused/1 -- sub)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
DEOPT_IF(!PyFloat_CheckExact(right), BINARY_OP);
@@ -241,10 +250,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
ERROR_IF(sub == NULL, error);
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
- inst(BINARY_OP_ADD_UNICODE, (left, right -- res)) {
+ inst(BINARY_OP_ADD_UNICODE, (left, right, unused/1 -- res)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyUnicode_CheckExact(left), BINARY_OP);
DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@@ -253,7 +261,6 @@ dummy_func(
_Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc);
_Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc);
ERROR_IF(res == NULL, error);
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
// This is a subtle one. It's a super-instruction for
@@ -292,7 +299,7 @@ dummy_func(
JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP + 1);
}
- inst(BINARY_OP_ADD_FLOAT, (left, right -- sum)) {
+ inst(BINARY_OP_ADD_FLOAT, (left, right, unused/1 -- sum)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyFloat_CheckExact(left), BINARY_OP);
DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@@ -303,10 +310,9 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
ERROR_IF(sum == NULL, error);
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
- inst(BINARY_OP_ADD_INT, (left, right -- sum)) {
+ inst(BINARY_OP_ADD_INT, (left, right, unused/1 -- sum)) {
assert(cframe.use_tracing == 0);
DEOPT_IF(!PyLong_CheckExact(left), BINARY_OP);
DEOPT_IF(Py_TYPE(right) != Py_TYPE(left), BINARY_OP);
@@ -315,7 +321,6 @@ dummy_func(
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
ERROR_IF(sum == NULL, error);
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
}
inst(BINARY_SUBSCR, (container, sub -- res)) {
@@ -3691,30 +3696,21 @@ dummy_func(
PUSH(Py_NewRef(peek));
}
- // stack effect: (__0 -- )
- inst(BINARY_OP_GENERIC) {
- PyObject *rhs = POP();
- PyObject *lhs = TOP();
+ inst(BINARY_OP_GENERIC, (lhs, rhs, unused/1 -- res)) {
assert(0 <= oparg);
assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops));
assert(binary_ops[oparg]);
- PyObject *res = binary_ops[oparg](lhs, rhs);
+ res = binary_ops[oparg](lhs, rhs);
Py_DECREF(lhs);
Py_DECREF(rhs);
- SET_TOP(res);
- if (res == NULL) {
- goto error;
- }
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
+ ERROR_IF(res == NULL, error);
}
- // stack effect: (__0 -- )
- inst(BINARY_OP) {
+ // This always dispatches, so the result is unused.
+ inst(BINARY_OP, (lhs, rhs, unused/1 -- unused)) {
_PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr;
if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
assert(cframe.use_tracing == 0);
- PyObject *lhs = SECOND();
- PyObject *rhs = TOP();
next_instr--;
_Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0));
DISPATCH_SAME_OPARG();
@@ -3761,13 +3757,8 @@ dummy_func(
;
}
-// Families go below this point //
+// Future families go below this point //
-family(binary_op) = {
- BINARY_OP, BINARY_OP_ADD_FLOAT,
- BINARY_OP_ADD_INT, BINARY_OP_ADD_UNICODE, BINARY_OP_GENERIC, BINARY_OP_INPLACE_ADD_UNICODE,
- BINARY_OP_MULTIPLY_FLOAT, BINARY_OP_MULTIPLY_INT, BINARY_OP_SUBTRACT_FLOAT,
- BINARY_OP_SUBTRACT_INT };
family(binary_subscr) = {
BINARY_SUBSCR, BINARY_SUBSCR_DICT,
BINARY_SUBSCR_GETITEM, BINARY_SUBSCR_LIST_INT, BINARY_SUBSCR_TUPLE_INT };
diff --git a/Python/generated_cases.c.h b/Python/generated_cases.c.h
index 552d0e6..b8bc66b 100644
--- a/Python/generated_cases.c.h
+++ b/Python/generated_cases.c.h
@@ -145,9 +145,9 @@
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
if (prod == NULL) goto pop_2_error;
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, prod);
+ next_instr += 1;
DISPATCH();
}
@@ -165,9 +165,9 @@
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
if (prod == NULL) goto pop_2_error;
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, prod);
+ next_instr += 1;
DISPATCH();
}
@@ -183,9 +183,9 @@
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
if (sub == NULL) goto pop_2_error;
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, sub);
+ next_instr += 1;
DISPATCH();
}
@@ -202,9 +202,9 @@
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
if (sub == NULL) goto pop_2_error;
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, sub);
+ next_instr += 1;
DISPATCH();
}
@@ -220,9 +220,9 @@
_Py_DECREF_SPECIALIZED(left, _PyUnicode_ExactDealloc);
_Py_DECREF_SPECIALIZED(right, _PyUnicode_ExactDealloc);
if (res == NULL) goto pop_2_error;
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, res);
+ next_instr += 1;
DISPATCH();
}
@@ -274,9 +274,9 @@
_Py_DECREF_SPECIALIZED(right, _PyFloat_ExactDealloc);
_Py_DECREF_SPECIALIZED(left, _PyFloat_ExactDealloc);
if (sum == NULL) goto pop_2_error;
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, sum);
+ next_instr += 1;
DISPATCH();
}
@@ -292,9 +292,9 @@
_Py_DECREF_SPECIALIZED(right, (destructor)PyObject_Free);
_Py_DECREF_SPECIALIZED(left, (destructor)PyObject_Free);
if (sum == NULL) goto pop_2_error;
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
STACK_SHRINK(1);
POKE(1, sum);
+ next_instr += 1;
DISPATCH();
}
@@ -3703,29 +3703,30 @@
TARGET(BINARY_OP_GENERIC) {
PREDICTED(BINARY_OP_GENERIC);
- PyObject *rhs = POP();
- PyObject *lhs = TOP();
+ PyObject *rhs = PEEK(1);
+ PyObject *lhs = PEEK(2);
+ PyObject *res;
assert(0 <= oparg);
assert((unsigned)oparg < Py_ARRAY_LENGTH(binary_ops));
assert(binary_ops[oparg]);
- PyObject *res = binary_ops[oparg](lhs, rhs);
+ res = binary_ops[oparg](lhs, rhs);
Py_DECREF(lhs);
Py_DECREF(rhs);
- SET_TOP(res);
- if (res == NULL) {
- goto error;
- }
- JUMPBY(INLINE_CACHE_ENTRIES_BINARY_OP);
+ if (res == NULL) goto pop_2_error;
+ STACK_SHRINK(1);
+ POKE(1, res);
+ next_instr += 1;
DISPATCH();
}
TARGET(BINARY_OP) {
PREDICTED(BINARY_OP);
+ assert(INLINE_CACHE_ENTRIES_BINARY_OP == 1);
+ PyObject *rhs = PEEK(1);
+ PyObject *lhs = PEEK(2);
_PyBinaryOpCache *cache = (_PyBinaryOpCache *)next_instr;
if (ADAPTIVE_COUNTER_IS_ZERO(cache->counter)) {
assert(cframe.use_tracing == 0);
- PyObject *lhs = SECOND();
- PyObject *rhs = TOP();
next_instr--;
_Py_Specialize_BinaryOp(lhs, rhs, next_instr, oparg, &GETLOCAL(0));
DISPATCH_SAME_OPARG();
diff --git a/Tools/cases_generator/generate_cases.py b/Tools/cases_generator/generate_cases.py
index b4f5f8f..d016531 100644
--- a/Tools/cases_generator/generate_cases.py
+++ b/Tools/cases_generator/generate_cases.py
@@ -18,7 +18,6 @@ RE_PREDICTED = r"(?s)(?:PREDICT\(|GO_TO_INSTRUCTION\(|DEOPT_IF\(.*?,\s*)(\w+)\);
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("-i", "--input", type=str, default="Python/bytecodes.c")
arg_parser.add_argument("-o", "--output", type=str, default="Python/generated_cases.c.h")
-arg_parser.add_argument("-c", "--compare", action="store_true")
arg_parser.add_argument("-q", "--quiet", action="store_true")
@@ -40,7 +39,6 @@ def parse_cases(
families: list[parser.Family] = []
while not psr.eof():
if inst := psr.inst_def():
- assert inst.block
instrs.append(inst)
elif sup := psr.super_def():
supers.append(sup)
@@ -69,17 +67,45 @@ def always_exits(block: parser.Block) -> bool:
return line.startswith(("goto ", "return ", "DISPATCH", "GO_TO_", "Py_UNREACHABLE()"))
-def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0):
- assert instr.block
+def find_cache_size(instr: InstDef, families: list[parser.Family]) -> str | None:
+ for family in families:
+ if instr.name == family.members[0]:
+ return family.size
+
+
+def write_instr(
+ instr: InstDef, predictions: set[str], indent: str, f: TextIO, dedent: int = 0, cache_size: str | None = None
+) -> int:
+ # Returns cache offset
if dedent < 0:
indent += " " * -dedent
+ # Separate stack inputs from cache inputs
+ input_names: set[str] = set()
+ stack: list[parser.StackEffect] = []
+ cache: list[parser.CacheEffect] = []
+ for input in instr.inputs:
+ if isinstance(input, parser.StackEffect):
+ stack.append(input)
+ input_names.add(input.name)
+ else:
+ assert isinstance(input, parser.CacheEffect), input
+ cache.append(input)
+ outputs = instr.outputs
+ cache_offset = 0
+ for ceffect in cache:
+ if ceffect.name != "unused":
+ bits = ceffect.size * 16
+ f.write(f"{indent} PyObject *{ceffect.name} = read{bits}(next_instr + {cache_offset});\n")
+ cache_offset += ceffect.size
+ if cache_size:
+ f.write(f"{indent} assert({cache_size} == {cache_offset});\n")
# TODO: Is it better to count forward or backward?
- for i, input in enumerate(reversed(instr.inputs), 1):
- f.write(f"{indent} PyObject *{input} = PEEK({i});\n")
+ for i, effect in enumerate(reversed(stack), 1):
+ if effect.name != "unused":
+ f.write(f"{indent} PyObject *{effect.name} = PEEK({i});\n")
for output in instr.outputs:
- if output not in instr.inputs:
- f.write(f"{indent} PyObject *{output};\n")
- assert instr.block is not None
+ if output.name not in input_names and output.name != "unused":
+ f.write(f"{indent} PyObject *{output.name};\n")
blocklines = instr.block.to_text(dedent=dedent).splitlines(True)
# Remove blank lines from ends
while blocklines and not blocklines[0].strip():
@@ -95,7 +121,7 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
while blocklines and not blocklines[-1].strip():
blocklines.pop()
# Write the body
- ninputs = len(instr.inputs or ())
+ ninputs = len(stack)
for line in blocklines:
if m := re.match(r"(\s*)ERROR_IF\((.+), (\w+)\);\s*$", line):
space, cond, label = m.groups()
@@ -107,46 +133,56 @@ def write_instr(instr: InstDef, predictions: set[str], indent: str, f: TextIO, d
f.write(f"{space}if ({cond}) goto {label};\n")
else:
f.write(line)
- noutputs = len(instr.outputs or ())
+ if always_exits(instr.block):
+ # None of the rest matters
+ return cache_offset
+ # Stack effect
+ noutputs = len(outputs)
diff = noutputs - ninputs
if diff > 0:
f.write(f"{indent} STACK_GROW({diff});\n")
elif diff < 0:
f.write(f"{indent} STACK_SHRINK({-diff});\n")
- for i, output in enumerate(reversed(instr.outputs or ()), 1):
- if output not in (instr.inputs or ()):
- f.write(f"{indent} POKE({i}, {output});\n")
- assert instr.block
-
-def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]):
+ for i, output in enumerate(reversed(outputs), 1):
+ if output.name not in input_names and output.name != "unused":
+ f.write(f"{indent} POKE({i}, {output.name});\n")
+ # Cache effect
+ if cache_offset:
+ f.write(f"{indent} next_instr += {cache_offset};\n")
+ return cache_offset
+
+
+def write_cases(
+ f: TextIO, instrs: list[InstDef], supers: list[parser.Super], families: list[parser.Family]
+) -> dict[str, tuple[int, int, int]]:
predictions: set[str] = set()
for instr in instrs:
- assert isinstance(instr, InstDef)
- assert instr.block is not None
for target in re.findall(RE_PREDICTED, instr.block.text):
predictions.add(target)
indent = " "
f.write(f"// This file is generated by {os.path.relpath(__file__)}\n")
f.write(f"// Do not edit!\n")
instr_index: dict[str, InstDef] = {}
+ effects_table: dict[str, tuple[int, int, int]] = {} # name -> (ninputs, noutputs, cache_offset)
for instr in instrs:
instr_index[instr.name] = instr
f.write(f"\n{indent}TARGET({instr.name}) {{\n")
if instr.name in predictions:
f.write(f"{indent} PREDICTED({instr.name});\n")
- write_instr(instr, predictions, indent, f)
- assert instr.block
+ cache_offset = write_instr(
+ instr, predictions, indent, f,
+ cache_size=find_cache_size(instr, families)
+ )
+ effects_table[instr.name] = len(instr.inputs), len(instr.outputs), cache_offset
if not always_exits(instr.block):
f.write(f"{indent} DISPATCH();\n")
# Write trailing '}'
f.write(f"{indent}}}\n")
for sup in supers:
- assert isinstance(sup, parser.Super)
components = [instr_index[name] for name in sup.ops]
f.write(f"\n{indent}TARGET({sup.name}) {{\n")
for i, instr in enumerate(components):
- assert instr.block
if i > 0:
f.write(f"{indent} NEXTOPARG();\n")
f.write(f"{indent} next_instr++;\n")
@@ -156,6 +192,8 @@ def write_cases(f: TextIO, instrs: list[InstDef], supers: list[parser.Super]):
f.write(f"{indent} DISPATCH();\n")
f.write(f"{indent}}}\n")
+ return effects_table
+
def main():
args = arg_parser.parse_args()
@@ -176,12 +214,28 @@ def main():
file=sys.stderr,
)
with eopen(args.output, "w") as f:
- write_cases(f, instrs, supers)
+ effects_table = write_cases(f, instrs, supers, families)
if not args.quiet:
print(
f"Wrote {ninstrs + nsupers} instructions to {args.output}",
file=sys.stderr,
)
+ # Check that families have consistent effects
+ errors = 0
+ for family in families:
+ head = effects_table[family.members[0]]
+ for member in family.members:
+ if effects_table[member] != head:
+ errors += 1
+ print(
+ f"Family {family.name!r} has inconsistent effects (inputs, outputs, cache units):",
+ file=sys.stderr,
+ )
+ print(
+ f" {family.members[0]} = {head}; {member} = {effects_table[member]}",
+ )
+ if errors:
+ sys.exit(1)
if __name__ == "__main__":
diff --git a/Tools/cases_generator/parser.py b/Tools/cases_generator/parser.py
index 9e95cdb..1f85531 100644
--- a/Tools/cases_generator/parser.py
+++ b/Tools/cases_generator/parser.py
@@ -57,10 +57,27 @@ class Block(Node):
@dataclass
+class Effect(Node):
+ pass
+
+
+@dataclass
+class StackEffect(Effect):
+ name: str
+ # TODO: type, condition
+
+
+@dataclass
+class CacheEffect(Effect):
+ name: str
+ size: int
+
+
+@dataclass
class InstHeader(Node):
name: str
- inputs: list[str]
- outputs: list[str]
+ inputs: list[Effect]
+ outputs: list[Effect]
@dataclass
@@ -69,16 +86,17 @@ class InstDef(Node):
block: Block
@property
- def name(self):
+ def name(self) -> str:
return self.header.name
@property
- def inputs(self):
+ def inputs(self) -> list[Effect]:
return self.header.inputs
@property
- def outputs(self):
- return self.header.outputs
+ def outputs(self) -> list[StackEffect]:
+ # This is always true
+ return [x for x in self.header.outputs if isinstance(x, StackEffect)]
@dataclass
@@ -90,6 +108,7 @@ class Super(Node):
@dataclass
class Family(Node):
name: str
+ size: str # Variable giving the cache size in code units
members: list[str]
@@ -123,18 +142,16 @@ class Parser(PLexer):
return InstHeader(name, [], [])
return None
- def check_overlaps(self, inp: list[str], outp: list[str]):
+ def check_overlaps(self, inp: list[Effect], outp: list[Effect]):
for i, name in enumerate(inp):
- try:
- j = outp.index(name)
- except ValueError:
- continue
- else:
- if i != j:
- raise self.make_syntax_error(
- f"Input {name!r} at pos {i} repeated in output at different pos {j}")
+ for j, name2 in enumerate(outp):
+ if name == name2:
+ if i != j:
+ raise self.make_syntax_error(
+ f"Input {name!r} at pos {i} repeated in output at different pos {j}")
+ break
- def stack_effect(self) -> tuple[list[str], list[str]]:
+ def stack_effect(self) -> tuple[list[Effect], list[Effect]]:
# '(' [inputs] '--' [outputs] ')'
if self.expect(lx.LPAREN):
inp = self.inputs() or []
@@ -144,8 +161,8 @@ class Parser(PLexer):
return inp, outp
raise self.make_syntax_error("Expected stack effect")
- def inputs(self) -> list[str] | None:
- # input (, input)*
+ def inputs(self) -> list[Effect] | None:
+ # input (',' input)*
here = self.getpos()
if inp := self.input():
near = self.getpos()
@@ -157,27 +174,25 @@ class Parser(PLexer):
self.setpos(here)
return None
- def input(self) -> str | None:
- # IDENTIFIER
+ @contextual
+ def input(self) -> Effect | None:
+ # IDENTIFIER '/' INTEGER (CacheEffect)
+ # IDENTIFIER (StackEffect)
if (tkn := self.expect(lx.IDENTIFIER)):
- if self.expect(lx.LBRACKET):
- if arg := self.expect(lx.IDENTIFIER):
- if self.expect(lx.RBRACKET):
- return f"{tkn.text}[{arg.text}]"
- if self.expect(lx.TIMES):
- if num := self.expect(lx.NUMBER):
- if self.expect(lx.RBRACKET):
- return f"{tkn.text}[{arg.text}*{num.text}]"
- raise self.make_syntax_error("Expected argument in brackets", tkn)
-
- return tkn.text
- if self.expect(lx.CONDOP):
- while self.expect(lx.CONDOP):
- pass
- return "??"
- return None
+ if self.expect(lx.DIVIDE):
+ if num := self.expect(lx.NUMBER):
+ try:
+ size = int(num.text)
+ except ValueError:
+ raise self.make_syntax_error(
+ f"Expected integer, got {num.text!r}")
+ else:
+ return CacheEffect(tkn.text, size)
+ raise self.make_syntax_error("Expected integer")
+ else:
+ return StackEffect(tkn.text)
- def outputs(self) -> list[str] | None:
+ def outputs(self) -> list[Effect] | None:
# output (, output)*
here = self.getpos()
if outp := self.output():
@@ -190,8 +205,10 @@ class Parser(PLexer):
self.setpos(here)
return None
- def output(self) -> str | None:
- return self.input() # TODO: They're not quite the same.
+ @contextual
+ def output(self) -> Effect | None:
+ if (tkn := self.expect(lx.IDENTIFIER)):
+ return StackEffect(tkn.text)
@contextual
def super_def(self) -> Super | None:
@@ -216,24 +233,35 @@ class Parser(PLexer):
@contextual
def family_def(self) -> Family | None:
if (tkn := self.expect(lx.IDENTIFIER)) and tkn.text == "family":
+ size = None
if self.expect(lx.LPAREN):
if (tkn := self.expect(lx.IDENTIFIER)):
+ if self.expect(lx.COMMA):
+ if not (size := self.expect(lx.IDENTIFIER)):
+ raise self.make_syntax_error(
+ "Expected identifier")
if self.expect(lx.RPAREN):
if self.expect(lx.EQUALS):
+ if not self.expect(lx.LBRACE):
+ raise self.make_syntax_error("Expected {")
if members := self.members():
- if self.expect(lx.SEMI):
- return Family(tkn.text, members)
+ if self.expect(lx.RBRACE) and self.expect(lx.SEMI):
+ return Family(tkn.text, size.text if size else "", members)
return None
def members(self) -> list[str] | None:
here = self.getpos()
if tkn := self.expect(lx.IDENTIFIER):
- near = self.getpos()
- if self.expect(lx.COMMA):
- if rest := self.members():
- return [tkn.text] + rest
- self.setpos(near)
- return [tkn.text]
+ members = [tkn.text]
+ while self.expect(lx.COMMA):
+ if tkn := self.expect(lx.IDENTIFIER):
+ members.append(tkn.text)
+ else:
+ break
+ peek = self.peek()
+ if not peek or peek.kind != lx.RBRACE:
+ raise self.make_syntax_error("Expected comma or right paren")
+ return members
self.setpos(here)
return None
@@ -274,5 +302,5 @@ if __name__ == "__main__":
filename = None
src = "if (x) { x.foo; // comment\n}"
parser = Parser(src, filename)
- x = parser.inst_def()
+ x = parser.inst_def() or parser.super_def() or parser.family_def()
print(x)