From 64297d6e78b7b86e8389514a80d13c1a046a9c0d Mon Sep 17 00:00:00 2001 From: Ming Date: Fri, 14 Jul 2023 20:26:31 +0700 Subject: [PATCH] fine tune pi circuit endian (#438) --- src/zkevm_specs/pi_circuit.py | 144 +++++++++++++++++----------------- 1 file changed, 72 insertions(+), 72 deletions(-) diff --git a/src/zkevm_specs/pi_circuit.py b/src/zkevm_specs/pi_circuit.py index 7f2bab1ad..c88d4ff37 100644 --- a/src/zkevm_specs/pi_circuit.py +++ b/src/zkevm_specs/pi_circuit.py @@ -336,64 +336,64 @@ def verify_circuit( for i in range(BLOCK_LEN // 2 + 1): block_row = block_table.table[i] - lo = copy_constrains.pop(0) + lo_le = copy_constrains.pop(0)[::-1] if block_row.is_word: - hi = copy_constrains.pop(0) + hi_le = copy_constrains.pop(0)[::-1] else: - hi = bytes(0) + hi_le = bytes(0)[::-1] (lo_expr, hi_expr) = block_row.to_lo_hi() - assert lo_expr == bytes_to_fq(lo) - assert hi_expr == bytes_to_fq(hi) + assert lo_expr == bytes_to_fq(lo_le) + assert hi_expr == bytes_to_fq(hi_le) # constrain block_hash and state_root lo/hi. # TODO layout block_hash in proper table - lo = copy_constrains.pop(0) - hi = copy_constrains.pop(0) - assert public_inputs.block_hash.lo.expr() == bytes_to_fq(lo) - assert public_inputs.block_hash.hi.expr() == bytes_to_fq(hi) + lo_le = copy_constrains.pop(0)[::-1] + hi_le = copy_constrains.pop(0)[::-1] + assert public_inputs.block_hash.lo.expr() == bytes_to_fq(lo_le) + assert public_inputs.block_hash.hi.expr() == bytes_to_fq(hi_le) # TODO layout state_root in proper table - lo = copy_constrains.pop(0) - hi = copy_constrains.pop(0) - assert public_inputs.state_root.lo.expr() == bytes_to_fq(lo) - assert public_inputs.state_root.hi.expr() == bytes_to_fq(hi) + lo_le = copy_constrains.pop(0)[::-1] + hi_le = copy_constrains.pop(0)[::-1] + assert public_inputs.state_root.lo.expr() == bytes_to_fq(lo_le) + assert public_inputs.state_root.hi.expr() == bytes_to_fq(hi_le) # TODO layout state_root_prev in proper table - lo = copy_constrains.pop(0) - hi = copy_constrains.pop(0) - assert public_inputs.state_root_prev.lo.expr() == bytes_to_fq(lo) - assert public_inputs.state_root_prev.hi.expr() == bytes_to_fq(hi) + lo_le = copy_constrains.pop(0)[::-1] + hi_le = copy_constrains.pop(0)[::-1] + assert public_inputs.state_root_prev.lo.expr() == bytes_to_fq(lo_le) + assert public_inputs.state_root_prev.hi.expr() == bytes_to_fq(hi_le) # constrain tx table `id``, `index`, value lo/hi per row, and all rows equals witness rpi bytes in vertical order tx_len = TX_LEN * MAX_TXS + 1 for i in range(tx_len): tx_row: TxTableRow = tx_table.table[i] tx_id, index, value = tx_row.tx_id, tx_row.index, tx_row.value - lo = copy_constrains.pop(0) - assert tx_id == bytes_to_fq(lo) - lo = copy_constrains.pop(0) - assert index == bytes_to_fq(lo) + lo_le = copy_constrains.pop(0)[::-1] + assert tx_id == bytes_to_fq(lo_le) + lo_le = copy_constrains.pop(0)[::-1] + assert index == bytes_to_fq(lo_le) - lo = copy_constrains.pop(0) + lo_le = copy_constrains.pop(0)[::-1] if value.is_word: - hi = copy_constrains.pop(0) + hi_le = copy_constrains.pop(0)[::-1] else: - hi = bytes(0) - assert value.lo.expr() == bytes_to_fq(lo) - assert value.hi.expr() == bytes_to_fq(hi) + hi_le = bytes(0) + assert value.lo.expr() == bytes_to_fq(lo_le) + assert value.hi.expr() == bytes_to_fq(hi_le) # constrain tx calldata value lo/hi euqal to equals witness rpi bytes in vertical order calldata_len = MAX_CALLDATA_BYTES for i in range(calldata_len): value = tx_table.table[tx_len + i].value - lo = copy_constrains.pop(0) + lo_le = copy_constrains.pop(0)[::-1] if value.is_word: - hi = copy_constrains.pop(0) + hi_le = copy_constrains.pop(0)[::-1] else: - hi = bytes(0) - assert value.lo.expr() == bytes_to_fq(lo) - assert value.hi.expr() == bytes_to_fq(hi) + hi_le = bytes(0) + assert value.lo.expr() == bytes_to_fq(lo_le) + assert value.hi.expr() == bytes_to_fq(hi_le) # check gates constrains for i in range(len(rows)): @@ -475,39 +475,39 @@ def tx_table_value_column(self) -> List[WordOrValue]: def tx_raw_bytes(self, tx_id: int) -> List[bytes]: tx_raw_byte: List[bytes] = [] self.append_raw_byte_with_id_index( - tx_raw_byte, tx_id, self.nonce.to_bytes(8, "little") + tx_raw_byte, tx_id, self.nonce.to_bytes(8, "big") ) # Nonce self.append_raw_byte_with_id_index( - tx_raw_byte, tx_id, self.gas.to_bytes(8, "little") + tx_raw_byte, tx_id, self.gas.to_bytes(8, "big") ) # Gas Limit gas_price_lo, gas_price_hi = Word(self.gas_price).to_lo_hi() self.append_raw_byte_with_id_index( tx_raw_byte, tx_id, - gas_price_lo.n.to_bytes(16, "little"), - gas_price_hi.n.to_bytes(16, "little"), + gas_price_lo.n.to_bytes(16, "big"), + gas_price_hi.n.to_bytes(16, "big"), ) # GasPrice self.append_raw_byte_with_id_index( - tx_raw_byte, tx_id, self.from_addr.to_bytes(20, "little") + tx_raw_byte, tx_id, self.from_addr.to_bytes(20, "big") ) # CallerAddress self.append_raw_byte_with_id_index( - tx_raw_byte, tx_id, (self.to_addr or U160(0)).to_bytes(20, "little") + tx_raw_byte, tx_id, (self.to_addr or U160(0)).to_bytes(20, "big") ) # CalleeAddress self.append_raw_byte_with_id_index( tx_raw_byte, tx_id, - (U64(1) if self.to_addr is None else U64(0)).to_bytes(8, "little"), + (U64(1) if self.to_addr is None else U64(0)).to_bytes(8, "big"), ) # IsCreate value_lo, value_hi = Word(self.value).to_lo_hi() self.append_raw_byte_with_id_index( tx_raw_byte, tx_id, - value_lo.n.to_bytes(16, "little"), - value_hi.n.to_bytes(16, "little"), + value_lo.n.to_bytes(16, "big"), + value_hi.n.to_bytes(16, "big"), ) # Value self.append_raw_byte_with_id_index( - tx_raw_byte, tx_id, U64(len(self.data)).to_bytes(8, "little") + tx_raw_byte, tx_id, U64(len(self.data)).to_bytes(8, "big") ) # CallDataLength call_data_gas_cost = sum( [ @@ -520,14 +520,14 @@ def tx_raw_bytes(self, tx_id: int) -> List[bytes]: ] ) self.append_raw_byte_with_id_index( - tx_raw_byte, tx_id, U64(call_data_gas_cost).to_bytes(8, "little") + tx_raw_byte, tx_id, U64(call_data_gas_cost).to_bytes(8, "big") ) # CallDataCost tx_sign_hash_lo, tx_sign_hash_hi = Word(self.tx_sign_hash).to_lo_hi() self.append_raw_byte_with_id_index( tx_raw_byte, tx_id, - tx_sign_hash_lo.n.to_bytes(16, "little"), - tx_sign_hash_hi.n.to_bytes(16, "little"), + tx_sign_hash_lo.n.to_bytes(16, "big"), + tx_sign_hash_hi.n.to_bytes(16, "big"), ) # TxSignHash return tx_raw_byte @@ -538,8 +538,8 @@ def append_raw_byte_with_id_index( value_lo: bytes, value_hi: bytes = bytes(0), ): - raw_byte_value_col.append(U64(tx_id).to_bytes(8, "little")) - raw_byte_value_col.append(U64(0).to_bytes(8, "little")) + raw_byte_value_col.append(U64(tx_id).to_bytes(8, "big")) + raw_byte_value_col.append(U64(0).to_bytes(8, "big")) raw_byte_value_col.append(value_lo) if value_hi != bytes(0): raw_byte_value_col.append(value_hi) @@ -581,23 +581,23 @@ def block_table_raw_byte_values(self) -> List[bytes]: """Return the block table bytes, including first 0 row""" raw_block_value = [] - raw_block_value.append(U8(0).to_bytes(1, "little")) # offset = 0 - raw_block_value.append(self.block.coinbase.to_bytes(20, "little")) - raw_block_value.append(self.block.gas_limit.to_bytes(8, "little")) - raw_block_value.append(self.block.number.to_bytes(8, "little")) - raw_block_value.append(self.block.time.to_bytes(8, "little")) + raw_block_value.append(U8(0).to_bytes(1, "big")) # offset = 0 + raw_block_value.append(self.block.coinbase.to_bytes(20, "big")) + raw_block_value.append(self.block.gas_limit.to_bytes(8, "big")) + raw_block_value.append(self.block.number.to_bytes(8, "big")) + raw_block_value.append(self.block.time.to_bytes(8, "big")) difficulty_lo, difficulty_hi = Word(self.block.difficulty).to_lo_hi() - raw_block_value.append(difficulty_lo.n.to_bytes(16, "little")) - raw_block_value.append(difficulty_hi.n.to_bytes(16, "little")) + raw_block_value.append(difficulty_lo.n.to_bytes(16, "big")) + raw_block_value.append(difficulty_hi.n.to_bytes(16, "big")) base_fee_lo, base_fee_hi = Word(self.block.base_fee).to_lo_hi() - raw_block_value.append(base_fee_lo.n.to_bytes(16, "little")) - raw_block_value.append(base_fee_hi.n.to_bytes(16, "little")) - raw_block_value.append(self.chain_id.to_bytes(8, "little")) + raw_block_value.append(base_fee_lo.n.to_bytes(16, "big")) + raw_block_value.append(base_fee_hi.n.to_bytes(16, "big")) + raw_block_value.append(self.chain_id.to_bytes(8, "big")) assert len(self.block_hashes) == 256 for block_hash in self.block_hashes: block_hash_lo, block_hash_hi = Word(block_hash).to_lo_hi() - raw_block_value.append(block_hash_lo.n.to_bytes(16, "little")) - raw_block_value.append(block_hash_hi.n.to_bytes(16, "little")) + raw_block_value.append(block_hash_lo.n.to_bytes(16, "big")) + raw_block_value.append(block_hash_hi.n.to_bytes(16, "big")) return raw_block_value def tx_table_raw_bytes(self, MAX_TXS: int) -> List[bytes]: @@ -605,9 +605,9 @@ def tx_table_raw_bytes(self, MAX_TXS: int) -> List[bytes]: table_raw_bytes = [] assert len(self.txs) > 0 assert len(self.txs) <= MAX_TXS - table_raw_bytes.append(U64(0).to_bytes(8, "little")) # empty id - table_raw_bytes.append(U64(0).to_bytes(8, "little")) # empty index - table_raw_bytes.append(U8(0).to_bytes(1, "little")) # empty value lo + table_raw_bytes.append(U64(0).to_bytes(8, "big")) # empty id + table_raw_bytes.append(U64(0).to_bytes(8, "big")) # empty index + table_raw_bytes.append(U8(0).to_bytes(1, "big")) # empty value lo for i in range(MAX_TXS): tx = Transaction.default() if i < len(self.txs): @@ -618,15 +618,15 @@ def tx_table_raw_bytes(self, MAX_TXS: int) -> List[bytes]: def tx_table_calldata_raw_bytes(self, MAX_CALLDATA_BYTES: int) -> List[bytes]: tx_calldata_raw_bytes = [] calldata_count = 0 - for i, tx in enumerate(self.txs): - for byte_index, byte in enumerate(tx.data): - tx_calldata_raw_bytes.append(U8(byte).to_bytes(1, "little")) + for tx in self.txs: + for byte in tx.data: + tx_calldata_raw_bytes.append(U8(byte).to_bytes(1, "big")) calldata_count += 1 assert calldata_count <= MAX_CALLDATA_BYTES for _ in range(MAX_CALLDATA_BYTES - calldata_count): - tx_calldata_raw_bytes.append(U8(0).to_bytes(1, "little")) + tx_calldata_raw_bytes.append(U8(0).to_bytes(1, "big")) return tx_calldata_raw_bytes @@ -740,14 +740,14 @@ def public_data2witness( # Extra fields hash_lo, hash_hi = Word(public_data.block.hash).to_lo_hi() - rpi_byte_values.append(hash_lo.n.to_bytes(16, "little")) - rpi_byte_values.append(hash_hi.n.to_bytes(16, "little")) + rpi_byte_values.append(hash_lo.n.to_bytes(16, "big")) + rpi_byte_values.append(hash_hi.n.to_bytes(16, "big")) state_root_lo, state_root_hi = Word(public_data.block.state_root).to_lo_hi() - rpi_byte_values.append(state_root_lo.n.to_bytes(16, "little")) - rpi_byte_values.append(state_root_hi.n.to_bytes(16, "little")) + rpi_byte_values.append(state_root_lo.n.to_bytes(16, "big")) + rpi_byte_values.append(state_root_hi.n.to_bytes(16, "big")) state_root_prev_lo, state_root_prev_hi = Word(public_data.state_root_prev).to_lo_hi() - rpi_byte_values.append(state_root_prev_lo.n.to_bytes(16, "little")) - rpi_byte_values.append(state_root_prev_hi.n.to_bytes(16, "little")) + rpi_byte_values.append(state_root_prev_lo.n.to_bytes(16, "big")) + rpi_byte_values.append(state_root_prev_hi.n.to_bytes(16, "big")) assert flatten_len(rpi_byte_values) == N_BYTES_ONE + N_BYTES_BLOCK + N_BYTES_EXTRA_VALUE # Tx Table @@ -799,7 +799,7 @@ def public_data2witness( rpi_bytes = [] for value in reversed(rpi_byte_values): # acc from big endian - for byte_index, byte in enumerate(reversed(value)): + for byte_index, byte in enumerate(value): rpi_bytes.append(byte) q_rpi_byte_enable = FQ.one()