Skip to content

Commit

Permalink
Use Enum to make mypy understand that unlimited is a singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
mthuurne committed Mar 13, 2023
1 parent 6128803 commit b6a26bf
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 22 deletions.
3 changes: 1 addition & 2 deletions src/retroasm/asm_formatter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from collections.abc import Iterable, Iterator, Mapping
from typing import cast

from .asm_directives import DataDirective, OriginDirective, StringDirective
from .expression import IntLiteral
Expand Down Expand Up @@ -41,7 +40,7 @@ def hex_value(self, value: int, width: int) -> str:

def hex_range(self, start: int, end: Width, width: int) -> str:
start_str = self.hex_value(start, width)
end_str = "" if end is unlimited else self.hex_value(cast(int, end), width)
end_str = "" if end is unlimited else self.hex_value(end, width)
return f"{start_str}-{end_str}"

def _string_literal(self, string: bytes) -> Iterator[str]:
Expand Down
2 changes: 1 addition & 1 deletion src/retroasm/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def compute_mask(cls, exprs: Iterable[Expression]) -> int:
# Compute bit mask for maximum combined value.
cmb_mask = -1 << cmb_start
if cmb_value is not unlimited:
cmb_mask &= (1 << cast(int, cmb_value).bit_length()) - 1
cmb_mask &= (1 << cmb_value.bit_length()) - 1
result |= cmb_mask
return result

Expand Down
2 changes: 1 addition & 1 deletion src/retroasm/expression_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _convert_reference_slice(
if ref_width is unlimited:
end_expr: Expression | None = None
else:
end_expr = IntLiteral(cast(int, ref_width))
end_expr = IntLiteral(ref_width)
else:
end_expr = build_expression(end_node, namespace)
width_expr: Expression | None = (
Expand Down
4 changes: 2 additions & 2 deletions src/retroasm/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def __str__(self) -> str:
end = (
""
if width is unlimited
else str(AddOperator(offset, IntLiteral(cast(int, width))))
else str(AddOperator(offset, IntLiteral(width)))
)
return f"{self._bits}[{start}:{end}]"

Expand Down Expand Up @@ -460,7 +460,7 @@ def decode_int(encoded: Expression, typ: IntType) -> Expression:
if typ.signed:
width = typ.width
if width is not unlimited:
return SignExtension(encoded, cast(int, width))
return SignExtension(encoded, width)
return encoded


Expand Down
40 changes: 24 additions & 16 deletions src/retroasm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@

from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import NoReturn, TypeAlias, cast
from enum import Enum
from typing import NoReturn, TypeAlias

from .utils import Singleton, Unique
from .utils import Unique


class DoesNotExist:
"""Used in type annotations when no type is allowed to match."""


class Unlimited(metaclass=Singleton):
class Unlimited(Enum):
"""
Width value for arbitrary-width integer types.
Compares as infinity: larger than any integer.
"""

__slots__ = ()
instance = "unlimited"
"""Singleton instance."""

def __repr__(self) -> str:
return "unlimited"
Expand All @@ -31,41 +33,47 @@ def __lt__(self, other: int | Unlimited) -> bool:
elif isinstance(other, int):
return False
else:
return NotImplemented
return NotImplemented # type: ignore[unreachable]

def __le__(self, other: int | Unlimited) -> bool:
if self is other:
return True
elif isinstance(other, int):
return False
else:
return NotImplemented
return NotImplemented # type: ignore[unreachable]

def __gt__(self, other: int | Unlimited) -> bool:
if self is other:
return False
elif isinstance(other, int):
return True
else:
return NotImplemented
return NotImplemented # type: ignore[unreachable]

def __ge__(self, other: int | Unlimited) -> bool:
if self is other:
return True
elif isinstance(other, int):
return True
else:
return NotImplemented
return NotImplemented # type: ignore[unreachable]

def __add__(self, other: int | Unlimited) -> Unlimited:
if self is other:
return self
elif isinstance(other, int):
return self
else:
return NotImplemented
return NotImplemented # type: ignore[unreachable]

__radd__ = __add__
def __radd__(self, other: int | Unlimited) -> Unlimited:
if self is other:
return self
elif isinstance(other, int):
return self
else:
return NotImplemented # type: ignore[unreachable]

def __sub__(self, other: int) -> Unlimited:
if isinstance(other, int):
Expand All @@ -77,13 +85,13 @@ def __rsub__(self, other: DoesNotExist) -> NoReturn:
raise ArithmeticError('Cannot subtract "unlimited"')


unlimited = Unlimited()
unlimited = Unlimited.instance

Width: TypeAlias = int | Unlimited


def mask_for_width(width: Width) -> int:
return -1 if width is unlimited else (1 << cast(int, width)) - 1
return -1 if width is unlimited else (1 << width) - 1


def width_for_mask(mask: int) -> Width:
Expand Down Expand Up @@ -238,12 +246,12 @@ def __str__(self) -> str:
if self._width is unlimited:
return "int"
else:
return f"s{cast(int, self._width):d}"
return f"s{self._width:d}"
else:
if self._width is unlimited:
return "uint"
else:
return f"u{cast(int, self._width):d}"
return f"u{self._width:d}"

def check_range(self, value: int) -> None:
"""
Expand All @@ -258,13 +266,13 @@ def check_range(self, value: int) -> None:
elif width == 0:
if value == 0:
return
elif -1 << (cast(int, width) - 1) <= value < 1 << (cast(int, width) - 1):
elif -1 << (width - 1) <= value < 1 << (width - 1):
return
else:
if width is unlimited:
if value >= 0:
return
elif 0 <= value < 1 << cast(int, width):
elif 0 <= value < 1 << width:
return
raise ValueError(f"value {value:d} does not fit in type {self}")

Expand Down

0 comments on commit b6a26bf

Please sign in to comment.