From ca97f5ab5413238da061ab076ccfb5895cb56ae4 Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Thu, 12 Nov 2020 16:43:16 -0800 Subject: [PATCH] feature: use BV class as python value for Cryptol bit sequences (#116) * feature: use BV class as python value for Cryptol bit sequences * fix: correct subtraction bug and add test * chore: fix typos, delete misc commented out code * chore: use Sphinx-style docstrings * test: add cryptol python unit tests to CI * chore: rename module bv to bitvector * chore: rename test for cryptol.bitvector * feature: allow BV creation from size/value or a BitVector * feature: add widen method to BV * install python deps for CI * chore: fix typo in widen docstrings * feature: with_bits method for replacing segments of BVs * chore: tweak all python tests to use cabal, fix mypy test failures * feature: support BV keyword args for construction * try tweaking the mypy version in requirements.txt... --- .github/workflows/ci.yml | 6 + python/cryptol/__init__.py | 14 +- python/cryptol/bitvector.py | 478 ++++++++++++++++++ python/cryptol/cryptoltypes.py | 12 +- python/cryptol/test/__init__.py | 2 + python/cryptol/test/test_bitvector.py | 446 ++++++++++++++++ python/hs-test/Main.hs | 10 +- python/requirements.txt | 2 +- .../src/Test/Tasty/HUnit/ScriptExit.hs | 9 + 9 files changed, 971 insertions(+), 8 deletions(-) create mode 100644 python/cryptol/bitvector.py create mode 100644 python/cryptol/test/__init__.py create mode 100644 python/cryptol/test/test_bitvector.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2275856..f8359ac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,6 +40,10 @@ jobs: run: cabal update # Build macaw-base dependencies and crucible separately just so later # steps are less verbose and major dependency failures are separate. + - name: Install python dependencies + working-directory: ./python + run: | + if [ -f requirements.txt ]; then pip3 install -r requirements.txt; fi - name: Build run: | cabal build all @@ -47,3 +51,5 @@ jobs: run: cabal test argo - name: Cabal file-echo-api tests run: cabal test file-echo-api + - name: Python unit tests + run: cabal test python diff --git a/python/cryptol/__init__.py b/python/cryptol/__init__.py index dd9d0a8..89db6d2 100644 --- a/python/cryptol/__init__.py +++ b/python/cryptol/__init__.py @@ -13,7 +13,7 @@ from argo.interaction import HasProtocolState from argo.connection import DynamicSocketProcess, ServerConnection, ServerProcess, StdIOProcess from . import cryptoltypes - +from cryptol.bitvector import BV __all__ = ['cryptoltypes'] @@ -42,6 +42,7 @@ def fail_with(x : Exception) -> NoReturn: def from_cryptol_arg(val : Any) -> Any: + """Return the canonical Python value for a Cryptol JSON value.""" if isinstance(val, bool): return val elif isinstance(val, int): @@ -58,13 +59,18 @@ def from_cryptol_arg(val : Any) -> Any: return [from_cryptol_arg(v) for v in val['data']] elif tag == 'bits': enc = val['encoding'] + size = val['width'] if enc == 'base64': - data = base64.b64decode(val['data'].encode('ascii')) + n = int.from_bytes( + base64.b64decode(val['data'].encode('ascii')), + byteorder='big') elif enc == 'hex': - data = bytes.fromhex(extend_hex(val['data'])) + n = int.from_bytes( + bytes.fromhex(extend_hex(val['data'])), + byteorder='big') else: raise ValueError("Unknown encoding " + str(enc)) - return data + return BV(size, n) else: raise ValueError("Unknown expression tag " + tag) else: diff --git a/python/cryptol/bitvector.py b/python/cryptol/bitvector.py new file mode 100644 index 0000000..058ac2a --- /dev/null +++ b/python/cryptol/bitvector.py @@ -0,0 +1,478 @@ + +from functools import reduce +from typing import Any, List, Union, Optional, overload +from BitVector import BitVector #type: ignore + + +class BV: + """A class representing a cryptol bit vector (i.e., a sequence of bits). + + ``BV(size : int, value : int)`` will create a ``BV`` of length ``size`` and bits corresponding + to the unsigned integer representation of ``value`` (N.B., ``0 <= size <= value <= 2 ** size - 1`` + must evaluate to ``True`` or an error will be raised). + + N.B., the ``size`` and ``value`` arguments can be passed positionally or by name: + + ``BV(8,0xff) == BV(size=8, value=0xff) == BV(value=0xff, size=8)`` + + ``BV(bv : BitVector)`` will create an equivalent ``BV`` to the given ``BitVector`` value. + """ + __size : int + __value : int + + def __init__(self, size : Union[int, BitVector], value : Optional[int] = None) -> None: + """Initialize a ``BV`` from a ``BitVector`` or from size and value nonnegative integers.""" + if value is not None: + if not isinstance(size, int) or size < 0: + raise ValueError(f'`size` parameter to BV must be a nonnegative integer but was given {size!r}.') + self.__size = size + if not isinstance(value, int): + raise ValueError(f'{value!r} is not an integer value to initilize a bit vector of size {self.__size!r} with.') + self.__value = value + elif not isinstance(size, BitVector): + raise ValueError(f'BV can only be created from a single value when that value is a BitVector, but got {size!r}') + else: + self.__size = len(size) + self.__value = int(size) + if self.__value < 0 or self.__value.bit_length() > self.__size: + raise ValueError(f'{self.__value!r} is not representable as an unsigned integer with {self.__size!r} bits.') + + def hex(self) -> str: + """Return the (padded) hexadecimal string for the unsigned integer this ``BV`` represents. + + Note: padding is determined by ``self.size()``, rounding up a single digit + for widths not evenly divisible by 4.""" + hex_str_width = 2 + (self.__size // 4) + (0 if (self.__size % 4 == 0) else 1) + return format(self.__value, f'#0{hex_str_width!r}x') + + def __repr__(self) -> str: + return f"BV({self.__size!r}, {self.hex()})" + + @overload + def __getitem__(self, key : int) -> bool: + pass + @overload + def __getitem__(self, key : slice) -> 'BV': + pass + def __getitem__(self, key : Union[int, slice]) -> Union[bool, 'BV']: + """``BV`` indexing and slicing. + + :param key: If ``key`` is an integer, ``True`` is returned if the corresponding bit + is non-zero, else ``False`` is returned. If ``key`` is a ``slice`` (i.e., ``[high:low]``) + it specifies a sub-``BV`` of ``self`` corresponding to the bits from + index ``low`` up until (but not including) index ``high``. + + Examples: + + ``BV(8,0b00000010)[0] == False`` + + ``BV(8,0b00000010)[1] == True`` + + ``BV(8,0b00000010)[4:0] == BV(4,0b0010)`` + """ + if isinstance(key, int): + if key < 0 or key >= self.__size: + raise ValueError(f'{key!r} is not a valid index for {self!r}') + else: + return (self.__value & (1 << key)) != 0 + if isinstance(key, slice): + high = key.start + low = key.stop + if not isinstance(low, int): raise ValueError(f'Expected BV slice to use non-negative integer indices, but got low index of {low!r}.') + if low < 0 and low > self.__size: raise ValueError(f'Expected BV slice low index to be >= 0 and <= the BV size (i.e., {self.__size!r}) but got {low!r}.') + if not isinstance(high, int): raise ValueError(f'Expected BV slice to use non-negative integer indices, but got high index of {high!r}.') + if low > high: raise ValueError(f'BV slice low index {low!r} is larger than the high index {high!r}.') + if high > self.__size: raise ValueError(f'BV slice high index {high!r} is larger than the BV size (i.e., {self.__size!r}).') + if key.step: raise ValueError(f'BV slicing expects a step of None, but found {key.step!r}') + new_sz = high - low + return BV(new_sz, (self.__value >> low) & ((2 ** new_sz) - 1)) + else: + raise ValueError(f'{key!r} is not a valid BV index or slice.') + + def size(self) -> int: + """Size of the ``BV`` (i.e., the available "bit width").""" + return self.__size + + def widen(self, n : int) -> 'BV': + """Returns a "widened" version of ``self``, i.e. ``BV(self.size() + n, self.value())``. + + Args: + n (int): How many bits wider the returned ``BV`` should be than ``self`` (must be nonnegative). + """ + if not isinstance(n, int) or n < 0: #type: ignore + raise ValueError(f'``widen`` expects a nonnegative integer, but got {n!r}') + else: + return BV(self.__size + n, self.__value) + + def value(self) -> int: + """The unsigned integer interpretation of the ``self``.""" + return self.__value + + def __concat_single(self, other : 'BV') -> 'BV': + if isinstance(other, BV): + return BV(self.__size + other.__size, (self.__value << other.__size) + other.__value) + else: + raise ValueError(f'Cannot concat BV with {other!r}') + + def concat(self, *others : 'BV') -> 'BV': + """Concatenate the given ``BV``s to the right of ``self``. + + :param others: The BVs to concatenate onto the right side of ``self`` in order. + + Returns: + BV: a bit vector with the bits from ``self`` on the left and the bits from + ``others`` in order on the right. + """ + return reduce(lambda acc, b: acc.__concat_single(b), others, self) + + @staticmethod + def join(*bs : 'BV') -> 'BV': + """Concatenate the given ``BV``s in order. + + :param bs: The ``BV``s to concatenate in order. + + Returns: + BV: A bit vector with the bits from ``others`` in order. + """ + return reduce(lambda acc, b: acc.__concat_single(b), bs, BV(0,0)) + + def zero(self) -> 'BV': + """The zero bit vector for ``self``'s size (i.e., ``BV(self.size(), 0)``).""" + return BV(self.size() ,0) + + def to_int(self) -> int: + """Return the unsigned integer the ``BV`` represents (equivalent to ``self.value()``).""" + return self.__value + + def to_signed_int(self) -> int: + """Return the signed (i.e., two's complement) integer the ``BV`` represents.""" + if not self.msb(): + n = self.__value + else: + n = 0 - ((2 ** self.__size) - self.__value) + return n + + def msb(self) -> bool: + """Returns ``True`` if the most significant bit is 1, else returns ``False``.""" + if self.__size == 0: + raise ValueError("0-length BVs have no most significant bit.") + else: + return self[self.__size - 1] + + def lsb(self) -> bool: + """Returns ``True`` if the least significant bit is 1, else returns ``False``.""" + if self.__size == 0: + raise ValueError("0-length BVs have no least significant bit.") + else: + return self[0] + + + def __eq__(self, other : Any) -> bool: + """Returns ``True`` if ``other`` is also a ``BV`` of the same size and value, else returns ``False``.""" + if isinstance(other, BV): + return self.__size == other.__size and self.__value == other.__value + else: + return False + + def __index__(self) -> int: + """Equivalent to ``self.value()``.""" + return self.__value + + def __len__(self) -> int: + """Equivalent to ``self.size()``.""" + return self.__size + + def __bytes__(self) -> bytes: + """Returns the ``bytes`` value equivalent to ``self.value()``.""" + byte_len = (self.__size // 8) + (0 if self.__size % 8 == 0 else 1) + return self.__value.to_bytes(byte_len, 'big') + + + def split(self, size : int) -> List['BV']: + """Split ``self`` into a list of ``BV``s of length ``size``. + + :param size: Size of segments to partition ``self`` into (must evently divide ``self.size()``). + """ + if not isinstance(size, int) or size <= 0: #type: ignore + raise ValueError(f'`size` argument to splits must be a positive integer, got {size!r}') + if not self.size() % size == 0: + raise ValueError(f'{self!r} is not divisible into equal parts of size {size!r}') + mask = (1 << size) - 1 + return [BV(size, (self.__value >> (i * size)) & mask) + for i in range(self.size() // size - 1, -1, -1)] + + + def popcount(self) -> int: + """Return the number of bits set to ``1`` in ``self``.""" + return bin(self).count("1") + + @staticmethod + def from_bytes(bs : bytes, *, size : Optional[int] =None, byteorder : str ='big') -> 'BV': + """Convert the given bytes ``bs`` into a ``BV``. + + :param bs: Bytes to convert to a ``BV``. + :param size, optional: Desired ``BV``'s size (must be large enough to represent ``bs``). The + default (i.e., ``None``) will result in a ``BV`` of size ``len(bs) * 8``. + :param byteorder, optional: Byte ordering ``bs`` should be interpreted as, defaults to + ``'big'``, ``little`` being the other acceptable value. Equivalent to the ``byteorder`` + parameter from Python's ``int.from_bytes``.""" + + if not isinstance(bs, bytes): + raise ValueError("from_bytes given not bytes value: {bs!r}") + + if not byteorder == 'little' and not byteorder == 'big': + raise ValueError("from_bytes given not bytes value: {bs!r}") + + if size == None: + return BV(len(bs) * 8, int.from_bytes(bs, byteorder=byteorder)) + elif isinstance(size, int) and size >= len(bs) * 8: + return BV(size, int.from_bytes(bs, byteorder=byteorder)) + else: + raise ValueError(f'from_bytes given invalid bit size {size!r} for bytes {bs!r}') + + def with_bit(self, index : int, set_bit : bool) -> 'BV': + """Return a ``BV`` identical to ``self`` but with the bit at ``index`` set to + ``1`` if ``set_bit == True``, else ``0``.""" + if index < 0 or index >= self.__size: + raise ValueError(f'{index!r} is not a valid bit index for {self!r}') + if set_bit: + mask = (1 << index) + return BV(self.__size, self.__value | mask) + else: + mask = (2 ** self.__size - 1) ^ (1 << index) + return BV(self.__size, self.__value & mask) + + + def with_bits(self, low : int, bits : 'BV') -> 'BV': + """Return a ``BV`` identical to ``self`` but with the bits from ``low`` to + ``low + bits.size() - 1`` replaced by the bits from ``bits``.""" + if not isinstance(low, int) or low < 0 or low >= self.__size: # type: ignore + raise ValueError(f'{low!r} is not a valid low bit index for {self!r}') + elif not isinstance(bits, BV): + raise ValueError(f'Expected a BV but got {bits!r}') + elif low + bits.size() > self.__size: + raise ValueError(f'{bits!r} does not fit within {self!r} when starting from low bit index {low!r}.') + else: + wider = self.size() - (low + bits.size()) + mask = (BV(bits.size(), 2 ** bits.size() - 1) << low).widen(wider) + return (self & ~mask) | (bits << low).widen(wider) + + def to_bytes(self) -> bytes: + """Convert the given ``BV`` into a python native ``bytes`` value. + + Note: equivalent to bytes(_).""" + + return self.__bytes__() + + def __mod_if_overflow(self, value : int) -> int: + n : int = value if value.bit_length() <= self.__size \ + else (value % (2 ** self.__size)) + return n + + def __add__(self, other : Union[int, 'BV']) -> 'BV': + """Addition bewteen ``BV``s of equal size or bewteen a ``BV`` and a nonnegative + integer whose value is expressible with the ``BV`` parameter's size.""" + if isinstance(other, BV): + if self.__size == other.__size: + return BV( + self.__size, + self.__mod_if_overflow(self.__value + other.__value)) + else: + raise ValueError(self.__unequal_len_op_error_msg("+", other)) + elif isinstance(other, int): + return BV( + self.__size, + self.__mod_if_overflow(self.__value + other)) + else: + raise ValueError(f'Cannot add {self!r} with {other!r}.') + + def __radd__(self, other : int) -> 'BV': + if isinstance(other, int): + return BV(self.__size, self.__mod_if_overflow(self.__value + other)) + else: + raise ValueError(f'Cannot add {self!r} with {other!r}.') + + def __and__(self, other : Union['BV', int]) -> 'BV': + """Bitwise 'logical and' bewteen ``BV``s of equal size or bewteen a ``BV`` and a nonnegative + integer whose value is expressible with the ``BV`` parameter's size.""" + if isinstance(other, BV): + if self.__size == other.__size: + return BV(self.__size, self.__value & other.__value) + else: + raise ValueError(self.__unequal_len_op_error_msg("&", other)) + elif isinstance(other, int): + return BV(self.__size, self.__value & other) + else: + raise ValueError(f'Cannot bitwise and {self!r} with value {other!r}.') + + def __rand__(self, other : int) -> 'BV': + if isinstance(other, int): + return BV(self.__size, self.__value & other) + else: + raise ValueError(f'Cannot bitwise and {self!r} with value {other!r}.') + + def __or__(self, other : Union['BV', int]) -> 'BV': + """Bitwise 'logical or' bewteen ``BV``s of equal size or bewteen a ``BV`` and a nonnegative + integer whose value is expressible with the ``BV`` parameter's size.""" + if isinstance(other, BV): + if self.__size == other.__size: + return BV(self.__size, self.__value | other.__value) + else: + raise ValueError(self.__unequal_len_op_error_msg("|", other)) + elif isinstance(other, int): + return BV(self.__size, self.__value | other) + else: + raise ValueError(f'Cannot bitwise or {self!r} with value {other!r}.') + + def __ror__(self, other : int) -> 'BV': + if isinstance(other, int): + return BV(self.__size, self.__value | other) + else: + raise ValueError(f'Cannot bitwise or {self!r} with value {other!r}.') + + def __xor__(self, other : Union['BV', int]) -> 'BV': + """Bitwise 'logical xor' bewteen ``BV``s of equal size or bewteen a ``BV`` and a nonnegative + integer whose value is expressible with the ``BV`` parameter's size.""" + if isinstance(other, BV): + if self.__size == other.__size: + return BV(self.__size, self.__value ^ other.__value) + else: + raise ValueError(self.__unequal_len_op_error_msg("^", other)) + elif isinstance(other, int): + return BV(self.__size, self.__value ^ other) + else: + raise ValueError(f'Cannot bitwise xor {self!r} with value {other!r}.') + + def __rxor__(self, other : int) -> 'BV': + if isinstance(other, int): + return BV(self.__size, self.__value ^ other) + else: + raise ValueError(f'Cannot bitwise xor {self!r} with value {other!r}.') + + + def __invert__(self) -> 'BV': + """Returns the bitwise inversion of ``self``.""" + return BV(self.__size, (1 << self.__size) - 1 - self.__value) + + @staticmethod + def __from_signed_int(size: int, val : int) -> 'BV': + excl_max = 2 ** size + if (size == 0): + return BV(0,0) + elif val >= 0: + return BV(size, val % excl_max) + else: + return BV(size, ((excl_max - 1) & ~(abs(val + 1))) % excl_max) + + @staticmethod + def from_signed_int(size: int, value : int) -> 'BV': + """Returns the ``BV`` corrsponding to the ``self.size()``-bit two's complement representation of ``value``. + + :param size: Bit width of desired ``BV``. + :param value: Integer returned ``BV`` is derived from (must be in range + ``-(2 ** (size - 1))`` to ``(2 ** (size - 1) - 1)`` inclusively). + """ + if size == 0: + raise ValueError("There are no two's complement 0-bit vectors.") + max_val = 2 ** (size - 1) - 1 + min_val = -(2 ** (size - 1)) + if value < min_val or value > max_val: + raise ValueError(f'{value!r} is not in range [{min_val!r},{max_val!r}].') + else: + return BV.__from_signed_int(size, value) + + def __sub__(self, other : Union[int, 'BV']) -> 'BV': + """Subtraction bewteen ``BV``s of equal size or bewteen a ``BV`` and a nonnegative + integer whose value is expressible with the ``BV`` parameter's size.""" + if isinstance(other, BV): + if self.__size == other.__size: + if self.__size == 0: + return self + else: + return BV.__from_signed_int( + self.__size, + self.to_signed_int() - other.to_signed_int()) + else: + raise ValueError(self.__unequal_len_op_error_msg("-", other)) + elif isinstance(other, int): + self.__check_int_size(other) + if self.__size == 0: + return self + else: + return BV.__from_signed_int( + self.__size, + self.to_signed_int() - other) + else: + raise ValueError(f'Cannot subtract {other!r} from {self!r}.') + + def __rsub__(self, other : int) -> 'BV': + if isinstance(other, int): + self.__check_int_size(other) + if self.__size == 0: + return self + else: + return BV.__from_signed_int( + self.__size, + other - self.to_signed_int()) + else: + raise ValueError(f'Cannot subtract {self!r} from {other!r}.') + + + def __mul__(self, other: Union[int, 'BV']) -> 'BV': + """Multiplication bewteen ``BV``s of equal size or bewteen a ``BV`` and a nonnegative + integer whose value is expressible with the ``BV`` parameter's size.""" + if isinstance(other, BV): + if self.__size == other.__size: + return BV( + self.__size, + self.__mod_if_overflow(self.__value * other.__value)) + else: + raise ValueError(self.__unequal_len_op_error_msg("*", other)) + elif isinstance(other, int): + self.__check_int_size(other) + return BV.__from_signed_int( + self.__size, + self.__mod_if_overflow(self.__value * other)) + else: + raise ValueError(f'Cannot multiply {self!r} and {other!r}.') + + def __rmul__(self, other : int) -> 'BV': + return self.__mul__(other) + + + def __lshift__(self, other : Union[int, 'BV']) -> 'BV': + """Returns the bitwise left shift of ``self``. + + :param other: Nonnegative amount to left shift ``self`` by (resulting + ``BV``'s size is ``self.size() + int(other)``)). + """ + if isinstance(other, int) or isinstance(other, BV): + n = int(other) + if n < 0: + raise ValueError(f'Cannot left shift a negative amount (i.e, {n!r}).') + return BV(self.__size + n, self.__value << n) + else: + raise ValueError(f'Shift must be specified with an integer or BV, but got {other!r}.') + + + def __rshift__(self, other : Union[int, 'BV']) -> 'BV': + """Returns the bitwise right shift of ``self``. + + :param other: Nonnegative amount to right shift ``self`` by (resulting + ``BV``'s size is ``max(0, self.size() - int(other))``). + """ + if isinstance(other, int) or isinstance(other, BV): + n = int(other) + if n < 0: + raise ValueError(f'Cannot right shift a negative amount (i.e, {n!r}).') + return BV(max(0, self.__size - n), self.__value >> n) + else: + raise ValueError(f'Shift must be specified with an integer or BV, but got {other!r}.') + + def __check_int_size(self, val : int) -> None: + if val >= (2 ** self.__size) or val < 0: + raise ValueError(f'{val!r} is not a valid unsigned {self.__size!r}-bit value.') + + + def __unequal_len_op_error_msg(self, op : str, other : 'BV') -> str: + return f'Operator `{op}` cannot be called on BV of unequal length {self!r} and {other!r}.' diff --git a/python/cryptol/cryptoltypes.py b/python/cryptol/cryptoltypes.py index cd4acf3..ee7ccbb 100644 --- a/python/cryptol/cryptoltypes.py +++ b/python/cryptol/cryptoltypes.py @@ -4,6 +4,7 @@ import base64 from math import ceil import BitVector #type: ignore +from cryptol.bitvector import BV from typing import Any, Dict, Iterable, List, NoReturn, Optional, TypeVar, Union @@ -157,6 +158,11 @@ def convert(self, val : Any) -> Any: 'encoding': 'base64', 'width': val.length(), # N.B. original length, not padded 'data': base64.b64encode(n.to_bytes(byte_width,'big')).decode('ascii')} + elif isinstance(val, BV): + return {'expression': 'bits', + 'encoding': 'hex', + 'width': val.size(), # N.B. original length, not padded + 'data': val.hex()[2:]} else: raise TypeError("Unsupported value: " + str(val)) @@ -200,6 +206,8 @@ def convert(self, val : Any) -> Any: 'data': base64.b64encode(val).decode('ascii')} elif isinstance(val, BitVector.BitVector): return CryptolType.convert(self, val) + elif isinstance(val, BV): + return CryptolType.convert(self, val) else: raise ValueError(f"Not supported as bitvector: {val!r}") @@ -343,7 +351,7 @@ def __repr__(self) -> str: class Log2(CryptolType): def __init__(self, operand : CryptolType) -> None: - self.right = operand + self.operand = operand def __str__(self) -> str: return f"(lg2 {self.operand})" @@ -353,7 +361,7 @@ def __repr__(self) -> str: class Width(CryptolType): def __init__(self, operand : CryptolType) -> None: - self.right = operand + self.operand = operand def __str__(self) -> str: return f"(width {self.operand})" diff --git a/python/cryptol/test/__init__.py b/python/cryptol/test/__init__.py new file mode 100644 index 0000000..ed84c68 --- /dev/null +++ b/python/cryptol/test/__init__.py @@ -0,0 +1,2 @@ +# import the package +import cryptol.bitvector diff --git a/python/cryptol/test/test_bitvector.py b/python/cryptol/test/test_bitvector.py new file mode 100644 index 0000000..f954ff8 --- /dev/null +++ b/python/cryptol/test/test_bitvector.py @@ -0,0 +1,446 @@ +import unittest +import random +from cryptol.bitvector import BV +from BitVector import BitVector + + +class BVBaseTest(unittest.TestCase): + """Base class for BV test cases.""" + + def assertBVEqual(self, b, size, value): + """Assert BV `b` has the specified `size` and `value`.""" + self.assertEqual(b.size(), size) + self.assertEqual(b.value(), value) + + + def assertUnOpExpected(self, op_fn, expected_fn): + """Assert `prop` holds for any BV value.""" + for width in range(0, 129): + max_val = 2 ** width - 1 + for i in range(0, 100): + b = BV(width, random.randint(0, max_val)) + # Put `b` in the assertion so we can see its value + # on failed test cases. + self.assertEqual((b, op_fn(b)), (b, expected_fn(b))) + + def assertBinOpExpected(self, op_fn, expected_fn): + """Assert `prop` holds for any BV value.""" + for width in range(0, 129): + max_val = 2 ** width - 1 + for i in range(0, 100): + b1 = BV(width, random.randint(0, max_val)) + b2 = BV(width, random.randint(0, max_val)) + # Put `b1` and `b2` in the assertion so we can + # see its value on failed test cases. + self.assertEqual((b1,b2,op_fn(b1, b2)), (b1,b2,expected_fn(b1, b2))) + + +class BVBasicTests(BVBaseTest): + def test_constructor1(self): + b = BV(BitVector(intVal = 0, size = 8)) + self.assertBVEqual(b, 8, 0) + b = BV(BitVector(intVal = 42, size = 8)) + self.assertBVEqual(b, 8, 42) + + def test_constructor2(self): + b = BV(0,0) + self.assertBVEqual(b, 0, 0) + b = BV(value=16,size=8) + self.assertBVEqual(b, 8, 16) + b = BV(8,42) + self.assertBVEqual(b, 8, 42) + + def test_constructor_fails(self): + with self.assertRaises(ValueError): + BV(8, 256) + with self.assertRaises(ValueError): + BV(8, -1) + + def test_hex(self): + self.assertEqual(hex(BV(0,0)), "0x0") + self.assertEqual(BV(0,0).hex(), "0x0") + self.assertEqual(hex(BV(4,0)), "0x0") + self.assertEqual(BV(4,0).hex(), "0x0") + self.assertEqual(hex(BV(5,0)), "0x0") + self.assertEqual(BV(5,0).hex(), "0x00") + self.assertEqual(hex(BV(5,11)), "0xb") + self.assertEqual(BV(5,11).hex(), "0x0b") + self.assertEqual(hex(BV(8,255)), "0xff") + self.assertEqual(BV(8,255).hex(), "0xff") + self.assertEqual(hex(BV(9,255)), "0xff") + self.assertEqual(BV(9,255).hex(), "0x0ff") + + def test_repr(self): + self.assertEqual(repr(BV(0,0)), "BV(0, 0x0)") + self.assertEqual(repr(BV(9,255)), "BV(9, 0x0ff)") + + def test_int(self): + self.assertEqual(int(BV(0,0)), 0) + self.assertEqual(int(BV(9,255)), 255) + self.assertUnOpExpected( + lambda b: BV(b.size(), int(b)), + lambda b: b) + + def test_size(self): + self.assertEqual(BV(0,0).size(), 0) + self.assertEqual(BV(9,255).size(), 9) + + def test_len(self): + self.assertEqual(len(BV(0,0)), 0) + self.assertEqual(len(BV(9,255)), 9) + + def test_popcount(self): + self.assertEqual(BV(0,0).popcount(), 0) + self.assertEqual(BV(8,0).popcount(), 0) + self.assertEqual(BV(8,1).popcount(), 1) + self.assertEqual(BV(8,2).popcount(), 1) + self.assertEqual(BV(8,3).popcount(), 2) + self.assertEqual(BV(8,255).popcount(), 8) + + def test_eq(self): + self.assertEqual(BV(0,0), BV(0,0)) + self.assertEqual(BV(8,255), BV(8,255)) + self.assertTrue(BV(8,255) == BV(8,255)) + self.assertFalse(BV(8,255) == BV(8,254)) + self.assertFalse(BV(8,255) == BV(9,255)) + + def test_neq(self): + self.assertNotEqual(BV(0,0), BV(1,0)) + self.assertNotEqual(BV(0,0), 0) + self.assertNotEqual(0, BV(0,0)) + self.assertTrue(BV(0,0) != BV(1,0)) + self.assertTrue(BV(1,0) != BV(0,0)) + self.assertFalse(BV(0,0) != BV(0,0)) + self.assertTrue(BV(0,0) != 0) + self.assertTrue(0 != BV(0,0)) + + def test_widen(self): + self.assertEqual(BV(0,0).widen(8), BV(8,0)) + self.assertEqual(BV(9,255).widen(8), BV(17,255)) + + + def test_add(self): + self.assertEqual(BV(16,7) + BV(16,9), BV(16,16)) + self.assertEqual(BV(16,9) + BV(16,7), BV(16,16)) + self.assertEqual(BV(16,9) + BV(16,7) + 1, BV(16,17)) + self.assertEqual(1 + BV(16,9) + BV(16,7), BV(16,17)) + self.assertBinOpExpected( + lambda b1, b2: b1 + b2, + lambda b1, b2: BV(0,0) if b1.size() == 0 else + BV(b1.size(), (int(b1) + int(b2)) % ((2 ** b1.size() - 1) + 1))) + with self.assertRaises(ValueError): + BV(15,7) + BV(16,9) + + + def test_bitewise_and(self): + self.assertEqual(BV(0,0) & BV(0,0), BV(0,0)) + self.assertEqual(BV(8,0xff) & BV(8,0xff), BV(8,0xff)) + self.assertEqual(BV(8,0xff) & BV(8,42), BV(8,42)) + self.assertEqual(BV(16,7) & BV(16,9), BV(16,1)) + self.assertEqual(BV(16,9) & BV(16,7), BV(16,1)) + self.assertEqual(BV(16,9) & BV(16,7) & 1, BV(16,1)) + self.assertEqual(1 & BV(16,9) & BV(16,7), BV(16,1)) + self.assertUnOpExpected( + lambda b: b & 0, + lambda b: BV(b.size(), 0)) + self.assertUnOpExpected( + lambda b: b & (2 ** b.size() - 1), + lambda b: b) + self.assertBinOpExpected( + lambda b1, b2: b1 & b2, + lambda b1, b2: BV(b1.size(), int(b1) & int(b2))) + with self.assertRaises(ValueError): + BV(15,7) & BV(16,9) + + def test_bitewise_not(self): + self.assertEqual(~BV(0,0), BV(0,0)) + self.assertEqual(~BV(1,0b0), BV(1,0b1)) + self.assertEqual(~BV(8,0x0f), BV(8,0xf0)) + self.assertEqual(~BV(10,0b0001110101), BV(10,0b1110001010)) + self.assertEqual(~~BV(10,0b0001110101), BV(10,0b0001110101)) + self.assertUnOpExpected( + lambda b: ~~b, + lambda b: b) + self.assertUnOpExpected( + lambda b: ~b & b, + lambda b: BV(b.size(), 0)) + + + def test_positional_index(self): + self.assertFalse(BV(16,0b10)[0]) + self.assertTrue(BV(16,0b10)[1]) + self.assertFalse(BV(16,0b10)[3]) + self.assertFalse(BV(8,0b10)[7]) + with self.assertRaises(ValueError): + BV(8,7)["Bad Index"] + with self.assertRaises(ValueError): + BV(8,7)[-1] + with self.assertRaises(ValueError): + BV(8,7)[8] + + def test_positional_slice(self): + self.assertEqual(BV(0,0)[0:0], BV(0,0)) + self.assertEqual(BV(16,0b10)[2:0], BV(2,0b10)) + self.assertEqual(BV(16,0b10)[16:0], BV(16,0b10)) + self.assertEqual(BV(16,0b1100110011001100)[16:8], BV(8,0b11001100)) + with self.assertRaises(ValueError): + BV(0,0)[2:0] + with self.assertRaises(ValueError): + BV(8,42)[0:1] + with self.assertRaises(ValueError): + BV(8,42)[9:0] + with self.assertRaises(ValueError): + BV(8,42)[8:-1] + with self.assertRaises(ValueError): + BV(8,42)[10:10] + + def test_concat(self): + self.assertEqual(BV(0,0).concat(BV(0,0)), BV(0,0)) + self.assertEqual(BV(1,0b1).concat(BV(0,0b0)), BV(1,0b1)) + self.assertEqual(BV(0,0b0).concat(BV(1,0b1)), BV(1,0b1)) + self.assertEqual(BV(1,0b1).concat(BV(1,0b0)), BV(2,0b10)) + self.assertEqual(BV(1,0b0).concat(BV(1,0b1)), BV(2,0b01)) + self.assertEqual(BV(1,0b1).concat(BV(1,0b1)), BV(2,0b11)) + self.assertEqual(BV(5,0b11111).concat(BV(3,0b000)), BV(8,0b11111000)) + self.assertEqual(BV(0,0).concat(), BV(0,0)) + self.assertEqual(BV(0,0).concat(BV(2,0b10),BV(2,0b01)), BV(4,0b1001)) + self.assertBinOpExpected( + lambda b1, b2: b1.concat(b2)[b2.size():0], + lambda b1, b2: b2) + self.assertBinOpExpected( + lambda b1, b2: b1.concat(b2)[b1.size() + b2.size():b2.size()], + lambda b1, b2: b1) + with self.assertRaises(ValueError): + BV(8,42).concat(42) + with self.assertRaises(ValueError): + BV(8,42).concat("Oops not a BV") + + def test_join(self): + self.assertEqual(BV.join(), BV(0,0)) + self.assertEqual(BV.join(*[]), BV(0,0)) + self.assertEqual(BV.join(BV(8,42)), BV(8,42)) + self.assertEqual(BV.join(*[BV(8,42)]), BV(8,42)) + self.assertEqual(BV.join(BV(0,0), BV(2,0b10),BV(3,0b110)), BV(5,0b10110)) + self.assertEqual(BV.join(*[BV(0,0), BV(2,0b10),BV(3,0b110)]), BV(5,0b10110)) + + def test_bytes(self): + self.assertEqual(bytes(BV(0,0)), b'') + self.assertEqual(bytes(BV(1,1)), b'\x01') + self.assertEqual(bytes(BV(8,255)), b'\xff') + self.assertEqual(bytes(BV(16,255)), b'\x00\xff') + + + def test_zero(self): + self.assertEqual(BV(0,0).zero(), BV(0,0)) + self.assertEqual(BV(9,255).zero(), BV(9,0)) + + def test_msb(self): + self.assertEqual(BV(8,0).msb(), False) + self.assertEqual(BV(8,1).msb(), False) + self.assertEqual(BV(8,127).msb(), False) + self.assertEqual(BV(8,128).msb(), True) + self.assertEqual(BV(8,255).msb(), True) + with self.assertRaises(ValueError): + BV(0,0).msb() + + + def test_lsb(self): + self.assertEqual(BV(8,0).lsb(), False) + self.assertEqual(BV(8,1).lsb(), True) + self.assertEqual(BV(8,127).lsb(), True) + self.assertEqual(BV(8,128).lsb(), False) + self.assertEqual(BV(8,255).lsb(), True) + with self.assertRaises(ValueError): + BV(0,0).lsb() + + def test_from_signed_int(self): + self.assertEqual(BV.from_signed_int(8,127), BV(8,127)) + self.assertEqual(BV.from_signed_int(8,-128), BV(8,0x80)) + self.assertEqual(BV.from_signed_int(8,-1), BV(8,255)) + self.assertUnOpExpected( + lambda b: b if b.size() == 0 else BV.from_signed_int(b.size(), b.to_signed_int()), + lambda b: b) + with self.assertRaises(ValueError): + BV.from_signed_int(8,128) + with self.assertRaises(ValueError): + BV.from_signed_int(8,-129) + + def test_sub(self): + self.assertEqual(BV(0,0) - BV(0,0), BV(0,0)) + self.assertEqual(BV(0,0) - 0, BV(0,0)) + self.assertEqual(0 - BV(0,0), BV(0,0)) + self.assertEqual(BV(8,5) - 3, BV(8,2)) + self.assertEqual(5 - BV(8,3), BV(8,2)) + self.assertEqual(BV(8,3) - BV(8,3), BV(8,0)) + self.assertEqual(BV(8,3) - BV(8,4), BV(8,255)) + self.assertEqual(BV(8,3) - BV(8,255), BV(8,4)) + self.assertEqual(BV(8,255) - BV(8,3), BV(8,252)) + self.assertEqual(BV(8,3) - 255, BV(8,4)) + self.assertEqual(255 - BV(8,3), BV(8,252)) + self.assertUnOpExpected( + lambda b: b - b, + lambda b: b.zero()) + self.assertUnOpExpected( + lambda b: b - BV(b.size(), 2 ** b.size() - 1), + lambda b: b + 1) + with self.assertRaises(ValueError): + BV(9,3) - BV(8,3) + with self.assertRaises(ValueError): + 256 - BV(8,3) + with self.assertRaises(ValueError): + BV(8,3) - 256 + with self.assertRaises(ValueError): + (-1) - BV(8,3) + with self.assertRaises(ValueError): + BV(8,3) - (-1) + + + def test_mul(self): + self.assertEqual(BV(8,5) * BV(8,4), BV(8,20)) + self.assertEqual(5 * BV(8,4), BV(8,20)) + self.assertEqual(4 * BV(8,5), BV(8,20)) + self.assertEqual(100 * BV(8,5), BV(8,0xf4)) + self.assertEqual(BV(8,5) * 100, BV(8,0xf4)) + self.assertUnOpExpected( + lambda b: b * 3 if b.size() >= 3 else b.zero(), + lambda b: b + b + b if b.size() >= 3 else b.zero()) + with self.assertRaises(ValueError): + BV(9,3) * BV(8,3) + with self.assertRaises(ValueError): + 256 * BV(8,3) + with self.assertRaises(ValueError): + BV(8,3) * 256 + with self.assertRaises(ValueError): + (-1) * BV(8,3) + with self.assertRaises(ValueError): + BV(8,3) * (-1) + + def test_split(self): + self.assertEqual( + BV(8,0xff).split(1), + [BV(1,0x1), + BV(1,0x1), + BV(1,0x1), + BV(1,0x1), + BV(1,0x1), + BV(1,0x1), + BV(1,0x1), + BV(1,0x1)]) + self.assertEqual( + BV(9,0b100111000).split(3), + [BV(3,0b100), + BV(3,0b111), + BV(3,0x000)]) + self.assertEqual( + BV(64,0x0123456789abcdef).split(4), + [BV(4,0x0), + BV(4,0x1), + BV(4,0x2), + BV(4,0x3), + BV(4,0x4), + BV(4,0x5), + BV(4,0x6), + BV(4,0x7), + BV(4,0x8), + BV(4,0x9), + BV(4,0xa), + BV(4,0xb), + BV(4,0xc), + BV(4,0xd), + BV(4,0xe), + BV(4,0xf)]) + with self.assertRaises(ValueError): + BV(9,3).split("4") + with self.assertRaises(ValueError): + BV(9,3).split(4) + + + def test_from_bytes(self): + self.assertEqual(BV.from_bytes(b''), BV(0,0)) + self.assertEqual(BV.from_bytes(b'', size=64), BV(64,0)) + self.assertEqual(BV.from_bytes(b'\x00'), BV(8,0)) + self.assertEqual(BV.from_bytes(b'\x01'), BV(8,1)) + self.assertEqual(BV.from_bytes(b'\x01', size=16), BV(16,1)) + self.assertEqual(BV.from_bytes(b'\x00\x01'), BV(16,1)) + self.assertEqual(BV.from_bytes(b'\x01\x00', byteorder='little'), BV(16,1)) + self.assertEqual(BV.from_bytes(b'\x01\x00'), BV(16,0x0100)) + self.assertEqual(BV.from_bytes(b'\x01\x00', byteorder='little'), BV(16,0x0001)) + self.assertEqual(BV.from_bytes(b'\x01\x00', size=32,byteorder='little'), BV(32,0x0001)) + + def test_to_bytes(self): + self.assertEqual(BV(0,0).to_bytes() ,b'') + self.assertEqual(BV(8,0).to_bytes() ,b'\x00') + self.assertEqual(BV(8,1).to_bytes() ,b'\x01') + self.assertEqual(BV(16,1).to_bytes(), b'\x00\x01') + + + def test_bitewise_or(self): + self.assertEqual(BV(0,0) | BV(0,0), BV(0,0)) + self.assertEqual(BV(8,0xff) | BV(8,0x00), BV(8,0xff)) + self.assertEqual(BV(8,0x00) | BV(8,0xff), BV(8,0xff)) + self.assertEqual(BV(8,0x00) | 0xff, BV(8,0xff)) + self.assertEqual(0xff | BV(8,0x00), BV(8,0xff)) + self.assertEqual(BV(8,0x00) | BV(8,42), BV(8,42)) + self.assertUnOpExpected( + lambda b: b | 0, + lambda b: b) + with self.assertRaises(ValueError): + BV(15,7) | BV(16,9) + with self.assertRaises(ValueError): + BV(8,255) | 256 + with self.assertRaises(ValueError): + 256 | BV(8,9) + with self.assertRaises(ValueError): + BV(8,255) | -1 + with self.assertRaises(ValueError): + -1 | BV(8,9) + + + def test_bitewise_xor(self): + self.assertEqual(BV(0,0) ^ BV(0,0), BV(0,0)) + self.assertEqual(BV(8,0xff) ^ BV(8,0x00), BV(8,0xff)) + self.assertEqual(BV(8,0x00) ^ BV(8,0xff), BV(8,0xff)) + self.assertEqual(BV(8,0x0f) ^ BV(8,0xff), BV(8,0xf0)) + self.assertEqual(BV(8,0xf0) ^ BV(8,0xff), BV(8,0x0f)) + self.assertUnOpExpected( + lambda b: b ^ 0, + lambda b: b) + self.assertUnOpExpected( + lambda b: b ^ ~b, + lambda b: BV(b.size(), 2 ** b.size() - 1)) + with self.assertRaises(ValueError): + BV(15,7) ^ BV(16,9) + with self.assertRaises(ValueError): + BV(8,255) ^ 256 + with self.assertRaises(ValueError): + 256 ^ BV(8,9) + with self.assertRaises(ValueError): + BV(8,255) ^ -1 + with self.assertRaises(ValueError): + -1 ^ BV(8,9) + + def test_with_bit(self): + self.assertEqual(BV(1,0).with_bit(0,True), BV(1,1)) + self.assertEqual(BV(1,1).with_bit(0,False), BV(1,0)) + self.assertEqual(BV(8,0b11001100).with_bit(0,True), BV(8,0b11001101)) + self.assertEqual(BV(8,0b11001100).with_bit(3,False), BV(8,0b11000100)) + self.assertEqual(BV(8,0b11001100).with_bit(7,False), BV(8,0b01001100)) + with self.assertRaises(ValueError): + BV(8,0b11001100).with_bit(8,False) + with self.assertRaises(ValueError): + BV(8,0b11001100).with_bit(-1,False) + + def test_with_bits(self): + self.assertEqual(BV(1,0b0).with_bits(0,BV(1,0b0)), BV(1,0b0)) + self.assertEqual(BV(1,0b0).with_bits(0,BV(1,0b1)), BV(1,0b1)) + self.assertEqual(BV(1,0b1).with_bits(0,BV(1,0b0)), BV(1,0b0)) + self.assertEqual(BV(8,0b11010101).with_bits(3,BV(3,0b101)), BV(8,0b11101101)) + self.assertEqual(BV(8,0b11010101).with_bits(5,BV(3,0b101)), BV(8,0b10110101)) + with self.assertRaises(ValueError): + BV(8,0b11000101).with_bits(-1,BV(3,0b111)) + with self.assertRaises(ValueError): + BV(8,0b11000101).with_bits(0,"bad") + with self.assertRaises(ValueError): + BV(8,0b11000101).with_bits(0,BV(9,0)) + with self.assertRaises(ValueError): + BV(8,0b11000101).with_bits(1,BV(8,0)) diff --git a/python/hs-test/Main.hs b/python/hs-test/Main.hs index 8adf59c..e46a6b3 100644 --- a/python/hs-test/Main.hs +++ b/python/hs-test/Main.hs @@ -8,6 +8,7 @@ import Control.Monad import Test.Tasty import Test.Tasty.HUnit import Test.Tasty.HUnit.ScriptExit +import System.Exit import Paths_argo_python import Argo.PythonBindings @@ -30,5 +31,12 @@ main = tests <- makeScriptTests dir [mypy] pure (testGroup ("Typechecking: " <> name) tests) + -- Have python discover and run unit tests + (unitExitCode, unitStdOut, unitStdErr) <- + runPythonUnitTests $ testLangExecutable python + defaultMain $ - testGroup "Tests for Python components" allTests + testGroup "Tests for Python components" $ + allTests ++ + [testCase "Python Unit Tests" $ + unitExitCode @?= ExitSuccess] diff --git a/python/requirements.txt b/python/requirements.txt index 1e96325..c9870b9 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,5 +1,5 @@ BitVector==3.4.9 -mypy==0.770 +mypy==0.790 requests==2.24.0 diff --git a/tasty-script-exitcode/src/Test/Tasty/HUnit/ScriptExit.hs b/tasty-script-exitcode/src/Test/Tasty/HUnit/ScriptExit.hs index bd961b7..55133c8 100644 --- a/tasty-script-exitcode/src/Test/Tasty/HUnit/ScriptExit.hs +++ b/tasty-script-exitcode/src/Test/Tasty/HUnit/ScriptExit.hs @@ -186,3 +186,12 @@ scriptTest execPath makeArgs scriptPath = "Exit code " <> show code <> ": " <> execPath <> " " <> concat (intersperse " " args) <> ":\nstdout: " <> stdout <> "\nstderr: " <> stderr + + +-- | Given the name of the pyhon executable, +-- run `python -m unittest discover` via +-- readProcessWithExitCode and return the result. +runPythonUnitTests :: String -> IO (ExitCode, String, String) +runPythonUnitTests pyExeName = + let args = ["-m", "unittest", "discover"] + in readProcessWithExitCode pyExeName args ""