summaryrefslogtreecommitdiffstats
path: root/Tools/cases_generator/flags.py
blob: 5241331bb97cdbaf7324c5a2dad29f0e3427fbf4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import dataclasses

from formatting import Formatter
import lexer as lx
import parsing
from typing import AbstractSet


@dataclasses.dataclass
class InstructionFlags:
    """Construct and manipulate instruction flags"""

    HAS_ARG_FLAG: bool = False
    HAS_CONST_FLAG: bool = False
    HAS_NAME_FLAG: bool = False
    HAS_JUMP_FLAG: bool = False
    HAS_FREE_FLAG: bool = False
    HAS_LOCAL_FLAG: bool = False
    HAS_EVAL_BREAK_FLAG: bool = False
    HAS_DEOPT_FLAG: bool = False
    HAS_ERROR_FLAG: bool = False

    def __post_init__(self) -> None:
        self.bitmask = {name: (1 << i) for i, name in enumerate(self.names())}

    @staticmethod
    def fromInstruction(instr: parsing.Node) -> "InstructionFlags":
        has_free = (
            variable_used(instr, "PyCell_New")
            or variable_used(instr, "PyCell_GET")
            or variable_used(instr, "PyCell_SET")
        )

        return InstructionFlags(
            HAS_ARG_FLAG=variable_used(instr, "oparg"),
            HAS_CONST_FLAG=variable_used(instr, "FRAME_CO_CONSTS"),
            HAS_NAME_FLAG=variable_used(instr, "FRAME_CO_NAMES"),
            HAS_JUMP_FLAG=variable_used(instr, "JUMPBY"),
            HAS_FREE_FLAG=has_free,
            HAS_LOCAL_FLAG=(
                variable_used(instr, "GETLOCAL") or variable_used(instr, "SETLOCAL")
            )
            and not has_free,
            HAS_EVAL_BREAK_FLAG=variable_used(instr, "CHECK_EVAL_BREAKER"),
            HAS_DEOPT_FLAG=variable_used(instr, "DEOPT_IF"),
            HAS_ERROR_FLAG=(
                variable_used(instr, "ERROR_IF")
                or variable_used(instr, "error")
                or variable_used(instr, "pop_1_error")
                or variable_used(instr, "exception_unwind")
                or variable_used(instr, "resume_with_error")
            ),
        )

    @staticmethod
    def newEmpty() -> "InstructionFlags":
        return InstructionFlags()

    def add(self, other: "InstructionFlags") -> None:
        for name, value in dataclasses.asdict(other).items():
            if value:
                setattr(self, name, value)

    def names(self, value: bool | None = None) -> list[str]:
        if value is None:
            return list(dataclasses.asdict(self).keys())
        return [n for n, v in dataclasses.asdict(self).items() if v == value]

    def bitmap(self, ignore: AbstractSet[str] = frozenset()) -> int:
        flags = 0
        assert all(hasattr(self, name) for name in ignore)
        for name in self.names():
            if getattr(self, name) and name not in ignore:
                flags |= self.bitmask[name]
        return flags

    @classmethod
    def emit_macros(cls, out: Formatter) -> None:
        flags = cls.newEmpty()
        for name, value in flags.bitmask.items():
            out.emit(f"#define {name} ({value})")

        for name, value in flags.bitmask.items():
            out.emit(
                f"#define OPCODE_{name[:-len('_FLAG')]}(OP) "
                f"(_PyOpcode_opcode_metadata[OP].flags & ({name}))"
            )


def variable_used(node: parsing.Node, name: str) -> bool:
    """Determine whether a variable with a given name is used in a node."""
    return any(
        token.kind == "IDENTIFIER" and token.text == name for token in node.tokens
    )


def variable_used_unspecialized(node: parsing.Node, name: str) -> bool:
    """Like variable_used(), but skips #if ENABLE_SPECIALIZATION blocks."""
    tokens: list[lx.Token] = []
    skipping = False
    for i, token in enumerate(node.tokens):
        if token.kind == "MACRO":
            text = "".join(token.text.split())
            # TODO: Handle nested #if
            if text == "#if":
                if i + 1 < len(node.tokens) and node.tokens[i + 1].text in (
                    "ENABLE_SPECIALIZATION",
                    "TIER_ONE",
                ):
                    skipping = True
            elif text in ("#else", "#endif"):
                skipping = False
        if not skipping:
            tokens.append(token)
    return any(token.kind == "IDENTIFIER" and token.text == name for token in tokens)