Skip to content

Commit

Permalink
Address Brandt's feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
diegorusso committed Sep 19, 2024
1 parent 23e8697 commit 7bdf9f0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 20 deletions.
28 changes: 16 additions & 12 deletions Tools/jit/_stencils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,14 @@ class StencilGroup:
default_factory=dict, init=False
)
_got: dict[str, int] = dataclasses.field(default_factory=dict, init=False)
trampolines: set[int] = dataclasses.field(default_factory=set, init=False)

def process_relocations(self, *, alignment: int = 1) -> None:
_trampolines: set[int] = dataclasses.field(default_factory=set, init=False)

def process_relocations(
self,
known_symbols: dict[str | None, int],
*,
alignment: int = 1,
) -> None:
"""Fix up all GOT and internal relocations for this stencil group."""
for hole in self.code.holes.copy():
if (
Expand All @@ -268,7 +273,7 @@ def process_relocations(self, *, alignment: int = 1) -> None:
else:
ordinal = len(known_symbols)
known_symbols[hole.symbol] = ordinal
self.trampolines.add(ordinal)
self._trampolines.add(ordinal)
hole.addend = ordinal
hole.symbol = None
self.code.remove_jump(alignment=alignment)
Expand Down Expand Up @@ -328,18 +333,17 @@ def _emit_global_offset_table(self) -> None:
def _get_trampoline_mask(self) -> str:
bitmask: int = 0
trampoline_mask: list[str] = []
for ordinal in self.trampolines:
for ordinal in self._trampolines:
bitmask |= 1 << ordinal
if bitmask:
trampoline_mask = [
f"0x{(bitmask >> i*32) & ((1 << 32) - 1):x}"
for i in range(0, SYMBOL_MASK_SIZE)
]
return ", ".join(trampoline_mask)
while bitmask:
word = bitmask & ((1 << 32) - 1)
trampoline_mask.append(f"{word:#04x}")
bitmask >>= 32
return "{" + ", ".join(trampoline_mask) + "}"

def as_c(self, opname: str) -> str:
"""Dump this hole as a StencilGroup initializer."""
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {{{self._get_trampoline_mask()}}}}}"
return f"{{emit_{opname}, {len(self.code.body)}, {len(self.data.body)}, {self._get_trampoline_mask()}}}"


def symbol_to_value(symbol: str) -> tuple[HoleValue, str | None]:
Expand Down
7 changes: 5 additions & 2 deletions Tools/jit/_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class _Target(typing.Generic[_S, _R]):
stable: bool = False
debug: bool = False
verbose: bool = False
known_symbols: dict[str | None, int] = dataclasses.field(default_factory=dict)

def _compute_digest(self, out: pathlib.Path) -> str:
hasher = hashlib.sha256()
Expand Down Expand Up @@ -95,7 +96,9 @@ async def _parse(self, path: pathlib.Path) -> _stencils.StencilGroup:
if group.data.body:
line = f"0: {str(bytes(group.data.body)).removeprefix('b')}"
group.data.disassembly.append(line)
group.process_relocations(alignment=self.alignment)
group.process_relocations(
known_symbols=self.known_symbols, alignment=self.alignment
)
return group

def _handle_section(self, section: _S, group: _stencils.StencilGroup) -> None:
Expand Down Expand Up @@ -231,7 +234,7 @@ def build(
if comment:
file.write(f"// {comment}\n")
file.write("\n")
for line in _writer.dump(stencil_groups):
for line in _writer.dump(stencil_groups, self.known_symbols):
file.write(f"{line}\n")
try:
jit_stencils_new.replace(jit_stencils)
Expand Down
18 changes: 12 additions & 6 deletions Tools/jit/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import itertools
import typing
import math

import _stencils


def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]:
yield f"typedef uint32_t SymbolMask[{_stencils.SYMBOL_MASK_SIZE}];"
def _dump_footer(
groups: dict[str, _stencils.StencilGroup], symbols: dict[str | None, int]
) -> typing.Iterator[str]:
symbol_mask_size = math.ceil(len(symbols) / 32)
yield f"typedef uint32_t SymbolMask[{symbol_mask_size}];"
yield ""
yield "typedef struct {"
yield " void (*emit)("
Expand All @@ -27,8 +31,8 @@ def _dump_footer(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[s
yield f" [{opname}] = {group.as_c(opname)},"
yield "};"
yield ""
yield f"static const void * const symbols_map[{max(len(_stencils.known_symbols), 1)}] = {{"
for symbol, ordinal in _stencils.known_symbols.items():
yield f"static const void * const symbols_map[{max(len(symbols), 1)}] = {{"
for symbol, ordinal in symbols.items():
yield f" [{ordinal}] = &{symbol},"
yield "};"

Expand Down Expand Up @@ -66,8 +70,10 @@ def _dump_stencil(opname: str, group: _stencils.StencilGroup) -> typing.Iterator
yield ""


def dump(groups: dict[str, _stencils.StencilGroup]) -> typing.Iterator[str]:
def dump(
groups: dict[str, _stencils.StencilGroup], symbols: dict[str | None, int]
) -> typing.Iterator[str]:
"""Yield a JIT compiler line-by-line as a C header file."""
for opname, group in sorted(groups.items()):
yield from _dump_stencil(opname, group)
yield from _dump_footer(groups)
yield from _dump_footer(groups, symbols)

0 comments on commit 7bdf9f0

Please sign in to comment.