From b85923a0feaae698c811f6c4cf6be018f3725dd5 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Wed, 2 Oct 2024 20:07:20 +0100 Subject: GH-119726: Deduplicate AArch64 trampolines within a trace (GH-123872) --- .../2024-09-19-16-57-34.gh-issue-119726.DseseK.rst | 2 + Python/jit.c | 92 ++++++++++++++++++++-- Tools/jit/_stencils.py | 81 +++++++++---------- Tools/jit/_targets.py | 7 +- Tools/jit/_writer.py | 24 ++++-- 5 files changed, 147 insertions(+), 59 deletions(-) create mode 100644 Misc/NEWS.d/next/Core_and_Builtins/2024-09-19-16-57-34.gh-issue-119726.DseseK.rst diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2024-09-19-16-57-34.gh-issue-119726.DseseK.rst b/Misc/NEWS.d/next/Core_and_Builtins/2024-09-19-16-57-34.gh-issue-119726.DseseK.rst new file mode 100644 index 0000000..c01eeff --- /dev/null +++ b/Misc/NEWS.d/next/Core_and_Builtins/2024-09-19-16-57-34.gh-issue-119726.DseseK.rst @@ -0,0 +1,2 @@ +The JIT now generates more efficient code for calls to C functions resulting +in up to 0.8% memory savings and 1.5% speed improvement on AArch64. Patch by Diego Russo. diff --git a/Python/jit.c b/Python/jit.c index 3332076..234fc7d 100644 --- a/Python/jit.c +++ b/Python/jit.c @@ -3,6 +3,7 @@ #include "Python.h" #include "pycore_abstract.h" +#include "pycore_bitutils.h" #include "pycore_call.h" #include "pycore_ceval.h" #include "pycore_critical_section.h" @@ -113,6 +114,21 @@ mark_executable(unsigned char *memory, size_t size) // JIT compiler stuff: ///////////////////////////////////////////////////////// +#define SYMBOL_MASK_WORDS 4 + +typedef uint32_t symbol_mask[SYMBOL_MASK_WORDS]; + +typedef struct { + unsigned char *mem; + symbol_mask mask; + size_t size; +} trampoline_state; + +typedef struct { + trampoline_state trampolines; + uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH]; +} jit_state; + // Warning! AArch64 requires you to get your hands dirty. These are your gloves: // value[value_start : value_start + len] @@ -390,66 +406,126 @@ patch_x86_64_32rx(unsigned char *location, uint64_t value) patch_32r(location, value); } +void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state); + #include "jit_stencils.h" +#if defined(__aarch64__) || defined(_M_ARM64) + #define TRAMPOLINE_SIZE 16 +#else + #define TRAMPOLINE_SIZE 0 +#endif + +// Generate and patch AArch64 trampolines. The symbols to jump to are stored +// in the jit_stencils.h in the symbols_map. +void +patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state) +{ + // Masking is done modulo 32 as the mask is stored as an array of uint32_t + const uint32_t symbol_mask = 1 << (ordinal % 32); + const uint32_t trampoline_mask = state->trampolines.mask[ordinal / 32]; + assert(symbol_mask & trampoline_mask); + + // Count the number of set bits in the trampoline mask lower than ordinal, + // this gives the index into the array of trampolines. + int index = _Py_popcount32(trampoline_mask & (symbol_mask - 1)); + for (int i = 0; i < ordinal / 32; i++) { + index += _Py_popcount32(state->trampolines.mask[i]); + } + + uint32_t *p = (uint32_t*)(state->trampolines.mem + index * TRAMPOLINE_SIZE); + assert((size_t)(index + 1) * TRAMPOLINE_SIZE <= state->trampolines.size); + + uint64_t value = (uintptr_t)symbols_map[ordinal]; + + /* Generate the trampoline + 0: 58000048 ldr x8, 8 + 4: d61f0100 br x8 + 8: 00000000 // The next two words contain the 64-bit address to jump to. + c: 00000000 + */ + p[0] = 0x58000048; + p[1] = 0xD61F0100; + p[2] = value & 0xffffffff; + p[3] = value >> 32; + + patch_aarch64_26r(location, (uintptr_t)p); +} + +static void +combine_symbol_mask(const symbol_mask src, symbol_mask dest) +{ + // Calculate the union of the trampolines required by each StencilGroup + for (size_t i = 0; i < SYMBOL_MASK_WORDS; i++) { + dest[i] |= src[i]; + } +} + // Compiles executor in-place. Don't forget to call _PyJIT_Free later! int _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], size_t length) { const StencilGroup *group; // Loop once to find the total compiled size: - uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH]; size_t code_size = 0; size_t data_size = 0; + jit_state state = {}; group = &trampoline; code_size += group->code_size; data_size += group->data_size; for (size_t i = 0; i < length; i++) { const _PyUOpInstruction *instruction = &trace[i]; group = &stencil_groups[instruction->opcode]; - instruction_starts[i] = code_size; + state.instruction_starts[i] = code_size; code_size += group->code_size; data_size += group->data_size; + combine_symbol_mask(group->trampoline_mask, state.trampolines.mask); } group = &stencil_groups[_FATAL_ERROR]; code_size += group->code_size; data_size += group->data_size; + combine_symbol_mask(group->trampoline_mask, state.trampolines.mask); + // Calculate the size of the trampolines required by the whole trace + for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) { + state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE; + } // Round up to the nearest page: size_t page_size = get_page_size(); assert((page_size & (page_size - 1)) == 0); - size_t padding = page_size - ((code_size + data_size) & (page_size - 1)); - size_t total_size = code_size + data_size + padding; + size_t padding = page_size - ((code_size + data_size + state.trampolines.size) & (page_size - 1)); + size_t total_size = code_size + data_size + state.trampolines.size + padding; unsigned char *memory = jit_alloc(total_size); if (memory == NULL) { return -1; } // Update the offsets of each instruction: for (size_t i = 0; i < length; i++) { - instruction_starts[i] += (uintptr_t)memory; + state.instruction_starts[i] += (uintptr_t)memory; } // Loop again to emit the code: unsigned char *code = memory; unsigned char *data = memory + code_size; + state.trampolines.mem = memory + code_size + data_size; // Compile the trampoline, which handles converting between the native // calling convention and the calling convention used by jitted code // (which may be different for efficiency reasons). On platforms where // we don't change calling conventions, the trampoline is empty and // nothing is emitted here: group = &trampoline; - group->emit(code, data, executor, NULL, instruction_starts); + group->emit(code, data, executor, NULL, &state); code += group->code_size; data += group->data_size; assert(trace[0].opcode == _START_EXECUTOR); for (size_t i = 0; i < length; i++) { const _PyUOpInstruction *instruction = &trace[i]; group = &stencil_groups[instruction->opcode]; - group->emit(code, data, executor, instruction, instruction_starts); + group->emit(code, data, executor, instruction, &state); code += group->code_size; data += group->data_size; } // Protect against accidental buffer overrun into data: group = &stencil_groups[_FATAL_ERROR]; - group->emit(code, data, executor, NULL, instruction_starts); + group->emit(code, data, executor, NULL, &state); code += group->code_size; data += group->data_size; assert(code == memory + code_size); diff --git a/Tools/jit/_stencils.py b/Tools/jit/_stencils.py index 1c6a9ed..bbb52f3 100644 --- a/Tools/jit/_stencils.py +++ b/Tools/jit/_stencils.py @@ -2,7 +2,6 @@ import dataclasses import enum -import sys import typing import _schema @@ -103,8 +102,8 @@ _HOLE_EXPRS = { HoleValue.OPERAND_HI: "(instruction->operand >> 32)", HoleValue.OPERAND_LO: "(instruction->operand & UINT32_MAX)", HoleValue.TARGET: "instruction->target", - HoleValue.JUMP_TARGET: "instruction_starts[instruction->jump_target]", - HoleValue.ERROR_TARGET: "instruction_starts[instruction->error_target]", + HoleValue.JUMP_TARGET: "state->instruction_starts[instruction->jump_target]", + HoleValue.ERROR_TARGET: "state->instruction_starts[instruction->error_target]", HoleValue.ZERO: "", } @@ -125,6 +124,7 @@ class Hole: symbol: str | None # ...plus this addend: addend: int + need_state: bool = False func: str = dataclasses.field(init=False) # Convenience method: replace = dataclasses.replace @@ -157,10 +157,12 @@ class Hole: if value: value += " + " value += f"(uintptr_t)&{self.symbol}" - if _signed(self.addend): + if _signed(self.addend) or not value: if value: value += " + " value += f"{_signed(self.addend):#x}" + if self.need_state: + return f"{self.func}({location}, {value}, state);" return f"{self.func}({location}, {value});" @@ -175,7 +177,6 @@ class Stencil: body: bytearray = dataclasses.field(default_factory=bytearray, init=False) holes: list[Hole] = dataclasses.field(default_factory=list, init=False) disassembly: list[str] = dataclasses.field(default_factory=list, init=False) - trampolines: dict[str, int] = dataclasses.field(default_factory=dict, init=False) def pad(self, alignment: int) -> None: """Pad the stencil to the given alignment.""" @@ -184,39 +185,6 @@ class Stencil: self.disassembly.append(f"{offset:x}: {' '.join(['00'] * padding)}") self.body.extend([0] * padding) - def emit_aarch64_trampoline(self, hole: Hole, alignment: int) -> Hole: - """Even with the large code model, AArch64 Linux insists on 28-bit jumps.""" - assert hole.symbol is not None - reuse_trampoline = hole.symbol in self.trampolines - if reuse_trampoline: - # Re-use the base address of the previously created trampoline - base = self.trampolines[hole.symbol] - else: - self.pad(alignment) - base = len(self.body) - new_hole = hole.replace(addend=base, symbol=None, value=HoleValue.DATA) - - if reuse_trampoline: - return new_hole - - self.disassembly += [ - f"{base + 4 * 0:x}: 58000048 ldr x8, 8", - f"{base + 4 * 1:x}: d61f0100 br x8", - f"{base + 4 * 2:x}: 00000000", - f"{base + 4 * 2:016x}: R_AARCH64_ABS64 {hole.symbol}", - f"{base + 4 * 3:x}: 00000000", - ] - for code in [ - 0x58000048.to_bytes(4, sys.byteorder), - 0xD61F0100.to_bytes(4, sys.byteorder), - 0x00000000.to_bytes(4, sys.byteorder), - 0x00000000.to_bytes(4, sys.byteorder), - ]: - self.body.extend(code) - self.holes.append(hole.replace(offset=base + 8, kind="R_AARCH64_ABS64")) - self.trampolines[hole.symbol] = base - return new_hole - def remove_jump(self, *, alignment: int = 1) -> None: """Remove a zero-length continuation jump, if it exists.""" hole = max(self.holes, key=lambda hole: hole.offset) @@ -282,8 +250,14 @@ class StencilGroup: default_factory=dict, init=False ) _got: dict[str, int] = dataclasses.field(default_factory=dict, init=False) - - def process_relocations(self, *, alignment: int = 1) -> None: + _trampolines: set[int] = dataclasses.field(default_factory=set, init=False) + + def process_relocations( + self, + known_symbols: dict[str, int], + *, + alignment: int = 1, + ) -> None: """Fix up all GOT and internal relocations for this stencil group.""" for hole in self.code.holes.copy(): if ( @@ -291,9 +265,17 @@ class StencilGroup: in {"R_AARCH64_CALL26", "R_AARCH64_JUMP26", "ARM64_RELOC_BRANCH26"} and hole.value is HoleValue.ZERO ): - new_hole = self.data.emit_aarch64_trampoline(hole, alignment) - self.code.holes.remove(hole) - self.code.holes.append(new_hole) + hole.func = "patch_aarch64_trampoline" + hole.need_state = True + assert hole.symbol is not None + if hole.symbol in known_symbols: + ordinal = known_symbols[hole.symbol] + else: + ordinal = len(known_symbols) + known_symbols[hole.symbol] = ordinal + self._trampolines.add(ordinal) + hole.addend = ordinal + hole.symbol = None self.code.remove_jump(alignment=alignment) self.code.pad(alignment) self.data.pad(8) @@ -348,9 +330,20 @@ class StencilGroup: ) self.data.body.extend([0] * 8) + def _get_trampoline_mask(self) -> str: + bitmask: int = 0 + trampoline_mask: list[str] = [] + for ordinal in self._trampolines: + bitmask |= 1 << ordinal + while bitmask: + word = bitmask & ((1 << 32) - 1) + trampoline_mask.append(f"{word:#04x}") + bitmask >>= 32 + return "{" + ", ".join(trampoline_mask) + "}" + def as_c(self, opname: str) -> str: """Dump this hole as a StencilGroup initializer.""" - return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}}}" + return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}" def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]: diff --git a/Tools/jit/_targets.py b/Tools/jit/_targets.py index b6c0e79..5eb316e 100644 --- a/Tools/jit/_targets.py +++ b/Tools/jit/_targets.py @@ -44,6 +44,7 @@ class _Target(typing.Generic[_S, _R]): stable: bool = False debug: bool = False verbose: bool = False + known_symbols: dict[str, int] = dataclasses.field(default_factory=dict) def _compute_digest(self, out: pathlib.Path) -> str: hasher = hashlib.sha256() @@ -95,7 +96,9 @@ class _Target(typing.Generic[_S, _R]): if group.data.body: line = f"0: {str(bytes(group.data.body)).removeprefix('b')}" group.data.disassembly.append(line) - group.process_relocations(alignment=self.alignment) + group.process_relocations( + known_symbols=self.known_symbols, alignment=self.alignment + ) return group def _handle_section(self, section: _S, group: _stencils.StencilGroup) -> None: @@ -231,7 +234,7 @@ class _Target(typing.Generic[_S, _R]): if comment: file.write(f"// {comment}\n") file.write("\n") - for line in _writer.dump(stencil_groups): + for line in _writer.dump(stencil_groups, self.known_symbols): file.write(f"{line}\n") try: jit_stencils_new.replace(jit_stencils) diff --git a/Tools/jit/_writer.py b/Tools/jit/_writer.py index 9d11094..7b99d10 100644 --- a/Tools/jit/_writer.py +++ b/Tools/jit/_writer.py @@ -2,17 +2,24 @@ import itertools import typing +import math import _stencils -def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]: +def _dump_footer( + groups: dict[str, _stencils.StencilGroup], symbols: dict[str, int] +) -> typing.Iterator[str]: + symbol_mask_size = max(math.ceil(len(symbols) / 32), 1) + yield f'static_assert(SYMBOL_MASK_WORDS >= {symbol_mask_size}, "SYMBOL_MASK_WORDS too small");' + yield "" yield "typedef struct {" yield " void (*emit)(" yield " unsigned char *code, unsigned char *data, _PyExecutorObject *executor," - yield " const _PyUOpInstruction *instruction, uintptr_t instruction_starts[]);" + yield " const _PyUOpInstruction *instruction, jit_state *state);" yield " size_t code_size;" yield " size_t data_size;" + yield " symbol_mask trampoline_mask;" yield "} StencilGroup;" yield "" yield f"static const StencilGroup trampoline = {groups['trampoline'].as_c('trampoline')};" @@ -23,13 +30,18 @@ def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[s continue yield f" [{opname}] = {group.as_c(opname)}," yield "};" + yield "" + yield f"static const void * const symbols_map[{max(len(symbols), 1)}] = {{" + for symbol, ordinal in symbols.items(): + yield f" [{ordinal}] = &{symbol}," + yield "};" def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator[str]: yield "void" yield f"emit_{opname}(" yield " unsigned char *code, unsigned char *data, _PyExecutorObject *executor," - yield " const _PyUOpInstruction *instruction, uintptr_t instruction_starts[])" + yield " const _PyUOpInstruction *instruction, jit_state *state)" yield "{" for part, stencil in [("code", group.code), ("data", group.data)]: for line in stencil.disassembly: @@ -58,8 +70,10 @@ def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator yield "" -def dump(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]: +def dump( + groups: dict[str, _stencils.StencilGroup], symbols: dict[str, int] +) -> typing.Iterator[str]: """Yield a JIT compiler line-by-line as a C header file.""" for opname, group in sorted(groups.items()): yield from _dump_stencil(opname, group) - yield from _dump_footer(groups) + yield from _dump_footer(groups, symbols) -- cgit v0.12