summaryrefslogtreecommitdiffstats
path: root/Lib/test/support/strace_helper.py
blob: 90d4b5bccb6fa3e417d71f88cd8fe412df5e1782 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import re
import sys
import textwrap
import unittest
from dataclasses import dataclass
from functools import cache
from test import support
from test.support.script_helper import run_python_until_end

_strace_binary = "/usr/bin/strace"
_syscall_regex = re.compile(
    r"(?P<syscall>[^(]*)\((?P<args>[^)]*)\)\s*[=]\s*(?P<returncode>.+)")
_returncode_regex = re.compile(
    br"\+\+\+ exited with (?P<returncode>\d+) \+\+\+")


@dataclass
class StraceEvent:
    syscall: str
    args: list[str]
    returncode: str


@dataclass
class StraceResult:
    strace_returncode: int
    python_returncode: int

    """The event messages generated by strace. This is very similar to the
    stderr strace produces with returncode marker section removed."""
    event_bytes: bytes
    stdout: bytes
    stderr: bytes

    def events(self):
        """Parse event_bytes data into system calls for easier processing.

        This assumes the program under inspection doesn't print any non-utf8
        strings which would mix into the strace output."""
        decoded_events = self.event_bytes.decode('utf-8')
        matches = [
            _syscall_regex.match(event)
            for event in decoded_events.splitlines()
        ]
        return [
            StraceEvent(match["syscall"],
                        [arg.strip() for arg in (match["args"].split(","))],
                        match["returncode"]) for match in matches if match
        ]

    def sections(self):
        """Find all "MARK <X>" writes and use them to make groups of events.

        This is useful to avoid variable / overhead events, like those at
        interpreter startup or when opening a file so a test can verify just
        the small case under study."""
        current_section = "__startup"
        sections = {current_section: []}
        for event in self.events():
            if event.syscall == 'write' and len(
                    event.args) > 2 and event.args[1].startswith("\"MARK "):
                # Found a new section, don't include the write in the section
                # but all events until next mark should be in that section
                current_section = event.args[1].split(
                    " ", 1)[1].removesuffix('\\n"')
                if current_section not in sections:
                    sections[current_section] = list()
            else:
                sections[current_section].append(event)

        return sections


@support.requires_subprocess()
def strace_python(code, strace_flags, check=True):
    """Run strace and return the trace.

    Sets strace_returncode and python_returncode to `-1` on error."""
    res = None

    def _make_error(reason, details):
        return StraceResult(
            strace_returncode=-1,
            python_returncode=-1,
            event_bytes=f"error({reason},details={details}) = -1".encode('utf-8'),
            stdout=res.out if res else b"",
            stderr=res.err if res else b"")

    # Run strace, and get out the raw text
    try:
        res, cmd_line = run_python_until_end(
            "-c",
            textwrap.dedent(code),
            __run_using_command=[_strace_binary] + strace_flags)
    except OSError as err:
        return _make_error("Caught OSError", err)

    if check and res.rc:
        res.fail(cmd_line)

    # Get out program returncode
    stripped = res.err.strip()
    output = stripped.rsplit(b"\n", 1)
    if len(output) != 2:
        return _make_error("Expected strace events and exit code line",
                           stripped[-50:])

    returncode_match = _returncode_regex.match(output[1])
    if not returncode_match:
        return _make_error("Expected to find returncode in last line.",
                           output[1][:50])

    python_returncode = int(returncode_match["returncode"])
    if check and python_returncode:
        res.fail(cmd_line)

    return StraceResult(strace_returncode=res.rc,
                        python_returncode=python_returncode,
                        event_bytes=output[0],
                        stdout=res.out,
                        stderr=res.err)


def get_events(code, strace_flags, prelude, cleanup):
    # NOTE: The flush is currently required to prevent the prints from getting
    # buffered and done all at once at exit
    prelude = textwrap.dedent(prelude)
    code = textwrap.dedent(code)
    cleanup = textwrap.dedent(cleanup)
    to_run = f"""
print("MARK prelude", flush=True)
{prelude}
print("MARK code", flush=True)
{code}
print("MARK cleanup", flush=True)
{cleanup}
print("MARK __shutdown", flush=True)
    """
    trace = strace_python(to_run, strace_flags)
    all_sections = trace.sections()
    return all_sections['code']


def get_syscalls(code, strace_flags, prelude="", cleanup=""):
    """Get the syscalls which a given chunk of python code generates"""
    events = get_events(code, strace_flags, prelude=prelude, cleanup=cleanup)
    return [ev.syscall for ev in events]


# Moderately expensive (spawns a subprocess), so share results when possible.
@cache
def _can_strace():
    res = strace_python("import sys; sys.exit(0)", [], check=False)
    assert res.events(), "Should have parsed multiple calls"

    return res.strace_returncode == 0 and res.python_returncode == 0


def requires_strace():
    if sys.platform != "linux":
        return unittest.skip("Linux only, requires strace.")

    if support.check_sanitizer(address=True, memory=True):
        return unittest.skip("LeakSanitizer does not work under ptrace (strace, gdb, etc)")

    return unittest.skipUnless(_can_strace(), "Requires working strace")


__all__ = ["get_events", "get_syscalls", "requires_strace", "strace_python",
           "StraceEvent", "StraceResult"]