Skip to content

Commit

Permalink
fix end_common
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez committed Jun 12, 2024
1 parent 1e912c8 commit b11f704
Showing 1 changed file with 15 additions and 55 deletions.
70 changes: 15 additions & 55 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,6 @@ where
DB: DatabaseExt,
{
ecx: &'a mut EvmContext<DB>,
prank_origin: Option<Address>,
single_call_prank: bool,
original_origin: Option<Address>,
single_call_broadcast: bool,
depth: u64,
expected_revert: Option<ExpectedRevert>,
outcome_result: InstructionResult,
outcome_output: Bytes,
address: Option<Address>,
Expand Down Expand Up @@ -497,31 +491,31 @@ impl Cheatcodes {
F: FnMut(InstructionResult, Bytes, Option<Address>) -> CommonEndOutcome,
{
// Clean up pranks
if let Some(prank_origin) = params.prank_origin {
if params.ecx.inner.journaled_state.depth() == params.depth {
params.ecx.inner.env.tx.caller = prank_origin;
if let Some(prank) = &self.prank {
if params.ecx.inner.journaled_state.depth() == prank.depth {
params.ecx.inner.env.tx.caller = prank.prank_origin;

// Clean single-call prank once we have returned to the original depth
if params.single_call_prank {
if prank.single_call {
std::mem::take(&mut self.prank);
}
}
}

// Clean up broadcasts
if let Some(original_origin) = params.original_origin {
if params.ecx.inner.journaled_state.depth() == params.depth {
params.ecx.inner.env.tx.caller = original_origin;
if let Some(broadcast) = &self.broadcast {
if params.ecx.inner.journaled_state.depth() == broadcast.depth {
params.ecx.inner.env.tx.caller = broadcast.original_origin;

// Clean single-call broadcast once we have returned to the original depth
if params.single_call_broadcast {
if broadcast.single_call {
std::mem::take(&mut self.broadcast);
}
}
}

// Handle expected reverts
if let Some(expected_revert) = &params.expected_revert {
if let Some(expected_revert) = &self.expected_revert {
if params.ecx.inner.journaled_state.depth() <= expected_revert.depth &&
matches!(expected_revert.kind, ExpectedRevertKind::Default)
{
Expand All @@ -533,24 +527,13 @@ impl Cheatcodes {
params.outcome_output.clone(),
) {
Ok((new_address, retdata)) => {
params.outcome_result = InstructionResult::Return;
params.outcome_output = retdata;
params.address = new_address;
create_outcome_fn(
params.outcome_result,
params.outcome_output.clone(),
params.address,
)
}
Err(err) => {
params.outcome_result = InstructionResult::Revert;
params.outcome_output = err.abi_encode().into();
create_outcome_fn(
params.outcome_result,
params.outcome_output.clone(),
params.address,
)
create_outcome_fn(InstructionResult::Return, retdata.clone(), new_address)
}
Err(err) => create_outcome_fn(
InstructionResult::Revert,
Bytes::from(err.abi_encode()),
params.address,
),
};
}
}
Expand Down Expand Up @@ -606,7 +589,6 @@ impl Cheatcodes {
}
}
}

create_outcome_fn(params.outcome_result, params.outcome_output.clone(), params.address)
}
}
Expand Down Expand Up @@ -1659,19 +1641,8 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
_call: &CreateInputs,
outcome: CreateOutcome,
) -> CreateOutcome {
let depth = ecx.inner.journaled_state.depth();

let mut params = EndParams {
ecx,
prank_origin: self.prank.as_ref().map(|prank| prank.prank_origin),
single_call_prank: self.prank.as_ref().map_or(false, |prank| prank.single_call),
original_origin: self.broadcast.as_ref().map(|broadcast| broadcast.original_origin),
single_call_broadcast: self
.broadcast
.as_ref()
.map_or(false, |broadcast| broadcast.single_call),
depth,
expected_revert: self.expected_revert.clone(),
outcome_result: outcome.result.result,
outcome_output: outcome.result.output.clone(),
address: outcome.address,
Expand Down Expand Up @@ -1754,19 +1725,8 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
call: &EOFCreateInput,
outcome: EOFCreateOutcome,
) -> EOFCreateOutcome {
let depth = ecx.inner.journaled_state.depth();

let mut params = EndParams {
ecx,
prank_origin: self.prank.as_ref().map(|prank| prank.prank_origin),
single_call_prank: self.prank.as_ref().map_or(false, |prank| prank.single_call),
original_origin: self.broadcast.as_ref().map(|broadcast| broadcast.original_origin),
single_call_broadcast: self
.broadcast
.as_ref()
.map_or(false, |broadcast| broadcast.single_call),
depth,
expected_revert: self.expected_revert.clone(),
outcome_result: outcome.result.result,
outcome_output: outcome.result.output.clone(),
address: Some(outcome.address),
Expand Down

0 comments on commit b11f704

Please sign in to comment.