Skip to content

Commit

Permalink
[feat] App Circuit Utils for Keccak Coprocessor (#141)
Browse files Browse the repository at this point in the history
* Add keccak coprocessor encoding for VarLenBytesVec/FixLenBytesVec

* Fix naming/nits

* Fix nit
  • Loading branch information
nyunyunyunyu authored Sep 9, 2023
1 parent 9377d90 commit 54044c9
Show file tree
Hide file tree
Showing 10 changed files with 556 additions and 14 deletions.
46 changes: 46 additions & 0 deletions halo2-base/src/poseidon/hasher/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ impl<F: ScalarField, const RATE: usize> PoseidonCompactInput<F, RATE> {
}
}

/// A compact chunk input for Poseidon hasher. The end of a logical input could only be at the boundary of a chunk.
#[derive(Clone, Debug)]
pub struct PoseidonCompactChunkInput<F: ScalarField, const RATE: usize> {
// Inputs of a chunk. All witnesses will be absorbed.
inputs: Vec<[AssignedValue<F>; RATE]>,
// is_final = 1 triggers squeeze.
is_final: SafeBool<F>,
}

impl<F: ScalarField, const RATE: usize> PoseidonCompactChunkInput<F, RATE> {
/// Create a new PoseidonCompactInput.
pub fn new(inputs: Vec<[AssignedValue<F>; RATE]>, is_final: SafeBool<F>) -> Self {
Self { inputs, is_final }
}
}

/// 1 logical row of compact output for Poseidon hasher.
#[derive(Copy, Clone, Debug, Getters)]
pub struct PoseidonCompactOutput<F: ScalarField> {
Expand Down Expand Up @@ -232,6 +248,36 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RAT
}
outputs
}

/// Constrains and returns hashes of chunk inputs in a compact format. Length of `chunk_inputs` should be determined at compile time.
pub fn hash_compact_chunk_inputs(
&self,
ctx: &mut Context<F>,
range: &impl RangeInstructions<F>,
chunk_inputs: &[PoseidonCompactChunkInput<F, RATE>],
) -> Vec<PoseidonCompactOutput<F>>
where
F: BigPrimeField,
{
let zero_witness = ctx.load_zero();
let mut outputs = Vec::with_capacity(chunk_inputs.len());
let mut state = self.init_state().clone();
for chunk_input in chunk_inputs {
let is_final = chunk_input.is_final;
for absorb in &chunk_input.inputs {
state.permutation(ctx, range.gate(), absorb, None, &self.spec);
}
// Because the length of each absorb is always RATE. An extra permutation is needed for squeeze.
let mut output_state = state.clone();
output_state.permutation(ctx, range.gate(), &[], None, &self.spec);
let hash =
range.gate().select(ctx, output_state.s[1], zero_witness, *is_final.as_ref());
outputs.push(PoseidonCompactOutput { hash, is_final });
// Reset state to init_state if this is the end of a logical input.
state.select(ctx, range.gate(), is_final, self.init_state());
}
outputs
}
}

/// Poseidon sponge. This is stateful.
Expand Down
123 changes: 122 additions & 1 deletion halo2-base/src/poseidon/hasher/tests/hasher.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use crate::{
gates::{range::RangeInstructions, RangeChip},
halo2_proofs::halo2curves::bn256::Fr,
poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonCompactInput, PoseidonHasher},
poseidon::hasher::{
spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactInput,
PoseidonHasher,
},
safe_types::SafeTypeChip,
utils::{testing::base_test, ScalarField},
Context,
};
use halo2_proofs_axiom::arithmetic::Field;
use itertools::Itertools;
use pse_poseidon::Poseidon;
use rand::Rng;

Expand Down Expand Up @@ -111,6 +115,61 @@ fn hasher_compact_inputs_compatiblity_verification<
}
}

// check if the results from hasher and native sponge are same for hash_compact_input.
fn hasher_compact_chunk_inputs_compatiblity_verification<
const T: usize,
const RATE: usize,
const R_F: usize,
const R_P: usize,
>(
payloads: Vec<(Payload<Fr>, bool)>,
ctx: &mut Context<Fr>,
range: &RangeChip<Fr>,
) {
// Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0.
let spec = OptimizedPoseidonSpec::<Fr, T, RATE>::new::<R_F, R_P, 0>();
let mut hasher = PoseidonHasher::<Fr, T, RATE>::new(spec);
hasher.initialize_consts(ctx, range.gate());

let mut native_results = Vec::with_capacity(payloads.len());
let mut chunk_inputs = Vec::<PoseidonCompactChunkInput<Fr, RATE>>::new();
let true_witness = SafeTypeChip::unsafe_to_bool(ctx.load_constant(Fr::ONE));
let false_witness = SafeTypeChip::unsafe_to_bool(ctx.load_zero());

// Construct native Poseidon sponge.
let mut native_sponge = Poseidon::<Fr, T, RATE>::new(R_F, R_P);
for (payload, is_final) in payloads {
assert!(payload.values.len() == payload.len);
assert!(payload.values.len() % RATE == 0);
let inputs = ctx.assign_witnesses(payload.values.clone());

let is_final_witness = if is_final { true_witness } else { false_witness };
chunk_inputs.push(PoseidonCompactChunkInput {
inputs: inputs.chunks(RATE).map(|c| c.try_into().unwrap()).collect_vec(),
is_final: is_final_witness,
});
native_sponge.update(&payload.values);
if is_final {
let native_result = native_sponge.squeeze();
native_results.push(native_result);
native_sponge = Poseidon::<Fr, T, RATE>::new(R_F, R_P);
}
}
let compact_outputs = hasher.hash_compact_chunk_inputs(ctx, range, &chunk_inputs);
assert_eq!(chunk_inputs.len(), compact_outputs.len());
let mut output_offset = 0;
for (compact_output, chunk_input) in compact_outputs.iter().zip(chunk_inputs) {
// into() doesn't work if ! is in the beginning in the bool expression...
let is_final_input = chunk_input.is_final.as_ref().value();
let is_final_output = compact_output.is_final().as_ref().value();
assert_eq!(is_final_input, is_final_output);
if is_final_output == &Fr::ONE {
assert_eq!(native_results[output_offset], *compact_output.hash().value());
output_offset += 1;
}
}
}

fn random_payload<F: ScalarField>(max_len: usize, len: usize, max_value: usize) -> Payload<F> {
assert!(len <= max_len);
let mut rng = rand::thread_rng();
Expand Down Expand Up @@ -235,3 +294,65 @@ fn test_poseidon_hasher_compact_inputs_with_prover() {
});
}
}

#[test]
fn test_poseidon_hasher_compact_chunk_inputs() {
{
const T: usize = 3;
const RATE: usize = 2;
let payloads = vec![
(random_payload(RATE * 5, RATE * 5, usize::MAX), true),
(random_payload(RATE, RATE, usize::MAX), false),
(random_payload(RATE * 2, RATE * 2, usize::MAX), true),
(random_payload(RATE * 3, RATE * 3, usize::MAX), true),
];
base_test().k(12).run(|ctx, range| {
hasher_compact_chunk_inputs_compatiblity_verification::<T, RATE, 8, 57>(
payloads, ctx, range,
);
});
}
{
const T: usize = 3;
const RATE: usize = 2;
let payloads = vec![
(random_payload(0, 0, usize::MAX), true),
(random_payload(0, 0, usize::MAX), false),
(random_payload(0, 0, usize::MAX), false),
];
base_test().k(12).run(|ctx, range| {
hasher_compact_chunk_inputs_compatiblity_verification::<T, RATE, 8, 57>(
payloads, ctx, range,
);
});
}
}

#[test]
fn test_poseidon_hasher_compact_chunk_inputs_with_prover() {
{
const T: usize = 3;
const RATE: usize = 2;
let params = [
(RATE, false),
(RATE * 2, false),
(RATE * 5, false),
(RATE * 2, true),
(RATE * 5, true),
];
let init_payloads = params
.iter()
.map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final))
.collect::<Vec<_>>();
let logic_payloads = params
.iter()
.map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final))
.collect::<Vec<_>>();
base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| {
let ctx = pool.main();
hasher_compact_chunk_inputs_compatiblity_verification::<T, RATE, 8, 57>(
input, ctx, range,
);
});
}
}
58 changes: 56 additions & 2 deletions halo2-base/src/safe_types/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ impl<F: ScalarField, const MAX_LEN: usize> VarLenBytes<F, MAX_LEN> {
padded.into_iter().map(|b| SafeByte(b)).collect::<Vec<_>>().try_into().unwrap(),
)
}

/// Return a copy of the byte array with 0 padding ensured.
pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
Self::new(bytes.try_into().unwrap(), self.len)
}
}

/// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time.
Expand Down Expand Up @@ -93,7 +99,13 @@ impl<F: ScalarField> VarLenBytesVec<F> {
gate: &impl GateInstructions<F>,
) -> FixLenBytesVec<F> {
let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len());
padded.into_iter().map(|b| SafeByte(b)).collect()
FixLenBytesVec::new(padded.into_iter().map(|b| SafeByte(b)).collect_vec(), self.max_len())
}

/// Return a copy of the byte array with 0 padding ensured.
pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
Self::new(bytes, self.len, self.max_len())
}
}

Expand All @@ -117,6 +129,27 @@ impl<F: ScalarField, const LEN: usize> FixLenBytes<F, LEN> {
}
}

/// Represents a fixed length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time.
#[derive(Debug, Clone, Getters)]
pub struct FixLenBytesVec<F: ScalarField> {
/// The byte array
#[getset(get = "pub")]
bytes: Vec<SafeByte<F>>,
}

impl<F: ScalarField> FixLenBytesVec<F> {
// FixLenBytes can be only created by SafeChip.
pub(super) fn new(bytes: Vec<SafeByte<F>>, len: usize) -> Self {
assert_eq!(bytes.len(), len, "bytes length doesn't match");
Self { bytes }
}

/// Returns the length of the byte array.
pub fn len(&self) -> usize {
self.bytes.len()
}
}

impl<F: ScalarField, const TOTAL_BITS: usize> From<SafeType<F, 1, TOTAL_BITS>>
for FixLenBytes<F, { SafeType::<F, 1, TOTAL_BITS>::VALUE_LENGTH }>
{
Expand All @@ -138,7 +171,7 @@ impl<F: ScalarField, const TOTAL_BITS: usize>

/// Represents a fixed length byte array in circuit as a vector, where length must be fixed.
/// Not encouraged to use because `LEN` cannot be verified at compile time.
pub type FixLenBytesVec<F> = Vec<SafeByte<F>>;
// pub type FixLenBytesVec<F> = Vec<SafeByte<F>>;

/// Takes a fixed length array `arr` and returns a length `out_len` array equal to
/// `[[0; out_len - len], arr[..len]].concat()`, i.e., we take `arr[..len]` and
Expand Down Expand Up @@ -172,3 +205,24 @@ pub fn left_pad_var_array_to_fixed<F: ScalarField>(
}
padded
}

fn ensure_0_padding<F: ScalarField>(
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
bytes: &[SafeByte<F>],
len: AssignedValue<F>,
) -> Vec<SafeByte<F>> {
let max_len = bytes.len();
// Generate a mask array where a[i] = i < len for i = 0..max_len.
let idx = gate.dec(ctx, len);
let len_indicator = gate.idx_to_indicator(ctx, idx, max_len);
// inputs_mask[i] = sum(len_indicator[i..])
let mut mask = gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec();
mask.reverse();

bytes
.iter()
.zip(mask.iter())
.map(|(byte, mask)| SafeByte(gate.mul(ctx, byte.0, *mask)))
.collect_vec()
}
31 changes: 30 additions & 1 deletion halo2-base/src/safe_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
FixLenBytes::<F, MAX_LEN>::new(inputs.map(|input| Self::unsafe_to_byte(input)))
}

/// Unsafe method that directly converts `inputs` to [`FixLenBytesVec`] **without any checks**.
/// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
pub fn unsafe_to_fix_len_bytes_vec(
inputs: RawAssignedValues<F>,
len: usize,
) -> FixLenBytesVec<F> {
FixLenBytesVec::<F>::new(
inputs.into_iter().map(|input| Self::unsafe_to_byte(input)).collect_vec(),
len,
)
}

/// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes.
///
/// * ctx: Circuit [Context]<F> to assign witnesses to.
Expand All @@ -249,7 +261,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
/// * ctx: Circuit [Context]<F> to assign witnesses to.
/// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding.
/// * len: [AssignedValue]<F> witness representing the variable length of the byte array. Constrained to be `<= max_len`.
/// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain.
/// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. We enforce this to be provided explictly to make sure length of `inputs` is determinstic.
pub fn raw_to_var_len_bytes_vec(
&self,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -278,6 +290,23 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
FixLenBytes::<F, LEN>::new(inputs.map(|input| self.assert_byte(ctx, input)))
}

/// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytesVec.
///
/// * ctx: Circuit [Context]<F> to assign witnesses to.
/// * inputs: Slice representing the byte array.
/// * len: length of the byte array. We enforce this to be provided explictly to make sure length of `inputs` is determinstic.
pub fn raw_to_fix_len_bytes_vec(
&self,
ctx: &mut Context<F>,
inputs: RawAssignedValues<F>,
len: usize,
) -> FixLenBytesVec<F> {
FixLenBytesVec::<F>::new(
inputs.into_iter().map(|input| self.assert_byte(ctx, input)).collect_vec(),
len,
)
}

fn add_bytes_constraints(
&self,
ctx: &mut Context<F>,
Expand Down
29 changes: 27 additions & 2 deletions halo2-base/src/safe_types/tests/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn left_pad_var_len_bytes(mut bytes: Vec<u8>, max_len: usize) -> Vec<u8> {
let len = ctx.load_witness(Fr::from(len as u64));
let bytes = safe.raw_to_var_len_bytes_vec(ctx, bytes, len, max_len);
let padded = bytes.left_pad_to_fixed(ctx, range.gate());
padded.iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect()
padded.bytes().iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect()
})
}

Expand Down Expand Up @@ -132,7 +132,7 @@ fn neg_var_len_bytes_vec_len_less_than_max_len() {

// Circuit Satisfied for valid inputs
#[test]
fn pos_fix_len_bytes_vec() {
fn pos_fix_len_bytes() {
base_test().k(10).lookup_bits(8).run(|ctx, range| {
let safe = SafeTypeChip::new(range);
let fake_bytes = ctx.assign_witnesses(
Expand All @@ -142,6 +142,31 @@ fn pos_fix_len_bytes_vec() {
});
}

// Assert inputs.len() == len
#[test]
#[should_panic]
fn neg_fix_len_bytes_vec() {
base_test().k(10).lookup_bits(8).run(|ctx, range| {
let safe = SafeTypeChip::new(range);
let fake_bytes = ctx.assign_witnesses(
vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::<Vec<_>>(),
);
safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 5);
});
}

// Circuit Satisfied for valid inputs
#[test]
fn pos_fix_len_bytes_vec() {
base_test().k(10).lookup_bits(8).run(|ctx, range| {
let safe = SafeTypeChip::new(range);
let fake_bytes = ctx.assign_witnesses(
vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::<Vec<_>>(),
);
safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 4);
});
}

// =========== Prover ===========
#[test]
fn pos_prover_satisfied() {
Expand Down
Loading

0 comments on commit 54044c9

Please sign in to comment.