Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add Poseidon Chip #114

Merged
merged 7 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion halo2-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ harness = false

[[example]]
name = "inner_product"
features = ["test-utils"]
required-features = ["test-utils"]
16 changes: 16 additions & 0 deletions halo2-base/src/gates/flex_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ pub trait GateInstructions<F: ScalarField> {
ctx.assign_region_last([a, b, Constant(F::ONE), Witness(out_val)], [0])
}

/// Constrains and returns `out = a + 1`.
///
/// * `ctx`: [Context] to add the constraints to
/// * `a`: [QuantumCell] value
fn inc(&self, ctx: &mut Context<F>, a: impl Into<QuantumCell<F>>) -> AssignedValue<F> {
self.add(ctx, a, Constant(F::ONE))
}

/// Constrains and returns `a + b * (-1) = out`.
///
/// Defines a vertical gate of form | a - b | b | 1 | a |, where (a - b) = out.
Expand All @@ -200,6 +208,14 @@ pub trait GateInstructions<F: ScalarField> {
ctx.get(-4)
}

/// Constrains and returns `out = a - 1`.
///
/// * `ctx`: [Context] to add the constraints to
/// * `a`: [QuantumCell] value
fn dec(&self, ctx: &mut Context<F>, a: impl Into<QuantumCell<F>>) -> AssignedValue<F> {
self.sub(ctx, a, Constant(F::ONE))
}

/// Constrains and returns `a - b * c = out`.
///
/// Defines a vertical gate of form | a - b * c | b | c | a |, where (a - b * c) = out.
Expand Down
12 changes: 12 additions & 0 deletions halo2-base/src/gates/tests/flex_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,24 @@ pub fn test_add(inputs: &[QuantumCell<Fr>]) -> Fr {
base_test().run_gate(|ctx, chip| *chip.add(ctx, inputs[0], inputs[1]).value())
}

#[test_case(Witness(Fr::from(10))=> Fr::from(11); "inc(): 10 -> 11")]
#[test_case(Witness(Fr::from(1))=> Fr::from(2); "inc(): 1 -> 2")]
pub fn test_inc(input: QuantumCell<Fr>) -> Fr {
base_test().run_gate(|ctx, chip| *chip.inc(ctx, input).value())
}

#[test_case(&[10, 12].map(Fr::from).map(Witness)=> -Fr::from(2) ; "sub(): 10 - 12 == -2")]
#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(0) ; "sub(): 1 - 1 == 0")]
pub fn test_sub(inputs: &[QuantumCell<Fr>]) -> Fr {
base_test().run_gate(|ctx, chip| *chip.sub(ctx, inputs[0], inputs[1]).value())
}

#[test_case(Witness(Fr::from(10))=> Fr::from(9); "dec(): 10 -> 9")]
#[test_case(Witness(Fr::from(1))=> Fr::from(0); "dec(): 1 -> 0")]
pub fn test_dec(input: QuantumCell<Fr>) -> Fr {
base_test().run_gate(|ctx, chip| *chip.dec(ctx, input).value())
}

#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub_mul(): 1 - 1 * 1 == 0")]
pub fn test_sub_mul(inputs: &[QuantumCell<Fr>]) -> Fr {
base_test().run_gate(|ctx, chip| *chip.sub_mul(ctx, inputs[0], inputs[1], inputs[2]).value())
Expand Down
206 changes: 156 additions & 50 deletions halo2-base/src/poseidon/hasher/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use std::mem;

use crate::{
gates::GateInstructions,
poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState},
AssignedValue, Context, ScalarField,
safe_types::{RangeInstructions, SafeTypeChip},
utils::BigPrimeField,
AssignedValue, Context,
QuantumCell::Constant,
ScalarField,
};

use getset::Getters;
use num_bigint::BigUint;
use std::{cell::OnceCell, mem};

#[cfg(test)]
mod tests;

Expand All @@ -16,15 +22,142 @@ pub mod spec;
/// Module for poseidon states.
pub mod state;

/// Poseidon hasher. This is stateful.
/// Stateless Poseidon hasher.
pub struct PoseidonHasher<F: ScalarField, const T: usize, const RATE: usize> {
spec: OptimizedPoseidonSpec<F, T, RATE>,
consts: OnceCell<PoseidonHasherConsts<F, T, RATE>>,
}
#[derive(Getters)]
struct PoseidonHasherConsts<F: ScalarField, const T: usize, const RATE: usize> {
#[getset(get = "pub")]
init_state: PoseidonState<F, T, RATE>,
// hash of an empty input("").
#[getset(get = "pub")]
empty_hash: AssignedValue<F>,
}

impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasherConsts<F, T, RATE> {
pub fn new(
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
spec: &OptimizedPoseidonSpec<F, T, RATE>,
) -> Self {
let init_state = PoseidonState::default(ctx);
let mut state = init_state.clone();
let empty_hash = fix_len_array_squeeze(ctx, gate, &[], &mut state, spec);
Self { init_state, empty_hash }
}
}

impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RATE> {
/// Create a poseidon hasher from an existing spec.
pub fn new(spec: OptimizedPoseidonSpec<F, T, RATE>) -> Self {
Self { spec, consts: OnceCell::new() }
}
/// Initialize necessary consts of hasher. Must be called before any computation.
pub fn initialize_consts(&mut self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) {
self.consts.get_or_init(|| PoseidonHasherConsts::<F, T, RATE>::new(ctx, gate, &self.spec));
}

fn empty_hash(&self) -> &AssignedValue<F> {
self.consts.get().unwrap().empty_hash()
}
fn init_state(&self) -> &PoseidonState<F, T, RATE> {
self.consts.get().unwrap().init_state()
}

/// Constrains and returns hash of a witness array with a variable length.
///
/// Assumes `len` is within [usize] and `len <= inputs.len()`.
/// * inputs: An right-padded array of [AssignedValue]. Constraints on paddings are not required.
/// * len: Length of `inputs`.
/// Return hash of `inputs`.
pub fn hash_var_len_array(
&self,
ctx: &mut Context<F>,
range: &impl RangeInstructions<F>,
inputs: &[AssignedValue<F>],
len: AssignedValue<F>,
) -> AssignedValue<F>
where
F: BigPrimeField,
{
jonathanpwang marked this conversation as resolved.
Show resolved Hide resolved
let max_len = inputs.len();
if max_len == 0 {
return *self.empty_hash();
};

// len <= max_len --> num_of_bits(len) <= num_of_bits(max_len)
let num_bits = (usize::BITS - max_len.leading_zeros()) as usize;
// num_perm = len // RATE + 1, len_last_chunk = len % RATE
let (mut num_perm, len_last_chunk) = range.div_mod(ctx, len, BigUint::from(RATE), num_bits);
num_perm = range.gate().inc(ctx, num_perm);

let mut state = self.init_state().clone();
let mut result_state = state.clone();
for (i, chunk) in inputs.chunks(RATE).enumerate() {
let is_last_perm =
range.gate().is_equal(ctx, num_perm, Constant(F::from((i + 1) as u64)));
let len_chunk = range.gate().select(
ctx,
len_last_chunk,
Constant(F::from(RATE as u64)),
is_last_perm,
);

state.permutation(ctx, range.gate(), chunk, Some(len_chunk), &self.spec);
result_state.select(
ctx,
range.gate(),
SafeTypeChip::<F>::unsafe_to_bool(is_last_perm),
&state,
);
}
if max_len % RATE == 0 {
let is_last_perm = range.gate().is_equal(
ctx,
num_perm,
Constant(F::from((max_len / RATE + 1) as u64)),
);
let len_chunk = ctx.load_zero();
state.permutation(ctx, range.gate(), &[], Some(len_chunk), &self.spec);
result_state.select(
ctx,
range.gate(),
SafeTypeChip::<F>::unsafe_to_bool(is_last_perm),
&state,
);
}
result_state.s[1]
}

/// Constrains and returns hash of a witness array.
///
/// * inputs: An array of [AssignedValue].
/// Return hash of `inputs`.
pub fn hash_fix_len_array(
&self,
ctx: &mut Context<F>,
range: &impl RangeInstructions<F>,
inputs: &[AssignedValue<F>],
) -> AssignedValue<F>
where
F: BigPrimeField,
{
let mut state = self.init_state().clone();
fix_len_array_squeeze(ctx, range.gate(), inputs, &mut state, &self.spec)
}
}

/// Poseidon sponge. This is stateful.
pub struct PoseidonSponge<F: ScalarField, const T: usize, const RATE: usize> {
init_state: PoseidonState<F, T, RATE>,
state: PoseidonState<F, T, RATE>,
spec: OptimizedPoseidonSpec<F, T, RATE>,
absorbing: Vec<AssignedValue<F>>,
}

impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RATE> {
impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonSponge<F, T, RATE> {
/// Create new Poseidon hasher.
pub fn new<const R_F: usize, const R_P: usize, const SECURE_MDS: usize>(
ctx: &mut Context<F>,
Expand Down Expand Up @@ -64,53 +197,26 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RAT
gate: &impl GateInstructions<F>,
) -> AssignedValue<F> {
let input_elements = mem::take(&mut self.absorbing);
let exact = input_elements.len() % RATE == 0;

for chunk in input_elements.chunks(RATE) {
self.permutation(ctx, gate, chunk.to_vec());
}
if exact {
self.permutation(ctx, gate, vec![]);
}

self.state.s[1]
fix_len_array_squeeze(ctx, gate, &input_elements, &mut self.state, &self.spec)
}
}

fn permutation(
&mut self,
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
inputs: Vec<AssignedValue<F>>,
) {
let r_f = self.spec.r_f / 2;
let mds = &self.spec.mds_matrices.mds.0;
let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0;
let sparse_matrices = &self.spec.mds_matrices.sparse_matrices;

// First half of the full round
let constants = &self.spec.constants.start;
self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]);
for constants in constants.iter().skip(1).take(r_f - 1) {
self.state.sbox_full(ctx, gate, constants);
self.state.apply_mds(ctx, gate, mds);
}
self.state.sbox_full(ctx, gate, constants.last().unwrap());
self.state.apply_mds(ctx, gate, pre_sparse_mds);

// Partial rounds
let constants = &self.spec.constants.partial;
for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) {
self.state.sbox_part(ctx, gate, constant);
self.state.apply_sparse_mds(ctx, gate, sparse_mds);
}
/// ATTETION: input_elements.len() needs to be fixed at compile time.
fn fix_len_array_squeeze<F: ScalarField, const T: usize, const RATE: usize>(
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
input_elements: &[AssignedValue<F>],
state: &mut PoseidonState<F, T, RATE>,
spec: &OptimizedPoseidonSpec<F, T, RATE>,
) -> AssignedValue<F> {
let exact = input_elements.len() % RATE == 0;

// Second half of the full rounds
let constants = &self.spec.constants.end;
for constants in constants.iter() {
self.state.sbox_full(ctx, gate, constants);
self.state.apply_mds(ctx, gate, mds);
}
self.state.sbox_full(ctx, gate, &[F::ZERO; T]);
self.state.apply_mds(ctx, gate, mds);
for chunk in input_elements.chunks(RATE) {
state.permutation(ctx, gate, chunk, None, spec);
}
if exact {
state.permutation(ctx, gate, &[], None, spec);
}

state.s[1]
}
Loading
Loading