Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(contracts-communication): deduplicate versioning libraries #2402

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions packages/contracts-communication/contracts/InterchainClientV1.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
} from "./libs/InterchainTransaction.sol";
import {OptionsLib, OptionsV1} from "./libs/Options.sol";
import {TypeCasts} from "./libs/TypeCasts.sol";
import {VersionedPayloadLib} from "./libs/VersionedPayload.sol";

import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol";

Expand All @@ -26,6 +27,7 @@ import {Ownable} from "@openzeppelin/contracts/access/Ownable.sol";
contract InterchainClientV1 is Ownable, InterchainClientV1Events, IInterchainClientV1 {
using AppConfigLib for bytes;
using OptionsLib for bytes;
using VersionedPayloadLib for bytes;

/// @notice Version of the InterchainClient contract. Sent and received transactions must have the same version.
uint16 public constant CLIENT_VERSION = 1;
Expand Down Expand Up @@ -209,8 +211,11 @@ contract InterchainClientV1 is Ownable, InterchainClientV1Events, IInterchainCli
}

/// @notice Encodes the transaction data into a bytes format.
function encodeTransaction(InterchainTransaction memory icTx) external pure returns (bytes memory) {
return InterchainTransactionLib.encodeVersionedTransaction(CLIENT_VERSION, icTx);
function encodeTransaction(InterchainTransaction memory icTx) public pure returns (bytes memory) {
return VersionedPayloadLib.encodeVersionedPayload({
version: CLIENT_VERSION,
payload: InterchainTransactionLib.encodeTransaction(icTx)
});
}

// ═════════════════════════════════════════════════ INTERNAL ══════════════════════════════════════════════════════
Expand Down Expand Up @@ -245,7 +250,7 @@ contract InterchainClientV1 is Ownable, InterchainClientV1Events, IInterchainCli
options: options,
message: message
});
desc.transactionId = keccak256(InterchainTransactionLib.encodeVersionedTransaction(CLIENT_VERSION, icTx));
desc.transactionId = keccak256(encodeTransaction(icTx));
// Sanity check: nonce returned from DB should match the nonce used to construct the transaction
{
(uint256 dbNonce, uint64 entryIndex) = IInterchainDB(INTERCHAIN_DB).writeEntryWithVerification{
Expand Down Expand Up @@ -362,15 +367,15 @@ contract InterchainClientV1 is Ownable, InterchainClientV1Events, IInterchainCli
}

/// @dev Asserts that the transaction version is correct. Returns the decoded transaction for chaining purposes.
function _assertCorrectVersion(bytes calldata encodedTx)
function _assertCorrectVersion(bytes calldata versionedTx)
internal
pure
returns (InterchainTransaction memory icTx)
{
uint16 version;
(version, icTx) = InterchainTransactionLib.decodeVersionedTransaction(encodedTx);
uint16 version = versionedTx.getVersion();
if (version != CLIENT_VERSION) {
revert InterchainClientV1__InvalidTransactionVersion(version);
}
icTx = InterchainTransactionLib.decodeTransaction(versionedTx.getPayload());
}
}
11 changes: 9 additions & 2 deletions packages/contracts-communication/contracts/InterchainDB.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ import {IInterchainModule} from "./interfaces/IInterchainModule.sol";

import {InterchainBatch, InterchainBatchLib} from "./libs/InterchainBatch.sol";
import {InterchainEntry, InterchainEntryLib} from "./libs/InterchainEntry.sol";
import {VersionedPayloadLib} from "./libs/VersionedPayload.sol";

contract InterchainDB is InterchainDBEvents, IInterchainDB {
using VersionedPayloadLib for bytes;

uint16 public constant DB_VERSION = 1;

bytes32[] internal _entryValues;
Expand Down Expand Up @@ -65,10 +68,11 @@ contract InterchainDB is InterchainDBEvents, IInterchainDB {

/// @inheritdoc IInterchainDB
function verifyRemoteBatch(bytes calldata versionedBatch) external {
(uint16 dbVersion, InterchainBatch memory batch) = InterchainBatchLib.decodeVersionedBatch(versionedBatch);
uint16 dbVersion = versionedBatch.getVersion();
if (dbVersion != DB_VERSION) {
revert InterchainDB__InvalidBatchVersion(dbVersion);
}
InterchainBatch memory batch = InterchainBatchLib.decodeBatch(versionedBatch.getPayload());
if (batch.srcChainId == block.chainid) {
revert InterchainDB__SameChainId(batch.srcChainId);
}
Expand Down Expand Up @@ -217,7 +221,10 @@ contract InterchainDB is InterchainDBEvents, IInterchainDB {
fees[0] += msg.value - totalFee;
}
uint256 len = srcModules.length;
bytes memory versionedBatch = InterchainBatchLib.encodeVersionedBatch(DB_VERSION, batch);
bytes memory versionedBatch = VersionedPayloadLib.encodeVersionedPayload({
version: DB_VERSION,
payload: InterchainBatchLib.encodeBatch(batch)
});
for (uint256 i = 0; i < len; ++i) {
IInterchainModule(srcModules[i]).requestBatchVerification{value: fees[i]}(dstChainId, versionedBatch);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,46 +36,19 @@ library InterchainBatchLib {
return InterchainBatch({srcChainId: block.chainid, dbNonce: dbNonce, batchRoot: batchRoot});
}

/// @notice Decodes the versioned batch payload from memory into version and InterchainBatch struct.
/// @dev See `VersionedPayloadLib` for more details about calldata/memory locations.
/// @param versionedBatch The versioned batch payload
/// @return dbVersion The version of the InterchainDB contract that created the batch
/// @return batch The InterchainBatch struct
function decodeVersionedBatchFromMemory(bytes memory versionedBatch)
internal
view
returns (uint16 dbVersion, InterchainBatch memory batch)
{
dbVersion = versionedBatch.getVersionFromMemory();
batch = abi.decode(versionedBatch.getPayloadFromMemory(), (InterchainBatch));
/// @notice Encodes the InterchainBatch struct into a non-versioned batch payload.
function encodeBatch(InterchainBatch memory batch) internal pure returns (bytes memory) {
return abi.encode(batch);
}

/// @notice Decodes the versioned batch payload into version and InterchainBatch struct.
/// @param versionedBatch The versioned batch payload
/// @return dbVersion The version of the InterchainDB contract that created the batch
/// @return batch The InterchainBatch struct
function decodeVersionedBatch(bytes calldata versionedBatch)
internal
pure
returns (uint16 dbVersion, InterchainBatch memory batch)
{
dbVersion = versionedBatch.getVersion();
batch = abi.decode(versionedBatch.getPayload(), (InterchainBatch));
/// @notice Decodes the InterchainBatch struct from a non-versioned batch payload in calldata.
function decodeBatch(bytes calldata data) internal pure returns (InterchainBatch memory) {
return abi.decode(data, (InterchainBatch));
}

/// @notice Encodes the InterchainBatch struct into a versioned batch payload.
/// @param dbVersion The version of the InterchainDB contract that created the batch
/// @param batch The InterchainBatch struct
/// @return versionedBatch The versioned batch payload
function encodeVersionedBatch(
uint16 dbVersion,
InterchainBatch memory batch
)
internal
pure
returns (bytes memory versionedBatch)
{
versionedBatch = VersionedPayloadLib.encodeVersionedPayload(dbVersion, abi.encode(batch));
/// @notice Decodes the InterchainBatch struct from a non-versioned batch payload in memory.
function decodeBatchFromMemory(bytes memory data) internal pure returns (InterchainBatch memory) {
return abi.decode(data, (InterchainBatch));
}

/// @notice Returns the globally unique identifier of the batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,12 @@ library InterchainTransactionLib {
});
}

function encodeVersionedTransaction(
uint16 clientVersion,
InterchainTransaction memory transaction
)
internal
pure
returns (bytes memory)
{
return VersionedPayloadLib.encodeVersionedPayload(clientVersion, abi.encode(transaction));
function encodeTransaction(InterchainTransaction memory transaction) internal pure returns (bytes memory) {
return abi.encode(transaction);
}

function decodeVersionedTransaction(bytes calldata versionedTx)
internal
pure
returns (uint16 clientVersion, InterchainTransaction memory transaction)
{
clientVersion = versionedTx.getVersion();
transaction = abi.decode(versionedTx.getPayload(), (InterchainTransaction));
function decodeTransaction(bytes calldata transaction) internal pure returns (InterchainTransaction memory) {
return abi.decode(transaction, (InterchainTransaction));
}

function payloadSize(uint256 optionsLen, uint256 messageLen) internal pure returns (uint256) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ import {IInterchainModule} from "../interfaces/IInterchainModule.sol";

import {InterchainBatch, InterchainBatchLib} from "../libs/InterchainBatch.sol";
import {ModuleBatchLib} from "../libs/ModuleBatch.sol";
import {VersionedPayloadLib} from "../libs/VersionedPayload.sol";

import {MessageHashUtils} from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";

/// @notice Common logic for all Interchain Modules.
abstract contract InterchainModule is InterchainModuleEvents, IInterchainModule {
using VersionedPayloadLib for bytes;

address public immutable INTERCHAIN_DB;

constructor(address interchainDB) {
Expand All @@ -23,7 +26,7 @@ abstract contract InterchainModule is InterchainModuleEvents, IInterchainModule
if (msg.sender != INTERCHAIN_DB) {
revert InterchainModule__NotInterchainDB(msg.sender);
}
(, InterchainBatch memory batch) = InterchainBatchLib.decodeVersionedBatch(versionedBatch);
InterchainBatch memory batch = InterchainBatchLib.decodeBatch(versionedBatch.getPayload());
if (dstChainId == block.chainid) {
revert InterchainModule__SameChainId(block.chainid);
}
Expand Down Expand Up @@ -51,7 +54,7 @@ abstract contract InterchainModule is InterchainModuleEvents, IInterchainModule
function _verifyBatch(bytes memory encodedModuleBatch) internal {
(bytes memory versionedBatch, bytes memory moduleData) =
ModuleBatchLib.decodeVersionedModuleBatch(encodedModuleBatch);
(, InterchainBatch memory batch) = InterchainBatchLib.decodeVersionedBatchFromMemory(versionedBatch);
InterchainBatch memory batch = InterchainBatchLib.decodeBatchFromMemory(versionedBatch.getPayloadFromMemory());
if (batch.srcChainId == block.chainid) {
revert InterchainModule__SameChainId(block.chainid);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import {
} from "../contracts/libs/InterchainTransaction.sol";
import {OptionsLib} from "../contracts/libs/Options.sol";

import {InterchainTransactionLibHarness} from "./harnesses/InterchainTransactionLibHarness.sol";
import {VersionedPayloadLibHarness} from "./harnesses/VersionedPayloadLibHarness.sol";
import {ExecutionFeesMock} from "./mocks/ExecutionFeesMock.sol";
import {ExecutionServiceMock} from "./mocks/ExecutionServiceMock.sol";
import {InterchainDBMock} from "./mocks/InterchainDBMock.sol";
Expand All @@ -26,6 +28,9 @@ abstract contract InterchainClientV1BaseTest is Test, InterchainClientV1Events {
bytes32 public constant MOCK_REMOTE_CLIENT = keccak256("RemoteClient");
uint16 public constant CLIENT_VERSION = 1;

InterchainTransactionLibHarness public txLibHarness;
VersionedPayloadLibHarness public payloadLibHarness;

address public mockRemoteClientEVM = makeAddr("RemoteClientEVM");
bytes32 public mockRemoteClientEVMBytes32 = bytes32(uint256(uint160(mockRemoteClientEVM)));

Expand All @@ -47,6 +52,8 @@ abstract contract InterchainClientV1BaseTest is Test, InterchainClientV1Events {
execService = address(new ExecutionServiceMock());
icModuleA = address(new InterchainModuleMock());
icModuleB = address(new InterchainModuleMock());
txLibHarness = new InterchainTransactionLibHarness();
payloadLibHarness = new VersionedPayloadLibHarness();
}

function setExecutionFees(address executionFees) public {
Expand Down Expand Up @@ -229,8 +236,8 @@ abstract contract InterchainClientV1BaseTest is Test, InterchainClientV1Events {
assertEq(icTx.message, expected.message, "!message");
}

function getEncodedTx(InterchainTransaction memory icTx) internal pure returns (bytes memory) {
return InterchainTransactionLib.encodeVersionedTransaction(CLIENT_VERSION, icTx);
function getEncodedTx(InterchainTransaction memory icTx) internal view returns (bytes memory) {
return payloadLibHarness.encodeVersionedPayload(CLIENT_VERSION, txLibHarness.encodeTransaction(icTx));
}

// ═══════════════════════════════════════════════════ UTILS ═══════════════════════════════════════════════════════
Expand Down
Loading
Loading