Skip to content

Commit

Permalink
feat(hint): post state key hint
Browse files Browse the repository at this point in the history
  • Loading branch information
winsvega committed Oct 16, 2024
1 parent 10860cc commit 0d03773
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 15 deletions.
49 changes: 39 additions & 10 deletions src/ethereum_test_base_types/composite_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Base composite types for Ethereum test cases.
"""
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, SupportsBytes, Type, TypeAlias
from typing import Any, ClassVar, Dict, Optional, SupportsBytes, Type, TypeAlias

from pydantic import Field, PrivateAttr, RootModel, TypeAdapter

Expand Down Expand Up @@ -92,22 +92,37 @@ class KeyValueMismatch(Exception):
key: int
want: int
got: int

def __init__(self, address: Address, key: int, want: int, got: int, *args):
hint: Optional[Dict[int, str]] = None

def __init__(
self,
address: Address,
key: int,
want: int,
got: int,
hint: Optional[Dict[int, str]] = None,
*args,
):
super().__init__(args)
self.address = address
self.key = key
self.want = want
self.got = got
self.hint = hint

def __str__(self):
"""Print exception string"""
label_str = ""
if self.address.label is not None:
label_str = f" ({self.address.label})"

key = Hash(self.key)
if self.hint is not None:
key = self.hint[self.key]

return (
f"incorrect value in address {self.address}{label_str} for "
+ f"key {Hash(self.key)}:"
+ f"key {key}:"
+ f" want {HexNumber(self.want)} (dec:{int(self.want)}),"
+ f" got {HexNumber(self.got)} (dec:{int(self.got)})"
)
Expand Down Expand Up @@ -233,7 +248,12 @@ def must_contain(self, address: Address, other: "Storage"):
address=address, key=key, want=self[key], got=other[key]
)

def must_be_equal(self, address: Address, other: "Storage | None"):
def must_be_equal(
self,
address: Address,
other: "Storage | None",
post_hint: Optional[Dict[int, str]] = None,
):
"""
Succeeds only if "self" is equal to "other" storage.
"""
Expand All @@ -243,17 +263,21 @@ def must_be_equal(self, address: Address, other: "Storage | None"):
for key in self.keys() & other.keys():
if self[key] != other[key]:
raise Storage.KeyValueMismatch(
address=address, key=key, want=self[key], got=other[key]
address=address, key=key, want=self[key], got=other[key], hint=post_hint
)

# Test keys contained in either one of the storage objects
for key in self.keys() ^ other.keys():
if key in self:
if self[key] != 0:
raise Storage.KeyValueMismatch(address=address, key=key, want=self[key], got=0)
raise Storage.KeyValueMismatch(
address=address, key=key, want=self[key], got=0, hint=post_hint
)

elif other[key] != 0:
raise Storage.KeyValueMismatch(address=address, key=key, want=0, got=other[key])
raise Storage.KeyValueMismatch(
address=address, key=key, want=0, got=other[key], hint=post_hint
)

def canary(self) -> "Storage":
"""
Expand Down Expand Up @@ -374,7 +398,12 @@ def __str__(self):
+ f"want {self.want}, got {self.got}"
)

def check_alloc(self: "Account", address: Address, account: "Account"):
def check_alloc(
self: "Account",
address: Address,
account: "Account",
post_hint: Optional[Dict[int, str]] = None,
):
"""
Checks the returned alloc against an expected account in post state.
Raises exception on failure.
Expand Down Expand Up @@ -404,7 +433,7 @@ def check_alloc(self: "Account", address: Address, account: "Account"):
)

if "storage" in self.model_fields_set:
self.storage.must_be_equal(address=address, other=account.storage)
self.storage.must_be_equal(address=address, other=account.storage, post_hint=post_hint)

def __bool__(self: "Account") -> bool:
"""
Expand Down
6 changes: 4 additions & 2 deletions src/ethereum_test_specs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class StateTest(BaseTest):
pre: Alloc
post: Alloc
tx: Transaction
post_hint: Optional[Dict[int, str]] = None
engine_api_error_code: Optional[EngineAPIError] = None
blockchain_test_header_verify: Optional[Header] = None
blockchain_test_rlp_modifier: Optional[Header] = None
Expand Down Expand Up @@ -117,6 +118,7 @@ def make_state_test_fixture(
t8n: TransitionTool,
fork: Fork,
eips: Optional[List[int]] = None,
post_hint: Optional[Dict[int, str]] = None,
) -> Fixture:
"""
Create a fixture from the state test definition.
Expand Down Expand Up @@ -146,7 +148,7 @@ def make_state_test_fixture(
)

try:
self.post.verify_post_alloc(transition_tool_output.alloc)
self.post.verify_post_alloc(transition_tool_output.alloc, post_hint)
except Exception as e:
print_traces(t8n.get_traces())
raise e
Expand Down Expand Up @@ -183,7 +185,7 @@ def generate(
request=request, t8n=t8n, fork=fork, fixture_format=fixture_format, eips=eips
)
elif fixture_format == StateFixture:
return self.make_state_test_fixture(t8n, fork, eips)
return self.make_state_test_fixture(t8n, fork, eips, self.post_hint)

raise Exception(f"Unknown fixture format: {fixture_format}")

Expand Down
6 changes: 3 additions & 3 deletions src/ethereum_test_types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass
from functools import cached_property
from typing import Any, ClassVar, Dict, Generic, List, Literal, Sequence, Tuple
from typing import Any, ClassVar, Dict, Generic, List, Literal, Optional, Sequence, Tuple

from coincurve.keys import PrivateKey, PublicKey
from ethereum import rlp as eth_rlp
Expand Down Expand Up @@ -268,7 +268,7 @@ def state_root(self) -> bytes:
)
return state_root(state)

def verify_post_alloc(self, got_alloc: "Alloc"):
def verify_post_alloc(self, got_alloc: "Alloc", post_hint: Optional[Dict[int, str]] = None):
"""
Verify that the allocation matches the expected post in the test.
Raises exception on unexpected values.
Expand All @@ -284,7 +284,7 @@ def verify_post_alloc(self, got_alloc: "Alloc"):
got_account = got_alloc.root[address]
assert isinstance(got_account, Account)
assert isinstance(account, Account)
account.check_alloc(address, got_account)
account.check_alloc(address, got_account, post_hint)
else:
raise Alloc.MissingAccount(address)

Expand Down

0 comments on commit 0d03773

Please sign in to comment.