diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 83bfc0bb..a50958d0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -24,7 +24,7 @@ jobs: cache-on-failure: true - name: Run test - run: cargo test --all --features test -- --nocapture + run: cargo test --all --all-features -- --nocapture lint: diff --git a/.gitignore b/.gitignore index ebb68914..0175c775 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .DS_Store /target -fixture +testdata Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index e6ac1d49..de037c4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,57 +1,54 @@ [package] -name = "plonk-verifier" +name = "plonk_verifier" version = "0.1.0" edition = "2021" [dependencies] -ff = "0.12.0" -group = "0.12.0" itertools = "0.10.3" lazy_static = "1.4.0" -num-bigint = "0.4" -num-traits = "0.2" +num-bigint = "0.4.3" +num-integer = "0.1.45" +num-traits = "0.2.15" rand = "0.8" rand_chacha = "0.3.1" -halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.2.1", package = "halo2curves" } - -# halo2 -blake2b_simd = { version = "1.0.0", optional = true } -halo2_proofs = { version = "0.2.0", optional = true } -halo2_wrong = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", package = "halo2wrong", optional = true } -halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", package = "ecc", optional = true } -halo2_wrong_maingate = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", package = "maingate", optional = true } -halo2_wrong_transcript = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", package = "transcript", optional = true } -poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", branch = "padding", optional = true } - -# evm +halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } + +# system_halo2 +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_10_22", optional = true } + +# loader_evm ethereum_types = { package = "ethereum-types", version = "0.13.1", default-features = false, features = ["std"], optional = true } -foundry_evm = { git = "https://github.com/foundry-rs/foundry", package = "foundry-evm", rev = "93ee742d", optional = true } -crossterm = { version = "0.22.1", optional = true } -tui = { version = "0.16.0", default-features = false, features = ["crossterm"], optional = true } sha3 = { version = "0.10.1", optional = true } +# loader_halo2 +halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc", optional = true } +poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true } + [dev-dependencies] paste = "1.0.7" +# system_halo2 +halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc" } + +# loader_evm +foundry_evm = { git = "https://github.com/foundry-rs/foundry", package = "foundry-evm", rev = "6b1ee60e" } +crossterm = { version = "0.22.1" } +tui = { version = "0.16.0", default-features = false, features = ["crossterm"] } + [features] -default = ["halo2", "evm"] -test = ["halo2", "evm"] +default = ["loader_evm", "loader_halo2", "system_halo2"] -halo2 = ["dep:blake2b_simd", "dep:halo2_proofs", "dep:halo2_wrong", "dep:halo2_wrong_ecc", "dep:halo2_wrong_maingate", "dep:halo2_wrong_transcript", "dep:poseidon"] -evm = ["dep:foundry_evm", "dep:crossterm", "dep:tui", "dep:ethereum_types", "dep:sha3"] -sanity-check = [] +loader_evm = ["dep:ethereum_types", "dep:sha3"] +loader_halo2 = ["dep:halo2_proofs", "dep:halo2_wrong_ecc", "dep:poseidon"] -[patch.crates-io] -halo2_proofs = { git = "https://github.com/han0110/halo2", branch = "experiment", package = "halo2_proofs" } +system_halo2 = ["dep:halo2_proofs"] -[patch."https://github.com/privacy-scaling-explorations/halo2"] -halo2_proofs = { git = "https://github.com/han0110/halo2", branch = "experiment", package = "halo2_proofs" } +sanity_check = [] -[patch."https://github.com/privacy-scaling-explorations/halo2curves"] -halo2_curves = { git = "https://github.com//privacy-scaling-explorations/halo2curves", tag = "0.2.1", package = "halo2curves" } +[[example]] +name = "evm-verifier" +required-features = ["loader_evm", "system_halo2"] -[patch."https://github.com/privacy-scaling-explorations/halo2wrong"] -halo2_wrong = { git = "https://github.com/han0110/halo2wrong", branch = "feature/range-chip-with-tagged-table", package = "halo2wrong" } -halo2_wrong_ecc = { git = "https://github.com/han0110/halo2wrong", branch = "feature/range-chip-with-tagged-table", package = "ecc" } -halo2_wrong_maingate = { git = "https://github.com/han0110/halo2wrong", branch = "feature/range-chip-with-tagged-table", package = "maingate" } -halo2_wrong_transcript = { git = "https://github.com/han0110/halo2wrong", branch = "feature/range-chip-with-tagged-table", package = "transcript" } +[[example]] +name = "evm-verifier-with-accumulator" +required-features = ["loader_halo2", "loader_evm", "system_halo2"] diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs new file mode 100644 index 00000000..69def21e --- /dev/null +++ b/examples/evm-verifier-with-accumulator.rs @@ -0,0 +1,619 @@ +use ethereum_types::Address; +use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; +use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use halo2_proofs::{ + dev::MockProver, + plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::{ + commitment::{Params, ParamsProver}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + VerificationStrategy, + }, + transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use itertools::Itertools; +use plonk_verifier::{ + loader::{ + evm::{encode_calldata, EvmLoader}, + native::NativeLoader, + }, + pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, + verifier::{self, PlonkVerifier}, +}; +use rand::rngs::OsRng; +use std::{io::Cursor, rc::Rc}; + +const LIMBS: usize = 4; +const BITS: usize = 68; + +type Pcs = Kzg; +type As = KzgAs; +type Plonk = verifier::Plonk>; + +mod application { + use halo2_curves::bn256::Fr; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + poly::Rotation, + }; + use rand::RngCore; + + #[derive(Clone, Copy)] + pub struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, + } + + impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = + [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { + a, + b, + c, + q_a, + q_b, + q_c, + q_ab, + constant, + instance, + } + } + } + + #[derive(Clone, Default)] + pub struct StandardPlonk(Fr); + + impl StandardPlonk { + pub fn rand(mut rng: R) -> Self { + Self(Fr::from(rng.next_u32() as u64)) + } + + pub fn num_instance() -> Vec { + vec![1] + } + + pub fn instances(&self) -> Vec> { + vec![vec![self.0]] + } + } + + impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; + + region.assign_advice(|| "", config.a, 1, || Value::known(-Fr::from(5)))?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(|| "", column, 1, || Value::known(Fr::from(idx)))?; + } + + let a = region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; + a.copy_advice(|| "", &mut region, config.b, 3)?; + a.copy_advice(|| "", &mut region, config.c, 4)?; + + Ok(()) + }, + ) + } + } +} + +mod aggregation { + use super::{As, Plonk, BITS, LIMBS}; + use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{self, Circuit, ConstraintSystem}, + poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, + }; + use halo2_wrong_ecc::{ + integer::rns::Rns, + maingate::{ + MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, + RangeInstructions, RegionCtx, + }, + EccConfig, + }; + use itertools::Itertools; + use plonk_verifier::{ + loader::{self, native::NativeLoader}, + pcs::{ + kzg::{KzgAccumulator, KzgSuccinctVerifyingKey}, + AccumulationScheme, AccumulationSchemeProver, + }, + system, + util::arithmetic::{fe_to_limbs, FieldExt}, + verifier::PlonkVerifier, + Protocol, + }; + use rand::rngs::OsRng; + use std::{iter, rc::Rc}; + + const T: usize = 5; + const RATE: usize = 4; + const R_F: usize = 8; + const R_P: usize = 60; + + type Svk = KzgSuccinctVerifyingKey; + type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; + type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + pub type PoseidonTranscript = + system::halo2::transcript::halo2::PoseidonTranscript; + + pub struct Snark { + protocol: Protocol, + instances: Vec>, + proof: Vec, + } + + impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + Self { + protocol, + instances, + proof, + } + } + } + + impl From for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } + } + + #[derive(Clone)] + pub struct SnarkWitness { + protocol: Protocol, + instances: Vec>>, + proof: Value>, + } + + impl SnarkWitness { + fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } + } + + pub fn aggregate<'a>( + svk: &Svk, + loader: &Rc>, + snarks: &[SnarkWitness], + as_proof: Value<&'_ [u8]>, + ) -> KzgAccumulator>> { + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances + .iter() + .map(|instance| loader.assign_scalar(*instance)) + .collect_vec() + }) + .collect_vec() + }; + + let accumulators = snarks + .iter() + .flat_map(|snark| { + let instances = assign_instances(&snark.instances); + let mut transcript = + PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = + Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + }) + .collect_vec(); + + let acccumulator = { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = + As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() + }; + + acccumulator + } + + #[derive(Clone)] + pub struct AggregationConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, + } + + impl AggregationConfig { + pub fn configure( + meta: &mut ConstraintSystem, + composition_bits: Vec, + overflow_bits: Vec, + ) -> Self { + let main_gate_config = MainGate::::configure(meta); + let range_config = + RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); + AggregationConfig { + main_gate_config, + range_config, + } + } + + pub fn main_gate(&self) -> MainGate { + MainGate::new(self.main_gate_config.clone()) + } + + pub fn range_chip(&self) -> RangeChip { + RangeChip::new(self.range_config.clone()) + } + + pub fn ecc_chip(&self) -> BaseFieldEccChip { + BaseFieldEccChip::new(EccConfig::new( + self.range_config.clone(), + self.main_gate_config.clone(), + )) + } + } + + #[derive(Clone)] + pub struct AggregationCircuit { + svk: Svk, + snarks: Vec, + instances: Vec, + as_proof: Value>, + } + + impl AggregationCircuit { + pub fn new(params: &ParamsKZG, snarks: impl IntoIterator) -> Self { + let svk = params.get_g()[0].into(); + let snarks = snarks.into_iter().collect_vec(); + + let accumulators = snarks + .iter() + .flat_map(|snark| { + let mut transcript = + PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }) + .collect_vec(); + + let (accumulator, as_proof) = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .unwrap(); + (accumulator, transcript.finalize()) + }; + + let KzgAccumulator { lhs, rhs } = accumulator; + let instances = [lhs.x, lhs.y, rhs.x, rhs.y] + .map(fe_to_limbs::<_, _, LIMBS, BITS>) + .concat(); + + Self { + svk, + snarks: snarks.into_iter().map_into().collect(), + instances, + as_proof: Value::known(as_proof), + } + } + + pub fn accumulator_indices() -> Vec<(usize, usize)> { + (0..4 * LIMBS).map(|idx| (0, idx)).collect() + } + + pub fn num_instance() -> Vec { + vec![4 * LIMBS] + } + + pub fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + pub fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } + } + + impl Circuit for AggregationCircuit { + type Config = AggregationConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + snarks: self + .snarks + .iter() + .map(SnarkWitness::without_witnesses) + .collect(), + instances: Vec::new(), + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + AggregationConfig::configure( + meta, + vec![BITS / LIMBS], + Rns::::construct().overflow_lengths(), + ) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + + range_chip.load_table(&mut layouter)?; + + let (lhs, rhs) = layouter.assign_region( + || "", + |region| { + let ctx = RegionCtx::new(region, 0); + + let ecc_chip = config.ecc_chip(); + let loader = Halo2Loader::new(ecc_chip, ctx); + let KzgAccumulator { lhs, rhs } = + aggregate(&self.svk, &loader, &self.snarks, self.as_proof()); + + Ok((lhs.assigned(), rhs.assigned())) + }, + )?; + + for (limb, row) in iter::empty() + .chain(lhs.x().limbs()) + .chain(lhs.y().limbs()) + .chain(rhs.x().limbs()) + .chain(rhs.y().limbs()) + .zip(0..) + { + main_gate.expose_public(layouter.namespace(|| ""), limb.into(), row)?; + } + + Ok(()) + } + } +} + +fn gen_srs(k: u32) -> ParamsKZG { + ParamsKZG::::setup(k, OsRng) +} + +fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() +} + +fn gen_proof< + C: Circuit, + E: EncodedChallenge, + TR: TranscriptReadBuffer>, G1Affine, E>, + TW: TranscriptWriterBuffer, G1Affine, E>, +>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, +) -> Vec { + MockProver::run(params.k(), &circuit, instances.clone()) + .unwrap() + .assert_satisfied(); + + let instances = instances + .iter() + .map(|instances| instances.as_slice()) + .collect_vec(); + let proof = { + let mut transcript = TW::init(Vec::new()); + create_proof::, ProverGWC<_>, _, _, TW, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = TR::init(Cursor::new(proof.clone())); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + verify_proof::<_, VerifierGWC<_>, _, TR, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof +} + +fn gen_application_snark(params: &ParamsKZG) -> aggregation::Snark { + let circuit = application::StandardPlonk::rand(OsRng); + + let pk = gen_pk(params, &circuit); + let protocol = compile( + params, + pk.get_vk(), + Config::kzg().with_num_instance(application::StandardPlonk::num_instance()), + ); + + let proof = gen_proof::< + _, + _, + aggregation::PoseidonTranscript, + aggregation::PoseidonTranscript, + >(params, &pk, circuit.clone(), circuit.instances()); + aggregation::Snark::new(protocol, circuit.instances(), proof) +} + +fn gen_aggregation_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + num_instance: Vec, + accumulator_indices: Vec<(usize, usize)>, +) -> Vec { + let svk = params.get_g()[0].into(); + let dk = (params.g2(), params.s_g2()).into(); + let protocol = compile( + params, + vk, + Config::kzg() + .with_num_instance(num_instance.clone()) + .with_accumulator_indices(accumulator_indices), + ); + + let loader = EvmLoader::new::(); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + + let instances = transcript.load_instances(num_instance); + let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); + + loader.deployment_code() +} + +fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { + let calldata = encode_calldata(&instances, &proof); + let success = { + let mut evm = ExecutorBuilder::default() + .with_gas_limit(u64::MAX.into()) + .build(Backend::new(MultiFork::new().0, None)); + + let caller = Address::from_low_u64_be(0xfe); + let verifier = evm + .deploy(caller, deployment_code.into(), 0.into(), None) + .unwrap() + .address; + let result = evm + .call_raw(caller, verifier, calldata.into(), 0.into()) + .unwrap(); + + dbg!(result.gas_used); + + !result.reverted + }; + assert!(success); +} + +fn main() { + let params = gen_srs(22); + let params_app = { + let mut params = params.clone(); + params.downsize(8); + params + }; + + let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app)); + let agg_circuit = aggregation::AggregationCircuit::new(¶ms, snarks); + let pk = gen_pk(¶ms, &agg_circuit); + let deployment_code = gen_aggregation_evm_verifier( + ¶ms, + pk.get_vk(), + aggregation::AggregationCircuit::num_instance(), + aggregation::AggregationCircuit::accumulator_indices(), + ); + + let proof = gen_proof::<_, _, EvmTranscript, EvmTranscript>( + ¶ms, + &pk, + agg_circuit.clone(), + agg_circuit.instances(), + ); + evm_verify(deployment_code, agg_circuit.instances(), proof); +} diff --git a/examples/evm-verifier.rs b/examples/evm-verifier.rs new file mode 100644 index 00000000..b51a9a30 --- /dev/null +++ b/examples/evm-verifier.rs @@ -0,0 +1,260 @@ +use ethereum_types::Address; +use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; +use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{ + create_proof, keygen_pk, keygen_vk, verify_proof, Advice, Circuit, Column, + ConstraintSystem, Error, Fixed, Instance, ProvingKey, VerifyingKey, + }, + poly::{ + commitment::{Params, ParamsProver}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + Rotation, VerificationStrategy, + }, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use itertools::Itertools; +use plonk_verifier::{ + loader::evm::{encode_calldata, EvmLoader}, + pcs::kzg::{Gwc19, Kzg}, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, + verifier::{self, PlonkVerifier}, +}; +use rand::{rngs::OsRng, RngCore}; +use std::rc::Rc; + +type Plonk = verifier::Plonk>; + +#[derive(Clone, Copy)] +struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, +} + +impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { + a, + b, + c, + q_a, + q_b, + q_c, + q_ab, + constant, + instance, + } + } +} + +#[derive(Clone, Default)] +struct StandardPlonk(Fr); + +impl StandardPlonk { + fn rand(mut rng: R) -> Self { + Self(Fr::from(rng.next_u32() as u64)) + } + + fn num_instance() -> Vec { + vec![1] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0]] + } +} + +impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; + + region.assign_advice(|| "", config.a, 1, || Value::known(-Fr::from(5)))?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(|| "", column, 1, || Value::known(Fr::from(idx)))?; + } + + let a = region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; + a.copy_advice(|| "", &mut region, config.b, 3)?; + a.copy_advice(|| "", &mut region, config.c, 4)?; + + Ok(()) + }, + ) + } +} + +fn gen_srs(k: u32) -> ParamsKZG { + ParamsKZG::::setup(k, OsRng) +} + +fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() +} + +fn gen_proof>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, +) -> Vec { + MockProver::run(params.k(), &circuit, instances.clone()) + .unwrap() + .assert_satisfied(); + + let instances = instances + .iter() + .map(|instances| instances.as_slice()) + .collect_vec(); + let proof = { + let mut transcript = TranscriptWriterBuffer::<_, G1Affine, _>::init(Vec::new()); + create_proof::, ProverGWC<_>, _, _, EvmTranscript<_, _, _, _>, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = TranscriptReadBuffer::<_, G1Affine, _>::init(proof.as_slice()); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + verify_proof::<_, VerifierGWC<_>, _, EvmTranscript<_, _, _, _>, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof +} + +fn gen_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + num_instance: Vec, +) -> Vec { + let svk = params.get_g()[0].into(); + let dk = (params.g2(), params.s_g2()).into(); + let protocol = compile( + params, + vk, + Config::kzg().with_num_instance(num_instance.clone()), + ); + + let loader = EvmLoader::new::(); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + + let instances = transcript.load_instances(num_instance); + let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); + + loader.deployment_code() +} + +fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { + let calldata = encode_calldata(&instances, &proof); + let success = { + let mut evm = ExecutorBuilder::default() + .with_gas_limit(u64::MAX.into()) + .build(Backend::new(MultiFork::new().0, None)); + + let caller = Address::from_low_u64_be(0xfe); + let verifier = evm + .deploy(caller, deployment_code.into(), 0.into(), None) + .unwrap() + .address; + let result = evm + .call_raw(caller, verifier, calldata.into(), 0.into()) + .unwrap(); + + dbg!(result.gas_used); + + !result.reverted + }; + assert!(success); +} + +fn main() { + let params = gen_srs(8); + + let circuit = StandardPlonk::rand(OsRng); + let pk = gen_pk(¶ms, &circuit); + let deployment_code = gen_evm_verifier(¶ms, pk.get_vk(), StandardPlonk::num_instance()); + + let proof = gen_proof(¶ms, &pk, circuit.clone(), circuit.instances()); + evm_verify(deployment_code, circuit.instances(), proof); +} diff --git a/rust-toolchain b/rust-toolchain index db84486f..7cc6ef41 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2022-06-01 \ No newline at end of file +1.63.0 \ No newline at end of file diff --git a/src/cost.rs b/src/cost.rs new file mode 100644 index 00000000..b085aed8 --- /dev/null +++ b/src/cost.rs @@ -0,0 +1,44 @@ +use std::ops::Add; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Cost { + pub num_instance: usize, + pub num_commitment: usize, + pub num_evaluation: usize, + pub num_msm: usize, +} + +impl Cost { + pub fn new( + num_instance: usize, + num_commitment: usize, + num_evaluation: usize, + num_msm: usize, + ) -> Self { + Self { + num_instance, + num_commitment, + num_evaluation, + num_msm, + } + } +} + +impl Add for Cost { + type Output = Cost; + + fn add(self, rhs: Cost) -> Self::Output { + Cost::new( + self.num_instance + rhs.num_instance, + self.num_commitment + rhs.num_commitment, + self.num_evaluation + rhs.num_evaluation, + self.num_msm + rhs.num_msm, + ) + } +} + +pub trait CostEstimation { + type Input; + + fn estimate_cost(input: &Self::Input) -> Cost; +} diff --git a/src/lib.rs b/src/lib.rs index b193cc5c..4e8da5fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,19 +1,38 @@ -#![feature(int_log)] -#![feature(int_roundings)] -#![feature(assert_matches)] #![allow(clippy::type_complexity)] #![allow(clippy::too_many_arguments)] #![allow(clippy::upper_case_acronyms)] +pub mod cost; pub mod loader; -pub mod protocol; -pub mod scheme; +pub mod pcs; +pub mod system; pub mod util; +pub mod verifier; #[derive(Clone, Debug)] pub enum Error { InvalidInstances, - MissingQuery(util::Query), - MissingChallenge(usize), + InvalidLinearization, + InvalidQuery(util::protocol::Query), + InvalidChallenge(usize), + AssertionFailure(String), Transcript(std::io::ErrorKind, String), } + +#[derive(Clone, Debug)] +pub struct Protocol { + // Common description + pub domain: util::arithmetic::Domain, + pub preprocessed: Vec, + pub num_instance: Vec, + pub num_witness: Vec, + pub num_challenge: Vec, + pub evaluations: Vec, + pub queries: Vec, + pub quotient: util::protocol::QuotientPolynomial, + // Minor customization + pub transcript_initial_state: Option, + pub instance_committing_key: Option>, + pub linearization: Option, + pub accumulator_indices: Vec>, +} diff --git a/src/loader.rs b/src/loader.rs index ebca0ec9..8c39bae0 100644 --- a/src/loader.rs +++ b/src/loader.rs @@ -1,15 +1,21 @@ -use crate::util::{Curve, FieldOps, GroupOps, Itertools, PrimeField}; +use crate::{ + util::{ + arithmetic::{CurveAffine, FieldOps, PrimeField}, + Itertools, + }, + Error, +}; use std::{fmt::Debug, iter}; pub mod native; -#[cfg(feature = "evm")] +#[cfg(feature = "loader_evm")] pub mod evm; -#[cfg(feature = "halo2")] +#[cfg(feature = "loader_halo2")] pub mod halo2; -pub trait LoadedEcPoint: Clone + Debug + GroupOps + PartialEq { +pub trait LoadedEcPoint: Clone + Debug + PartialEq { type Loader: Loader; fn loader(&self) -> &Self::Loader; @@ -24,67 +30,11 @@ pub trait LoadedEcPoint: Clone + Debug + GroupOps + PartialEq { ) -> Self; } -pub trait LoadedScalar: Clone + Debug + FieldOps { +pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { type Loader: ScalarLoader; fn loader(&self) -> &Self::Loader; - fn sum_with_coeff_and_constant(values: &[(F, Self)], constant: &F) -> Self { - assert!(!values.is_empty()); - - let loader = values.first().unwrap().1.loader(); - iter::empty() - .chain(if *constant == F::zero() { - None - } else { - Some(loader.load_const(constant)) - }) - .chain( - values - .iter() - .map(|(coeff, value)| loader.load_const(coeff) * value), - ) - .reduce(|acc, term| acc + term) - .unwrap() - } - - fn sum_products_with_coeff_and_constant(values: &[(F, Self, Self)], constant: &F) -> Self { - assert!(!values.is_empty()); - - let loader = values.first().unwrap().1.loader(); - iter::empty() - .chain(if *constant == F::zero() { - None - } else { - Some(loader.load_const(constant)) - }) - .chain( - values - .iter() - .map(|(coeff, lhs, rhs)| loader.load_const(coeff) * lhs * rhs), - ) - .reduce(|acc, term| acc + term) - .unwrap() - } - - fn sum_with_coeff(values: &[(F, Self)]) -> Self { - Self::sum_with_coeff_and_constant(values, &F::zero()) - } - - fn sum_with_const(values: &[Self], constant: &F) -> Self { - Self::sum_with_coeff_and_constant( - &values - .iter() - .map(|value| (F::one(), value.clone())) - .collect_vec(), - constant, - ) - } - - fn sum(values: &[Self]) -> Self { - Self::sum_with_const(values, &F::zero()) - } - fn square(&self) -> Self { self.clone() * self } @@ -133,7 +83,7 @@ pub trait LoadedScalar: Clone + Debug + FieldOps { } } -pub trait EcPointLoader { +pub trait EcPointLoader { type LoadedEcPoint: LoadedEcPoint; fn ec_point_load_const(&self, value: &C) -> Self::LoadedEcPoint; @@ -145,6 +95,13 @@ pub trait EcPointLoader { fn ec_point_load_one(&self) -> Self::LoadedEcPoint { self.ec_point_load_const(&C::generator()) } + + fn ec_point_assert_eq( + &self, + annotation: &str, + lhs: &Self::LoadedEcPoint, + rhs: &Self::LoadedEcPoint, + ) -> Result<(), Error>; } pub trait ScalarLoader { @@ -159,9 +116,121 @@ pub trait ScalarLoader { fn load_one(&self) -> Self::LoadedScalar { self.load_const(&F::one()) } + + fn assert_eq( + &self, + annotation: &str, + lhs: &Self::LoadedScalar, + rhs: &Self::LoadedScalar, + ) -> Result<(), Error>; + + fn sum_with_coeff_and_const( + &self, + values: &[(F, &Self::LoadedScalar)], + constant: F, + ) -> Self::LoadedScalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let loader = values.first().unwrap().1.loader(); + iter::empty() + .chain(if constant == F::zero() { + None + } else { + Some(loader.load_const(&constant)) + }) + .chain(values.iter().map(|&(coeff, value)| { + if coeff == F::one() { + value.clone() + } else { + loader.load_const(&coeff) * value + } + })) + .reduce(|acc, term| acc + term) + .unwrap() + } + + fn sum_products_with_coeff_and_const( + &self, + values: &[(F, &Self::LoadedScalar, &Self::LoadedScalar)], + constant: F, + ) -> Self::LoadedScalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let loader = values.first().unwrap().1.loader(); + iter::empty() + .chain(if constant == F::zero() { + None + } else { + Some(loader.load_const(&constant)) + }) + .chain(values.iter().map(|&(coeff, lhs, rhs)| { + if coeff == F::one() { + lhs.clone() * rhs + } else { + loader.load_const(&coeff) * lhs * rhs + } + })) + .reduce(|acc, term| acc + term) + .unwrap() + } + + fn sum_with_coeff(&self, values: &[(F, &Self::LoadedScalar)]) -> Self::LoadedScalar { + self.sum_with_coeff_and_const(values, F::zero()) + } + + fn sum_with_const(&self, values: &[&Self::LoadedScalar], constant: F) -> Self::LoadedScalar { + self.sum_with_coeff_and_const( + &values.iter().map(|&value| (F::one(), value)).collect_vec(), + constant, + ) + } + + fn sum(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { + self.sum_with_const(values, F::zero()) + } + + fn sum_products_with_coeff( + &self, + values: &[(F, &Self::LoadedScalar, &Self::LoadedScalar)], + ) -> Self::LoadedScalar { + self.sum_products_with_coeff_and_const(values, F::zero()) + } + + fn sum_products_with_const( + &self, + values: &[(&Self::LoadedScalar, &Self::LoadedScalar)], + constant: F, + ) -> Self::LoadedScalar { + self.sum_products_with_coeff_and_const( + &values + .iter() + .map(|&(lhs, rhs)| (F::one(), lhs, rhs)) + .collect_vec(), + constant, + ) + } + + fn sum_products( + &self, + values: &[(&Self::LoadedScalar, &Self::LoadedScalar)], + ) -> Self::LoadedScalar { + self.sum_products_with_const(values, F::zero()) + } + + fn product(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { + values + .iter() + .fold(self.load_one(), |acc, value| acc * *value) + } } -pub trait Loader: EcPointLoader + ScalarLoader + Clone { +pub trait Loader: + EcPointLoader + ScalarLoader + Clone + Debug +{ fn start_cost_metering(&self, _: &str) {} fn end_cost_metering(&self) {} diff --git a/src/loader/evm.rs b/src/loader/evm.rs index e0754532..7a07670c 100644 --- a/src/loader/evm.rs +++ b/src/loader/evm.rs @@ -1,70 +1,14 @@ -use crate::{ - scheme::kzg::Cost, - util::{Itertools, PrimeField}, -}; -use ethereum_types::U256; -use std::iter; - -mod accumulation; mod code; -mod loader; -mod transcript; +pub(crate) mod loader; +mod util; #[cfg(test)] mod test; -pub use loader::EvmLoader; -pub use transcript::EvmTranscript; +pub use loader::{EcPoint, EvmLoader, Scalar}; +pub use util::{encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, MemoryChunk}; + +pub use ethereum_types::U256; #[cfg(test)] pub use test::execute; - -// Assert F::Repr in little-endian -pub fn field_to_u256(f: &F) -> U256 -where - F: PrimeField, -{ - U256::from_little_endian(f.to_repr().as_ref()) -} - -pub fn u256_to_field(value: U256) -> F -where - F: PrimeField, -{ - let value = value % (field_to_u256(&-F::one()) + 1u64); - let mut repr = F::Repr::default(); - value.to_little_endian(repr.as_mut()); - F::from_repr(repr).unwrap() -} - -pub fn modulus() -> U256 -where - F: PrimeField, -{ - U256::from_little_endian((-F::one()).to_repr().as_ref()) + 1 -} - -pub fn encode_calldata(instances: Vec>, proof: Vec) -> Vec -where - F: PrimeField, -{ - iter::empty() - .chain( - instances - .into_iter() - .flatten() - .flat_map(|value| value.to_repr().as_ref().iter().rev().cloned().collect_vec()), - ) - .chain(proof) - .collect() -} - -pub fn estimate_gas(cost: Cost) -> usize { - let proof_size = cost.num_commitment * 64 + (cost.num_evaluation + cost.num_statement) * 32; - - let intrinsic_cost = 21000; - let calldata_cost = (proof_size as f64 * 15.25).ceil() as usize; - let ec_operation_cost = 113100 + (cost.num_msm - 2) * 6350; - - intrinsic_cost + calldata_cost + ec_operation_cost -} diff --git a/src/loader/evm/accumulation.rs b/src/loader/evm/accumulation.rs deleted file mode 100644 index cfe4be6f..00000000 --- a/src/loader/evm/accumulation.rs +++ /dev/null @@ -1,98 +0,0 @@ -use crate::{ - loader::evm::loader::{EvmLoader, Scalar}, - protocol::Protocol, - scheme::kzg::{AccumulationStrategy, Accumulator, SameCurveAccumulation, MSM}, - util::{Curve, Itertools, PrimeCurveAffine, PrimeField, Transcript, UncompressedEncoding}, - Error, -}; -use ethereum_types::U256; -use halo2_curves::{ - bn256::{G1Affine, G2Affine, G1}, - CurveAffine, -}; -use std::{ops::Neg, rc::Rc}; - -impl SameCurveAccumulation, LIMBS, BITS> { - pub fn code(self, g1: G1Affine, g2: G2Affine, s_g2: G2Affine) -> Vec { - let (lhs, rhs) = self.accumulator.unwrap().evaluate(g1.to_curve()); - let loader = lhs.loader(); - - let [g2, minus_s_g2] = [g2, s_g2.neg()].map(|ec_point| { - let coordinates = ec_point.coordinates().unwrap(); - let x = coordinates.x().to_repr(); - let y = coordinates.y().to_repr(); - ( - U256::from_little_endian(&x.as_ref()[32..]), - U256::from_little_endian(&x.as_ref()[..32]), - U256::from_little_endian(&y.as_ref()[32..]), - U256::from_little_endian(&y.as_ref()[..32]), - ) - }); - loader.pairing(&lhs, g2, &rhs, minus_s_g2); - - loader.code() - } -} - -impl - AccumulationStrategy, T, P> - for SameCurveAccumulation, LIMBS, BITS> -where - C::Scalar: PrimeField, - C: UncompressedEncoding, - T: Transcript>, -{ - type Output = (); - - fn extract_accumulator( - &self, - protocol: &Protocol, - loader: &Rc, - transcript: &mut T, - statements: &[Vec], - ) -> Option>> { - let accumulator_indices = protocol.accumulator_indices.as_ref()?; - - let num_statements = statements - .iter() - .map(|statements| statements.len()) - .collect_vec(); - - let challenges = transcript.squeeze_n_challenges(accumulator_indices.len()); - let accumulators = accumulator_indices - .iter() - .map(|indices| { - assert_eq!(indices.len(), 4 * LIMBS); - assert!(indices - .iter() - .enumerate() - .all(|(idx, index)| indices[0] == (index.0, index.1 - idx))); - let offset = - (num_statements[..indices[0].0].iter().sum::() + indices[0].1) * 0x20; - let lhs = loader.calldataload_ec_point_from_limbs::(offset); - let rhs = loader.calldataload_ec_point_from_limbs::(offset + 0x100); - Accumulator::new(MSM::base(lhs), MSM::base(rhs)) - }) - .collect_vec(); - - Some(Accumulator::random_linear_combine( - challenges.into_iter().zip(accumulators), - )) - } - - fn process( - &mut self, - _: &Rc, - transcript: &mut T, - _: P, - accumulator: Accumulator>, - ) -> Result { - self.accumulator = Some(match self.accumulator.take() { - Some(curr_accumulator) => { - accumulator + curr_accumulator * &transcript.squeeze_challenge() - } - None => accumulator, - }); - Ok(()) - } -} diff --git a/src/loader/evm/code.rs b/src/loader/evm/code.rs index 38069401..80dd5c71 100644 --- a/src/loader/evm/code.rs +++ b/src/loader/evm/code.rs @@ -1,6 +1,6 @@ use crate::util::Itertools; use ethereum_types::U256; -use foundry_evm::{revm::opcode::*, HashMap}; +use std::{collections::HashMap, iter}; pub enum Precompiled { BigModExp = 0x05, @@ -36,6 +36,37 @@ impl Code { code } + pub fn deployment(code: Vec) -> Vec { + let code_len = code.len(); + assert_ne!(code_len, 0); + + iter::empty() + .chain([ + PUSH1 + 1, + (code_len >> 8) as u8, + (code_len & 0xff) as u8, + PUSH1, + 14, + PUSH1, + 0, + CODECOPY, + ]) + .chain([ + PUSH1 + 1, + (code_len >> 8) as u8, + (code_len & 0xff) as u8, + PUSH1, + 0, + RETURN, + ]) + .chain(code) + .collect() + } + + pub fn stack_len(&self) -> usize { + self.stack_len + } + pub fn len(&self) -> usize { self.code.len() } @@ -180,3 +211,85 @@ impl_opcodes!( revert -> (REVERT, -2) selfdestruct -> (SELFDESTRUCT, -1) ); + +const STOP: u8 = 0x00; +const ADD: u8 = 0x01; +const MUL: u8 = 0x02; +const SUB: u8 = 0x03; +const DIV: u8 = 0x04; +const SDIV: u8 = 0x05; +const MOD: u8 = 0x06; +const SMOD: u8 = 0x07; +const ADDMOD: u8 = 0x08; +const MULMOD: u8 = 0x09; +const EXP: u8 = 0x0A; +const SIGNEXTEND: u8 = 0x0B; +const LT: u8 = 0x10; +const GT: u8 = 0x11; +const SLT: u8 = 0x12; +const SGT: u8 = 0x13; +const EQ: u8 = 0x14; +const ISZERO: u8 = 0x15; +const AND: u8 = 0x16; +const OR: u8 = 0x17; +const XOR: u8 = 0x18; +const NOT: u8 = 0x19; +const BYTE: u8 = 0x1A; +const SHL: u8 = 0x1B; +const SHR: u8 = 0x1C; +const SAR: u8 = 0x1D; +const SHA3: u8 = 0x20; +const ADDRESS: u8 = 0x30; +const BALANCE: u8 = 0x31; +const ORIGIN: u8 = 0x32; +const CALLER: u8 = 0x33; +const CALLVALUE: u8 = 0x34; +const CALLDATALOAD: u8 = 0x35; +const CALLDATASIZE: u8 = 0x36; +const CALLDATACOPY: u8 = 0x37; +const CODESIZE: u8 = 0x38; +const CODECOPY: u8 = 0x39; +const GASPRICE: u8 = 0x3A; +const EXTCODESIZE: u8 = 0x3B; +const EXTCODECOPY: u8 = 0x3C; +const RETURNDATASIZE: u8 = 0x3D; +const RETURNDATACOPY: u8 = 0x3E; +const EXTCODEHASH: u8 = 0x3F; +const BLOCKHASH: u8 = 0x40; +const COINBASE: u8 = 0x41; +const TIMESTAMP: u8 = 0x42; +const NUMBER: u8 = 0x43; +const DIFFICULTY: u8 = 0x44; +const GASLIMIT: u8 = 0x45; +const CHAINID: u8 = 0x46; +const SELFBALANCE: u8 = 0x47; +const BASEFEE: u8 = 0x48; +const POP: u8 = 0x50; +const MLOAD: u8 = 0x51; +const MSTORE: u8 = 0x52; +const MSTORE8: u8 = 0x53; +const SLOAD: u8 = 0x54; +const SSTORE: u8 = 0x55; +const JUMP: u8 = 0x56; +const JUMPI: u8 = 0x57; +const PC: u8 = 0x58; +const MSIZE: u8 = 0x59; +const GAS: u8 = 0x5A; +const JUMPDEST: u8 = 0x5B; +const PUSH1: u8 = 0x60; +const DUP1: u8 = 0x80; +const SWAP1: u8 = 0x90; +const LOG0: u8 = 0xA0; +const LOG1: u8 = 0xA1; +const LOG2: u8 = 0xA2; +const LOG3: u8 = 0xA3; +const LOG4: u8 = 0xA4; +const CREATE: u8 = 0xF0; +const CALL: u8 = 0xF1; +const CALLCODE: u8 = 0xF2; +const RETURN: u8 = 0xF3; +const DELEGATECALL: u8 = 0xF4; +const CREATE2: u8 = 0xF5; +const STATICCALL: u8 = 0xFA; +const REVERT: u8 = 0xFD; +const SELFDESTRUCT: u8 = 0xFF; diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs index 258ffdd0..06ab7fd8 100644 --- a/src/loader/evm/loader.rs +++ b/src/loader/evm/loader.rs @@ -1,24 +1,49 @@ use crate::{ loader::evm::{ code::{Code, Precompiled}, - modulus, + fe_to_u256, modulus, }, - loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::{Curve, FieldOps, Itertools, PrimeField, UncompressedEncoding}, + loader::{evm::u256_to_fe, EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, + util::{ + arithmetic::{CurveAffine, FieldOps, PrimeField}, + Itertools, + }, + Error, }; use ethereum_types::{U256, U512}; use std::{ cell::RefCell, + collections::HashMap, fmt::{self, Debug}, iter, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, + ops::{Add, AddAssign, DerefMut, Mul, MulAssign, Neg, Sub, SubAssign}, rc::Rc, }; -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub enum Value { Constant(T), Memory(usize), + Negated(Box>), + Sum(Box>, Box>), + Product(Box>, Box>), +} + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + self.identifier() == other.identifier() + } +} + +impl Value { + fn identifier(&self) -> String { + match &self { + Value::Constant(_) | Value::Memory(_) => format!("{:?}", self), + Value::Negated(value) => format!("-({:?})", value), + Value::Sum(lhs, rhs) => format!("({:?} + {:?})", lhs, rhs), + Value::Product(lhs, rhs) => format!("({:?} * {:?})", lhs, rhs), + } + } } #[derive(Clone, Debug)] @@ -27,6 +52,7 @@ pub struct EvmLoader { scalar_modulus: U256, code: RefCell, ptr: RefCell, + cache: RefCell>, #[cfg(test)] gas_metering_ids: RefCell>, } @@ -46,13 +72,18 @@ impl EvmLoader { base_modulus, scalar_modulus, code: RefCell::new(code), - ptr: RefCell::new(0), + ptr: Default::default(), + cache: Default::default(), #[cfg(test)] gas_metering_ids: RefCell::new(Vec::new()), }) } - pub fn code(self: &Rc) -> Vec { + pub fn deployment_code(self: &Rc) -> Vec { + Code::deployment(self.runtime_code()) + } + + pub fn runtime_code(self: &Rc) -> Vec { let mut code = self.code.borrow().clone(); let dst = code.len() + 9; code.push(dst) @@ -72,7 +103,41 @@ impl EvmLoader { ptr } - fn scalar(self: &Rc, value: Value) -> Scalar { + pub(crate) fn scalar_modulus(&self) -> U256 { + self.scalar_modulus + } + + pub(crate) fn ptr(&self) -> usize { + *self.ptr.borrow() + } + + pub(crate) fn code_mut(&self) -> impl DerefMut + '_ { + self.code.borrow_mut() + } + + pub(crate) fn scalar(self: &Rc, value: Value) -> Scalar { + let value = if matches!( + value, + Value::Constant(_) | Value::Memory(_) | Value::Negated(_) + ) { + value + } else { + let identifier = value.identifier(); + let some_ptr = self.cache.borrow().get(&identifier).cloned(); + let ptr = if let Some(ptr) = some_ptr { + ptr + } else { + self.push(&Scalar { + loader: self.clone(), + value, + }); + let ptr = self.allocate(0x20); + self.code.borrow_mut().push(ptr).mstore(); + self.cache.borrow_mut().insert(identifier, ptr); + ptr + }; + Value::Memory(ptr) + }; Scalar { loader: self.clone(), value, @@ -87,13 +152,29 @@ impl EvmLoader { } fn push(self: &Rc, scalar: &Scalar) { - match scalar.value { + match scalar.value.clone() { Value::Constant(constant) => { self.code.borrow_mut().push(constant); } Value::Memory(ptr) => { self.code.borrow_mut().push(ptr).mload(); } + Value::Negated(value) => { + self.push(&self.scalar(*value)); + self.code.borrow_mut().push(self.scalar_modulus).sub(); + } + Value::Sum(lhs, rhs) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(&self.scalar(*lhs)); + self.push(&self.scalar(*rhs)); + self.code.borrow_mut().addmod(); + } + Value::Product(lhs, rhs) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(&self.scalar(*lhs)); + self.push(&self.scalar(*rhs)); + self.code.borrow_mut().mulmod(); + } } } @@ -139,46 +220,36 @@ impl EvmLoader { self.ec_point(Value::Memory(ptr)) } - pub fn calldataload_ec_point_from_limbs( + pub fn ec_point_from_limbs( self: &Rc, - offset: usize, + x_limbs: [Scalar; LIMBS], + y_limbs: [Scalar; LIMBS], ) -> EcPoint { let ptr = self.allocate(0x40); - for (ptr, offset) in [(ptr, offset), (ptr + 0x20, offset + LIMBS * 0x20)] { - for idx in 0..LIMBS { - if idx == 0 { - self.code - .borrow_mut() - // [..., success] - .push(offset) - // [..., success, x_limb_0_ptr] - .calldataload(); - // [..., success, x_limb_0] - } else { + for (ptr, limbs) in [(ptr, x_limbs), (ptr + 0x20, y_limbs)] { + for (idx, limb) in limbs.into_iter().enumerate() { + self.push(&limb); + // [..., success, acc] + if idx > 0 { self.code .borrow_mut() - // [..., success, x_acc] - .push(offset + idx * 0x20) - // [..., success, x_acc, x_limb_i_ptr] - .calldataload() - // [..., success, x_acc, x_limb_i] .push(idx * BITS) - // [..., success, x_acc, x_limb_i, shift] + // [..., success, acc, limb_i, shift] .shl() - // [..., success, x_acc, x_limb_i << shift] + // [..., success, acc, limb_i << shift] .add(); - // [..., success, x_acc] + // [..., success, acc] } } self.code .borrow_mut() - // [..., success, x] + // [..., success, coordinate] .dup(0) - // [..., success, x, x] + // [..., success, coordinate, coordinate] .push(ptr) - // [..., success, x, x, x_ptr] + // [..., success, coordinate, coordinate, ptr] .mstore(); - // [..., success, x] + // [..., success, coordinate] } // [..., success, x, y] self.validate_ec_point(); @@ -258,60 +329,21 @@ impl EvmLoader { .and(); } - pub fn squeeze_challenge(self: &Rc, ptr: usize, len: usize) -> (usize, Scalar) { - assert!(len > 0 && len % 0x20 == 0); - - let (ptr, len) = if len == 0x20 { - let ptr = if ptr + len != *self.ptr.borrow() { - (ptr..ptr + len) - .step_by(0x20) - .map(|ptr| self.dup_scalar(&self.scalar(Value::Memory(ptr)))) - .collect_vec() - .first() - .unwrap() - .ptr() - } else { - ptr - }; - self.code.borrow_mut().push(1).push(ptr + 0x20).mstore8(); - (ptr, len + 1) - } else { - (ptr, len) - }; - - let challenge_ptr = self.allocate(0x20); + pub fn keccak256(self: &Rc, ptr: usize, len: usize) -> usize { let hash_ptr = self.allocate(0x20); - self.code .borrow_mut() - .push(self.scalar_modulus) .push(len) .push(ptr) .keccak256() - .dup(0) .push(hash_ptr) - .mstore() - .r#mod() - .push(challenge_ptr) .mstore(); - - (hash_ptr, self.scalar(Value::Memory(challenge_ptr))) + hash_ptr } pub fn copy_scalar(self: &Rc, scalar: &Scalar, ptr: usize) { - match scalar.value { - Value::Constant(constant) => { - self.code.borrow_mut().push(constant).push(ptr).mstore(); - } - Value::Memory(src_ptr) => { - self.code - .borrow_mut() - .push(src_ptr) - .mload() - .push(ptr) - .mstore(); - } - } + self.push(scalar); + self.code.borrow_mut().push(ptr).mstore(); } pub fn dup_scalar(self: &Rc, scalar: &Scalar) -> Scalar { @@ -320,7 +352,7 @@ impl EvmLoader { self.scalar(Value::Memory(ptr)) } - fn dup_ec_point(self: &Rc, value: &EcPoint) -> EcPoint { + pub fn dup_ec_point(self: &Rc, value: &EcPoint) -> EcPoint { let ptr = self.allocate(0x40); match value.value { Value::Constant((x, y)) => { @@ -345,6 +377,9 @@ impl EvmLoader { .push(ptr + 0x20) .mstore(); } + Value::Negated(_) | Value::Sum(_, _) | Value::Product(_, _) => { + unreachable!() + } } self.ec_point(Value::Memory(ptr)) } @@ -390,14 +425,6 @@ impl EvmLoader { self.ec_point(Value::Memory(rd_ptr)) } - fn ec_point_sub(self: &Rc, _: &EcPoint, _: &EcPoint) -> EcPoint { - unreachable!() - } - - fn ec_point_neg(self: &Rc, _: &EcPoint) -> EcPoint { - unreachable!() - } - fn ec_point_scalar_mul(self: &Rc, ec_point: &EcPoint, scalar: &Scalar) -> EcPoint { let rd_ptr = self.dup_ec_point(ec_point).ptr(); self.dup_scalar(scalar); @@ -449,19 +476,15 @@ impl EvmLoader { } fn add(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { - if let (Value::Constant(lhs), Value::Constant(rhs)) = (lhs.value, rhs.value) { + if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) { let out = (U512::from(lhs) + U512::from(rhs)) % U512::from(self.scalar_modulus); return self.scalar(Value::Constant(out.try_into().unwrap())); } - let ptr = self.allocate(0x20); - - self.code.borrow_mut().push(self.scalar_modulus); - self.push(rhs); - self.push(lhs); - self.code.borrow_mut().addmod().push(ptr).mstore(); - - self.scalar(Value::Memory(ptr)) + self.scalar(Value::Sum( + Box::new(lhs.value.clone()), + Box::new(rhs.value.clone()), + )) } fn sub(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { @@ -469,31 +492,22 @@ impl EvmLoader { return self.add(lhs, &self.neg(rhs)); } - let ptr = self.allocate(0x20); - - self.code.borrow_mut().push(self.scalar_modulus); - self.push(rhs); - self.code.borrow_mut().push(self.scalar_modulus).sub(); - self.push(lhs); - self.code.borrow_mut().addmod().push(ptr).mstore(); - - self.scalar(Value::Memory(ptr)) + self.scalar(Value::Sum( + Box::new(lhs.value.clone()), + Box::new(Value::Negated(Box::new(rhs.value.clone()))), + )) } fn mul(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { - if let (Value::Constant(lhs), Value::Constant(rhs)) = (lhs.value, rhs.value) { + if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) { let out = (U512::from(lhs) * U512::from(rhs)) % U512::from(self.scalar_modulus); return self.scalar(Value::Constant(out.try_into().unwrap())); } - let ptr = self.allocate(0x20); - - self.code.borrow_mut().push(self.scalar_modulus); - self.push(rhs); - self.push(lhs); - self.code.borrow_mut().mulmod().push(ptr).mstore(); - - self.scalar(Value::Memory(ptr)) + self.scalar(Value::Product( + Box::new(lhs.value.clone()), + Box::new(rhs.value.clone()), + )) } fn neg(self: &Rc, scalar: &Scalar) -> Scalar { @@ -501,17 +515,7 @@ impl EvmLoader { return self.scalar(Value::Constant(self.scalar_modulus - constant)); } - let ptr = self.allocate(0x20); - - self.push(scalar); - self.code - .borrow_mut() - .push(self.scalar_modulus) - .sub() - .push(ptr) - .mstore(); - - self.scalar(Value::Memory(ptr)) + self.scalar(Value::Negated(Box::new(scalar.value.clone()))) } } @@ -552,19 +556,15 @@ pub struct EcPoint { } impl EcPoint { - pub(super) fn loader(&self) -> &Rc { + pub(crate) fn loader(&self) -> &Rc { &self.loader } - pub fn value(&self) -> Value<(U256, U256)> { - self.value - } - - pub fn is_const(&self) -> bool { - matches!(self.value, Value::Constant(_)) + pub(crate) fn value(&self) -> Value<(U256, U256)> { + self.value.clone() } - pub fn ptr(&self) -> usize { + pub(crate) fn ptr(&self) -> usize { match self.value { Value::Memory(ptr) => ptr, _ => unreachable!(), @@ -580,70 +580,6 @@ impl Debug for EcPoint { } } -impl Add for EcPoint { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - self.loader.ec_point_add(&self, &rhs) - } -} - -impl Sub for EcPoint { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - self.loader.ec_point_sub(&self, &rhs) - } -} - -impl Neg for EcPoint { - type Output = Self; - - fn neg(self) -> Self { - self.loader.ec_point_neg(&self) - } -} - -impl<'a> Add<&'a Self> for EcPoint { - type Output = Self; - - fn add(self, rhs: &'a Self) -> Self { - self.loader.ec_point_add(&self, rhs) - } -} - -impl<'a> Sub<&'a Self> for EcPoint { - type Output = Self; - - fn sub(self, rhs: &'a Self) -> Self { - self.loader.ec_point_sub(&self, rhs) - } -} - -impl AddAssign for EcPoint { - fn add_assign(&mut self, rhs: Self) { - *self = self.loader.ec_point_add(self, &rhs); - } -} - -impl SubAssign for EcPoint { - fn sub_assign(&mut self, rhs: Self) { - *self = self.loader.ec_point_sub(self, &rhs); - } -} - -impl<'a> AddAssign<&'a Self> for EcPoint { - fn add_assign(&mut self, rhs: &'a Self) { - *self = self.loader.ec_point_add(self, rhs); - } -} - -impl<'a> SubAssign<&'a Self> for EcPoint { - fn sub_assign(&mut self, rhs: &'a Self) { - *self = self.loader.ec_point_sub(self, rhs); - } -} - impl PartialEq for EcPoint { fn eq(&self, other: &Self) -> bool { self.value == other.value @@ -652,8 +588,8 @@ impl PartialEq for EcPoint { impl LoadedEcPoint for EcPoint where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, + C: CurveAffine, + C::ScalarExt: PrimeField, { type Loader = Rc; @@ -668,7 +604,7 @@ where Value::Constant(constant) if constant == U256::one() => ec_point, _ => ec_point.loader.ec_point_scalar_mul(&ec_point, &scalar), }) - .reduce(|acc, ec_point| acc + ec_point) + .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point)) .unwrap() } } @@ -680,18 +616,27 @@ pub struct Scalar { } impl Scalar { - pub fn value(&self) -> Value { - self.value + pub(crate) fn loader(&self) -> &Rc { + &self.loader } - pub fn is_const(&self) -> bool { + pub(crate) fn value(&self) -> Value { + self.value.clone() + } + + pub(crate) fn is_const(&self) -> bool { matches!(self.value, Value::Constant(_)) } - pub fn ptr(&self) -> usize { + pub(crate) fn ptr(&self) -> usize { match self.value { Value::Memory(ptr) => ptr, - _ => unreachable!(), + _ => *self + .loader + .cache + .borrow() + .get(&self.value.identifier()) + .unwrap(), } } } @@ -879,34 +824,164 @@ impl> LoadedScalar for Scalar { impl EcPointLoader for Rc where - C: Curve + UncompressedEncoding, + C: CurveAffine, C::Scalar: PrimeField, { type LoadedEcPoint = EcPoint; fn ec_point_load_const(&self, value: &C) -> EcPoint { - let bytes = value.to_uncompressed(); - let (x, y) = ( - U256::from_little_endian(&bytes[..32]), - U256::from_little_endian(&bytes[32..]), - ); + let coordinates = value.coordinates().unwrap(); + let [x, y] = [coordinates.x(), coordinates.y()] + .map(|coordinate| U256::from_little_endian(coordinate.to_repr().as_ref())); self.ec_point(Value::Constant((x, y))) } + + fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) -> Result<(), Error> { + unimplemented!() + } } impl> ScalarLoader for Rc { type LoadedScalar = Scalar; fn load_const(&self, value: &F) -> Scalar { - self.scalar(Value::Constant(U256::from_little_endian( - value.to_repr().as_slice(), - ))) + self.scalar(Value::Constant(fe_to_u256(*value))) + } + + fn assert_eq(&self, _: &str, _: &Scalar, _: &Scalar) -> Result<(), Error> { + unimplemented!() + } + + fn sum_with_coeff_and_const(&self, values: &[(F, &Scalar)], constant: F) -> Scalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let push_addend = |(coeff, value): &(F, &Scalar)| { + assert_ne!(*coeff, F::zero()); + match (*coeff == F::one(), &value.value) { + (true, _) => { + self.push(value); + } + (false, Value::Constant(value)) => { + self.push(&self.scalar(Value::Constant(fe_to_u256( + *coeff * u256_to_fe::(*value), + )))); + } + (false, _) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + self.push(value); + self.code.borrow_mut().mulmod(); + } + } + }; + + let mut values = values.iter(); + if constant == F::zero() { + push_addend(values.next().unwrap()); + } else { + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); + } + + let chunk_size = 16 - self.code.borrow().stack_len(); + for values in &values.chunks(chunk_size) { + let values = values.into_iter().collect_vec(); + + self.code.borrow_mut().push(self.scalar_modulus); + for _ in 1..chunk_size.min(values.len()) { + self.code.borrow_mut().dup(0); + } + self.code.borrow_mut().swap(chunk_size.min(values.len())); + + for value in values { + push_addend(value); + self.code.borrow_mut().addmod(); + } + } + + let ptr = self.allocate(0x20); + self.code.borrow_mut().push(ptr).mstore(); + + self.scalar(Value::Memory(ptr)) + } + + fn sum_products_with_coeff_and_const( + &self, + values: &[(F, &Scalar, &Scalar)], + constant: F, + ) -> Scalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let push_addend = |(coeff, lhs, rhs): &(F, &Scalar, &Scalar)| { + assert_ne!(*coeff, F::zero()); + match (*coeff == F::one(), &lhs.value, &rhs.value) { + (_, Value::Constant(lhs), Value::Constant(rhs)) => { + self.push(&self.scalar(Value::Constant(fe_to_u256( + *coeff * u256_to_fe::(*lhs) * u256_to_fe::(*rhs), + )))); + } + (_, value @ Value::Memory(_), Value::Constant(constant)) + | (_, Value::Constant(constant), value @ Value::Memory(_)) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(&self.scalar(Value::Constant(fe_to_u256( + *coeff * u256_to_fe::(*constant), + )))); + self.push(&self.scalar(value.clone())); + self.code.borrow_mut().mulmod(); + } + (true, _, _) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(lhs); + self.push(rhs); + self.code.borrow_mut().mulmod(); + } + (false, _, _) => { + self.code.borrow_mut().push(self.scalar_modulus).dup(0); + self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + self.push(lhs); + self.code.borrow_mut().mulmod(); + self.push(rhs); + self.code.borrow_mut().mulmod(); + } + } + }; + + let mut values = values.iter(); + if constant == F::zero() { + push_addend(values.next().unwrap()); + } else { + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); + } + + let chunk_size = 16 - self.code.borrow().stack_len(); + for values in &values.chunks(chunk_size) { + let values = values.into_iter().collect_vec(); + + self.code.borrow_mut().push(self.scalar_modulus); + for _ in 1..chunk_size.min(values.len()) { + self.code.borrow_mut().dup(0); + } + self.code.borrow_mut().swap(chunk_size.min(values.len())); + + for value in values { + push_addend(value); + self.code.borrow_mut().addmod(); + } + } + + let ptr = self.allocate(0x20); + self.code.borrow_mut().push(ptr).mstore(); + + self.scalar(Value::Memory(ptr)) } } impl Loader for Rc where - C: Curve + UncompressedEncoding, + C: CurveAffine, C::Scalar: PrimeField, { #[cfg(test)] diff --git a/src/loader/evm/test.rs b/src/loader/evm/test.rs index d4ae4984..e204c1b8 100644 --- a/src/loader/evm/test.rs +++ b/src/loader/evm/test.rs @@ -9,12 +9,6 @@ use std::env::var_os; mod tui; -fn small_address(lsb: u8) -> Address { - let mut address = Address::zero(); - *address.0.last_mut().unwrap() = lsb; - address -} - fn debug() -> bool { matches!( var_os("DEBUG"), @@ -23,9 +17,15 @@ fn debug() -> bool { } pub fn execute(code: Vec, calldata: Vec) -> (bool, u64, Vec) { + assert!( + code.len() <= 0x6000, + "Contract size {} exceeds the limit 24576", + code.len() + ); + let debug = debug(); - let caller = small_address(0xfe); - let callee = small_address(0xff); + let caller = Address::from_low_u64_be(0xfe); + let callee = Address::from_low_u64_be(0xff); let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) @@ -52,5 +52,5 @@ pub fn execute(code: Vec, calldata: Vec) -> (bool, u64, Vec) { Tui::new(result.debug.unwrap().flatten(0), 0).start(); } - (!result.reverted, result.gas, costs) + (!result.reverted, result.gas_used, costs) } diff --git a/src/loader/evm/test/tui.rs b/src/loader/evm/test/tui.rs index 72866d19..fcaef36c 100644 --- a/src/loader/evm/test/tui.rs +++ b/src/loader/evm/test/tui.rs @@ -60,7 +60,7 @@ impl Tui { .expect("unable to execute disable mouse capture"); println!("{e}"); })); - let tick_rate = Duration::from_millis(200); + let tick_rate = Duration::from_millis(60); let (tx, rx) = mpsc::channel(); thread::spawn(move || { diff --git a/src/loader/evm/transcript.rs b/src/loader/evm/transcript.rs deleted file mode 100644 index a5ff68d2..00000000 --- a/src/loader/evm/transcript.rs +++ /dev/null @@ -1,229 +0,0 @@ -use crate::{ - loader::{ - evm::{ - loader::{EcPoint, EvmLoader, Scalar, Value}, - u256_to_field, - }, - native::NativeLoader, - Loader, - }, - util::{Curve, Group, Itertools, PrimeField, Transcript, TranscriptRead, UncompressedEncoding}, - Error, -}; -use ethereum_types::U256; -use sha3::{Digest, Keccak256}; -use std::{ - io::{self, Read, Write}, - marker::PhantomData, - rc::Rc, -}; - -pub struct MemoryChunk { - ptr: usize, - len: usize, -} - -impl MemoryChunk { - fn new(ptr: usize) -> Self { - Self { ptr, len: 0x20 } - } - - fn reset(&mut self, ptr: usize) { - self.ptr = ptr; - self.len = 0x20; - } - - fn include(&self, ptr: usize, size: usize) -> bool { - let range = self.ptr..=self.ptr + self.len; - range.contains(&ptr) && range.contains(&(ptr + size)) - } - - fn extend(&mut self, ptr: usize, size: usize) { - if !self.include(ptr, size) { - assert_eq!(self.ptr + self.len, ptr); - self.len += size; - } - } -} - -pub struct EvmTranscript, S, B> { - loader: L, - stream: S, - buf: B, - _marker: PhantomData, -} - -impl EvmTranscript, usize, MemoryChunk> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - pub fn new(loader: Rc) -> Self { - let ptr = loader.allocate(0x20); - assert_eq!(ptr, 0); - Self { - loader, - stream: 0, - buf: MemoryChunk::new(ptr), - _marker: PhantomData, - } - } -} - -impl Transcript> for EvmTranscript, usize, MemoryChunk> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn squeeze_challenge(&mut self) -> Scalar { - let (ptr, scalar) = self.loader.squeeze_challenge(self.buf.ptr, self.buf.len); - self.buf.reset(ptr); - scalar - } - - fn common_ec_point(&mut self, ec_point: &EcPoint) -> Result<(), Error> { - if let Value::Memory(ptr) = ec_point.value() { - self.buf.extend(ptr, 0x40); - } else { - unreachable!() - } - Ok(()) - } - - fn common_scalar(&mut self, scalar: &Scalar) -> Result<(), Error> { - match scalar.value() { - Value::Constant(_) if self.buf.ptr == 0 => { - self.loader.copy_scalar(scalar, self.buf.ptr); - } - Value::Memory(ptr) => { - self.buf.extend(ptr, 0x20); - } - _ => unreachable!(), - } - Ok(()) - } -} - -impl TranscriptRead> for EvmTranscript, usize, MemoryChunk> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn read_scalar(&mut self) -> Result { - let scalar = self.loader.calldataload_scalar(self.stream); - self.stream += 0x20; - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result { - let ec_point = self.loader.calldataload_ec_point(self.stream); - self.stream += 0x40; - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl EvmTranscript> -where - C: Curve, -{ - pub fn new(stream: S) -> Self { - Self { - loader: NativeLoader, - stream, - buf: Vec::new(), - _marker: PhantomData, - } - } -} - -impl Transcript for EvmTranscript> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn squeeze_challenge(&mut self) -> C::Scalar { - let data = self - .buf - .iter() - .cloned() - .chain(if self.buf.len() == 0x20 { - Some(1) - } else { - None - }) - .collect_vec(); - let hash: [u8; 32] = Keccak256::digest(data).into(); - self.buf = hash.to_vec(); - u256_to_field(U256::from_big_endian(hash.as_slice())) - } - - fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { - let uncopressed = ec_point.to_uncompressed(); - self.buf.extend(uncopressed[..32].iter().rev().cloned()); - self.buf.extend(uncopressed[32..].iter().rev().cloned()); - - Ok(()) - } - - fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { - self.buf.extend(scalar.to_repr().as_ref().iter().rev()); - - Ok(()) - } -} - -impl TranscriptRead for EvmTranscript> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, - S: Read, -{ - fn read_scalar(&mut self) -> Result { - let mut data = [0; 32]; - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - data.reverse(); - let scalar = ::Scalar::from_repr_vartime(data).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid scalar encoding in proof".to_string(), - ) - })?; - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result { - let mut data = [0; 64]; - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - data.as_mut_slice()[..32].reverse(); - data.as_mut_slice()[32..].reverse(); - let ec_point = C::from_uncompressed(data).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid elliptic curve point encoding in proof".to_string(), - ) - })?; - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl EvmTranscript> -where - C: Curve, - S: Write, -{ - pub fn stream_mut(&mut self) -> &mut S { - &mut self.stream - } - - pub fn finalize(self) -> S { - self.stream - } -} diff --git a/src/loader/evm/util.rs b/src/loader/evm/util.rs new file mode 100644 index 00000000..0d9698bd --- /dev/null +++ b/src/loader/evm/util.rs @@ -0,0 +1,92 @@ +use crate::{ + cost::Cost, + util::{arithmetic::PrimeField, Itertools}, +}; +use ethereum_types::U256; +use std::iter; + +pub struct MemoryChunk { + ptr: usize, + len: usize, +} + +impl MemoryChunk { + pub fn new(ptr: usize) -> Self { + Self { ptr, len: 0 } + } + + pub fn ptr(&self) -> usize { + self.ptr + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + pub fn end(&self) -> usize { + self.ptr + self.len + } + + pub fn reset(&mut self, ptr: usize) { + self.ptr = ptr; + self.len = 0; + } + + pub fn extend(&mut self, size: usize) { + self.len += size; + } +} + +// Assume fields implements traits in crate `ff` always have little-endian representation. +pub fn fe_to_u256(f: F) -> U256 +where + F: PrimeField, +{ + U256::from_little_endian(f.to_repr().as_ref()) +} + +pub fn u256_to_fe(value: U256) -> F +where + F: PrimeField, +{ + let value = value % modulus::(); + let mut repr = F::Repr::default(); + value.to_little_endian(repr.as_mut()); + F::from_repr(repr).unwrap() +} + +pub fn modulus() -> U256 +where + F: PrimeField, +{ + U256::from_little_endian((-F::one()).to_repr().as_ref()) + 1 +} + +pub fn encode_calldata(instances: &[Vec], proof: &[u8]) -> Vec +where + F: PrimeField, +{ + iter::empty() + .chain( + instances + .iter() + .flatten() + .flat_map(|value| value.to_repr().as_ref().iter().rev().cloned().collect_vec()), + ) + .chain(proof.iter().cloned()) + .collect() +} + +pub fn estimate_gas(cost: Cost) -> usize { + let proof_size = cost.num_commitment * 64 + (cost.num_evaluation + cost.num_instance) * 32; + + let intrinsic_cost = 21000; + let calldata_cost = (proof_size as f64 * 15.25).ceil() as usize; + let ec_operation_cost = 113100 + (cost.num_msm - 2) * 6350; + + intrinsic_cost + calldata_cost + ec_operation_cost +} diff --git a/src/loader/halo2.rs b/src/loader/halo2.rs index 3d227a8a..9eae74e3 100644 --- a/src/loader/halo2.rs +++ b/src/loader/halo2.rs @@ -1,6 +1,29 @@ -mod accumulation; -mod loader; -mod transcript; +pub(crate) mod loader; +mod shim; -pub use loader::Halo2Loader; -pub use transcript::PoseidonTranscript; +#[cfg(test)] +pub(crate) mod test; + +pub use loader::{EcPoint, Halo2Loader, Scalar}; +pub use shim::{Context, EccInstructions, IntegerInstructions}; +pub use util::Valuetools; + +pub use halo2_wrong_ecc; + +mod util { + use halo2_proofs::circuit::Value; + + pub trait Valuetools: Iterator> { + fn fold_zipped(self, init: B, mut f: F) -> Value + where + Self: Sized, + F: FnMut(B, V) -> B, + { + self.into_iter().fold(Value::known(init), |acc, value| { + acc.zip(value).map(|(acc, value)| f(acc, value)) + }) + } + } + + impl>> Valuetools for I {} +} diff --git a/src/loader/halo2/accumulation.rs b/src/loader/halo2/accumulation.rs deleted file mode 100644 index cc78ab83..00000000 --- a/src/loader/halo2/accumulation.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::{ - loader::{ - halo2::loader::{Halo2Loader, Scalar}, - LoadedEcPoint, - }, - protocol::Protocol, - scheme::kzg::{AccumulationStrategy, Accumulator, SameCurveAccumulation, MSM}, - util::{Itertools, Transcript}, - Error, -}; -use halo2_curves::CurveAffine; -use halo2_wrong_ecc::AssignedPoint; -use std::rc::Rc; - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> - SameCurveAccumulation>, LIMBS, BITS> -{ - pub fn finalize( - self, - g1: C, - ) -> ( - AssignedPoint, - AssignedPoint, - ) { - let (lhs, rhs) = self.accumulator.unwrap().evaluate(g1.to_curve()); - let loader = lhs.loader(); - ( - loader.ec_point_nomalize(&lhs.assigned()), - loader.ec_point_nomalize(&rhs.assigned()), - ) - } -} - -impl<'a, 'b, C, T, P, const LIMBS: usize, const BITS: usize> - AccumulationStrategy>, T, P> - for SameCurveAccumulation>, LIMBS, BITS> -where - C: CurveAffine, - T: Transcript>>, -{ - type Output = (); - - fn extract_accumulator( - &self, - protocol: &Protocol, - loader: &Rc>, - transcript: &mut T, - statements: &[Vec>], - ) -> Option>>> { - let accumulator_indices = protocol.accumulator_indices.as_ref()?; - - let challenges = transcript.squeeze_n_challenges(accumulator_indices.len()); - let accumulators = accumulator_indices - .iter() - .map(|indices| { - assert_eq!(indices.len(), 4 * LIMBS); - let assinged = indices - .iter() - .map(|index| statements[index.0][index.1].assigned()) - .collect_vec(); - let lhs = loader.assign_ec_point_from_limbs( - assinged[..LIMBS].to_vec().try_into().unwrap(), - assinged[LIMBS..2 * LIMBS].to_vec().try_into().unwrap(), - ); - let rhs = loader.assign_ec_point_from_limbs( - assinged[2 * LIMBS..3 * LIMBS].to_vec().try_into().unwrap(), - assinged[3 * LIMBS..].to_vec().try_into().unwrap(), - ); - Accumulator::new(MSM::base(lhs), MSM::base(rhs)) - }) - .collect_vec(); - - Some(Accumulator::random_linear_combine( - challenges.into_iter().zip(accumulators), - )) - } - - fn process( - &mut self, - _: &Rc>, - transcript: &mut T, - _: P, - accumulator: Accumulator>>, - ) -> Result { - self.accumulator = Some(match self.accumulator.take() { - Some(curr_accumulator) => { - accumulator + curr_accumulator * &transcript.squeeze_challenge() - } - None => accumulator, - }); - Ok(()) - } -} diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index e811debb..0d288ce7 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -1,184 +1,133 @@ use crate::{ - loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::{Curve, Field, FieldOps, Group, Itertools}, -}; -use halo2_curves::CurveAffine; -use halo2_proofs::circuit; -use halo2_wrong_ecc::{ - integer::{ - rns::{Integer, Rns}, - IntegerInstructions, Range, + loader::{ + halo2::shim::{EccInstructions, IntegerInstructions}, + EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader, }, - maingate::{ - AssignedValue, CombinationOptionCommon, MainGate, MainGateInstructions, RegionCtx, Term, + util::{ + arithmetic::{CurveAffine, Field, FieldOps}, + Itertools, }, - AssignedPoint, BaseFieldEccChip, EccConfig, }; -use rand::rngs::OsRng; +use halo2_proofs::circuit; use std::{ - cell::RefCell, + cell::{Ref, RefCell, RefMut}, + collections::btree_map::{BTreeMap, Entry}, fmt::{self, Debug}, iter, - ops::{Add, AddAssign, Deref, DerefMut, Mul, MulAssign, Neg, Sub, SubAssign}, + marker::PhantomData, + ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Sub, SubAssign}, rc::Rc, }; -const WINDOW_SIZE: usize = 3; - -#[derive(Clone, Debug)] -pub enum Value { - Constant(T), - Assigned(L), -} - -pub struct Halo2Loader<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> { - rns: Rc>, - ecc_chip: RefCell>, - main_gate: MainGate, - ctx: RefCell>, +#[derive(Debug)] +pub struct Halo2Loader<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { + ecc_chip: RefCell, + ctx: RefCell, + num_scalar: RefCell, num_ec_point: RefCell, + const_ec_point: RefCell>>, + _marker: PhantomData, #[cfg(test)] row_meterings: RefCell>, } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> - Halo2Loader<'a, 'b, C, LIMBS, BITS> -{ - pub fn new(ecc_config: EccConfig, ctx: RegionCtx<'a, 'b, C::Scalar>) -> Rc { - let ecc_chip = BaseFieldEccChip::new(ecc_config); - let main_gate = ecc_chip.main_gate(); +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, EccChip> { + pub fn new(ecc_chip: EccChip, ctx: EccChip::Context) -> Rc { Rc::new(Self { - rns: Rc::new(Rns::construct()), ecc_chip: RefCell::new(ecc_chip), - main_gate, ctx: RefCell::new(ctx), - num_ec_point: RefCell::new(0), + num_scalar: RefCell::default(), + num_ec_point: RefCell::default(), + const_ec_point: RefCell::default(), #[cfg(test)] - row_meterings: RefCell::new(Vec::new()), + row_meterings: RefCell::default(), + _marker: PhantomData, }) } - pub fn rns(&self) -> Rc> { - self.rns.clone() + pub fn into_ctx(self) -> EccChip::Context { + self.ctx.into_inner() } - pub fn ecc_chip(&self) -> impl Deref> + '_ { + pub fn ecc_chip(&self) -> Ref<'_, EccChip> { self.ecc_chip.borrow() } - pub(super) fn ctx_mut(&self) -> impl DerefMut> + '_ { + pub fn scalar_chip(&self) -> Ref<'_, EccChip::ScalarChip> { + Ref::map(self.ecc_chip(), |ecc_chip| ecc_chip.scalar_chip()) + } + + pub fn ctx(&self) -> Ref<'_, EccChip::Context> { + self.ctx.borrow() + } + + pub(crate) fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { self.ctx.borrow_mut() } - pub fn assign_const_scalar( - self: &Rc, - scalar: C::Scalar, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + pub fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> Scalar<'a, C, EccChip> { let assigned = self - .main_gate - .assign_constant(&mut self.ctx_mut(), scalar) + .scalar_chip() + .assign_constant(&mut self.ctx_mut(), constant) .unwrap(); self.scalar(Value::Assigned(assigned)) } pub fn assign_scalar( self: &Rc, - scalar: circuit::Value, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + scalar: circuit::Value, + ) -> Scalar<'a, C, EccChip> { let assigned = self - .main_gate - .assign_value(&mut self.ctx_mut(), scalar) + .scalar_chip() + .assign_integer(&mut self.ctx_mut(), scalar) .unwrap(); self.scalar(Value::Assigned(assigned)) } - pub fn scalar( + pub(crate) fn scalar( self: &Rc, - value: Value>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + value: Value, + ) -> Scalar<'a, C, EccChip> { + let index = *self.num_scalar.borrow(); + *self.num_scalar.borrow_mut() += 1; Scalar { loader: self.clone(), + index, value, } } - pub fn assign_const_ec_point(self: &Rc, ec_point: C) -> EcPoint<'a, 'b, C, LIMBS, BITS> { - let assigned = self - .ecc_chip - .borrow() - .assign_constant(&mut self.ctx_mut(), ec_point) - .unwrap(); - self.ec_point(assigned) + pub fn assign_const_ec_point(self: &Rc, constant: C) -> EcPoint<'a, C, EccChip> { + let coordinates = constant.coordinates().unwrap(); + match self + .const_ec_point + .borrow_mut() + .entry((*coordinates.x(), *coordinates.y())) + { + Entry::Occupied(entry) => entry.get().clone(), + Entry::Vacant(entry) => { + let assigned = self + .ecc_chip() + .assign_point(&mut self.ctx_mut(), circuit::Value::known(constant)) + .unwrap(); + let ec_point = self.ec_point(assigned); + entry.insert(ec_point).clone() + } + } } pub fn assign_ec_point( self: &Rc, ec_point: circuit::Value, - ) -> EcPoint<'a, 'b, C, LIMBS, BITS> { + ) -> EcPoint<'a, C, EccChip> { let assigned = self - .ecc_chip - .borrow() + .ecc_chip() .assign_point(&mut self.ctx_mut(), ec_point) .unwrap(); self.ec_point(assigned) } - pub fn assign_ec_point_from_limbs( - self: &Rc, - x_limbs: [AssignedValue; LIMBS], - y_limbs: [AssignedValue; LIMBS], - ) -> EcPoint<'a, 'b, C, LIMBS, BITS> { - let [x, y] = [&x_limbs, &y_limbs] - .map(|limbs| { - limbs.iter().enumerate().fold( - circuit::Value::known([C::Scalar::zero(); LIMBS]), - |acc, (idx, limb)| { - acc.zip(limb.value()).map(|(mut acc, limb)| { - acc[idx] = *limb; - acc - }) - }, - ) - }) - .map(|limbs| { - self.ecc_chip - .borrow() - .integer_chip() - .assign_integer( - &mut self.ctx_mut(), - limbs - .map(|limbs| Integer::from_limbs(&limbs, self.rns.clone())) - .into(), - Range::Remainder, - ) - .unwrap() - }); - - let ec_point = AssignedPoint::new(x, y); - self.ecc_chip() - .assert_is_on_curve(&mut self.ctx_mut(), &ec_point) - .unwrap(); - - for (src, dst) in x_limbs.iter().chain(y_limbs.iter()).zip( - ec_point - .get_x() - .limbs() - .iter() - .chain(ec_point.get_y().limbs().iter()), - ) { - self.ctx - .borrow_mut() - .constrain_equal(src.cell(), dst.as_ref().cell()) - .unwrap(); - } - - self.ec_point(ec_point) - } - - pub fn ec_point( - self: &Rc, - assigned: AssignedPoint, - ) -> EcPoint<'a, 'b, C, LIMBS, BITS> { + fn ec_point(self: &Rc, assigned: EccChip::AssignedEcPoint) -> EcPoint<'a, C, EccChip> { let index = *self.num_ec_point.borrow(); *self.num_ec_point.borrow_mut() += 1; EcPoint { @@ -188,71 +137,66 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> } } - pub fn ec_point_nomalize( - self: &Rc, - assigned: &AssignedPoint, - ) -> AssignedPoint { - self.ecc_chip() - .normalize(&mut self.ctx_mut(), assigned) - .unwrap() - } - fn add( self: &Rc, - lhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - rhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Scalar<'a, C, EccChip> { let output = match (&lhs.value, &rhs.value) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs + rhs), (Value::Assigned(assigned), Value::Constant(constant)) - | (Value::Constant(constant), Value::Assigned(assigned)) => { - MainGateInstructions::add_constant( - &self.main_gate, + | (Value::Constant(constant), Value::Assigned(assigned)) => self + .scalar_chip() + .sum_with_coeff_and_const( &mut self.ctx_mut(), - assigned, + &[(C::Scalar::one(), assigned.clone())], *constant, ) .map(Value::Assigned) - .unwrap() - } - (Value::Assigned(lhs), Value::Assigned(rhs)) => { - MainGateInstructions::add(&self.main_gate, &mut self.ctx_mut(), lhs, rhs) - .map(Value::Assigned) - .unwrap() - } + .unwrap(), + (Value::Assigned(lhs), Value::Assigned(rhs)) => self + .scalar_chip() + .sum_with_coeff_and_const( + &mut self.ctx_mut(), + &[ + (C::Scalar::one(), lhs.clone()), + (C::Scalar::one(), rhs.clone()), + ], + C::Scalar::zero(), + ) + .map(Value::Assigned) + .unwrap(), }; self.scalar(output) } fn sub( self: &Rc, - lhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - rhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Scalar<'a, C, EccChip> { let output = match (&lhs.value, &rhs.value) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs - rhs), - (Value::Constant(constant), Value::Assigned(assigned)) => { - MainGateInstructions::neg_with_constant( - &self.main_gate, + (Value::Constant(constant), Value::Assigned(assigned)) => self + .scalar_chip() + .sum_with_coeff_and_const( &mut self.ctx_mut(), - assigned, + &[(-C::Scalar::one(), assigned.clone())], *constant, ) .map(Value::Assigned) - .unwrap() - } - (Value::Assigned(assigned), Value::Constant(constant)) => { - MainGateInstructions::add_constant( - &self.main_gate, + .unwrap(), + (Value::Assigned(assigned), Value::Constant(constant)) => self + .scalar_chip() + .sum_with_coeff_and_const( &mut self.ctx_mut(), - assigned, - constant.neg(), + &[(C::Scalar::one(), assigned.clone())], + -*constant, ) .map(Value::Assigned) - .unwrap() - } + .unwrap(), (Value::Assigned(lhs), Value::Assigned(rhs)) => { - MainGateInstructions::sub(&self.main_gate, &mut self.ctx_mut(), lhs, rhs) + IntegerInstructions::sub(self.scalar_chip().deref(), &mut self.ctx_mut(), lhs, rhs) .map(Value::Assigned) .unwrap() } @@ -262,89 +206,78 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> fn mul( self: &Rc, - lhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - rhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Scalar<'a, C, EccChip> { let output = match (&lhs.value, &rhs.value) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs * rhs), (Value::Assigned(assigned), Value::Constant(constant)) - | (Value::Constant(constant), Value::Assigned(assigned)) => { - MainGateInstructions::apply( - &self.main_gate, + | (Value::Constant(constant), Value::Assigned(assigned)) => self + .scalar_chip() + .sum_with_coeff_and_const( &mut self.ctx_mut(), - [ - Term::Assigned(assigned, *constant), - Term::unassigned_to_sub( - assigned.value().map(|assigned| *assigned * constant), - ), - ], + &[(*constant, assigned.clone())], C::Scalar::zero(), - CombinationOptionCommon::OneLinerAdd.into(), ) - .map(|mut assigned| Value::Assigned(assigned.swap_remove(1))) - .unwrap() - } - (Value::Assigned(lhs), Value::Assigned(rhs)) => { - MainGateInstructions::mul(&self.main_gate, &mut self.ctx_mut(), lhs, rhs) - .map(Value::Assigned) - .unwrap() - } + .map(Value::Assigned) + .unwrap(), + (Value::Assigned(lhs), Value::Assigned(rhs)) => self + .scalar_chip() + .sum_products_with_coeff_and_const( + &mut self.ctx_mut(), + &[(C::Scalar::one(), lhs.clone(), rhs.clone())], + C::Scalar::zero(), + ) + .map(Value::Assigned) + .unwrap(), }; self.scalar(output) } - fn neg( - self: &Rc, - scalar: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + fn neg(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { let output = match &scalar.value { Value::Constant(constant) => Value::Constant(constant.neg()), - Value::Assigned(assigned) => MainGateInstructions::neg_with_constant( - &self.main_gate, - &mut self.ctx_mut(), - assigned, - C::Scalar::zero(), - ) - .map(Value::Assigned) - .unwrap(), + Value::Assigned(assigned) => { + IntegerInstructions::neg(self.scalar_chip().deref(), &mut self.ctx_mut(), assigned) + .map(Value::Assigned) + .unwrap() + } }; self.scalar(output) } - fn invert( - self: &Rc, - scalar: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + fn invert(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { let output = match &scalar.value { Value::Constant(constant) => Value::Constant(Field::invert(constant).unwrap()), - Value::Assigned(assigned) => { - let (inv, non_invertable) = - MainGateInstructions::invert(&self.main_gate, &mut self.ctx_mut(), assigned) - .unwrap(); - self.main_gate - .assert_zero(&mut self.ctx_mut(), &non_invertable) - .unwrap(); - Value::Assigned(inv) - } + Value::Assigned(assigned) => Value::Assigned( + IntegerInstructions::invert( + self.scalar_chip().deref(), + &mut self.ctx_mut(), + assigned, + ) + .unwrap(), + ), }; self.scalar(output) } } #[cfg(test)] -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> - Halo2Loader<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, EccChip> { fn start_row_metering(self: &Rc, identifier: &str) { + use crate::loader::halo2::shim::Context; + self.row_meterings .borrow_mut() - .push((identifier.to_string(), *self.ctx.borrow().offset)) + .push((identifier.to_string(), self.ctx().offset())) } fn end_row_metering(self: &Rc) { + use crate::loader::halo2::shim::Context; + let mut row_meterings = self.row_meterings.borrow_mut(); let (_, row) = row_meterings.last_mut().unwrap(); - *row = *self.ctx.borrow().offset - *row; + *row = self.ctx().offset() - *row; } pub fn print_row_metering(self: &Rc) { @@ -354,14 +287,25 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> } } +#[derive(Clone, Debug)] +pub enum Value { + Constant(T), + Assigned(L), +} + #[derive(Clone)] -pub struct Scalar<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> { - loader: Rc>, - value: Value>, +pub struct Scalar<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { + loader: Rc>, + index: usize, + value: Value, } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Scalar<'a, 'b, C, LIMBS, BITS> { - pub fn assigned(&self) -> AssignedValue { +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> { + pub fn loader(&self) -> &Rc> { + &self.loader + } + + pub(crate) fn assigned(&self) -> EccChip::AssignedScalar { match &self.value { Value::Constant(constant) => self.loader.assign_const_scalar(*constant).assigned(), Value::Assigned(assigned) => assigned.clone(), @@ -369,19 +313,23 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Scalar<'a, ' } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> LoadedScalar - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for Scalar<'a, C, EccChip> { + fn eq(&self, other: &Self) -> bool { + self.index == other.index + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedScalar + for Scalar<'a, C, EccChip> { - type Loader = Rc>; + type Loader = Rc>; fn loader(&self) -> &Self::Loader { &self.loader } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Debug - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for Scalar<'a, C, EccChip> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Scalar") .field("value", &self.value) @@ -389,166 +337,146 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Debug } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> FieldOps - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> FieldOps for Scalar<'a, C, EccChip> { fn invert(&self) -> Option { - Some((&self.loader).invert(self)) + Some(self.loader.invert(self)) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Add - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add for Scalar<'a, C, EccChip> { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - (&self.loader).add(&self, &rhs) + Halo2Loader::add(&self.loader, &self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Sub - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub for Scalar<'a, C, EccChip> { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - (&self.loader).sub(&self, &rhs) + Halo2Loader::sub(&self.loader, &self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Mul - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul for Scalar<'a, C, EccChip> { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - (&self.loader).mul(&self, &rhs) + Halo2Loader::mul(&self.loader, &self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Neg - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Neg for Scalar<'a, C, EccChip> { type Output = Self; fn neg(self) -> Self::Output { - (&self.loader).neg(&self) + Halo2Loader::neg(&self.loader, &self) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Add<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add<&'b Self> + for Scalar<'a, C, EccChip> { type Output = Self; - fn add(self, rhs: &'c Self) -> Self::Output { - (&self.loader).add(&self, rhs) + fn add(self, rhs: &'b Self) -> Self::Output { + Halo2Loader::add(&self.loader, &self, rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Sub<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub<&'b Self> + for Scalar<'a, C, EccChip> { type Output = Self; - fn sub(self, rhs: &'c Self) -> Self::Output { - (&self.loader).sub(&self, rhs) + fn sub(self, rhs: &'b Self) -> Self::Output { + Halo2Loader::sub(&self.loader, &self, rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Mul<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul<&'b Self> + for Scalar<'a, C, EccChip> { type Output = Self; - fn mul(self, rhs: &'c Self) -> Self::Output { - (&self.loader).mul(&self, rhs) + fn mul(self, rhs: &'b Self) -> Self::Output { + Halo2Loader::mul(&self.loader, &self, rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> AddAssign - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> AddAssign for Scalar<'a, C, EccChip> { fn add_assign(&mut self, rhs: Self) { - *self = (&self.loader).add(self, &rhs) + *self = Halo2Loader::add(&self.loader, self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> SubAssign - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> SubAssign for Scalar<'a, C, EccChip> { fn sub_assign(&mut self, rhs: Self) { - *self = (&self.loader).sub(self, &rhs) + *self = Halo2Loader::sub(&self.loader, self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> MulAssign - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign for Scalar<'a, C, EccChip> { fn mul_assign(&mut self, rhs: Self) { - *self = (&self.loader).mul(self, &rhs) + *self = Halo2Loader::mul(&self.loader, self, &rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> AddAssign<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> AddAssign<&'b Self> + for Scalar<'a, C, EccChip> { - fn add_assign(&mut self, rhs: &'c Self) { - *self = (&self.loader).add(self, rhs) + fn add_assign(&mut self, rhs: &'b Self) { + *self = Halo2Loader::add(&self.loader, self, rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> SubAssign<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> SubAssign<&'b Self> + for Scalar<'a, C, EccChip> { - fn sub_assign(&mut self, rhs: &'c Self) { - *self = (&self.loader).sub(self, rhs) + fn sub_assign(&mut self, rhs: &'b Self) { + *self = Halo2Loader::sub(&self.loader, self, rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> MulAssign<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign<&'b Self> + for Scalar<'a, C, EccChip> { - fn mul_assign(&mut self, rhs: &'c Self) { - *self = (&self.loader).mul(self, rhs) + fn mul_assign(&mut self, rhs: &'b Self) { + *self = Halo2Loader::mul(&self.loader, self, rhs) } } #[derive(Clone)] -pub struct EcPoint<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> { - loader: Rc>, +pub struct EcPoint<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { + loader: Rc>, index: usize, - assigned: AssignedPoint, + assigned: EccChip::AssignedEcPoint, } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> - EcPoint<'a, 'b, C, LIMBS, BITS> -{ - pub fn assigned(&self) -> AssignedPoint { +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip> { + pub fn assigned(&self) -> EccChip::AssignedEcPoint { self.assigned.clone() } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> PartialEq - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for EcPoint<'a, C, EccChip> { fn eq(&self, other: &Self) -> bool { self.index == other.index } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> LoadedEcPoint - for EcPoint<'a, 'b, C, LIMBS, BITS> +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedEcPoint + for EcPoint<'a, C, EccChip> { - type Loader = Rc>; + type Loader = Rc>; fn loader(&self) -> &Self::Loader { &self.loader } fn multi_scalar_multiplication( - pairs: impl IntoIterator, Self)>, + pairs: impl IntoIterator, Self)>, ) -> Self { let pairs = pairs.into_iter().collect_vec(); let loader = &pairs[0].0.loader; @@ -570,42 +498,37 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> LoadedEcPoin .chain(if scaled.is_empty() { None } else { - let aux_generator = ::CurveExt::random(OsRng).to_affine(); - loader - .ecc_chip - .borrow_mut() - .assign_aux_generator( - &mut loader.ctx.borrow_mut(), - circuit::Value::known(aux_generator), - ) - .unwrap(); - loader - .ecc_chip - .borrow_mut() - .assign_aux(&mut loader.ctx.borrow_mut(), WINDOW_SIZE, scaled.len()) - .unwrap(); Some( loader .ecc_chip - .borrow() - .mul_batch_1d_horizontal(&mut loader.ctx.borrow_mut(), scaled, WINDOW_SIZE) + .borrow_mut() + .multi_scalar_multiplication(&mut loader.ctx_mut(), scaled) .unwrap(), ) }) .chain(non_scaled) .reduce(|acc, ec_point| { - (loader.ecc_chip().deref()) - .add(&mut loader.ctx.borrow_mut(), &acc, &ec_point) + EccInstructions::add( + loader.ecc_chip().deref(), + &mut loader.ctx_mut(), + &acc, + &ec_point, + ) + .unwrap() + }) + .map(|output| { + loader + .ecc_chip() + .normalize(&mut loader.ctx_mut(), &output) .unwrap() }) .unwrap(); + loader.ec_point(output) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Debug - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for EcPoint<'a, C, EccChip> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("EcPoint") .field("index", &self.index) @@ -614,110 +537,82 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Debug } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Add - for EcPoint<'a, 'b, C, LIMBS, BITS> +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> ScalarLoader + for Rc> { - type Output = Self; - - fn add(self, _: Self) -> Self::Output { - todo!() - } -} - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Sub - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - type Output = Self; - - fn sub(self, _: Self) -> Self::Output { - todo!() - } -} + type LoadedScalar = Scalar<'a, C, EccChip>; -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Neg - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - type Output = Self; - - fn neg(self) -> Self::Output { - todo!() - } -} - -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Add<&'c Self> - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - type Output = Self; - - fn add(self, rhs: &'c Self) -> Self::Output { - self + rhs.clone() - } -} - -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Sub<&'c Self> - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - type Output = Self; - - fn sub(self, rhs: &'c Self) -> Self::Output { - self - rhs.clone() - } -} - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> AddAssign - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - fn add_assign(&mut self, rhs: Self) { - *self = self.clone() + rhs - } -} - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> SubAssign - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - fn sub_assign(&mut self, rhs: Self) { - *self = self.clone() - rhs - } -} - -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> AddAssign<&'c Self> - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - fn add_assign(&mut self, rhs: &'c Self) { - *self = self.clone() + rhs - } -} - -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> SubAssign<&'c Self> - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - fn sub_assign(&mut self, rhs: &'c Self) { - *self = self.clone() - rhs - } -} - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> ScalarLoader - for Rc> -{ - type LoadedScalar = Scalar<'a, 'b, C, LIMBS, BITS>; - - fn load_const(&self, value: &C::Scalar) -> Scalar<'a, 'b, C, LIMBS, BITS> { + fn load_const(&self, value: &C::Scalar) -> Scalar<'a, C, EccChip> { self.scalar(Value::Constant(*value)) } -} -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> EcPointLoader - for Rc> -{ - type LoadedEcPoint = EcPoint<'a, 'b, C, LIMBS, BITS>; - - fn ec_point_load_const(&self, ec_point: &C::CurveExt) -> EcPoint<'a, 'b, C, LIMBS, BITS> { - self.assign_const_ec_point(ec_point.to_affine()) + fn assert_eq( + &self, + annotation: &str, + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Result<(), crate::Error> { + self.scalar_chip() + .assert_equal(&mut self.ctx_mut(), &lhs.assigned(), &rhs.assigned()) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + } + + fn sum_with_coeff_and_const( + &self, + values: &[(C::Scalar, &Scalar<'a, C, EccChip>)], + constant: C::Scalar, + ) -> Scalar<'a, C, EccChip> { + let values = values + .iter() + .map(|(coeff, value)| (*coeff, value.assigned())) + .collect_vec(); + self.scalar(Value::Assigned( + self.scalar_chip() + .sum_with_coeff_and_const(&mut self.ctx_mut(), &values, constant) + .unwrap(), + )) + } + + fn sum_products_with_coeff_and_const( + &self, + values: &[(C::Scalar, &Scalar<'a, C, EccChip>, &Scalar<'a, C, EccChip>)], + constant: C::Scalar, + ) -> Scalar<'a, C, EccChip> { + let values = values + .iter() + .map(|(coeff, lhs, rhs)| (*coeff, lhs.assigned(), rhs.assigned())) + .collect_vec(); + self.scalar(Value::Assigned( + self.scalar_chip() + .sum_products_with_coeff_and_const(&mut self.ctx_mut(), &values, constant) + .unwrap(), + )) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader + for Rc> +{ + type LoadedEcPoint = EcPoint<'a, C, EccChip>; + + fn ec_point_load_const(&self, ec_point: &C) -> EcPoint<'a, C, EccChip> { + self.assign_const_ec_point(*ec_point) + } + + fn ec_point_assert_eq( + &self, + annotation: &str, + lhs: &EcPoint<'a, C, EccChip>, + rhs: &EcPoint<'a, C, EccChip>, + ) -> Result<(), crate::Error> { + self.ecc_chip() + .assert_equal(&mut self.ctx_mut(), &lhs.assigned(), &rhs.assigned()) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Loader - for Rc> +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Loader + for Rc> { #[cfg(test)] fn start_cost_metering(&self, identifier: &str) { diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs new file mode 100644 index 00000000..67f06cc9 --- /dev/null +++ b/src/loader/halo2/shim.rs @@ -0,0 +1,403 @@ +use crate::util::arithmetic::{CurveAffine, FieldExt}; +use halo2_proofs::{ + circuit::{Cell, Value}, + plonk::Error, +}; +use std::fmt::Debug; + +pub trait Context: Debug { + fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error>; + + fn offset(&self) -> usize; +} + +pub trait IntegerInstructions<'a, F: FieldExt>: Clone + Debug { + type Context: Context; + type Integer: Clone + Debug; + type AssignedInteger: Clone + Debug; + + fn integer(&self, fe: F) -> Self::Integer; + + fn assign_integer( + &self, + ctx: &mut Self::Context, + integer: Value, + ) -> Result; + + fn assign_constant( + &self, + ctx: &mut Self::Context, + integer: F, + ) -> Result; + + fn sum_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F::Scalar, Self::AssignedInteger)], + constant: F::Scalar, + ) -> Result; + + fn sum_products_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F::Scalar, Self::AssignedInteger, Self::AssignedInteger)], + constant: F::Scalar, + ) -> Result; + + fn sub( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result; + + fn neg( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result; + + fn invert( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result; + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result<(), Error>; +} + +pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { + type Context: Context; + type ScalarChip: IntegerInstructions< + 'a, + C::Scalar, + Context = Self::Context, + Integer = Self::Scalar, + AssignedInteger = Self::AssignedScalar, + >; + type AssignedEcPoint: Clone + Debug; + type Scalar: Clone + Debug; + type AssignedScalar: Clone + Debug; + + fn scalar_chip(&self) -> &Self::ScalarChip; + + fn assign_constant( + &self, + ctx: &mut Self::Context, + point: C, + ) -> Result; + + fn assign_point( + &self, + ctx: &mut Self::Context, + point: Value, + ) -> Result; + + fn add( + &self, + ctx: &mut Self::Context, + p0: &Self::AssignedEcPoint, + p1: &Self::AssignedEcPoint, + ) -> Result; + + fn multi_scalar_multiplication( + &mut self, + ctx: &mut Self::Context, + pairs: Vec<(Self::AssignedEcPoint, Self::AssignedScalar)>, + ) -> Result; + + fn normalize( + &self, + ctx: &mut Self::Context, + point: &Self::AssignedEcPoint, + ) -> Result; + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedEcPoint, + b: &Self::AssignedEcPoint, + ) -> Result<(), Error>; +} + +mod halo2_wrong { + use crate::{ + loader::halo2::{Context, EccInstructions, IntegerInstructions}, + util::{ + arithmetic::{CurveAffine, FieldExt, Group}, + Itertools, + }, + }; + use halo2_proofs::{ + circuit::{AssignedCell, Cell, Value}, + plonk::Error, + }; + use halo2_wrong_ecc::{ + integer::rns::Common, + maingate::{ + CombinationOption, CombinationOptionCommon, MainGate, MainGateInstructions, RegionCtx, + Term, + }, + AssignedPoint, BaseFieldEccChip, + }; + use rand::rngs::OsRng; + + impl<'a, F: FieldExt> Context for RegionCtx<'a, F> { + fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { + self.constrain_equal(lhs, rhs) + } + + fn offset(&self) -> usize { + self.offset() + } + } + + impl<'a, F: FieldExt> IntegerInstructions<'a, F> for MainGate { + type Context = RegionCtx<'a, F>; + type Integer = F; + type AssignedInteger = AssignedCell; + + fn integer(&self, scalar: F) -> Self::Integer { + scalar + } + + fn assign_integer( + &self, + ctx: &mut Self::Context, + integer: Value, + ) -> Result { + self.assign_value(ctx, integer) + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + integer: F, + ) -> Result { + MainGateInstructions::assign_constant(self, ctx, integer) + } + + fn sum_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F, Self::AssignedInteger)], + constant: F, + ) -> Result { + self.compose( + ctx, + &values + .iter() + .map(|(coeff, assigned)| Term::Assigned(assigned, *coeff)) + .collect_vec(), + constant, + ) + } + + fn sum_products_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F, Self::AssignedInteger, Self::AssignedInteger)], + constant: F, + ) -> Result { + match values.len() { + 0 => MainGateInstructions::assign_constant(self, ctx, constant), + 1 => { + let (scalar, lhs, rhs) = &values[0]; + let output = lhs + .value() + .zip(rhs.value()) + .map(|(lhs, rhs)| *scalar * lhs * rhs + constant); + + Ok(self + .apply( + ctx, + [ + Term::Zero, + Term::Zero, + Term::assigned_to_mul(lhs), + Term::assigned_to_mul(rhs), + Term::unassigned_to_sub(output), + ], + constant, + CombinationOption::OneLinerDoubleMul(*scalar), + )? + .swap_remove(4)) + } + _ => { + let (scalar, lhs, rhs) = &values[0]; + self.apply( + ctx, + [Term::assigned_to_mul(lhs), Term::assigned_to_mul(rhs)], + constant, + CombinationOptionCommon::CombineToNextScaleMul(-F::one(), *scalar).into(), + )?; + let acc = + Value::known(*scalar) * lhs.value() * rhs.value() + Value::known(constant); + let output = values.iter().skip(1).fold( + Ok::<_, Error>(acc), + |acc, (scalar, lhs, rhs)| { + acc.and_then(|acc| { + self.apply( + ctx, + [ + Term::assigned_to_mul(lhs), + Term::assigned_to_mul(rhs), + Term::Zero, + Term::Zero, + Term::Unassigned(acc, F::one()), + ], + F::zero(), + CombinationOptionCommon::CombineToNextScaleMul( + -F::one(), + *scalar, + ) + .into(), + )?; + Ok(acc + Value::known(*scalar) * lhs.value() * rhs.value()) + }) + }, + )?; + self.apply( + ctx, + [ + Term::Zero, + Term::Zero, + Term::Zero, + Term::Zero, + Term::Unassigned(output, F::zero()), + ], + F::zero(), + CombinationOptionCommon::OneLinerAdd.into(), + ) + .map(|mut outputs| outputs.swap_remove(4)) + } + } + } + + fn sub( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result { + MainGateInstructions::sub(self, ctx, a, b) + } + + fn neg( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result { + MainGateInstructions::neg_with_constant(self, ctx, a, F::zero()) + } + + fn invert( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result { + MainGateInstructions::invert_unsafe(self, ctx, a) + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result<(), Error> { + let mut eq = true; + a.value().zip(b.value()).map(|(lhs, rhs)| { + eq &= lhs == rhs; + }); + MainGateInstructions::assert_equal(self, ctx, a, b) + .and(eq.then_some(()).ok_or(Error::Synthesis)) + } + } + + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> EccInstructions<'a, C> + for BaseFieldEccChip + { + type Context = RegionCtx<'a, C::Scalar>; + type ScalarChip = MainGate; + type AssignedEcPoint = AssignedPoint; + type Scalar = C::Scalar; + type AssignedScalar = AssignedCell; + + fn scalar_chip(&self) -> &Self::ScalarChip { + self.main_gate() + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + point: C, + ) -> Result { + self.assign_constant(ctx, point) + } + + fn assign_point( + &self, + ctx: &mut Self::Context, + point: Value, + ) -> Result { + self.assign_point(ctx, point) + } + + fn add( + &self, + ctx: &mut Self::Context, + p0: &Self::AssignedEcPoint, + p1: &Self::AssignedEcPoint, + ) -> Result { + self.add(ctx, p0, p1) + } + + fn multi_scalar_multiplication( + &mut self, + ctx: &mut Self::Context, + pairs: Vec<(Self::AssignedEcPoint, Self::AssignedScalar)>, + ) -> Result { + const WINDOW_SIZE: usize = 3; + match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { + Err(_) => { + if self.assign_aux(ctx, WINDOW_SIZE, pairs.len()).is_err() { + let aux_generator = Value::known(C::Curve::random(OsRng).into()); + self.assign_aux_generator(ctx, aux_generator)?; + self.assign_aux(ctx, WINDOW_SIZE, pairs.len())?; + } + self.mul_batch_1d_horizontal(ctx, pairs, WINDOW_SIZE) + } + result => result, + } + } + + fn normalize( + &self, + ctx: &mut Self::Context, + point: &Self::AssignedEcPoint, + ) -> Result { + self.normalize(ctx, point) + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedEcPoint, + b: &Self::AssignedEcPoint, + ) -> Result<(), Error> { + let mut eq = true; + [(a.x(), b.x()), (a.y(), b.y())].map(|(lhs, rhs)| { + lhs.integer().zip(rhs.integer()).map(|(lhs, rhs)| { + eq &= lhs.value() == rhs.value(); + }); + }); + self.assert_equal(ctx, a, b) + .and(eq.then_some(()).ok_or(Error::Synthesis)) + } + } +} diff --git a/src/loader/halo2/test.rs b/src/loader/halo2/test.rs new file mode 100644 index 00000000..08551fe0 --- /dev/null +++ b/src/loader/halo2/test.rs @@ -0,0 +1,66 @@ +use crate::{ + util::{arithmetic::CurveAffine, Itertools}, + Protocol, +}; +use halo2_proofs::circuit::Value; + +pub struct Snark { + pub protocol: Protocol, + pub instances: Vec>, + pub proof: Vec, +} + +impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + assert_eq!( + protocol.num_instance, + instances + .iter() + .map(|instances| instances.len()) + .collect_vec() + ); + Snark { + protocol, + instances, + proof, + } + } +} + +pub struct SnarkWitness { + pub protocol: Protocol, + pub instances: Vec>>, + pub proof: Value>, +} + +impl From> for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } +} + +impl SnarkWitness { + pub fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + pub fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } +} diff --git a/src/loader/halo2/transcript.rs b/src/loader/halo2/transcript.rs deleted file mode 100644 index 5bfe57ba..00000000 --- a/src/loader/halo2/transcript.rs +++ /dev/null @@ -1,363 +0,0 @@ -use crate::{ - loader::{ - halo2::loader::{EcPoint, Halo2Loader, Scalar, Value}, - native::NativeLoader, - }, - util::{Curve, GroupEncoding, PrimeField, Transcript, TranscriptRead}, - Error, -}; -use halo2_curves::{Coordinates, CurveAffine}; -use halo2_proofs::circuit; -use halo2_wrong_ecc::integer::rns::{Common, Integer, Rns}; -use halo2_wrong_transcript::{NativeRepresentation, PointRepresentation, TranscriptChip}; -use poseidon::{Poseidon, Spec}; -use std::{ - io::{self, Read, Write}, - marker::PhantomData, - rc::Rc, -}; - -pub struct PoseidonTranscript< - C: CurveAffine, - L, - S, - B, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, -> { - loader: L, - stream: S, - buf: B, - rns: Rc>, - _marker: PhantomData<(C, E)>, -} - -impl< - 'a, - 'b, - C: CurveAffine, - R: Read, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > - PoseidonTranscript< - C, - Rc>, - circuit::Value, - TranscriptChip, - E, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - pub fn new( - loader: &Rc>, - stream: circuit::Value, - ) -> Self { - let transcript_chip = TranscriptChip::new( - &mut loader.ctx_mut(), - &Spec::new(R_F, R_P), - loader.ecc_chip().clone(), - ) - .unwrap(); - Self { - loader: loader.clone(), - stream, - buf: transcript_chip, - rns: Rc::new(Rns::::construct()), - _marker: PhantomData, - } - } -} - -impl< - 'a, - 'b, - C: CurveAffine, - R: Read, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > Transcript>> - for PoseidonTranscript< - C, - Rc>, - circuit::Value, - TranscriptChip, - E, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn squeeze_challenge(&mut self) -> Scalar<'a, 'b, C, LIMBS, BITS> { - let assigned = self.buf.squeeze(&mut self.loader.ctx_mut()).unwrap(); - self.loader.scalar(Value::Assigned(assigned)) - } - - fn common_scalar(&mut self, scalar: &Scalar<'a, 'b, C, LIMBS, BITS>) -> Result<(), Error> { - self.buf.write_scalar(&scalar.assigned()); - Ok(()) - } - - fn common_ec_point(&mut self, ec_point: &EcPoint<'a, 'b, C, LIMBS, BITS>) -> Result<(), Error> { - self.buf - .write_point(&mut self.loader.ctx_mut(), &ec_point.assigned()) - .unwrap(); - Ok(()) - } -} - -impl< - 'a, - 'b, - C: CurveAffine, - R: Read, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead>> - for PoseidonTranscript< - C, - Rc>, - circuit::Value, - TranscriptChip, - E, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn read_scalar(&mut self) -> Result, Error> { - let scalar = self.stream.as_mut().and_then(|stream| { - let mut data = ::Repr::default(); - if stream.read_exact(data.as_mut()).is_err() { - return circuit::Value::unknown(); - } - Option::::from(C::Scalar::from_repr(data)) - .map(circuit::Value::known) - .unwrap_or_else(circuit::Value::unknown) - }); - let scalar = self.loader.assign_scalar(scalar); - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result, Error> { - let ec_point = self.stream.as_mut().and_then(|stream| { - let mut compressed = C::Repr::default(); - if stream.read_exact(compressed.as_mut()).is_err() { - return circuit::Value::unknown(); - } - Option::::from(C::from_bytes(&compressed)) - .map(circuit::Value::known) - .unwrap_or_else(circuit::Value::unknown) - }); - let ec_point = self.loader.assign_ec_point(ec_point); - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl< - C: CurveAffine, - S, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > - PoseidonTranscript< - C, - NativeLoader, - S, - Poseidon, - E, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - pub fn new(stream: S) -> Self { - Self { - loader: NativeLoader, - stream, - buf: Poseidon::new(R_F, R_P), - rns: Rc::new(Rns::::construct()), - _marker: PhantomData, - } - } -} - -impl< - C: CurveAffine, - S, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > Transcript - for PoseidonTranscript< - C, - NativeLoader, - S, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn squeeze_challenge(&mut self) -> C::Scalar { - self.buf.squeeze() - } - - fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { - self.buf.update(&[*scalar]); - Ok(()) - } - - fn common_ec_point(&mut self, ec_point: &C::CurveExt) -> Result<(), Error> { - let coords: Coordinates = - Option::from(ec_point.to_affine().coordinates()).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Cannot write points at infinity to the transcript".to_string(), - ) - })?; - let x = Integer::from_fe(*coords.x(), self.rns.clone()); - let y = Integer::from_fe(*coords.y(), self.rns.clone()); - self.buf.update(&[x.native(), y.native()]); - Ok(()) - } -} - -impl< - C: CurveAffine, - R: Read, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead - for PoseidonTranscript< - C, - NativeLoader, - R, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn read_scalar(&mut self) -> Result { - let mut data = ::Repr::default(); - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid scalar encoding in proof".to_string(), - ) - })?; - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result { - let mut data = C::Repr::default(); - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - let ec_point = Option::::from( - ::from_bytes(&data).map(|ec_point| ec_point.to_curve()), - ) - .ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid elliptic curve point encoding in proof".to_string(), - ) - })?; - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl< - C: CurveAffine, - W: Write, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > - PoseidonTranscript< - C, - NativeLoader, - W, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - pub fn stream_mut(&mut self) -> &mut W { - &mut self.stream - } - - pub fn finalize(self) -> W { - self.stream - } -} diff --git a/src/loader/native.rs b/src/loader/native.rs index d0ee9fb5..6bf9f7c4 100644 --- a/src/loader/native.rs +++ b/src/loader/native.rs @@ -1,4 +1,85 @@ -mod accumulation; -mod loader; +use crate::{ + loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, + util::arithmetic::{Curve, CurveAffine, FieldOps, PrimeField}, + Error, +}; +use lazy_static::lazy_static; +use std::fmt::Debug; -pub use loader::NativeLoader; +lazy_static! { + pub static ref LOADER: NativeLoader = NativeLoader; +} + +#[derive(Clone, Debug)] +pub struct NativeLoader; + +impl LoadedEcPoint for C { + type Loader = NativeLoader; + + fn loader(&self) -> &NativeLoader { + &LOADER + } + + fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { + pairs + .into_iter() + .map(|(scalar, base)| base * scalar) + .reduce(|acc, value| acc + value) + .unwrap() + .to_affine() + } +} + +impl FieldOps for F { + fn invert(&self) -> Option { + self.invert().into() + } +} + +impl LoadedScalar for F { + type Loader = NativeLoader; + + fn loader(&self) -> &NativeLoader { + &LOADER + } +} + +impl EcPointLoader for NativeLoader { + type LoadedEcPoint = C; + + fn ec_point_load_const(&self, value: &C) -> Self::LoadedEcPoint { + *value + } + + fn ec_point_assert_eq( + &self, + annotation: &str, + lhs: &Self::LoadedEcPoint, + rhs: &Self::LoadedEcPoint, + ) -> Result<(), Error> { + lhs.eq(rhs) + .then_some(()) + .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) + } +} + +impl ScalarLoader for NativeLoader { + type LoadedScalar = F; + + fn load_const(&self, value: &F) -> Self::LoadedScalar { + *value + } + + fn assert_eq( + &self, + annotation: &str, + lhs: &Self::LoadedScalar, + rhs: &Self::LoadedScalar, + ) -> Result<(), Error> { + lhs.eq(rhs) + .then_some(()) + .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) + } +} + +impl Loader for NativeLoader {} diff --git a/src/loader/native/accumulation.rs b/src/loader/native/accumulation.rs deleted file mode 100644 index d0b67dc9..00000000 --- a/src/loader/native/accumulation.rs +++ /dev/null @@ -1,111 +0,0 @@ -use crate::{ - loader::native::NativeLoader, - protocol::Protocol, - scheme::kzg::{AccumulationStrategy, Accumulator, SameCurveAccumulation, MSM}, - util::{fe_from_limbs, Curve, Group, Itertools, PrimeCurveAffine, Transcript}, - Error, -}; -use halo2_curves::{ - pairing::{MillerLoopResult, MultiMillerLoop}, - CurveAffine, CurveExt, -}; - -impl - SameCurveAccumulation -{ - pub fn finalize(self, g1: C) -> (C, C) { - self.accumulator.unwrap().evaluate(g1) - } -} - -impl - SameCurveAccumulation -{ - pub fn decide>( - self, - g1: M::G1Affine, - g2: M::G2Affine, - s_g2: M::G2Affine, - ) -> bool { - let (lhs, rhs) = self.finalize(g1.to_curve()); - - let g2 = M::G2Prepared::from(g2); - let minus_s_g2 = M::G2Prepared::from(-s_g2); - - let terms = [(&lhs.into(), &g2), (&rhs.into(), &minus_s_g2)]; - M::multi_miller_loop(&terms) - .final_exponentiation() - .is_identity() - .into() - } -} - -impl AccumulationStrategy - for SameCurveAccumulation -where - C: CurveExt, - T: Transcript, -{ - type Output = P; - - fn extract_accumulator( - &self, - protocol: &Protocol, - _: &NativeLoader, - transcript: &mut T, - statements: &[Vec], - ) -> Option> { - let accumulator_indices = protocol.accumulator_indices.as_ref()?; - - let challenges = transcript.squeeze_n_challenges(accumulator_indices.len()); - let accumulators = accumulator_indices - .iter() - .map(|indices| { - assert_eq!(indices.len(), 4 * LIMBS); - let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = indices - .chunks(4) - .into_iter() - .map(|indices| { - fe_from_limbs::<_, _, LIMBS, BITS>( - indices - .iter() - .map(|index| statements[index.0][index.1]) - .collect_vec() - .try_into() - .unwrap(), - ) - }) - .collect_vec() - .try_into() - .unwrap(); - let lhs = ::from_xy(lhs_x, lhs_y) - .unwrap() - .to_curve(); - let rhs = ::from_xy(rhs_x, rhs_y) - .unwrap() - .to_curve(); - Accumulator::new(MSM::base(lhs), MSM::base(rhs)) - }) - .collect_vec(); - - Some(Accumulator::random_linear_combine( - challenges.into_iter().zip(accumulators), - )) - } - - fn process( - &mut self, - _: &NativeLoader, - transcript: &mut T, - proof: P, - accumulator: Accumulator, - ) -> Result { - self.accumulator = Some(match self.accumulator.take() { - Some(curr_accumulator) => { - accumulator + curr_accumulator * &transcript.squeeze_challenge() - } - None => accumulator, - }); - Ok(proof) - } -} diff --git a/src/loader/native/loader.rs b/src/loader/native/loader.rs deleted file mode 100644 index 12bb475e..00000000 --- a/src/loader/native/loader.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::{ - loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::{Curve, FieldOps, PrimeField}, -}; -use lazy_static::lazy_static; -use std::fmt::Debug; - -lazy_static! { - static ref LOADER: NativeLoader = NativeLoader; -} - -impl LoadedEcPoint for C { - type Loader = NativeLoader; - - fn loader(&self) -> &NativeLoader { - &LOADER - } - - fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { - pairs - .into_iter() - .map(|(scalar, base)| base * scalar) - .reduce(|acc, value| acc + value) - .unwrap() - } -} - -impl FieldOps for F { - fn invert(&self) -> Option { - self.invert().into() - } -} - -impl LoadedScalar for F { - type Loader = NativeLoader; - - fn loader(&self) -> &NativeLoader { - &LOADER - } -} - -#[derive(Clone, Debug)] -pub struct NativeLoader; - -impl EcPointLoader for NativeLoader { - type LoadedEcPoint = C; - - fn ec_point_load_const(&self, value: &C) -> Self::LoadedEcPoint { - *value - } -} - -impl ScalarLoader for NativeLoader { - type LoadedScalar = F; - - fn load_const(&self, value: &F) -> Self::LoadedScalar { - *value - } -} - -impl Loader for NativeLoader {} diff --git a/src/pcs.rs b/src/pcs.rs new file mode 100644 index 00000000..65804895 --- /dev/null +++ b/src/pcs.rs @@ -0,0 +1,138 @@ +use crate::{ + loader::{native::NativeLoader, Loader}, + util::{ + arithmetic::{CurveAffine, PrimeField}, + msm::Msm, + transcript::{TranscriptRead, TranscriptWrite}, + }, + Error, +}; +use rand::Rng; +use std::fmt::Debug; + +pub mod kzg; + +pub trait PolynomialCommitmentScheme: Clone + Debug +where + C: CurveAffine, + L: Loader, +{ + type Accumulator: Clone + Debug; +} + +#[derive(Clone, Debug)] +pub struct Query { + pub poly: usize, + pub shift: F, + pub eval: T, +} + +impl Query { + pub fn with_evaluation(self, eval: T) -> Query { + Query { + poly: self.poly, + shift: self.shift, + eval, + } + } +} + +pub trait MultiOpenScheme: PolynomialCommitmentScheme +where + C: CurveAffine, + L: Loader, +{ + type SuccinctVerifyingKey: Clone + Debug; + type Proof: Clone + Debug; + + fn read_proof( + svk: &Self::SuccinctVerifyingKey, + queries: &[Query], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead; + + fn succinct_verify( + svk: &Self::SuccinctVerifyingKey, + commitments: &[Msm], + point: &L::LoadedScalar, + queries: &[Query], + proof: &Self::Proof, + ) -> Result; +} + +pub trait Decider: PolynomialCommitmentScheme +where + C: CurveAffine, + L: Loader, +{ + type DecidingKey: Clone + Debug; + type Output: Clone + Debug; + + fn decide(dk: &Self::DecidingKey, accumulator: Self::Accumulator) -> Self::Output; + + fn decide_all(dk: &Self::DecidingKey, accumulators: Vec) -> Self::Output; +} + +pub trait AccumulationScheme: Clone + Debug +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme, +{ + type VerifyingKey: Clone + Debug; + type Proof: Clone + Debug; + + fn read_proof( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead; + + fn verify( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + proof: &Self::Proof, + ) -> Result; +} + +pub trait AccumulationSchemeProver: AccumulationScheme +where + C: CurveAffine, + PCS: PolynomialCommitmentScheme, +{ + type ProvingKey: Clone + Debug; + + fn create_proof( + pk: &Self::ProvingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + rng: R, + ) -> Result + where + T: TranscriptWrite, + R: Rng; +} + +pub trait AccumulatorEncoding: Clone + Debug +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme, +{ + fn from_repr(repr: Vec) -> Result; +} + +impl AccumulatorEncoding for () +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme, +{ + fn from_repr(_: Vec) -> Result { + unimplemented!() + } +} diff --git a/src/pcs/kzg.rs b/src/pcs/kzg.rs new file mode 100644 index 00000000..9f10bd44 --- /dev/null +++ b/src/pcs/kzg.rs @@ -0,0 +1,45 @@ +use crate::{ + loader::Loader, + pcs::PolynomialCommitmentScheme, + util::arithmetic::{CurveAffine, MultiMillerLoop}, +}; +use std::{fmt::Debug, marker::PhantomData}; + +mod accumulation; +mod accumulator; +mod decider; +mod multiopen; + +pub use accumulation::{KzgAs, KzgAsProvingKey, KzgAsVerifyingKey}; +pub use accumulator::{KzgAccumulator, LimbsEncoding}; +pub use decider::KzgDecidingKey; +pub use multiopen::{Bdfg21, Bdfg21Proof, Gwc19, Gwc19Proof}; + +#[derive(Clone, Debug)] +pub struct Kzg(PhantomData<(M, MOS)>); + +impl PolynomialCommitmentScheme for Kzg +where + M: MultiMillerLoop, + L: Loader, + MOS: Clone + Debug, +{ + type Accumulator = KzgAccumulator; +} + +#[derive(Clone, Copy, Debug)] +pub struct KzgSuccinctVerifyingKey { + pub g: C, +} + +impl KzgSuccinctVerifyingKey { + pub fn new(g: C) -> Self { + Self { g } + } +} + +impl From for KzgSuccinctVerifyingKey { + fn from(g: C) -> KzgSuccinctVerifyingKey { + KzgSuccinctVerifyingKey::new(g) + } +} diff --git a/src/pcs/kzg/accumulation.rs b/src/pcs/kzg/accumulation.rs new file mode 100644 index 00000000..cd13fd00 --- /dev/null +++ b/src/pcs/kzg/accumulation.rs @@ -0,0 +1,196 @@ +use crate::{ + loader::{native::NativeLoader, LoadedScalar, Loader}, + pcs::{ + kzg::KzgAccumulator, AccumulationScheme, AccumulationSchemeProver, + PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{Curve, CurveAffine, Field}, + msm::Msm, + transcript::{TranscriptRead, TranscriptWrite}, + }, + Error, +}; +use rand::Rng; +use std::marker::PhantomData; + +#[derive(Clone, Debug)] +pub struct KzgAs(PhantomData); + +impl AccumulationScheme for KzgAs +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + type VerifyingKey = KzgAsVerifyingKey; + type Proof = KzgAsProof; + + fn read_proof( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + KzgAsProof::read(vk, instances, transcript) + } + + fn verify( + _: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + proof: &Self::Proof, + ) -> Result { + let (lhs, rhs) = instances + .iter() + .cloned() + .map(|accumulator| (accumulator.lhs, accumulator.rhs)) + .chain(proof.blind.clone()) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let powers_of_r = proof.r.powers(lhs.len()); + let [lhs, rhs] = [lhs, rhs].map(|msms| { + msms.into_iter() + .zip(powers_of_r.iter()) + .map(|(msm, r)| Msm::::base(msm) * r) + .sum::>() + .evaluate(None) + }); + + Ok(KzgAccumulator::new(lhs, rhs)) + } +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct KzgAsProvingKey(pub Option<(C, C)>); + +impl KzgAsProvingKey { + pub fn new(g: Option<(C, C)>) -> Self { + Self(g) + } + + pub fn zk(&self) -> bool { + self.0.is_some() + } + + pub fn vk(&self) -> KzgAsVerifyingKey { + KzgAsVerifyingKey(self.zk()) + } +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct KzgAsVerifyingKey(bool); + +impl KzgAsVerifyingKey { + pub fn zk(&self) -> bool { + self.0 + } +} + +#[derive(Clone, Debug)] +pub struct KzgAsProof +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + blind: Option<(L::LoadedEcPoint, L::LoadedEcPoint)>, + r: L::LoadedScalar, + _marker: PhantomData, +} + +impl KzgAsProof +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + fn read( + vk: &KzgAsVerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + assert!(!instances.is_empty()); + + for accumulator in instances { + transcript.common_ec_point(&accumulator.lhs)?; + transcript.common_ec_point(&accumulator.rhs)?; + } + + let blind = vk + .zk() + .then(|| Ok((transcript.read_ec_point()?, transcript.read_ec_point()?))) + .transpose()?; + + let r = transcript.squeeze_challenge(); + + Ok(Self { + blind, + r, + _marker: PhantomData, + }) + } +} + +impl AccumulationSchemeProver for KzgAs +where + C: CurveAffine, + PCS: PolynomialCommitmentScheme>, +{ + type ProvingKey = KzgAsProvingKey; + + fn create_proof( + pk: &Self::ProvingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + rng: R, + ) -> Result + where + T: TranscriptWrite, + R: Rng, + { + assert!(!instances.is_empty()); + + for accumulator in instances { + transcript.common_ec_point(&accumulator.lhs)?; + transcript.common_ec_point(&accumulator.rhs)?; + } + + let blind = pk + .zk() + .then(|| { + let s = C::Scalar::random(rng); + let (g, s_g) = pk.0.unwrap(); + let lhs = (s_g * s).to_affine(); + let rhs = (g * s).to_affine(); + transcript.write_ec_point(lhs)?; + transcript.write_ec_point(rhs)?; + Ok((lhs, rhs)) + }) + .transpose()?; + + let r = transcript.squeeze_challenge(); + + let (lhs, rhs) = instances + .iter() + .cloned() + .map(|accumulator| (accumulator.lhs, accumulator.rhs)) + .chain(blind) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let powers_of_r = r.powers(lhs.len()); + let [lhs, rhs] = [lhs, rhs].map(|msms| { + msms.into_iter() + .zip(powers_of_r.iter()) + .map(|(msm, power_of_r)| Msm::::base(msm) * power_of_r) + .sum::>() + .evaluate(None) + }); + + Ok(KzgAccumulator::new(lhs, rhs)) + } +} diff --git a/src/pcs/kzg/accumulator.rs b/src/pcs/kzg/accumulator.rs new file mode 100644 index 00000000..17c7bf83 --- /dev/null +++ b/src/pcs/kzg/accumulator.rs @@ -0,0 +1,208 @@ +use crate::{loader::Loader, util::arithmetic::CurveAffine}; +use std::fmt::Debug; + +#[derive(Clone, Debug)] +pub struct KzgAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub lhs: L::LoadedEcPoint, + pub rhs: L::LoadedEcPoint, +} + +impl KzgAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub fn new(lhs: L::LoadedEcPoint, rhs: L::LoadedEcPoint) -> Self { + Self { lhs, rhs } + } +} + +/// `AccumulatorEncoding` that encodes `Accumulator` into limbs. +/// +/// Since in circuit everything are in scalar field, but `Accumulator` might contain base field elements, so we split them into limbs. +/// The const generic `LIMBS` and `BITS` respectively represents how many limbs +/// a base field element are split into and how many bits each limbs could have. +#[derive(Clone, Debug)] +pub struct LimbsEncoding; + +mod native { + use crate::{ + loader::native::NativeLoader, + pcs::{ + kzg::{KzgAccumulator, LimbsEncoding}, + AccumulatorEncoding, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{fe_from_limbs, CurveAffine}, + Itertools, + }, + Error, + }; + + impl AccumulatorEncoding + for LimbsEncoding + where + C: CurveAffine, + PCS: PolynomialCommitmentScheme< + C, + NativeLoader, + Accumulator = KzgAccumulator, + >, + { + fn from_repr(limbs: Vec) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = limbs + .chunks(LIMBS) + .into_iter() + .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) + .collect_vec() + .try_into() + .unwrap(); + let accumulator = KzgAccumulator::new( + C::from_xy(lhs_x, lhs_y).unwrap(), + C::from_xy(rhs_x, rhs_y).unwrap(), + ); + + Ok(accumulator) + } + } +} + +#[cfg(feature = "loader_evm")] +mod evm { + use crate::{ + loader::evm::{EvmLoader, Scalar}, + pcs::{ + kzg::{KzgAccumulator, LimbsEncoding}, + AccumulatorEncoding, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{CurveAffine, PrimeField}, + Itertools, + }, + Error, + }; + use std::rc::Rc; + + impl AccumulatorEncoding, PCS> + for LimbsEncoding + where + C: CurveAffine, + C::Scalar: PrimeField, + PCS: PolynomialCommitmentScheme< + C, + Rc, + Accumulator = KzgAccumulator>, + >, + { + fn from_repr(limbs: Vec) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let loader = limbs[0].loader(); + + let [lhs_x, lhs_y, rhs_x, rhs_y]: [[_; LIMBS]; 4] = limbs + .chunks(LIMBS) + .into_iter() + .map(|limbs| limbs.to_vec().try_into().unwrap()) + .collect_vec() + .try_into() + .unwrap(); + let accumulator = KzgAccumulator::new( + loader.ec_point_from_limbs::(lhs_x, lhs_y), + loader.ec_point_from_limbs::(rhs_x, rhs_y), + ); + + Ok(accumulator) + } + } +} + +#[cfg(feature = "loader_halo2")] +mod halo2 { + use crate::{ + loader::halo2::{Context, EccInstructions, Halo2Loader, Scalar, Valuetools}, + pcs::{ + kzg::{KzgAccumulator, LimbsEncoding}, + AccumulatorEncoding, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{fe_from_limbs, CurveAffine}, + Itertools, + }, + Error, + }; + use halo2_proofs::circuit::Value; + use halo2_wrong_ecc::{maingate::AssignedValue, AssignedPoint}; + use std::{iter, rc::Rc}; + + fn ec_point_from_assigned_limbs( + limbs: &[AssignedValue], + ) -> Value { + assert_eq!(limbs.len(), 2 * LIMBS); + + let [x, y] = [&limbs[..LIMBS], &limbs[LIMBS..]].map(|limbs| { + limbs + .iter() + .map(|assigned| assigned.value()) + .fold_zipped(Vec::new(), |mut acc, limb| { + acc.push(*limb); + acc + }) + .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) + }); + + x.zip(y).map(|(x, y)| C::from_xy(x, y).unwrap()) + } + + impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> + AccumulatorEncoding>, PCS> for LimbsEncoding + where + C: CurveAffine, + PCS: PolynomialCommitmentScheme< + C, + Rc>, + Accumulator = KzgAccumulator>>, + >, + EccChip: EccInstructions< + 'a, + C, + AssignedEcPoint = AssignedPoint<::Base, C::Scalar, LIMBS, BITS>, + AssignedScalar = AssignedValue, + >, + { + fn from_repr(limbs: Vec>) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let loader = limbs[0].loader(); + + let assigned_limbs = limbs.iter().map(|limb| limb.assigned()).collect_vec(); + let [lhs, rhs] = [&assigned_limbs[..2 * LIMBS], &assigned_limbs[2 * LIMBS..]].map( + |assigned_limbs| { + let ec_point = ec_point_from_assigned_limbs::<_, LIMBS, BITS>(assigned_limbs); + loader.assign_ec_point(ec_point) + }, + ); + + for (src, dst) in assigned_limbs.iter().zip( + iter::empty() + .chain(lhs.assigned().x().limbs()) + .chain(lhs.assigned().y().limbs()) + .chain(rhs.assigned().x().limbs()) + .chain(rhs.assigned().y().limbs()), + ) { + loader + .ctx_mut() + .constrain_equal(src.cell(), dst.as_ref().cell()) + .unwrap(); + } + let accumulator = KzgAccumulator::new(lhs, rhs); + + Ok(accumulator) + } + } +} diff --git a/src/pcs/kzg/decider.rs b/src/pcs/kzg/decider.rs new file mode 100644 index 00000000..b6957883 --- /dev/null +++ b/src/pcs/kzg/decider.rs @@ -0,0 +1,162 @@ +use crate::util::arithmetic::MultiMillerLoop; +use std::marker::PhantomData; + +#[derive(Debug, Clone, Copy)] +pub struct KzgDecidingKey { + pub g2: M::G2Affine, + pub s_g2: M::G2Affine, + _marker: PhantomData, +} + +impl KzgDecidingKey { + pub fn new(g2: M::G2Affine, s_g2: M::G2Affine) -> Self { + Self { + g2, + s_g2, + _marker: PhantomData, + } + } +} + +impl From<(M::G2Affine, M::G2Affine)> for KzgDecidingKey { + fn from((g2, s_g2): (M::G2Affine, M::G2Affine)) -> KzgDecidingKey { + KzgDecidingKey::new(g2, s_g2) + } +} + +mod native { + use crate::{ + loader::native::NativeLoader, + pcs::{ + kzg::{Kzg, KzgAccumulator, KzgDecidingKey}, + Decider, + }, + util::arithmetic::{Group, MillerLoopResult, MultiMillerLoop}, + }; + use std::fmt::Debug; + + impl Decider for Kzg + where + M: MultiMillerLoop, + MOS: Clone + Debug, + { + type DecidingKey = KzgDecidingKey; + type Output = bool; + + fn decide( + dk: &Self::DecidingKey, + KzgAccumulator { lhs, rhs }: KzgAccumulator, + ) -> bool { + let terms = [(&lhs, &dk.g2.into()), (&rhs, &(-dk.s_g2).into())]; + M::multi_miller_loop(&terms) + .final_exponentiation() + .is_identity() + .into() + } + + fn decide_all( + dk: &Self::DecidingKey, + accumulators: Vec>, + ) -> bool { + !accumulators + .into_iter() + .any(|accumulator| !Self::decide(dk, accumulator)) + } + } +} + +#[cfg(feature = "loader_evm")] +mod evm { + use crate::{ + loader::{ + evm::{loader::Value, EvmLoader}, + LoadedScalar, + }, + pcs::{ + kzg::{Kzg, KzgAccumulator, KzgDecidingKey}, + Decider, + }, + util::{ + arithmetic::{CurveAffine, MultiMillerLoop, PrimeField}, + msm::Msm, + }, + }; + use ethereum_types::U256; + use std::{fmt::Debug, rc::Rc}; + + impl Decider> for Kzg + where + M: MultiMillerLoop, + M::Scalar: PrimeField, + MOS: Clone + Debug, + { + type DecidingKey = KzgDecidingKey; + type Output = (); + + fn decide( + dk: &Self::DecidingKey, + KzgAccumulator { lhs, rhs }: KzgAccumulator>, + ) { + let loader = lhs.loader(); + let [g2, minus_s_g2] = [dk.g2, -dk.s_g2].map(|ec_point| { + let coordinates = ec_point.coordinates().unwrap(); + let x = coordinates.x().to_repr(); + let y = coordinates.y().to_repr(); + ( + U256::from_little_endian(&x.as_ref()[32..]), + U256::from_little_endian(&x.as_ref()[..32]), + U256::from_little_endian(&y.as_ref()[32..]), + U256::from_little_endian(&y.as_ref()[..32]), + ) + }); + loader.pairing(&lhs, g2, &rhs, minus_s_g2); + } + + fn decide_all( + dk: &Self::DecidingKey, + mut accumulators: Vec>>, + ) { + assert!(!accumulators.is_empty()); + + let accumulator = if accumulators.len() == 1 { + accumulators.pop().unwrap() + } else { + let loader = accumulators[0].lhs.loader(); + let (lhs, rhs) = accumulators + .iter() + .map(|KzgAccumulator { lhs, rhs }| { + let [lhs, rhs] = [&lhs, &rhs].map(|ec_point| loader.dup_ec_point(ec_point)); + (lhs, rhs) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let hash_ptr = loader.keccak256(lhs[0].ptr(), lhs.len() * 0x80); + let challenge_ptr = loader.allocate(0x20); + loader + .code_mut() + .push(loader.scalar_modulus()) + .push(hash_ptr) + .mload() + .r#mod() + .push(challenge_ptr) + .mstore(); + let challenge = loader.scalar(Value::Memory(challenge_ptr)); + + let powers_of_challenge = LoadedScalar::::powers(&challenge, lhs.len()); + let [lhs, rhs] = [lhs, rhs].map(|msms| { + msms.into_iter() + .zip(powers_of_challenge.iter()) + .map(|(msm, power_of_challenge)| { + Msm::>::base(msm) * power_of_challenge + }) + .sum::>() + .evaluate(None) + }); + + KzgAccumulator::new(lhs, rhs) + }; + + Self::decide(dk, accumulator) + } + } +} diff --git a/src/pcs/kzg/multiopen.rs b/src/pcs/kzg/multiopen.rs new file mode 100644 index 00000000..d3e50e62 --- /dev/null +++ b/src/pcs/kzg/multiopen.rs @@ -0,0 +1,5 @@ +mod bdfg21; +mod gwc19; + +pub use bdfg21::{Bdfg21, Bdfg21Proof}; +pub use gwc19::{Gwc19, Gwc19Proof}; diff --git a/src/pcs/kzg/multiopen/bdfg21.rs b/src/pcs/kzg/multiopen/bdfg21.rs new file mode 100644 index 00000000..287700d7 --- /dev/null +++ b/src/pcs/kzg/multiopen/bdfg21.rs @@ -0,0 +1,381 @@ +use crate::{ + cost::{Cost, CostEstimation}, + loader::{LoadedScalar, Loader, ScalarLoader}, + pcs::{ + kzg::{Kzg, KzgAccumulator, KzgSuccinctVerifyingKey}, + MultiOpenScheme, Query, + }, + util::{ + arithmetic::{ilog2, CurveAffine, FieldExt, Fraction, MultiMillerLoop}, + msm::Msm, + transcript::TranscriptRead, + Itertools, + }, + Error, +}; +use std::{ + collections::{BTreeMap, BTreeSet}, + marker::PhantomData, +}; + +#[derive(Clone, Debug)] +pub struct Bdfg21; + +impl MultiOpenScheme for Kzg +where + M: MultiMillerLoop, + L: Loader, +{ + type SuccinctVerifyingKey = KzgSuccinctVerifyingKey; + type Proof = Bdfg21Proof; + + fn read_proof( + _: &KzgSuccinctVerifyingKey, + _: &[Query], + transcript: &mut T, + ) -> Result, Error> + where + T: TranscriptRead, + { + Bdfg21Proof::read(transcript) + } + + fn succinct_verify( + svk: &KzgSuccinctVerifyingKey, + commitments: &[Msm], + z: &L::LoadedScalar, + queries: &[Query], + proof: &Bdfg21Proof, + ) -> Result { + let f = { + let sets = query_sets(queries); + let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); + + let powers_of_mu = proof + .mu + .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let msms = sets + .iter() + .zip(coeffs.iter()) + .map(|(set, coeff)| set.msm(coeff, commitments, &powers_of_mu)); + + msms.zip(proof.gamma.powers(sets.len()).into_iter()) + .map(|(msm, power_of_gamma)| msm * &power_of_gamma) + .sum::>() + - Msm::base(proof.w.clone()) * &coeffs[0].z_s + }; + + let rhs = Msm::base(proof.w_prime.clone()); + let lhs = f + rhs.clone() * &proof.z_prime; + + Ok(KzgAccumulator::new( + lhs.evaluate(Some(svk.g)), + rhs.evaluate(Some(svk.g)), + )) + } +} + +#[derive(Clone, Debug)] +pub struct Bdfg21Proof +where + C: CurveAffine, + L: Loader, +{ + mu: L::LoadedScalar, + gamma: L::LoadedScalar, + w: L::LoadedEcPoint, + z_prime: L::LoadedScalar, + w_prime: L::LoadedEcPoint, +} + +impl Bdfg21Proof +where + C: CurveAffine, + L: Loader, +{ + fn read>(transcript: &mut T) -> Result { + let mu = transcript.squeeze_challenge(); + let gamma = transcript.squeeze_challenge(); + let w = transcript.read_ec_point()?; + let z_prime = transcript.squeeze_challenge(); + let w_prime = transcript.read_ec_point()?; + Ok(Bdfg21Proof { + mu, + gamma, + w, + z_prime, + w_prime, + }) + } +} + +fn query_sets(queries: &[Query]) -> Vec> { + let poly_shifts = queries.iter().fold( + Vec::<(usize, Vec, Vec<&T>)>::new(), + |mut poly_shifts, query| { + if let Some(pos) = poly_shifts + .iter() + .position(|(poly, _, _)| *poly == query.poly) + { + let (_, shifts, evals) = &mut poly_shifts[pos]; + if !shifts.contains(&query.shift) { + shifts.push(query.shift); + evals.push(&query.eval); + } + } else { + poly_shifts.push((query.poly, vec![query.shift], vec![&query.eval])); + } + poly_shifts + }, + ); + + poly_shifts.into_iter().fold( + Vec::>::new(), + |mut sets, (poly, shifts, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.shifts + .iter() + .map(|lhs| { + let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + evals[idx].clone() + }) + .collect(), + ); + } + } else { + let set = QuerySet { + shifts, + polys: vec![poly], + evals: vec![evals.into_iter().cloned().collect()], + }; + sets.push(set); + } + sets + }, + ) +} + +fn query_set_coeffs>( + sets: &[QuerySet], + z: &T, + z_prime: &T, +) -> Vec> { + let loader = z.loader(); + + let superset = sets + .iter() + .flat_map(|set| set.shifts.clone()) + .sorted() + .dedup(); + + let size = 2.max( + ilog2((sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1).next_power_of_two()) + 1, + ); + let powers_of_z = z.powers(size); + let z_prime_minus_z_shift_i = BTreeMap::from_iter(superset.map(|shift| { + ( + shift, + z_prime.clone() - z.clone() * loader.load_const(&shift), + ) + })); + + let mut z_s_1 = None; + let mut coeffs = sets + .iter() + .map(|set| { + let coeff = QuerySetCoeff::new( + &set.shifts, + &powers_of_z, + z_prime, + &z_prime_minus_z_shift_i, + &z_s_1, + ); + if z_s_1.is_none() { + z_s_1 = Some(coeff.z_s.clone()); + }; + coeff + }) + .collect_vec(); + + T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + coeffs.iter_mut().for_each(QuerySetCoeff::evaluate); + + coeffs +} + +#[derive(Clone, Debug)] +struct QuerySet { + shifts: Vec, + polys: Vec, + evals: Vec>, +} + +impl> QuerySet { + fn msm>( + &self, + coeff: &QuerySetCoeff, + commitments: &[Msm], + powers_of_mu: &[T], + ) -> Msm { + self.polys + .iter() + .zip(self.evals.iter()) + .zip(powers_of_mu.iter()) + .map(|((poly, evals), power_of_mu)| { + let loader = power_of_mu.loader(); + let commitment = coeff + .commitment_coeff + .as_ref() + .map(|commitment_coeff| { + commitments[*poly].clone() * commitment_coeff.evaluated() + }) + .unwrap_or_else(|| commitments[*poly].clone()); + let r_eval = loader.sum_products( + &coeff + .eval_coeffs + .iter() + .zip(evals.iter()) + .map(|(coeff, eval)| (coeff.evaluated(), eval)) + .collect_vec(), + ) * coeff.r_eval_coeff.as_ref().unwrap().evaluated(); + (commitment - Msm::constant(r_eval)) * power_of_mu + }) + .sum() + } +} + +#[derive(Clone, Debug)] +struct QuerySetCoeff { + z_s: T, + eval_coeffs: Vec>, + commitment_coeff: Option>, + r_eval_coeff: Option>, + _marker: PhantomData, +} + +impl QuerySetCoeff +where + F: FieldExt, + T: LoadedScalar, +{ + fn new( + shifts: &[F], + powers_of_z: &[T], + z_prime: &T, + z_prime_minus_z_shift_i: &BTreeMap, + z_s_1: &Option, + ) -> Self { + let loader = z_prime.loader(); + + let normalized_ell_primes = shifts + .iter() + .enumerate() + .map(|(j, shift_j)| { + shifts + .iter() + .enumerate() + .filter(|&(i, _)| i != j) + .map(|(_, shift_i)| (*shift_j - shift_i)) + .reduce(|acc, value| acc * value) + .unwrap_or_else(|| F::one()) + }) + .collect_vec(); + + let z = &powers_of_z[1].clone(); + let z_pow_k_minus_one = { + let k_minus_one = shifts.len() - 1; + powers_of_z + .iter() + .enumerate() + .skip(1) + .filter_map(|(i, power_of_z)| { + (k_minus_one & (1 << i) == 1).then(|| power_of_z.clone()) + }) + .reduce(|acc, value| acc * value) + .unwrap_or_else(|| loader.load_one()) + }; + + let barycentric_weights = shifts + .iter() + .zip(normalized_ell_primes.iter()) + .map(|(shift, normalized_ell_prime)| { + loader.sum_products_with_coeff(&[ + (*normalized_ell_prime, &z_pow_k_minus_one, z_prime), + (-(*normalized_ell_prime * shift), &z_pow_k_minus_one, z), + ]) + }) + .map(Fraction::one_over) + .collect_vec(); + + let z_s = loader.product( + &shifts + .iter() + .map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()) + .collect_vec(), + ); + let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone())); + + Self { + z_s, + eval_coeffs: barycentric_weights, + commitment_coeff: z_s_1_over_z_s, + r_eval_coeff: None, + _marker: PhantomData, + } + } + + fn denoms(&mut self) -> impl IntoIterator { + if self.eval_coeffs.first().unwrap().denom().is_some() { + return self + .eval_coeffs + .iter_mut() + .chain(self.commitment_coeff.as_mut()) + .filter_map(Fraction::denom_mut) + .collect_vec(); + } + + if self.r_eval_coeff.is_none() { + let loader = self.z_s.loader(); + self.eval_coeffs + .iter_mut() + .chain(self.commitment_coeff.as_mut()) + .for_each(Fraction::evaluate); + let barycentric_weights_sum = loader.sum( + &self + .eval_coeffs + .iter() + .map(Fraction::evaluated) + .collect_vec(), + ); + self.r_eval_coeff = Some(match self.commitment_coeff.clone() { + Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum), + None => Fraction::one_over(barycentric_weights_sum), + }); + return vec![self.r_eval_coeff.as_mut().unwrap().denom_mut().unwrap()]; + } + + unreachable!() + } + + fn evaluate(&mut self) { + self.r_eval_coeff.as_mut().unwrap().evaluate(); + } +} + +impl CostEstimation for Kzg +where + M: MultiMillerLoop, +{ + type Input = Vec>; + + fn estimate_cost(_: &Vec>) -> Cost { + Cost::new(0, 2, 0, 2) + } +} diff --git a/src/pcs/kzg/multiopen/gwc19.rs b/src/pcs/kzg/multiopen/gwc19.rs new file mode 100644 index 00000000..121fce8a --- /dev/null +++ b/src/pcs/kzg/multiopen/gwc19.rs @@ -0,0 +1,167 @@ +use crate::{ + cost::{Cost, CostEstimation}, + loader::{LoadedScalar, Loader}, + pcs::{ + kzg::{Kzg, KzgAccumulator, KzgSuccinctVerifyingKey}, + MultiOpenScheme, Query, + }, + util::{ + arithmetic::{CurveAffine, MultiMillerLoop, PrimeField}, + msm::Msm, + transcript::TranscriptRead, + Itertools, + }, + Error, +}; + +#[derive(Clone, Debug)] +pub struct Gwc19; + +impl MultiOpenScheme for Kzg +where + M: MultiMillerLoop, + L: Loader, +{ + type SuccinctVerifyingKey = KzgSuccinctVerifyingKey; + type Proof = Gwc19Proof; + + fn read_proof( + _: &Self::SuccinctVerifyingKey, + queries: &[Query], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + Gwc19Proof::read(queries, transcript) + } + + fn succinct_verify( + svk: &Self::SuccinctVerifyingKey, + commitments: &[Msm], + z: &L::LoadedScalar, + queries: &[Query], + proof: &Self::Proof, + ) -> Result { + let sets = query_sets(queries); + let powers_of_u = &proof.u.powers(sets.len()); + let f = { + let powers_of_v = proof + .v + .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + sets.iter() + .map(|set| set.msm(commitments, &powers_of_v)) + .zip(powers_of_u.iter()) + .map(|(msm, power_of_u)| msm * power_of_u) + .sum::>() + }; + let z_omegas = sets + .iter() + .map(|set| z.clone() * &z.loader().load_const(&set.shift)); + + let rhs = proof + .ws + .iter() + .zip(powers_of_u.iter()) + .map(|(w, power_of_u)| Msm::base(w.clone()) * power_of_u) + .collect_vec(); + let lhs = f + rhs + .iter() + .zip(z_omegas) + .map(|(uw, z_omega)| uw.clone() * &z_omega) + .sum(); + + Ok(KzgAccumulator::new( + lhs.evaluate(Some(svk.g)), + rhs.into_iter().sum::>().evaluate(Some(svk.g)), + )) + } +} + +#[derive(Clone, Debug)] +pub struct Gwc19Proof +where + C: CurveAffine, + L: Loader, +{ + v: L::LoadedScalar, + ws: Vec, + u: L::LoadedScalar, +} + +impl Gwc19Proof +where + C: CurveAffine, + L: Loader, +{ + fn read(queries: &[Query], transcript: &mut T) -> Result + where + T: TranscriptRead, + { + let v = transcript.squeeze_challenge(); + let ws = transcript.read_n_ec_points(query_sets(queries).len())?; + let u = transcript.squeeze_challenge(); + Ok(Gwc19Proof { v, ws, u }) + } +} + +struct QuerySet { + shift: F, + polys: Vec, + evals: Vec, +} + +impl QuerySet +where + F: PrimeField, + T: Clone, +{ + fn msm>( + &self, + commitments: &[Msm], + powers_of_v: &[L::LoadedScalar], + ) -> Msm { + self.polys + .iter() + .zip(self.evals.iter()) + .map(|(poly, eval)| { + let commitment = commitments[*poly].clone(); + commitment - Msm::constant(eval.clone()) + }) + .zip(powers_of_v.iter()) + .map(|(msm, power_of_v)| msm * power_of_v) + .sum() + } +} + +fn query_sets(queries: &[Query]) -> Vec> +where + F: PrimeField, + T: Clone + PartialEq, +{ + queries.iter().fold(Vec::new(), |mut sets, query| { + if let Some(pos) = sets.iter().position(|set| set.shift == query.shift) { + sets[pos].polys.push(query.poly); + sets[pos].evals.push(query.eval.clone()); + } else { + sets.push(QuerySet { + shift: query.shift, + polys: vec![query.poly], + evals: vec![query.eval.clone()], + }); + } + sets + }) +} + +impl CostEstimation for Kzg +where + M: MultiMillerLoop, +{ + type Input = Vec>; + + fn estimate_cost(queries: &Vec>) -> Cost { + let num_w = query_sets(queries).len(); + Cost::new(0, num_w, 0, num_w) + } +} diff --git a/src/protocol.rs b/src/protocol.rs deleted file mode 100644 index af591b78..00000000 --- a/src/protocol.rs +++ /dev/null @@ -1,54 +0,0 @@ -use crate::util::{Curve, Domain, Expression, Group, Itertools, Query}; - -#[cfg(feature = "halo2")] -pub mod halo2; - -#[derive(Clone, Debug)] -pub struct Protocol { - pub zk: bool, - pub domain: Domain, - pub preprocessed: Vec, - pub num_statement: Vec, - pub num_auxiliary: Vec, - pub num_challenge: Vec, - pub evaluations: Vec, - pub queries: Vec, - pub relations: Vec>, - pub transcript_initial_state: C::Scalar, - pub accumulator_indices: Option>>, -} - -impl Protocol { - pub fn vanishing_poly(&self) -> usize { - self.preprocessed.len() - + self.num_statement.len() - + self.num_auxiliary.iter().sum::() - } -} - -pub struct Snark { - pub protocol: Protocol, - pub statements: Vec::Scalar>>, - pub proof: Vec, -} - -impl Snark { - pub fn new( - protocol: Protocol, - statements: Vec::Scalar>>, - proof: Vec, - ) -> Self { - assert_eq!( - protocol.num_statement, - statements - .iter() - .map(|statements| statements.len()) - .collect_vec() - ); - Snark { - protocol, - statements, - proof, - } - } -} diff --git a/src/protocol/halo2/test.rs b/src/protocol/halo2/test.rs deleted file mode 100644 index cfbcefef..00000000 --- a/src/protocol/halo2/test.rs +++ /dev/null @@ -1,176 +0,0 @@ -use crate::{ - protocol::halo2::{compile, Config}, - scheme::kzg::{Cost, CostEstimation, PlonkAccumulationScheme}, - util::{CommonPolynomial, Expression, Query}, -}; -use halo2_curves::bn256::{Bn256, Fr, G1}; -use halo2_proofs::{ - arithmetic::FieldExt, - dev::MockProver, - plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey}, - poly::{ - commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier}, - kzg::commitment::KZGCommitmentScheme, - Rotation, VerificationStrategy, - }, - transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, -}; -use rand_chacha::{ - rand_core::{RngCore, SeedableRng}, - ChaCha20Rng, -}; -use std::assert_matches::assert_matches; - -mod circuit; -mod kzg; - -pub use circuit::{ - maingate::{ - MainGateWithPlookup, MainGateWithPlookupConfig, MainGateWithRange, MainGateWithRangeConfig, - }, - standard::StandardPlonk, -}; - -pub fn create_proof_checked<'a, S, C, P, V, VS, TW, TR, EC, R, const ZK: bool>( - params: &'a S::ParamsProver, - pk: &ProvingKey, - circuits: &[C], - instances: &[&[&[S::Scalar]]], - mut rng: R, -) -> Vec -where - S: CommitmentScheme, - S::ParamsVerifier: 'a, - C: Circuit, - P: Prover<'a, S>, - V: Verifier<'a, S>, - VS: VerificationStrategy<'a, S, V, Output = VS>, - TW: TranscriptWriterBuffer, S::Curve, EC>, - TR: TranscriptReadBuffer<&'static [u8], S::Curve, EC>, - EC: EncodedChallenge, - R: RngCore, -{ - for (circuit, instances) in circuits.iter().zip(instances.iter()) { - MockProver::run::<_, ZK>( - params.k(), - circuit, - instances.iter().map(|instance| instance.to_vec()).collect(), - ) - .unwrap() - .assert_satisfied(); - } - - let proof = { - let mut transcript = TW::init(Vec::new()); - create_proof::( - params, - pk, - circuits, - instances, - &mut rng, - &mut transcript, - ) - .unwrap(); - transcript.finalize() - }; - - let accept = { - let params = params.verifier_params(); - let strategy = VS::new(params); - let mut transcript = TR::init(Box::leak(Box::new(proof.clone()))); - verify_proof::<_, _, _, _, _, ZK>(params, pk.get_vk(), strategy, instances, &mut transcript) - .unwrap() - .finalize() - }; - assert!(accept); - - proof -} - -#[test] -fn test_compile_standard_plonk() { - let circuit = StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())); - - let params = kzg::read_or_create_srs::(9); - let vk = keygen_vk::, _, false>(¶ms, &circuit).unwrap(); - let pk = keygen_pk::, _, false>(¶ms, vk, &circuit).unwrap(); - - let protocol = compile::( - pk.get_vk(), - Config { - zk: false, - query_instance: false, - num_instance: vec![1], - num_proof: 1, - accumulator_indices: None, - }, - ); - - let [q_a, q_b, q_c, q_ab, constant, sigma_a, sigma_b, sigma_c, instance, a, b, c, z] = - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12].map(|poly| Query::new(poly, Rotation::cur())); - let z_w = Query::new(12, Rotation::next()); - let t = Query::new(13, Rotation::cur()); - - assert_eq!(protocol.preprocessed.len(), 8); - assert_eq!(protocol.num_statement, vec![1]); - assert_eq!(protocol.num_auxiliary, vec![3, 0, 1]); - assert_eq!(protocol.num_challenge, vec![1, 2, 0]); - assert_eq!( - protocol.evaluations, - vec![a, b, c, q_a, q_b, q_c, q_ab, constant, sigma_a, sigma_b, sigma_c, z, z_w] - ); - assert_eq!( - protocol.queries, - vec![a, b, c, z, z_w, q_a, q_b, q_c, q_ab, constant, sigma_a, sigma_b, sigma_c, t] - ); - assert_eq!( - format!("{:?}", protocol.relations), - format!("{:?}", { - let [q_a, q_b, q_c, q_ab, constant, sigma_a, sigma_b, sigma_c, instance, a, b, c, z, z_w, beta, gamma, l_0, identity, one, k_1, k_2] = - &[ - Expression::Polynomial(q_a), - Expression::Polynomial(q_b), - Expression::Polynomial(q_c), - Expression::Polynomial(q_ab), - Expression::Polynomial(constant), - Expression::Polynomial(sigma_a), - Expression::Polynomial(sigma_b), - Expression::Polynomial(sigma_c), - Expression::Polynomial(instance), - Expression::Polynomial(a), - Expression::Polynomial(b), - Expression::Polynomial(c), - Expression::Polynomial(z), - Expression::Polynomial(z_w), - Expression::Challenge(1), // beta - Expression::Challenge(2), // gamma - Expression::CommonPolynomial(CommonPolynomial::Lagrange(0)), // l_0 - Expression::CommonPolynomial(CommonPolynomial::Identity), // identity - Expression::Constant(Fr::one()), // one - Expression::Constant(Fr::DELTA), // k_1 - Expression::Constant(Fr::DELTA * Fr::DELTA), // k_2 - ]; - - vec![ - q_a * a + q_b * b + q_c * c + q_ab * a * b + constant + instance, - l_0 * (one - z), - z_w * ((a + beta * sigma_a + gamma) - * (b + beta * sigma_b + gamma) - * (c + beta * sigma_c + gamma)) - - z * ((a + beta * one * identity + gamma) - * (b + beta * k_1 * identity + gamma) - * (c + beta * k_2 * identity + gamma)), - ] - }) - ); - - assert_matches!( - PlonkAccumulationScheme::estimate_cost(&protocol), - Cost { - num_commitment: 9, - num_evaluation: 13, - num_msm: 20, - .. - } - ); -} diff --git a/src/protocol/halo2/test/circuit/maingate.rs b/src/protocol/halo2/test/circuit/maingate.rs deleted file mode 100644 index b03a3ff4..00000000 --- a/src/protocol/halo2/test/circuit/maingate.rs +++ /dev/null @@ -1,385 +0,0 @@ -use crate::{protocol::halo2::test::circuit::plookup::PlookupConfig, util::Itertools}; -use halo2_proofs::{ - arithmetic::FieldExt, - circuit::{floor_planner::V1, Chip, Layouter, Value}, - plonk::{Any, Circuit, Column, ConstraintSystem, Error, Fixed}, - poly::Rotation, -}; -use halo2_wrong_ecc::{ - maingate::{ - decompose, AssignedValue, MainGate, MainGateConfig, MainGateInstructions, RangeChip, - RangeConfig, RangeInstructions, RegionCtx, Term, - }, - EccConfig, -}; -use rand::RngCore; -use std::{collections::BTreeMap, iter}; - -#[derive(Clone)] -pub struct MainGateWithRangeConfig { - main_gate_config: MainGateConfig, - range_config: RangeConfig, -} - -impl MainGateWithRangeConfig { - pub fn configure( - meta: &mut ConstraintSystem, - composition_bits: Vec, - overflow_bits: Vec, - ) -> Self { - let main_gate_config = MainGate::::configure(meta); - let range_config = - RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); - MainGateWithRangeConfig { - main_gate_config, - range_config, - } - } - - pub fn ecc_config(&self) -> EccConfig { - EccConfig::new(self.range_config.clone(), self.main_gate_config.clone()) - } - - pub fn load_table(&self, layouter: &mut impl Layouter) -> Result<(), Error> { - let range_chip = RangeChip::::new(self.range_config.clone()); - range_chip.load_table(layouter)?; - Ok(()) - } -} - -#[derive(Clone, Default)] -pub struct MainGateWithRange(Vec); - -impl MainGateWithRange { - pub fn new(inner: Vec) -> Self { - Self(inner) - } - - pub fn rand(mut rng: R) -> Self { - Self::new(vec![F::from(rng.next_u32() as u64)]) - } - - pub fn instances(&self) -> Vec> { - vec![self.0.clone()] - } -} - -impl Circuit for MainGateWithRange { - type Config = MainGateWithRangeConfig; - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self(vec![F::zero()]) - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - MainGateWithRangeConfig::configure(meta, vec![8], vec![4, 7]) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let main_gate = MainGate::new(config.main_gate_config); - let range_chip = RangeChip::new(config.range_config); - range_chip.load_table(&mut layouter)?; - - let a = layouter.assign_region( - || "", - |mut region| { - let mut offset = 0; - let mut ctx = RegionCtx::new(&mut region, &mut offset); - range_chip.decompose(&mut ctx, Value::known(F::from(u64::MAX)), 8, 64)?; - range_chip.decompose(&mut ctx, Value::known(F::from(u32::MAX as u64)), 8, 39)?; - let a = range_chip.assign(&mut ctx, Value::known(self.0[0]), 8, 68)?; - let b = main_gate.sub_sub_with_constant(&mut ctx, &a, &a, &a, F::from(2))?; - let cond = main_gate.assign_bit(&mut ctx, Value::known(F::one()))?; - main_gate.select(&mut ctx, &a, &b, &cond)?; - - Ok(a) - }, - )?; - - main_gate.expose_public(layouter, a, 0)?; - - Ok(()) - } -} - -#[derive(Clone, Debug)] -pub struct PlookupRangeConfig { - main_gate_config: MainGateConfig, - plookup_config: PlookupConfig, - table: [Column; 2], - q_limb: [Column; 2], - q_overflow: [Column; 2], - bits: BTreeMap, -} - -pub struct PlookupRangeChip { - n: usize, - config: PlookupRangeConfig, - main_gate: MainGate, -} - -impl PlookupRangeChip { - pub fn new(config: PlookupRangeConfig, n: usize) -> Self { - let main_gate = MainGate::new(config.main_gate_config.clone()); - Self { - n, - config, - main_gate, - } - } - - pub fn configure( - meta: &mut ConstraintSystem, - main_gate_config: MainGateConfig, - bits: impl IntoIterator, - ) -> PlookupRangeConfig { - let table = [(); 2].map(|_| meta.fixed_column()); - let q_limb = [(); 2].map(|_| meta.fixed_column()); - let q_overflow = [(); 2].map(|_| meta.fixed_column()); - let plookup_config = PlookupConfig::configure( - meta, - |meta| { - let [a, b, c, d, _] = main_gate_config.advices(); - let limbs = [a, b, c, d].map(|column| meta.query_advice(column, Rotation::cur())); - let overflow = meta.query_advice(a, Rotation::cur()); - let q_limb = q_limb.map(|column| meta.query_fixed(column, Rotation::cur())); - let q_overflow = q_overflow.map(|column| meta.query_fixed(column, Rotation::cur())); - iter::empty() - .chain(limbs.into_iter().zip(iter::repeat(q_limb))) - .chain(Some((overflow, q_overflow))) - .map(|(value, [selector, tag])| [tag, selector * value]) - .collect() - }, - table.map(Column::::from), - None, - None, - None, - None, - ); - let bits = bits - .into_iter() - .sorted() - .dedup() - .enumerate() - .map(|(tag, bit)| (bit, tag)) - .collect(); - PlookupRangeConfig { - main_gate_config, - plookup_config, - table, - q_limb, - q_overflow, - bits, - } - } - - pub fn assign_inner(&self, layouter: impl Layouter, n: usize) -> Result<(), Error> { - self.config.plookup_config.assign(layouter, n) - } -} - -impl Chip for PlookupRangeChip { - type Config = PlookupRangeConfig; - - type Loaded = (); - - fn config(&self) -> &Self::Config { - &self.config - } - - fn loaded(&self) -> &Self::Loaded { - &() - } -} - -impl RangeInstructions for PlookupRangeChip { - fn assign( - &self, - ctx: &mut RegionCtx<'_, '_, F>, - value: Value, - limb_bit: usize, - bit: usize, - ) -> Result, Error> { - let (assigned, _) = self.decompose(ctx, value, limb_bit, bit)?; - Ok(assigned) - } - - fn decompose( - &self, - ctx: &mut RegionCtx<'_, '_, F>, - value: Value, - limb_bit: usize, - bit: usize, - ) -> Result<(AssignedValue, Vec>), Error> { - let (num_limbs, overflow) = (bit / limb_bit, bit % limb_bit); - - let num_limbs = num_limbs + if overflow > 0 { 1 } else { 0 }; - let terms = value - .map(|value| decompose(value, num_limbs, limb_bit)) - .transpose_vec(num_limbs) - .into_iter() - .zip((0..num_limbs).map(|i| F::from(2).pow(&[(limb_bit * i) as u64, 0, 0, 0]))) - .map(|(limb, base)| Term::Unassigned(limb, base)) - .collect_vec(); - - self.main_gate - .decompose(ctx, &terms, F::zero(), |ctx, is_last| { - ctx.assign_fixed(|| "", self.config.q_limb[0], F::one())?; - ctx.assign_fixed( - || "", - self.config.q_limb[1], - F::from(*self.config.bits.get(&limb_bit).unwrap() as u64), - )?; - if is_last && overflow != 0 { - ctx.assign_fixed(|| "", self.config.q_overflow[0], F::one())?; - ctx.assign_fixed( - || "", - self.config.q_overflow[1], - F::from(*self.config.bits.get(&limb_bit).unwrap() as u64), - )?; - } - Ok(()) - }) - } - - fn load_table(&self, layouter: &mut impl Layouter) -> Result<(), Error> { - layouter.assign_region( - || "", - |mut region| { - let mut offset = 0; - - for (bit, tag) in self.config.bits.iter() { - let tag = F::from(*tag as u64); - let table_values: Vec = (0..1 << bit).map(|e| F::from(e)).collect(); - for value in table_values.iter() { - region.assign_fixed( - || "table tag", - self.config.table[0], - offset, - || Value::known(tag), - )?; - region.assign_fixed( - || "table value", - self.config.table[1], - offset, - || Value::known(*value), - )?; - offset += 1; - } - } - - for offset in offset..self.n { - region.assign_fixed( - || "table tag", - self.config.table[0], - offset, - || Value::known(F::zero()), - )?; - region.assign_fixed( - || "table value", - self.config.table[1], - offset, - || Value::known(F::zero()), - )?; - } - - Ok(()) - }, - )?; - - Ok(()) - } -} - -#[derive(Clone)] -pub struct MainGateWithPlookupConfig { - main_gate_config: MainGateConfig, - plookup_range_config: PlookupRangeConfig, -} - -impl MainGateWithPlookupConfig { - pub fn configure( - meta: &mut ConstraintSystem, - bits: impl IntoIterator, - ) -> Self { - let main_gate_config = MainGate::configure(meta); - let plookup_range_config = - PlookupRangeChip::configure(meta, main_gate_config.clone(), bits); - - assert_eq!(meta.degree::(), 3); - - MainGateWithPlookupConfig { - main_gate_config, - plookup_range_config, - } - } -} - -#[derive(Clone, Default)] -pub struct MainGateWithPlookup { - n: usize, - inner: Vec, -} - -impl MainGateWithPlookup { - pub fn new(k: u32, inner: Vec) -> Self { - Self { n: 1 << k, inner } - } - - pub fn instances(&self) -> Vec> { - vec![self.inner.clone()] - } -} - -impl Circuit for MainGateWithPlookup { - type Config = MainGateWithPlookupConfig; - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self { - n: self.n, - inner: vec![F::zero()], - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - MainGateWithPlookupConfig::configure(meta, [1, 7, 8]) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let main_gate = MainGate::::new(config.main_gate_config.clone()); - let range_chip = PlookupRangeChip::new(config.plookup_range_config, self.n); - - range_chip.load_table(&mut layouter)?; - range_chip.assign_inner(layouter.namespace(|| ""), self.n)?; - - let a = layouter.assign_region( - || "", - |mut region| { - let mut offset = 0; - let mut ctx = RegionCtx::new(&mut region, &mut offset); - range_chip.decompose(&mut ctx, Value::known(F::from(u64::MAX)), 8, 64)?; - range_chip.decompose(&mut ctx, Value::known(F::from(u32::MAX as u64)), 8, 39)?; - let a = range_chip.assign(&mut ctx, Value::known(self.inner[0]), 8, 68)?; - let b = main_gate.sub_sub_with_constant(&mut ctx, &a, &a, &a, F::from(2))?; - let cond = main_gate.assign_bit(&mut ctx, Value::known(F::one()))?; - main_gate.select(&mut ctx, &a, &b, &cond)?; - - Ok(a) - }, - )?; - - main_gate.expose_public(layouter, a, 0)?; - - Ok(()) - } -} diff --git a/src/protocol/halo2/test/circuit/plookup.rs b/src/protocol/halo2/test/circuit/plookup.rs deleted file mode 100644 index 4e05e076..00000000 --- a/src/protocol/halo2/test/circuit/plookup.rs +++ /dev/null @@ -1,947 +0,0 @@ -use crate::util::{BatchInvert, EitherOrBoth, Field, Itertools}; -use halo2_proofs::{ - arithmetic::FieldExt, - circuit::{Layouter, Value}, - plonk::{ - Advice, Any, Challenge, Column, ConstraintSystem, Error, Expression, FirstPhase, - SecondPhase, Selector, ThirdPhase, VirtualCells, - }, - poly::Rotation, -}; -use std::{collections::BTreeMap, convert::TryFrom, iter, ops::Mul}; - -fn query( - meta: &mut ConstraintSystem, - query_fn: impl FnOnce(&mut VirtualCells<'_, F>) -> T, -) -> T { - let mut tmp = None; - meta.create_gate("", |meta| { - tmp = Some(query_fn(meta)); - Some(Expression::Constant(F::zero())) - }); - tmp.unwrap() -} - -fn first_fit_packing(cap: usize, weights: Vec) -> Vec> { - let mut bins = Vec::<(usize, Vec)>::new(); - - weights.into_iter().enumerate().for_each(|(idx, weight)| { - for (remaining, indices) in bins.iter_mut() { - if *remaining >= weight { - *remaining -= weight; - indices.push(idx); - return; - } - } - bins.push((cap - weight, vec![idx])); - }); - - bins.into_iter().map(|(_, indices)| indices).collect() -} - -fn max_advice_phase(expression: &Expression) -> u8 { - expression.evaluate( - &|_| 0, - &|_| 0, - &|_| 0, - &|query| query.phase(), - &|_| 0, - &|_| 0, - &|a| a, - &|a, b| a.max(b), - &|a, b| a.max(b), - &|a, _| a, - ) -} - -fn min_challenge_phase(expression: &Expression) -> Option { - expression.evaluate( - &|_| None, - &|_| None, - &|_| None, - &|_| None, - &|_| None, - &|challenge| Some(challenge.phase()), - &|a| a, - &|a, b| match (a, b) { - (Some(a), Some(b)) => Some(a.min(b)), - (Some(phase), None) | (None, Some(phase)) => Some(phase), - (None, None) => None, - }, - &|a, b| match (a, b) { - (Some(a), Some(b)) => Some(a.min(b)), - (Some(phase), None) | (None, Some(phase)) => Some(phase), - (None, None) => None, - }, - &|a, _| a, - ) -} - -fn advice_column_in(meta: &mut ConstraintSystem, phase: u8) -> Column { - match phase { - 0 => meta.advice_column_in(FirstPhase), - 1 => meta.advice_column_in(SecondPhase), - 2 => meta.advice_column_in(ThirdPhase), - _ => unreachable!(), - } -} - -fn challenge_usable_after(meta: &mut ConstraintSystem, phase: u8) -> Challenge { - match phase { - 0 => meta.challenge_usable_after(FirstPhase), - 1 => meta.challenge_usable_after(SecondPhase), - 2 => meta.challenge_usable_after(ThirdPhase), - _ => unreachable!(), - } -} - -#[derive(Clone, Debug)] -pub struct ShuffleConfig { - l_0: Selector, - zs: Vec>, - gamma: Option, - lhs: Vec>, - rhs: Vec>, -} - -impl ShuffleConfig { - pub fn configure( - meta: &mut ConstraintSystem, - lhs: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec>, - rhs: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec>, - l_0: Option, - ) -> Self { - let (lhs, rhs, gamma) = { - let (lhs, rhs) = query(meta, |meta| { - let (lhs, rhs) = (lhs(meta), rhs(meta)); - assert_eq!(lhs.len(), rhs.len()); - (lhs, rhs) - }); - let phase = iter::empty() - .chain(lhs.iter()) - .chain(rhs.iter()) - .map(max_advice_phase) - .max() - .unwrap(); - - let gamma = challenge_usable_after(meta, phase); - - (lhs, rhs, gamma) - }; - let lhs_with_gamma = |meta: &mut VirtualCells<'_, F>| { - let gamma = meta.query_challenge(gamma); - lhs.into_iter().zip(iter::repeat(gamma)).collect() - }; - let rhs_with_gamma = |meta: &mut VirtualCells<'_, F>| { - let gamma = meta.query_challenge(gamma); - rhs.into_iter().zip(iter::repeat(gamma)).collect() - }; - let mut config = Self::configure_with_gamma( - meta, - lhs_with_gamma, - rhs_with_gamma, - |_| None, - |_| None, - l_0, - ); - config.gamma = Some(gamma); - config - } - - pub fn configure_with_gamma( - meta: &mut ConstraintSystem, - lhs_with_gamma: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec<(Expression, Expression)>, - rhs_with_gamma: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec<(Expression, Expression)>, - lhs_coeff: impl FnOnce(&mut VirtualCells<'_, F>) -> Option>, - rhs_coeff: impl FnOnce(&mut VirtualCells<'_, F>) -> Option>, - l_0: Option, - ) -> Self { - if ZK { - todo!() - } - - let (lhs_with_gamma, rhs_with_gamma, lhs_coeff, rhs_coeff) = query(meta, |meta| { - let lhs_with_gamma = lhs_with_gamma(meta); - let rhs_with_gamma = rhs_with_gamma(meta); - let lhs_coeff = lhs_coeff(meta); - let rhs_coeff = rhs_coeff(meta); - assert_eq!(lhs_with_gamma.len(), rhs_with_gamma.len()); - - (lhs_with_gamma, rhs_with_gamma, lhs_coeff, rhs_coeff) - }); - - let gamma_phase = iter::empty() - .chain(lhs_with_gamma.iter()) - .chain(rhs_with_gamma.iter()) - .map(|(value, _)| max_advice_phase(value)) - .max() - .unwrap(); - let z_phase = gamma_phase + 1; - assert!(!lhs_with_gamma - .iter() - .any(|(_, gamma)| gamma.degree() != 0 - || min_challenge_phase(gamma).unwrap() < gamma_phase)); - assert!(!rhs_with_gamma - .iter() - .any(|(_, gamma)| gamma.degree() != 0 - || min_challenge_phase(gamma).unwrap() < gamma_phase)); - - let [lhs_bins, rhs_bins] = [&lhs_with_gamma, &rhs_with_gamma].map(|value_with_gamma| { - first_fit_packing( - meta.degree::() - 1, - value_with_gamma - .iter() - .map(|(value, _)| value.degree()) - .collect(), - ) - }); - let num_z = lhs_bins.len().max(rhs_bins.len()); - - let l_0 = l_0.unwrap_or_else(|| meta.selector()); - let zs = iter::repeat_with(|| advice_column_in(meta, z_phase)) - .take(num_z) - .collect_vec(); - - let collect_contribution = |value_with_gamma: Vec<(Expression, Expression)>, - coeff: Option>, - bins: &[Vec]| { - let mut contribution = bins - .iter() - .map(|bin| { - bin.iter() - .map(|idx| value_with_gamma[*idx].clone()) - .map(|(value, gamma)| value + gamma) - .reduce(|acc, expr| acc * expr) - .unwrap() - }) - .collect_vec(); - - if let Some(coeff) = coeff { - contribution[0] = coeff * contribution[0].clone(); - } - - contribution - }; - let lhs = collect_contribution(lhs_with_gamma, lhs_coeff, &lhs_bins); - let rhs = collect_contribution(rhs_with_gamma, rhs_coeff, &rhs_bins); - - meta.create_gate("Shuffle", |meta| { - let l_0 = meta.query_selector(l_0); - let zs = iter::empty() - .chain(zs.iter().cloned().zip(iter::repeat(Rotation::cur()))) - .chain(Some((zs[0], Rotation::next()))) - .map(|(z, at)| meta.query_advice(z, at)) - .collect_vec(); - - let one = Expression::Constant(F::one()); - let z_0 = zs[0].clone(); - - iter::once(l_0 * (one - z_0)).chain( - lhs.clone() - .into_iter() - .zip_longest(rhs.clone()) - .zip(zs.clone().into_iter().zip(zs.into_iter().skip(1))) - .map(|(pair, (z_i, z_j))| match pair { - EitherOrBoth::Left(lhs) => z_i * lhs - z_j, - EitherOrBoth::Right(rhs) => z_i - z_j * rhs, - EitherOrBoth::Both(lhs, rhs) => z_i * lhs - z_j * rhs, - }), - ) - }); - - ShuffleConfig { - l_0, - zs, - gamma: None, - lhs, - rhs, - } - } - - pub fn assign(&self, mut layouter: impl Layouter, n: usize) -> Result<(), Error> { - if ZK { - todo!() - } - - let lhs = self - .lhs - .iter() - .map(|expression| layouter.evaluate_committed(expression)) - .fold(Value::known(Vec::new()), |acc, evaluated| { - acc.zip(evaluated).map(|(mut acc, evaluated)| { - acc.extend(evaluated); - acc - }) - }); - let rhs = self - .rhs - .iter() - .map(|expression| layouter.evaluate_committed(expression)) - .fold(Value::known(Vec::new()), |acc, evaluated| { - acc.zip(evaluated).map(|(mut acc, evaluated)| { - acc.extend(evaluated); - acc - }) - }); - - let z = lhs - .zip(rhs) - .map(|(lhs, mut rhs)| { - rhs.iter_mut().batch_invert(); - - let products = lhs - .into_iter() - .zip_longest(rhs) - .map(|pair| match pair { - EitherOrBoth::Left(value) | EitherOrBoth::Right(value) => value, - EitherOrBoth::Both(lhs, rhs) => lhs * rhs, - }) - .collect_vec(); - - let mut z = vec![F::one()]; - for i in 0..n { - for j in (i..).step_by(n).take(self.zs.len()) { - z.push(products[j] * z.last().unwrap()); - } - } - - let _last = z.pop().unwrap(); - #[cfg(feature = "sanity-check")] - assert_eq!(_last, F::one()); - - z - }) - .transpose_vec(self.zs.len() * n); - - layouter.assign_region( - || "zs", - |mut region| { - self.l_0.enable(&mut region, 0)?; - - let mut z = z.iter(); - for offset in 0..n { - for column in self.zs.iter() { - region.assign_advice(|| "", *column, offset, || *z.next().unwrap())?; - } - } - - Ok(()) - }, - ) - } -} - -fn binomial_coeffs(n: usize) -> Vec { - debug_assert!(n > 0); - - match n { - 1 => vec![1], - _ => { - let last_row = binomial_coeffs(n - 1); - iter::once(0) - .chain(last_row.iter().cloned()) - .zip(last_row.iter().cloned().chain(iter::once(0))) - .map(|(n, m)| n + m) - .collect() - } - } -} - -fn powers>(one: T, base: T) -> impl Iterator { - iter::successors(Some(one), move |power| Some(base.clone() * power.clone())) -} - -fn ordered_multiset(inputs: &[Vec], table: &[F]) -> Vec { - let mut input_counts = inputs - .iter() - .flatten() - .fold(BTreeMap::new(), |mut map, value| { - map.entry(value) - .and_modify(|count| *count += 1) - .or_insert(1); - map - }); - - let mut ordered = Vec::with_capacity((inputs.len() + 1) * inputs[0].len()); - for (count, value) in table.iter().dedup_with_count() { - let count = input_counts - .remove(value) - .map(|input_count| input_count + count) - .unwrap_or(count); - ordered.extend(iter::repeat(*value).take(count)); - } - - #[cfg(feature = "sanity-check")] - { - assert_eq!(input_counts.len(), 0); - assert_eq!(ordered.len(), ordered.capacity()); - } - - ordered.extend(iter::repeat(*ordered.last().unwrap()).take(ordered.capacity() - ordered.len())); - - ordered -} - -#[allow(dead_code)] -#[derive(Clone, Debug)] -pub struct PlookupConfig { - shuffle: ShuffleConfig, - compressed_inputs: Vec>, - compressed_table: Expression, - mixes: Vec>, - theta: Option, - beta: Challenge, - gamma: Challenge, -} - -impl PlookupConfig { - pub fn configure( - meta: &mut ConstraintSystem, - inputs: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec<[Expression; W]>, - table: [Column; W], - l_0: Option, - theta: Option, - beta: Option, - gamma: Option, - ) -> Self { - if ZK { - todo!() - } - - let inputs = query(meta, inputs); - let t = inputs.len(); - let theta_phase = iter::empty() - .chain(inputs.iter().flatten()) - .map(max_advice_phase) - .chain(table.iter().map(|column| { - Column::::try_from(*column) - .map(|column| column.column_type().phase()) - .unwrap_or_default() - })) - .max() - .unwrap(); - let mixes_phase = theta_phase + 1; - - let theta = if W > 1 { - Some(match theta { - Some(theta) => { - assert!(theta.phase() >= theta_phase); - theta - } - None => challenge_usable_after(meta, theta_phase), - }) - } else { - assert!(theta.is_none()); - None - }; - let mixes = iter::repeat_with(|| advice_column_in(meta, mixes_phase)) - .take(t + 1) - .collect_vec(); - let [beta, gamma] = [beta, gamma].map(|challenge| match challenge { - Some(challenge) => { - assert!(challenge.phase() >= mixes_phase); - challenge - } - None => challenge_usable_after(meta, mixes_phase), - }); - assert_ne!(theta, Some(beta)); - assert_ne!(theta, Some(gamma)); - assert_ne!(beta, gamma); - - let (compressed_inputs, compressed_table, compressed_table_w) = query(meta, |meta| { - let [table, table_w] = [Rotation::cur(), Rotation::next()] - .map(|at| table.map(|column| meta.query_any(column, at))); - let theta = theta.map(|theta| meta.query_challenge(theta)); - - let compressed_inputs = inputs - .iter() - .map(|input| { - input - .iter() - .cloned() - .reduce(|acc, expr| acc * theta.clone().unwrap() + expr) - .unwrap() - }) - .collect_vec(); - let compressed_table = table - .iter() - .cloned() - .reduce(|acc, expr| acc * theta.clone().unwrap() + expr) - .unwrap(); - let compressed_table_w = table_w - .iter() - .cloned() - .reduce(|acc, expr| acc * theta.clone().unwrap() + expr) - .unwrap(); - - (compressed_inputs, compressed_table, compressed_table_w) - }); - let lhs_with_gamma = |meta: &mut VirtualCells<'_, F>| { - let [beta, gamma] = [beta, gamma].map(|challenge| meta.query_challenge(challenge)); - let one = Expression::Constant(F::one()); - let gamma_prime = (one + beta.clone()) * gamma.clone(); - - let values = compressed_inputs.clone().into_iter().chain(Some( - compressed_table.clone() + compressed_table_w.clone() * beta, - )); - let gammas = iter::empty() - .chain(iter::repeat(gamma).take(t)) - .chain(Some(gamma_prime)); - values.zip(gammas).collect() - }; - let rhs_with_gamma = |meta: &mut VirtualCells<'_, F>| { - let mixes = iter::empty() - .chain(mixes.iter().cloned().zip(iter::repeat(Rotation::cur()))) - .chain(Some((mixes[0], Rotation::next()))) - .map(|(column, at)| meta.query_advice(column, at)) - .collect_vec(); - let [beta, gamma] = [beta, gamma].map(|challenge| meta.query_challenge(challenge)); - let one = Expression::Constant(F::one()); - let gamma_prime = (one + beta.clone()) * gamma; - - let values = mixes - .iter() - .cloned() - .zip(mixes.iter().skip(1).cloned()) - .zip(iter::repeat(beta)) - .map(|((mix_i, mix_j), beta)| mix_i + mix_j * beta); - let gammas = iter::repeat(gamma_prime).take(t + 1); - values.zip(gammas).collect() - }; - let lhs_coeff = |meta: &mut VirtualCells<'_, F>| { - let beta = meta.query_challenge(beta); - let one = Expression::Constant(F::one()); - binomial_coeffs(t + 1) - .into_iter() - .zip(powers(one, beta)) - .map(|(coeff, power_of_beta)| Expression::Constant(F::from(coeff)) * power_of_beta) - .reduce(|acc, expr| acc + expr) - }; - let shuffle = ShuffleConfig::configure_with_gamma( - meta, - lhs_with_gamma, - rhs_with_gamma, - lhs_coeff, - |_| None, - l_0, - ); - - Self { - shuffle, - compressed_inputs, - compressed_table, - mixes, - theta, - beta, - gamma, - } - } - - pub fn assign(&self, mut layouter: impl Layouter, n: usize) -> Result<(), Error> { - if ZK { - todo!() - } - - let compressed_inputs = self - .compressed_inputs - .iter() - .map(|expression| layouter.evaluate_committed(expression)) - .fold(Value::known(Vec::new()), |acc, compressed_input| { - acc.zip(compressed_input) - .map(|(mut acc, compressed_input)| { - acc.push(compressed_input); - acc - }) - }); - let compressed_table = layouter.evaluate_committed(&self.compressed_table); - - let mix = compressed_inputs - .zip(compressed_table.as_ref()) - .map(|(compressed_inputs, compressed_table)| { - ordered_multiset(&compressed_inputs, compressed_table) - }) - .transpose_vec(self.mixes.len() * n); - - layouter.assign_region( - || "mixes", - |mut region| { - let mut mix = mix.iter(); - for offset in 0..n { - for column in self.mixes.iter() { - region.assign_advice(|| "", *column, offset, || *mix.next().unwrap())?; - } - } - - Ok(()) - }, - )?; - - self.shuffle.assign(layouter.namespace(|| "Shuffle"), n)?; - - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::{PlookupConfig, ShuffleConfig}; - use crate::util::Itertools; - use halo2_curves::{bn256::Fr, FieldExt}; - use halo2_proofs::{ - circuit::{floor_planner::V1, Layouter, Value}, - dev::{metadata::Constraint, FailureLocation, MockProver, VerifyFailure}, - plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed}, - poly::Rotation, - }; - use rand::{rngs::OsRng, RngCore}; - use std::{iter, mem}; - - fn shuffled( - mut values: [Vec; T], - mut rng: R, - ) -> [Vec; T] { - let n = values[0].len(); - let mut swap = |lhs: usize, rhs: usize| { - let tmp = mem::take(&mut values[lhs / n][lhs % n]); - values[lhs / n][lhs % n] = mem::replace(&mut values[rhs / n][rhs % n], tmp); - }; - - for row in (1..n * T).rev() { - let rand_row = (rng.next_u32() as usize) % row; - swap(row, rand_row); - } - - values - } - - #[derive(Clone)] - pub struct Shuffler { - n: usize, - lhs: Value<[Vec; T]>, - rhs: Value<[Vec; T]>, - } - - impl Shuffler { - pub fn rand(k: u32, mut rng: R) -> Self { - let n = 1 << k; - let lhs = [(); T].map(|_| { - let rng = &mut rng; - iter::repeat_with(|| F::random(&mut *rng)) - .take(n) - .collect_vec() - }); - let rhs = shuffled( - lhs.iter() - .map(|lhs| lhs.iter().map(F::square).collect()) - .collect_vec() - .try_into() - .unwrap(), - rng, - ); - Self { - n, - lhs: Value::known(lhs), - rhs: Value::known(rhs), - } - } - } - - impl Circuit for Shuffler { - type Config = ( - [Column; T], - [Column; T], - ShuffleConfig, - ); - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self { - n: self.n, - lhs: Value::unknown(), - rhs: Value::unknown(), - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let lhs = [(); T].map(|_| meta.advice_column()); - let rhs = [(); T].map(|_| meta.advice_column()); - let shuffle = ShuffleConfig::configure( - meta, - |meta| { - lhs.map(|column| { - let lhs = meta.query_advice(column, Rotation::cur()); - lhs.clone() * lhs - }) - .to_vec() - }, - |meta| { - rhs.map(|column| meta.query_advice(column, Rotation::cur())) - .to_vec() - }, - None, - ); - - (lhs, rhs, shuffle) - } - - fn synthesize( - &self, - (lhs, rhs, shuffle): Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - layouter.assign_region( - || "", - |mut region| { - for (idx, column) in lhs.into_iter().enumerate() { - let values = self.lhs.as_ref().map(|lhs| lhs[idx].clone()); - for (offset, value) in - values.clone().transpose_vec(self.n).into_iter().enumerate() - { - region.assign_advice(|| "", column, offset, || value)?; - } - } - for (idx, column) in rhs.into_iter().enumerate() { - let values = self.rhs.as_ref().map(|rhs| rhs[idx].clone()); - for (offset, value) in - values.clone().transpose_vec(self.n).into_iter().enumerate() - { - region.assign_advice(|| "", column, offset, || value)?; - } - } - Ok(()) - }, - )?; - shuffle.assign(layouter.namespace(|| "Shuffle"), self.n)?; - - Ok(()) - } - } - - #[derive(Clone)] - pub struct Plookuper { - n: usize, - inputs: Value<[Vec<[F; W]>; T]>, - table: Vec<[F; W]>, - } - - impl Plookuper { - pub fn rand(k: u32, mut rng: R) -> Self { - let n = 1 << k; - let m = rng.next_u32() as usize % n; - let mut table = iter::repeat_with(|| [(); W].map(|_| F::random(&mut rng))) - .take(m) - .collect_vec(); - table.extend( - iter::repeat( - table - .first() - .cloned() - .unwrap_or_else(|| [(); W].map(|_| F::random(&mut rng))), - ) - .take(n - m), - ); - let inputs = [(); T].map(|_| { - iter::repeat_with(|| table[rng.next_u32() as usize % n]) - .take(n) - .collect() - }); - Self { - n, - inputs: Value::known(inputs), - table, - } - } - } - - impl Circuit - for Plookuper - { - type Config = ( - [[Column; W]; T], - [Column; W], - PlookupConfig, - ); - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self { - n: self.n, - inputs: Value::unknown(), - table: self.table.clone(), - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let inputs = [(); T].map(|_| [(); W].map(|_| meta.advice_column())); - let table = [(); W].map(|_| meta.fixed_column()); - let plookup = PlookupConfig::configure( - meta, - |meta| { - inputs - .iter() - .map(|input| input.map(|column| meta.query_advice(column, Rotation::cur()))) - .collect() - }, - table.map(|fixed| fixed.into()), - None, - None, - None, - None, - ); - - (inputs, table, plookup) - } - - fn synthesize( - &self, - (inputs, table, plookup): Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - layouter.assign_region( - || "", - |mut region| { - for (offset, value) in self.table.iter().enumerate() { - for (column, value) in table.iter().zip(value.iter()) { - region.assign_fixed(|| "", *column, offset, || Value::known(*value))?; - } - } - Ok(()) - }, - )?; - layouter.assign_region( - || "", - |mut region| { - for (idx, columns) in inputs.iter().enumerate() { - let values = self.inputs.as_ref().map(|inputs| inputs[idx].clone()); - for (offset, value) in values.transpose_vec(self.n).into_iter().enumerate() - { - for (column, value) in columns.iter().zip(value.transpose_array()) { - region.assign_advice(|| "", *column, offset, || value)?; - } - } - } - Ok(()) - }, - )?; - plookup.assign(layouter.namespace(|| "Plookup"), self.n)?; - Ok(()) - } - } - - #[allow(dead_code)] - fn assert_constraint_not_satisfied( - result: Result<(), Vec>, - failures: Vec<(Constraint, FailureLocation)>, - ) { - match result { - Err(expected) => { - assert_eq!( - expected - .into_iter() - .map(|failure| match failure { - VerifyFailure::ConstraintNotSatisfied { - constraint, - location, - .. - } => (constraint, location), - _ => panic!("MockProver::verify has unexpected failure"), - }) - .collect_vec(), - failures - ) - } - Ok(_) => { - panic!("MockProver::verify unexpectedly succeeds") - } - } - } - - #[test] - fn test_shuffle() { - const T: usize = 9; - const ZK: bool = false; - - let k = 9; - let circuit = Shuffler::::rand(k, OsRng); - - let mut cs = ConstraintSystem::default(); - Shuffler::::configure(&mut cs); - assert_eq!(cs.degree::(), 3); - - MockProver::run::<_, ZK>(k, &circuit, Vec::new()) - .unwrap() - .assert_satisfied(); - - #[cfg(not(feature = "sanity-check"))] - { - let n = 1 << k; - let mut circuit = circuit; - circuit.lhs = mem::take(&mut circuit.lhs).map(|mut value| { - value[0][0] += Fr::one(); - value - }); - assert_constraint_not_satisfied( - MockProver::run::<_, ZK>(k, &circuit, Vec::new()) - .unwrap() - .verify(), - vec![( - ( - (2, "Shuffle").into(), - (T * 2).div_ceil(cs.degree::() - 1), - "", - ) - .into(), - FailureLocation::InRegion { - region: (0, "").into(), - offset: n - 1, - }, - )], - ); - } - } - - #[test] - fn test_plookup() { - const W: usize = 2; - const T: usize = 5; - const ZK: bool = false; - - let k = 9; - let circuit = Plookuper::::rand(k, OsRng); - - let mut cs = ConstraintSystem::default(); - Plookuper::::configure(&mut cs); - assert_eq!(cs.degree::(), 3); - - MockProver::run::<_, ZK>(k, &circuit, Vec::new()) - .unwrap() - .assert_satisfied(); - - #[cfg(not(feature = "sanity-check"))] - { - let n = 1 << k; - let mut circuit = circuit; - circuit.inputs = mem::take(&mut circuit.inputs).map(|mut inputs| { - inputs[0][0][0] += Fr::one(); - inputs - }); - assert_constraint_not_satisfied( - MockProver::run::<_, ZK>(k, &circuit, Vec::new()) - .unwrap() - .verify(), - vec![( - ( - (3, "Shuffle").into(), - (T + 1).div_ceil(cs.degree::() - 1), - "", - ) - .into(), - FailureLocation::InRegion { - region: (0, "").into(), - offset: n - 1, - }, - )], - ); - } - } -} diff --git a/src/protocol/halo2/test/kzg.rs b/src/protocol/halo2/test/kzg.rs deleted file mode 100644 index 771e1a70..00000000 --- a/src/protocol/halo2/test/kzg.rs +++ /dev/null @@ -1,232 +0,0 @@ -use crate::{ - protocol::halo2::test::{MainGateWithPlookup, MainGateWithRange}, - util::fe_to_limbs, -}; -use halo2_curves::{pairing::Engine, CurveAffine}; -use halo2_proofs::poly::{ - commitment::{CommitmentScheme, Params, ParamsProver}, - kzg::commitment::{KZGCommitmentScheme, ParamsKZG}, -}; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::{fmt::Debug, fs}; - -mod halo2; -mod native; - -#[cfg(feature = "evm")] -mod evm; - -pub const LIMBS: usize = 4; -pub const BITS: usize = 68; - -pub fn read_or_create_srs(k: u32) -> ParamsKZG { - const DIR: &str = "./src/protocol/halo2/test/kzg/fixture"; - let path = format!("{}/k-{}.srs", DIR, k); - match fs::File::open(path.as_str()) { - Ok(mut file) => ParamsKZG::::read(&mut file).unwrap(), - Err(_) => { - fs::create_dir_all(DIR).unwrap(); - let params = - KZGCommitmentScheme::::new_params(k, ChaCha20Rng::from_seed(Default::default())); - let mut file = fs::File::create(path.as_str()).unwrap(); - params.write(&mut file).unwrap(); - params - } - } -} - -pub fn main_gate_with_range_with_mock_kzg_accumulator( -) -> MainGateWithRange { - let g = read_or_create_srs::(3).get_g(); - let [g1, s_g1] = [g[0], g[1]].map(|point| point.coordinates().unwrap()); - MainGateWithRange::new( - [*s_g1.x(), *s_g1.y(), *g1.x(), *g1.y()] - .iter() - .cloned() - .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) - .collect(), - ) -} - -pub fn main_gate_with_plookup_with_mock_kzg_accumulator( - k: u32, -) -> MainGateWithPlookup { - let g = read_or_create_srs::(3).get_g(); - let [g1, s_g1] = [g[0], g[1]].map(|point| point.coordinates().unwrap()); - MainGateWithPlookup::new( - k, - [*s_g1.x(), *s_g1.y(), *g1.x(), *g1.y()] - .iter() - .cloned() - .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) - .collect(), - ) -} - -#[macro_export] -macro_rules! halo2_kzg_config { - ($zk:expr, $num_proof:expr) => { - $crate::protocol::halo2::Config { - zk: $zk, - query_instance: false, - num_instance: Vec::new(), - num_proof: $num_proof, - accumulator_indices: None, - } - }; - ($zk:expr, $num_proof:expr, $accumulator_indices:expr) => { - $crate::protocol::halo2::Config { - zk: $zk, - query_instance: false, - num_instance: Vec::new(), - num_proof: $num_proof, - accumulator_indices: Some($accumulator_indices), - } - }; -} - -#[macro_export] -macro_rules! halo2_kzg_prepare { - ($k:expr, $config:expr, $create_circuit:expr) => {{ - use $crate::{ - protocol::halo2::{compile, test::kzg::read_or_create_srs}, - util::{GroupEncoding, Itertools}, - }; - use halo2_curves::bn256::{Bn256, G1}; - use halo2_proofs::{ - plonk::{keygen_pk, keygen_vk}, - poly::kzg::commitment::KZGCommitmentScheme, - }; - use std::{iter}; - - let circuits = iter::repeat_with(|| $create_circuit) - .take($config.num_proof) - .collect_vec(); - - let params = read_or_create_srs::($k); - let pk = if $config.zk { - let vk = keygen_vk::, _, true>(¶ms, &circuits[0]).unwrap(); - let pk = keygen_pk::, _, true>(¶ms, vk, &circuits[0]).unwrap(); - pk - } else { - let vk = keygen_vk::, _, false>(¶ms, &circuits[0]).unwrap(); - let pk = keygen_pk::, _, false>(¶ms, vk, &circuits[0]).unwrap(); - pk - }; - - let mut config = $config; - config.num_instance = circuits[0].instances().iter().map(|instances| instances.len()).collect(); - let protocol = compile::(pk.get_vk(), config); - assert_eq!( - protocol.preprocessed.len(), - protocol.preprocessed - .iter() - .map(|ec_point| <[u8; 32]>::try_from(ec_point.to_bytes().as_ref().to_vec()).unwrap()) - .unique() - .count() - ); - - (params, pk, protocol, circuits) - }}; -} - -#[macro_export] -macro_rules! halo2_kzg_create_snark { - ($params:expr, $pk:expr, $protocol:expr, $circuits:expr, $prover:ty, $verifier:ty, $verification_strategy:ty, $transcript_read:ty, $transcript_write:ty, $encoded_challenge:ty) => {{ - use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; - use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - use $crate::{ - collect_slice, - protocol::{halo2::test::create_proof_checked, Snark}, - util::Itertools, - }; - - let instances = $circuits - .iter() - .map(|circuit| circuit.instances()) - .collect_vec(); - let proof = { - collect_slice!(instances, 2); - #[allow(clippy::needless_borrow)] - if $protocol.zk { - create_proof_checked::< - KZGCommitmentScheme<_>, - _, - $prover, - $verifier, - $verification_strategy, - $transcript_read, - $transcript_write, - $encoded_challenge, - _, - true, - >( - $params, - $pk, - $circuits, - &instances, - &mut ChaCha20Rng::from_seed(Default::default()), - ) - } else { - create_proof_checked::< - KZGCommitmentScheme<_>, - _, - $prover, - $verifier, - $verification_strategy, - $transcript_read, - $transcript_write, - $encoded_challenge, - _, - false, - >( - $params, - $pk, - $circuits, - &instances, - &mut ChaCha20Rng::from_seed(Default::default()), - ) - } - }; - - Snark::new( - $protocol.clone(), - instances.into_iter().flatten().collect_vec(), - proof, - ) - }}; -} - -#[macro_export] -macro_rules! halo2_kzg_native_accumulate { - ($protocol:expr, $statements:expr, $scheme:ty, $transcript:expr, $stretagy:expr) => {{ - use $crate::{loader::native::NativeLoader, scheme::kzg::AccumulationScheme}; - - <$scheme>::accumulate( - $protocol, - &NativeLoader, - $statements, - $transcript, - $stretagy, - ) - .unwrap(); - }}; -} - -#[macro_export] -macro_rules! halo2_kzg_native_verify { - ($params:ident, $protocol:expr, $statements:expr, $scheme:ty, $transcript:expr) => {{ - use halo2_curves::bn256::Bn256; - use halo2_proofs::poly::commitment::ParamsProver; - use $crate::{ - halo2_kzg_native_accumulate, - protocol::halo2::test::kzg::{BITS, LIMBS}, - scheme::kzg::SameCurveAccumulation, - }; - - let mut stretagy = SameCurveAccumulation::<_, _, LIMBS, BITS>::default(); - halo2_kzg_native_accumulate!($protocol, $statements, $scheme, $transcript, &mut stretagy); - - assert!(stretagy.decide::($params.get_g()[0], $params.g2(), $params.s_g2())); - }}; -} diff --git a/src/protocol/halo2/test/kzg/evm.rs b/src/protocol/halo2/test/kzg/evm.rs deleted file mode 100644 index 6caf0e68..00000000 --- a/src/protocol/halo2/test/kzg/evm.rs +++ /dev/null @@ -1,168 +0,0 @@ -use crate::{ - halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_evm_verify, halo2_kzg_native_verify, - halo2_kzg_prepare, - loader::evm::EvmTranscript, - protocol::halo2::{ - test::{ - kzg::{ - halo2::Accumulation, main_gate_with_plookup_with_mock_kzg_accumulator, - main_gate_with_range_with_mock_kzg_accumulator, LIMBS, - }, - StandardPlonk, - }, - util::evm::ChallengeEvm, - }, - scheme::kzg::PlonkAccumulationScheme, -}; -use halo2_proofs::poly::kzg::{ - multiopen::{ProverGWC, VerifierGWC}, - strategy::AccumulatorStrategy, -}; -use paste::paste; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - -#[macro_export] -macro_rules! halo2_kzg_evm_verify { - ($params:expr, $protocol:expr, $statements:expr, $proof:expr, $scheme:ty) => {{ - use halo2_curves::bn256::{Fq, Fr}; - use halo2_proofs::poly::commitment::ParamsProver; - use std::{iter, rc::Rc}; - use $crate::{ - loader::evm::{encode_calldata, execute, EvmLoader, EvmTranscript}, - protocol::halo2::test::kzg::{BITS, LIMBS}, - scheme::kzg::{AccumulationScheme, SameCurveAccumulation}, - util::{Itertools, TranscriptRead}, - }; - - let loader = EvmLoader::new::(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); - let statements = $statements - .iter() - .map(|instance| { - iter::repeat_with(|| transcript.read_scalar().unwrap()) - .take(instance.len()) - .collect_vec() - }) - .collect_vec(); - let mut strategy = SameCurveAccumulation::<_, _, LIMBS, BITS>::default(); - <$scheme>::accumulate( - $protocol, - &loader, - statements, - &mut transcript, - &mut strategy, - ) - .unwrap(); - let code = strategy.code($params.get_g()[0], $params.g2(), $params.s_g2()); - let (accept, total_cost, costs) = execute(code, encode_calldata($statements, $proof)); - loader.print_gas_metering(costs); - println!("Total: {}", total_cost); - assert!(accept); - }}; -} - -macro_rules! test { - (@ #[$($attr:meta),*], $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - paste! { - $(#[$attr])* - fn []() { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - $k, - $config, - $create_circuit - ); - let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverGWC<_>, - VerifierGWC<_>, - AccumulatorStrategy<_>, - EvmTranscript<_, _, _, _>, - EvmTranscript<_, _, _, _>, - ChallengeEvm<_> - ); - halo2_kzg_native_verify!( - params, - &snark.protocol, - snark.statements.clone(), - PlonkAccumulationScheme, - &mut EvmTranscript::<_, NativeLoader, _, _>::new(snark.proof.as_slice()) - ); - halo2_kzg_evm_verify!( - params, - &snark.protocol, - snark.statements, - snark.proof, - PlonkAccumulationScheme - ); - } - } - }; - ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test], $name, $k, $config, $create_circuit); - }; - (#[ignore = $reason:literal], $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test, ignore = $reason], $name, $k, $config, $create_circuit); - }; -} - -test!( - zk_standard_plonk_rand, - 9, - halo2_kzg_config!(true, 1), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) -); -test!( - zk_main_gate_with_range_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_range_with_mock_kzg_accumulator::() -); -test!( - #[ignore = "cause it requires 16GB memory to run"], - zk_accumulation_two_snark, - 21, - halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - Accumulation::two_snark(true) -); -test!( - #[ignore = "cause it requires 32GB memory to run"], - zk_accumulation_two_snark_with_accumulator, - 22, - halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - Accumulation::two_snark_with_accumulator(true) -); -test!( - standard_plonk_rand, - 9, - halo2_kzg_config!(false, 1), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) -); -test!( - main_gate_with_range_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_range_with_mock_kzg_accumulator::() -); -test!( - main_gate_with_plookup_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_plookup_with_mock_kzg_accumulator::(9) -); -test!( - #[ignore = "cause it requires 16GB memory to run"], - accumulation_two_snark, - 21, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - Accumulation::two_snark(false) -); -test!( - #[ignore = "cause it requires 32GB memory to run"], - accumulation_two_snark_with_accumulator, - 22, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - Accumulation::two_snark_with_accumulator(false) -); diff --git a/src/protocol/halo2/test/kzg/halo2.rs b/src/protocol/halo2/test/kzg/halo2.rs deleted file mode 100644 index cd8e884b..00000000 --- a/src/protocol/halo2/test/kzg/halo2.rs +++ /dev/null @@ -1,380 +0,0 @@ -use crate::{ - collect_slice, halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_accumulate, - halo2_kzg_native_verify, halo2_kzg_prepare, - loader::{halo2, native::NativeLoader}, - protocol::{ - halo2::{ - test::{ - kzg::{BITS, LIMBS}, - MainGateWithRange, MainGateWithRangeConfig, StandardPlonk, - }, - util::halo2::ChallengeScalar, - }, - Protocol, Snark, - }, - scheme::kzg::{self, AccumulationScheme, ShplonkAccumulationScheme}, - util::{fe_to_limbs, Curve, Group, Itertools, PrimeCurveAffine}, -}; -use halo2_curves::bn256::{Fr, G1Affine, G1}; -use halo2_proofs::{ - circuit::{floor_planner::V1, Layouter, Value}, - plonk, - plonk::Circuit, - poly::{ - commitment::ParamsProver, - kzg::{ - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::AccumulatorStrategy, - }, - }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, -}; -use halo2_wrong_ecc::{self, maingate::RegionCtx}; -use halo2_wrong_transcript::NativeRepresentation; -use paste::paste; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::rc::Rc; - -const T: usize = 5; -const RATE: usize = 4; -const R_F: usize = 8; -const R_P: usize = 57; - -type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; -type Halo2Loader<'a, 'b, C> = halo2::Halo2Loader<'a, 'b, C, LIMBS, BITS>; -type PoseidonTranscript = - halo2::PoseidonTranscript; -type SameCurveAccumulation = kzg::SameCurveAccumulation; - -pub struct SnarkWitness { - protocol: Protocol, - statements: Vec::Scalar>>>, - proof: Value>, -} - -impl From> for SnarkWitness { - fn from(snark: Snark) -> Self { - Self { - protocol: snark.protocol, - statements: snark - .statements - .into_iter() - .map(|statements| statements.into_iter().map(Value::known).collect_vec()) - .collect(), - proof: Value::known(snark.proof), - } - } -} - -impl SnarkWitness { - pub fn without_witnesses(&self) -> Self { - SnarkWitness { - protocol: self.protocol.clone(), - statements: self - .statements - .iter() - .map(|statements| vec![Value::unknown(); statements.len()]) - .collect(), - proof: Value::unknown(), - } - } -} - -pub fn accumulate<'a, 'b>( - loader: &Rc>, - stretagy: &mut SameCurveAccumulation>>, - snark: &SnarkWitness, -) -> Result<(), plonk::Error> { - let mut transcript = PoseidonTranscript::<_, Rc>, _, _>::new( - loader, - snark.proof.as_ref().map(|proof| proof.as_slice()), - ); - let statements = snark - .statements - .iter() - .map(|statements| { - statements - .iter() - .map(|statement| loader.assign_scalar(*statement)) - .collect_vec() - }) - .collect_vec(); - ShplonkAccumulationScheme::accumulate( - &snark.protocol, - loader, - statements, - &mut transcript, - stretagy, - ) - .map_err(|_| plonk::Error::Synthesis)?; - Ok(()) -} - -pub struct Accumulation { - g1: G1Affine, - snarks: Vec>, - instances: Vec, -} - -impl Accumulation { - pub fn accumulator_indices() -> Vec<(usize, usize)> { - (0..4 * LIMBS).map(|idx| (0, idx)).collect() - } - - pub fn two_snark(zk: bool) -> Self { - const K: u32 = 9; - - let (params, snark1) = { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - K, - halo2_kzg_config!(zk, 1), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) - ); - let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - AccumulatorStrategy<_>, - PoseidonTranscript<_, _, _, _>, - PoseidonTranscript<_, _, _, _>, - ChallengeScalar<_> - ); - (params, snark) - }; - let snark2 = { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - K, - halo2_kzg_config!(zk, 1), - MainGateWithRange::<_>::rand(ChaCha20Rng::from_seed(Default::default())) - ); - halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - AccumulatorStrategy<_>, - PoseidonTranscript<_, _, _, _>, - PoseidonTranscript<_, _, _, _>, - ChallengeScalar<_> - ) - }; - - let mut strategy = SameCurveAccumulation::::default(); - halo2_kzg_native_accumulate!( - &snark1.protocol, - snark1.statements.clone(), - ShplonkAccumulationScheme, - &mut PoseidonTranscript::::init(snark1.proof.as_slice()), - &mut strategy - ); - halo2_kzg_native_accumulate!( - &snark2.protocol, - snark2.statements.clone(), - ShplonkAccumulationScheme, - &mut PoseidonTranscript::::init(snark2.proof.as_slice()), - &mut strategy - ); - - let g1 = params.get_g()[0]; - let accumulator = strategy.finalize(g1.to_curve()); - let instances = [ - accumulator.0.to_affine().x, - accumulator.0.to_affine().y, - accumulator.1.to_affine().x, - accumulator.1.to_affine().y, - ] - .map(fe_to_limbs::<_, _, LIMBS, BITS>) - .concat(); - - Self { - g1, - snarks: vec![snark1.into(), snark2.into()], - instances, - } - } - - pub fn two_snark_with_accumulator(zk: bool) -> Self { - const K: u32 = 21; - - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - K, - halo2_kzg_config!(zk, 2, Self::accumulator_indices()), - Self::two_snark(zk) - ); - let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - AccumulatorStrategy<_>, - PoseidonTranscript<_, _, _, _>, - PoseidonTranscript<_, _, _, _>, - ChallengeScalar<_> - ); - - let mut strategy = SameCurveAccumulation::::default(); - halo2_kzg_native_accumulate!( - &snark.protocol, - snark.statements.clone(), - ShplonkAccumulationScheme, - &mut PoseidonTranscript::::init(snark.proof.as_slice()), - &mut strategy - ); - - let g1 = params.get_g()[0]; - let accumulator = strategy.finalize(g1.to_curve()); - let instances = [ - accumulator.0.to_affine().x, - accumulator.0.to_affine().y, - accumulator.1.to_affine().x, - accumulator.1.to_affine().y, - ] - .map(fe_to_limbs::<_, _, LIMBS, BITS>) - .concat(); - - Self { - g1, - snarks: vec![snark.into()], - instances, - } - } - - pub fn instances(&self) -> Vec> { - vec![self.instances.clone()] - } -} - -impl Circuit for Accumulation { - type Config = MainGateWithRangeConfig; - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self { - g1: self.g1, - snarks: self - .snarks - .iter() - .map(SnarkWitness::without_witnesses) - .collect(), - instances: Vec::new(), - } - } - - fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - MainGateWithRangeConfig::configure::( - meta, - vec![BITS / LIMBS], - BaseFieldEccChip::::rns().overflow_lengths(), - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), plonk::Error> { - config.load_table(&mut layouter)?; - - let (lhs, rhs) = layouter.assign_region( - || "", - |mut region| { - let mut offset = 0; - let ctx = RegionCtx::new(&mut region, &mut offset); - - let loader = Halo2Loader::::new(config.ecc_config(), ctx); - let mut stretagy = SameCurveAccumulation::default(); - for snark in self.snarks.iter() { - accumulate(&loader, &mut stretagy, snark)?; - } - let (lhs, rhs) = stretagy.finalize(self.g1); - - loader.print_row_metering(); - println!("Total: {}", offset); - - Ok((lhs, rhs)) - }, - )?; - - let ecc_chip = BaseFieldEccChip::::new(config.ecc_config()); - ecc_chip.expose_public(layouter.namespace(|| ""), lhs, 0)?; - ecc_chip.expose_public(layouter.namespace(|| ""), rhs, 2 * LIMBS)?; - - Ok(()) - } -} - -macro_rules! test { - (@ #[$($attr:meta),*], $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - paste! { - $(#[$attr])* - fn []() { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - $k, - $config, - $create_circuit - ); - let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - AccumulatorStrategy<_>, - Blake2bWrite<_, _, _>, - Blake2bRead<_, _, _>, - Challenge255<_> - ); - halo2_kzg_native_verify!( - params, - &snark.protocol, - snark.statements, - ShplonkAccumulationScheme, - &mut Blake2bRead::<_, G1Affine, _>::init(snark.proof.as_slice()) - ); - } - } - }; - ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test], $name, $k, $config, $create_circuit); - }; - (#[ignore = $reason:literal], $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test, ignore = $reason], $name, $k, $config, $create_circuit); - }; -} - -test!( - #[ignore = "cause it requires 16GB memory to run"], - zk_accumulation_two_snark, - 21, - halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), - Accumulation::two_snark(true) -); -test!( - #[ignore = "cause it requires 32GB memory to run"], - zk_accumulation_two_snark_with_accumulator, - 22, - halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), - Accumulation::two_snark_with_accumulator(true) -); -test!( - #[ignore = "cause it requires 16GB memory to run"], - accumulation_two_snark, - 21, - halo2_kzg_config!(false, 1, Accumulation::accumulator_indices()), - Accumulation::two_snark(false) -); -test!( - #[ignore = "cause it requires 32GB memory to run"], - accumulation_two_snark_with_accumulator, - 22, - halo2_kzg_config!(false, 1, Accumulation::accumulator_indices()), - Accumulation::two_snark_with_accumulator(false) -); diff --git a/src/protocol/halo2/util.rs b/src/protocol/halo2/util.rs deleted file mode 100644 index 437d59cb..00000000 --- a/src/protocol/halo2/util.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::{ - loader::native::NativeLoader, - util::{ - Curve, Itertools, PrimeCurveAffine, PrimeField, Transcript, TranscriptRead, - UncompressedEncoding, - }, - Error, -}; -use halo2_proofs::{ - arithmetic::{CurveAffine, CurveExt}, - transcript::{Blake2bRead, Challenge255}, -}; -use std::{io::Read, iter}; - -pub mod halo2; - -#[cfg(feature = "evm")] -pub mod evm; - -impl UncompressedEncoding for C -where - ::Base: PrimeField, -{ - type Uncompressed = [u8; 64]; - - fn to_uncompressed(&self) -> [u8; 64] { - let coordinates = self.to_affine().coordinates().unwrap(); - iter::empty() - .chain(coordinates.x().to_repr().as_ref()) - .chain(coordinates.y().to_repr().as_ref()) - .cloned() - .collect_vec() - .try_into() - .unwrap() - } - - fn from_uncompressed(uncompressed: [u8; 64]) -> Option { - let x = Option::from(::Base::from_repr( - uncompressed[..32].to_vec().try_into().unwrap(), - ))?; - let y = Option::from(::Base::from_repr( - uncompressed[32..].to_vec().try_into().unwrap(), - ))?; - C::AffineExt::from_xy(x, y) - .map(|ec_point| ec_point.to_curve()) - .into() - } -} - -impl Transcript - for Blake2bRead> -{ - fn squeeze_challenge(&mut self) -> C::Scalar { - *halo2_proofs::transcript::Transcript::squeeze_challenge_scalar::(self) - } - - fn common_ec_point(&mut self, ec_point: &C::CurveExt) -> Result<(), Error> { - halo2_proofs::transcript::Transcript::common_point(self, ec_point.to_affine()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string())) - } - - fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { - halo2_proofs::transcript::Transcript::common_scalar(self, *scalar) - .map_err(|err| Error::Transcript(err.kind(), err.to_string())) - } -} - -impl TranscriptRead - for Blake2bRead> -{ - fn read_scalar(&mut self) -> Result { - halo2_proofs::transcript::TranscriptRead::read_scalar(self) - .map_err(|err| Error::Transcript(err.kind(), err.to_string())) - } - - fn read_ec_point(&mut self) -> Result { - halo2_proofs::transcript::TranscriptRead::read_point(self) - .map(|ec_point| ec_point.to_curve()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string())) - } -} diff --git a/src/protocol/halo2/util/evm.rs b/src/protocol/halo2/util/evm.rs deleted file mode 100644 index c8da8fe9..00000000 --- a/src/protocol/halo2/util/evm.rs +++ /dev/null @@ -1,142 +0,0 @@ -use crate::{ - loader::{ - evm::{u256_to_field, EvmTranscript}, - native::NativeLoader, - }, - util::{self, Curve, PrimeField, UncompressedEncoding}, - Error, -}; -use ethereum_types::U256; -use halo2_curves::{Coordinates, CurveAffine}; -use halo2_proofs::transcript::{ - EncodedChallenge, Transcript, TranscriptRead, TranscriptReadBuffer, TranscriptWrite, - TranscriptWriterBuffer, -}; -use std::io::{self, Read, Write}; - -pub struct ChallengeEvm(C::Scalar) -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField; - -impl EncodedChallenge for ChallengeEvm -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - type Input = [u8; 32]; - - fn new(challenge_input: &[u8; 32]) -> Self { - ChallengeEvm(u256_to_field(U256::from_big_endian(challenge_input))) - } - - fn get_scalar(&self) -> C::Scalar { - self.0 - } -} - -impl Transcript> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn squeeze_challenge(&mut self) -> ChallengeEvm { - ChallengeEvm(util::Transcript::squeeze_challenge(self)) - } - - fn common_point(&mut self, ec_point: C) -> io::Result<()> { - match util::Transcript::common_ec_point(self, &ec_point.to_curve()) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } - - fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - match util::Transcript::common_scalar(self, &scalar) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } -} - -impl TranscriptRead> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn read_point(&mut self) -> io::Result { - match util::TranscriptRead::read_ec_point(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value.to_affine()), - } - } - - fn read_scalar(&mut self) -> io::Result { - match util::TranscriptRead::read_scalar(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value), - } - } -} - -impl TranscriptReadBuffer> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn init(reader: R) -> Self { - Self::new(reader) - } -} - -impl TranscriptWrite> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn write_point(&mut self, ec_point: C) -> io::Result<()> { - Transcript::>::common_point(self, ec_point)?; - let coords: Coordinates = Option::from(ec_point.coordinates()).ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - "Cannot write points at infinity to the transcript", - ) - })?; - let mut x = coords.x().to_repr(); - let mut y = coords.y().to_repr(); - x.as_mut().reverse(); - y.as_mut().reverse(); - self.stream_mut().write_all(x.as_ref())?; - self.stream_mut().write_all(y.as_ref()) - } - - fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - Transcript::>::common_scalar(self, scalar)?; - let mut data = scalar.to_repr(); - data.as_mut().reverse(); - self.stream_mut().write_all(data.as_ref()) - } -} - -impl TranscriptWriterBuffer> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn init(writer: W) -> Self { - Self::new(writer) - } - - fn finalize(self) -> W { - self.finalize() - } -} diff --git a/src/protocol/halo2/util/halo2.rs b/src/protocol/halo2/util/halo2.rs deleted file mode 100644 index dfda15a8..00000000 --- a/src/protocol/halo2/util/halo2.rs +++ /dev/null @@ -1,212 +0,0 @@ -use crate::{ - loader::{halo2::PoseidonTranscript, native::NativeLoader}, - util::{self, Curve, PrimeField}, - Error, -}; -use halo2_curves::CurveAffine; -use halo2_proofs::transcript::{ - EncodedChallenge, Transcript, TranscriptRead, TranscriptReadBuffer, TranscriptWrite, - TranscriptWriterBuffer, -}; -use halo2_wrong_transcript::NativeRepresentation; -use poseidon::Poseidon; -use std::io::{self, Read, Write}; - -pub struct ChallengeScalar(C::Scalar); - -impl EncodedChallenge for ChallengeScalar { - type Input = C::Scalar; - - fn new(challenge_input: &C::Scalar) -> Self { - ChallengeScalar(*challenge_input) - } - - fn get_scalar(&self) -> C::Scalar { - self.0 - } -} - -impl< - C: CurveAffine, - S, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > Transcript> - for PoseidonTranscript< - C, - NativeLoader, - S, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn squeeze_challenge(&mut self) -> ChallengeScalar { - ChallengeScalar::new(&util::Transcript::squeeze_challenge(self)) - } - - fn common_point(&mut self, ec_point: C) -> io::Result<()> { - match util::Transcript::common_ec_point(self, &ec_point.to_curve()) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } - - fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - match util::Transcript::common_scalar(self, &scalar) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } -} - -impl< - C: CurveAffine, - R: Read, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead> - for PoseidonTranscript< - C, - NativeLoader, - R, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn read_point(&mut self) -> io::Result { - match util::TranscriptRead::read_ec_point(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value.to_affine()), - } - } - - fn read_scalar(&mut self) -> io::Result { - match util::TranscriptRead::read_scalar(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value), - } - } -} - -impl< - C: CurveAffine, - R: Read, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptReadBuffer> - for PoseidonTranscript< - C, - NativeLoader, - R, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn init(reader: R) -> Self { - Self::new(reader) - } -} - -impl< - C: CurveAffine, - W: Write, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptWrite> - for PoseidonTranscript< - C, - NativeLoader, - W, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn write_point(&mut self, ec_point: C) -> io::Result<()> { - Transcript::>::common_point(self, ec_point)?; - let data = ec_point.to_bytes(); - self.stream_mut().write_all(data.as_ref()) - } - - fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - Transcript::>::common_scalar(self, scalar)?; - let data = scalar.to_repr(); - self.stream_mut().write_all(data.as_ref()) - } -} - -impl< - C: CurveAffine, - W: Write, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptWriterBuffer> - for PoseidonTranscript< - C, - NativeLoader, - W, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn init(writer: W) -> Self { - Self::new(writer) - } - - fn finalize(self) -> W { - self.finalize() - } -} diff --git a/src/scheme.rs b/src/scheme.rs deleted file mode 100644 index d883d742..00000000 --- a/src/scheme.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod kzg; diff --git a/src/scheme/kzg.rs b/src/scheme/kzg.rs deleted file mode 100644 index e1f949ab..00000000 --- a/src/scheme/kzg.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::{ - protocol::Protocol, - util::{Curve, Expression}, -}; - -mod accumulation; -mod cost; -mod msm; - -pub use accumulation::{ - plonk::PlonkAccumulationScheme, shplonk::ShplonkAccumulationScheme, AccumulationScheme, - AccumulationStrategy, Accumulator, SameCurveAccumulation, -}; -pub use cost::{Cost, CostEstimation}; -pub use msm::MSM; - -pub fn langranges( - protocol: &Protocol, - statements: &[Vec], -) -> impl IntoIterator { - protocol - .relations - .iter() - .cloned() - .sum::>() - .used_langrange() - .into_iter() - .chain( - 0..statements - .iter() - .map(|statement| statement.len()) - .max() - .unwrap_or_default() as i32, - ) -} diff --git a/src/scheme/kzg/accumulation.rs b/src/scheme/kzg/accumulation.rs deleted file mode 100644 index 6931639d..00000000 --- a/src/scheme/kzg/accumulation.rs +++ /dev/null @@ -1,171 +0,0 @@ -use crate::{ - loader::Loader, - protocol::Protocol, - scheme::kzg::msm::MSM, - util::{Curve, Transcript}, - Error, -}; -use std::ops::{Add, AddAssign, Mul, MulAssign}; - -pub mod plonk; -pub mod shplonk; - -pub trait AccumulationScheme -where - C: Curve, - L: Loader, - T: Transcript, - S: AccumulationStrategy, -{ - type Proof; - - fn accumulate( - protocol: &Protocol, - loader: &L, - statements: Vec>, - transcript: &mut T, - strategy: &mut S, - ) -> Result; -} - -pub trait AccumulationStrategy -where - C: Curve, - L: Loader, - T: Transcript, -{ - type Output; - - fn extract_accumulator( - &self, - _: &Protocol, - _: &L, - _: &mut T, - _: &[Vec], - ) -> Option> { - None - } - - fn process( - &mut self, - loader: &L, - transcript: &mut T, - proof: P, - accumulator: Accumulator, - ) -> Result; -} - -#[derive(Clone, Debug)] -pub struct Accumulator -where - C: Curve, - L: Loader, -{ - lhs: MSM, - rhs: MSM, -} - -impl Accumulator -where - C: Curve, - L: Loader, -{ - pub fn new(lhs: MSM, rhs: MSM) -> Self { - Self { lhs, rhs } - } - - pub fn scale(&mut self, scalar: &L::LoadedScalar) { - self.lhs *= scalar; - self.rhs *= scalar; - } - - pub fn extend(&mut self, other: Self) { - self.lhs += other.lhs; - self.rhs += other.rhs; - } - - pub fn evaluate(self, g1: C) -> (L::LoadedEcPoint, L::LoadedEcPoint) { - (self.lhs.evaluate(g1), self.rhs.evaluate(g1)) - } - - pub fn random_linear_combine( - scaled_accumulators: impl IntoIterator, - ) -> Self { - scaled_accumulators - .into_iter() - .map(|(scalar, accumulator)| accumulator * &scalar) - .reduce(|acc, scaled_accumulator| acc + scaled_accumulator) - .unwrap_or_default() - } -} - -impl Default for Accumulator -where - C: Curve, - L: Loader, -{ - fn default() -> Self { - Self { - lhs: MSM::default(), - rhs: MSM::default(), - } - } -} - -impl Add for Accumulator -where - C: Curve, - L: Loader, -{ - type Output = Self; - - fn add(mut self, rhs: Self) -> Self::Output { - self.extend(rhs); - self - } -} - -impl AddAssign for Accumulator -where - C: Curve, - L: Loader, -{ - fn add_assign(&mut self, rhs: Self) { - self.extend(rhs); - } -} - -impl Mul<&L::LoadedScalar> for Accumulator -where - C: Curve, - L: Loader, -{ - type Output = Self; - - fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { - self.scale(rhs); - self - } -} - -impl MulAssign<&L::LoadedScalar> for Accumulator -where - C: Curve, - L: Loader, -{ - fn mul_assign(&mut self, rhs: &L::LoadedScalar) { - self.scale(rhs); - } -} - -pub struct SameCurveAccumulation, const LIMBS: usize, const BITS: usize> { - pub accumulator: Option>, -} - -impl, const LIMBS: usize, const BITS: usize> Default - for SameCurveAccumulation -{ - fn default() -> Self { - Self { accumulator: None } - } -} diff --git a/src/scheme/kzg/accumulation/plonk.rs b/src/scheme/kzg/accumulation/plonk.rs deleted file mode 100644 index 10c91eed..00000000 --- a/src/scheme/kzg/accumulation/plonk.rs +++ /dev/null @@ -1,373 +0,0 @@ -use crate::{ - loader::{LoadedScalar, Loader}, - protocol::Protocol, - scheme::kzg::{ - accumulation::{AccumulationScheme, AccumulationStrategy, Accumulator}, - cost::{Cost, CostEstimation}, - langranges, - msm::MSM, - }, - util::{ - CommonPolynomial, CommonPolynomialEvaluation, Curve, Expression, Field, Itertools, Query, - Rotation, TranscriptRead, - }, - Error, -}; -use std::{collections::HashMap, iter}; - -#[derive(Default)] -pub struct PlonkAccumulationScheme; - -impl AccumulationScheme for PlonkAccumulationScheme -where - C: Curve, - L: Loader, - T: TranscriptRead, - S: AccumulationStrategy>, -{ - type Proof = PlonkProof; - - fn accumulate( - protocol: &Protocol, - loader: &L, - statements: Vec>, - transcript: &mut T, - strategy: &mut S, - ) -> Result { - transcript.common_scalar(&loader.load_const(&protocol.transcript_initial_state))?; - - let proof = PlonkProof::read(protocol, statements, transcript)?; - let old_accumulator = - strategy.extract_accumulator(protocol, loader, transcript, &proof.statements); - - let common_poly_eval = { - let mut common_poly_eval = CommonPolynomialEvaluation::new( - &protocol.domain, - loader, - langranges(protocol, &proof.statements), - &proof.z, - ); - - L::LoadedScalar::batch_invert(common_poly_eval.denoms()); - - common_poly_eval - }; - - let commitments = proof.commitments(protocol, loader, &common_poly_eval); - let evaluations = proof.evaluations(protocol, loader, &common_poly_eval)?; - - let sets = rotation_sets(protocol); - let powers_of_u = &proof.u.powers(sets.len()); - let f = { - let powers_of_v = proof - .v - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); - sets.iter() - .map(|set| set.msm(&commitments, &evaluations, &powers_of_v)) - .zip(powers_of_u.iter()) - .map(|(msm, power_of_u)| msm * power_of_u) - .sum::>() - }; - let z_omegas = sets.iter().map(|set| { - loader.load_const( - &protocol - .domain - .rotate_scalar(C::Scalar::one(), set.rotation), - ) * &proof.z - }); - - let rhs = proof - .ws - .iter() - .zip(powers_of_u.iter()) - .map(|(w, power_of_u)| MSM::base(w.clone()) * power_of_u) - .collect_vec(); - let lhs = f + rhs - .iter() - .zip(z_omegas) - .map(|(uw, z_omega)| uw.clone() * &z_omega) - .sum(); - - let mut accumulator = Accumulator::new(lhs, rhs.into_iter().sum()); - if let Some(old_accumulator) = old_accumulator { - accumulator += old_accumulator; - } - strategy.process(loader, transcript, proof, accumulator) - } -} - -pub struct PlonkProof> { - statements: Vec>, - auxiliaries: Vec, - challenges: Vec, - alpha: L::LoadedScalar, - quotients: Vec, - z: L::LoadedScalar, - evaluations: Vec, - v: L::LoadedScalar, - ws: Vec, - u: L::LoadedScalar, -} - -impl> PlonkProof { - fn read>( - protocol: &Protocol, - statements: Vec>, - transcript: &mut T, - ) -> Result { - if protocol.num_statement - != statements - .iter() - .map(|statements| statements.len()) - .collect_vec() - { - return Err(Error::InvalidInstances); - } - for statements in statements.iter() { - for statement in statements.iter() { - transcript.common_scalar(statement)?; - } - } - - let (auxiliaries, challenges) = { - let (auxiliaries, challenges) = protocol - .num_auxiliary - .iter() - .zip(protocol.num_challenge.iter()) - .map(|(&n, &m)| { - Ok(( - transcript.read_n_ec_points(n)?, - transcript.squeeze_n_challenges(m), - )) - }) - .collect::, Error>>()? - .into_iter() - .unzip::<_, _, Vec<_>, Vec<_>>(); - - ( - auxiliaries.into_iter().flatten().collect_vec(), - challenges.into_iter().flatten().collect_vec(), - ) - }; - - let alpha = transcript.squeeze_challenge(); - let quotients = { - let max_degree = protocol - .relations - .iter() - .map(Expression::degree) - .max() - .unwrap(); - transcript.read_n_ec_points(max_degree - 1)? - }; - - let z = transcript.squeeze_challenge(); - let evaluations = transcript.read_n_scalars(protocol.evaluations.len())?; - - let v = transcript.squeeze_challenge(); - let ws = transcript.read_n_ec_points(rotation_sets(protocol).len())?; - let u = transcript.squeeze_challenge(); - - Ok(Self { - statements, - auxiliaries, - challenges, - alpha, - quotients, - z, - evaluations, - v, - ws, - u, - }) - } - - fn commitments( - &self, - protocol: &Protocol, - loader: &L, - common_poly_eval: &CommonPolynomialEvaluation, - ) -> HashMap> { - iter::empty() - .chain( - protocol - .preprocessed - .iter() - .map(|value| MSM::base(loader.ec_point_load_const(value))) - .enumerate(), - ) - .chain({ - let auxiliary_offset = protocol.preprocessed.len() + protocol.num_statement.len(); - self.auxiliaries - .iter() - .cloned() - .enumerate() - .map(move |(i, auxiliary)| (auxiliary_offset + i, MSM::base(auxiliary))) - }) - .chain(iter::once(( - protocol.vanishing_poly(), - common_poly_eval - .zn() - .powers(self.quotients.len()) - .into_iter() - .zip(self.quotients.iter().cloned().map(MSM::base)) - .map(|(coeff, piece)| piece * &coeff) - .sum(), - ))) - .collect() - } - - fn evaluations( - &self, - protocol: &Protocol, - loader: &L, - common_poly_eval: &CommonPolynomialEvaluation, - ) -> Result, Error> { - let statement_evaluations = self.statements.iter().map(|statements| { - L::LoadedScalar::sum( - &statements - .iter() - .enumerate() - .map(|(i, statement)| { - common_poly_eval.get(CommonPolynomial::Lagrange(i as i32)) * statement - }) - .collect_vec(), - ) - }); - let mut evaluations = HashMap::::from_iter( - iter::empty() - .chain( - statement_evaluations - .into_iter() - .enumerate() - .map(|(i, evaluation)| { - ( - Query { - poly: protocol.preprocessed.len() + i, - rotation: Rotation::cur(), - }, - evaluation, - ) - }), - ) - .chain( - protocol - .evaluations - .iter() - .cloned() - .zip(self.evaluations.iter().cloned()), - ), - ); - - let powers_of_alpha = self.alpha.powers(protocol.relations.len()); - let quotient_evaluation = L::LoadedScalar::sum( - &powers_of_alpha - .into_iter() - .rev() - .zip(protocol.relations.iter()) - .map(|(power_of_alpha, relation)| { - relation - .evaluate( - &|scalar| Ok(loader.load_const(&scalar)), - &|poly| Ok(common_poly_eval.get(poly)), - &|index| { - evaluations - .get(&index) - .cloned() - .ok_or(Error::MissingQuery(index)) - }, - &|index| { - self.challenges - .get(index) - .cloned() - .ok_or(Error::MissingChallenge(index)) - }, - &|a| a.map(|a| -a), - &|a, b| a.and_then(|a| Ok(a + b?)), - &|a, b| a.and_then(|a| Ok(a * b?)), - &|a, scalar| a.map(|a| a * loader.load_const(&scalar)), - ) - .map(|evaluation| power_of_alpha * evaluation) - }) - .collect::, Error>>()?, - ) * &common_poly_eval.zn_minus_one_inv(); - - evaluations.insert( - Query { - poly: protocol.vanishing_poly(), - rotation: Rotation::cur(), - }, - quotient_evaluation, - ); - - Ok(evaluations) - } -} - -struct RotationSet { - rotation: Rotation, - polys: Vec, -} - -impl RotationSet { - fn msm>( - &self, - commitments: &HashMap>, - evaluations: &HashMap, - powers_of_v: &[L::LoadedScalar], - ) -> MSM { - self.polys - .iter() - .map(|poly| { - let commitment = commitments.get(poly).unwrap().clone(); - let evalaution = evaluations - .get(&Query::new(*poly, self.rotation)) - .unwrap() - .clone(); - commitment - MSM::scalar(evalaution) - }) - .zip(powers_of_v.iter()) - .map(|(msm, power_of_v)| msm * power_of_v) - .sum() - } -} - -fn rotation_sets(protocol: &Protocol) -> Vec { - protocol.queries.iter().fold(Vec::new(), |mut sets, query| { - if let Some(pos) = sets.iter().position(|set| set.rotation == query.rotation) { - sets[pos].polys.push(query.poly) - } else { - sets.push(RotationSet { - rotation: query.rotation, - polys: vec![query.poly], - }) - } - sets - }) -} - -impl CostEstimation for PlonkAccumulationScheme { - fn estimate_cost(protocol: &Protocol) -> Cost { - let num_quotient = protocol - .relations - .iter() - .map(Expression::degree) - .max() - .unwrap() - - 1; - let num_w = rotation_sets(protocol).len(); - let num_accumulator = protocol - .accumulator_indices - .as_ref() - .map(|accumulator_indices| accumulator_indices.len()) - .unwrap_or_default(); - - let num_statement = protocol.num_statement.iter().sum(); - let num_commitment = protocol.num_auxiliary.iter().sum::() + num_quotient + num_w; - let num_evaluation = protocol.evaluations.len(); - let num_msm = - protocol.preprocessed.len() + num_commitment + 1 + num_w + 2 * num_accumulator; - - Cost::new(num_statement, num_commitment, num_evaluation, num_msm) - } -} diff --git a/src/scheme/kzg/accumulation/shplonk.rs b/src/scheme/kzg/accumulation/shplonk.rs deleted file mode 100644 index e9c31aa8..00000000 --- a/src/scheme/kzg/accumulation/shplonk.rs +++ /dev/null @@ -1,593 +0,0 @@ -use crate::{ - loader::{LoadedScalar, Loader}, - protocol::Protocol, - scheme::kzg::{ - accumulation::{AccumulationScheme, AccumulationStrategy, Accumulator}, - cost::{Cost, CostEstimation}, - langranges, - msm::MSM, - }, - util::{ - CommonPolynomial, CommonPolynomialEvaluation, Curve, Domain, Expression, Field, Fraction, - Itertools, Query, Rotation, TranscriptRead, - }, - Error, -}; -use std::{ - collections::{BTreeSet, HashMap}, - iter, -}; - -#[derive(Default)] -pub struct ShplonkAccumulationScheme; - -impl AccumulationScheme for ShplonkAccumulationScheme -where - C: Curve, - L: Loader, - T: TranscriptRead, - S: AccumulationStrategy>, -{ - type Proof = ShplonkProof; - - fn accumulate( - protocol: &Protocol, - loader: &L, - statements: Vec>, - transcript: &mut T, - strategy: &mut S, - ) -> Result { - transcript.common_scalar(&loader.load_const(&protocol.transcript_initial_state))?; - - let proof = ShplonkProof::read(protocol, statements, transcript)?; - let old_accumulator = - strategy.extract_accumulator(protocol, loader, transcript, &proof.statements); - - let (common_poly_eval, sets) = { - let mut common_poly_eval = CommonPolynomialEvaluation::new( - &protocol.domain, - loader, - langranges(protocol, &proof.statements), - &proof.z, - ); - let mut sets = intermediate_sets(protocol, loader, &proof.z, &proof.z_prime); - - L::LoadedScalar::batch_invert( - iter::empty() - .chain(common_poly_eval.denoms()) - .chain(sets.iter_mut().flat_map(IntermediateSet::denoms)), - ); - L::LoadedScalar::batch_invert(sets.iter_mut().flat_map(IntermediateSet::denoms)); - - (common_poly_eval, sets) - }; - - let commitments = proof.commitments(protocol, loader, &common_poly_eval); - let evaluations = proof.evaluations(protocol, loader, &common_poly_eval)?; - - let f = { - let powers_of_mu = proof - .mu - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); - let msms = sets - .iter() - .map(|set| set.msm(&commitments, &evaluations, &powers_of_mu)); - - msms.zip(proof.gamma.powers(sets.len()).into_iter()) - .map(|(msm, power_of_gamma)| msm * &power_of_gamma) - .sum::>() - - MSM::base(proof.w.clone()) * &sets[0].z_s - }; - - let rhs = MSM::base(proof.w_prime.clone()); - let lhs = f + rhs.clone() * &proof.z_prime; - - let mut accumulator = Accumulator::new(lhs, rhs); - if let Some(old_accumulator) = old_accumulator { - accumulator += old_accumulator; - } - strategy.process(loader, transcript, proof, accumulator) - } -} - -pub struct ShplonkProof> { - statements: Vec>, - auxiliaries: Vec, - challenges: Vec, - alpha: L::LoadedScalar, - quotients: Vec, - z: L::LoadedScalar, - evaluations: Vec, - mu: L::LoadedScalar, - gamma: L::LoadedScalar, - w: L::LoadedEcPoint, - z_prime: L::LoadedScalar, - w_prime: L::LoadedEcPoint, -} - -impl> ShplonkProof { - fn read>( - protocol: &Protocol, - statements: Vec>, - transcript: &mut T, - ) -> Result { - if protocol.num_statement - != statements - .iter() - .map(|statements| statements.len()) - .collect_vec() - { - return Err(Error::InvalidInstances); - } - for statements in statements.iter() { - for statement in statements.iter() { - transcript.common_scalar(statement)?; - } - } - - let (auxiliaries, challenges) = { - let (auxiliaries, challenges) = protocol - .num_auxiliary - .iter() - .zip(protocol.num_challenge.iter()) - .map(|(&n, &m)| { - Ok(( - transcript.read_n_ec_points(n)?, - transcript.squeeze_n_challenges(m), - )) - }) - .collect::, Error>>()? - .into_iter() - .unzip::<_, _, Vec<_>, Vec<_>>(); - - ( - auxiliaries.into_iter().flatten().collect_vec(), - challenges.into_iter().flatten().collect_vec(), - ) - }; - - let alpha = transcript.squeeze_challenge(); - let quotients = { - let max_degree = protocol - .relations - .iter() - .map(Expression::degree) - .max() - .unwrap(); - transcript.read_n_ec_points(max_degree - 1)? - }; - - let z = transcript.squeeze_challenge(); - let evaluations = transcript.read_n_scalars(protocol.evaluations.len())?; - - let mu = transcript.squeeze_challenge(); - let gamma = transcript.squeeze_challenge(); - let w = transcript.read_ec_point()?; - let z_prime = transcript.squeeze_challenge(); - let w_prime = transcript.read_ec_point()?; - - Ok(Self { - statements, - auxiliaries, - challenges, - alpha, - quotients, - z, - evaluations, - mu, - gamma, - w, - z_prime, - w_prime, - }) - } - - fn commitments( - &self, - protocol: &Protocol, - loader: &L, - common_poly_eval: &CommonPolynomialEvaluation, - ) -> HashMap> { - iter::empty() - .chain( - protocol - .preprocessed - .iter() - .map(|value| MSM::base(loader.ec_point_load_const(value))) - .enumerate(), - ) - .chain({ - let auxiliary_offset = protocol.preprocessed.len() + protocol.num_statement.len(); - self.auxiliaries - .iter() - .cloned() - .enumerate() - .map(move |(i, auxiliary)| (auxiliary_offset + i, MSM::base(auxiliary))) - }) - .chain(iter::once(( - protocol.vanishing_poly(), - common_poly_eval - .zn() - .powers(self.quotients.len()) - .into_iter() - .zip(self.quotients.iter().cloned().map(MSM::base)) - .map(|(coeff, piece)| piece * &coeff) - .sum(), - ))) - .collect() - } - - fn evaluations( - &self, - protocol: &Protocol, - loader: &L, - common_poly_eval: &CommonPolynomialEvaluation, - ) -> Result, Error> { - let statement_evaluations = self.statements.iter().map(|statements| { - L::LoadedScalar::sum( - &statements - .iter() - .enumerate() - .map(|(i, statement)| { - statement.clone() - * common_poly_eval.get(CommonPolynomial::Lagrange(i as i32)) - }) - .collect_vec(), - ) - }); - let mut evaluations = HashMap::::from_iter( - iter::empty() - .chain( - statement_evaluations - .into_iter() - .enumerate() - .map(|(i, evaluation)| { - ( - Query { - poly: protocol.preprocessed.len() + i, - rotation: Rotation::cur(), - }, - evaluation, - ) - }), - ) - .chain( - protocol - .evaluations - .iter() - .cloned() - .zip(self.evaluations.iter().cloned()), - ), - ); - - let powers_of_alpha = self.alpha.powers(protocol.relations.len()); - let quotient_evaluation = L::LoadedScalar::sum( - &powers_of_alpha - .into_iter() - .rev() - .zip(protocol.relations.iter()) - .map(|(power_of_alpha, relation)| { - relation - .evaluate( - &|scalar| Ok(loader.load_const(&scalar)), - &|poly| Ok(common_poly_eval.get(poly)), - &|index| { - evaluations - .get(&index) - .cloned() - .ok_or(Error::MissingQuery(index)) - }, - &|index| { - self.challenges - .get(index) - .cloned() - .ok_or(Error::MissingChallenge(index)) - }, - &|a| a.map(|a| -a), - &|a, b| a.and_then(|a| Ok(a + b?)), - &|a, b| a.and_then(|a| Ok(a * b?)), - &|a, scalar| a.map(|a| a * loader.load_const(&scalar)), - ) - .map(|evaluation| power_of_alpha * evaluation) - }) - .collect::, Error>>()?, - ) * &common_poly_eval.zn_minus_one_inv(); - - evaluations.insert( - Query { - poly: protocol.vanishing_poly(), - rotation: Rotation::cur(), - }, - quotient_evaluation, - ); - - Ok(evaluations) - } -} - -struct IntermediateSet> { - rotations: Vec, - polys: Vec, - z_s: L::LoadedScalar, - evaluation_coeffs: Vec>, - commitment_coeff: Option>, - remainder_coeff: Option>, -} - -impl> IntermediateSet { - fn new( - domain: &Domain, - loader: &L, - rotations: Vec, - powers_of_z: &[L::LoadedScalar], - z_prime: &L::LoadedScalar, - z_prime_minus_z_omega_i: &HashMap, - z_s_1: &Option, - ) -> Self { - let omegas = rotations - .iter() - .map(|rotation| domain.rotate_scalar(C::Scalar::one(), *rotation)) - .collect_vec(); - - let normalized_ell_primes = omegas - .iter() - .enumerate() - .map(|(j, omega_j)| { - omegas - .iter() - .enumerate() - .filter(|&(i, _)| i != j) - .fold(C::Scalar::one(), |acc, (_, omega_i)| { - acc * (*omega_j - omega_i) - }) - }) - .collect_vec(); - - let z = &powers_of_z[1].clone(); - let z_pow_k_minus_one = { - let k_minus_one = rotations.len() - 1; - powers_of_z.iter().enumerate().skip(1).fold( - loader.load_one(), - |acc, (i, power_of_z)| { - if k_minus_one & (1 << i) == 1 { - acc * power_of_z - } else { - acc - } - }, - ) - }; - - let barycentric_weights = omegas - .iter() - .zip(normalized_ell_primes.iter()) - .map(|(omega, normalized_ell_prime)| { - L::LoadedScalar::sum_products_with_coeff_and_constant( - &[ - ( - *normalized_ell_prime, - z_pow_k_minus_one.clone(), - z_prime.clone(), - ), - ( - -(*normalized_ell_prime * omega), - z_pow_k_minus_one.clone(), - z.clone(), - ), - ], - &C::Scalar::zero(), - ) - }) - .map(Fraction::one_over) - .collect_vec(); - - let z_s = rotations - .iter() - .map(|rotation| z_prime_minus_z_omega_i.get(rotation).unwrap().clone()) - .reduce(|acc, z_prime_minus_z_omega_i| acc * z_prime_minus_z_omega_i) - .unwrap(); - let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone())); - - Self { - rotations, - polys: Vec::new(), - z_s, - evaluation_coeffs: barycentric_weights, - commitment_coeff: z_s_1_over_z_s, - remainder_coeff: None, - } - } - - fn denoms(&mut self) -> impl IntoIterator { - if self.evaluation_coeffs.first().unwrap().denom().is_some() { - self.evaluation_coeffs - .iter_mut() - .chain(self.commitment_coeff.as_mut()) - .filter_map(Fraction::denom_mut) - .collect_vec() - } else if self.remainder_coeff.is_none() { - let barycentric_weights_sum = L::LoadedScalar::sum( - &self - .evaluation_coeffs - .iter() - .map(Fraction::evaluate) - .collect_vec(), - ); - self.remainder_coeff = Some(match self.commitment_coeff.clone() { - Some(coeff) => Fraction::new(coeff.evaluate(), barycentric_weights_sum), - None => Fraction::one_over(barycentric_weights_sum), - }); - vec![self.remainder_coeff.as_mut().unwrap().denom_mut().unwrap()] - } else { - unreachable!() - } - } - - fn msm( - &self, - commitments: &HashMap>, - evaluations: &HashMap, - powers_of_mu: &[L::LoadedScalar], - ) -> MSM { - self.polys - .iter() - .zip(powers_of_mu.iter()) - .map(|(poly, power_of_mu)| { - let commitment = self - .commitment_coeff - .as_ref() - .map(|commitment_coeff| { - commitments.get(poly).unwrap().clone() * &commitment_coeff.evaluate() - }) - .unwrap_or_else(|| commitments.get(poly).unwrap().clone()); - let remainder = self.remainder_coeff.as_ref().unwrap().evaluate() - * L::LoadedScalar::sum( - &self - .rotations - .iter() - .zip(self.evaluation_coeffs.iter()) - .map(|(rotation, coeff)| { - coeff.evaluate() - * evaluations - .get(&Query { - poly: *poly, - rotation: *rotation, - }) - .unwrap() - }) - .collect_vec(), - ); - (commitment - MSM::scalar(remainder)) * power_of_mu - }) - .sum() - } -} - -fn intermediate_sets>( - protocol: &Protocol, - loader: &L, - z: &L::LoadedScalar, - z_prime: &L::LoadedScalar, -) -> Vec> { - let rotations_sets = rotations_sets(protocol); - let superset = rotations_sets - .iter() - .flat_map(|set| set.rotations.clone()) - .sorted() - .dedup(); - - let size = 2.max( - (rotations_sets - .iter() - .map(|set| set.rotations.len()) - .max() - .unwrap() - - 1) - .next_power_of_two() - .log2() as usize - + 1, - ); - let powers_of_z = z.powers(size); - let z_prime_minus_z_omega_i = HashMap::from_iter( - superset - .map(|rotation| { - ( - rotation, - loader.load_const(&protocol.domain.rotate_scalar(C::Scalar::one(), rotation)), - ) - }) - .map(|(rotation, omega)| (rotation, z_prime.clone() - z.clone() * omega)), - ); - - let mut z_s_1 = None; - rotations_sets - .into_iter() - .map(|set| { - let intermetidate_set = IntermediateSet { - polys: set.polys, - ..IntermediateSet::new( - &protocol.domain, - loader, - set.rotations, - &powers_of_z, - z_prime, - &z_prime_minus_z_omega_i, - &z_s_1, - ) - }; - if z_s_1.is_none() { - z_s_1 = Some(intermetidate_set.z_s.clone()); - }; - intermetidate_set - }) - .collect() -} - -struct RotationsSet { - rotations: Vec, - polys: Vec, -} - -fn rotations_sets(protocol: &Protocol) -> Vec { - let poly_rotations = protocol.queries.iter().fold( - Vec::<(usize, Vec)>::new(), - |mut poly_rotations, query| { - if let Some(pos) = poly_rotations - .iter() - .position(|(poly, _)| *poly == query.poly) - { - let (_, rotations) = &mut poly_rotations[pos]; - if !rotations.contains(&query.rotation) { - rotations.push(query.rotation); - } - } else { - poly_rotations.push((query.poly, vec![query.rotation])); - } - poly_rotations - }, - ); - - poly_rotations - .into_iter() - .fold(Vec::::new(), |mut sets, (poly, rotations)| { - if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.rotations.iter()) == BTreeSet::from_iter(rotations.iter()) - }) { - let set = &mut sets[pos]; - if !set.polys.contains(&poly) { - set.polys.push(poly); - } - } else { - let set = RotationsSet { - rotations, - polys: vec![poly], - }; - sets.push(set); - } - sets - }) -} - -impl CostEstimation for ShplonkAccumulationScheme { - fn estimate_cost(protocol: &Protocol) -> Cost { - let num_quotient = protocol - .relations - .iter() - .map(Expression::degree) - .max() - .unwrap() - - 1; - let num_accumulator = protocol - .accumulator_indices - .as_ref() - .map(|accumulator_indices| accumulator_indices.len()) - .unwrap_or_default(); - - let num_statement = protocol.num_statement.iter().sum(); - let num_commitment = protocol.num_auxiliary.iter().sum::() + num_quotient + 2; - let num_evaluation = protocol.evaluations.len(); - let num_msm = protocol.preprocessed.len() + num_commitment + 3 + 2 * num_accumulator; - - Cost::new(num_statement, num_commitment, num_evaluation, num_msm) - } -} diff --git a/src/scheme/kzg/cost.rs b/src/scheme/kzg/cost.rs deleted file mode 100644 index f83f7dd1..00000000 --- a/src/scheme/kzg/cost.rs +++ /dev/null @@ -1,29 +0,0 @@ -use crate::{protocol::Protocol, util::Curve}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Cost { - pub num_statement: usize, - pub num_commitment: usize, - pub num_evaluation: usize, - pub num_msm: usize, -} - -impl Cost { - pub fn new( - num_statement: usize, - num_commitment: usize, - num_evaluation: usize, - num_msm: usize, - ) -> Self { - Self { - num_statement, - num_commitment, - num_evaluation, - num_msm, - } - } -} - -pub trait CostEstimation { - fn estimate_cost(protocol: &Protocol) -> Cost; -} diff --git a/src/scheme/kzg/msm.rs b/src/scheme/kzg/msm.rs deleted file mode 100644 index db5e4ce2..00000000 --- a/src/scheme/kzg/msm.rs +++ /dev/null @@ -1,149 +0,0 @@ -use crate::{ - loader::{LoadedEcPoint, Loader}, - util::Curve, -}; -use std::{ - default::Default, - iter::{self, Sum}, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, -}; - -#[derive(Clone, Debug)] -pub struct MSM> { - pub scalar: Option, - bases: Vec, - scalars: Vec, -} - -impl> Default for MSM { - fn default() -> Self { - Self { - scalar: None, - scalars: Vec::new(), - bases: Vec::new(), - } - } -} - -impl> MSM { - pub fn scalar(scalar: L::LoadedScalar) -> Self { - MSM { - scalar: Some(scalar), - ..Default::default() - } - } - - pub fn base(base: L::LoadedEcPoint) -> Self { - let one = base.loader().load_one(); - MSM { - scalars: vec![one], - bases: vec![base], - ..Default::default() - } - } - - pub fn evaluate(self, gen: C) -> L::LoadedEcPoint { - let gen = self - .bases - .first() - .unwrap() - .loader() - .ec_point_load_const(&gen); - L::LoadedEcPoint::multi_scalar_multiplication( - iter::empty() - .chain(self.scalar.map(|scalar| (scalar, gen))) - .chain(self.scalars.into_iter().zip(self.bases.into_iter())), - ) - } - - pub fn scale(&mut self, factor: &L::LoadedScalar) { - if let Some(scalar) = self.scalar.as_mut() { - *scalar *= factor; - } - for scalar in self.scalars.iter_mut() { - *scalar *= factor - } - } - - pub fn push(&mut self, scalar: L::LoadedScalar, base: L::LoadedEcPoint) { - if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) { - self.scalars[pos] += scalar; - } else { - self.scalars.push(scalar); - self.bases.push(base); - } - } - - pub fn extend(&mut self, mut other: Self) { - match (self.scalar.as_mut(), other.scalar.as_ref()) { - (Some(lhs), Some(rhs)) => *lhs += rhs, - (None, Some(_)) => self.scalar = other.scalar.take(), - _ => {} - }; - for (scalar, base) in other.scalars.into_iter().zip(other.bases) { - self.push(scalar, base); - } - } -} - -impl> Add> for MSM { - type Output = MSM; - - fn add(mut self, rhs: MSM) -> Self::Output { - self.extend(rhs); - self - } -} - -impl> AddAssign> for MSM { - fn add_assign(&mut self, rhs: MSM) { - self.extend(rhs); - } -} - -impl> Sub> for MSM { - type Output = MSM; - - fn sub(mut self, rhs: MSM) -> Self::Output { - self.extend(-rhs); - self - } -} - -impl> SubAssign> for MSM { - fn sub_assign(&mut self, rhs: MSM) { - self.extend(-rhs); - } -} - -impl> Mul<&L::LoadedScalar> for MSM { - type Output = MSM; - - fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { - self.scale(rhs); - self - } -} - -impl> MulAssign<&L::LoadedScalar> for MSM { - fn mul_assign(&mut self, rhs: &L::LoadedScalar) { - self.scale(rhs); - } -} - -impl> Neg for MSM { - type Output = MSM; - fn neg(mut self) -> MSM { - self.scalar = self.scalar.map(|scalar| -scalar); - for scalar in self.scalars.iter_mut() { - *scalar = -scalar.clone(); - } - self - } -} - -impl> Sum for MSM { - fn sum>(iter: I) -> Self { - iter.reduce(|acc, item| acc + item).unwrap_or_default() - } -} diff --git a/src/system.rs b/src/system.rs new file mode 100644 index 00000000..5d5aa99c --- /dev/null +++ b/src/system.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "system_halo2")] +pub mod halo2; diff --git a/src/protocol/halo2.rs b/src/system/halo2.rs similarity index 76% rename from src/protocol/halo2.rs rename to src/system/halo2.rs index 6c30bf5d..bf3e4091 100644 --- a/src/protocol/halo2.rs +++ b/src/system/halo2.rs @@ -1,40 +1,99 @@ use crate::{ - protocol::Protocol, - util::{CommonPolynomial, Domain, Expression, Itertools, Query, Rotation}, + util::{ + arithmetic::{root_of_unity, CurveAffine, Domain, FieldExt, Rotation}, + protocol::{ + CommonPolynomial, Expression, InstanceCommittingKey, Query, QuotientPolynomial, + }, + Itertools, + }, + Protocol, }; use halo2_proofs::{ - arithmetic::{CurveAffine, CurveExt, FieldExt}, plonk::{self, Any, ConstraintSystem, FirstPhase, SecondPhase, ThirdPhase, VerifyingKey}, - poly, + poly::{self, commitment::Params}, transcript::{EncodedChallenge, Transcript}, }; -use std::{io, iter}; +use num_integer::Integer; +use std::{io, iter, mem::size_of}; -mod util; +pub mod transcript; #[cfg(test)] -mod test; +pub(crate) mod test; +#[derive(Clone, Debug, Default)] pub struct Config { - zk: bool, - query_instance: bool, - num_instance: Vec, - num_proof: usize, - accumulator_indices: Option>, + pub zk: bool, + pub query_instance: bool, + pub num_proof: usize, + pub num_instance: Vec, + pub accumulator_indices: Option>, } -pub fn compile(vk: &VerifyingKey, config: Config) -> Protocol { +impl Config { + pub fn kzg() -> Self { + Self { + zk: true, + query_instance: false, + num_proof: 1, + ..Default::default() + } + } + + pub fn ipa() -> Self { + Self { + zk: true, + query_instance: true, + num_proof: 1, + ..Default::default() + } + } + + pub fn set_zk(mut self, zk: bool) -> Self { + self.zk = zk; + self + } + + pub fn set_query_instance(mut self, query_instance: bool) -> Self { + self.query_instance = query_instance; + self + } + + pub fn with_num_proof(mut self, num_proof: usize) -> Self { + assert!(num_proof > 0); + self.num_proof = num_proof; + self + } + + pub fn with_num_instance(mut self, num_instance: Vec) -> Self { + self.num_instance = num_instance; + self + } + + pub fn with_accumulator_indices(mut self, accumulator_indices: Vec<(usize, usize)>) -> Self { + self.accumulator_indices = Some(accumulator_indices); + self + } +} + +pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( + params: &P, + vk: &VerifyingKey, + config: Config, +) -> Protocol { + assert_eq!(vk.get_domain().k(), params.k()); + let cs = vk.cs(); let Config { zk, - num_instance, query_instance, num_proof, + num_instance, accumulator_indices, } = config; - let k = vk.get_domain().empty_lagrange().len().log2(); - let domain = Domain::new(k as usize); + let k = params.k() as usize; + let domain = Domain::new(k, root_of_unity(k)); let preprocessed = vk .fixed_commitments() @@ -66,35 +125,39 @@ pub fn compile(vk: &VerifyingKey, config: Config) -> }) .chain(polynomials.fixed_queries()) .chain(polynomials.permutation_fixed_queries()) - .chain(iter::once(polynomials.vanishing_query())) + .chain(iter::once(polynomials.quotient_query())) .chain(polynomials.random_query()) .collect(); - let relations = (0..num_proof) - .flat_map(|t| { - iter::empty() - .chain(polynomials.gate_relations(t)) - .chain(polynomials.permutation_relations(t)) - .chain(polynomials.lookup_relations(t)) - }) - .collect(); - let transcript_initial_state = transcript_initial_state::(vk); + let instance_committing_key = query_instance.then(|| { + instance_committing_key( + params, + polynomials + .num_instance() + .into_iter() + .max() + .unwrap_or_default(), + ) + }); + let accumulator_indices = accumulator_indices - .map(|accumulator_indices| polynomials.accumulator_indices(accumulator_indices)); + .map(|accumulator_indices| polynomials.accumulator_indices(accumulator_indices)) + .unwrap_or_default(); Protocol { - zk: config.zk, domain, preprocessed, - num_statement: polynomials.num_statement(), - num_auxiliary: polynomials.num_auxiliary(), + num_instance: polynomials.num_instance(), + num_witness: polynomials.num_witness(), num_challenge: polynomials.num_challenge(), evaluations, queries, - relations, - transcript_initial_state, + quotient: polynomials.quotient(), + transcript_initial_state: Some(transcript_initial_state), + instance_committing_key, + linearization: None, accumulator_indices, } } @@ -131,11 +194,8 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { num_instance: Vec, num_proof: usize, ) -> Self { - let degree = if zk { - cs.degree::() - } else { - cs.degree::() - }; + // TODO: Re-enable optional-zk when it's merged in pse/halo2. + let degree = if zk { cs.degree() } else { unimplemented!() }; let permutation_chunk_size = if zk || cs.permutation().get_columns().len() >= degree { degree - 2 } else { @@ -155,7 +215,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { state[*phase as usize] += 1; Some(index) }) - .collect_vec(); + .collect::>(); (num, index) }; @@ -164,8 +224,6 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { assert_eq!(num_advice.iter().sum::(), cs.num_advice_columns()); assert_eq!(num_challenge.iter().sum::(), cs.num_challenges()); - assert_eq!(cs.num_instance_columns(), num_instance.len()); - Self { cs, zk, @@ -180,11 +238,10 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { challenge_index, num_lookup_permuted: 2 * cs.lookups().len(), permutation_chunk_size, - num_permutation_z: cs - .permutation() - .get_columns() - .len() - .div_ceil(permutation_chunk_size), + num_permutation_z: Integer::div_ceil( + &cs.permutation().get_columns().len(), + &permutation_chunk_size, + ), num_lookup_z: cs.lookups().len(), } } @@ -193,14 +250,14 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { self.num_fixed + self.num_permutation_fixed } - fn num_statement(&self) -> Vec { + fn num_instance(&self) -> Vec { iter::repeat(self.num_instance.clone()) .take(self.num_proof) .flatten() .collect() } - fn num_auxiliary(&self) -> Vec { + fn num_witness(&self) -> Vec { iter::empty() .chain( self.num_advice @@ -222,7 +279,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .chain(num_challenge) .chain([ 2, // beta, gamma - 0, + 1, // alpha ]) .collect() } @@ -231,14 +288,14 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { self.num_preprocessed() } - fn auxiliary_offset(&self) -> usize { - self.instance_offset() + self.num_statement().len() + fn witness_offset(&self) -> usize { + self.instance_offset() + self.num_instance().len() } - fn cs_auxiliary_offset(&self) -> usize { - self.auxiliary_offset() + fn cs_witness_offset(&self) -> usize { + self.witness_offset() + self - .num_auxiliary() + .num_witness() .iter() .take(self.num_advice.len()) .sum::() @@ -260,9 +317,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { * self.num_advice[..advice.phase() as usize] .iter() .sum::(); - self.auxiliary_offset() - + phase_offset - + t * self.num_advice[advice.phase() as usize] + self.witness_offset() + phase_offset + t * self.num_advice[advice.phase() as usize] } }; Query::new(offset + column_index, rotation.into()) @@ -270,14 +325,14 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { fn instance_queries(&'a self, t: usize) -> impl IntoIterator + 'a { self.query_instance - .then_some( + .then(|| { self.cs .instance_queries() .iter() .map(move |(column, rotation)| { self.query(*column.column_type(), column.index(), *rotation, t) - }), - ) + }) + }) .into_iter() .flatten() } @@ -305,7 +360,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { } fn permutation_poly(&'a self, t: usize, i: usize) -> usize { - let z_offset = self.cs_auxiliary_offset() + self.num_auxiliary()[self.num_advice.len()]; + let z_offset = self.cs_witness_offset() + self.num_witness()[self.num_advice.len()]; z_offset + t * self.num_permutation_z + i } @@ -346,9 +401,9 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { } fn lookup_poly(&'a self, t: usize, i: usize) -> (usize, usize, usize) { - let permuted_offset = self.cs_auxiliary_offset(); + let permuted_offset = self.cs_witness_offset(); let z_offset = permuted_offset - + self.num_auxiliary()[self.num_advice.len()] + + self.num_witness()[self.num_advice.len()] + self.num_proof * self.num_permutation_z; let z = z_offset + t * self.num_lookup_z + i; let permuted_input = permuted_offset + 2 * (t * self.num_lookup_z + i); @@ -382,18 +437,20 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { }) } - fn vanishing_query(&self) -> Query { + fn quotient_query(&self) -> Query { Query::new( - self.auxiliary_offset() + self.num_auxiliary().iter().sum::(), + self.witness_offset() + self.num_witness().iter().sum::(), 0, ) } fn random_query(&self) -> Option { - self.zk.then_some(Query::new( - self.auxiliary_offset() + self.num_auxiliary().iter().sum::() - 1, - 0, - )) + self.zk.then(|| { + Query::new( + self.witness_offset() + self.num_witness().iter().sum::() - 1, + 0, + ) + }) } fn convert(&self, expression: &plonk::Expression, t: usize) -> Expression { @@ -435,7 +492,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { ) } - fn gate_relations(&'a self, t: usize) -> impl IntoIterator> + 'a { + fn gate_constraints(&'a self, t: usize) -> impl IntoIterator> + 'a { self.cs.gates().iter().flat_map(move |gate| { gate.polynomials() .iter() @@ -444,7 +501,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { } fn rotation_last(&self) -> Rotation { - Rotation(-((self.cs.blinding_factors::() + 1) as i32)) + Rotation(-((self.cs.blinding_factors() + 1) as i32)) } fn l_last(&self) -> Expression { @@ -483,7 +540,11 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { Expression::Challenge(self.system_challenge_offset() + 2) } - fn permutation_relations(&'a self, t: usize) -> impl IntoIterator> + 'a { + fn alpha(&self) -> Expression { + Expression::Challenge(self.system_challenge_offset() + 3) + } + + fn permutation_constraints(&'a self, t: usize) -> impl IntoIterator> + 'a { let one = &Expression::Constant(F::one()); let l_0 = &Expression::::CommonPolynomial(CommonPolynomial::Lagrange(0)); let l_last = &self.l_last(); @@ -519,7 +580,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .chain(zs.first().map(|(z_0, _, _)| l_0 * (one - z_0))) .chain( zs.last() - .and_then(|(z_l, _, _)| self.zk.then_some(l_last * (z_l * z_l - z_l))), + .and_then(|(z_l, _, _)| self.zk.then(|| l_last * (z_l * z_l - z_l))), ) .chain(if self.zk { zs.iter() @@ -575,12 +636,11 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .collect_vec() } - fn lookup_relations(&'a self, t: usize) -> impl IntoIterator> + 'a { + fn lookup_constraints(&'a self, t: usize) -> impl IntoIterator> + 'a { let one = &Expression::Constant(F::one()); let l_0 = &Expression::::CommonPolynomial(CommonPolynomial::Lagrange(0)); let l_last = &self.l_last(); let l_active = &self.l_active(); - let theta = &self.theta(); let beta = &self.beta(); let gamma = &self.gamma(); @@ -598,15 +658,13 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .collect_vec(); let compress = |expressions: &'a [plonk::Expression]| { - expressions - .iter() - .rev() - .zip(iter::successors(Some(one.clone()), |power_of_theta| { - Some(power_of_theta * theta) - })) - .map(|(expression, power_of_theta)| power_of_theta * self.convert(expression, t)) - .reduce(|acc, expr| acc + expr) - .unwrap() + Expression::DistributePowers( + expressions + .iter() + .map(|expression| self.convert(expression, t)) + .collect(), + self.theta().into(), + ) }; self.cs @@ -619,7 +677,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { let table = compress(lookup.table_expressions()); iter::empty() .chain(Some(l_0 * (one - z))) - .chain(self.zk.then_some(l_last * (z * z - z))) + .chain(self.zk.then(|| l_last * (z * z - z))) .chain(Some(if self.zk { l_active * (z_w * (permuted_input + beta) * (permuted_table + gamma) @@ -628,7 +686,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { z_w * (permuted_input + beta) * (permuted_table + gamma) - z * (input + beta) * (table + gamma) })) - .chain(self.zk.then_some(l_0 * (permuted_input - permuted_table))) + .chain(self.zk.then(|| l_0 * (permuted_input - permuted_table))) .chain(Some(if self.zk { l_active * (permuted_input - permuted_table) @@ -642,6 +700,22 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .collect_vec() } + fn quotient(&self) -> QuotientPolynomial { + let constraints = (0..self.num_proof) + .flat_map(|t| { + iter::empty() + .chain(self.gate_constraints(t)) + .chain(self.permutation_constraints(t)) + .chain(self.lookup_constraints(t)) + }) + .collect_vec(); + let numerator = Expression::DistributePowers(constraints, self.alpha().into()); + QuotientPolynomial { + chunk_degree: 1, + numerator, + } + } + fn accumulator_indices( &self, accumulator_indices: Vec<(usize, usize)>, @@ -690,8 +764,47 @@ impl Transcript for MockTranscript } } -fn transcript_initial_state(vk: &VerifyingKey) -> C::ScalarExt { +fn transcript_initial_state(vk: &VerifyingKey) -> C::Scalar { let mut transcript = MockTranscript::default(); vk.hash_into(&mut transcript).unwrap(); transcript.0 } + +fn instance_committing_key<'a, C: CurveAffine, P: Params<'a, C>>( + params: &P, + len: usize, +) -> InstanceCommittingKey { + let buf = { + let mut buf = Vec::new(); + params.write(&mut buf).unwrap(); + buf + }; + + let repr = C::Repr::default(); + let repr_len = repr.as_ref().len(); + let offset = size_of::() + (1 << params.k()) * repr_len; + + let bases = (offset..) + .step_by(repr_len) + .map(|offset| { + let mut repr = C::Repr::default(); + repr.as_mut() + .copy_from_slice(&buf[offset..offset + repr_len]); + C::from_bytes(&repr).unwrap() + }) + .take(len) + .collect(); + + let w = { + let offset = size_of::() + (2 << params.k()) * repr_len; + let mut repr = C::Repr::default(); + repr.as_mut() + .copy_from_slice(&buf[offset..offset + repr_len]); + C::from_bytes(&repr).unwrap() + }; + + InstanceCommittingKey { + bases, + constant: Some(w), + } +} diff --git a/src/system/halo2/test.rs b/src/system/halo2/test.rs new file mode 100644 index 00000000..9cd4a2fc --- /dev/null +++ b/src/system/halo2/test.rs @@ -0,0 +1,221 @@ +use crate::util::arithmetic::CurveAffine; +use halo2_proofs::{ + dev::MockProver, + plonk::{create_proof, verify_proof, Circuit, ProvingKey}, + poly::{ + commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier}, + VerificationStrategy, + }, + transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use rand_chacha::rand_core::RngCore; +use std::{fs, io::Cursor}; + +mod circuit; +mod kzg; + +pub use circuit::{ + maingate::{MainGateWithRange, MainGateWithRangeConfig}, + standard::StandardPlonk, +}; + +pub fn read_or_create_srs<'a, C: CurveAffine, P: ParamsProver<'a, C>>( + dir: &str, + k: u32, + setup: impl Fn(u32) -> P, +) -> P { + let path = format!("{}/k-{}.srs", dir, k); + match fs::File::open(path.as_str()) { + Ok(mut file) => P::read(&mut file).unwrap(), + Err(_) => { + fs::create_dir_all(dir).unwrap(); + let params = setup(k); + params.write(&mut fs::File::create(path).unwrap()).unwrap(); + params + } + } +} + +pub fn create_proof_checked<'a, S, C, P, V, VS, TW, TR, EC, R>( + params: &'a S::ParamsProver, + pk: &ProvingKey, + circuits: &[C], + instances: &[&[&[S::Scalar]]], + mut rng: R, + finalize: impl Fn(Vec, VS::Output) -> Vec, +) -> Vec +where + S: CommitmentScheme, + S::ParamsVerifier: 'a, + C: Circuit, + P: Prover<'a, S>, + V: Verifier<'a, S>, + VS: VerificationStrategy<'a, S, V>, + TW: TranscriptWriterBuffer, S::Curve, EC>, + TR: TranscriptReadBuffer>, S::Curve, EC>, + EC: EncodedChallenge, + R: RngCore, +{ + for (circuit, instances) in circuits.iter().zip(instances.iter()) { + MockProver::run( + params.k(), + circuit, + instances.iter().map(|instance| instance.to_vec()).collect(), + ) + .unwrap() + .assert_satisfied(); + } + + let proof = { + let mut transcript = TW::init(Vec::new()); + create_proof::( + params, + pk, + circuits, + instances, + &mut rng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let output = { + let params = params.verifier_params(); + let strategy = VS::new(params); + let mut transcript = TR::init(Cursor::new(proof.clone())); + verify_proof(params, pk.get_vk(), strategy, instances, &mut transcript).unwrap() + }; + + finalize(proof, output) +} + +macro_rules! halo2_prepare { + ($dir:expr, $k:expr, $setup:expr, $config:expr, $create_circuit:expr) => {{ + use halo2_proofs::plonk::{keygen_pk, keygen_vk}; + use std::iter; + use $crate::{ + system::halo2::{compile, test::read_or_create_srs}, + util::{arithmetic::GroupEncoding, Itertools}, + }; + + let params = read_or_create_srs($dir, $k, $setup); + + let circuits = iter::repeat_with(|| $create_circuit) + .take($config.num_proof) + .collect_vec(); + + let pk = if $config.zk { + let vk = keygen_vk(¶ms, &circuits[0]).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuits[0]).unwrap(); + pk + } else { + // TODO: Re-enable optional-zk when it's merged in pse/halo2. + unimplemented!() + }; + + let num_instance = circuits[0] + .instances() + .iter() + .map(|instances| instances.len()) + .collect(); + let protocol = compile( + ¶ms, + pk.get_vk(), + $config.with_num_instance(num_instance), + ); + assert_eq!( + protocol.preprocessed.len(), + protocol + .preprocessed + .iter() + .map( + |ec_point| <[u8; 32]>::try_from(ec_point.to_bytes().as_ref().to_vec()).unwrap() + ) + .unique() + .count() + ); + + (params, pk, protocol, circuits) + }}; +} + +macro_rules! halo2_create_snark { + ( + $commitment_scheme:ty, + $prover:ty, + $verifier:ty, + $verification_strategy:ty, + $transcript_read:ty, + $transcript_write:ty, + $encoded_challenge:ty, + $finalize:expr, + $params:expr, + $pk:expr, + $protocol:expr, + $circuits:expr + ) => {{ + use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + use $crate::{ + loader::halo2::test::Snark, system::halo2::test::create_proof_checked, util::Itertools, + }; + + let instances = $circuits + .iter() + .map(|circuit| circuit.instances()) + .collect_vec(); + let proof = { + #[allow(clippy::needless_borrow)] + let instances = instances + .iter() + .map(|instances| instances.iter().map(Vec::as_slice).collect_vec()) + .collect_vec(); + let instances = instances.iter().map(Vec::as_slice).collect_vec(); + create_proof_checked::< + $commitment_scheme, + _, + $prover, + $verifier, + $verification_strategy, + $transcript_read, + $transcript_write, + $encoded_challenge, + _, + >( + $params, + $pk, + $circuits, + &instances, + &mut ChaCha20Rng::from_seed(Default::default()), + $finalize, + ) + }; + + Snark::new( + $protocol.clone(), + instances.into_iter().flatten().collect_vec(), + proof, + ) + }}; +} + +macro_rules! halo2_native_verify { + ( + $plonk_verifier:ty, + $params:expr, + $protocol:expr, + $instances:expr, + $transcript:expr, + $svk:expr, + $dk:expr + ) => {{ + use halo2_proofs::poly::commitment::ParamsProver; + use $crate::verifier::PlonkVerifier; + + let proof = + <$plonk_verifier>::read_proof($svk, $protocol, $instances, $transcript).unwrap(); + assert!(<$plonk_verifier>::verify($svk, $dk, $protocol, $instances, &proof).unwrap()) + }}; +} + +pub(crate) use {halo2_create_snark, halo2_native_verify, halo2_prepare}; diff --git a/src/protocol/halo2/test/circuit.rs b/src/system/halo2/test/circuit.rs similarity index 67% rename from src/protocol/halo2/test/circuit.rs rename to src/system/halo2/test/circuit.rs index a87cb997..87a005fb 100644 --- a/src/protocol/halo2/test/circuit.rs +++ b/src/system/halo2/test/circuit.rs @@ -1,3 +1,2 @@ pub mod maingate; -pub mod plookup; pub mod standard; diff --git a/src/system/halo2/test/circuit/maingate.rs b/src/system/halo2/test/circuit/maingate.rs new file mode 100644 index 00000000..82d63b5e --- /dev/null +++ b/src/system/halo2/test/circuit/maingate.rs @@ -0,0 +1,111 @@ +use crate::util::arithmetic::{CurveAffine, FieldExt}; +use halo2_proofs::{ + circuit::{floor_planner::V1, Layouter, Value}, + plonk::{Circuit, ConstraintSystem, Error}, +}; +use halo2_wrong_ecc::{ + maingate::{ + MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, RangeInstructions, + RegionCtx, + }, + BaseFieldEccChip, EccConfig, +}; +use rand::RngCore; + +#[derive(Clone)] +pub struct MainGateWithRangeConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, +} + +impl MainGateWithRangeConfig { + pub fn configure( + meta: &mut ConstraintSystem, + composition_bits: Vec, + overflow_bits: Vec, + ) -> Self { + let main_gate_config = MainGate::::configure(meta); + let range_config = + RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); + MainGateWithRangeConfig { + main_gate_config, + range_config, + } + } + + pub fn main_gate(&self) -> MainGate { + MainGate::new(self.main_gate_config.clone()) + } + + pub fn range_chip(&self) -> RangeChip { + RangeChip::new(self.range_config.clone()) + } + + pub fn ecc_chip( + &self, + ) -> BaseFieldEccChip { + BaseFieldEccChip::new(EccConfig::new( + self.range_config.clone(), + self.main_gate_config.clone(), + )) + } +} + +#[derive(Clone, Default)] +pub struct MainGateWithRange(Vec); + +impl MainGateWithRange { + pub fn new(inner: Vec) -> Self { + Self(inner) + } + + pub fn rand(mut rng: R) -> Self { + Self::new(vec![F::from(rng.next_u32() as u64)]) + } + + pub fn instances(&self) -> Vec> { + vec![self.0.clone()] + } +} + +impl Circuit for MainGateWithRange { + type Config = MainGateWithRangeConfig; + type FloorPlanner = V1; + + fn without_witnesses(&self) -> Self { + Self(vec![F::zero()]) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + MainGateWithRangeConfig::configure(meta, vec![8], vec![4, 7]) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + range_chip.load_table(&mut layouter)?; + + let a = layouter.assign_region( + || "", + |region| { + let mut ctx = RegionCtx::new(region, 0); + range_chip.decompose(&mut ctx, Value::known(F::from(u64::MAX)), 8, 64)?; + range_chip.decompose(&mut ctx, Value::known(F::from(u32::MAX as u64)), 8, 39)?; + let a = range_chip.assign(&mut ctx, Value::known(self.0[0]), 8, 68)?; + let b = main_gate.sub_sub_with_constant(&mut ctx, &a, &a, &a, F::from(2))?; + let cond = main_gate.assign_bit(&mut ctx, Value::known(F::one()))?; + main_gate.select(&mut ctx, &a, &b, &cond)?; + + Ok(a) + }, + )?; + + main_gate.expose_public(layouter, a, 0)?; + + Ok(()) + } +} diff --git a/src/protocol/halo2/test/circuit/standard.rs b/src/system/halo2/test/circuit/standard.rs similarity index 69% rename from src/protocol/halo2/test/circuit/standard.rs rename to src/system/halo2/test/circuit/standard.rs index 0773f360..90f30f2b 100644 --- a/src/protocol/halo2/test/circuit/standard.rs +++ b/src/system/halo2/test/circuit/standard.rs @@ -1,7 +1,7 @@ +use crate::util::arithmetic::FieldExt; use halo2_proofs::{ - arithmetic::FieldExt, circuit::{floor_planner::V1, Layouter, Value}, - plonk::{Advice, Any, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, poly::Rotation, }; use rand::RngCore; @@ -22,39 +22,29 @@ pub struct StandardPlonkConfig { impl StandardPlonkConfig { pub fn configure(meta: &mut ConstraintSystem) -> Self { - let a = meta.advice_column(); - let b = meta.advice_column(); - let c = meta.advice_column(); - - let q_a = meta.fixed_column(); - let q_b = meta.fixed_column(); - let q_c = meta.fixed_column(); - - let q_ab = meta.fixed_column(); - - let constant = meta.fixed_column(); + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); let instance = meta.instance_column(); - meta.enable_equality(a); - meta.enable_equality(b); - meta.enable_equality(c); - - meta.create_gate("", |meta| { - let [a, b, c, q_a, q_b, q_c, q_ab, constant, instance] = [ - a.into(), - b.into(), - c.into(), - q_a.into(), - q_b.into(), - q_c.into(), - q_ab.into(), - constant.into(), - instance.into(), - ] - .map(|column: Column| meta.query_any(column, Rotation::cur())); - - vec![q_a * a.clone() + q_b * b.clone() + q_c * c + q_ab * a * b + constant + instance] - }); + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); StandardPlonkConfig { a, diff --git a/src/system/halo2/test/kzg.rs b/src/system/halo2/test/kzg.rs new file mode 100644 index 00000000..0b071175 --- /dev/null +++ b/src/system/halo2/test/kzg.rs @@ -0,0 +1,120 @@ +use crate::{ + system::halo2::test::{read_or_create_srs, MainGateWithRange}, + util::arithmetic::{fe_to_limbs, CurveAffine, MultiMillerLoop}, +}; +use halo2_proofs::poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}; +use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + +mod native; + +#[cfg(feature = "loader_evm")] +mod evm; + +#[cfg(feature = "loader_halo2")] +mod halo2; + +pub const TESTDATA_DIR: &str = "./src/system/halo2/test/kzg/testdata"; + +pub const LIMBS: usize = 4; +pub const BITS: usize = 68; + +pub fn setup(k: u32) -> ParamsKZG { + ParamsKZG::::setup(k, ChaCha20Rng::from_seed(Default::default())) +} + +pub fn main_gate_with_range_with_mock_kzg_accumulator( +) -> MainGateWithRange { + let srs = read_or_create_srs(TESTDATA_DIR, 1, setup::); + let [g1, s_g1] = [srs.get_g()[0], srs.get_g()[1]].map(|point| point.coordinates().unwrap()); + MainGateWithRange::new( + [*s_g1.x(), *s_g1.y(), *g1.x(), *g1.y()] + .iter() + .cloned() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .collect(), + ) +} + +macro_rules! halo2_kzg_config { + ($zk:expr, $num_proof:expr) => { + $crate::system::halo2::Config::kzg() + .set_zk($zk) + .with_num_proof($num_proof) + }; + ($zk:expr, $num_proof:expr, $accumulator_indices:expr) => { + $crate::system::halo2::Config::kzg() + .set_zk($zk) + .with_num_proof($num_proof) + .with_accumulator_indices($accumulator_indices) + }; +} + +macro_rules! halo2_kzg_prepare { + ($k:expr, $config:expr, $create_circuit:expr) => {{ + use halo2_curves::bn256::Bn256; + use $crate::system::halo2::test::{ + halo2_prepare, + kzg::{setup, TESTDATA_DIR}, + }; + + halo2_prepare!(TESTDATA_DIR, $k, setup::, $config, $create_circuit) + }}; +} + +macro_rules! halo2_kzg_create_snark { + ( + $prover:ty, + $verifier:ty, + $transcript_read:ty, + $transcript_write:ty, + $encoded_challenge:ty, + $params:expr, + $pk:expr, + $protocol:expr, + $circuits:expr + ) => {{ + use halo2_proofs::poly::kzg::{commitment::KZGCommitmentScheme, strategy::SingleStrategy}; + use $crate::system::halo2::test::halo2_create_snark; + + halo2_create_snark!( + KZGCommitmentScheme<_>, + $prover, + $verifier, + SingleStrategy<_>, + $transcript_read, + $transcript_write, + $encoded_challenge, + |proof, _| proof, + $params, + $pk, + $protocol, + $circuits + ) + }}; +} + +macro_rules! halo2_kzg_native_verify { + ( + $plonk_verifier:ty, + $params:expr, + $protocol:expr, + $instances:expr, + $transcript:expr + ) => {{ + use $crate::system::halo2::test::halo2_native_verify; + + halo2_native_verify!( + $plonk_verifier, + $params, + $protocol, + $instances, + $transcript, + &$params.get_g()[0].into(), + &($params.g2(), $params.s_g2()).into() + ) + }}; +} + +pub(crate) use { + halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, halo2_kzg_prepare, +}; diff --git a/src/system/halo2/test/kzg/evm.rs b/src/system/halo2/test/kzg/evm.rs new file mode 100644 index 00000000..4ce850c1 --- /dev/null +++ b/src/system/halo2/test/kzg/evm.rs @@ -0,0 +1,138 @@ +use crate::{ + loader::native::NativeLoader, + pcs::kzg::{Bdfg21, Gwc19, Kzg, LimbsEncoding}, + system::halo2::{ + test::{ + kzg::{ + self, halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, + halo2_kzg_prepare, main_gate_with_range_with_mock_kzg_accumulator, BITS, LIMBS, + }, + StandardPlonk, + }, + transcript::evm::{ChallengeEvm, EvmTranscript}, + }, + verifier::Plonk, +}; +use halo2_curves::bn256::{Bn256, G1Affine}; +use halo2_proofs::poly::kzg::multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}; +use paste::paste; +use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + +macro_rules! halo2_kzg_evm_verify { + ($plonk_verifier:ty, $params:expr, $protocol:expr, $instances:expr, $proof:expr) => {{ + use halo2_curves::bn256::{Bn256, Fq, Fr}; + use halo2_proofs::poly::commitment::ParamsProver; + use std::rc::Rc; + use $crate::{ + loader::evm::{encode_calldata, execute, EvmLoader}, + system::halo2::{ + test::kzg::{BITS, LIMBS}, + transcript::evm::EvmTranscript, + }, + util::Itertools, + verifier::PlonkVerifier, + }; + + let loader = EvmLoader::new::(); + let runtime_code = { + let svk = $params.get_g()[0].into(); + let dk = ($params.g2(), $params.s_g2()).into(); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let instances = transcript.load_instances( + $instances + .iter() + .map(|instances| instances.len()) + .collect_vec(), + ); + let proof = <$plonk_verifier>::read_proof(&svk, $protocol, &instances, &mut transcript) + .unwrap(); + <$plonk_verifier>::verify(&svk, &dk, $protocol, &instances, &proof).unwrap(); + + loader.runtime_code() + }; + + let (accept, total_cost, costs) = + execute(runtime_code, encode_calldata($instances, &$proof)); + + loader.print_gas_metering(costs); + println!("Total gas cost: {}", total_cost); + + assert!(accept); + }}; +} + +macro_rules! test { + (@ $(#[$attr:meta],)* $prefix:ident, $name:ident, $k:expr, $config:expr, $create_circuit:expr, $prover:ty, $verifier:ty, $plonk_verifier:ty) => { + paste! { + $(#[$attr])* + fn []() { + let (params, pk, protocol, circuits) = halo2_kzg_prepare!( + $k, + $config, + $create_circuit + ); + let snark = halo2_kzg_create_snark!( + $prover, + $verifier, + EvmTranscript, + EvmTranscript, + ChallengeEvm<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + halo2_kzg_native_verify!( + $plonk_verifier, + params, + &snark.protocol, + &snark.instances, + &mut EvmTranscript::<_, NativeLoader, _, _>::new(snark.proof.as_slice()) + ); + halo2_kzg_evm_verify!( + $plonk_verifier, + params, + &snark.protocol, + &snark.instances, + snark.proof + ); + } + } + }; + ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + test!(@ #[test], shplonk, $name, $k, $config, $create_circuit, ProverSHPLONK<_>, VerifierSHPLONK<_>, Plonk, LimbsEncoding>); + test!(@ #[test], plonk, $name, $k, $config, $create_circuit, ProverGWC<_>, VerifierGWC<_>, Plonk, LimbsEncoding>); + }; + ($(#[$attr:meta],)* $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + test!(@ #[test] $(,#[$attr])*, plonk, $name, $k, $config, $create_circuit, ProverGWC<_>, VerifierGWC<_>, Plonk, LimbsEncoding>); + }; +} + +test!( + zk_standard_plonk_rand, + 9, + halo2_kzg_config!(true, 1), + StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) +); +test!( + zk_main_gate_with_range_with_mock_kzg_accumulator, + 9, + halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + main_gate_with_range_with_mock_kzg_accumulator::() +); +test!( + #[cfg(feature = "loader_halo2")], + #[ignore = "cause it requires 32GB memory to run"], + zk_accumulation_two_snark, + 22, + halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + kzg::halo2::Accumulation::two_snark() +); +test!( + #[cfg(feature = "loader_halo2")], + #[ignore = "cause it requires 32GB memory to run"], + zk_accumulation_two_snark_with_accumulator, + 22, + halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + kzg::halo2::Accumulation::two_snark_with_accumulator() +); diff --git a/src/system/halo2/test/kzg/halo2.rs b/src/system/halo2/test/kzg/halo2.rs new file mode 100644 index 00000000..1bd332a0 --- /dev/null +++ b/src/system/halo2/test/kzg/halo2.rs @@ -0,0 +1,372 @@ +use crate::{ + loader, + loader::{ + halo2::test::{Snark, SnarkWitness}, + native::NativeLoader, + }, + pcs::{ + kzg::{ + Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgAsProvingKey, KzgAsVerifyingKey, + KzgSuccinctVerifyingKey, LimbsEncoding, + }, + AccumulationScheme, AccumulationSchemeProver, + }, + system::{ + self, + halo2::{ + test::{ + kzg::{ + halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, + halo2_kzg_prepare, BITS, LIMBS, + }, + MainGateWithRange, MainGateWithRangeConfig, StandardPlonk, + }, + transcript::halo2::ChallengeScalar, + }, + }, + util::{arithmetic::fe_to_limbs, Itertools}, + verifier::{self, PlonkVerifier}, +}; +use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use halo2_proofs::{ + circuit::{floor_planner::V1, Layouter, Value}, + plonk, + plonk::Circuit, + poly::{ + commitment::ParamsProver, + kzg::{ + commitment::ParamsKZG, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + }, + }, + transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, +}; +use halo2_wrong_ecc::{ + self, + integer::rns::Rns, + maingate::{MainGateInstructions, RangeInstructions, RegionCtx}, +}; +use paste::paste; +use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; +use std::{iter, rc::Rc}; + +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; +type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; +type PoseidonTranscript = + system::halo2::transcript::halo2::PoseidonTranscript; + +type Pcs = Kzg; +type Svk = KzgSuccinctVerifyingKey; +type As = KzgAs; +type AsPk = KzgAsProvingKey; +type AsVk = KzgAsVerifyingKey; +type Plonk = verifier::Plonk>; + +pub fn accumulate<'a>( + svk: &Svk, + loader: &Rc>, + snarks: &[SnarkWitness], + as_vk: &AsVk, + as_proof: Value<&'_ [u8]>, +) -> KzgAccumulator>> { + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances + .iter() + .map(|instance| loader.assign_scalar(*instance)) + .collect_vec() + }) + .collect_vec() + }; + + let mut accumulators = snarks + .iter() + .flat_map(|snark| { + let instances = assign_instances(&snark.instances); + let mut transcript = + PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = + Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + }) + .collect_vec(); + + let acccumulator = if accumulators.len() > 1 { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = As::read_proof(as_vk, &accumulators, &mut transcript).unwrap(); + As::verify(as_vk, &accumulators, &proof).unwrap() + } else { + accumulators.pop().unwrap() + }; + + acccumulator +} + +pub struct Accumulation { + svk: Svk, + snarks: Vec>, + instances: Vec, + as_vk: AsVk, + as_proof: Value>, +} + +impl Accumulation { + pub fn accumulator_indices() -> Vec<(usize, usize)> { + (0..4 * LIMBS).map(|idx| (0, idx)).collect() + } + + pub fn new( + params: &ParamsKZG, + snarks: impl IntoIterator>, + ) -> Self { + let svk = params.get_g()[0].into(); + let snarks = snarks.into_iter().collect_vec(); + + let mut accumulators = snarks + .iter() + .flat_map(|snark| { + let mut transcript = + PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }) + .collect_vec(); + + let as_pk = AsPk::new(Some((params.get_g()[0], params.get_g()[1]))); + let (accumulator, as_proof) = if accumulators.len() > 1 { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = As::create_proof( + &as_pk, + &accumulators, + &mut transcript, + ChaCha20Rng::from_seed(Default::default()), + ) + .unwrap(); + (accumulator, Value::known(transcript.finalize())) + } else { + (accumulators.pop().unwrap(), Value::unknown()) + }; + + let KzgAccumulator { lhs, rhs } = accumulator; + let instances = [lhs.x, lhs.y, rhs.x, rhs.y] + .map(fe_to_limbs::<_, _, LIMBS, BITS>) + .concat(); + + Self { + svk, + snarks: snarks.into_iter().map_into().collect(), + instances, + as_vk: as_pk.vk(), + as_proof, + } + } + + pub fn two_snark() -> Self { + let (params, snark1) = { + const K: u32 = 9; + let (params, pk, protocol, circuits) = halo2_kzg_prepare!( + K, + halo2_kzg_config!(true, 1), + StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) + ); + let snark = halo2_kzg_create_snark!( + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + PoseidonTranscript<_, _>, + PoseidonTranscript<_, _>, + ChallengeScalar<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + (params, snark) + }; + let snark2 = { + const K: u32 = 9; + let (params, pk, protocol, circuits) = halo2_kzg_prepare!( + K, + halo2_kzg_config!(true, 1), + MainGateWithRange::rand(ChaCha20Rng::from_seed(Default::default())) + ); + halo2_kzg_create_snark!( + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + PoseidonTranscript<_, _>, + PoseidonTranscript<_, _>, + ChallengeScalar<_>, + ¶ms, + &pk, + &protocol, + &circuits + ) + }; + Self::new(¶ms, [snark1, snark2]) + } + + pub fn two_snark_with_accumulator() -> Self { + let (params, pk, protocol, circuits) = { + const K: u32 = 22; + halo2_kzg_prepare!( + K, + halo2_kzg_config!(true, 2, Self::accumulator_indices()), + Self::two_snark() + ) + }; + let snark = halo2_kzg_create_snark!( + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + PoseidonTranscript<_, _>, + PoseidonTranscript<_, _>, + ChallengeScalar<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + Self::new(¶ms, [snark]) + } + + pub fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + pub fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } +} + +impl Circuit for Accumulation { + type Config = MainGateWithRangeConfig; + type FloorPlanner = V1; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + snarks: self + .snarks + .iter() + .map(SnarkWitness::without_witnesses) + .collect(), + instances: Vec::new(), + as_vk: self.as_vk, + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + MainGateWithRangeConfig::configure( + meta, + vec![BITS / LIMBS], + Rns::::construct().overflow_lengths(), + ) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + + range_chip.load_table(&mut layouter)?; + + let (lhs, rhs) = layouter.assign_region( + || "", + |region| { + let ctx = RegionCtx::new(region, 0); + + let ecc_chip = config.ecc_chip(); + let loader = Halo2Loader::new(ecc_chip, ctx); + let KzgAccumulator { lhs, rhs } = accumulate( + &self.svk, + &loader, + &self.snarks, + &self.as_vk, + self.as_proof(), + ); + + loader.print_row_metering(); + println!("Total row cost: {}", loader.ctx().offset()); + + Ok((lhs.assigned(), rhs.assigned())) + }, + )?; + + for (limb, row) in iter::empty() + .chain(lhs.x().limbs()) + .chain(lhs.y().limbs()) + .chain(rhs.x().limbs()) + .chain(rhs.y().limbs()) + .zip(0..) + { + main_gate.expose_public(layouter.namespace(|| ""), limb.into(), row)?; + } + + Ok(()) + } +} + +macro_rules! test { + (@ $(#[$attr:meta],)* $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + paste! { + $(#[$attr])* + fn []() { + let (params, pk, protocol, circuits) = halo2_kzg_prepare!( + $k, + $config, + $create_circuit + ); + let snark = halo2_kzg_create_snark!( + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + Blake2bWrite<_, _, _>, + Blake2bRead<_, _, _>, + Challenge255<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + halo2_kzg_native_verify!( + Plonk, + params, + &snark.protocol, + &snark.instances, + &mut Blake2bRead::<_, G1Affine, _>::init(snark.proof.as_slice()) + ); + } + } + }; + ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + test!(@ #[test], $name, $k, $config, $create_circuit); + }; + ($(#[$attr:meta],)* $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + test!(@ #[test] $(,#[$attr])*, $name, $k, $config, $create_circuit); + }; +} + +test!( + #[ignore = "cause it requires 32GB memory to run"], + zk_accumulation_two_snark, + 22, + halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), + Accumulation::two_snark() +); +test!( + #[ignore = "cause it requires 32GB memory to run"], + zk_accumulation_two_snark_with_accumulator, + 22, + halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), + Accumulation::two_snark_with_accumulator() +); diff --git a/src/protocol/halo2/test/kzg/native.rs b/src/system/halo2/test/kzg/native.rs similarity index 50% rename from src/protocol/halo2/test/kzg/native.rs rename to src/system/halo2/test/kzg/native.rs index 4273e7d0..e52ceb38 100644 --- a/src/protocol/halo2/test/kzg/native.rs +++ b/src/system/halo2/test/kzg/native.rs @@ -1,61 +1,56 @@ use crate::{ - collect_slice, halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, - halo2_kzg_prepare, - protocol::halo2::test::{ + pcs::kzg::{Bdfg21, Gwc19, Kzg, LimbsEncoding}, + system::halo2::test::{ kzg::{ - main_gate_with_plookup_with_mock_kzg_accumulator, - main_gate_with_range_with_mock_kzg_accumulator, LIMBS, + halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, halo2_kzg_prepare, + main_gate_with_range_with_mock_kzg_accumulator, BITS, LIMBS, }, StandardPlonk, }, - scheme::kzg::{PlonkAccumulationScheme, ShplonkAccumulationScheme}, + verifier::Plonk, }; -use halo2_curves::bn256::G1Affine; +use halo2_curves::bn256::{Bn256, G1Affine}; use halo2_proofs::{ - poly::kzg::{ - multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, - strategy::AccumulatorStrategy, - }, + poly::kzg::multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, }; use paste::paste; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; macro_rules! test { - (@ $prefix:ident, $name:ident, $k:expr, $config:expr, $create_cirucit:expr, $prover:ty, $verifier:ty, $scheme:ty) => { + (@ $prefix:ident, $name:ident, $k:expr, $config:expr, $create_cirucit:expr, $prover:ty, $verifier:ty, $plonk_verifier:ty) => { paste! { #[test] - fn []() { + fn []() { let (params, pk, protocol, circuits) = halo2_kzg_prepare!( $k, $config, $create_cirucit ); let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, $prover, $verifier, - AccumulatorStrategy<_>, Blake2bWrite<_, _, _>, Blake2bRead<_, _, _>, - Challenge255<_> + Challenge255<_>, + ¶ms, + &pk, + &protocol, + &circuits ); halo2_kzg_native_verify!( + $plonk_verifier, params, &snark.protocol, - snark.statements, - $scheme, + &snark.instances, &mut Blake2bRead::<_, G1Affine, _>::init(snark.proof.as_slice()) ); } } }; ($name:ident, $k:expr, $config:expr, $create_cirucit:expr) => { - test!(@ shplonk, $name, $k, $config, $create_cirucit, ProverSHPLONK<_>, VerifierSHPLONK<_>, ShplonkAccumulationScheme); - test!(@ plonk, $name, $k, $config, $create_cirucit, ProverGWC<_>, VerifierGWC<_>, PlonkAccumulationScheme); + test!(@ shplonk, $name, $k, $config, $create_cirucit, ProverSHPLONK<_>, VerifierSHPLONK<_>, Plonk, LimbsEncoding>); + test!(@ plonk, $name, $k, $config, $create_cirucit, ProverGWC<_>, VerifierGWC<_>, Plonk, LimbsEncoding>); } } @@ -63,7 +58,7 @@ test!( zk_standard_plonk_rand, 9, halo2_kzg_config!(true, 2), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) + StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) ); test!( zk_main_gate_with_range_with_mock_kzg_accumulator, @@ -71,21 +66,3 @@ test!( halo2_kzg_config!(true, 2, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), main_gate_with_range_with_mock_kzg_accumulator::() ); -test!( - standard_plonk_rand, - 9, - halo2_kzg_config!(false, 2), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) -); -test!( - main_gate_with_range_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(false, 2, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_range_with_mock_kzg_accumulator::() -); -test!( - main_gate_with_plookup_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_plookup_with_mock_kzg_accumulator::(9) -); diff --git a/src/system/halo2/transcript.rs b/src/system/halo2/transcript.rs new file mode 100644 index 00000000..2200bbf4 --- /dev/null +++ b/src/system/halo2/transcript.rs @@ -0,0 +1,82 @@ +use crate::{ + loader::native::{self, NativeLoader}, + util::{ + arithmetic::CurveAffine, + transcript::{Transcript, TranscriptRead, TranscriptWrite}, + }, + Error, +}; +use halo2_proofs::transcript::{Blake2bRead, Blake2bWrite, Challenge255}; +use std::io::{Read, Write}; + +#[cfg(feature = "loader_evm")] +pub mod evm; + +#[cfg(feature = "loader_halo2")] +pub mod halo2; + +impl Transcript for Blake2bRead> { + fn loader(&self) -> &NativeLoader { + &native::LOADER + } + + fn squeeze_challenge(&mut self) -> C::Scalar { + *halo2_proofs::transcript::Transcript::squeeze_challenge_scalar::(self) + } + + fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { + halo2_proofs::transcript::Transcript::common_point(self, *ec_point) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } + + fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { + halo2_proofs::transcript::Transcript::common_scalar(self, *scalar) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } +} + +impl TranscriptRead + for Blake2bRead> +{ + fn read_scalar(&mut self) -> Result { + halo2_proofs::transcript::TranscriptRead::read_scalar(self) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } + + fn read_ec_point(&mut self) -> Result { + halo2_proofs::transcript::TranscriptRead::read_point(self) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } +} + +impl Transcript for Blake2bWrite> { + fn loader(&self) -> &NativeLoader { + &native::LOADER + } + + fn squeeze_challenge(&mut self) -> C::Scalar { + *halo2_proofs::transcript::Transcript::squeeze_challenge_scalar::(self) + } + + fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { + halo2_proofs::transcript::Transcript::common_point(self, *ec_point) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } + + fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { + halo2_proofs::transcript::Transcript::common_scalar(self, *scalar) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } +} + +impl TranscriptWrite for Blake2bWrite, C, Challenge255> { + fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error> { + halo2_proofs::transcript::TranscriptWrite::write_scalar(self, scalar) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } + + fn write_ec_point(&mut self, ec_point: C) -> Result<(), Error> { + halo2_proofs::transcript::TranscriptWrite::write_point(self, ec_point) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } +} diff --git a/src/system/halo2/transcript/evm.rs b/src/system/halo2/transcript/evm.rs new file mode 100644 index 00000000..77aca5cf --- /dev/null +++ b/src/system/halo2/transcript/evm.rs @@ -0,0 +1,400 @@ +use crate::{ + loader::{ + evm::{loader::Value, u256_to_fe, EcPoint, EvmLoader, MemoryChunk, Scalar}, + native::{self, NativeLoader}, + Loader, + }, + util::{ + arithmetic::{Coordinates, CurveAffine, PrimeField}, + hash::{Digest, Keccak256}, + transcript::{Transcript, TranscriptRead}, + Itertools, + }, + Error, +}; +use ethereum_types::U256; +use halo2_proofs::transcript::EncodedChallenge; +use std::{ + io::{self, Read, Write}, + iter, + marker::PhantomData, + rc::Rc, +}; +pub struct EvmTranscript, S, B> { + loader: L, + stream: S, + buf: B, + _marker: PhantomData, +} + +impl EvmTranscript, usize, MemoryChunk> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + pub fn new(loader: Rc) -> Self { + let ptr = loader.allocate(0x20); + assert_eq!(ptr, 0); + let mut buf = MemoryChunk::new(ptr); + buf.extend(0x20); + Self { + loader, + stream: 0, + buf, + _marker: PhantomData, + } + } + + pub fn load_instances(&mut self, num_instance: Vec) -> Vec> { + num_instance + .into_iter() + .map(|len| { + iter::repeat_with(|| { + let scalar = self.loader.calldataload_scalar(self.stream); + self.stream += 0x20; + scalar + }) + .take(len) + .collect_vec() + }) + .collect() + } +} + +impl Transcript> for EvmTranscript, usize, MemoryChunk> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn loader(&self) -> &Rc { + &self.loader + } + + fn squeeze_challenge(&mut self) -> Scalar { + let len = if self.buf.len() == 0x20 { + assert_eq!(self.loader.ptr(), self.buf.end()); + self.loader + .code_mut() + .push(1) + .push(self.buf.end()) + .mstore8(); + 0x21 + } else { + self.buf.len() + }; + let hash_ptr = self.loader.keccak256(self.buf.ptr(), len); + + let challenge_ptr = self.loader.allocate(0x20); + let dup_hash_ptr = self.loader.allocate(0x20); + self.loader + .code_mut() + .push(hash_ptr) + .mload() + .push(self.loader.scalar_modulus()) + .dup(1) + .r#mod() + .push(challenge_ptr) + .mstore() + .push(dup_hash_ptr) + .mstore(); + + self.buf.reset(dup_hash_ptr); + self.buf.extend(0x20); + + self.loader.scalar(Value::Memory(challenge_ptr)) + } + + fn common_ec_point(&mut self, ec_point: &EcPoint) -> Result<(), Error> { + if let Value::Memory(ptr) = ec_point.value() { + assert_eq!(self.buf.end(), ptr); + self.buf.extend(0x40); + } else { + unreachable!() + } + Ok(()) + } + + fn common_scalar(&mut self, scalar: &Scalar) -> Result<(), Error> { + match scalar.value() { + Value::Constant(_) if self.buf.ptr() == 0 => { + self.loader.copy_scalar(scalar, self.buf.ptr()); + } + Value::Memory(ptr) => { + assert_eq!(self.buf.end(), ptr); + self.buf.extend(0x20); + } + _ => unreachable!(), + } + Ok(()) + } +} + +impl TranscriptRead> for EvmTranscript, usize, MemoryChunk> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn read_scalar(&mut self) -> Result { + let scalar = self.loader.calldataload_scalar(self.stream); + self.stream += 0x20; + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result { + let ec_point = self.loader.calldataload_ec_point(self.stream); + self.stream += 0x40; + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +impl EvmTranscript> +where + C: CurveAffine, +{ + pub fn new(stream: S) -> Self { + Self { + loader: NativeLoader, + stream, + buf: Vec::new(), + _marker: PhantomData, + } + } +} + +impl Transcript for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn loader(&self) -> &NativeLoader { + &native::LOADER + } + + fn squeeze_challenge(&mut self) -> C::Scalar { + let data = self + .buf + .iter() + .cloned() + .chain(if self.buf.len() == 0x20 { + Some(1) + } else { + None + }) + .collect_vec(); + let hash: [u8; 32] = Keccak256::digest(data).into(); + self.buf = hash.to_vec(); + u256_to_fe(U256::from_big_endian(hash.as_slice())) + } + + fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { + let coordinates = + Option::>::from(ec_point.coordinates()).ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Cannot write points at infinity to the transcript".to_string(), + ) + })?; + + [coordinates.x(), coordinates.y()].map(|coordinate| { + self.buf + .extend(coordinate.to_repr().as_ref().iter().rev().cloned()); + }); + + Ok(()) + } + + fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { + self.buf.extend(scalar.to_repr().as_ref().iter().rev()); + + Ok(()) + } +} + +impl TranscriptRead for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, + S: Read, +{ + fn read_scalar(&mut self) -> Result { + let mut data = [0; 32]; + self.stream + .read_exact(data.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + data.reverse(); + let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid scalar encoding in proof".to_string(), + ) + })?; + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result { + let [mut x, mut y] = [::Repr::default(); 2]; + for repr in [&mut x, &mut y] { + self.stream + .read_exact(repr.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + repr.as_mut().reverse(); + } + let x = Option::from(::from_repr(x)); + let y = Option::from(::from_repr(y)); + let ec_point = x + .zip(y) + .and_then(|(x, y)| Option::from(C::from_xy(x, y))) + .ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid elliptic curve point encoding in proof".to_string(), + ) + })?; + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +impl EvmTranscript> +where + C: CurveAffine, + S: Write, +{ + pub fn stream_mut(&mut self) -> &mut S { + &mut self.stream + } + + pub fn finalize(self) -> S { + self.stream + } +} + +pub struct ChallengeEvm(C::Scalar) +where + C: CurveAffine, + C::Scalar: PrimeField; + +impl EncodedChallenge for ChallengeEvm +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + type Input = [u8; 32]; + + fn new(challenge_input: &[u8; 32]) -> Self { + ChallengeEvm(u256_to_fe(U256::from_big_endian(challenge_input))) + } + + fn get_scalar(&self) -> C::Scalar { + self.0 + } +} + +impl halo2_proofs::transcript::Transcript> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn squeeze_challenge(&mut self) -> ChallengeEvm { + ChallengeEvm(Transcript::squeeze_challenge(self)) + } + + fn common_point(&mut self, ec_point: C) -> io::Result<()> { + match Transcript::common_ec_point(self, &ec_point) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } + + fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + match Transcript::common_scalar(self, &scalar) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } +} + +impl halo2_proofs::transcript::TranscriptRead> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn read_point(&mut self) -> io::Result { + match TranscriptRead::read_ec_point(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } + + fn read_scalar(&mut self) -> io::Result { + match TranscriptRead::read_scalar(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } +} + +impl halo2_proofs::transcript::TranscriptReadBuffer> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn init(reader: R) -> Self { + Self::new(reader) + } +} + +impl halo2_proofs::transcript::TranscriptWrite> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn write_point(&mut self, ec_point: C) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_point(self, ec_point)?; + let coords: Coordinates = Option::from(ec_point.coordinates()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "Cannot write points at infinity to the transcript", + ) + })?; + let mut x = coords.x().to_repr(); + let mut y = coords.y().to_repr(); + x.as_mut().reverse(); + y.as_mut().reverse(); + self.stream_mut().write_all(x.as_ref())?; + self.stream_mut().write_all(y.as_ref()) + } + + fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_scalar(self, scalar)?; + let mut data = scalar.to_repr(); + data.as_mut().reverse(); + self.stream_mut().write_all(data.as_ref()) + } +} + +impl halo2_proofs::transcript::TranscriptWriterBuffer> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn init(writer: W) -> Self { + Self::new(writer) + } + + fn finalize(self) -> W { + self.finalize() + } +} diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs new file mode 100644 index 00000000..43701bd7 --- /dev/null +++ b/src/system/halo2/transcript/halo2.rs @@ -0,0 +1,439 @@ +use crate::{ + loader::{ + halo2::{self, EcPoint, EccInstructions, Halo2Loader, IntegerInstructions, Scalar}, + native::{self, NativeLoader}, + Loader, ScalarLoader, + }, + util::{ + arithmetic::{fe_from_big, fe_to_big, CurveAffine, FieldExt, PrimeField}, + hash::Poseidon, + transcript::{Transcript, TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, +}; +use halo2_proofs::{ + circuit::{AssignedCell, Value}, + transcript::EncodedChallenge, +}; +use std::{ + io::{self, Read, Write}, + marker::PhantomData, + rc::Rc, +}; + +pub trait EncodeNative<'a, C: CurveAffine, N: FieldExt>: EccInstructions<'a, C> { + fn encode_native( + &self, + ctx: &mut Self::Context, + ec_point: &Self::AssignedEcPoint, + ) -> Result>, Error>; +} + +pub struct PoseidonTranscript< + C: CurveAffine, + L: Loader, + S, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +> { + loader: L, + stream: S, + buf: Poseidon>::LoadedScalar, T, RATE>, + _marker: PhantomData, +} + +impl< + 'a, + C: CurveAffine, + R: Read, + EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > PoseidonTranscript>, Value, T, RATE, R_F, R_P> +{ + pub fn new(loader: &Rc>, stream: Value) -> Self { + Self { + loader: loader.clone(), + stream, + buf: Poseidon::new(loader.clone(), R_F, R_P), + _marker: PhantomData, + } + } +} + +impl< + 'a, + C: CurveAffine, + R: Read, + EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > Transcript>> + for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +{ + fn loader(&self) -> &Rc> { + &self.loader + } + + fn squeeze_challenge(&mut self) -> Scalar<'a, C, EccChip> { + self.buf.squeeze() + } + + fn common_scalar(&mut self, scalar: &Scalar<'a, C, EccChip>) -> Result<(), Error> { + self.buf.update(&[scalar.clone()]); + Ok(()) + } + + fn common_ec_point(&mut self, ec_point: &EcPoint<'a, C, EccChip>) -> Result<(), Error> { + let encoded = self + .loader + .ecc_chip() + .encode_native(&mut self.loader.ctx_mut(), &ec_point.assigned()) + .map(|encoded| { + encoded + .into_iter() + .map(|encoded| self.loader.scalar(halo2::loader::Value::Assigned(encoded))) + .collect_vec() + }) + .map_err(|_| Error::Transcript(io::ErrorKind::Other, "".to_string()))?; + self.buf.update(&encoded); + Ok(()) + } +} + +impl< + 'a, + C: CurveAffine, + R: Read, + EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > TranscriptRead>> + for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +{ + fn read_scalar(&mut self) -> Result, Error> { + let scalar = self.stream.as_mut().and_then(|stream| { + let mut data = ::Repr::default(); + if stream.read_exact(data.as_mut()).is_err() { + return Value::unknown(); + } + Option::::from(C::Scalar::from_repr(data)) + .map(|scalar| Value::known(self.loader.scalar_chip().integer(scalar))) + .unwrap_or_else(Value::unknown) + }); + let scalar = self.loader.assign_scalar(scalar); + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result, Error> { + let ec_point = self.stream.as_mut().and_then(|stream| { + let mut compressed = C::Repr::default(); + if stream.read_exact(compressed.as_mut()).is_err() { + return Value::unknown(); + } + Option::::from(C::from_bytes(&compressed)) + .map(Value::known) + .unwrap_or_else(Value::unknown) + }); + let ec_point = self.loader.assign_ec_point(ec_point); + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +// + +impl + PoseidonTranscript +{ + pub fn new(stream: S) -> Self { + Self { + loader: NativeLoader, + stream, + buf: Poseidon::new(NativeLoader, R_F, R_P), + _marker: PhantomData, + } + } +} + +impl + Transcript for PoseidonTranscript +{ + fn loader(&self) -> &NativeLoader { + &native::LOADER + } + + fn squeeze_challenge(&mut self) -> C::Scalar { + self.buf.squeeze() + } + + fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { + self.buf.update(&[*scalar]); + Ok(()) + } + + fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { + let encoded: Vec<_> = Option::from(ec_point.coordinates().map(|coordinates| { + [coordinates.x(), coordinates.y()] + .into_iter() + .map(|fe| fe_from_big(fe_to_big(*fe))) + .collect_vec() + })) + .ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid elliptic curve point encoding in proof".to_string(), + ) + })?; + self.buf.update(&encoded); + Ok(()) + } +} + +impl< + C: CurveAffine, + R: Read, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > TranscriptRead + for PoseidonTranscript +{ + fn read_scalar(&mut self) -> Result { + let mut data = ::Repr::default(); + self.stream + .read_exact(data.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid scalar encoding in proof".to_string(), + ) + })?; + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result { + let mut data = C::Repr::default(); + self.stream + .read_exact(data.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + let ec_point = Option::::from(C::from_bytes(&data)).ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid elliptic curve point encoding in proof".to_string(), + ) + })?; + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +impl< + C: CurveAffine, + W: Write, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > PoseidonTranscript +{ + pub fn stream_mut(&mut self) -> &mut W { + &mut self.stream + } + + pub fn finalize(self) -> W { + self.stream + } +} + +impl< + C: CurveAffine, + W: Write, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > TranscriptWrite for PoseidonTranscript +{ + fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error> { + self.common_scalar(&scalar)?; + let data = scalar.to_repr(); + self.stream_mut().write_all(data.as_ref()).map_err(|err| { + Error::Transcript( + err.kind(), + "Failed to write scalar to transcript".to_string(), + ) + }) + } + + fn write_ec_point(&mut self, ec_point: C) -> Result<(), Error> { + self.common_ec_point(&ec_point)?; + let data = ec_point.to_bytes(); + self.stream_mut().write_all(data.as_ref()).map_err(|err| { + Error::Transcript( + err.kind(), + "Failed to write elliptic curve to transcript".to_string(), + ) + }) + } +} + +pub struct ChallengeScalar(C::Scalar); + +impl EncodedChallenge for ChallengeScalar { + type Input = C::Scalar; + + fn new(challenge_input: &C::Scalar) -> Self { + ChallengeScalar(*challenge_input) + } + + fn get_scalar(&self) -> C::Scalar { + self.0 + } +} + +impl + halo2_proofs::transcript::Transcript> + for PoseidonTranscript +{ + fn squeeze_challenge(&mut self) -> ChallengeScalar { + ChallengeScalar::new(&Transcript::squeeze_challenge(self)) + } + + fn common_point(&mut self, ec_point: C) -> io::Result<()> { + match Transcript::common_ec_point(self, &ec_point) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } + + fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + match Transcript::common_scalar(self, &scalar) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } +} + +impl< + C: CurveAffine, + R: Read, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > halo2_proofs::transcript::TranscriptRead> + for PoseidonTranscript +{ + fn read_point(&mut self) -> io::Result { + match TranscriptRead::read_ec_point(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } + + fn read_scalar(&mut self) -> io::Result { + match TranscriptRead::read_scalar(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } +} + +impl< + C: CurveAffine, + R: Read, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > halo2_proofs::transcript::TranscriptReadBuffer> + for PoseidonTranscript +{ + fn init(reader: R) -> Self { + Self::new(reader) + } +} + +impl< + C: CurveAffine, + W: Write, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > halo2_proofs::transcript::TranscriptWrite> + for PoseidonTranscript +{ + fn write_point(&mut self, ec_point: C) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_point( + self, ec_point, + )?; + let data = ec_point.to_bytes(); + self.stream_mut().write_all(data.as_ref()) + } + + fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_scalar(self, scalar)?; + let data = scalar.to_repr(); + self.stream_mut().write_all(data.as_ref()) + } +} + +impl< + C: CurveAffine, + W: Write, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > halo2_proofs::transcript::TranscriptWriterBuffer> + for PoseidonTranscript +{ + fn init(writer: W) -> Self { + Self::new(writer) + } + + fn finalize(self) -> W { + self.finalize() + } +} + +mod halo2_wrong { + use crate::system::halo2::transcript::halo2::EncodeNative; + use halo2_curves::CurveAffine; + use halo2_proofs::circuit::AssignedCell; + use halo2_wrong_ecc::BaseFieldEccChip; + + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> EncodeNative<'a, C, C::Scalar> + for BaseFieldEccChip + { + fn encode_native( + &self, + _: &mut Self::Context, + ec_point: &Self::AssignedEcPoint, + ) -> Result>, crate::Error> { + Ok(vec![ + ec_point.x().native().clone(), + ec_point.y().native().clone(), + ]) + } + } +} diff --git a/src/util.rs b/src/util.rs index eb38e56a..3d5d0d79 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,34 +1,7 @@ -mod arithmetic; -mod expression; -mod transcript; +pub mod arithmetic; +pub mod hash; +pub mod msm; +pub mod protocol; +pub mod transcript; -pub use arithmetic::{ - batch_invert, batch_invert_and_mul, fe_from_limbs, fe_to_limbs, BatchInvert, Curve, Domain, - Field, FieldOps, Fraction, Group, GroupEncoding, GroupOps, PrimeCurveAffine, PrimeField, - Rotation, UncompressedEncoding, -}; -pub use expression::{CommonPolynomial, CommonPolynomialEvaluation, Expression, Query}; -pub use transcript::{Transcript, TranscriptRead}; - -pub use itertools::{EitherOrBoth, Itertools}; - -#[macro_export] -macro_rules! collect_slice { - ($vec:ident) => { - use $crate::util::Itertools; - - let $vec = $vec.iter().map(|vec| vec.as_slice()).collect_vec(); - }; - ($vec:ident, 2) => { - use $crate::util::Itertools; - - let $vec = $vec - .iter() - .map(|vec| { - collect_slice!(vec); - vec - }) - .collect_vec(); - let $vec = $vec.iter().map(|vec| vec.as_slice()).collect_vec(); - }; -} +pub(crate) use itertools::Itertools; diff --git a/src/util/arithmetic.rs b/src/util/arithmetic.rs index 0ba1929b..b9a5c7c6 100644 --- a/src/util/arithmetic.rs +++ b/src/util/arithmetic.rs @@ -8,43 +8,34 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; -pub use ff::{BatchInvert, Field, PrimeField}; -pub use group::{prime::PrimeCurveAffine, Curve, Group, GroupEncoding}; +pub use halo2_curves::{ + group::{ + ff::{BatchInvert, Field, PrimeField}, + prime::PrimeCurveAffine, + Curve, Group, GroupEncoding, + }, + pairing::MillerLoopResult, + Coordinates, CurveAffine, CurveExt, FieldExt, +}; + +pub trait MultiMillerLoop: halo2_curves::pairing::MultiMillerLoop + Debug {} + +impl MultiMillerLoop for M {} -pub trait GroupOps: +pub trait FieldOps: Sized + + Neg + Add + Sub - + Neg + + Mul + for<'a> Add<&'a Self, Output = Self> + for<'a> Sub<&'a Self, Output = Self> + + for<'a> Mul<&'a Self, Output = Self> + AddAssign + SubAssign + + MulAssign + for<'a> AddAssign<&'a Self> + for<'a> SubAssign<&'a Self> -{ -} - -impl GroupOps for T where - T: Sized - + Add - + Sub - + Neg - + for<'a> Add<&'a Self, Output = Self> - + for<'a> Sub<&'a Self, Output = Self> - + AddAssign - + SubAssign - + for<'a> AddAssign<&'a Self> - + for<'a> SubAssign<&'a Self> -{ -} - -pub trait FieldOps: - Sized - + GroupOps - + Mul - + for<'a> Mul<&'a Self, Output = Self> - + MulAssign + for<'a> MulAssign<&'a Self> { fn invert(&self) -> Option; @@ -78,12 +69,13 @@ pub fn batch_invert(values: &mut [F]) { batch_invert_and_mul(values, &F::one()) } -pub trait UncompressedEncoding: Sized { - type Uncompressed: AsRef<[u8]> + AsMut<[u8]>; +pub fn root_of_unity(k: usize) -> F { + assert!(k <= F::S as usize); - fn to_uncompressed(&self) -> Self::Uncompressed; - - fn from_uncompressed(uncompressed: Self::Uncompressed) -> Option; + iter::successors(Some(F::root_of_unity()), |acc| Some(acc.square())) + .take(F::S as usize - k + 1) + .last() + .unwrap() } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -119,15 +111,9 @@ pub struct Domain { } impl Domain { - pub fn new(k: usize) -> Self { - assert!(k <= F::S as usize); - + pub fn new(k: usize, gen: F) -> Self { let n = 1 << k; let n_inv = F::from(n as u64).invert().unwrap(); - let gen = iter::successors(Some(F::root_of_unity()), |acc| Some(acc.square())) - .take(F::S as usize - k + 1) - .last() - .unwrap(); let gen_inv = gen.invert().unwrap(); Self { @@ -149,30 +135,33 @@ impl Domain { } #[derive(Clone, Debug)] -pub struct Fraction { - numer: Option, - denom: F, +pub struct Fraction { + numer: Option, + denom: T, + eval: Option, inv: bool, } -impl Fraction { - pub fn new(numer: F, denom: F) -> Self { +impl Fraction { + pub fn new(numer: T, denom: T) -> Self { Self { numer: Some(numer), denom, + eval: None, inv: false, } } - pub fn one_over(denom: F) -> Self { + pub fn one_over(denom: T) -> Self { Self { numer: None, denom, + eval: None, inv: false, } } - pub fn denom(&self) -> Option<&F> { + pub fn denom(&self) -> Option<&T> { if !self.inv { Some(&self.denom) } else { @@ -180,7 +169,7 @@ impl Fraction { } } - pub fn denom_mut(&mut self) -> Option<&mut F> { + pub fn denom_mut(&mut self) -> Option<&mut T> { if !self.inv { self.inv = true; Some(&mut self.denom) @@ -190,21 +179,35 @@ impl Fraction { } } -impl Fraction { - pub fn evaluate(&self) -> F { - let denom = if self.inv { - self.denom.clone() - } else { - self.denom.invert().unwrap() - }; - self.numer - .clone() - .map(|numer| numer * &denom) - .unwrap_or(denom) +impl Fraction { + pub fn evaluate(&mut self) { + assert!(self.inv); + assert!(self.eval.is_none()); + + self.eval = Some( + self.numer + .as_ref() + .map(|numer| numer.clone() * &self.denom) + .unwrap_or_else(|| self.denom.clone()), + ); } + + pub fn evaluated(&self) -> &T { + assert!(self.inv); + + self.eval.as_ref().unwrap() + } +} + +pub fn ilog2(value: usize) -> usize { + (usize::BITS - value.leading_zeros() - 1) as usize } -pub fn big_to_fe(big: BigUint) -> F { +pub fn modulus() -> BigUint { + fe_to_big(-F::one()) + 1usize +} + +pub fn fe_from_big(big: BigUint) -> F { let bytes = big.to_bytes_le(); let mut repr = F::Repr::default(); assert!(bytes.len() <= repr.as_ref().len()); @@ -212,10 +215,18 @@ pub fn big_to_fe(big: BigUint) -> F { F::from_repr(repr).unwrap() } +pub fn fe_to_big(fe: F) -> BigUint { + BigUint::from_bytes_le(fe.to_repr().as_ref()) +} + +pub fn fe_to_fe(fe: F1) -> F2 { + fe_from_big(fe_to_big(fe) % modulus::()) +} + pub fn fe_from_limbs( limbs: [F1; LIMBS], ) -> F2 { - big_to_fe( + fe_from_big( limbs .iter() .map(|limb| BigUint::from_bytes_le(limb.to_repr().as_ref())) @@ -234,8 +245,15 @@ pub fn fe_to_limbs> shift) & &mask)) + .map(move |shift| fe_from_big((&big >> shift) & &mask)) .collect_vec() .try_into() .unwrap() } + +pub fn powers(scalar: F) -> impl Iterator +where + for<'a> F: Mul<&'a F, Output = F> + One + Clone, +{ + iter::successors(Some(F::one()), move |power| Some(scalar.clone() * power)) +} diff --git a/src/util/hash.rs b/src/util/hash.rs new file mode 100644 index 00000000..17ede0b3 --- /dev/null +++ b/src/util/hash.rs @@ -0,0 +1,6 @@ +mod poseidon; + +pub use crate::util::hash::poseidon::Poseidon; + +#[cfg(feature = "loader_evm")] +pub use sha3::{Digest, Keccak256}; diff --git a/src/util/hash/poseidon.rs b/src/util/hash/poseidon.rs new file mode 100644 index 00000000..878b69ce --- /dev/null +++ b/src/util/hash/poseidon.rs @@ -0,0 +1,178 @@ +use crate::{ + loader::{LoadedScalar, ScalarLoader}, + util::{arithmetic::FieldExt, Itertools}, +}; +use poseidon::{self, SparseMDSMatrix, Spec}; +use std::{iter, marker::PhantomData, mem}; + +struct State { + inner: [L; T], + _marker: PhantomData, +} + +impl, const T: usize, const RATE: usize> State { + fn new(inner: [L; T]) -> Self { + Self { + inner, + _marker: PhantomData, + } + } + + fn loader(&self) -> &L::Loader { + self.inner[0].loader() + } + + fn power5_with_constant(value: &L, constant: &F) -> L { + value + .loader() + .sum_products_with_const(&[(value, &value.square().square())], *constant) + } + + fn sbox_full(&mut self, constants: &[F; T]) { + for (state, constant) in self.inner.iter_mut().zip(constants.iter()) { + *state = Self::power5_with_constant(state, constant); + } + } + + fn sbox_part(&mut self, constant: &F) { + self.inner[0] = Self::power5_with_constant(&self.inner[0], constant); + } + + fn absorb_with_pre_constants(&mut self, inputs: &[L], pre_constants: &[F; T]) { + assert!(inputs.len() < T); + + self.inner[0] = self + .loader() + .sum_with_const(&[&self.inner[0]], pre_constants[0]); + self.inner + .iter_mut() + .zip(pre_constants.iter()) + .skip(1) + .zip(inputs) + .for_each(|((state, constant), input)| { + *state = state.loader().sum_with_const(&[state, input], *constant); + }); + self.inner + .iter_mut() + .zip(pre_constants.iter()) + .skip(1 + inputs.len()) + .enumerate() + .for_each(|(idx, (state, constant))| { + *state = state.loader().sum_with_const( + &[state], + if idx == 0 { + F::one() + constant + } else { + *constant + }, + ); + }); + } + + fn apply_mds(&mut self, mds: &[[F; T]; T]) { + self.inner = mds + .iter() + .map(|row| { + self.loader() + .sum_with_coeff(&row.iter().cloned().zip(self.inner.iter()).collect_vec()) + }) + .collect_vec() + .try_into() + .unwrap(); + } + + fn apply_sparse_mds(&mut self, mds: &SparseMDSMatrix) { + self.inner = iter::once( + self.loader().sum_with_coeff( + &mds.row() + .iter() + .cloned() + .zip(self.inner.iter()) + .collect_vec(), + ), + ) + .chain( + mds.col_hat() + .iter() + .zip(self.inner.iter().skip(1)) + .map(|(coeff, state)| { + self.loader() + .sum_with_coeff(&[(*coeff, &self.inner[0]), (F::one(), state)]) + }), + ) + .collect_vec() + .try_into() + .unwrap(); + } +} + +pub struct Poseidon { + spec: Spec, + state: State, + buf: Vec, +} + +impl, const T: usize, const RATE: usize> Poseidon { + pub fn new(loader: L::Loader, r_f: usize, r_p: usize) -> Self { + Self { + spec: Spec::new(r_f, r_p), + state: State::new( + poseidon::State::default() + .words() + .map(|state| loader.load_const(&state)), + ), + buf: Vec::new(), + } + } + + pub fn update(&mut self, elements: &[L]) { + self.buf.extend_from_slice(elements); + } + + pub fn squeeze(&mut self) -> L { + let buf = mem::take(&mut self.buf); + let exact = buf.len() % RATE == 0; + + for chunk in buf.chunks(RATE) { + self.permutation(chunk); + } + if exact { + self.permutation(&[]); + } + + self.state.inner[1].clone() + } + + fn permutation(&mut self, inputs: &[L]) { + let r_f = self.spec.r_f() / 2; + let mds = self.spec.mds_matrices().mds().rows(); + let pre_sparse_mds = self.spec.mds_matrices().pre_sparse_mds().rows(); + let sparse_matrices = self.spec.mds_matrices().sparse_matrices(); + + // First half of the full rounds + let constants = self.spec.constants().start(); + self.state.absorb_with_pre_constants(inputs, &constants[0]); + for constants in constants.iter().skip(1).take(r_f - 1) { + self.state.sbox_full(constants); + self.state.apply_mds(&mds); + } + self.state.sbox_full(constants.last().unwrap()); + self.state.apply_mds(&pre_sparse_mds); + + // Partial rounds + let constants = self.spec.constants().partial(); + for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { + self.state.sbox_part(constant); + self.state.apply_sparse_mds(sparse_mds); + } + + // Second half of the full rounds + let constants = self.spec.constants().end(); + for constants in constants.iter() { + self.state.sbox_full(constants); + self.state.apply_mds(&mds); + } + self.state.sbox_full(&[F::zero(); T]); + self.state.apply_mds(&mds); + } +} diff --git a/src/util/msm.rs b/src/util/msm.rs new file mode 100644 index 00000000..a7a3d45d --- /dev/null +++ b/src/util/msm.rs @@ -0,0 +1,203 @@ +use crate::{ + loader::{LoadedEcPoint, Loader}, + util::arithmetic::CurveAffine, +}; +use std::{ + default::Default, + iter::{self, Sum}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; + +#[derive(Clone, Debug)] +pub struct Msm> { + constant: Option, + scalars: Vec, + bases: Vec, +} + +impl Default for Msm +where + C: CurveAffine, + L: Loader, +{ + fn default() -> Self { + Self { + constant: None, + scalars: Vec::new(), + bases: Vec::new(), + } + } +} + +impl Msm +where + C: CurveAffine, + L: Loader, +{ + pub fn constant(constant: L::LoadedScalar) -> Self { + Msm { + constant: Some(constant), + ..Default::default() + } + } + + pub fn base(base: L::LoadedEcPoint) -> Self { + let one = base.loader().load_one(); + Msm { + scalars: vec![one], + bases: vec![base], + ..Default::default() + } + } + + pub(crate) fn size(&self) -> usize { + self.bases.len() + } + + pub(crate) fn split(mut self) -> (Self, Option) { + let constant = self.constant.take(); + (self, constant) + } + + pub(crate) fn try_into_constant(self) -> Option { + self.bases.is_empty().then(|| self.constant.unwrap()) + } + + pub fn evaluate(self, gen: Option) -> L::LoadedEcPoint { + let gen = gen.map(|gen| { + self.bases + .first() + .unwrap() + .loader() + .ec_point_load_const(&gen) + }); + L::LoadedEcPoint::multi_scalar_multiplication( + iter::empty() + .chain(self.constant.map(|constant| (constant, gen.unwrap()))) + .chain(self.scalars.into_iter().zip(self.bases.into_iter())), + ) + } + + pub fn scale(&mut self, factor: &L::LoadedScalar) { + if let Some(constant) = self.constant.as_mut() { + *constant *= factor; + } + for scalar in self.scalars.iter_mut() { + *scalar *= factor + } + } + + pub fn push(&mut self, scalar: L::LoadedScalar, base: L::LoadedEcPoint) { + if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) { + self.scalars[pos] += scalar; + } else { + self.scalars.push(scalar); + self.bases.push(base); + } + } + + pub fn extend(&mut self, mut other: Self) { + match (self.constant.as_mut(), other.constant.as_ref()) { + (Some(lhs), Some(rhs)) => *lhs += rhs, + (None, Some(_)) => self.constant = other.constant.take(), + _ => {} + }; + for (scalar, base) in other.scalars.into_iter().zip(other.bases) { + self.push(scalar, base); + } + } +} + +impl Add> for Msm +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm; + + fn add(mut self, rhs: Msm) -> Self::Output { + self.extend(rhs); + self + } +} + +impl AddAssign> for Msm +where + C: CurveAffine, + L: Loader, +{ + fn add_assign(&mut self, rhs: Msm) { + self.extend(rhs); + } +} + +impl Sub> for Msm +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm; + + fn sub(mut self, rhs: Msm) -> Self::Output { + self.extend(-rhs); + self + } +} + +impl SubAssign> for Msm +where + C: CurveAffine, + L: Loader, +{ + fn sub_assign(&mut self, rhs: Msm) { + self.extend(-rhs); + } +} + +impl Mul<&L::LoadedScalar> for Msm +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm; + + fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { + self.scale(rhs); + self + } +} + +impl MulAssign<&L::LoadedScalar> for Msm +where + C: CurveAffine, + L: Loader, +{ + fn mul_assign(&mut self, rhs: &L::LoadedScalar) { + self.scale(rhs); + } +} + +impl Neg for Msm +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm; + fn neg(mut self) -> Msm { + self.constant = self.constant.map(|constant| -constant); + for scalar in self.scalars.iter_mut() { + *scalar = -scalar.clone(); + } + self + } +} + +impl Sum for Msm +where + C: CurveAffine, + L: Loader, +{ + fn sum>(iter: I) -> Self { + iter.reduce(|acc, item| acc + item).unwrap_or_default() + } +} diff --git a/src/util/expression.rs b/src/util/protocol.rs similarity index 65% rename from src/util/expression.rs rename to src/util/protocol.rs index 41b04ded..bae363ff 100644 --- a/src/util/expression.rs +++ b/src/util/protocol.rs @@ -1,7 +1,12 @@ use crate::{ loader::{LoadedScalar, Loader}, - util::{Curve, Domain, Field, Fraction, Itertools, Rotation}, + util::{ + arithmetic::{CurveAffine, Domain, Field, Fraction, Rotation}, + Itertools, + }, }; +use num_integer::Integer; +use num_traits::One; use std::{ cmp::max, collections::{BTreeMap, BTreeSet}, @@ -19,10 +24,11 @@ pub enum CommonPolynomial { #[derive(Clone, Debug)] pub struct CommonPolynomialEvaluation where - C: Curve, + C: CurveAffine, L: Loader, { zn: L::LoadedScalar, + zn_minus_one: L::LoadedScalar, zn_minus_one_inv: Fraction, identity: L::LoadedScalar, lagrange: BTreeMap>, @@ -30,28 +36,29 @@ where impl CommonPolynomialEvaluation where - C: Curve, + C: CurveAffine, L: Loader, { pub fn new( domain: &Domain, - loader: &L, langranges: impl IntoIterator, z: &L::LoadedScalar, ) -> Self { + let loader = z.loader(); + let zn = z.pow_const(domain.n as u64); let langranges = langranges.into_iter().sorted().dedup().collect_vec(); let one = loader.load_one(); let zn_minus_one = zn.clone() - one; + let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone()); + let n_inv = loader.load_const(&domain.n_inv); let numer = zn_minus_one.clone() * n_inv; - let omegas = langranges .iter() .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::one(), Rotation(i)))) .collect_vec(); - let lagrange_evals = omegas .iter() .map(|omega| Fraction::new(numer.clone() * omega, z.clone() - omega)) @@ -59,24 +66,29 @@ where Self { zn, - zn_minus_one_inv: Fraction::one_over(zn_minus_one), + zn_minus_one, + zn_minus_one_inv, identity: z.clone(), lagrange: langranges.into_iter().zip(lagrange_evals).collect(), } } - pub fn zn(&self) -> L::LoadedScalar { - self.zn.clone() + pub fn zn(&self) -> &L::LoadedScalar { + &self.zn + } + + pub fn zn_minus_one(&self) -> &L::LoadedScalar { + &self.zn_minus_one } - pub fn zn_minus_one_inv(&self) -> L::LoadedScalar { - self.zn_minus_one_inv.evaluate() + pub fn zn_minus_one_inv(&self) -> &L::LoadedScalar { + self.zn_minus_one_inv.evaluated() } - pub fn get(&self, poly: CommonPolynomial) -> L::LoadedScalar { + pub fn get(&self, poly: CommonPolynomial) -> &L::LoadedScalar { match poly { - CommonPolynomial::Identity => self.identity.clone(), - CommonPolynomial::Lagrange(i) => self.lagrange.get(&i).unwrap().evaluate(), + CommonPolynomial::Identity => &self.identity, + CommonPolynomial::Lagrange(i) => self.lagrange.get(&i).unwrap().evaluated(), } } @@ -87,6 +99,26 @@ where .chain(iter::once(self.zn_minus_one_inv.denom_mut())) .flatten() } + + pub fn evaluate(&mut self) { + self.lagrange + .iter_mut() + .map(|(_, value)| value) + .chain(iter::once(&mut self.zn_minus_one_inv)) + .for_each(Fraction::evaluate) + } +} + +#[derive(Clone, Debug)] +pub struct QuotientPolynomial { + pub chunk_degree: usize, + pub numerator: Expression, +} + +impl QuotientPolynomial { + pub fn num_chunk(&self) -> usize { + Integer::div_ceil(&(self.numerator.degree() - 1), &self.chunk_degree) + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -114,10 +146,11 @@ pub enum Expression { Sum(Box>, Box>), Product(Box>, Box>), Scaled(Box>, F), + DistributePowers(Vec>, Box>), } impl Expression { - pub fn evaluate( + pub fn evaluate( &self, constant: &impl Fn(F) -> T, common_poly: &impl Fn(CommonPolynomial) -> T, @@ -128,83 +161,53 @@ impl Expression { product: &impl Fn(T, T) -> T, scaled: &impl Fn(T, F) -> T, ) -> T { + let evaluate = |expr: &Expression| { + expr.evaluate( + constant, + common_poly, + poly, + challenge, + negated, + sum, + product, + scaled, + ) + }; match self { Expression::Constant(scalar) => constant(scalar.clone()), Expression::CommonPolynomial(poly) => common_poly(*poly), Expression::Polynomial(query) => poly(*query), Expression::Challenge(index) => challenge(*index), Expression::Negated(a) => { - let a = a.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); + let a = evaluate(a); negated(a) } Expression::Sum(a, b) => { - let a = a.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); - let b = b.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); + let a = evaluate(a); + let b = evaluate(b); sum(a, b) } Expression::Product(a, b) => { - let a = a.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); - let b = b.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); + let a = evaluate(a); + let b = evaluate(b); product(a, b) } Expression::Scaled(a, scalar) => { - let a = a.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); + let a = evaluate(a); scaled(a, scalar.clone()) } + Expression::DistributePowers(exprs, scalar) => { + assert!(!exprs.is_empty()); + if exprs.len() == 1 { + return evaluate(exprs.first().unwrap()); + } + let mut exprs = exprs.iter(); + let first = evaluate(exprs.next().unwrap()); + let scalar = evaluate(scalar); + exprs.fold(first, |acc, expr| { + sum(product(acc, scalar.clone()), evaluate(expr)) + }) + } } } @@ -218,6 +221,12 @@ impl Expression { Expression::Sum(a, b) => max(a.degree(), b.degree()), Expression::Product(a, b) => a.degree() + b.degree(), Expression::Scaled(a, _) => a.degree(), + Expression::DistributePowers(a, b) => a + .iter() + .chain(Some(b.as_ref())) + .map(Self::degree) + .max() + .unwrap_or_default(), } } @@ -237,6 +246,20 @@ impl Expression { ) .unwrap_or_default() } + + pub fn used_query(&self) -> BTreeSet { + self.evaluate( + &|_| None, + &|_| None, + &|query| Some(BTreeSet::from_iter([query])), + &|_| None, + &|a| a, + &merge_left_right, + &merge_left_right, + &|a, _| a, + ) + .unwrap_or_default() + } } impl From for Expression { @@ -306,6 +329,12 @@ impl Sum for Expression { } } +impl One for Expression { + fn one() -> Self { + Expression::Constant(F::one()) + } +} + fn merge_left_right(a: Option>, b: Option>) -> Option> { match (a, b) { (Some(a), None) | (None, Some(a)) => Some(a), @@ -316,3 +345,21 @@ fn merge_left_right(a: Option>, b: Option>) -> O _ => None, } } + +#[derive(Clone, Debug)] +pub enum LinearizationStrategy { + /// Older linearization strategy of GWC19, which has linearization + /// polynomial that doesn't evaluate to 0, and requires prover to send extra + /// evaluation of it to verifier. + WithoutConstant, + /// Current linearization strategy of GWC19, which has linearization + /// polynomial that evaluate to 0 by subtracting product of vanishing and + /// quotient polynomials. + MinusVanishingTimesQuotient, +} + +#[derive(Clone, Debug, Default)] +pub struct InstanceCommittingKey { + pub bases: Vec, + pub constant: Option, +} diff --git a/src/util/transcript.rs b/src/util/transcript.rs index a42d5e70..3337324d 100644 --- a/src/util/transcript.rs +++ b/src/util/transcript.rs @@ -1,13 +1,15 @@ use crate::{ - loader::Loader, - {util::Curve, Error}, + loader::{native::NativeLoader, Loader}, + {util::arithmetic::CurveAffine, Error}, }; pub trait Transcript where - C: Curve, + C: CurveAffine, L: Loader, { + fn loader(&self) -> &L; + fn squeeze_challenge(&mut self) -> L::LoadedScalar; fn squeeze_n_challenges(&mut self, n: usize) -> Vec { @@ -21,7 +23,7 @@ where pub trait TranscriptRead: Transcript where - C: Curve, + C: CurveAffine, L: Loader, { fn read_scalar(&mut self) -> Result; @@ -36,3 +38,9 @@ where (0..n).map(|_| self.read_ec_point()).collect() } } + +pub trait TranscriptWrite: Transcript { + fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error>; + + fn write_ec_point(&mut self, ec_point: C) -> Result<(), Error>; +} diff --git a/src/verifier.rs b/src/verifier.rs new file mode 100644 index 00000000..07529603 --- /dev/null +++ b/src/verifier.rs @@ -0,0 +1,51 @@ +use crate::{ + loader::Loader, + pcs::{Decider, MultiOpenScheme}, + util::{arithmetic::CurveAffine, transcript::TranscriptRead}, + Error, Protocol, +}; +use std::fmt::Debug; + +mod plonk; + +pub use plonk::{Plonk, PlonkProof}; + +pub trait PlonkVerifier +where + C: CurveAffine, + L: Loader, + MOS: MultiOpenScheme, +{ + type Proof: Clone + Debug; + + fn read_proof( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead; + + fn succinct_verify( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + proof: &Self::Proof, + ) -> Result, Error>; + + fn verify( + svk: &MOS::SuccinctVerifyingKey, + dk: &MOS::DecidingKey, + protocol: &Protocol, + instances: &[Vec], + proof: &Self::Proof, + ) -> Result + where + MOS: Decider, + { + let accumulators = Self::succinct_verify(svk, protocol, instances, proof)?; + let output = MOS::decide_all(dk, accumulators); + Ok(output) + } +} diff --git a/src/verifier/plonk.rs b/src/verifier/plonk.rs new file mode 100644 index 00000000..9e08e585 --- /dev/null +++ b/src/verifier/plonk.rs @@ -0,0 +1,464 @@ +use crate::{ + cost::{Cost, CostEstimation}, + loader::{native::NativeLoader, LoadedScalar, Loader}, + pcs::{self, AccumulatorEncoding, MultiOpenScheme}, + util::{ + arithmetic::{CurveAffine, Field, Rotation}, + msm::Msm, + protocol::{ + CommonPolynomial::Lagrange, CommonPolynomialEvaluation, LinearizationStrategy, Query, + }, + transcript::TranscriptRead, + Itertools, + }, + verifier::PlonkVerifier, + Error, Protocol, +}; +use std::{collections::HashMap, iter, marker::PhantomData}; + +pub struct Plonk(PhantomData<(MOS, AE)>); + +impl PlonkVerifier for Plonk +where + C: CurveAffine, + L: Loader, + MOS: MultiOpenScheme, + AE: AccumulatorEncoding, +{ + type Proof = PlonkProof; + + fn read_proof( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + PlonkProof::read::(svk, protocol, instances, transcript) + } + + fn succinct_verify( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + proof: &Self::Proof, + ) -> Result, Error> { + let common_poly_eval = { + let mut common_poly_eval = CommonPolynomialEvaluation::new( + &protocol.domain, + langranges(protocol, instances), + &proof.z, + ); + + L::LoadedScalar::batch_invert(common_poly_eval.denoms()); + common_poly_eval.evaluate(); + + common_poly_eval + }; + + let mut evaluations = proof.evaluations(protocol, instances, &common_poly_eval)?; + let commitments = proof.commitments(protocol, &common_poly_eval, &mut evaluations)?; + let queries = proof.queries(protocol, evaluations); + + let accumulator = MOS::succinct_verify(svk, &commitments, &proof.z, &queries, &proof.pcs)?; + + let accumulators = iter::empty() + .chain(Some(accumulator)) + .chain(proof.old_accumulators.iter().cloned()) + .collect(); + + Ok(accumulators) + } +} + +#[derive(Clone, Debug)] +pub struct PlonkProof +where + C: CurveAffine, + L: Loader, + MOS: MultiOpenScheme, +{ + pub committed_instances: Option>, + pub witnesses: Vec, + pub challenges: Vec, + pub quotients: Vec, + pub z: L::LoadedScalar, + pub evaluations: Vec, + pub pcs: MOS::Proof, + pub old_accumulators: Vec, +} + +impl PlonkProof +where + C: CurveAffine, + L: Loader, + MOS: MultiOpenScheme, +{ + fn read( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + AE: AccumulatorEncoding, + { + let loader = transcript.loader(); + if let Some(transcript_initial_state) = &protocol.transcript_initial_state { + transcript.common_scalar(&loader.load_const(transcript_initial_state))?; + } + + if protocol.num_instance + != instances + .iter() + .map(|instances| instances.len()) + .collect_vec() + { + return Err(Error::InvalidInstances); + } + + let committed_instances = if let Some(ick) = &protocol.instance_committing_key { + let loader = transcript.loader(); + let bases = ick + .bases + .iter() + .map(|value| loader.ec_point_load_const(value)) + .collect_vec(); + let constant = ick + .constant + .as_ref() + .map(|value| loader.ec_point_load_const(value)); + + let committed_instances = instances + .iter() + .map(|instances| { + instances + .iter() + .zip(bases.iter()) + .map(|(scalar, base)| Msm::::base(base.clone()) * scalar) + .chain(constant.clone().map(|constant| Msm::base(constant))) + .sum::>() + .evaluate(None) + }) + .collect_vec(); + for committed_instance in committed_instances.iter() { + transcript.common_ec_point(committed_instance)?; + } + + Some(committed_instances) + } else { + for instances in instances.iter() { + for instance in instances.iter() { + transcript.common_scalar(instance)?; + } + } + + None + }; + + let (witnesses, challenges) = { + let (witnesses, challenges) = protocol + .num_witness + .iter() + .zip(protocol.num_challenge.iter()) + .map(|(&n, &m)| { + Ok(( + transcript.read_n_ec_points(n)?, + transcript.squeeze_n_challenges(m), + )) + }) + .collect::, Error>>()? + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + ( + witnesses.into_iter().flatten().collect_vec(), + challenges.into_iter().flatten().collect_vec(), + ) + }; + + let quotients = transcript.read_n_ec_points(protocol.quotient.num_chunk())?; + + let z = transcript.squeeze_challenge(); + let evaluations = transcript.read_n_scalars(protocol.evaluations.len())?; + + let pcs = MOS::read_proof(svk, &Self::empty_queries(protocol), transcript)?; + + let old_accumulators = protocol + .accumulator_indices + .iter() + .map(|accumulator_indices| { + accumulator_indices + .iter() + .map(|&(i, j)| instances[i][j].clone()) + .collect() + }) + .map(AE::from_repr) + .collect::, _>>()?; + + Ok(Self { + committed_instances, + witnesses, + challenges, + quotients, + z, + evaluations, + pcs, + old_accumulators, + }) + } + + fn empty_queries(protocol: &Protocol) -> Vec> { + protocol + .queries + .iter() + .map(|query| pcs::Query { + poly: query.poly, + shift: protocol + .domain + .rotate_scalar(C::Scalar::one(), query.rotation), + eval: (), + }) + .collect() + } + + fn queries( + &self, + protocol: &Protocol, + mut evaluations: HashMap, + ) -> Vec> { + Self::empty_queries(protocol) + .into_iter() + .zip( + protocol + .queries + .iter() + .map(|query| evaluations.remove(query).unwrap()), + ) + .map(|(query, eval)| query.with_evaluation(eval)) + .collect() + } + + fn commitments( + &self, + protocol: &Protocol, + common_poly_eval: &CommonPolynomialEvaluation, + evaluations: &mut HashMap, + ) -> Result>, Error> { + let loader = common_poly_eval.zn().loader(); + let mut commitments = iter::empty() + .chain( + protocol + .preprocessed + .iter() + .map(|value| Msm::base(loader.ec_point_load_const(value))), + ) + .chain( + self.committed_instances + .clone() + .map(|committed_instances| { + committed_instances.into_iter().map(Msm::base).collect_vec() + }) + .unwrap_or_else(|| { + iter::repeat_with(Default::default) + .take(protocol.num_instance.len()) + .collect_vec() + }), + ) + .chain(self.witnesses.iter().cloned().map(Msm::base)) + .collect_vec(); + + let numerator = protocol.quotient.numerator.evaluate( + &|scalar| Ok(Msm::constant(loader.load_const(&scalar))), + &|poly| Ok(Msm::constant(common_poly_eval.get(poly).clone())), + &|query| { + evaluations + .get(&query) + .cloned() + .map(Msm::constant) + .or_else(|| { + (query.rotation == Rotation::cur()) + .then(|| commitments.get(query.poly).cloned()) + .flatten() + }) + .ok_or(Error::InvalidQuery(query)) + }, + &|index| { + self.challenges + .get(index) + .cloned() + .map(Msm::constant) + .ok_or(Error::InvalidChallenge(index)) + }, + &|a| Ok(-a?), + &|a, b| Ok(a? + b?), + &|a, b| { + let (a, b) = (a?, b?); + match (a.size(), b.size()) { + (0, _) => Ok(b * &a.try_into_constant().unwrap()), + (_, 0) => Ok(a * &b.try_into_constant().unwrap()), + (_, _) => Err(Error::InvalidLinearization), + } + }, + &|a, scalar| Ok(a? * &loader.load_const(&scalar)), + )?; + + let quotient_query = Query::new( + protocol.preprocessed.len() + protocol.num_instance.len() + self.witnesses.len(), + Rotation::cur(), + ); + let quotient = common_poly_eval + .zn() + .pow_const(protocol.quotient.chunk_degree as u64) + .powers(self.quotients.len()) + .into_iter() + .zip(self.quotients.iter().cloned().map(Msm::base)) + .map(|(coeff, chunk)| chunk * &coeff) + .sum::>(); + match protocol.linearization { + Some(LinearizationStrategy::WithoutConstant) => { + let linearization_query = Query::new(quotient_query.poly + 1, Rotation::cur()); + let (msm, constant) = numerator.split(); + commitments.push(quotient); + commitments.push(msm); + evaluations.insert( + quotient_query, + (constant.unwrap_or_else(|| loader.load_zero()) + + evaluations.get(&linearization_query).unwrap()) + * common_poly_eval.zn_minus_one_inv(), + ); + } + Some(LinearizationStrategy::MinusVanishingTimesQuotient) => { + let (msm, constant) = + (numerator - quotient * common_poly_eval.zn_minus_one()).split(); + commitments.push(msm); + evaluations.insert( + quotient_query, + constant.unwrap_or_else(|| loader.load_zero()), + ); + } + None => { + commitments.push(quotient); + evaluations.insert( + quotient_query, + numerator + .try_into_constant() + .ok_or(Error::InvalidLinearization)? + * common_poly_eval.zn_minus_one_inv(), + ); + } + } + + Ok(commitments) + } + + fn evaluations( + &self, + protocol: &Protocol, + instances: &[Vec], + common_poly_eval: &CommonPolynomialEvaluation, + ) -> Result, Error> { + let loader = common_poly_eval.zn().loader(); + let instance_evals = protocol.instance_committing_key.is_none().then(|| { + let offset = protocol.preprocessed.len(); + let queries = { + let range = offset..offset + protocol.num_instance.len(); + protocol + .quotient + .numerator + .used_query() + .into_iter() + .filter(move |query| range.contains(&query.poly)) + }; + queries + .map(move |query| { + let instances = instances[query.poly - offset].iter(); + let l_i_minus_r = (-query.rotation.0..) + .map(|i_minus_r| common_poly_eval.get(Lagrange(i_minus_r))); + let eval = loader.sum_products(&instances.zip(l_i_minus_r).collect_vec()); + (query, eval) + }) + .collect_vec() + }); + + let evals = iter::empty() + .chain(instance_evals.into_iter().flatten()) + .chain( + protocol + .evaluations + .iter() + .cloned() + .zip(self.evaluations.iter().cloned()), + ) + .collect(); + + Ok(evals) + } +} + +impl CostEstimation<(C, MOS)> for Plonk +where + C: CurveAffine, + MOS: MultiOpenScheme + CostEstimation>>, +{ + type Input = Protocol; + + fn estimate_cost(protocol: &Protocol) -> Cost { + let plonk_cost = { + let num_accumulator = protocol.accumulator_indices.len(); + let num_instance = protocol.num_instance.iter().sum(); + let num_commitment = + protocol.num_witness.iter().sum::() + protocol.quotient.num_chunk(); + let num_evaluation = protocol.evaluations.len(); + let num_msm = protocol.preprocessed.len() + num_commitment + 1 + 2 * num_accumulator; + Cost::new(num_instance, num_commitment, num_evaluation, num_msm) + }; + let pcs_cost = { + let queries = PlonkProof::::empty_queries(protocol); + MOS::estimate_cost(&queries) + }; + plonk_cost + pcs_cost + } +} + +fn langranges(protocol: &Protocol, instances: &[Vec]) -> impl IntoIterator +where + C: CurveAffine, +{ + let instance_eval_lagrange = protocol.instance_committing_key.is_none().then(|| { + let queries = { + let offset = protocol.preprocessed.len(); + let range = offset..offset + protocol.num_instance.len(); + protocol + .quotient + .numerator + .used_query() + .into_iter() + .filter(move |query| range.contains(&query.poly)) + }; + let (min_rotation, max_rotation) = queries.fold((0, 0), |(min, max), query| { + if query.rotation.0 < min { + (query.rotation.0, max) + } else if query.rotation.0 > max { + (min, query.rotation.0) + } else { + (min, max) + } + }); + let max_instance_len = instances + .iter() + .map(|instance| instance.len()) + .max() + .unwrap_or_default(); + -max_rotation..max_instance_len as i32 + min_rotation.abs() + }); + protocol + .quotient + .numerator + .used_langrange() + .into_iter() + .chain(instance_eval_lagrange.into_iter().flatten()) +}