Skip to content

Commit

Permalink
chore: minor refactor to more closely match snark-verifier
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathanpwang committed Aug 17, 2023
1 parent 11088c2 commit 1fdcded
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 30 deletions.
28 changes: 17 additions & 11 deletions halo2-base/src/poseidon/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem;

use crate::{
gates::GateInstructions,
poseidon::{spec::OptimizedPoseidonSpec, state::PoseidonState},
Expand Down Expand Up @@ -27,7 +29,7 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherChip<F, T,
pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>(
ctx: &mut Context<F>,
) -> Self {
let init_state = PoseidonState::<F, T, RATE>::default(ctx);
let init_state = PoseidonState::default(ctx);
let state = init_state.clone();
Self {
init_state,
Expand All @@ -37,6 +39,12 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherChip<F, T,
}
}

/// Initialize a poseidon hasher from an existing spec.
pub fn from_spec(ctx: &mut Context<F>, spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
let init_state = PoseidonState::default(ctx);
Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() }
}

/// Reset state to default and clear the buffer.
pub fn clear(&mut self) {
self.state = self.init_state.clone();
Expand All @@ -55,17 +63,13 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherChip<F, T,
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
) -> AssignedValue<F> {
let mut input_elements = vec![];
input_elements.append(&mut self.absorbing);

let mut padding_offset = 0;
let input_elements = mem::take(&mut self.absorbing);
let exact = input_elements.len() % RATE == 0;

for chunk in input_elements.chunks(RATE) {
padding_offset = RATE - chunk.len();
self.permutation(ctx, gate, chunk.to_vec());
}

if padding_offset == 0 {
if exact {
self.permutation(ctx, gate, vec![]);
}

Expand All @@ -80,25 +84,27 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherChip<F, T,
) {
let r_f = self.spec.r_f / 2;
let mds = &self.spec.mds_matrices.mds.0;
let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0;
let sparse_matrices = &self.spec.mds_matrices.sparse_matrices;

// First half of the full round
let constants = &self.spec.constants.start;
self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]);
for constants in constants.iter().skip(1).take(r_f - 1) {
self.state.sbox_full(ctx, gate, constants);
self.state.apply_mds(ctx, gate, mds);
}

let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0;
self.state.sbox_full(ctx, gate, constants.last().unwrap());
self.state.apply_mds(ctx, gate, pre_sparse_mds);

let sparse_matrices = &self.spec.mds_matrices.sparse_matrices;
// Partial rounds
let constants = &self.spec.constants.partial;
for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) {
self.state.sbox_part(ctx, gate, constant);
self.state.apply_sparse_mds(ctx, gate, sparse_mds);
}

// Second half of the full rounds
let constants = &self.spec.constants.end;
for constants in constants.iter() {
self.state.sbox_full(ctx, gate, constants);
Expand Down
41 changes: 22 additions & 19 deletions halo2-base/src/poseidon/state.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::iter;

use crate::{
gates::GateInstructions,
poseidon::mds::SparseMDSMatrix,
Expand Down Expand Up @@ -61,7 +63,6 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonState<F, T, RATE
pre_constants: &[F; T],
) {
assert!(inputs.len() < T);
let offset = inputs.len() + 1;

// Explanation of what's going on: before each round of the poseidon permutation,
// two things have to be added to the state: inputs (the absorbed elements) and
Expand All @@ -77,20 +78,19 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonState<F, T, RATE

// adding pre-constants and inputs to the elements for which both are available
for ((x, constant), input) in
self.s.iter_mut().skip(1).zip(pre_constants.iter().skip(1)).zip(inputs.iter())
self.s.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs.iter())
{
*x = gate.sum(ctx, [Existing(*x), Existing(*input), Constant(*constant)]);
}

let offset = inputs.len() + 1;
// adding only pre-constants when no input is left
for (i, (x, constant)) in
self.s.iter_mut().skip(offset).zip(pre_constants.iter().skip(offset)).enumerate()
self.s.iter_mut().zip(pre_constants.iter()).skip(offset).enumerate()
{
*x = gate.add(
ctx,
Existing(*x),
Constant(if i == 0 { F::ONE + constant } else { *constant }),
);
*x = gate.add(ctx, *x, Constant(if i == 0 { F::ONE + constant } else { *constant }));
// the if idx == 0 { F::one() } else { F::zero() } is to pad the input with a single 1 and then 0s
// this is the padding suggested in pg 31 of https://eprint.iacr.org/2019/458.pdf and in Section 4.2 (Variable-Input-Length Hashing. The padding consists of one field element being 1, and the remaining elements being 0.)
}
}

Expand All @@ -116,16 +116,19 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonState<F, T, RATE
gate: &impl GateInstructions<F>,
mds: &SparseMDSMatrix<F, T, RATE>,
) {
let sum =
gate.inner_product(ctx, self.s.iter().copied(), mds.row.iter().map(|c| Constant(*c)));
let mut res = vec![sum];

for (e, x) in mds.col_hat.iter().zip(self.s.iter().skip(1)) {
res.push(gate.mul_add(ctx, self.s[0], Constant(*e), *x));
}

for (x, new_x) in self.s.iter_mut().zip(res) {
*x = new_x
}
self.s = iter::once(gate.inner_product(
ctx,
self.s.iter().copied(),
mds.row.iter().map(|c| Constant(*c)),
))
.chain(
mds.col_hat
.iter()
.zip(self.s.iter().skip(1))
.map(|(coeff, state)| gate.mul_add(ctx, self.s[0], Constant(*coeff), *state)),
)
.collect::<Vec<_>>()
.try_into()
.unwrap();
}
}

0 comments on commit 1fdcded

Please sign in to comment.