Skip to content

Commit

Permalink
4844: Refactor tests to use Spec and SpecHelper classes (#203)
Browse files Browse the repository at this point in the history
* clean-up: remove unused utils module

* style: fix docstring typos/formatting

* refactor: define spec parameters in their own class

* docs: show specs in test case ref section in online doc

* style: clean-up indentation

* docs: fix incorrect docstring

* refactor: don't make Spec, SpecHelpers enum classes

This avoids confusion/possible errors regarding when or when not to use the `.value` of the enum values; just define them as normal class attributes.

* refactor: give shared fixture a more descriptive name
  • Loading branch information
danceratopz authored Jul 14, 2023
1 parent 6e31a12 commit d61b176
Show file tree
Hide file tree
Showing 13 changed files with 483 additions and 848 deletions.
18 changes: 13 additions & 5 deletions docs/gen_test_case_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
source_directory = Path("tests")
target_dir = Path("tests")
navigation_file = "navigation.md"
non_test_files_to_include = [ # __init__.py is treated separately
"spec.py",
]


def get_script_relative_path(): # noqa: D103
Expand Down Expand Up @@ -99,13 +102,13 @@ def get_script_relative_path(): # noqa: D103
"""
# $title
Documentation for test cases from [`$pytest_test_path`]($module_github_url).
Documentation for [`$pytest_test_path`]($module_github_url).
$generate_fixtures_deployed
$generate_fixtures_development
::: $package_name
options:
filters: ["^[tT]est*"]
filters: ["^[tT]est*|^Spec*"]
"""
)
)
Expand Down Expand Up @@ -297,19 +300,20 @@ def non_recursive_os_walk(top_dir):

for file in sorted(python_files):
output_file_path = Path("undefined")

if file == "__init__.py":
output_file_path = output_directory / "index.md"
nav_path = "Test Case Reference" / test_dir_relative_path
package_name = root.replace(os.sep, ".")
pytest_test_path = root
elif not file.startswith("test_"):
continue
else:
elif file.startswith("test_") or file in non_test_files_to_include:
file_no_ext = os.path.splitext(file)[0]
output_file_path = output_directory / file_no_ext / "index.md"
nav_path = "Test Case Reference" / test_dir_relative_path / file_no_ext
package_name = os.path.join(root, file_no_ext).replace(os.sep, ".")
pytest_test_path = os.path.join(root, file)
else:
continue

nav_tuple = tuple(snake_to_capitalize(part) for part in nav_path.parts)
nav_tuple = tuple(apply_name_filters(part) for part in nav_tuple)
Expand Down Expand Up @@ -340,13 +344,17 @@ def non_recursive_os_walk(top_dir):
)

if root == "tests":
# special case, the root tests/ directory
generate_fixtures_deployed = GENERATE_FIXTURES_DEPLOYED.substitute(
pytest_test_path=pytest_test_path,
additional_title=" for all forks deployed to mainnet",
)
generate_fixtures_development = GENERATE_FIXTURES_DEVELOPMENT.substitute(
pytest_test_path=pytest_test_path, fork=DEV_FORKS[0]
)
elif file in non_test_files_to_include:
generate_fixtures_deployed = ""
generate_fixtures_development = ""
elif dev_forks := [fork for fork in DEV_FORKS if fork.lower() in root.lower()]:
assert len(dev_forks) == 1
generate_fixtures_deployed = ""
Expand Down
169 changes: 35 additions & 134 deletions tests/cancun/eip4844_blobs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
Common constants, classes & functions local to EIP-4844 tests.
"""
from dataclasses import dataclass
from hashlib import sha256
from typing import List, Literal, Tuple, Union

from ethereum_test_tools import (
TestAddress,
Transaction,
YulCompiler,
add_kzg_version,
compute_create2_address,
Expand All @@ -16,113 +14,14 @@
)
from ethereum_test_tools.vm.opcode import Opcodes as Op

# Reference Spec
REF_SPEC_4844_GIT_PATH = "EIPS/eip-4844.md"
REF_SPEC_4844_VERSION = "f0eb6a364aaf5ccb43516fa2c269a54fb881ecfd"

# Constants
BLOB_COMMITMENT_VERSION_KZG = 1
BLOBHASH_GAS_COST = 3
BLS_MODULUS = 0x73EDA753299D7D483339D80809A1D80553BDA402FFFE5BFEFFFFFFFF00000001
BLS_MODULUS_BYTES = BLS_MODULUS.to_bytes(32, "big")
DATA_GAS_PER_BLOB = 2**17
DATA_GASPRICE_UPDATE_FRACTION = 3338477
BYTES_PER_FIELD_ELEMENT = 32
FIELD_ELEMENTS_PER_BLOB = 4096
FIELD_ELEMENTS_PER_BLOB_BYTES = FIELD_ELEMENTS_PER_BLOB.to_bytes(32, "big")
from .spec import Spec, SpecHelpers

INF_POINT = (0xC0 << 376).to_bytes(48, byteorder="big")
MAX_DATA_GAS_PER_BLOCK = 786432
MAX_BLOBS_PER_BLOCK = MAX_DATA_GAS_PER_BLOCK // DATA_GAS_PER_BLOB
MIN_DATA_GASPRICE = 1
POINT_EVALUATION_PRECOMPILE_ADDRESS = 10
POINT_EVALUATION_PRECOMPILE_GAS = 50_000
TARGET_DATA_GAS_PER_BLOCK = 393216
TARGET_BLOBS_PER_BLOCK = TARGET_DATA_GAS_PER_BLOCK // DATA_GAS_PER_BLOB
Z = 0x623CE31CF9759A5C8DAF3A357992F9F3DD7F9339D8998BC8E68373E54F00B75E
Z_Y_INVALID_ENDIANNESS: Literal["little", "big"] = "little"
Z_Y_VALID_ENDIANNESS: Literal["little", "big"] = "big"


# Functions
def fake_exponential(factor: int, numerator: int, denominator: int) -> int:
"""
Used to calculate the data gas cost.
"""
i = 1
output = 0
numerator_accumulator = factor * denominator
while numerator_accumulator > 0:
output += numerator_accumulator
numerator_accumulator = (numerator_accumulator * numerator) // (denominator * i)
i += 1
return output // denominator


def get_total_data_gas(tx: Transaction) -> int:
"""
Calculate the total data gas for a transaction.
"""
if tx.blob_versioned_hashes is None:
return 0
return DATA_GAS_PER_BLOB * len(tx.blob_versioned_hashes)


def get_data_gasprice(*, excess_data_gas: int) -> int:
"""
Calculate the data gas price from the excess.
"""
return fake_exponential(
MIN_DATA_GASPRICE,
excess_data_gas,
DATA_GASPRICE_UPDATE_FRACTION,
)


def get_min_excess_data_gas_for_data_gas_price(data_gas_price: int) -> int:
"""
Gets the minimum required excess data gas value to get a given data gas cost in a block
"""
current_excess_data_gas = 0
current_data_gas_price = 1
while current_data_gas_price < data_gas_price:
current_excess_data_gas += DATA_GAS_PER_BLOB
current_data_gas_price = get_data_gasprice(excess_data_gas=current_excess_data_gas)
return current_excess_data_gas


def get_min_excess_data_blobs_for_data_gas_price(data_gas_price: int) -> int:
"""
Gets the minimum required excess data blobs to get a given data gas cost in a block
"""
return get_min_excess_data_gas_for_data_gas_price(data_gas_price) // DATA_GAS_PER_BLOB


def calc_excess_data_gas(*, parent_excess_data_gas: int, parent_blobs: int) -> int:
"""
Calculate the excess data gas for a block given the parent excess data gas
and the number of blobs in the block.
"""
parent_consumed_data_gas = parent_blobs * DATA_GAS_PER_BLOB
if parent_excess_data_gas + parent_consumed_data_gas < TARGET_DATA_GAS_PER_BLOCK:
return 0
else:
return parent_excess_data_gas + parent_consumed_data_gas - TARGET_DATA_GAS_PER_BLOCK


def kzg_to_versioned_hash(
kzg_commitment: bytes | int, # 48 bytes
blob_commitment_version_kzg: bytes | int = BLOB_COMMITMENT_VERSION_KZG,
) -> bytes:
"""
Calculates the versioned hash for a given KZG commitment.
"""
if isinstance(kzg_commitment, int):
kzg_commitment = kzg_commitment.to_bytes(48, "big")
if isinstance(blob_commitment_version_kzg, int):
blob_commitment_version_kzg = blob_commitment_version_kzg.to_bytes(1, "big")
return blob_commitment_version_kzg + sha256(kzg_commitment).digest()[1:]


@dataclass(kw_only=True)
class Blob:
"""
Expand All @@ -137,7 +36,7 @@ def versioned_hash(self) -> bytes:
"""
Calculates the versioned hash for a given blob.
"""
return kzg_to_versioned_hash(self.kzg_commitment)
return Spec.kzg_to_versioned_hash(self.kzg_commitment)

@staticmethod
def blobs_to_transaction_input(
Expand All @@ -160,8 +59,8 @@ def blobs_to_transaction_input(

# Simple list of blob versioned hashes ranging from bytes32(1 to 4)
simple_blob_hashes: list[bytes] = add_kzg_version(
[(1 << x) for x in range(MAX_BLOBS_PER_BLOCK)],
BLOB_COMMITMENT_VERSION_KZG,
[(1 << x) for x in range(SpecHelpers.max_blobs_per_block())],
Spec.BLOB_COMMITMENT_VERSION_KZG,
)

# Random fixed list of blob versioned hashes
Expand All @@ -178,7 +77,7 @@ def blobs_to_transaction_input(
"0x00d78c25f8a1d6aa04d0e2e2a71cf8dfaa4239fa0f301eb57c249d1e6bfe3c3d",
"0x00c778eb1348a73b9c30c7b1d282a5f8b2c5b5a12d5c5e4a4a35f9c5f639b4a4",
],
BLOB_COMMITMENT_VERSION_KZG,
Spec.BLOB_COMMITMENT_VERSION_KZG,
)

# Blobhash index values for test_blobhash_gas_cost
Expand Down Expand Up @@ -244,29 +143,29 @@ def code(cls, context_name):
"blobhash_sstore": cls.yul_compiler(
f"""
{{
let pos := calldataload(0)
let end := calldataload(32)
for {{}} lt(pos, end) {{ pos := add(pos, 1) }}
{{
let pos := calldataload(0)
let end := calldataload(32)
for {{}} lt(pos, end) {{ pos := add(pos, 1) }}
{{
let blobhash := {blobhash_verbatim}
(hex"{Op.BLOBHASH.hex()}", pos)
sstore(pos, blobhash)
}}
let blobhash := {blobhash_verbatim}
}}
let blobhash := {blobhash_verbatim}
(hex"{Op.BLOBHASH.hex()}", end)
sstore(end, blobhash)
return(0, 0)
sstore(end, blobhash)
return(0, 0)
}}
"""
),
"blobhash_return": cls.yul_compiler(
f"""
{{
let pos := calldataload(0)
let blobhash := {blobhash_verbatim}
let pos := calldataload(0)
let blobhash := {blobhash_verbatim}
(hex"{Op.BLOBHASH.hex()}", pos)
mstore(0, blobhash)
return(0, 32)
mstore(0, blobhash)
return(0, 32)
}}
"""
),
Expand Down Expand Up @@ -351,13 +250,13 @@ def code(cls, context_name):
"initcode": cls.yul_compiler(
f"""
{{
for {{ let pos := 0 }} lt(pos, 10) {{ pos := add(pos, 1) }}
{{
for {{ let pos := 0 }} lt(pos, 10) {{ pos := add(pos, 1) }}
{{
let blobhash := {blobhash_verbatim}
(hex"{Op.BLOBHASH.hex()}", pos)
sstore(pos, blobhash)
}}
return(0, 0)
}}
return(0, 0)
}}
"""
),
Expand Down Expand Up @@ -405,16 +304,16 @@ def create_blob_hashes_list(length: int) -> list[list[bytes]]:
length: MAX_BLOBS_PER_BLOCK * length
-> [0x01, 0x02, 0x03, 0x04, ..., 0x0A, 0x0B, 0x0C, 0x0D]
Then split list into smaller chunks of MAX_BLOBS_PER_BLOCK
Then split list into smaller chunks of SpecHelpers.max_blobs_per_block()
-> [[0x01, 0x02, 0x03, 0x04], ..., [0x0a, 0x0b, 0x0c, 0x0d]]
"""
b_hashes = [
random_blob_hashes[i % len(random_blob_hashes)]
for i in range(MAX_BLOBS_PER_BLOCK * length)
for i in range(SpecHelpers.max_blobs_per_block() * length)
]
return [
b_hashes[i : i + MAX_BLOBS_PER_BLOCK]
for i in range(0, len(b_hashes), MAX_BLOBS_PER_BLOCK)
b_hashes[i : i + SpecHelpers.max_blobs_per_block()]
for i in range(0, len(b_hashes), SpecHelpers.max_blobs_per_block())
]

@staticmethod
Expand All @@ -427,7 +326,7 @@ def blobhash_sstore(index: int):
the BLOBHASH sstore.
"""
invalidity_check = Op.SSTORE(index, 0x01)
if index < 0 or index >= MAX_BLOBS_PER_BLOCK:
if index < 0 or index >= SpecHelpers.max_blobs_per_block():
return invalidity_check + Op.SSTORE(index, Op.BLOBHASH(index))
return Op.SSTORE(index, Op.BLOBHASH(index))

Expand All @@ -437,23 +336,25 @@ def generate_blobhash_bytecode(cls, scenario_name: str) -> bytes:
Returns BLOBHASH bytecode for the given scenario.
"""
scenarios = {
"single_valid": b"".join(cls.blobhash_sstore(i) for i in range(MAX_BLOBS_PER_BLOCK)),
"single_valid": b"".join(
cls.blobhash_sstore(i) for i in range(SpecHelpers.max_blobs_per_block())
),
"repeated_valid": b"".join(
b"".join(cls.blobhash_sstore(i) for _ in range(10))
for i in range(MAX_BLOBS_PER_BLOCK)
for i in range(SpecHelpers.max_blobs_per_block())
),
"valid_invalid": b"".join(
cls.blobhash_sstore(i)
+ cls.blobhash_sstore(MAX_BLOBS_PER_BLOCK)
+ cls.blobhash_sstore(SpecHelpers.max_blobs_per_block())
+ cls.blobhash_sstore(i)
for i in range(MAX_BLOBS_PER_BLOCK)
for i in range(SpecHelpers.max_blobs_per_block())
),
"varied_valid": b"".join(
cls.blobhash_sstore(i) + cls.blobhash_sstore(i + 1) + cls.blobhash_sstore(i)
for i in range(MAX_BLOBS_PER_BLOCK - 1)
for i in range(SpecHelpers.max_blobs_per_block() - 1)
),
"invalid_calls": b"".join(
cls.blobhash_sstore(i) for i in range(-5, MAX_BLOBS_PER_BLOCK + 5)
cls.blobhash_sstore(i) for i in range(-5, SpecHelpers.max_blobs_per_block() + 5)
),
}
scenario = scenarios.get(scenario_name)
Expand Down
Loading

0 comments on commit d61b176

Please sign in to comment.