Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalized Halo2Loader #12

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ num-bigint = "0.4.3"
num-integer = "0.1.45"
num-traits = "0.2.15"
rand = "0.8"
rand_chacha = "0.3.1"
halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" }

# system_halo2
Expand All @@ -25,6 +24,7 @@ halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2
poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true }

[dev-dependencies]
rand_chacha = "0.3.1"
paste = "1.0.7"

# system_halo2
Expand Down
11 changes: 6 additions & 5 deletions examples/evm-verifier-with-accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,12 @@ mod aggregation {
let accumulators = snarks
.iter()
.flat_map(|snark| {
let protocol = snark.protocol.loaded(loader);
let instances = assign_instances(&snark.instances);
let mut transcript =
PoseidonTranscript::<Rc<Halo2Loader>, _>::new(loader, snark.proof());
let proof =
Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap();
Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap()
let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap();
Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap()
})
.collect_vec();

Expand Down Expand Up @@ -555,11 +555,12 @@ fn gen_aggregation_evm_verifier(
vk,
Config::kzg()
.with_num_instance(num_instance.clone())
.with_accumulator_indices(accumulator_indices),
.with_accumulator_indices(Some(accumulator_indices)),
);

let loader = EvmLoader::new::<Fq, Fr>();
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(loader.clone());
let protocol = protocol.loaded(&loader);
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(&loader);

let instances = transcript.load_instances(num_instance);
let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap();
Expand Down
3 changes: 2 additions & 1 deletion examples/evm-verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ fn gen_evm_verifier(
);

let loader = EvmLoader::new::<Fq, Fr>();
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(loader.clone());
let protocol = protocol.loaded(&loader);
let mut transcript = EvmTranscript::<_, Rc<EvmLoader>, _, _>::new(&loader);

let instances = transcript.load_instances(num_instance);
let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap();
Expand Down
10 changes: 7 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,22 @@ pub enum Error {
}

#[derive(Clone, Debug)]
pub struct Protocol<C: util::arithmetic::CurveAffine> {
pub struct Protocol<C, L = loader::native::NativeLoader>
where
C: util::arithmetic::CurveAffine,
L: loader::Loader<C>,
{
// Common description
pub domain: util::arithmetic::Domain<C::Scalar>,
pub preprocessed: Vec<C>,
pub preprocessed: Vec<L::LoadedEcPoint>,
pub num_instance: Vec<usize>,
pub num_witness: Vec<usize>,
pub num_challenge: Vec<usize>,
pub evaluations: Vec<util::protocol::Query>,
pub queries: Vec<util::protocol::Query>,
pub quotient: util::protocol::QuotientPolynomial<C::Scalar>,
// Minor customization
pub transcript_initial_state: Option<C::Scalar>,
pub transcript_initial_state: Option<L::LoadedScalar>,
pub instance_committing_key: Option<util::protocol::InstanceCommittingKey<C>>,
pub linearization: Option<util::protocol::LinearizationStrategy>,
pub accumulator_indices: Vec<Vec<(usize, usize)>>,
Expand Down
33 changes: 15 additions & 18 deletions src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,6 @@ pub trait LoadedEcPoint<C: CurveAffine>: Clone + Debug + PartialEq {
type Loader: Loader<C, LoadedEcPoint = Self>;

fn loader(&self) -> &Self::Loader;

fn multi_scalar_multiplication(
pairs: impl IntoIterator<
Item = (
<Self::Loader as ScalarLoader<C::Scalar>>::LoadedScalar,
Self,
),
>,
) -> Self;
}

pub trait LoadedScalar<F: PrimeField>: Clone + Debug + PartialEq + FieldOps {
Expand All @@ -43,15 +34,6 @@ pub trait LoadedScalar<F: PrimeField>: Clone + Debug + PartialEq + FieldOps {
FieldOps::invert(self)
}

fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Self>)
where
Self: 'a,
{
values
.into_iter()
.for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone()))
}

fn pow_const(&self, mut exp: u64) -> Self {
assert!(exp > 0);

Expand Down Expand Up @@ -102,6 +84,12 @@ pub trait EcPointLoader<C: CurveAffine> {
lhs: &Self::LoadedEcPoint,
rhs: &Self::LoadedEcPoint,
) -> Result<(), Error>;

fn multi_scalar_multiplication(
pairs: &[(Self::LoadedScalar, Self::LoadedEcPoint)],
) -> Self::LoadedEcPoint
where
Self: ScalarLoader<C::ScalarExt>;
}

pub trait ScalarLoader<F: PrimeField> {
Expand Down Expand Up @@ -226,6 +214,15 @@ pub trait ScalarLoader<F: PrimeField> {
.iter()
.fold(self.load_one(), |acc, value| acc * *value)
}

fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Self::LoadedScalar>)
where
Self::LoadedScalar: 'a,
{
values
.into_iter()
.for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone()))
}
}

pub trait Loader<C: CurveAffine>:
Expand Down
148 changes: 75 additions & 73 deletions src/loader/evm/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -596,17 +596,6 @@ where
fn loader(&self) -> &Rc<EvmLoader> {
&self.loader
}

fn multi_scalar_multiplication(pairs: impl IntoIterator<Item = (Scalar, EcPoint)>) -> Self {
pairs
.into_iter()
.map(|(scalar, ec_point)| match scalar.value {
Value::Constant(constant) if constant == U256::one() => ec_point,
_ => ec_point.loader.ec_point_scalar_mul(&ec_point, &scalar),
})
.reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point))
.unwrap()
}
}

#[derive(Clone)]
Expand Down Expand Up @@ -759,73 +748,12 @@ impl<F: PrimeField<Repr = [u8; 0x20]>> LoadedScalar<F> for Scalar {
fn loader(&self) -> &Rc<EvmLoader> {
&self.loader
}

fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Self>) {
let values = values.into_iter().collect_vec();
let loader = &values.first().unwrap().loader;
let products = iter::once(values[0].clone())
.chain(
iter::repeat_with(|| loader.allocate(0x20))
.map(|ptr| loader.scalar(Value::Memory(ptr)))
.take(values.len() - 1),
)
.collect_vec();

loader.code.borrow_mut().push(loader.scalar_modulus);
for _ in 2..values.len() {
loader.code.borrow_mut().dup(0);
}

loader.push(products.first().unwrap());
for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() {
loader.push(value);
loader.code.borrow_mut().mulmod();
if idx < values.len() - 2 {
loader.code.borrow_mut().dup(0);
}
loader.code.borrow_mut().push(product.ptr()).mstore();
}

let inv = loader.invert(products.last().unwrap());

loader.code.borrow_mut().push(loader.scalar_modulus);
for _ in 2..values.len() {
loader.code.borrow_mut().dup(0);
}

loader.push(&inv);
for (value, product) in values.iter().rev().zip(
products
.iter()
.rev()
.skip(1)
.map(Some)
.chain(iter::once(None)),
) {
if let Some(product) = product {
loader.push(value);
loader
.code
.borrow_mut()
.dup(2)
.dup(2)
.push(product.ptr())
.mload()
.mulmod()
.push(value.ptr())
.mstore()
.mulmod();
} else {
loader.code.borrow_mut().push(value.ptr()).mstore();
}
}
}
}

impl<C> EcPointLoader<C> for Rc<EvmLoader>
where
C: CurveAffine,
C::Scalar: PrimeField<Repr = [u8; 0x20]>,
C::ScalarExt: PrimeField<Repr = [u8; 0x20]>,
{
type LoadedEcPoint = EcPoint;

Expand All @@ -839,6 +767,19 @@ where
fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) -> Result<(), Error> {
unimplemented!()
}

fn multi_scalar_multiplication(
pairs: &[(<Self as ScalarLoader<C::Scalar>>::LoadedScalar, EcPoint)],
) -> EcPoint {
pairs
.iter()
.map(|(scalar, ec_point)| match scalar.value {
Value::Constant(constant) if U256::one() == constant => ec_point.clone(),
_ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar),
})
.reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point))
.unwrap()
}
}

impl<F: PrimeField<Repr = [u8; 0x20]>> ScalarLoader<F> for Rc<EvmLoader> {
Expand Down Expand Up @@ -977,6 +918,67 @@ impl<F: PrimeField<Repr = [u8; 0x20]>> ScalarLoader<F> for Rc<EvmLoader> {

self.scalar(Value::Memory(ptr))
}

fn batch_invert<'a>(values: impl IntoIterator<Item = &'a mut Scalar>) {
let values = values.into_iter().collect_vec();
let loader = &values.first().unwrap().loader;
let products = iter::once(values[0].clone())
.chain(
iter::repeat_with(|| loader.allocate(0x20))
.map(|ptr| loader.scalar(Value::Memory(ptr)))
.take(values.len() - 1),
)
.collect_vec();

loader.code.borrow_mut().push(loader.scalar_modulus);
for _ in 2..values.len() {
loader.code.borrow_mut().dup(0);
}

loader.push(products.first().unwrap());
for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() {
loader.push(value);
loader.code.borrow_mut().mulmod();
if idx < values.len() - 2 {
loader.code.borrow_mut().dup(0);
}
loader.code.borrow_mut().push(product.ptr()).mstore();
}

let inv = loader.invert(products.last().unwrap());

loader.code.borrow_mut().push(loader.scalar_modulus);
for _ in 2..values.len() {
loader.code.borrow_mut().dup(0);
}

loader.push(&inv);
for (value, product) in values.iter().rev().zip(
products
.iter()
.rev()
.skip(1)
.map(Some)
.chain(iter::once(None)),
) {
if let Some(product) = product {
loader.push(value);
loader
.code
.borrow_mut()
.dup(2)
.dup(2)
.push(product.ptr())
.mload()
.mulmod()
.push(value.ptr())
.mstore()
.mulmod();
} else {
loader.code.borrow_mut().push(value.ptr()).mstore();
}
}
}
}

impl<C> Loader<C> for Rc<EvmLoader>
Expand Down
42 changes: 42 additions & 0 deletions src/loader/halo2.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use crate::{util::arithmetic::CurveAffine, Protocol};
use halo2_proofs::circuit;
use std::rc::Rc;

pub(crate) mod loader;
mod shim;

Expand Down Expand Up @@ -27,3 +31,41 @@ mod util {

impl<V, I: Iterator<Item = Value<V>>> Valuetools<V> for I {}
}

impl<C> Protocol<C>
where
C: CurveAffine,
{
pub fn loaded_preprocessed_as_witness<'a, EccChip: EccInstructions<'a, C>>(
&self,
loader: &Rc<Halo2Loader<'a, C, EccChip>>,
) -> Protocol<C, Rc<Halo2Loader<'a, C, EccChip>>> {
let preprocessed = self
.preprocessed
.iter()
.map(|preprocessed| loader.assign_ec_point(circuit::Value::known(*preprocessed)))
.collect();
let transcript_initial_state =
self.transcript_initial_state
.as_ref()
.map(|transcript_initial_state| {
loader.assign_scalar(circuit::Value::known(
loader.scalar_chip().integer(*transcript_initial_state),
))
});
Protocol {
domain: self.domain.clone(),
preprocessed,
num_instance: self.num_instance.clone(),
num_witness: self.num_witness.clone(),
num_challenge: self.num_challenge.clone(),
evaluations: self.evaluations.clone(),
queries: self.queries.clone(),
quotient: self.quotient.clone(),
transcript_initial_state,
instance_committing_key: self.instance_committing_key.clone(),
linearization: self.linearization.clone(),
accumulator_indices: self.accumulator_indices.clone(),
}
}
}
Loading