From 10ef300fbaebaad80d9e71b74cff4ad1d6c6e476 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 1 Sep 2023 11:28:40 -0600 Subject: [PATCH] feat: load `k` as witness and compute `n = 2^k` and `omega` from `k` (#30) * feat: load `k` as witness and compute `n = 2^k` and `omega` from `k` Removes need to make `omega` a public output in universal verifier. * fix: bit_length --- .../{n_as_witness.rs => k_as_witness.rs} | 0 snark-verifier-sdk/src/halo2/aggregation.rs | 61 ++++++++++--------- snark-verifier/src/verifier/plonk/protocol.rs | 28 ++++++--- 3 files changed, 52 insertions(+), 37 deletions(-) rename snark-verifier-sdk/examples/{n_as_witness.rs => k_as_witness.rs} (100%) diff --git a/snark-verifier-sdk/examples/n_as_witness.rs b/snark-verifier-sdk/examples/k_as_witness.rs similarity index 100% rename from snark-verifier-sdk/examples/n_as_witness.rs rename to snark-verifier-sdk/examples/k_as_witness.rs diff --git a/snark-verifier-sdk/src/halo2/aggregation.rs b/snark-verifier-sdk/src/halo2/aggregation.rs index f24dc96e..690a64fc 100644 --- a/snark-verifier-sdk/src/halo2/aggregation.rs +++ b/snark-verifier-sdk/src/halo2/aggregation.rs @@ -42,14 +42,20 @@ pub type Svk = KzgSuccinctVerifyingKey; pub type BaseFieldEccChip<'chip> = halo2_ecc::ecc::BaseFieldEccChip<'chip, G1Affine>; pub type Halo2Loader<'chip> = loader::halo2::Halo2Loader>; +#[derive(Clone, Debug)] +pub struct PreprocessedAndDomainAsWitness { + // this is basically the vkey + pub preprocessed: Vec>, + pub k: AssignedValue, +} + +#[derive(Clone, Debug)] pub struct SnarkAggregationWitness<'a> { pub previous_instances: Vec>>, pub accumulator: KzgAccumulator>>, /// This returns the assigned `preprocessed` and `transcript_initial_state` values as a vector of assigned values, one for each aggregated snark. /// These can then be exposed as public instances. - /// - /// This is only useful if preprocessed digest is loaded as witness (i.e., `preprocessed_as_witness` is true in `aggregate`), so we set it to `None` otherwise. - pub preprocessed_digests: Option>>>, + pub preprocessed: Vec, } /// Different possible stages of universality the aggregation circuit can support @@ -60,7 +66,7 @@ pub enum VerifierUniversality { None, /// Preprocessed digest (commitments to fixed columns) is loaded as witness PreprocessedAsWitness, - /// Preprocessed as witness and number of rows in the circuit `n` loaded as witness + /// Preprocessed as witness and log_2(number of rows in the circuit) = k loaded as witness Full, } @@ -69,7 +75,7 @@ impl VerifierUniversality { self != &VerifierUniversality::None } - pub fn n_as_witness(&self) -> bool { + pub fn k_as_witness(&self) -> bool { self == &VerifierUniversality::Full } } @@ -117,7 +123,7 @@ where }; let mut previous_instances = Vec::with_capacity(snarks.len()); - let mut preprocessed_digests = Vec::with_capacity(snarks.len()); + let mut preprocessed_witnesses = Vec::with_capacity(snarks.len()); // to avoid re-loading the spec each time, we create one transcript and clear the stream let mut transcript = PoseidonTranscript::>, &[u8]>::from_spec( loader, @@ -131,11 +137,11 @@ where .flat_map(|snark: &Snark| { let protocol = if preprocessed_as_witness { // always load `domain.n` as witness if vkey is witness - snark.protocol.loaded_preprocessed_as_witness(loader, universality.n_as_witness()) + snark.protocol.loaded_preprocessed_as_witness(loader, universality.k_as_witness()) } else { snark.protocol.loaded(loader) }; - let inputs = protocol + let preprocessed = protocol .preprocessed .iter() .flat_map(|preprocessed| { @@ -148,19 +154,18 @@ where .chain( protocol.transcript_initial_state.clone().map(|scalar| scalar.into_assigned()), ) - .chain( - protocol - .domain_as_witness - .as_ref() - .map(|domain| domain.n.clone().into_assigned()), - ) // If `n` is witness, add it as part of input - .chain( - protocol - .domain_as_witness - .as_ref() - .map(|domain| domain.gen.clone().into_assigned()), - ) // If `n` is witness, add the generator of the order `n` subgroup as part of input .collect_vec(); + // Store `k` as witness. If `k` was fixed, assign it as a constant. + let k = protocol + .domain_as_witness + .as_ref() + .map(|domain| domain.k.clone().into_assigned()) + .unwrap_or_else(|| { + loader.ctx_mut().main().load_constant(Fr::from(protocol.domain.k as u64)) + }); + let preprocessed_and_k = PreprocessedAndDomainAsWitness { preprocessed, k }; + preprocessed_witnesses.push(preprocessed_and_k); + let instances = assign_instances(&snark.instances); // read the transcript and perform Fiat-Shamir @@ -179,7 +184,6 @@ where previous_instances.push( instances.into_iter().flatten().map(|scalar| scalar.into_assigned()).collect(), ); - preprocessed_digests.push(inputs); accumulator }) @@ -198,9 +202,12 @@ where } else { accumulators.pop().unwrap() }; - let preprocessed_digests = preprocessed_as_witness.then_some(preprocessed_digests); - SnarkAggregationWitness { previous_instances, accumulator, preprocessed_digests } + SnarkAggregationWitness { + previous_instances, + accumulator, + preprocessed: preprocessed_witnesses, + } } /// Same as `FlexGateConfigParams` except we assume a single Phase and default 'Vertical' strategy. @@ -278,10 +285,8 @@ pub struct AggregationCircuit { previous_instances: Vec>>, /// This returns the assigned `preprocessed_digest` (vkey), optional `transcript_initial_state`, `domain.n` (optional), and `omega` (optional) values as a vector of assigned values, one for each aggregated snark. /// These can then be exposed as public instances. - /// - /// This is only useful if preprocessed digest is loaded as witness (i.e., `universality != None`), so we set it to `None` if `universality == None`. #[getset(get = "pub")] - preprocessed_digests: Option>>>, + preprocessed: Vec, // accumulation scheme proof, no longer used // pub as_proof: Vec, } @@ -380,7 +385,7 @@ impl AggregationCircuit { let loader = Halo2Loader::new(ecc_chip, pool); // run witness and copy constraint generation - let SnarkAggregationWitness { previous_instances, accumulator, preprocessed_digests } = + let SnarkAggregationWitness { previous_instances, accumulator, preprocessed } = aggregate::(&svk, &loader, &snarks, as_proof.as_slice(), universality); let lhs = accumulator.lhs.assigned(); let rhs = accumulator.rhs.assigned(); @@ -412,7 +417,7 @@ impl AggregationCircuit { ); // expose accumulator as public instances builder.assigned_instances[0] = accumulator; - Self { builder, previous_instances, preprocessed_digests } + Self { builder, previous_instances, preprocessed } } /// Re-expose the previous public instances of aggregated snarks again. diff --git a/snark-verifier/src/verifier/plonk/protocol.rs b/snark-verifier/src/verifier/plonk/protocol.rs index 098e7de9..9260e572 100644 --- a/snark-verifier/src/verifier/plonk/protocol.rs +++ b/snark-verifier/src/verifier/plonk/protocol.rs @@ -23,7 +23,9 @@ where C: CurveAffine, L: Loader, { - /// Number of rows in the domain + /// 2k is the number of rows in the domain + pub k: L::LoadedScalar, + /// n = 2k is the number of rows in the domain pub n: L::LoadedScalar, /// Generator of the domain pub gen: L::LoadedScalar, @@ -65,7 +67,6 @@ where serialize = "L::LoadedScalar: Serialize", deserialize = "L::LoadedScalar: Deserialize<'de>" ))] - #[serde(skip_serializing_if = "Option::is_none")] /// Optional: load `domain.n` and `domain.gen` as a witness pub domain_as_witness: Option>, @@ -176,14 +177,15 @@ mod halo2 { use crate::{ loader::{ halo2::{EccInstructions, Halo2Loader}, - LoadedScalar, + LoadedScalar, ScalarLoader, }, util::arithmetic::CurveAffine, verifier::plonk::PlonkProtocol, }; + use halo2_base::utils::bit_length; use std::rc::Rc; - use super::DomainAsWitness; + use super::{DomainAsWitness, PrimeField}; impl PlonkProtocol where @@ -195,13 +197,21 @@ mod halo2 { pub fn loaded_preprocessed_as_witness>( &self, loader: &Rc>, - load_n_as_witness: bool, + load_k_as_witness: bool, ) -> PlonkProtocol>> { - let domain_as_witness = load_n_as_witness.then(|| { - let n = loader.assign_scalar(C::Scalar::from(self.domain.n as u64)); - let gen = loader.assign_scalar(self.domain.gen); + let domain_as_witness = load_k_as_witness.then(|| { + let k = loader.assign_scalar(C::Scalar::from(self.domain.k as u64)); + // n = 2^k + let two = loader.load_const(&C::Scalar::from(2)); + let n = two.pow_var(&k, bit_length(C::Scalar::S as u64) + 1); + // gen = omega = ROOT_OF_UNITY ^ {2^{S - k}}, where ROOT_OF_UNITY is primitive 2^S root of unity + // this makes omega a 2^k root of unity + let root_of_unity = loader.load_const(&C::Scalar::ROOT_OF_UNITY); + let s = loader.load_const(&C::Scalar::from(C::Scalar::S as u64)); + let exp = two.pow_var(&(s - &k), bit_length(C::Scalar::S as u64)); // if S - k < 0, constraint on max bits will fail + let gen = root_of_unity.pow_var(&exp, C::Scalar::S as usize); // 2^{S - k} < 2^S for k > 0 let gen_inv = gen.invert().expect("subgroup generation is invertible"); - DomainAsWitness { n, gen, gen_inv } + DomainAsWitness { k, n, gen, gen_inv } }); let preprocessed = self