Skip to content

Commit

Permalink
Generalized Halo2Loader (#12)
Browse files Browse the repository at this point in the history
* feat: generalize `Protocol` for further usage

* feat: add `EccInstruction::{fixed_base_msm,variable_base_msm,sum_with_const}`

* chore: move `rand_chacha` as dev dependency
  • Loading branch information
han0110 authored Oct 28, 2022
1 parent 916b29f commit 2cd8b9d
Show file tree
Hide file tree
Showing 22 changed files with 423 additions and 252 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ num-bigint = "0.4.3"
num-integer = "0.1.45"
num-traits = "0.2.15"
rand = "0.8"
rand_chacha = "0.3.1"
halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" }

# system_halo2
Expand All @@ -25,6 +24,7 @@ halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2
poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true }

[dev-dependencies]
rand_chacha = "0.3.1"
paste = "1.0.7"

# system_halo2
Expand Down
11 changes: 6 additions & 5 deletions examples/evm-verifier-with-accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,12 @@ mod aggregation {
let accumulators = snarks
.iter()
.flat_map(|snark| {
let protocol = snark.protocol.loaded(loader);
let instances = assign_instances(&snark.instances);
let mut transcript =
PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
let proof =
Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap();
Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap()
let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap();
Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap()
})
.collect_vec();

Expand Down Expand Up @@ -555,11 +555,12 @@ fn gen_aggregation_evm_verifier(
vk,
Config::kzg()
.with_num_instance(num_instance.clone())
.with_accumulator_indices(accumulator_indices),
.with_accumulator_indices(Some(accumulator_indices)),
);

let loader = EvmLoader::new::<Fq, Fr>();
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(loader.clone());
let protocol = protocol.loaded(&loader);
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(&loader);

let instances = transcript.load_instances(num_instance);
let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion examples/evm-verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ fn gen_evm_verifier(
);

let loader = EvmLoader::new::<Fq, Fr>();
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(loader.clone());
let protocol = protocol.loaded(&loader);
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(&loader);

let instances = transcript.load_instances(num_instance);
let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap();
Expand Down
10 changes: 7 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@ pub enum Error {
}

#[derive(Clone, Debug)]
pub struct Protocol<C: util::arithmetic::CurveAffine> {
pub struct Protocol<C, L = loader::native::NativeLoader>
where
C: util::arithmetic::CurveAffine,
L: loader::Loader<C>,
{
// Common description
pub domain: util::arithmetic::Domain<C::Scalar>,
pub preprocessed: Vec<C>,
pub preprocessed: Vec<L::LoadedEcPoint>,
pub num_instance: Vec<usize>,
pub num_witness: Vec<usize>,
pub num_challenge: Vec<usize>,
pub evaluations: Vec<util::protocol::Query>,
pub queries: Vec<util::protocol::Query>,
pub quotient: util::protocol::QuotientPolynomial<C::Scalar>,
// Minor customization
pub transcript_initial_state: Option<C::Scalar>,
pub transcript_initial_state: Option<L::LoadedScalar>,
pub instance_committing_key: Option<util::protocol::InstanceCommittingKey<C>>,
pub linearization: Option<util::protocol::LinearizationStrategy>,
pub accumulator_indices: Vec<Vec<(usize, usize)>>,
Expand Down
33 changes: 15 additions & 18 deletions src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,6 @@ pub trait LoadedEcPoint<C: CurveAffine>: Clone + Debug + PartialEq {
type Loader: Loader<C, LoadedEcPoint = Self>;

fn loader(&self) -> &Self::Loader;

fn multi_scalar_multiplication(
pairs: impl IntoIterator<
Item = (
<Self::Loader as ScalarLoader<C::Scalar>>::LoadedScalar,
Self,
),
>,
) -> Self;
}

pub trait LoadedScalar<F: PrimeField>: Clone + Debug + PartialEq + FieldOps {
Expand All @@ -43,15 +34,6 @@ pub trait LoadedScalar<F: PrimeField>: Clone + Debug + PartialEq + FieldOps {
FieldOps::invert(self)
}

fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Self>)
where
Self: 'a,
{
values
.into_iter()
.for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone()))
}

fn pow_const(&self, mut exp: u64) -> Self {
assert!(exp > 0);

Expand Down Expand Up @@ -102,6 +84,12 @@ pub trait EcPointLoader<C: CurveAffine> {
lhs: &Self::LoadedEcPoint,
rhs: &Self::LoadedEcPoint,
) -> Result<(), Error>;

fn multi_scalar_multiplication(
pairs: &[(Self::LoadedScalar, Self::LoadedEcPoint)],
) -> Self::LoadedEcPoint
where
Self: ScalarLoader<C::ScalarExt>;
}

pub trait ScalarLoader<F: PrimeField> {
Expand Down Expand Up @@ -226,6 +214,15 @@ pub trait ScalarLoader<F: PrimeField> {
.iter()
.fold(self.load_one(), |acc, value| acc * *value)
}

fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Self::LoadedScalar>)
where
Self::LoadedScalar: 'a,
{
values
.into_iter()
.for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone()))
}
}

pub trait Loader<C: CurveAffine>:
Expand Down
148 changes: 75 additions & 73 deletions src/loader/evm/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,17 +596,6 @@ where
fn loader(&self) -> &Rc<EvmLoader> {
&self.loader
}

fn multi_scalar_multiplication(pairs: impl IntoIterator<Item = (Scalar, EcPoint)>) -> Self {
pairs
.into_iter()
.map(|(scalar, ec_point)| match scalar.value {
Value::Constant(constant) if constant == U256::one() => ec_point,
_ => ec_point.loader.ec_point_scalar_mul(&ec_point, &scalar),
})
.reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point))
.unwrap()
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -759,73 +748,12 @@ impl<F: PrimeField<Repr = [u8; 0x20]>> LoadedScalar<F> for Scalar {
fn loader(&self) -> &Rc<EvmLoader> {
&self.loader
}

fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Self>) {
let values = values.into_iter().collect_vec();
let loader = &values.first().unwrap().loader;
let products = iter::once(values[0].clone())
.chain(
iter::repeat_with(|| loader.allocate(0x20))
.map(|ptr| loader.scalar(Value::Memory(ptr)))
.take(values.len() - 1),
)
.collect_vec();

loader.code.borrow_mut().push(loader.scalar_modulus);
for _ in 2..values.len() {
loader.code.borrow_mut().dup(0);
}

loader.push(products.first().unwrap());
for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() {
loader.push(value);
loader.code.borrow_mut().mulmod();
if idx < values.len() - 2 {
loader.code.borrow_mut().dup(0);
}
loader.code.borrow_mut().push(product.ptr()).mstore();
}

let inv = loader.invert(products.last().unwrap());

loader.code.borrow_mut().push(loader.scalar_modulus);
for _ in 2..values.len() {
loader.code.borrow_mut().dup(0);
}

loader.push(&inv);
for (value, product) in values.iter().rev().zip(
products
.iter()
.rev()
.skip(1)
.map(Some)
.chain(iter::once(None)),
) {
if let Some(product) = product {
loader.push(value);
loader
.code
.borrow_mut()
.dup(2)
.dup(2)
.push(product.ptr())
.mload()
.mulmod()
.push(value.ptr())
.mstore()
.mulmod();
} else {
loader.code.borrow_mut().push(value.ptr()).mstore();
}
}
}
}

impl<C> EcPointLoader<C> for Rc<EvmLoader>
where
C: CurveAffine,
C::Scalar: PrimeField<Repr = [u8; 0x20]>,
C::ScalarExt: PrimeField<Repr = [u8; 0x20]>,
{
type LoadedEcPoint = EcPoint;

Expand All @@ -839,6 +767,19 @@ where
fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) -> Result<(), Error> {
unimplemented!()
}

fn multi_scalar_multiplication(
pairs: &[(<Self as ScalarLoader<C::Scalar>>::LoadedScalar, EcPoint)],
) -> EcPoint {
pairs
.iter()
.map(|(scalar, ec_point)| match scalar.value {
Value::Constant(constant) if U256::one() == constant => ec_point.clone(),
_ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar),
})
.reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point))
.unwrap()
}
}

impl<F: PrimeField<Repr = [u8; 0x20]>> ScalarLoader<F> for Rc<EvmLoader> {
Expand Down Expand Up @@ -977,6 +918,67 @@ impl<F: PrimeField<Repr = [u8; 0x20]>> ScalarLoader<F> for Rc<EvmLoader> {

self.scalar(Value::Memory(ptr))
}

fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Scalar>) {
let values = values.into_iter().collect_vec();
let loader = &values.first().unwrap().loader;
let products = iter::once(values[0].clone())
.chain(
iter::repeat_with(|| loader.allocate(0x20))
.map(|ptr| loader.scalar(Value::Memory(ptr)))
.take(values.len() - 1),
)
.collect_vec();

loader.code.borrow_mut().push(loader.scalar_modulus);
for _ in 2..values.len() {
loader.code.borrow_mut().dup(0);
}

loader.push(products.first().unwrap());
for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() {
loader.push(value);
loader.code.borrow_mut().mulmod();
if idx < values.len() - 2 {
loader.code.borrow_mut().dup(0);
}
loader.code.borrow_mut().push(product.ptr()).mstore();
}

let inv = loader.invert(products.last().unwrap());

loader.code.borrow_mut().push(loader.scalar_modulus);
for _ in 2..values.len() {
loader.code.borrow_mut().dup(0);
}

loader.push(&inv);
for (value, product) in values.iter().rev().zip(
products
.iter()
.rev()
.skip(1)
.map(Some)
.chain(iter::once(None)),
) {
if let Some(product) = product {
loader.push(value);
loader
.code
.borrow_mut()
.dup(2)
.dup(2)
.push(product.ptr())
.mload()
.mulmod()
.push(value.ptr())
.mstore()
.mulmod();
} else {
loader.code.borrow_mut().push(value.ptr()).mstore();
}
}
}
}

impl<C> Loader<C> for Rc<EvmLoader>
Expand Down
42 changes: 42 additions & 0 deletions src/loader/halo2.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use crate::{util::arithmetic::CurveAffine, Protocol};
use halo2_proofs::circuit;
use std::rc::Rc;

pub(crate) mod loader;
mod shim;

Expand Down Expand Up @@ -27,3 +31,41 @@ mod util {

impl<V, I: Iterator<Item = Value<V>>> Valuetools<V> for I {}
}

impl<C> Protocol<C>
where
C: CurveAffine,
{
pub fn loaded_preprocessed_as_witness<'a, EccChip: EccInstructions<'a, C>>(
&self,
loader: &Rc<Halo2Loader<'a, C, EccChip>>,
) -> Protocol<C, Rc<Halo2Loader<'a, C, EccChip>>> {
let preprocessed = self
.preprocessed
.iter()
.map(|preprocessed| loader.assign_ec_point(circuit::Value::known(*preprocessed)))
.collect();
let transcript_initial_state =
self.transcript_initial_state
.as_ref()
.map(|transcript_initial_state| {
loader.assign_scalar(circuit::Value::known(
loader.scalar_chip().integer(*transcript_initial_state),
))
});
Protocol {
domain: self.domain.clone(),
preprocessed,
num_instance: self.num_instance.clone(),
num_witness: self.num_witness.clone(),
num_challenge: self.num_challenge.clone(),
evaluations: self.evaluations.clone(),
queries: self.queries.clone(),
quotient: self.quotient.clone(),
transcript_initial_state,
instance_committing_key: self.instance_committing_key.clone(),
linearization: self.linearization.clone(),
accumulator_indices: self.accumulator_indices.clone(),
}
}
}
Loading

0 comments on commit 2cd8b9d

Please sign in to comment.