diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8a3eb1f5..63c4fdc7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,52 +2,64 @@ name: Tests on: push: - branches: ["main", "community-edition"] + branches: ["main"] pull_request: - branches: ["main", "community-edition"] + branches: ["main", "develop", "community-edition"] env: CARGO_TERM_COLOR: always jobs: build: - runs-on: ubuntu-latest-m + runs-on: ubuntu-latest-64core-256ram steps: - uses: actions/checkout@v3 - name: Build run: cargo build --verbose - name: Run halo2-base tests - working-directory: 'halo2-base' + working-directory: "halo2-base" run: | - cargo test -- --test-threads=1 - - name: Run poseidon tests - working-directory: 'hashes/poseidon' + cargo test + - name: Run halo2-ecc tests (mock prover) + working-directory: "halo2-ecc" run: | - cargo test test_poseidon_compatibility - - name: Run halo2-ecc tests MockProver - working-directory: 'halo2-ecc' + cargo test --lib -- --skip bench + - name: Run halo2-ecc tests (real prover) + working-directory: "halo2-ecc" run: | - cargo test -- --test-threads=1 test_fp - cargo test -- test_ecc - cargo test -- test_secp - cargo test -- test_ecdsa - cargo test -- test_ec_add - cargo test -- test_fixed - cargo test -- test_msm - cargo test -- test_sm - cargo test -- test_fb - cargo test -- test_pairing - cargo test -- test_bls_signature - - name: Run halo2-ecc tests real prover - working-directory: 'halo2-ecc' - run: | - cargo test --release -- test_fp_assert_eq - cargo test --release -- --nocapture bench_secp256k1_ecdsa - cargo test --release -- --nocapture bench_ec_add mv configs/bn254/bench_fixed_msm.t.config configs/bn254/bench_fixed_msm.config - cargo test --release -- --nocapture bench_fixed_base_msm mv configs/bn254/bench_msm.t.config configs/bn254/bench_msm.config - cargo test --release -- --nocapture bench_msm mv configs/bn254/bench_pairing.t.config configs/bn254/bench_pairing.config + mv configs/secp256k1/bench_ecdsa.t.config configs/secp256k1/bench_ecdsa.config + cargo test --release -- --nocapture bench_secp256k1_ecdsa + cargo test --release -- --nocapture bench_fixed_base_msm + cargo test --release -- --nocapture bench_msm cargo test --release -- --nocapture bench_pairing + - name: Run zkevm tests + working-directory: "hashes/zkevm" + run: | + cargo test packed_multi_keccak_prover::k_14 + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + override: false + components: rustfmt, clippy + + - uses: Swatinem/rust-cache@v1 + with: + cache-on-failure: true + + - name: Run fmt + run: cargo fmt --all -- --check + + - name: Run clippy + run: cargo clippy --all -- -D warnings diff --git a/.gitignore b/.gitignore index 65983083..d2a5b639 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,10 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk + +# Local IDE configs +.idea/ +.vscode/ ======= /target @@ -15,3 +19,7 @@ Cargo.lock /halo2_ecc/src/bn254/data/ /halo2_ecc/src/secp256k1/data/ + +/halo2_ecc/params/ +/halo2_ecc/results/ +/halo2_base/params/ diff --git a/Cargo.toml b/Cargo.toml index 68b2cfe2..b2d3ab72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,14 +1,10 @@ [workspace] -members = [ - "halo2-base", - "halo2-ecc", - "hashes/zkevm-keccak", - "hashes/poseidon" -] +members = ["halo2-base", "halo2-ecc", "hashes/zkevm"] +resolver = "2" [profile.dev] opt-level = 3 -debug = 1 # change to 0 or 2 for more or less debug info +debug = 2 # change to 0 or 2 for more or less debug info overflow-checks = true incremental = true @@ -29,7 +25,7 @@ codegen-units = 16 opt-level = 3 debug = false debug-assertions = false -lto = "fat" +lto = "fat" # `codegen-units = 1` can lead to WORSE performance - always bench to find best profile for your machine! # codegen-units = 1 panic = "unwind" @@ -40,7 +36,6 @@ incremental = false inherits = "release" debug = true -# patch so snark-verifier uses this crate's halo2-base [patch."https://github.com/axiom-crypto/halo2-lib.git"] -halo2-base = { path = "./halo2-base" } -halo2-ecc = { path = "./halo2-ecc" } +halo2-base = { path = "../halo2-lib/halo2-base" } +halo2-ecc = { path = "../halo2-lib/halo2-ecc" } diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 5cf222b2..941c6b64 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -1,61 +1,74 @@ [package] -name = "halo2-base" -version = "0.3.0" -edition = "2021" +name="halo2-base" +version="0.4.0" +edition="2021" [dependencies] -itertools = "=0.10" -num-bigint = { version = "=0.4", features = ["rand"] } -num-integer = "=0.1" -num-traits = "=0.2" -rand_chacha = "=0.3" -rustc-hash = "=1.1" -ff = "=0.12" -rayon = "=1.7" -serde = { version = "=1.0", features = ["derive"] } -serde_json = "=1.0" -log = "=0.4" +itertools="0.11" +num-bigint={ version="0.4", features=["rand"] } +num-integer="0.1" +num-traits="0.2" +rand_chacha="0.3" +rustc-hash="1.1" +rayon="1.7" +serde={ version="1.0", features=["derive"] } +serde_json="1.0" +log="0.4" +getset="0.1.2" +ark-std={ version="0.3.0", features=["print-trace"], optional=true } # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", rev = "7db51d3", package = "halo2_proofs", optional = true } +halo2_proofs_axiom={ git="https://github.com/axiom-crypto/halo2.git", package="halo2_proofs", optional=true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_02_02", optional = true } +halo2_proofs={ git="https://github.com/privacy-scaling-explorations/halo2.git", rev="7a21656", optional=true } +# This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile). +# We forked it to upgrade to ff v0.13 and removed the circuit module +poseidon-rs={ git="https://github.com/axiom-crypto/poseidon-circuit.git", rev="1aee4a1" } # plotting circuit layout -plotters = { version = "0.3.0", optional = true } -tabbycat = { version = "0.1", features = ["attributes"], optional = true } +plotters={ version="0.3.0", optional=true } +tabbycat={ version="0.1", features=["attributes"], optional=true } # test-utils -rand = { version = "0.8", optional = true } +rand={ version="0.8", optional=true } [dev-dependencies] -ark-std = { version = "=0.3.0", features = ["print-trace"] } -rand = "=0.8" -pprof = { version = "=0.11", features = ["criterion", "flamegraph"] } -criterion = "=0.4" -criterion-macro = "=0.4" -test-case = "=3.1.0" -proptest = "=1.1.0" +ark-std={ version="0.3.0", features=["print-trace"] } +rand="0.8" +pprof={ version="0.11", features=["criterion", "flamegraph"] } +criterion="0.4" +criterion-macro="0.4" +test-case="3.1.0" +test-log="0.2.12" +env_logger="0.10.0" +proptest="1.1.0" +# native poseidon for testing +pse-poseidon={ git="https://github.com/axiom-crypto/pse-poseidon.git" } # memory allocation [target.'cfg(not(target_env = "msvc"))'.dependencies] -jemallocator = { version = "=0.5", optional = true } +jemallocator={ version="=0.5", optional=true } -mimalloc = { version = "=0.1", default-features = false, optional = true } +mimalloc={ version="=0.1", default-features=false, optional=true } [features] -default = ["halo2-axiom", "display"] -dev-graph = ["halo2_proofs?/dev-graph", "halo2_proofs_axiom?/dev-graph", "plotters"] -halo2-pse = ["halo2_proofs"] -halo2-axiom = ["halo2_proofs_axiom"] -display = [] -profile = ["halo2_proofs_axiom?/profile"] -test-utils = ["dep:rand"] +default=["halo2-axiom", "display", "test-utils"] +asm=["halo2_proofs_axiom?/asm"] +dev-graph=["halo2_proofs?/dev-graph", "halo2_proofs_axiom?/dev-graph", "plotters"] +halo2-pse=["halo2_proofs/circuit-params"] +halo2-axiom=["halo2_proofs_axiom"] +display=[] +profile=["halo2_proofs_axiom?/profile"] +test-utils=["dep:rand", "ark-std"] [[bench]] -name = "mul" -harness = false +name="mul" +harness=false [[bench]] -name = "inner_product" -harness = false +name="inner_product" +harness=false + +[[example]] +name="inner_product" +required-features=["test-utils"] diff --git a/halo2-base/benches/inner_product.rs b/halo2-base/benches/inner_product.rs index 9454faa3..45f503b9 100644 --- a/halo2-base/benches/inner_product.rs +++ b/halo2-base/benches/inner_product.rs @@ -1,28 +1,17 @@ -#![allow(unused_imports)] -#![allow(unused_variables)] -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; -use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; +use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ arithmetic::Field, - circuit::*, dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; +use halo2_base::utils::testing::gen_proof; use halo2_base::utils::ScalarField; -use halo2_base::{ - Context, - QuantumCell::{Existing, Witness}, - SKIP_FIRST_PASS, -}; +use halo2_base::{Context, QuantumCell::Existing}; use itertools::Itertools; use rand::rngs::OsRng; -use std::marker::PhantomData; use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; @@ -47,20 +36,20 @@ fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) fn bench(c: &mut Criterion) { let k = 19u32; // create circuit for keygen - let mut builder = GateThreadBuilder::new(false); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(k as usize); inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); - builder.config(k as usize, Some(20)); - let circuit = GateCircuitBuilder::mock(builder); + let config_params = builder.calculate_params(Some(20)); // check the circuit is correct just in case - MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); + MockProver::run(k, &builder, vec![]).unwrap().assert_satisfied(); let params = ParamsKZG::::setup(k, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let vk = keygen_vk(¶ms, &builder).expect("vk should not fail"); + let pk = keygen_pk(¶ms, vk, &builder).expect("pk should not fail"); - let break_points = circuit.break_points.take(); - drop(circuit); + let break_points = builder.break_points(); + drop(builder); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); @@ -69,22 +58,12 @@ fn bench(c: &mut Criterion) { &(¶ms, &pk), |bencher, &(params, pk)| { bencher.iter(|| { - let mut builder = GateThreadBuilder::new(true); + let mut builder = + RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); inner_prod_bench(builder.main(0), a, b); - let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); - - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, builder); }) }, ); diff --git a/halo2-base/benches/mul.rs b/halo2-base/benches/mul.rs index 16687e08..ee239abd 100644 --- a/halo2-base/benches/mul.rs +++ b/halo2-base/benches/mul.rs @@ -1,15 +1,12 @@ -use ff::Field; -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ - halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr}, + halo2curves::ff::Field, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverGWC, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; +use halo2_base::utils::testing::gen_proof; use halo2_base::utils::ScalarField; use halo2_base::Context; use rand::rngs::OsRng; @@ -34,16 +31,16 @@ fn mul_bench(ctx: &mut Context, inputs: [F; 2]) { fn bench(c: &mut Criterion) { // create circuit for keygen - let mut builder = GateThreadBuilder::new(false); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(K as usize); mul_bench(builder.main(0), [Fr::zero(); 2]); - builder.config(K as usize, Some(9)); - let circuit = GateCircuitBuilder::keygen(builder); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::::setup(K, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let vk = keygen_vk(¶ms, &builder).expect("vk should not fail"); + let pk = keygen_pk(¶ms, vk, &builder).expect("pk should not fail"); - let break_points = circuit.break_points.take(); + let break_points = builder.break_points(); let a = Fr::random(OsRng); let b = Fr::random(OsRng); @@ -53,21 +50,12 @@ fn bench(c: &mut Criterion) { &(¶ms, &pk, [a, b]), |bencher, &(params, pk, inputs)| { bencher.iter(|| { - let mut builder = GateThreadBuilder::new(true); + let mut builder = + RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); // do the computation mul_bench(builder.main(0), inputs); - let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverGWC<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .unwrap(); + gen_proof(params, pk, builder); }) }, ); diff --git a/halo2-base/examples/inner_product.rs b/halo2-base/examples/inner_product.rs index 8572817e..c1413211 100644 --- a/halo2-base/examples/inner_product.rs +++ b/halo2-base/examples/inner_product.rs @@ -1,95 +1,39 @@ -#![allow(unused_imports)] -#![allow(unused_variables)] -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; -use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; -use halo2_base::halo2_proofs::{ - arithmetic::Field, - circuit::*, - dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::*, - poly::kzg::multiopen::VerifierSHPLONK, - poly::kzg::strategy::SingleStrategy, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bRead, TranscriptReadBuffer}, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, -}; +#![cfg(feature = "test-utils")] +use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; +use halo2_base::gates::RangeInstructions; +use halo2_base::halo2_proofs::{arithmetic::Field, halo2curves::bn256::Fr}; +use halo2_base::utils::testing::base_test; use halo2_base::utils::ScalarField; -use halo2_base::{ - Context, - QuantumCell::{Existing, Witness}, - SKIP_FIRST_PASS, -}; +use halo2_base::{Context, QuantumCell::Existing}; use itertools::Itertools; use rand::rngs::OsRng; -use std::marker::PhantomData; - -use criterion::{criterion_group, criterion_main}; -use criterion::{BenchmarkId, Criterion}; - -use pprof::criterion::{Output, PProfProfiler}; -// Thanks to the example provided by @jebbow in his article -// https://www.jibbow.com/posts/criterion-flamegraphs/ const K: u32 = 19; -fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) { +fn inner_prod_bench( + ctx: &mut Context, + gate: &GateChip, + a: Vec, + b: Vec, +) { assert_eq!(a.len(), b.len()); let a = ctx.assign_witnesses(a); let b = ctx.assign_witnesses(b); - let chip = GateChip::default(); for _ in 0..(1 << K) / 16 - 10 { - chip.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); + gate.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); } } fn main() { - let k = 10u32; - // create circuit for keygen - let mut builder = GateThreadBuilder::new(false); - inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); - builder.config(k as usize, Some(20)); - let circuit = GateCircuitBuilder::mock(builder); - - // check the circuit is correct just in case - MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); - - let params = ParamsKZG::::setup(k, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - - let break_points = circuit.break_points.take(); - - let mut builder = GateThreadBuilder::new(true); - let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); - let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); - inner_prod_bench(builder.main(0), a, b); - let circuit = GateCircuitBuilder::prover(builder, break_points); - - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); - - let strategy = SingleStrategy::new(¶ms); - let proof = transcript.finalize(); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - _, - >(¶ms, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); + base_test().k(12).bench_builder( + (vec![Fr::ZERO; 5], vec![Fr::ZERO; 5]), + ( + (0..5).map(|_| Fr::random(OsRng)).collect_vec(), + (0..5).map(|_| Fr::random(OsRng)).collect_vec(), + ), + |pool, range, (a, b)| { + inner_prod_bench(pool.main(), range.gate(), a, b); + }, + ); } diff --git a/halo2-base/src/gates/builder.rs b/halo2-base/src/gates/builder.rs deleted file mode 100644 index 22c2ce93..00000000 --- a/halo2-base/src/gates/builder.rs +++ /dev/null @@ -1,796 +0,0 @@ -use super::{ - flex_gate::{FlexGateConfig, GateStrategy, MAX_PHASE}, - range::{RangeConfig, RangeStrategy}, -}; -use crate::{ - halo2_proofs::{ - circuit::{self, Layouter, Region, SimpleFloorPlanner, Value}, - plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Instance, Selector}, - }, - utils::ScalarField, - AssignedValue, Context, SKIP_FIRST_PASS, -}; -use serde::{Deserialize, Serialize}; -use std::{ - cell::RefCell, - collections::{HashMap, HashSet}, - env::{set_var, var}, -}; - -mod parallelize; -pub use parallelize::*; - -/// Vector of thread advice column break points -pub type ThreadBreakPoints = Vec; -/// Vector of vectors tracking the thread break points across different halo2 phases -pub type MultiPhaseThreadBreakPoints = Vec; - -/// Stores the cell values loaded during the Keygen phase of a halo2 proof and breakpoints for multi-threading -#[derive(Clone, Debug, Default)] -pub struct KeygenAssignments { - /// Advice assignments - pub assigned_advices: HashMap<(usize, usize), (circuit::Cell, usize)>, // (key = ContextCell, value = (circuit::Cell, row offset)) - /// Constant assignments in Fixes Assignments - pub assigned_constants: HashMap, // (key = constant, value = circuit::Cell) - /// Advice column break points for threads in each phase. - pub break_points: MultiPhaseThreadBreakPoints, -} - -/// Builds the process for gate threading -#[derive(Clone, Debug, Default)] -pub struct GateThreadBuilder { - /// Threads for each challenge phase - pub threads: [Vec>; MAX_PHASE], - /// Max number of threads - thread_count: usize, - /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. - pub witness_gen_only: bool, - /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. - use_unknown: bool, -} - -impl GateThreadBuilder { - /// Creates a new [GateThreadBuilder] and spawns a main thread in phase 0. - /// * `witness_gen_only`: If true, the [GateThreadBuilder] is used for witness generation only. - /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. - /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). - /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. - pub fn new(witness_gen_only: bool) -> Self { - let mut threads = [(); MAX_PHASE].map(|_| vec![]); - // start with a main thread in phase 0 - threads[0].push(Context::new(witness_gen_only, 0)); - Self { threads, thread_count: 1, witness_gen_only, use_unknown: false } - } - - /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. - /// - /// Performs the witness assignment computations and then checks using normal programming logic whether the gate constraints are all satisfied. - pub fn mock() -> Self { - Self::new(false) - } - - /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. - /// - /// Performs the witness assignment computations and generates prover and verifier keys. - pub fn keygen() -> Self { - Self::new(false) - } - - /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to true. - /// - /// Performs the witness assignment computations and then runs the proving system. - pub fn prover() -> Self { - Self::new(true) - } - - /// Creates a new [GateThreadBuilder] with `use_unknown` flag set. - /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. - pub fn unknown(self, use_unknown: bool) -> Self { - Self { use_unknown, ..self } - } - - /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. - /// * `phase`: The challenge phase (as an index) of the gate thread. - pub fn main(&mut self, phase: usize) -> &mut Context { - if self.threads[phase].is_empty() { - self.new_thread(phase) - } else { - self.threads[phase].last_mut().unwrap() - } - } - - /// Returns the `witness_gen_only` flag. - pub fn witness_gen_only(&self) -> bool { - self.witness_gen_only - } - - /// Returns the `use_unknown` flag. - pub fn use_unknown(&self) -> bool { - self.use_unknown - } - - /// Returns the current number of threads in the [GateThreadBuilder]. - pub fn thread_count(&self) -> usize { - self.thread_count - } - - /// Creates a new thread id by incrementing the `thread count` - pub fn get_new_thread_id(&mut self) -> usize { - let thread_id = self.thread_count; - self.thread_count += 1; - thread_id - } - - /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. - /// * `phase`: The phase (index) of the gate thread. - pub fn new_thread(&mut self, phase: usize) -> &mut Context { - let thread_id = self.thread_count; - self.thread_count += 1; - self.threads[phase].push(Context::new(self.witness_gen_only, thread_id)); - self.threads[phase].last_mut().unwrap() - } - - /// Auto-calculates configuration parameters for the circuit - /// - /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) - /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. - pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { - let max_rows = (1 << k) - minimum_rows.unwrap_or(0); - let total_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) - .collect::>(); - // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) - // if this is too small, manual configuration will be needed - let num_advice_per_phase = total_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_lookup_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) - .collect::>(); - let num_lookup_advice_per_phase = total_lookup_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { - threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) - })) - .len(); - let num_fixed = (total_fixed + (1 << k) - 1) >> k; - - let params = FlexGateConfigParams { - strategy: GateStrategy::Vertical, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - }; - #[cfg(feature = "display")] - { - for phase in 0..MAX_PHASE { - if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { - println!( - "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", - phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], - ); - } - } - println!("Total {total_fixed} fixed cells"); - log::info!("Auto-calculated config params:\n {params:#?}"); - } - set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); - params - } - - /// Assigns all advice and fixed cells, turns on selectors, and imposes equality constraints. - /// - /// Returns the assigned advices, and constants in the form of [KeygenAssignments]. - /// - /// Assumes selector and advice columns are already allocated and of the same length. - /// - /// Note: `assign_all()` **should** be called during keygen or if using mock prover. It also works for the real prover, but there it is more optimal to use [`assign_threads_in`] instead. - /// * `config`: The [FlexGateConfig] of the circuit. - /// * `lookup_advice`: The lookup advice columns. - /// * `q_lookup`: The lookup advice selectors. - /// * `region`: The [Region] of the circuit. - /// * `assigned_advices`: The assigned advice cells. - /// * `assigned_constants`: The assigned fixed cells. - /// * `break_points`: The break points of the circuit. - pub fn assign_all( - &self, - config: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - region: &mut Region, - KeygenAssignments { - mut assigned_advices, - mut assigned_constants, - mut break_points - }: KeygenAssignments, - ) -> KeygenAssignments { - let use_unknown = self.use_unknown; - let max_rows = config.max_rows; - let mut fixed_col = 0; - let mut fixed_offset = 0; - for (phase, threads) in self.threads.iter().enumerate() { - let mut break_point = vec![]; - let mut gate_index = 0; - let mut row_offset = 0; - for ctx in threads { - let mut basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - assert_eq!(ctx.selector.len(), ctx.advice.len()); - - for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { - let column = basic_gate.value; - let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; - #[cfg(feature = "halo2-axiom")] - let cell = *region.assign_advice(column, row_offset, value).cell(); - #[cfg(not(feature = "halo2-axiom"))] - let cell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); - - // If selector enabled and row_offset is valid add break point to Keygen Assignments, account for break point overlap, and enforce equality constraint for gate outputs. - if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { - break_point.push(row_offset); - row_offset = 0; - gate_index += 1; - - // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety - basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - let column = basic_gate.value; - - #[cfg(feature = "halo2-axiom")] - { - let ncell = region.assign_advice(column, row_offset, value); - region.constrain_equal(ncell.cell(), &cell); - } - #[cfg(not(feature = "halo2-axiom"))] - { - let ncell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - region.constrain_equal(ncell, cell).unwrap(); - } - } - - if q { - basic_gate - .q_enable - .enable(region, row_offset) - .expect("enable selector should not fail"); - } - - row_offset += 1; - } - // Assign fixed cells - for (c, _) in ctx.constant_equality_constraints.iter() { - if assigned_constants.get(c).is_none() { - #[cfg(feature = "halo2-axiom")] - let cell = - region.assign_fixed(config.constants[fixed_col], fixed_offset, c); - #[cfg(not(feature = "halo2-axiom"))] - let cell = region - .assign_fixed( - || "", - config.constants[fixed_col], - fixed_offset, - || Value::known(*c), - ) - .unwrap() - .cell(); - assigned_constants.insert(*c, cell); - fixed_col += 1; - if fixed_col >= config.constants.len() { - fixed_col = 0; - fixed_offset += 1; - } - } - } - } - break_points.push(break_point); - } - // we constrain equality constraints in a separate loop in case context `i` contains references to context `j` for `j > i` - for (phase, threads) in self.threads.iter().enumerate() { - let mut lookup_offset = 0; - let mut lookup_col = 0; - for ctx in threads { - for (left, right) in &ctx.advice_equality_constraints { - let (left, _) = assigned_advices[&(left.context_id, left.offset)]; - let (right, _) = assigned_advices[&(right.context_id, right.offset)]; - #[cfg(feature = "halo2-axiom")] - region.constrain_equal(&left, &right); - #[cfg(not(feature = "halo2-axiom"))] - region.constrain_equal(left, right).unwrap(); - } - for (left, right) in &ctx.constant_equality_constraints { - let left = assigned_constants[left]; - let (right, _) = assigned_advices[&(right.context_id, right.offset)]; - #[cfg(feature = "halo2-axiom")] - region.constrain_equal(&left, &right); - #[cfg(not(feature = "halo2-axiom"))] - region.constrain_equal(left, right).unwrap(); - } - - for advice in &ctx.cells_to_lookup { - // if q_lookup is Some, that means there should be a single advice column and it has lookup enabled - let cell = advice.cell.unwrap(); - let (acell, row_offset) = assigned_advices[&(cell.context_id, cell.offset)]; - if let Some(q_lookup) = q_lookup[phase] { - assert_eq!(config.basic_gates[phase].len(), 1); - q_lookup.enable(region, row_offset).unwrap(); - continue; - } - // otherwise, we copy the advice value to the special lookup_advice columns - if lookup_offset >= max_rows { - lookup_offset = 0; - lookup_col += 1; - } - let value = advice.value; - let value = if use_unknown { Value::unknown() } else { Value::known(value) }; - let column = lookup_advice[phase][lookup_col]; - - #[cfg(feature = "halo2-axiom")] - { - let bcell = region.assign_advice(column, lookup_offset, value); - region.constrain_equal(&acell, bcell.cell()); - } - #[cfg(not(feature = "halo2-axiom"))] - { - let bcell = region - .assign_advice(|| "", column, lookup_offset, || value) - .expect("assign_advice should not fail") - .cell(); - region.constrain_equal(acell, bcell).unwrap(); - } - lookup_offset += 1; - } - } - } - KeygenAssignments { assigned_advices, assigned_constants, break_points } - } -} - -/// Assigns threads to regions of advice column. -/// -/// Uses preprocessed `break_points` to assign where to divide the advice column into a new column for each thread. -/// -/// Performs only witness generation, so should only be evoked during proving not keygen. -/// -/// Assumes that the advice columns are already assigned. -/// * `phase` - the phase of the circuit -/// * `threads` - [Vec] threads to assign -/// * `config` - immutable reference to the configuration of the circuit -/// * `lookup_advice` - Slice of lookup advice columns -/// * `region` - mutable reference to the region to assign threads to -/// * `break_points` - the preprocessed break points for the threads -pub fn assign_threads_in( - phase: usize, - threads: Vec>, - config: &FlexGateConfig, - lookup_advice: &[Column], - region: &mut Region, - break_points: ThreadBreakPoints, -) { - if config.basic_gates[phase].is_empty() { - assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); - return; - } - - let mut break_points = break_points.into_iter(); - let mut break_point = break_points.next(); - - let mut gate_index = 0; - let mut column = config.basic_gates[phase][gate_index].value; - let mut row_offset = 0; - - let mut lookup_offset = 0; - let mut lookup_advice = lookup_advice.iter(); - let mut lookup_column = lookup_advice.next(); - for ctx in threads { - // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns - if lookup_column.is_some() { - for advice in ctx.cells_to_lookup { - if lookup_offset >= config.max_rows { - lookup_offset = 0; - lookup_column = lookup_advice.next(); - } - // Assign the lookup advice values to the lookup_column - let value = advice.value; - let lookup_column = *lookup_column.unwrap(); - #[cfg(feature = "halo2-axiom")] - region.assign_advice(lookup_column, lookup_offset, Value::known(value)); - #[cfg(not(feature = "halo2-axiom"))] - region - .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) - .unwrap(); - - lookup_offset += 1; - } - } - // Assign advice values to the advice columns in each [Context] - for advice in ctx.advice { - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - - if break_point == Some(row_offset) { - break_point = break_points.next(); - row_offset = 0; - gate_index += 1; - column = config.basic_gates[phase][gate_index].value; - - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - } - - row_offset += 1; - } - } -} - -/// A Config struct defining the parameters for a FlexGate circuit. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FlexGateConfigParams { - /// The gate strategy used for the advice column of the circuit and applied at every row. - pub strategy: GateStrategy, - /// Security parameter `k` used for the keygen. - pub k: usize, - /// The number of advice columns per phase - pub num_advice_per_phase: Vec, - /// The number of advice columns that do not have lookup enabled per phase - pub num_lookup_advice_per_phase: Vec, - /// The number of fixed columns per phase - pub num_fixed: usize, -} - -/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. -#[derive(Clone, Debug)] -pub struct GateCircuitBuilder { - /// The Thread Builder for the circuit - pub builder: RefCell>, // `RefCell` is just to trick circuit `synthesize` to take ownership of the inner builder - /// Break points for threads within the circuit - pub break_points: RefCell, // `RefCell` allows the circuit to record break points in a keygen call of `synthesize` for use in later witness gen -} - -impl GateCircuitBuilder { - /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to true. - pub fn keygen(builder: GateThreadBuilder) -> Self { - Self { builder: RefCell::new(builder.unknown(true)), break_points: RefCell::new(vec![]) } - } - - /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to false. - pub fn mock(builder: GateThreadBuilder) -> Self { - Self { builder: RefCell::new(builder.unknown(false)), break_points: RefCell::new(vec![]) } - } - - /// Creates a new [GateCircuitBuilder]. - pub fn prover( - builder: GateThreadBuilder, - break_points: MultiPhaseThreadBreakPoints, - ) -> Self { - Self { builder: RefCell::new(builder), break_points: RefCell::new(break_points) } - } - - /// Synthesizes from the [GateCircuitBuilder] by populating the advice column and assigning new threads if witness generation is performed. - pub fn sub_synthesize( - &self, - gate: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - layouter: &mut impl Layouter, - ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { - let mut first_pass = SKIP_FIRST_PASS; - let mut assigned_advices = HashMap::new(); - layouter - .assign_region( - || "GateCircuitBuilder generated circuit", - |mut region| { - if first_pass { - first_pass = false; - return Ok(()); - } - // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize - // If we are not performing witness generation only, we can skip the first pass and assign threads directly - if !self.builder.borrow().witness_gen_only { - // clone the builder so we can re-use the circuit for both vk and pk gen - let builder = self.builder.borrow().clone(); - for threads in builder.threads.iter().skip(1) { - assert!( - threads.is_empty(), - "GateCircuitBuilder only supports FirstPhase for now" - ); - } - let assignments = builder.assign_all( - gate, - lookup_advice, - q_lookup, - &mut region, - Default::default(), - ); - *self.break_points.borrow_mut() = assignments.break_points; - assigned_advices = assignments.assigned_advices; - } else { - // If we are only generating witness, we can skip the first pass and assign threads directly - let builder = self.builder.take(); - let break_points = self.break_points.take(); - for (phase, (threads, break_points)) in builder - .threads - .into_iter() - .zip(break_points.into_iter()) - .enumerate() - .take(1) - { - assign_threads_in( - phase, - threads, - gate, - lookup_advice.get(phase).unwrap_or(&vec![]), - &mut region, - break_points, - ); - } - } - Ok(()) - }, - ) - .unwrap(); - assigned_advices - } -} - -impl Circuit for GateCircuitBuilder { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; - - /// Creates a new instance of the circuit without withnesses filled in. - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using the the parameters specified [Config]. - fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase: _, - num_fixed, - k, - } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) - } - - /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - self.sub_synthesize(&config, &[], &[], &mut layouter); - Ok(()) - } -} - -/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. -#[derive(Clone, Debug)] -pub struct RangeCircuitBuilder(pub GateCircuitBuilder); - -impl RangeCircuitBuilder { - /// Creates an instance of the [RangeCircuitBuilder] and executes in keygen mode. - pub fn keygen(builder: GateThreadBuilder) -> Self { - Self(GateCircuitBuilder::keygen(builder)) - } - - /// Creates a mock instance of the [RangeCircuitBuilder]. - pub fn mock(builder: GateThreadBuilder) -> Self { - Self(GateCircuitBuilder::mock(builder)) - } - - /// Creates an instance of the [RangeCircuitBuilder] and executes in prover mode. - pub fn prover( - builder: GateThreadBuilder, - break_points: MultiPhaseThreadBreakPoints, - ) -> Self { - Self(GateCircuitBuilder::prover(builder, break_points)) - } -} - -impl Circuit for RangeCircuitBuilder { - type Config = RangeConfig; - type FloorPlanner = SimpleFloorPlanner; - - /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - let strategy = match strategy { - GateStrategy::Vertical => RangeStrategy::Vertical, - }; - let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); - RangeConfig::configure( - meta, - strategy, - &num_advice_per_phase, - &num_lookup_advice_per_phase, - num_fixed, - lookup_bits, - k, - ) - } - - /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // only load lookup table if we are actually doing lookups - if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 - || !config.q_lookup.iter().all(|q| q.is_none()) - { - config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); - } - self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); - Ok(()) - } -} - -/// Configuration with [`RangeConfig`] and a single public instance column. -#[derive(Clone, Debug)] -pub struct RangeWithInstanceConfig { - /// The underlying range configuration - pub range: RangeConfig, - /// The public instance column - pub instance: Column, -} - -/// This is an extension of [`RangeCircuitBuilder`] that adds support for public instances (aka public inputs+outputs) -/// -/// The intended design is that a [`GateThreadBuilder`] is populated and then produces some assigned instances, which are supplied as `assigned_instances` to this struct. -/// The [`Circuit`] implementation for this struct will then expose these instances and constrain them using the Halo2 API. -#[derive(Clone, Debug)] -pub struct RangeWithInstanceCircuitBuilder { - /// The underlying circuit builder - pub circuit: RangeCircuitBuilder, - /// The assigned instances to expose publicly at the end of circuit synthesis - pub assigned_instances: Vec>, -} - -impl RangeWithInstanceCircuitBuilder { - /// See [`RangeCircuitBuilder::keygen`] - pub fn keygen( - builder: GateThreadBuilder, - assigned_instances: Vec>, - ) -> Self { - Self { circuit: RangeCircuitBuilder::keygen(builder), assigned_instances } - } - - /// See [`RangeCircuitBuilder::mock`] - pub fn mock(builder: GateThreadBuilder, assigned_instances: Vec>) -> Self { - Self { circuit: RangeCircuitBuilder::mock(builder), assigned_instances } - } - - /// See [`RangeCircuitBuilder::prover`] - pub fn prover( - builder: GateThreadBuilder, - assigned_instances: Vec>, - break_points: MultiPhaseThreadBreakPoints, - ) -> Self { - Self { circuit: RangeCircuitBuilder::prover(builder, break_points), assigned_instances } - } - - /// Creates a new instance of the [RangeWithInstanceCircuitBuilder]. - pub fn new(circuit: RangeCircuitBuilder, assigned_instances: Vec>) -> Self { - Self { circuit, assigned_instances } - } - - /// Calls [`GateThreadBuilder::config`] - pub fn config(&self, k: u32, minimum_rows: Option) -> FlexGateConfigParams { - self.circuit.0.builder.borrow().config(k as usize, minimum_rows) - } - - /// Gets the break points of the circuit. - pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { - self.circuit.0.break_points.borrow().clone() - } - - /// Gets the number of instances. - pub fn instance_count(&self) -> usize { - self.assigned_instances.len() - } - - /// Gets the instances. - pub fn instance(&self) -> Vec { - self.assigned_instances.iter().map(|v| *v.value()).collect() - } -} - -impl Circuit for RangeWithInstanceCircuitBuilder { - type Config = RangeWithInstanceConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let range = RangeCircuitBuilder::configure(meta); - let instance = meta.instance_column(); - meta.enable_equality(instance); - RangeWithInstanceConfig { range, instance } - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // copied from RangeCircuitBuilder::synthesize but with extra logic to expose public instances - let range = config.range; - let circuit = &self.circuit.0; - // only load lookup table if we are actually doing lookups - if range.lookup_advice.iter().map(|a| a.len()).sum::() != 0 - || !range.q_lookup.iter().all(|q| q.is_none()) - { - range.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); - } - // we later `take` the builder, so we need to save this value - let witness_gen_only = circuit.builder.borrow().witness_gen_only(); - let assigned_advices = circuit.sub_synthesize( - &range.gate, - &range.lookup_advice, - &range.q_lookup, - &mut layouter, - ); - - if !witness_gen_only { - // expose public instances - let mut layouter = layouter.namespace(|| "expose"); - for (i, instance) in self.assigned_instances.iter().enumerate() { - let cell = instance.cell.unwrap(); - let (cell, _) = assigned_advices - .get(&(cell.context_id, cell.offset)) - .expect("instance not assigned"); - layouter.constrain_instance(*cell, config.instance, i); - } - } - Ok(()) - } -} - -/// Defines stage of the circuit builder. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum CircuitBuilderStage { - /// Keygen phase - Keygen, - /// Prover Circuit - Prover, - /// Mock Circuit - Mock, -} diff --git a/halo2-base/src/gates/builder/parallelize.rs b/halo2-base/src/gates/builder/parallelize.rs deleted file mode 100644 index ab9171d5..00000000 --- a/halo2-base/src/gates/builder/parallelize.rs +++ /dev/null @@ -1,38 +0,0 @@ -use itertools::Itertools; -use rayon::prelude::*; - -use crate::{utils::ScalarField, Context}; - -use super::GateThreadBuilder; - -/// Utility function to parallelize an operation involving [`Context`]s in phase `phase`. -pub fn parallelize_in( - phase: usize, - builder: &mut GateThreadBuilder, - input: Vec, - f: FR, -) -> Vec -where - F: ScalarField, - T: Send, - R: Send, - FR: Fn(&mut Context, T) -> R + Send + Sync, -{ - let witness_gen_only = builder.witness_gen_only(); - // to prevent concurrency issues with context id, we generate all the ids first - let ctx_ids = input.iter().map(|_| builder.get_new_thread_id()).collect_vec(); - let (outputs, mut ctxs): (Vec<_>, Vec<_>) = input - .into_par_iter() - .zip(ctx_ids.into_par_iter()) - .map(|(input, ctx_id)| { - // create new context - let mut ctx = Context::new(witness_gen_only, ctx_id); - let output = f(&mut ctx, input); - (output, ctx) - }) - .unzip(); - // we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused - builder.threads[phase].append(&mut ctxs); - - outputs -} diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs new file mode 100644 index 00000000..980abee9 --- /dev/null +++ b/halo2-base/src/gates/circuit/builder.rs @@ -0,0 +1,374 @@ +use std::sync::{Arc, Mutex}; + +use getset::{Getters, MutGetters, Setters}; +use itertools::Itertools; + +use crate::{ + gates::{ + circuit::CircuitBuilderStage, + flex_gate::{ + threads::{GateStatistics, MultiPhaseCoreManager, SinglePhaseCoreManager}, + MultiPhaseThreadBreakPoints, MAX_PHASE, + }, + range::RangeConfig, + RangeChip, + }, + halo2_proofs::{ + circuit::{Layouter, Region}, + plonk::{Column, Instance}, + }, + utils::ScalarField, + virtual_region::{ + copy_constraints::{CopyConstraintManager, SharedCopyConstraintManager}, + lookups::LookupAnyManager, + manager::VirtualRegionManager, + }, + AssignedValue, Context, +}; + +use super::BaseCircuitParams; + +/// Keeping the naming `RangeCircuitBuilder` for backwards compatibility. +pub type RangeCircuitBuilder = BaseCircuitBuilder; + +/// A circuit builder is a collection of virtual region managers that together assign virtual +/// regions into a single physical circuit. +/// +/// [BaseCircuitBuilder] is a circuit builder to create a circuit where the columns correspond to [PublicBaseConfig]. +/// This builder can hold multiple threads, but the [Circuit] implementation only evaluates the first phase. +/// The user will have to implement a separate [Circuit] with multi-phase witness generation logic. +/// +/// This is used to manage the virtual region corresponding to [FlexGateConfig] and (optionally) [RangeConfig]. +/// This can be used even if only using [GateChip] without [RangeChip]. +/// +/// The circuit will have `NI` public instance (aka public inputs+outputs) columns. +#[derive(Clone, Debug, Getters, MutGetters, Setters)] +pub struct BaseCircuitBuilder { + /// Virtual region for each challenge phase. These cannot be shared across threads while keeping circuit deterministic. + #[getset(get = "pub", get_mut = "pub", set = "pub")] + pub(super) core: MultiPhaseCoreManager, + /// The range lookup manager + #[getset(get = "pub", get_mut = "pub", set = "pub")] + pub(super) lookup_manager: [LookupAnyManager; MAX_PHASE], + /// Configuration parameters for the circuit shape + pub config_params: BaseCircuitParams, + /// The assigned instances to expose publicly at the end of circuit synthesis + pub assigned_instances: Vec>>, +} + +impl Default for BaseCircuitBuilder { + /// Quick start default circuit builder which can be used for MockProver, Keygen, and real prover. + /// For best performance during real proof generation, we recommend using [BaseCircuitBuilder::prover] instead. + fn default() -> Self { + Self::new(false) + } +} + +impl BaseCircuitBuilder { + /// Creates a new [BaseCircuitBuilder] with all default managers. + /// * `witness_gen_only`: + /// * If true, the builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the builder also imposes constraints (selectors, fixed columns, copy constraints). Primarily used for keygen and mock prover (but can also be used for real prover). + /// + /// By default, **no** circuit configuration parameters have been set. + /// These should be set separately using [use_params], or [use_k], [use_lookup_bits], and [config]. + /// + /// Upon construction, there are no public instances (aka all witnesses are private). + /// The intended usage is that _before_ calling `synthesize`, witness generation can be done to populate + /// assigned instances, which are supplied as `assigned_instances` to this struct. + /// The [`Circuit`] implementation for this struct will then expose these instances and constrain + /// them using the Halo2 API. + pub fn new(witness_gen_only: bool) -> Self { + let core = MultiPhaseCoreManager::new(witness_gen_only); + let lookup_manager = [(); MAX_PHASE] + .map(|_| LookupAnyManager::new(witness_gen_only, core.copy_manager.clone())); + Self { core, lookup_manager, config_params: Default::default(), assigned_instances: vec![] } + } + + /// Creates a new [MultiPhaseCoreManager] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [MultiPhaseCoreManager] is used for witness generation only. + pub fn from_stage(stage: CircuitBuilderStage) -> Self { + Self::new(stage.witness_gen_only()).unknown(stage == CircuitBuilderStage::Keygen) + } + + /// Creates a new [BaseCircuitBuilder] with a pinned circuit configuration given by `config_params` and `break_points`. + pub fn prover( + config_params: BaseCircuitParams, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self { + Self::new(true).use_params(config_params).use_break_points(break_points) + } + + /// Sets the copy manager to the given one in all shared references. + pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { + for lm in &mut self.lookup_manager { + lm.set_copy_manager(copy_manager.clone()); + } + self.core.set_copy_manager(copy_manager); + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + self.set_copy_manager(copy_manager); + self + } + + /// Deep clone of `self`, where the underlying object of shared references in [SharedCopyConstraintManager] and [LookupAnyManager] are cloned. + pub fn deep_clone(&self) -> Self { + let cm: CopyConstraintManager = self.core.copy_manager.lock().unwrap().clone(); + let cm_ref = Arc::new(Mutex::new(cm)); + let mut clone = self.clone().use_copy_manager(cm_ref.clone()); + for lm in &mut clone.lookup_manager { + *lm = lm.deep_clone(cm_ref.clone()); + } + clone + } + + /// The log_2 size of the lookup table, if using. + pub fn lookup_bits(&self) -> Option { + self.config_params.lookup_bits + } + + /// Set lookup bits + pub fn set_lookup_bits(&mut self, lookup_bits: usize) { + self.config_params.lookup_bits = Some(lookup_bits); + } + + /// Returns new with lookup bits + pub fn use_lookup_bits(mut self, lookup_bits: usize) -> Self { + self.set_lookup_bits(lookup_bits); + self + } + + /// Sets new `k` = log2 of domain + pub fn set_k(&mut self, k: usize) { + self.config_params.k = k; + } + + /// Returns new with `k` set + pub fn use_k(mut self, k: usize) -> Self { + self.set_k(k); + self + } + + /// Set the number of instance columns. This resizes `self.assigned_instances`. + pub fn set_instance_columns(&mut self, num_instance_columns: usize) { + self.config_params.num_instance_columns = num_instance_columns; + while self.assigned_instances.len() < num_instance_columns { + self.assigned_instances.push(vec![]); + } + assert_eq!(self.assigned_instances.len(), num_instance_columns); + } + + /// Returns new with `self.assigned_instances` resized to specified number of instance columns. + pub fn use_instance_columns(mut self, num_instance_columns: usize) -> Self { + self.set_instance_columns(num_instance_columns); + self + } + + /// Set config params + pub fn set_params(&mut self, params: BaseCircuitParams) { + self.set_instance_columns(params.num_instance_columns); + self.config_params = params; + } + + /// Returns new with config params + pub fn use_params(mut self, params: BaseCircuitParams) -> Self { + self.set_params(params); + self + } + + /// The break points of the circuit. + pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { + self.core + .phase_manager + .iter() + .map(|pm| pm.break_points.borrow().as_ref().expect("break points not set").clone()) + .collect() + } + + /// Sets the break points of the circuit. + pub fn set_break_points(&mut self, break_points: MultiPhaseThreadBreakPoints) { + if break_points.is_empty() { + return; + } + self.core.touch(break_points.len() - 1); + for (pm, bp) in self.core.phase_manager.iter().zip_eq(break_points) { + *pm.break_points.borrow_mut() = Some(bp); + } + } + + /// Returns new with break points + pub fn use_break_points(mut self, break_points: MultiPhaseThreadBreakPoints) -> Self { + self.set_break_points(break_points); + self + } + + /// Returns if the circuit is only used for witness generation. + pub fn witness_gen_only(&self) -> bool { + self.core.witness_gen_only() + } + + /// Creates a new [MultiPhaseCoreManager] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + pub fn unknown(mut self, use_unknown: bool) -> Self { + self.core = self.core.unknown(use_unknown); + self + } + + /// Clears state and copies, effectively resetting the circuit builder. + pub fn clear(&mut self) { + self.core.clear(); + for lm in &mut self.lookup_manager { + lm.clear(); + } + self.assigned_instances.iter_mut().for_each(|c| c.clear()); + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + /// * `phase`: The challenge phase (as an index) of the gate thread. + pub fn main(&mut self, phase: usize) -> &mut Context { + self.core.main(phase) + } + + /// Returns [SinglePhaseCoreManager] with the virtual region with all core threads in the given phase. + pub fn pool(&mut self, phase: usize) -> &mut SinglePhaseCoreManager { + self.core.phase_manager.get_mut(phase).unwrap() + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self, phase: usize) -> &mut Context { + self.core.new_thread(phase) + } + + /// Returns some statistics about the virtual region. + pub fn statistics(&self) -> RangeStatistics { + let gate = self.core.statistics(); + let total_lookup_advice_per_phase = self.total_lookup_advice_per_phase(); + RangeStatistics { gate, total_lookup_advice_per_phase } + } + + fn total_lookup_advice_per_phase(&self) -> Vec { + self.lookup_manager.iter().map(|lm| lm.total_rows()).collect() + } + + /// Auto-calculates configuration parameters for the circuit and sets them. + /// + /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) + /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. + /// * `lookup_bits`: The fixed lookup table will consist of [0, 2lookup_bits) + pub fn calculate_params(&mut self, minimum_rows: Option) -> BaseCircuitParams { + let k = self.config_params.k; + let ni = self.config_params.num_instance_columns; + assert_ne!(k, 0, "k must be set"); + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let gate_params = self.core.calculate_params(k, minimum_rows); + let total_lookup_advice_per_phase = self.total_lookup_advice_per_phase(); + let num_lookup_advice_per_phase = total_lookup_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let params = BaseCircuitParams { + k: gate_params.k, + num_advice_per_phase: gate_params.num_advice_per_phase, + num_fixed: gate_params.num_fixed, + num_lookup_advice_per_phase, + lookup_bits: self.lookup_bits(), + num_instance_columns: ni, + }; + self.config_params = params.clone(); + #[cfg(feature = "display")] + { + println!("Total range check advice cells to lookup per phase: {total_lookup_advice_per_phase:?}"); + log::info!("Auto-calculated config params:\n {params:#?}"); + } + params + } + + /// Copies `assigned_instances` to the instance columns. Should only be called at the very end of + /// `synthesize` after virtual `assigned_instances` have been assigned to physical circuit. + pub fn assign_instances( + &self, + instance_columns: &[Column], + mut layouter: impl Layouter, + ) { + if !self.core.witness_gen_only() { + // expose public instances + for (instances, instance_col) in self.assigned_instances.iter().zip_eq(instance_columns) + { + for (i, instance) in instances.iter().enumerate() { + let cell = instance.cell.unwrap(); + let copy_manager = self.core.copy_manager.lock().unwrap(); + let cell = + copy_manager.assigned_advices.get(&cell).expect("instance not assigned"); + layouter.constrain_instance(*cell, *instance_col, i); + } + } + } + } + + /// Creates a new [RangeChip] sharing the same [LookupAnyManager]s as `self`. + pub fn range_chip(&self) -> RangeChip { + RangeChip::new( + self.config_params.lookup_bits.expect("lookup bits not set"), + self.lookup_manager.clone(), + ) + } + + /// Copies the queued cells to be range looked up in phase `phase` to special advice lookup columns + /// using [LookupAnyManager]. + /// + /// ## Special case + /// Just for [RangeConfig], we have special handling for the case where there is a single (physical) + /// advice column in [FlexGateConfig]. In this case, `RangeConfig` does not create extra lookup advice columns, + /// the single advice column has lookup enabled, and there is a selector to toggle when lookup should + /// be turned on. + pub fn assign_lookups_in_phase( + &self, + config: &RangeConfig, + region: &mut Region, + phase: usize, + ) { + let lookup_manager = self.lookup_manager.get(phase).expect("too many phases"); + if lookup_manager.total_rows() == 0 { + return; + } + if let Some(q_lookup) = config.q_lookup.get(phase).and_then(|q| *q) { + // if q_lookup is Some, that means there should be a single advice column and it has lookup enabled + assert_eq!(config.gate.basic_gates[phase].len(), 1); + if !self.witness_gen_only() { + let cells_to_lookup = lookup_manager.cells_to_lookup.lock().unwrap(); + for advice in cells_to_lookup.iter().flat_map(|(_, advices)| advices) { + let cell = advice[0].cell.as_ref().unwrap(); + let copy_manager = self.core.copy_manager.lock().unwrap(); + let acell = copy_manager.assigned_advices[cell]; + assert_eq!( + acell.column, + config.gate.basic_gates[phase][0].value.into(), + "lookup column does not match" + ); + q_lookup.enable(region, acell.row_offset).unwrap(); + } + } + } else { + let lookup_cols = config + .lookup_advice + .get(phase) + .expect("No special lookup advice columns") + .iter() + .map(|c| [*c]) + .collect_vec(); + lookup_manager.assign_raw(&lookup_cols, region); + } + let _ = lookup_manager.assigned.set(()); + } +} + +/// Basic statistics +pub struct RangeStatistics { + /// Number of advice cells for the basic gate and total constants used + pub gate: GateStatistics, + /// Total special advice cells that need to be looked up, per phase + pub total_lookup_advice_per_phase: Vec, +} diff --git a/halo2-base/src/gates/circuit/mod.rs b/halo2-base/src/gates/circuit/mod.rs new file mode 100644 index 00000000..46dec873 --- /dev/null +++ b/halo2-base/src/gates/circuit/mod.rs @@ -0,0 +1,217 @@ +use serde::{Deserialize, Serialize}; + +use crate::utils::ScalarField; +use crate::{ + halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + plonk::{Circuit, Column, ConstraintSystem, Error, Fixed, Instance, Selector}, + }, + virtual_region::manager::VirtualRegionManager, +}; + +use self::builder::BaseCircuitBuilder; + +use super::flex_gate::{FlexGateConfig, FlexGateConfigParams}; +use super::range::RangeConfig; + +/// Module that helps auto-build circuits +pub mod builder; + +/// A struct defining the configuration parameters for a halo2-base circuit +/// - this is used to configure [BaseConfig]. +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct BaseCircuitParams { + // Keeping FlexGateConfigParams expanded for backwards compatibility + /// Specifies the number of rows in the circuit to be 2k + pub k: usize, + /// The number of advice columns per phase + pub num_advice_per_phase: Vec, + /// The number of fixed columns + pub num_fixed: usize, + /// The number of bits that can be ranged checked using a special lookup table with values [0, 2lookup_bits), if using. + /// The number of special advice columns that have range lookup enabled per phase + pub num_lookup_advice_per_phase: Vec, + /// This is `None` if no lookup table is used. + pub lookup_bits: Option, + /// Number of public instance columns + #[serde(default)] + pub num_instance_columns: usize, +} + +impl BaseCircuitParams { + fn gate_params(&self) -> FlexGateConfigParams { + FlexGateConfigParams { + k: self.k, + num_advice_per_phase: self.num_advice_per_phase.clone(), + num_fixed: self.num_fixed, + } + } +} + +/// Configuration with [`BaseConfig`] with `NI` public instance columns. +#[derive(Clone, Debug)] +pub struct BaseConfig { + /// The underlying private gate/range configuration + pub base: MaybeRangeConfig, + /// The public instance column + pub instance: Vec>, +} + +/// Smart Halo2 circuit config that has different variants depending on whether you need range checks or not. +/// The difference is that to enable range checks, the Halo2 config needs to add a lookup table. +#[derive(Clone, Debug)] +pub enum MaybeRangeConfig { + /// Config for a circuit that does not use range checks + WithoutRange(FlexGateConfig), + /// Config for a circuit that does use range checks + WithRange(RangeConfig), +} + +impl BaseConfig { + /// Generates a new `BaseConfig` depending on `params`. + /// - It will generate a `RangeConfig` is `params` has `lookup_bits` not None **and** `num_lookup_advice_per_phase` are not all empty or zero (i.e., if `params` indicates that the circuit actually requires a lookup table). + /// - Otherwise it will generate a `FlexGateConfig`. + pub fn configure(meta: &mut ConstraintSystem, params: BaseCircuitParams) -> Self { + let total_lookup_advice_cols = params.num_lookup_advice_per_phase.iter().sum::(); + let base = if params.lookup_bits.is_some() && total_lookup_advice_cols != 0 { + // We only add a lookup table if lookup bits is not None + MaybeRangeConfig::WithRange(RangeConfig::configure( + meta, + params.gate_params(), + ¶ms.num_lookup_advice_per_phase, + params.lookup_bits.unwrap(), + )) + } else { + MaybeRangeConfig::WithoutRange(FlexGateConfig::configure(meta, params.gate_params())) + }; + let instance = (0..params.num_instance_columns) + .map(|_| { + let inst = meta.instance_column(); + meta.enable_equality(inst); + inst + }) + .collect(); + Self { base, instance } + } + + /// Returns the inner [`FlexGateConfig`] + pub fn gate(&self) -> &FlexGateConfig { + match &self.base { + MaybeRangeConfig::WithoutRange(config) => config, + MaybeRangeConfig::WithRange(config) => &config.gate, + } + } + + /// Returns the fixed columns for constants + pub fn constants(&self) -> &Vec> { + match &self.base { + MaybeRangeConfig::WithoutRange(config) => &config.constants, + MaybeRangeConfig::WithRange(config) => &config.gate.constants, + } + } + + /// Returns a slice of the selector column to enable lookup -- this is only in the situation where there is a single advice column of any kind -- per phase + /// Returns empty slice if there are no lookups enabled. + pub fn q_lookup(&self) -> &[Option] { + match &self.base { + MaybeRangeConfig::WithoutRange(_) => &[], + MaybeRangeConfig::WithRange(config) => &config.q_lookup, + } + } + + /// Updates the number of usable rows in the circuit. Used if you mutate [ConstraintSystem] after `BaseConfig::configure` is called. + pub fn set_usable_rows(&mut self, usable_rows: usize) { + match &mut self.base { + MaybeRangeConfig::WithoutRange(config) => config.max_rows = usable_rows, + MaybeRangeConfig::WithRange(config) => config.gate.max_rows = usable_rows, + } + } + + /// Initialization of config at very beginning of `synthesize`. + /// Loads fixed lookup table, if using. + pub fn initialize(&self, layouter: &mut impl Layouter) { + // only load lookup table if we are actually doing lookups + if let MaybeRangeConfig::WithRange(config) = &self.base { + config.load_lookup_table(layouter).expect("load lookup table should not fail"); + } + } +} + +impl Circuit for BaseCircuitBuilder { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = BaseCircuitParams; + + fn params(&self) -> Self::Params { + self.config_params.clone() + } + + /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using [`BaseConfigParams`] + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + BaseConfig::configure(meta, params) + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!("You must use configure_with_params"); + } + + /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // only load lookup table if we are actually doing lookups + if let MaybeRangeConfig::WithRange(config) = &config.base { + config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + // Only FirstPhase (phase 0) + layouter + .assign_region( + || "BaseCircuitBuilder generated circuit", + |mut region| { + let usable_rows = config.gate().max_rows; + self.core.phase_manager[0].assign_raw( + &(config.gate().basic_gates[0].clone(), usable_rows), + &mut region, + ); + // Only assign cells to lookup if we're sure we're doing range lookups + if let MaybeRangeConfig::WithRange(config) = &config.base { + self.assign_lookups_in_phase(config, &mut region, 0); + } + // Impose equality constraints + if !self.core.witness_gen_only() { + self.core.copy_manager.assign_raw(config.constants(), &mut region); + } + Ok(()) + }, + ) + .unwrap(); + + self.assign_instances(&config.instance, layouter.namespace(|| "expose")); + Ok(()) + } +} + +/// Defines stage of circuit building. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CircuitBuilderStage { + /// Keygen phase + Keygen, + /// Prover Circuit + Prover, + /// Mock Circuit + Mock, +} + +impl CircuitBuilderStage { + /// Returns true if the circuit is used for witness generation only. + pub fn witness_gen_only(&self) -> bool { + matches!(self, CircuitBuilderStage::Prover) + } +} diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate/mod.rs similarity index 77% rename from halo2-base/src/gates/flex_gate.rs rename to halo2-base/src/gates/flex_gate/mod.rs index a0447ae7..2938381b 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate/mod.rs @@ -10,28 +10,31 @@ use crate::{ AssignedValue, Context, QuantumCell::{self, Constant, Existing, Witness, WitnessFraction}, }; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::{ iter::{self}, marker::PhantomData, }; -/// The maximum number of phases in halo2. -pub const MAX_PHASE: usize = 3; - -/// Specifies the gate strategy for the gate chip -#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] -pub enum GateStrategy { - /// # Vertical Gate Strategy: - /// `q_0 * (a + b * c - d) = 0` - /// where - /// * a = value[0], b = value[1], c = value[2], d = value[3] - /// * q = q_enable[0] - /// * q is either 0 or 1 so this is just a simple selector - /// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. - Vertical, -} +pub mod threads; +/// Vector of thread advice column break points +pub type ThreadBreakPoints = Vec; +/// Vector of vectors tracking the thread break points across different halo2 phases +pub type MultiPhaseThreadBreakPoints = Vec; + +/// The maximum number of phases in halo2. +pub(super) const MAX_PHASE: usize = 3; + +/// # Vertical Gate Strategy: +/// `q_0 * (a + b * c - d) = 0` +/// where +/// * a = value[0], b = value[1], c = value[2], d = value[3] +/// * q = q_enable[0] +/// * q is either 0 or 1 so this is just a simple selector +/// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. +/// /// A configuration for a basic gate chip describing the selector, and advice column values. #[derive(Clone, Debug)] pub struct BasicGateConfig { @@ -45,13 +48,18 @@ pub struct BasicGateConfig { } impl BasicGateConfig { + /// Constructor + pub fn new(q_enable: Selector, value: Column) -> Self { + Self { q_enable, value, _marker: PhantomData } + } + /// Instantiates a new [BasicGateConfig]. /// /// Assumes `phase` is in the range [0, MAX_PHASE). /// * `meta`: [ConstraintSystem] used for the gate /// * `strategy`: The [GateStrategy] to use for the gate /// * `phase`: The phase to add the gate to - pub fn configure(meta: &mut ConstraintSystem, strategy: GateStrategy, phase: u8) -> Self { + pub fn configure(meta: &mut ConstraintSystem, phase: u8) -> Self { let value = match phase { 0 => meta.advice_column_in(FirstPhase), 1 => meta.advice_column_in(SecondPhase), @@ -62,13 +70,9 @@ impl BasicGateConfig { let q_enable = meta.selector(); - match strategy { - GateStrategy::Vertical => { - let config = Self { q_enable, value, _marker: PhantomData }; - config.create_gate(meta); - config - } - } + let config = Self { q_enable, value, _marker: PhantomData }; + config.create_gate(meta); + config } /// Wrapper for [ConstraintSystem].create_gate(name, meta) creates a gate form [q * (a + b * c - out)]. @@ -87,18 +91,25 @@ impl BasicGateConfig { } } +/// A Config struct defining the parameters for [FlexGateConfig] +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct FlexGateConfigParams { + /// Specifies the number of rows in the circuit to be 2k + pub k: usize, + /// The number of advice columns per phase + pub num_advice_per_phase: Vec, + /// The number of fixed columns + pub num_fixed: usize, +} + /// Defines a configuration for a flex gate chip describing the selector, and advice column values for the chip. #[derive(Clone, Debug)] pub struct FlexGateConfig { /// A [Vec] of [BasicGateConfig] that define gates for each halo2 phase. - pub basic_gates: [Vec>; MAX_PHASE], + pub basic_gates: Vec>>, /// A [Vec] of [Fixed] [Column]s for allocating constant values. pub constants: Vec>, - /// Number of advice columns for each halo2 phase. - pub num_advice: [usize; MAX_PHASE], - /// [GateStrategy] for the flex gate. - _strategy: GateStrategy, - /// Max number of rows in flex gate. + /// Max number of usable rows in the circuit. pub max_rows: usize, } @@ -111,59 +122,37 @@ impl FlexGateConfig { /// * `num_advice`: Number of [Advice] [Column]s in each phase /// * `num_fixed`: Number of [Fixed] [Column]s in each phase /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) - pub fn configure( - meta: &mut ConstraintSystem, - strategy: GateStrategy, - num_advice: &[usize], - num_fixed: usize, - // log2_ceil(# rows in circuit) - circuit_degree: usize, - ) -> Self { + pub fn configure(meta: &mut ConstraintSystem, params: FlexGateConfigParams) -> Self { // create fixed (constant) columns and enable equality constraints - let mut constants = Vec::with_capacity(num_fixed); - for _i in 0..num_fixed { + let mut constants = Vec::with_capacity(params.num_fixed); + for _i in 0..params.num_fixed { let c = meta.fixed_column(); meta.enable_equality(c); // meta.enable_constant(c); constants.push(c); } - match strategy { - GateStrategy::Vertical => { - let mut basic_gates = [(); MAX_PHASE].map(|_| vec![]); - let mut num_advice_array = [0usize; MAX_PHASE]; - for ((phase, &num_columns), gates) in - num_advice.iter().enumerate().zip(basic_gates.iter_mut()) - { - *gates = (0..num_columns) - .map(|_| BasicGateConfig::configure(meta, strategy, phase as u8)) - .collect(); - num_advice_array[phase] = num_columns; - } - Self { - basic_gates, - constants, - num_advice: num_advice_array, - _strategy: strategy, - /// Warning: this needs to be updated if you create more advice columns after this `FlexGateConfig` is created - max_rows: (1 << circuit_degree) - meta.minimum_rows(), - } - } + let mut basic_gates = vec![]; + for (phase, &num_columns) in params.num_advice_per_phase.iter().enumerate() { + let config = + (0..num_columns).map(|_| BasicGateConfig::configure(meta, phase as u8)).collect(); + basic_gates.push(config); + } + log::info!("Poisoned rows after FlexGateConfig::configure {}", meta.minimum_rows()); + Self { + basic_gates, + constants, + /// Warning: this needs to be updated if you create more advice columns after this `FlexGateConfig` is created + max_rows: (1 << params.k) - meta.minimum_rows(), } } } /// Trait that defines basic arithmetic operations for a gate. pub trait GateInstructions { - /// Returns the [GateStrategy] for the gate. - fn strategy(&self) -> GateStrategy; - /// Returns a slice of the [ScalarField] field elements 2^i for i in 0..F::NUM_BITS. fn pow_of_two(&self) -> &[F]; - /// Converts a [u64] into a scalar field element [ScalarField]. - fn get_field_element(&self, n: u64) -> F; - /// Constrains and returns `a + b * 1 = out`. /// /// Defines a vertical gate of form | a | b | 1 | a + b | where (a + b) = out. @@ -179,7 +168,15 @@ pub trait GateInstructions { let a = a.into(); let b = b.into(); let out_val = *a.value() + b.value(); - ctx.assign_region_last([a, b, Constant(F::one()), Witness(out_val)], [0]) + ctx.assign_region_last([a, b, Constant(F::ONE), Witness(out_val)], [0]) + } + + /// Constrains and returns `out = a + 1`. + /// + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + fn inc(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.add(ctx, a, Constant(F::ONE)) } /// Constrains and returns `a + b * (-1) = out`. @@ -197,11 +194,19 @@ pub trait GateInstructions { let a = a.into(); let b = b.into(); let out_val = *a.value() - b.value(); - // slightly better to not have to compute -F::one() since F::one() is cached - ctx.assign_region([Witness(out_val), b, Constant(F::one()), a], [0]); + // slightly better to not have to compute -F::ONE since F::ONE is cached + ctx.assign_region([Witness(out_val), b, Constant(F::ONE), a], [0]); ctx.get(-4) } + /// Constrains and returns `out = a - 1`. + /// + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + fn dec(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.sub(ctx, a, Constant(F::ONE)) + } + /// Constrains and returns `a - b * c = out`. /// /// Defines a vertical gate of form | a - b * c | b | c | a |, where (a - b * c) = out. @@ -232,7 +237,7 @@ pub trait GateInstructions { fn neg(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { let a = a.into(); let out_val = -*a.value(); - ctx.assign_region([a, Witness(out_val), Constant(F::one()), Constant(F::zero())], [0]); + ctx.assign_region([a, Witness(out_val), Constant(F::ONE), Constant(F::ZERO)], [0]); ctx.get(-3) } @@ -251,7 +256,7 @@ pub trait GateInstructions { let a = a.into(); let b = b.into(); let out_val = *a.value() * b.value(); - ctx.assign_region_last([Constant(F::zero()), a, b, Witness(out_val)], [0]) + ctx.assign_region_last([Constant(F::ZERO), a, b, Witness(out_val)], [0]) } /// Constrains and returns `a * b + c = out`. @@ -289,7 +294,7 @@ pub trait GateInstructions { ) -> AssignedValue { let a = a.into(); let b = b.into(); - let out_val = (F::one() - a.value()) * b.value(); + let out_val = (F::ONE - a.value()) * b.value(); ctx.assign_region_smart([Witness(out_val), a, b, b], [0], [(2, 3)], []); ctx.get(-4) } @@ -300,7 +305,7 @@ pub trait GateInstructions { /// * `ctx`: [Context] to add the constraints to /// * `x`: [QuantumCell] value to constrain fn assert_bit(&self, ctx: &mut Context, x: AssignedValue) { - ctx.assign_region([Constant(F::zero()), Existing(x), Existing(x), Existing(x)], [0]); + ctx.assign_region([Constant(F::ZERO), Existing(x), Existing(x), Existing(x)], [0]); } /// Constrains and returns a / b = 0. @@ -322,7 +327,7 @@ pub trait GateInstructions { // TODO: if really necessary, make `c` of type `Assigned` // this would require the API using `Assigned` instead of `F` everywhere, so leave as last resort let c = b.value().invert().unwrap() * a.value(); - ctx.assign_region([Constant(F::zero()), Witness(c), b, a], [0]); + ctx.assign_region([Constant(F::ZERO), Witness(c), b, a], [0]); ctx.get(-3) } @@ -332,7 +337,7 @@ pub trait GateInstructions { /// * `constant`: constant value to constrain `a` to be equal to fn assert_is_const(&self, ctx: &mut Context, a: &AssignedValue, constant: &F) { if !ctx.witness_gen_only { - ctx.constant_equality_constraints.push((*constant, a.cell.unwrap())); + ctx.copy_manager.lock().unwrap().constant_equalities.push((*constant, a.cell.unwrap())); } } @@ -351,7 +356,11 @@ pub trait GateInstructions { where QA: Into>; - /// Returns the inner product of `` and the last element of `a` now assigned, i.e. `(inner_product_, last_element_a)`. + /// Returns the inner product of `` and the last element of `a` after it has been assigned. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, where you want to avoid first assigning `a` and then copying the last element into the + /// correct cell for this computation. /// /// Assumes 'a' and 'b' are the same length. /// * `ctx`: [Context] of the circuit @@ -366,6 +375,24 @@ pub trait GateInstructions { where QA: Into>; + /// Returns `(, a_assigned)`. See `inner_product` for more details. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, useful for when you want to simultaneously compute an inner product while assigning + /// private witnesses for the first time. This avoids first assigning `a` and then copying into the correct cells + /// for this computation. We do not return the assignments of `a` in `inner_product` as an optimization to avoid + /// the memory allocation of having to collect the vectors. + /// + /// Assumes 'a' and 'b' are the same length. + fn inner_product_left( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, Vec>) + where + QA: Into>; + /// Calculates and constrains the inner product. /// /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. @@ -406,7 +433,7 @@ pub trait GateInstructions { let cells = iter::once(start).chain(a.flat_map(|a| { let a = a.into(); sum += a.value(); - [a, Constant(F::one()), Witness(sum)] + [a, Constant(F::ONE), Witness(sum)] })); ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } @@ -440,7 +467,7 @@ pub trait GateInstructions { let cells = iter::once(start).chain(a.flat_map(|a| { let a = a.into(); sum += a.value(); - [a, Constant(F::one()), Witness(sum)] + [a, Constant(F::ONE), Witness(sum)] })); ctx.assign_region(cells, (0..len).map(|i| 3 * i as isize)); Box::new((0..=len).rev().map(|i| ctx.get(-1 - 3 * (i as isize)))) @@ -507,13 +534,13 @@ pub trait GateInstructions { ) -> AssignedValue { let a = a.into(); let b = b.into(); - let not_b_val = F::one() - b.value(); + let not_b_val = F::ONE - b.value(); let out_val = *a.value() + b.value() - *a.value() * b.value(); let cells = [ Witness(not_b_val), - Constant(F::one()), + Constant(F::ONE), b, - Constant(F::one()), + Constant(F::ONE), b, a, Witness(not_b_val), @@ -552,13 +579,13 @@ pub trait GateInstructions { ) -> AssignedValue { let a = a.into(); let b = b.into(); - let not_two_b_val = F::one() - F::from(2u64) * b.value(); + let not_two_b_val = F::ONE - F::from(2u64) * b.value(); let out_val = *a.value() + b.value() - F::from(2u64) * *a.value() * b.value(); let cells = [ Witness(not_two_b_val), Constant(F::from(2u64)), b, - Constant(F::one()), + Constant(F::ONE), b, a, Witness(not_two_b_val), @@ -574,7 +601,7 @@ pub trait GateInstructions { /// * `ctx`: [Context] to add the constraints to. /// * `a`: [QuantumCell] that contains a boolean value. fn not(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { - self.sub(ctx, Constant(F::one()), a) + self.sub(ctx, Constant(F::ONE), a) } /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. @@ -625,10 +652,10 @@ pub trait GateInstructions { let (inv_last_bit, last_bit) = { ctx.assign_region( [ - Witness(F::one() - bits[k - 1].value()), + Witness(F::ONE - bits[k - 1].value()), Existing(bits[k - 1]), - Constant(F::one()), - Constant(F::one()), + Constant(F::ONE), + Constant(F::ONE), ], [0], ); @@ -641,7 +668,7 @@ pub trait GateInstructions { for (idx, bit) in bits.iter().rev().enumerate().skip(1) { for old_idx in 0..(1 << idx) { // inv_prod_val = (1 - bit) * indicator[offset + old_idx] - let inv_prod_val = (F::one() - bit.value()) * indicator[offset + old_idx].value(); + let inv_prod_val = (F::ONE - bit.value()) * indicator[offset + old_idx].value(); ctx.assign_region( [ Witness(inv_prod_val), @@ -682,25 +709,25 @@ pub trait GateInstructions { // unroll `is_zero` to make sure if `idx == Witness(_)` it is replaced by `Existing(_)` in later iterations let x = idx.value(); let (is_zero, inv) = if x.is_zero_vartime() { - (F::one(), Assigned::Trivial(F::one())) + (F::ONE, Assigned::Trivial(F::ONE)) } else { - (F::zero(), Assigned::Rational(F::one(), *x)) + (F::ZERO, Assigned::Rational(F::ONE, *x)) }; let cells = [ Witness(is_zero), idx, WitnessFraction(inv), - Constant(F::one()), - Constant(F::zero()), + Constant(F::ONE), + Constant(F::ZERO), idx, Witness(is_zero), - Constant(F::zero()), + Constant(F::ZERO), ]; ctx.assign_region_smart(cells, [0, 4], [(0, 6), (1, 5)], []); // note the two `idx` need to be constrained equal: (1, 5) idx = Existing(ctx.get(-3)); // replacing `idx` with Existing cell so future loop iterations constrain equality of all `idx`s ctx.get(-2) } else { - self.is_equal(ctx, idx, Constant(self.get_field_element(i as u64))) + self.is_equal(ctx, idx, Constant(F::from(i as u64))) } }) .collect() @@ -722,18 +749,17 @@ pub trait GateInstructions { where Q: Into>, { - let mut sum = F::zero(); + let mut sum = F::ZERO; let a = a.into_iter(); let (len, hi) = a.size_hint(); assert_eq!(Some(len), hi); - let cells = std::iter::once(Constant(F::zero())).chain( - a.zip(indicator.into_iter()).flat_map(|(a, ind)| { + let cells = + std::iter::once(Constant(F::ZERO)).chain(a.zip(indicator).flat_map(|(a, ind)| { let a = a.into(); sum = if ind.value().is_zero_vartime() { sum } else { *a.value() }; [a, Existing(ind), Witness(sum)] - }), - ); + })); ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } @@ -760,6 +786,35 @@ pub trait GateInstructions { self.select_by_indicator(ctx, cells, ind) } + /// `array2d` is an array of fixed length arrays. + /// Assumes: + /// * `array2d.len() == indicator.len()` + /// * `array2d[i].len() == array2d[j].len()` for all `i,j`. + /// * the values of `indicator` are boolean and that `indicator` has at most one `1` bit. + /// * the lengths of `array2d` and `indicator` are the same. + /// + /// Returns the "dot product" of `array2d` with `indicator` as a fixed length (1d) array of length `array2d[0].len()`. + fn select_array_by_indicator( + &self, + ctx: &mut Context, + array2d: &[AR], + indicator: &[AssignedValue], + ) -> Vec> + where + AR: AsRef<[AV]>, + AV: AsRef>, + { + (0..array2d[0].as_ref().len()) + .map(|j| { + self.select_by_indicator( + ctx, + array2d.iter().map(|array_i| *array_i.as_ref()[j].as_ref()), + indicator.iter().copied(), + ) + }) + .collect() + } + /// Constrains that a cell is equal to 0 and returns `1` if `a = 0`, otherwise `0`. /// /// Defines a vertical gate of form `| out | a | inv | 1 | 0 | a | out | 0 |`, where out = 1 if a = 0, otherwise out = 0. @@ -768,20 +823,20 @@ pub trait GateInstructions { fn is_zero(&self, ctx: &mut Context, a: AssignedValue) -> AssignedValue { let x = a.value(); let (is_zero, inv) = if x.is_zero_vartime() { - (F::one(), Assigned::Trivial(F::one())) + (F::ONE, Assigned::Trivial(F::ONE)) } else { - (F::zero(), Assigned::Rational(F::one(), *x)) + (F::ZERO, Assigned::Rational(F::ONE, *x)) }; let cells = [ Witness(is_zero), Existing(a), WitnessFraction(inv), - Constant(F::one()), - Constant(F::zero()), + Constant(F::ONE), + Constant(F::ZERO), Existing(a), Witness(is_zero), - Constant(F::zero()), + Constant(F::ZERO), ]; ctx.assign_region_smart(cells, [0, 4], [(0, 6)], []); ctx.get(-2) @@ -813,6 +868,17 @@ pub trait GateInstructions { range_bits: usize, ) -> Vec>; + /// Constrains and computes `a``exp` where both `a, exp` are witnesses. The exponent is computed in the native field `F`. + /// + /// Constrains that `exp` has at most `max_bits` bits. + fn pow_var( + &self, + ctx: &mut Context, + a: AssignedValue, + exp: AssignedValue, + max_bits: usize, + ) -> AssignedValue; + /// Performs and constrains Lagrange interpolation on `coords` and evaluates the resulting polynomial at `x`. /// /// Given pairs `coords[i] = (x_i, y_i)`, let `f` be the unique degree `len(coords) - 1` polynomial such that `f(x_i) = y_i` for all `i`. @@ -850,7 +916,7 @@ pub trait GateInstructions { } // TODO: batch inversion let is_zero = self.is_zero(ctx, denom); - self.assert_is_const(ctx, &is_zero, &F::zero()); + self.assert_is_const(ctx, &is_zero, &F::ZERO); // y_i / denom let quot = self.div_unsafe(ctx, coords[i].1, denom); @@ -864,35 +930,11 @@ pub trait GateInstructions { let out = self.mul(ctx, eval.unwrap(), z); (out, z) } - - /// Bitwise right rotate a by BIT bits. BIT and NUM_BITS must be determined at compile time. - /// - /// Assumes 'a' is a NUM_BITS bit integer and NUM_BITS <= 128. - /// * `ctx`: [Context] to add the constraints to - /// * `a`: a [AssignedValue] value. - fn const_right_rotate_unsafe( - &self, - ctx: &mut Context, - a: AssignedValue, - ) -> AssignedValue; - - /// Bitwise left rotate a by BIT bits. BIT and NUM_BITS must be determined at compile time. - /// - /// Assumes 'a' is a NUM_BITS bit integer and NUM_BITS <= 128. - /// * `ctx`: [Context] to add the constraints to - /// * `a`: a [AssignedValue] value. - fn const_left_rotate_unsafe( - &self, - ctx: &mut Context, - a: AssignedValue, - ) -> AssignedValue; } /// A chip that implements the [GateInstructions] trait supporting basic arithmetic operations. #[derive(Clone, Debug)] pub struct GateChip { - /// The [GateStrategy] used when declaring gates. - strategy: GateStrategy, /// The field elements 2^i for i in 0..F::NUM_BITS. pub pow_of_two: Vec, /// To avoid Montgomery conversion in `F::from` for common small numbers, we keep a cache of field elements. @@ -901,28 +943,29 @@ pub struct GateChip { impl Default for GateChip { fn default() -> Self { - Self::new(GateStrategy::Vertical) + Self::new() } } impl GateChip { /// Returns a new [GateChip] with the given [GateStrategy]. - pub fn new(strategy: GateStrategy) -> Self { + pub fn new() -> Self { let mut pow_of_two = Vec::with_capacity(F::NUM_BITS as usize); let two = F::from(2); - pow_of_two.push(F::one()); + pow_of_two.push(F::ONE); pow_of_two.push(two); for _ in 2..F::NUM_BITS { pow_of_two.push(two * pow_of_two.last().unwrap()); } let field_element_cache = (0..1024).map(|i| F::from(i)).collect(); - Self { strategy, pow_of_two, field_element_cache } + Self { pow_of_two, field_element_cache } } /// Calculates and constrains the inner product of ``. + /// If the first element of `b` is `Constant(F::ONE)`, then an optimization is performed to save 3 cells. /// - /// Returns `true` if `b` start with `Constant(F::one())`, and `false` otherwise. + /// Returns `true` if `b` start with `Constant(F::ONE)`, and `false` otherwise. /// /// Assumes `a` and `b` are the same length. /// * `ctx`: [Context] of the circuit @@ -941,15 +984,15 @@ impl GateChip { let mut a = a.into_iter(); let mut b = b.into_iter().peekable(); - let b_starts_with_one = matches!(b.peek(), Some(Constant(c)) if c == &F::one()); + let b_starts_with_one = matches!(b.peek(), Some(Constant(c)) if c == &F::ONE); let cells = if b_starts_with_one { b.next(); let start_a = a.next().unwrap().into(); sum = *start_a.value(); iter::once(start_a) } else { - sum = F::zero(); - iter::once(Constant(F::zero())) + sum = F::ZERO; + iter::once(Constant(F::ZERO)) } .chain(a.zip(b).flat_map(|(a, b)| { let a = a.into(); @@ -967,54 +1010,16 @@ impl GateChip { }; b_starts_with_one } - - /// Bitwise right rotate a by bits. This function should never be called directly - /// because const bitwise rotation must be determined at compile time. - /// - /// Assumes 'a' is a bit integer and <= 128. - fn const_right_rotate_unsafe_internal( - &self, - ctx: &mut Context, - a: AssignedValue, - bit: usize, - num_bits: usize, - ) -> AssignedValue { - // Add a constrain a = l_witness << bit | r_wintess - let val = a.value().get_lower_128(); - let val_l = val >> bit; - let val_r = val - (val_l << bit); - let l_witness = Witness(F::from_u128(val_l)); - let r_witness = Witness(F::from_u128(val_r)); - let val_witness = self.mul_add(ctx, l_witness, Constant(self.pow_of_two()[bit]), r_witness); - ctx.constrain_equal(&a, &val_witness); - // Return (r_witness << (num_bits - bit)) | l_witness - self.mul_add(ctx, r_witness, Constant(self.pow_of_two()[num_bits - bit]), l_witness) - } } impl GateInstructions for GateChip { - /// Returns the [GateStrategy] the [GateChip]. - fn strategy(&self) -> GateStrategy { - self.strategy - } - /// Returns a slice of the [ScalarField] elements 2i for i in 0..F::NUM_BITS. fn pow_of_two(&self) -> &[F] { &self.pow_of_two } - /// Returns the the value of `n` as a [ScalarField] element. - /// * `n`: the [u64] value to convert - fn get_field_element(&self, n: u64) -> F { - let get = self.field_element_cache.get(n as usize); - if let Some(fe) = get { - *fe - } else { - F::from(n) - } - } - /// Constrains and returns the inner product of ``. + /// If the first element of `b` is `Constant(F::ONE)`, then an optimization is performed to save 3 cells. /// /// Assumes 'a' and 'b' are the same length. /// * `ctx`: [Context] to add the constraints to @@ -1033,7 +1038,11 @@ impl GateInstructions for GateChip { ctx.last().unwrap() } - /// Returns the inner product of `` and returns a tuple of the last item of `a` after it is assigned and the item to its left `(left_a, last_a)`. + /// Returns the inner product of `` and the last element of `a` after it has been assigned. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, where you want to avoid first assigning `a` and then copying the last element into the + /// correct cell for this computation. /// /// Assumes 'a' and 'b' are the same length. /// * `ctx`: [Context] of the circuit @@ -1065,6 +1074,46 @@ impl GateInstructions for GateChip { (ctx.last().unwrap(), a_last) } + /// Returns `(, a_assigned)`. See `inner_product` for more details. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, useful for when you want to simultaneously compute an inner product while assigning + /// private witnesses for the first time. This avoids first assigning `a` and then copying into the correct cells + /// for this computation. We do not return the assignments of `a` in `inner_product` as an optimization to avoid + /// the memory allocation of having to collect the vectors. + /// + /// We do not return `b_assigned` because if `b` starts with `Constant(F::ONE)`, the first element of `b` is not assigned. + /// + /// Assumes 'a' and 'b' are the same length. + fn inner_product_left( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, Vec>) + where + QA: Into>, + { + let a = a.into_iter().collect_vec(); + let len = a.len(); + let row_offset = ctx.advice.len(); + let b_starts_with_one = self.inner_product_simple(ctx, a, b); + let a_assigned = (0..len) + .map(|i| { + if b_starts_with_one { + if i == 0 { + ctx.get(row_offset as isize) + } else { + ctx.get((row_offset + 1 + 3 * (i - 1)) as isize) + } + } else { + ctx.get((row_offset + 1 + 3 * i) as isize) + } + }) + .collect_vec(); + (ctx.last().unwrap(), a_assigned) + } + /// Calculates and constrains the inner product. /// /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. @@ -1103,25 +1152,20 @@ impl GateInstructions for GateChip { values: impl IntoIterator, QuantumCell)>, var: QuantumCell, ) -> AssignedValue { - // TODO: optimizer - match self.strategy { - GateStrategy::Vertical => { - // Create an iterator starting with `var` and - let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::one()))) - .chain(values.into_iter().filter_map(|(c, va, vb)| { - if c == F::one() { - Some((va, vb)) - } else if c != F::zero() { - let prod = self.mul(ctx, va, vb); - Some((QuantumCell::Existing(prod), Constant(c))) - } else { - None - } - })) - .unzip(); - self.inner_product(ctx, a, b) - } - } + // Create an iterator starting with `var` and + let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::ONE))) + .chain(values.into_iter().filter_map(|(c, va, vb)| { + if c == F::ONE { + Some((va, vb)) + } else if c != F::ZERO { + let prod = self.mul(ctx, va, vb); + Some((QuantumCell::Existing(prod), Constant(c))) + } else { + None + } + })) + .unzip(); + self.inner_product(ctx, a, b) } /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. @@ -1143,24 +1187,20 @@ impl GateInstructions for GateChip { let sel = sel.into(); let diff_val = *a.value() - b.value(); let out_val = diff_val * sel.value() + b.value(); - match self.strategy { - // | a - b | 1 | b | a | - // | b | sel | a - b | out | - GateStrategy::Vertical => { - let cells = [ - Witness(diff_val), - Constant(F::one()), - b, - a, - b, - sel, - Witness(diff_val), - Witness(out_val), - ]; - ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); - ctx.last().unwrap() - } - } + // | a - b | 1 | b | a | + // | b | sel | a - b | out | + let cells = [ + Witness(diff_val), + Constant(F::ONE), + b, + a, + b, + sel, + Witness(diff_val), + Witness(out_val), + ]; + ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); + ctx.last().unwrap() } /// Constains and returns `a || (b && c)`, assuming `a`, `b` and `c` are boolean. @@ -1181,20 +1221,20 @@ impl GateInstructions for GateChip { let b = b.into(); let c = c.into(); let bc_val = *b.value() * c.value(); - let not_bc_val = F::one() - bc_val; - let not_a_val = *a.value() - F::one(); + let not_bc_val = F::ONE - bc_val; + let not_a_val = *a.value() - F::ONE; let out_val = bc_val + a.value() - bc_val * a.value(); let cells = [ Witness(not_bc_val), b, c, - Constant(F::one()), + Constant(F::ONE), Witness(not_a_val), Witness(not_bc_val), Witness(out_val), Witness(not_a_val), - Constant(F::one()), - Constant(F::one()), + Constant(F::ONE), + Constant(F::ONE), a, ]; ctx.assign_region_smart(cells, [0, 3, 7], [(4, 7), (0, 5)], []); @@ -1234,26 +1274,27 @@ impl GateInstructions for GateChip { bit_cells } - fn const_right_rotate_unsafe( - &self, - ctx: &mut Context, - a: AssignedValue, - ) -> AssignedValue { - if BIT == 0 { - return a; - }; - self.const_right_rotate_unsafe_internal(ctx, a, BIT, NUM_BITS) - } - - fn const_left_rotate_unsafe( + /// Constrains and computes `a^exp` where both `a, exp` are witnesses. The exponent is computed in the native field `F`. + /// + /// Constrains that `exp` has at most `max_bits` bits. + fn pow_var( &self, ctx: &mut Context, a: AssignedValue, + exp: AssignedValue, + max_bits: usize, ) -> AssignedValue { - if BIT == 0 { - return a; - }; - // left rotate by BIT == right rotate by (NUM_BITS - BIT) - self.const_right_rotate_unsafe_internal(ctx, a, NUM_BITS - BIT, NUM_BITS) + let exp_bits = self.num_to_bits(ctx, exp, max_bits); + // standard square-and-mul approach + let mut acc = ctx.load_constant(F::ONE); + for (i, bit) in exp_bits.into_iter().rev().enumerate() { + if i > 0 { + // square + acc = self.mul(ctx, acc, acc); + } + let mul = self.mul(ctx, acc, a); + acc = self.select(ctx, mul, acc, bit); + } + acc } } diff --git a/halo2-base/src/gates/flex_gate/threads/mod.rs b/halo2-base/src/gates/flex_gate/threads/mod.rs new file mode 100644 index 00000000..675f57ab --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/mod.rs @@ -0,0 +1,18 @@ +//! Module for managing the virtual region corresponding to [super::FlexGateConfig] +//! +//! In the virtual region we have virtual columns. Each virtual column is referred to as a "thread" +//! because it can be generated in a separate CPU thread. The virtual region manager will collect all +//! threads together, virtually concatenate them all together back into a single virtual column, and +//! then assign this virtual column to multiple physical Halo2 columns according to the provided configuration parameters. +//! +//! Supports multiple phases. + +/// Thread builder for multiple phases +mod multi_phase; +mod parallelize; +/// Thread builder for a single phase +pub mod single_phase; + +pub use multi_phase::{GateStatistics, MultiPhaseCoreManager}; +pub use parallelize::parallelize_core; +pub use single_phase::SinglePhaseCoreManager; diff --git a/halo2-base/src/gates/flex_gate/threads/multi_phase.rs b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs new file mode 100644 index 00000000..40ce5103 --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs @@ -0,0 +1,162 @@ +use getset::CopyGetters; +use itertools::Itertools; + +use crate::{ + gates::{circuit::CircuitBuilderStage, flex_gate::FlexGateConfigParams}, + utils::ScalarField, + virtual_region::copy_constraints::SharedCopyConstraintManager, + Context, +}; + +use super::SinglePhaseCoreManager; + +/// Virtual region manager for [FlexGateConfig] in multiple phases. +#[derive(Clone, Debug, Default, CopyGetters)] +pub struct MultiPhaseCoreManager { + /// Virtual region for each challenge phase. These cannot be shared across threads while keeping circuit deterministic. + pub phase_manager: Vec>, + /// Global shared copy manager + pub copy_manager: SharedCopyConstraintManager, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + #[getset(get_copy = "pub")] + witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + #[getset(get_copy = "pub")] + use_unknown: bool, +} + +impl MultiPhaseCoreManager { + /// Creates a new [MultiPhaseCoreManager] with a default [SinglePhaseCoreManager] in phase 0. + /// Creates an empty [SharedCopyConstraintManager] and sets `witness_gen_only` flag. + /// * `witness_gen_only`: If true, the [MultiPhaseCoreManager] is used for witness generation only. + /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). + /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. + pub fn new(witness_gen_only: bool) -> Self { + let copy_manager = SharedCopyConstraintManager::default(); + let phase_manager = + vec![SinglePhaseCoreManager::new(witness_gen_only, copy_manager.clone())]; + Self { phase_manager, witness_gen_only, use_unknown: false, copy_manager } + } + + /// Creates a new [MultiPhaseCoreManager] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [MultiPhaseCoreManager] is used for witness generation only. + pub fn from_stage(stage: CircuitBuilderStage) -> Self { + Self::new(stage.witness_gen_only()).unknown(stage == CircuitBuilderStage::Keygen) + } + + /// Mutates `self` to use the given copy manager in all phases and all threads. + pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { + for pm in &mut self.phase_manager { + pm.set_copy_manager(copy_manager.clone()); + } + self.copy_manager = copy_manager; + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + self.set_copy_manager(copy_manager); + self + } + + /// Creates a new [MultiPhaseCoreManager] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + pub fn unknown(mut self, use_unknown: bool) -> Self { + self.use_unknown = use_unknown; + for pm in &mut self.phase_manager { + pm.use_unknown = use_unknown; + } + self + } + + /// Clears all threads in all phases and copy manager. + pub fn clear(&mut self) { + for pm in &mut self.phase_manager { + pm.clear(); + } + self.copy_manager.lock().unwrap().clear(); + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + /// * `phase`: The challenge phase (as an index) of the gate thread. + pub fn main(&mut self, phase: usize) -> &mut Context { + self.touch(phase); + self.phase_manager[phase].main() + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self, phase: usize) -> &mut Context { + self.touch(phase); + self.phase_manager[phase].new_thread() + } + + /// Returns a mutable reference to the [SinglePhaseCoreManager] of a given `phase`. + pub fn in_phase(&mut self, phase: usize) -> &mut SinglePhaseCoreManager { + self.phase_manager.get_mut(phase).unwrap() + } + + /// Populate `self` up to Phase `phase` (inclusive) + pub(crate) fn touch(&mut self, phase: usize) { + while self.phase_manager.len() <= phase { + let _phase = self.phase_manager.len(); + let pm = SinglePhaseCoreManager::new(self.witness_gen_only, self.copy_manager.clone()) + .in_phase(_phase); + self.phase_manager.push(pm); + } + } + + /// Returns some statistics about the virtual region. + pub fn statistics(&self) -> GateStatistics { + let total_advice_per_phase = + self.phase_manager.iter().map(|pm| pm.total_advice()).collect::>(); + + let total_fixed: usize = self + .copy_manager + .lock() + .unwrap() + .constant_equalities + .iter() + .map(|(c, _)| *c) + .sorted() + .dedup() + .count(); + + GateStatistics { total_advice_per_phase, total_fixed } + } + + /// Auto-calculates configuration parameters for the circuit + /// + /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) + /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. + pub fn calculate_params(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let stats = self.statistics(); + // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) + // if this is too small, manual configuration will be needed + let num_advice_per_phase = stats + .total_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + let num_fixed = (stats.total_fixed + (1 << k) - 1) >> k; + + let params = FlexGateConfigParams { num_advice_per_phase, num_fixed, k }; + #[cfg(feature = "display")] + { + for (phase, num_advice) in stats.total_advice_per_phase.iter().enumerate() { + println!("Gate Chip | Phase {phase}: {num_advice} advice cells",); + } + println!("Total {} fixed cells", stats.total_fixed); + log::info!("Auto-calculated config params:\n {params:#?}"); + } + params + } +} + +/// Basic statistics +pub struct GateStatistics { + /// Total advice cell count per phase + pub total_advice_per_phase: Vec, + /// Total distinct constants used + pub total_fixed: usize, +} diff --git a/halo2-base/src/gates/flex_gate/threads/parallelize.rs b/halo2-base/src/gates/flex_gate/threads/parallelize.rs new file mode 100644 index 00000000..cc2754b0 --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/parallelize.rs @@ -0,0 +1,29 @@ +use rayon::prelude::*; + +use crate::{utils::ScalarField, Context}; + +use super::SinglePhaseCoreManager; + +/// Utility function to parallelize an operation involving [`Context`]s. +pub fn parallelize_core( + builder: &mut SinglePhaseCoreManager, // leaving `builder` for historical reasons, `pool` is a better name + input: Vec, + f: FR, +) -> Vec +where + F: ScalarField, + T: Send, + R: Send, + FR: Fn(&mut Context, T) -> R + Send + Sync, +{ + // to prevent concurrency issues with context id, we generate all the ids first + let thread_count = builder.thread_count(); + let mut ctxs = + (0..input.len()).map(|i| builder.new_context(thread_count + i)).collect::>(); + let outputs: Vec<_> = + input.into_par_iter().zip(ctxs.par_iter_mut()).map(|(input, ctx)| f(ctx, input)).collect(); + // we collect the new threads to ensure they are a FIXED order, otherwise the circuit will not be deterministic + builder.threads.append(&mut ctxs); + + outputs +} diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs new file mode 100644 index 00000000..dd8b30d5 --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -0,0 +1,318 @@ +use std::{any::TypeId, cell::RefCell}; + +use getset::CopyGetters; + +use crate::{ + gates::{ + circuit::CircuitBuilderStage, + flex_gate::{BasicGateConfig, ThreadBreakPoints}, + }, + utils::halo2::{raw_assign_advice, raw_constrain_equal}, + utils::ScalarField, + virtual_region::copy_constraints::{CopyConstraintManager, SharedCopyConstraintManager}, + Context, ContextCell, +}; +use crate::{ + halo2_proofs::{ + circuit::{Region, Value}, + plonk::{FirstPhase, SecondPhase, ThirdPhase}, + }, + virtual_region::manager::VirtualRegionManager, +}; + +/// Virtual region manager for [Vec] in a single challenge phase. +/// This is the core manager for [Context]s. +#[derive(Clone, Debug, Default, CopyGetters)] +pub struct SinglePhaseCoreManager { + /// Virtual columns. These cannot be shared across CPU threads while keeping the circuit deterministic. + pub threads: Vec>, + /// Global shared copy manager + pub copy_manager: SharedCopyConstraintManager, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + #[getset(get_copy = "pub")] + witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + #[getset(get_copy = "pub")] + pub(crate) use_unknown: bool, + /// The challenge phase the virtual regions will map to. + #[getset(get_copy = "pub", set)] + pub(crate) phase: usize, + /// A very simple computation graph for the basic vertical gate. Must be provided as a "pinning" + /// when running the production prover. + pub break_points: RefCell>, +} + +impl SinglePhaseCoreManager { + /// Creates a new [GateThreadBuilder] and spawns a main thread. + /// * `witness_gen_only`: If true, the [GateThreadBuilder] is used for witness generation only. + /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). + /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. + pub fn new(witness_gen_only: bool, copy_manager: SharedCopyConstraintManager) -> Self { + Self { + threads: vec![], + witness_gen_only, + use_unknown: false, + phase: 0, + copy_manager, + ..Default::default() + } + } + + /// Sets the phase to `phase` + pub fn in_phase(self, phase: usize) -> Self { + Self { phase, ..self } + } + + /// Creates a new [GateThreadBuilder] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [GateThreadBuilder] is used for witness generation only. + pub fn from_stage( + stage: CircuitBuilderStage, + copy_manager: SharedCopyConstraintManager, + ) -> Self { + Self::new(stage.witness_gen_only(), copy_manager) + .unknown(stage == CircuitBuilderStage::Keygen) + } + + /// Creates a new [GateThreadBuilder] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + pub fn unknown(self, use_unknown: bool) -> Self { + Self { use_unknown, ..self } + } + + /// Mutates `self` to use the given copy manager everywhere, including in all threads. + pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { + self.copy_manager = copy_manager.clone(); + for ctx in &mut self.threads { + ctx.copy_manager = copy_manager.clone(); + } + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + self.set_copy_manager(copy_manager); + self + } + + /// Clears all threads and copy manager + pub fn clear(&mut self) { + self.threads = vec![]; + self.copy_manager.lock().unwrap().clear(); + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + pub fn main(&mut self) -> &mut Context { + if self.threads.is_empty() { + self.new_thread() + } else { + self.threads.last_mut().unwrap() + } + } + + /// Returns the number of threads + pub fn thread_count(&self) -> usize { + self.threads.len() + } + + /// A distinct tag for this particular type of virtual manager, which is different for each phase. + pub fn type_of(&self) -> TypeId { + match self.phase { + 0 => TypeId::of::<(Self, FirstPhase)>(), + 1 => TypeId::of::<(Self, SecondPhase)>(), + 2 => TypeId::of::<(Self, ThirdPhase)>(), + _ => panic!("Unsupported phase"), + } + } + + /// Creates new context but does not append to `self.threads` + pub fn new_context(&self, context_id: usize) -> Context { + Context::new( + self.witness_gen_only, + self.phase, + self.type_of(), + context_id, + self.copy_manager.clone(), + ) + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self) -> &mut Context { + let context_id = self.thread_count(); + self.threads.push(self.new_context(context_id)); + self.threads.last_mut().unwrap() + } + + /// Returns total advice cells + pub fn total_advice(&self) -> usize { + self.threads.iter().map(|ctx| ctx.advice.len()).sum::() + } +} + +impl VirtualRegionManager for SinglePhaseCoreManager { + type Config = (Vec>, usize); // usize = usable_rows + + fn assign_raw(&self, (config, usable_rows): &Self::Config, region: &mut Region) { + if self.witness_gen_only { + let binding = self.break_points.borrow(); + let break_points = binding.as_ref().expect("break points not set"); + assign_witnesses(&self.threads, config, region, break_points); + } else { + let mut copy_manager = self.copy_manager.lock().unwrap(); + let break_points = assign_with_constraints::( + &self.threads, + config, + region, + &mut copy_manager, + *usable_rows, + self.use_unknown, + ); + let mut bp = self.break_points.borrow_mut(); + if let Some(bp) = bp.as_ref() { + assert_eq!(bp, &break_points, "break points don't match"); + } else { + *bp = Some(break_points); + } + } + } +} + +/// Assigns all virtual `threads` to the physical columns in `basic_gates` and returns the break points. +/// Also enables corresponding selectors and adds raw assigned cells to the `copy_manager`. +/// This function should be called either during proving & verifier key generation or when running MockProver. +/// +/// For proof generation, see [assign_witnesses]. +/// +/// This is generic for a "vertical" custom gate that uses a single column and `ROTATIONS` contiguous rows in that column. +/// +/// ⚠️ Right now we only support "overlaps" where you can have the gate enabled at `offset` and `offset + ROTATIONS - 1`, but not at `offset + delta` where `delta < ROTATIONS - 1`. +/// +/// # Inputs +/// - `max_rows`: The number of rows that can be used for the assignment. This is the number of rows that are not blinded for zero-knowledge. +/// - If `use_unknown` is true, then the advice columns will be assigned as unknowns. +/// +/// # Assumptions +/// - All `basic_gates` are in the same phase. +pub fn assign_with_constraints( + threads: &[Context], + basic_gates: &[BasicGateConfig], + region: &mut Region, + copy_manager: &mut CopyConstraintManager, + max_rows: usize, + use_unknown: bool, +) -> ThreadBreakPoints { + let mut break_points = vec![]; + let mut gate_index = 0; + let mut row_offset = 0; + for ctx in threads { + if ctx.advice.is_empty() { + continue; + } + let mut basic_gate = basic_gates + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + assert_eq!(ctx.selector.len(), ctx.advice.len()); + + for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { + let column = basic_gate.value; + let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; + #[cfg(feature = "halo2-axiom")] + let cell = region.assign_advice(column, row_offset, value).cell(); + #[cfg(not(feature = "halo2-axiom"))] + let cell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + copy_manager + .assigned_advices + .insert(ContextCell::new(ctx.type_id, ctx.context_id, i), cell); + + // If selector enabled and row_offset is valid add break point, account for break point overlap, and enforce equality constraint for gate outputs. + // ⚠️ This assumes overlap is of form: gate enabled at `i - delta` and `i`, where `delta = ROTATIONS - 1`. We currently do not support `delta < ROTATIONS - 1`. + if (q && row_offset + ROTATIONS > max_rows) || row_offset >= max_rows - 1 { + break_points.push(row_offset); + row_offset = 0; + gate_index += 1; + + // safety check: make sure selector is not enabled on `i - delta` for `0 < delta < ROTATIONS - 1` + if ROTATIONS > 1 && i + 2 >= ROTATIONS { + for delta in 1..ROTATIONS - 1 { + assert!( + !ctx.selector[i - delta], + "We do not support overlaps with delta = {delta}" + ); + } + } + // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety + basic_gate = basic_gates + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + let column = basic_gate.value; + #[cfg(feature = "halo2-axiom")] + let ncell = region.assign_advice(column, row_offset, value); + #[cfg(not(feature = "halo2-axiom"))] + let ncell = + region.assign_advice(|| "", column, row_offset, || value.map(|v| *v)).unwrap(); + raw_constrain_equal(region, ncell.cell(), cell); + } + + if q { + basic_gate + .q_enable + .enable(region, row_offset) + .expect("enable selector should not fail"); + } + + row_offset += 1; + } + } + break_points +} + +/// Assigns all virtual `threads` to the physical columns in `basic_gates` according to a precomputed "computation graph" +/// given by `break_points`. (`break_points` tells the assigner when to move to the next column.) +/// +/// This function does not impose **any** constraints. It only assigns witnesses to advice columns, and should be called +/// only during proof generation. +/// +/// # Assumptions +/// - All `basic_gates` are in the same phase. +pub fn assign_witnesses( + threads: &[Context], + basic_gates: &[BasicGateConfig], + region: &mut Region, + break_points: &ThreadBreakPoints, +) { + if basic_gates.is_empty() { + assert_eq!( + threads.iter().map(|ctx| ctx.advice.len()).sum::(), + 0, + "Trying to assign threads in a phase with no columns" + ); + return; + } + + let mut break_points = break_points.clone().into_iter(); + let mut break_point = break_points.next(); + + let mut gate_index = 0; + let mut column = basic_gates[gate_index].value; + let mut row_offset = 0; + + for ctx in threads { + // Assign advice values to the advice columns in each [Context] + for advice in &ctx.advice { + raw_assign_advice(region, column, row_offset, Value::known(advice)); + + if break_point == Some(row_offset) { + break_point = break_points.next(); + row_offset = 0; + gate_index += 1; + column = basic_gates[gate_index].value; + + raw_assign_advice(region, column, row_offset, Value::known(advice)); + } + + row_offset += 1; + } + } +} diff --git a/halo2-base/src/gates/mod.rs b/halo2-base/src/gates/mod.rs index a353a4f4..749ee834 100644 --- a/halo2-base/src/gates/mod.rs +++ b/halo2-base/src/gates/mod.rs @@ -1,5 +1,5 @@ -/// Module that helps auto-build circuits -pub mod builder; +/// Module providing tools to create a circuit using our gates +pub mod circuit; /// Module implementing our simple custom gate and common functions using it pub mod flex_gate; /// Module using a single lookup table for range checks diff --git a/halo2-base/src/gates/range.rs b/halo2-base/src/gates/range/mod.rs similarity index 76% rename from halo2-base/src/gates/range.rs rename to halo2-base/src/gates/range/mod.rs index 2592d515..79cdf155 100644 --- a/halo2-base/src/gates/range.rs +++ b/halo2-base/src/gates/range/mod.rs @@ -1,5 +1,5 @@ use crate::{ - gates::flex_gate::{FlexGateConfig, GateInstructions, GateStrategy, MAX_PHASE}, + gates::flex_gate::{FlexGateConfig, GateInstructions, MAX_PHASE}, halo2_proofs::{ circuit::{Layouter, Value}, plonk::{ @@ -11,30 +11,19 @@ use crate::{ biguint_to_fe, bit_length, decompose_fe_to_u64_limbs, fe_to_biguint, BigPrimeField, ScalarField, }, + virtual_region::lookups::LookupAnyManager, AssignedValue, Context, QuantumCell::{self, Constant, Existing, Witness}, }; + +use super::flex_gate::{FlexGateConfigParams, GateChip}; + +use getset::Getters; use num_bigint::BigUint; use num_integer::Integer; use num_traits::One; use std::{cmp::Ordering, ops::Shl}; -use super::flex_gate::GateChip; - -/// Specifies the gate strategy for the range chip -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum RangeStrategy { - /// # Vertical Gate Strategy: - /// `q_0 * (a + b * c - d) = 0` - /// where - /// * a = value[0], b = value[1], c = value[2], d = value[3] - /// * q = q_lookup[0] - /// * q is either 0 or 1 so this is just a simple selector - /// - /// Using `a + b * c` instead of `a * b + c` allows for "chaining" of gates, i.e., the output of one gate becomes `a` in the next gate. - Vertical, // vanilla implementation with vertical basic gate(s) -} - /// Configuration for Range Chip #[derive(Clone, Debug)] pub struct RangeConfig { @@ -47,15 +36,13 @@ pub struct RangeConfig { /// * If `gate` has only 1 advice column, lookups are enabled for that column, in which case `lookup_advice` is empty /// * If `gate` has more than 1 advice column some number of user-specified `lookup_advice` columns are added /// * In this case, we don't need a selector so `q_lookup` is empty - pub lookup_advice: [Vec>; MAX_PHASE], + pub lookup_advice: Vec>>, /// Selector values for the lookup table. pub q_lookup: Vec>, /// Column for lookup table values. pub lookup: TableColumn, /// Defines the number of bits represented in the lookup table [0,2^lookup_bits). lookup_bits: usize, - /// Gate Strategy used for specifying advice values. - _strategy: RangeStrategy, } impl RangeConfig { @@ -73,33 +60,27 @@ impl RangeConfig { /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) pub fn configure( meta: &mut ConstraintSystem, - range_strategy: RangeStrategy, - num_advice: &[usize], + gate_params: FlexGateConfigParams, num_lookup_advice: &[usize], - num_fixed: usize, lookup_bits: usize, - // params.k() - circuit_degree: usize, ) -> Self { - assert!(lookup_bits <= 28); + assert!(lookup_bits <= F::S as usize); + // sanity check: only create lookup table if there are lookup_advice columns + assert!(!num_lookup_advice.is_empty(), "You are creating a RangeConfig but don't seem to need a lookup table, please double-check if you're using lookups correctly. Consider setting lookup_bits = None in BaseConfigParams"); + let lookup = meta.lookup_table_column(); - let gate = FlexGateConfig::configure( - meta, - match range_strategy { - RangeStrategy::Vertical => GateStrategy::Vertical, - }, - num_advice, - num_fixed, - circuit_degree, - ); + let gate = FlexGateConfig::configure(meta, gate_params.clone()); // For now, we apply the same range lookup table to each phase let mut q_lookup = Vec::new(); - let mut lookup_advice = [(); MAX_PHASE].map(|_| Vec::new()); + let mut lookup_advice = Vec::new(); for (phase, &num_columns) in num_lookup_advice.iter().enumerate() { - // if num_columns is set to 0, then we assume you do not want to perform any lookups in that phase - if num_advice[phase] == 1 && num_columns != 0 { + let num_advice = *gate_params.num_advice_per_phase.get(phase).unwrap_or(&0); + let mut columns = Vec::new(); + // If num_columns is set to 0, then we assume you do not want to perform any lookups in that phase. + // Disable this optimization in phase > 0 because you might set selectors based a cell from other columns. + if phase == 0 && num_advice == 1 && num_columns != 0 { q_lookup.push(Some(meta.complex_selector())); } else { q_lookup.push(None); @@ -111,19 +92,17 @@ impl RangeConfig { _ => panic!("Currently RangeConfig only supports {MAX_PHASE} phases"), }; meta.enable_equality(a); - lookup_advice[phase].push(a); + columns.push(a); } } + lookup_advice.push(columns); } - let mut config = - Self { lookup_advice, q_lookup, lookup, lookup_bits, gate, _strategy: range_strategy }; + let mut config = Self { lookup_advice, q_lookup, lookup, lookup_bits, gate }; + config.create_lookup(meta); - // sanity check: only create lookup table if there are lookup_advice columns - if !num_lookup_advice.is_empty() { - config.create_lookup(meta); - } - config.gate.max_rows = (1 << circuit_degree) - meta.minimum_rows(); + log::info!("Poisoned rows after RangeConfig::configure {}", meta.minimum_rows()); + config.gate.max_rows = (1 << gate_params.k) - meta.minimum_rows(); assert!( (1 << lookup_bits) <= config.gate.max_rows, "lookup table is too large for the circuit degree plus blinding factors!" @@ -189,17 +168,14 @@ pub trait RangeInstructions { /// Returns the type of gate used. fn gate(&self) -> &Self::Gate; - /// Returns the [GateStrategy] for this range. - fn strategy(&self) -> RangeStrategy; - /// Returns the number of bits the lookup table represents. fn lookup_bits(&self) -> usize; /// Checks and constrains that `a` lies in the range [0, 2range_bits). /// - /// Assumes that both `a`<= `range_bits` bits. - /// * a: [AssignedValue] value to be range checked - /// * range_bits: number of bits to represent the range + /// Inputs: + /// * `a`: [AssignedValue] value to be range checked + /// * `range_bits`: number of bits in the range fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize); /// Constrains that 'a' is less than 'b'. @@ -227,7 +203,7 @@ pub trait RangeInstructions { (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); self.range_check(ctx, a, range_bits); - self.check_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) + self.check_less_than(ctx, a, Constant(F::from(b)), range_bits) } /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. @@ -275,7 +251,7 @@ pub trait RangeInstructions { (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); self.range_check(ctx, a, range_bits); - self.is_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) + self.is_less_than(ctx, a, Constant(F::from(b)), range_bits) } /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is in `[0,b)`. @@ -382,7 +358,7 @@ pub trait RangeInstructions { let [div_lo, div_hi, div, rem] = [-5, -4, -2, -1].map(|i| ctx.get(i)); self.range_check(ctx, div_lo, b_num_bits); if a_num_bits <= b_num_bits { - self.gate().assert_is_const(ctx, &div_hi, &F::zero()); + self.gate().assert_is_const(ctx, &div_hi, &F::ZERO); } else { self.range_check(ctx, div_hi, a_num_bits - b_num_bits); } @@ -415,7 +391,7 @@ pub trait RangeInstructions { ) -> AssignedValue { let a_big = fe_to_biguint(a.value()); let bit_v = F::from(a_big.bit(0)); - let two = self.gate().get_field_element(2u64); + let two = F::from(2u64); let h_v = F::from_bytes_le(&(a_big >> 1usize).to_bytes_le()); ctx.assign_region([Witness(bit_v), Witness(h_v), Constant(two), Existing(a)], [0]); @@ -428,19 +404,21 @@ pub trait RangeInstructions { } } -/// A chip that implements RangeInstructions which provides methods to constrain a field element `x` is within a range of bits. -#[derive(Clone, Debug)] +/// # RangeChip +/// This chip provides methods that rely on "range checking" that a field element `x` is within a range of bits. +/// Range checks are done using a lookup table with the numbers [0, 2lookup_bits). +#[derive(Clone, Debug, Getters)] pub struct RangeChip { - /// # RangeChip - /// Provides methods to constrain a field element `x` is within a range of bits. - /// Declares a lookup table of [0, 2lookup_bits) and constrains whether a field element appears in this table. - - /// [GateStrategy] for advice values in this chip. - strategy: RangeStrategy, /// Underlying [GateChip] for this chip. pub gate: GateChip, + /// Lookup manager for each phase, lazily initiated using the [SharedCopyConstraintManager] from the [Context] + /// that first calls it. + /// + /// The lookup manager is used to store the cells that need to be looked up in the range check lookup table. + #[getset(get = "pub")] + lookup_manager: [LookupAnyManager; MAX_PHASE], /// Defines the number of bits represented in the lookup table [0,2lookup_bits). - pub lookup_bits: usize, + lookup_bits: usize, /// [Vec] of powers of `2 ** lookup_bits` represented as [QuantumCell::Constant]. /// These are precomputed and cached as a performance optimization for later limb decompositions. We precompute up to the higher power that fits in `F`, which is `2 ** ((F::CAPACITY / lookup_bits) * lookup_bits)`. pub limb_bases: Vec>, @@ -450,105 +428,124 @@ impl RangeChip { /// Creates a new [RangeChip] with the given strategy and lookup_bits. /// * strategy: [GateStrategy] for advice values in this chip /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) - pub fn new(strategy: RangeStrategy, lookup_bits: usize) -> Self { + pub fn new(lookup_bits: usize, lookup_manager: [LookupAnyManager; MAX_PHASE]) -> Self { let limb_base = F::from(1u64 << lookup_bits); let mut running_base = limb_base; let num_bases = F::CAPACITY as usize / lookup_bits; let mut limb_bases = Vec::with_capacity(num_bases + 1); - limb_bases.extend([Constant(F::one()), Constant(running_base)]); + limb_bases.extend([Constant(F::ONE), Constant(running_base)]); for _ in 2..=num_bases { running_base *= &limb_base; limb_bases.push(Constant(running_base)); } - let gate = GateChip::new(match strategy { - RangeStrategy::Vertical => GateStrategy::Vertical, - }); + let gate = GateChip::new(); - Self { strategy, gate, lookup_bits, limb_bases } + Self { gate, lookup_bits, lookup_manager, limb_bases } } - /// Creates a new [RangeChip] with the default strategy and provided lookup_bits. - /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) - pub fn default(lookup_bits: usize) -> Self { - Self::new(RangeStrategy::Vertical, lookup_bits) - } -} - -impl RangeInstructions for RangeChip { - type Gate = GateChip; - - /// The type of Gate used in this chip. - fn gate(&self) -> &Self::Gate { - &self.gate - } - - /// Returns the [GateStrategy] for this range. - fn strategy(&self) -> RangeStrategy { - self.strategy - } - - /// Defines the number of bits represented in the lookup table [0,2lookup_bits). - fn lookup_bits(&self) -> usize { - self.lookup_bits + fn add_cell_to_lookup(&self, ctx: &Context, a: AssignedValue) { + let phase = ctx.phase(); + let manager = &self.lookup_manager[phase]; + manager.add_lookup(ctx.tag(), [a]); } /// Checks and constrains that `a` lies in the range [0, 2range_bits). /// - /// This is done by decomposing `a` into `k` limbs, where `k = ceil(range_bits / lookup_bits)`. + /// This is done by decomposing `a` into `num_limbs` limbs, where `num_limbs = ceil(range_bits / lookup_bits)`. /// Each limb is constrained to be within the range [0, 2lookup_bits). /// The limbs are then combined to form `a` again with the last limb having `rem_bits` number of bits. /// + /// Returns the last (highest) limb. + /// + /// Inputs: /// * `a`: [AssignedValue] value to be range checked /// * `range_bits`: number of bits in the range /// * `lookup_bits`: number of bits in the lookup table /// /// # Assumptions /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` - fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { + fn _range_check( + &self, + ctx: &mut Context, + a: AssignedValue, + range_bits: usize, + ) -> AssignedValue { if range_bits == 0 { - self.gate.assert_is_const(ctx, &a, &F::zero()); - return; + self.gate.assert_is_const(ctx, &a, &F::ZERO); + return a; } // the number of limbs - let k = (range_bits + self.lookup_bits - 1) / self.lookup_bits; + let num_limbs = (range_bits + self.lookup_bits - 1) / self.lookup_bits; // println!("range check {} bits {} len", range_bits, k); let rem_bits = range_bits % self.lookup_bits; - debug_assert!(self.limb_bases.len() >= k); + debug_assert!(self.limb_bases.len() >= num_limbs); - if k == 1 { - ctx.cells_to_lookup.push(a); + let last_limb = if num_limbs == 1 { + self.add_cell_to_lookup(ctx, a); + a } else { - let limbs = decompose_fe_to_u64_limbs(a.value(), k, self.lookup_bits) + let limbs = decompose_fe_to_u64_limbs(a.value(), num_limbs, self.lookup_bits) .into_iter() .map(|x| Witness(F::from(x))); let row_offset = ctx.advice.len() as isize; - let acc = self.gate.inner_product(ctx, limbs, self.limb_bases[..k].to_vec()); + let acc = self.gate.inner_product(ctx, limbs, self.limb_bases[..num_limbs].to_vec()); // the inner product above must equal `a` ctx.constrain_equal(&a, &acc); // we fetch the cells to lookup by getting the indices where `limbs` were assigned in `inner_product`. Because `limb_bases[0]` is 1, the progression of indices is 0,1,4,...,4+3*i - ctx.cells_to_lookup.push(ctx.get(row_offset)); - for i in 0..k - 1 { - ctx.cells_to_lookup.push(ctx.get(row_offset + 1 + 3 * i as isize)); + self.add_cell_to_lookup(ctx, ctx.get(row_offset)); + for i in 0..num_limbs - 1 { + self.add_cell_to_lookup(ctx, ctx.get(row_offset + 1 + 3 * i as isize)); } + ctx.get(row_offset + 1 + 3 * (num_limbs - 2) as isize) }; // additional constraints for the last limb if rem_bits != 0 match rem_bits.cmp(&1) { - // we want to check x := limbs[k-1] is boolean + // we want to check x := limbs[num_limbs-1] is boolean // we constrain x*(x-1) = 0 + x * x - x == 0 // | 0 | x | x | x | Ordering::Equal => { - self.gate.assert_bit(ctx, *ctx.cells_to_lookup.last().unwrap()); + self.gate.assert_bit(ctx, last_limb); } Ordering::Greater => { let mult_val = self.gate.pow_of_two[self.lookup_bits - rem_bits]; - let check = - self.gate.mul(ctx, *ctx.cells_to_lookup.last().unwrap(), Constant(mult_val)); - ctx.cells_to_lookup.push(check); + let check = self.gate.mul(ctx, last_limb, Constant(mult_val)); + self.add_cell_to_lookup(ctx, check); } _ => {} } + last_limb + } +} + +impl RangeInstructions for RangeChip { + type Gate = GateChip; + + /// The type of Gate used in this chip. + fn gate(&self) -> &Self::Gate { + &self.gate + } + + /// Returns the number of bits represented in the lookup table [0,2lookup_bits). + fn lookup_bits(&self) -> usize { + self.lookup_bits + } + + /// Checks and constrains that `a` lies in the range [0, 2range_bits). + /// + /// This is done by decomposing `a` into `num_limbs` limbs, where `num_limbs = ceil(range_bits / lookup_bits)`. + /// Each limb is constrained to be within the range [0, 2lookup_bits). + /// The limbs are then combined to form `a` again with the last limb having `rem_bits` number of bits. + /// + /// Inputs: + /// * `a`: [AssignedValue] value to be range checked + /// * `range_bits`: number of bits in the range + /// + /// # Assumptions + /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` + fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { + self._range_check(ctx, a, range_bits); } /// Constrains that 'a' is less than 'b'. @@ -569,22 +566,20 @@ impl RangeInstructions for RangeChip { let a = a.into(); let b = b.into(); let pow_of_two = self.gate.pow_of_two[num_bits]; - let check_cell = match self.strategy { - RangeStrategy::Vertical => { - let shift_a_val = pow_of_two + a.value(); - // | a + 2^(num_bits) - b | b | 1 | a + 2^(num_bits) | - 2^(num_bits) | 1 | a | - let cells = [ - Witness(shift_a_val - b.value()), - b, - Constant(F::one()), - Witness(shift_a_val), - Constant(-pow_of_two), - Constant(F::one()), - a, - ]; - ctx.assign_region(cells, [0, 3]); - ctx.get(-7) - } + let check_cell = { + let shift_a_val = pow_of_two + a.value(); + // | a + 2^(num_bits) - b | b | 1 | a + 2^(num_bits) | - 2^(num_bits) | 1 | a | + let cells = [ + Witness(shift_a_val - b.value()), + b, + Constant(F::ONE), + Witness(shift_a_val), + Constant(-pow_of_two), + Constant(F::ONE), + a, + ]; + ctx.assign_region(cells, [0, 3]); + ctx.get(-7) }; self.range_check(ctx, check_cell, num_bits); @@ -619,28 +614,26 @@ impl RangeInstructions for RangeChip { let shift_a_val = pow_padded + a.value(); let shifted_val = shift_a_val - b.value(); - let shifted_cell = match self.strategy { - RangeStrategy::Vertical => { - ctx.assign_region( - [ - Witness(shifted_val), - b, - Constant(F::one()), - Witness(shift_a_val), - Constant(-pow_padded), - Constant(F::one()), - a, - ], - [0, 3], - ); - ctx.get(-7) - } + let shifted_cell = { + ctx.assign_region( + [ + Witness(shifted_val), + b, + Constant(F::ONE), + Witness(shift_a_val), + Constant(-pow_padded), + Constant(F::ONE), + a, + ], + [0, 3], + ); + ctx.get(-7) }; // check whether a - b + 2^padded_bits < 2^padded_bits ? // since assuming a, b < 2^padded_bits we are guaranteed a - b + 2^padded_bits < 2^{padded_bits + 1} - self.range_check(ctx, shifted_cell, padded_bits + self.lookup_bits); - // ctx.cells_to_lookup.last() will have the (k + 1)-th limb of `a - b + 2^{k * limb_bits}`, which is zero iff `a < b` - self.gate.is_zero(ctx, *ctx.cells_to_lookup.last().unwrap()) + let last_limb = self._range_check(ctx, shifted_cell, padded_bits + self.lookup_bits); + // last_limb will have the (k + 1)-th limb of `a - b + 2^{k * limb_bits}`, which is zero iff `a < b` + self.gate.is_zero(ctx, last_limb) } } diff --git a/halo2-base/src/gates/tests/README.md b/halo2-base/src/gates/tests/README.md deleted file mode 100644 index 24f34537..00000000 --- a/halo2-base/src/gates/tests/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# Tests - -For tests that use `GateCircuitBuilder` or `RangeCircuitBuilder`, we currently must use environmental variables `FLEX_GATE_CONFIG` and `LOOKUP_BITS` to pass circuit configuration parameters to the `Circuit::configure` function. This is troublesome when Rust executes tests in parallel, so we to make sure all tests pass, run - -``` -cargo test -- --test-threads=1 -``` - -to force serial execution. diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index 2434b1d1..53cf9513 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -1,293 +1,225 @@ #![allow(clippy::type_complexity)] use super::*; -use crate::halo2_proofs::dev::MockProver; -use crate::halo2_proofs::dev::VerifyFailure; -use crate::utils::ScalarField; -use crate::QuantumCell::Witness; -use crate::{ - gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder}, - flex_gate::{GateChip, GateInstructions}, - }, - QuantumCell, -}; +use crate::utils::biguint_to_fe; +use crate::utils::testing::base_test; +use crate::QuantumCell::{Constant, Witness}; +use crate::{gates::flex_gate::GateInstructions, QuantumCell}; +use itertools::Itertools; +use num_bigint::BigUint; use test_case::test_case; -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "add(): 1 + 1 == 2")] -pub fn test_add(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.add(ctx, inputs[0], inputs[1]); - *a.value() +#[test_case(&[10, 12].map(Fr::from).map(Witness)=> Fr::from(22); "add(): 10 + 12 == 22")] +#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(2); "add(): 1 + 1 == 2")] +pub fn test_add(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.add(ctx, inputs[0], inputs[1]).value()) } -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub(): 1 - 1 == 0")] -pub fn test_sub(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.sub(ctx, inputs[0], inputs[1]); - *a.value() +#[test_case(Witness(Fr::from(10))=> Fr::from(11); "inc(): 10 -> 11")] +#[test_case(Witness(Fr::from(1))=> Fr::from(2); "inc(): 1 -> 2")] +pub fn test_inc(input: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.inc(ctx, input).value()) +} + +#[test_case(&[10, 12].map(Fr::from).map(Witness)=> -Fr::from(2) ; "sub(): 10 - 12 == -2")] +#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(0) ; "sub(): 1 - 1 == 0")] +pub fn test_sub(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.sub(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(Witness(Fr::from(10))=> Fr::from(9); "dec(): 10 -> 9")] +#[test_case(Witness(Fr::from(1))=> Fr::from(0); "dec(): 1 -> 0")] +pub fn test_dec(input: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.dec(ctx, input).value()) } #[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub_mul(): 1 - 1 * 1 == 0")] -pub fn test_sub_mul(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.sub_mul(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() +pub fn test_sub_mul(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.sub_mul(ctx, inputs[0], inputs[1], inputs[2]).value()) } #[test_case(Witness(Fr::from(1)) => -Fr::from(1) ; "neg(): 1 -> -1")] -pub fn test_neg(a: QuantumCell) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.neg(ctx, a); - *a.value() +pub fn test_neg(a: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.neg(ctx, a).value()) } +#[test_case(&[10, 12].map(Fr::from).map(Witness) => Fr::from(120) ; "mul(): 10 * 12 == 120")] #[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "mul(): 1 * 1 == 1")] -pub fn test_mul(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.mul(ctx, inputs[0], inputs[1]); - *a.value() +pub fn test_mul(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.mul(ctx, inputs[0], inputs[1]).value()) } #[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "mul_add(): 1 * 1 + 1 == 2")] -pub fn test_mul_add(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.mul_add(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() +pub fn test_mul_add(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.mul_add(ctx, inputs[0], inputs[1], inputs[2]).value()) } -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "mul_not(): 1 * 1 == 0")] -pub fn test_mul_not(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.mul_not(ctx, inputs[0], inputs[1]); - *a.value() +#[test_case(&[0, 10].map(Fr::from).map(Witness) => Fr::from(10); "mul_not(): (1 - 0) * 10 == 10")] +#[test_case(&[1, 10].map(Fr::from).map(Witness) => Fr::from(0); "mul_not(): (1 - 1) * 10 == 0")] +pub fn test_mul_not(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.mul_not(ctx, inputs[0], inputs[1]).value()) } -#[test_case(Fr::from(1) => Ok(()); "assert_bit(): 1 == bit")] -pub fn test_assert_bit(input: F) -> Result<(), Vec> { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([input])[0]; - chip.assert_bit(ctx, a); - // auto-tune circuit - builder.config(6, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - MockProver::run(6, &circuit, vec![]).unwrap().verify() +#[test_case(Fr::from(0), true; "assert_bit(0)")] +#[test_case(Fr::from(1), true; "assert_bit(1)")] +#[test_case(Fr::from(2), false; "assert_bit(2)")] +pub fn test_assert_bit(input: Fr, is_bit: bool) { + base_test().expect_satisfied(is_bit).run_gate(|ctx, chip| { + let a = ctx.load_witness(input); + chip.assert_bit(ctx, a); + }); } -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "div_unsafe(): 1 / 1 == 1")] -pub fn test_div_unsafe(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.div_unsafe(ctx, inputs[0], inputs[1]); - *a.value() +#[test_case(&[6, 2].map(Fr::from).map(Witness)=> Fr::from(3) ; "div_unsafe(): 6 / 2 == 3")] +#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(1) ; "div_unsafe(): 1 / 1 == 1")] +pub fn test_div_unsafe(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.div_unsafe(ctx, inputs[0], inputs[1]).value()) } -#[test_case(&[1, 1].map(Fr::from); "assert_is_const()")] -pub fn test_assert_is_const(inputs: &[F]) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([inputs[0]])[0]; - chip.assert_is_const(ctx, &a, &inputs[1]); - // auto-tune circuit - builder.config(6, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - MockProver::run(6, &circuit, vec![]).unwrap().assert_satisfied() +#[test_case(&[1, 1].map(Fr::from); "assert_is_const(1,1)")] +#[test_case(&[0, 1].map(Fr::from); "assert_is_const(0,1)")] +pub fn test_assert_is_const(inputs: &[Fr]) { + base_test().expect_satisfied(inputs[0] == inputs[1]).run_gate(|ctx, chip| { + let a = ctx.load_witness(inputs[0]); + chip.assert_is_const(ctx, &a, &inputs[1]); + }); } #[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => Fr::from(5) ; "inner_product(): 1 * 1 + ... + 1 * 1 == 5")] -pub fn test_inner_product(input: (Vec>, Vec>)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.inner_product(ctx, input.0, input.1); - *a.value() +pub fn test_inner_product(input: (Vec>, Vec>)) -> Fr { + base_test().run_gate(|ctx, chip| *chip.inner_product(ctx, input.0, input.1).value()) } #[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (Fr::from(5), Fr::from(1)); "inner_product_left_last(): 1 * 1 + ... + 1 * 1 == (5, 1)")] -pub fn test_inner_product_left_last( - input: (Vec>, Vec>), -) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.inner_product_left_last(ctx, input.0, input.1); - (*a.0.value(), *a.1.value()) -} - -#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => vec![Fr::one(), Fr::from(2), Fr::from(3), Fr::from(4), Fr::from(5)]; "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] -pub fn test_inner_product_with_sums( - input: (Vec>, Vec>), -) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.inner_product_with_sums(ctx, input.0, input.1); - a.into_iter().map(|x| *x.value()).collect() +pub fn test_inner_product_left_last( + input: (Vec>, Vec>), +) -> (Fr, Fr) { + base_test().run_gate(|ctx, chip| { + let a = chip.inner_product_left_last(ctx, input.0, input.1); + (*a.0.value(), *a.1.value()) + }) +} + +#[test_case([4,5,6].map(Fr::from).to_vec(), [1,2,3].map(|x| Constant(Fr::from(x))).to_vec() => (Fr::from(32), [4,5,6].map(Fr::from).to_vec()); +"inner_product_left(): <[1,2,3],[4,5,6]> Constant b starts with 1")] +#[test_case([1,2,3].map(Fr::from).to_vec(), [4,5,6].map(|x| Witness(Fr::from(x))).to_vec() => (Fr::from(32), [1,2,3].map(Fr::from).to_vec()); +"inner_product_left(): <[1,2,3],[4,5,6]> Witness")] +pub fn test_inner_product_left(a: Vec, b: Vec>) -> (Fr, Vec) { + base_test().run_gate(|ctx, chip| { + let (prod, a) = chip.inner_product_left(ctx, a.into_iter().map(Witness), b); + (*prod.value(), a.iter().map(|v| *v.value()).collect()) + }) +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (1..=5).map(Fr::from).collect::>(); "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] +pub fn test_inner_product_with_sums( + input: (Vec>, Vec>), +) -> Vec { + base_test().run_gate(|ctx, chip| { + chip.inner_product_with_sums(ctx, input.0, input.1).map(|a| *a.value()).collect() + }) } #[test_case((vec![(Fr::from(1), Witness(Fr::from(1)), Witness(Fr::from(1)))], Witness(Fr::from(1))) => Fr::from(2) ; "sum_product_with_coeff_and_var(): 1 * 1 + 1 == 2")] -pub fn test_sum_products_with_coeff_and_var( - input: (Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), -) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.sum_products_with_coeff_and_var(ctx, input.0, input.1); - *a.value() -} - -#[test_case(&[1, 0].map(Fr::from).map(Witness) => Fr::from(1) ; "or(): 1 || 0 == 1")] -pub fn test_or(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.or(ctx, inputs[0], inputs[1]); - *a.value() +pub fn test_sum_products_with_coeff_and_var( + input: (Vec<(Fr, QuantumCell, QuantumCell)>, QuantumCell), +) -> Fr { + base_test() + .run_gate(|ctx, chip| *chip.sum_products_with_coeff_and_var(ctx, input.0, input.1).value()) } +#[test_case(&[1, 0].map(Fr::from).map(Witness) => Fr::from(0) ; "and(): 1 && 0 == 0")] #[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "and(): 1 && 1 == 1")] -pub fn test_and(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.and(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "xor(): 1 ^ 1 == 0")] -pub fn test_xor(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.xor(ctx, inputs[0], inputs[1]); - *a.value() +pub fn test_and(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.and(ctx, inputs[0], inputs[1]).value()) } -#[test_case(Witness(Fr::from(1)) => Fr::zero() ; "not(): !1 == 0")] -pub fn test_not(a: QuantumCell) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.not(ctx, a); - *a.value() +#[test_case(Witness(Fr::from(1)) => Fr::zero(); "not(): !1 == 0")] +#[test_case(Witness(Fr::from(0)) => Fr::one(); "not(): !0 == 1")] +pub fn test_not(a: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.not(ctx, a).value()) } -#[test_case(&[2, 3, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "select(): 2 ? 3 : 1 == 2")] -pub fn test_select(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.select(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() +#[test_case(&[2, 3, 1].map(Fr::from).map(Witness) => Fr::from(2); "select(): 2 ? 3 : 1 == 2")] +pub fn test_select(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.select(ctx, inputs[0], inputs[1], inputs[2]).value()) } -#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "or_and(): 1 || 1 && 1 == 1")] -pub fn test_or_and(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.or_and(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() +#[test_case(&[0, 1, 0].map(Fr::from).map(Witness) => Fr::from(0); "or_and(): 0 || (1 && 0) == 0")] +#[test_case(&[1, 0, 1].map(Fr::from).map(Witness) => Fr::from(1); "or_and(): 1 || (0 && 1) == 1")] +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(1); "or_and(): 1 || (1 && 1) == 1")] +pub fn test_or_and(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.or_and(ctx, inputs[0], inputs[1], inputs[2]).value()) } -#[test_case(Fr::zero() => vec![Fr::one(), Fr::zero()]; "bits_to_indicator(): 0 -> [1, 0]")] -pub fn test_bits_to_indicator(bits: F) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([bits])[0]; - let a = chip.bits_to_indicator(ctx, &[a]); - a.iter().map(|x| *x.value()).collect() +#[test_case(&[0,1] => [0,0,1,0].map(Fr::from).to_vec(); "bits_to_indicator(): bin\"10 -> [0, 0, 1, 0]")] +#[test_case(&[0] => [1,0].map(Fr::from).to_vec(); "bits_to_indicator(): 0 -> [1, 0]")] +pub fn test_bits_to_indicator(bits: &[u8]) -> Vec { + base_test().run_gate(|ctx, chip| { + let a = ctx.assign_witnesses(bits.iter().map(|x| Fr::from(*x as u64))); + chip.bits_to_indicator(ctx, &a).iter().map(|a| *a.value()).collect() + }) } -#[test_case((Witness(Fr::zero()), 3) => vec![Fr::one(), Fr::zero(), Fr::zero()] ; "idx_to_indicator(): 0 -> [1, 0, 0]")] -pub fn test_idx_to_indicator(input: (QuantumCell, usize)) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.idx_to_indicator(ctx, input.0, input.1); - a.iter().map(|x| *x.value()).collect() +#[test_case(Witness(Fr::from(0)),3 => [1,0,0].map(Fr::from).to_vec(); "idx_to_indicator(): 0 -> [1, 0, 0]")] +pub fn test_idx_to_indicator(idx: QuantumCell, len: usize) -> Vec { + base_test().run_gate(|ctx, chip| { + chip.idx_to_indicator(ctx, idx, len).iter().map(|a| *a.value()).collect() + }) } -#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_by_indicator(): [0, 1, 2] -> 1")] -pub fn test_select_by_indicator(input: (Vec>, QuantumCell)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); - let a = chip.select_by_indicator(ctx, input.0, a); - *a.value() +#[test_case((0..3).map(Fr::from).map(Witness).collect(), Witness(Fr::one()) => Fr::from(1); "select_by_indicator(1): [0, 1, 2] -> 1")] +pub fn test_select_by_indicator(array: Vec>, idx: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| { + let a = chip.idx_to_indicator(ctx, idx, array.len()); + *chip.select_by_indicator(ctx, array, a).value() + }) } -#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_from_idx(): [0, 1, 2] -> 1")] -pub fn test_select_from_idx(input: (Vec>, QuantumCell)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.select_from_idx(ctx, input.0, input.1); - *a.value() +#[test_case((0..3).map(Fr::from).map(Witness).collect(), Witness(Fr::from(1)) => Fr::from(1); "select_from_idx(): [0, 1, 2] -> 1")] +pub fn test_select_from_idx(array: Vec>, idx: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.select_from_idx(ctx, array, idx).value()) } -#[test_case(Fr::zero() => Fr::from(1) ; "is_zero(): 0 -> 1")] -pub fn test_is_zero(x: F) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([x])[0]; - let a = chip.is_zero(ctx, a); - *a.value() +#[test_case(vec![vec![1,2,3], vec![4,5,6], vec![7,8,9]].into_iter().map(|a| a.into_iter().map(Fr::from).collect_vec()).collect_vec(), +Fr::from(1) => +[4,5,6].map(Fr::from).to_vec(); +"select_array_by_indicator(1): [[1,2,3], [4,5,6], [7,8,9]] -> [4,5,6]")] +pub fn test_select_array_by_indicator(array2d: Vec>, idx: Fr) -> Vec { + base_test().run_gate(|ctx, chip| { + let array2d = array2d.into_iter().map(|a| ctx.assign_witnesses(a)).collect_vec(); + let idx = ctx.load_witness(idx); + let ind = chip.idx_to_indicator(ctx, idx, array2d.len()); + chip.select_array_by_indicator(ctx, &array2d, &ind).iter().map(|a| *a.value()).collect() + }) } -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::one() ; "is_equal(): 1 == 1")] -pub fn test_is_equal(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.is_equal(ctx, inputs[0], inputs[1]); - *a.value() +#[test_case(Fr::zero() => Fr::from(1); "is_zero(): 0 -> 1")] +pub fn test_is_zero(input: Fr) -> Fr { + base_test().run_gate(|ctx, chip| { + let input = ctx.load_witness(input); + *chip.is_zero(ctx, input).value() + }) } -#[test_case((Fr::from(6u64), 3) => vec![Fr::zero(), Fr::one(), Fr::one()] ; "num_to_bits(): 6")] -pub fn test_num_to_bits(input: (F, usize)) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([input.0])[0]; - let a = chip.num_to_bits(ctx, a, input.1); - a.iter().map(|x| *x.value()).collect() +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::one(); "is_equal(): 1 == 1")] +pub fn test_is_equal(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.is_equal(ctx, inputs[0], inputs[1]).value()) } -#[test_case(&[0, 1, 2].map(Fr::from) => (Fr::one(), Fr::from(2)) ; "lagrange_eval(): constant fn")] -pub fn test_lagrange_eval(input: &[F]) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let input = ctx.assign_witnesses(input.iter().copied()); - let a = chip.lagrange_and_eval(ctx, &[(input[0], input[1])], input[2]); - (*a.0.value(), *a.1.value()) +#[test_case(6, 3 => [0,1,1].map(Fr::from).to_vec(); "num_to_bits(): 6")] +pub fn test_num_to_bits(num: usize, bits: usize) -> Vec { + base_test().run_gate(|ctx, chip| { + let num = ctx.load_witness(Fr::from(num as u64)); + chip.num_to_bits(ctx, num, bits).iter().map(|a| *a.value()).collect() + }) } -#[test_case(1 => Fr::one(); "inner_product_simple(): 1 -> 1")] -pub fn test_get_field_element(n: u64) -> F { - let chip = GateChip::default(); - chip.get_field_element(n) +#[test_case(Fr::from(3), BigUint::from(3u32), 4 => Fr::from(27); "pow_var(): 3^3 = 27")] +pub fn test_pow_var(a: Fr, exp: BigUint, max_bits: usize) -> Fr { + assert!(exp.bits() <= max_bits as u64); + base_test().run_gate(|ctx, chip| { + let a = ctx.load_witness(a); + let exp = ctx.load_witness(biguint_to_fe(&exp)); + *chip.pow_var(ctx, a, exp, max_bits).value() + }) } diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs index a37fff49..06f32f20 100644 --- a/halo2-base/src/gates/tests/general.rs +++ b/halo2-base/src/gates/tests/general.rs @@ -1,13 +1,18 @@ -use crate::gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, - flex_gate::{GateChip, GateInstructions}, - range::{RangeChip, RangeInstructions}, -}; -use crate::halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; +use crate::ff::Field; +use crate::gates::flex_gate::threads::parallelize_core; +use crate::halo2_proofs::halo2curves::bn256::Fr; use crate::utils::{BigPrimeField, ScalarField}; +use crate::{ + gates::{ + flex_gate::{GateChip, GateInstructions}, + range::{RangeChip, RangeInstructions}, + }, + utils::testing::base_test, +}; use crate::{Context, QuantumCell::Constant}; -use rand::rngs::OsRng; -use rayon::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use test_log::test; fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { let [a, b, c]: [_; 3] = ctx.assign_witnesses(inputs).try_into().unwrap(); @@ -25,7 +30,7 @@ fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { // test idx_to_indicator chip.idx_to_indicator(ctx, Constant(F::from(3u64)), 4); - let bits = ctx.assign_witnesses([F::zero(), F::one()]); + let bits = ctx.assign_witnesses([F::ZERO, F::ONE]); chip.bits_to_indicator(ctx, &bits); chip.is_equal(ctx, b, a); @@ -33,47 +38,21 @@ fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { chip.is_zero(ctx, a); } -#[test] -fn test_gates() { - let k = 6; - let inputs = [10u64, 12u64, 120u64].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - gate_tests(builder.main(0), inputs); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); -} - #[test] fn test_multithread_gates() { - let k = 6; - let inputs = [10u64, 12u64, 120u64].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - gate_tests(builder.main(0), inputs); - - let thread_ids = (0..4usize).map(|_| builder.get_new_thread_id()).collect::>(); - let new_threads = thread_ids - .into_par_iter() - .map(|id| { - let mut ctx = Context::new(builder.witness_gen_only(), id); - gate_tests(&mut ctx, [(); 3].map(|_| Fr::random(OsRng))); - ctx - }) - .collect::>(); - builder.threads[0].extend(new_threads); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + let mut rng = StdRng::seed_from_u64(0); + base_test().k(6).bench_builder( + vec![[Fr::ZERO; 3]; 4], + (0..4usize).map(|_| [(); 3].map(|_| Fr::random(&mut rng))).collect(), + |pool, _, inputs| { + parallelize_core(pool, inputs, |ctx, input| { + gate_tests(ctx, input); + }); + }, + ); } +/* #[cfg(feature = "dev-graph")] #[test] fn plot_gates() { @@ -91,21 +70,19 @@ fn plot_gates() { // auto-tune circuit builder.config(k, Some(9)); // create circuit - let circuit = GateCircuitBuilder::keygen(builder); + let circuit = RangeCircuitBuilder::keygen(builder); halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); } +*/ fn range_tests( ctx: &mut Context, - lookup_bits: usize, + chip: &RangeChip, inputs: [F; 2], range_bits: usize, lt_bits: usize, ) { let [a, b]: [_; 2] = ctx.assign_witnesses(inputs).try_into().unwrap(); - let chip = RangeChip::default(lookup_bits); - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - chip.range_check(ctx, a, range_bits); chip.check_less_than(ctx, a, b, lt_bits); @@ -119,51 +96,32 @@ fn range_tests( #[test] fn test_range_single() { - let k = 11; - let inputs = [100, 101].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - range_tests(builder.main(0), 3, inputs, 8, 8); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(11).lookup_bits(3).bench_builder( + [Fr::ZERO; 2], + [100, 101].map(Fr::from), + |pool, range, inputs| { + range_tests(pool.main(), range, inputs, 8, 8); + }, + ); } #[test] fn test_range_multicolumn() { - let k = 5; let inputs = [100, 101].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - range_tests(builder.main(0), 3, inputs, 8, 8); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(5).lookup_bits(3).run(|ctx, range| { + range_tests(ctx, range, inputs, 8, 8); + }) } -#[cfg(feature = "dev-graph")] #[test] -fn plot_range() { - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Range Layout", ("sans-serif", 60)).unwrap(); - - let k = 11; - let inputs = [0, 0].map(Fr::from); - let mut builder = GateThreadBuilder::new(false); - range_tests(builder.main(0), 3, inputs, 8, 8); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::keygen(builder); - halo2_proofs::dev::CircuitLayout::default().render(7, &circuit, &root).unwrap(); +fn test_multithread_range() { + base_test().k(6).lookup_bits(3).unusable_rows(20).bench_builder( + vec![[Fr::ZERO; 2]; 3], + vec![[0, 1].map(Fr::from), [100, 101].map(Fr::from), [254, 255].map(Fr::from)], + |pool, range, inputs| { + parallelize_core(pool, inputs, |ctx, input| { + range_tests(ctx, range, input, 8, 8); + }); + }, + ); } diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs index 33cbaa94..6d709b48 100644 --- a/halo2-base/src/gates/tests/idx_to_indicator.rs +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -1,8 +1,7 @@ +use crate::ff::Field; +use crate::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use crate::{ - gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder}, - GateChip, GateInstructions, - }, + gates::{GateChip, GateInstructions}, halo2_proofs::{ halo2curves::bn256::Fr, plonk::keygen_pk, @@ -12,42 +11,43 @@ use crate::{ utils::testing::{check_proof, gen_proof}, QuantumCell::Witness, }; -use ff::Field; use itertools::Itertools; use rand::{rngs::OsRng, thread_rng, Rng}; +use test_log::test; // soundness checks for `idx_to_indicator` function fn test_idx_to_indicator_gen(k: u32, len: usize) { // first create proving and verifying key - let mut builder = GateThreadBuilder::keygen(); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(k as usize); let gate = GateChip::default(); let dummy_idx = Witness(Fr::zero()); let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); // get the offsets of the indicator cells for later 'pranking' let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); - // set env vars - builder.config(k as usize, Some(9)); - let circuit = GateCircuitBuilder::keygen(builder); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::setup(k, OsRng); // generate proving key - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = keygen_vk(¶ms, &builder).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); let vk = pk.get_vk(); // pk consumed vk + let break_points = builder.break_points(); + drop(builder); // now create different proofs to test the soundness of the circuit let gen_pf = |idx: usize, ind_witnesses: &[Fr]| { - let mut builder = GateThreadBuilder::prover(); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); let gate = GateChip::default(); let idx = Witness(Fr::from(idx as u64)); - gate.idx_to_indicator(builder.main(0), idx, len); + let ctx = builder.main(0); + gate.idx_to_indicator(ctx, idx, len); // prank the indicator cells for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { - builder.main(0).advice[*offset] = Assigned::Trivial(*witness); + ctx.advice[*offset] = Assigned::Trivial(*witness); } - let circuit = GateCircuitBuilder::prover(builder, vec![vec![]]); // no break points - gen_proof(¶ms, &pk, circuit) + gen_proof(¶ms, &pk, builder) }; // expected answer diff --git a/halo2-base/src/gates/tests/mod.rs b/halo2-base/src/gates/tests/mod.rs index 9bed2c6f..8e35b53e 100644 --- a/halo2-base/src/gates/tests/mod.rs +++ b/halo2-base/src/gates/tests/mod.rs @@ -1,10 +1,9 @@ use crate::halo2_proofs::halo2curves::bn256::Fr; -mod bitwise_rotate; mod flex_gate; mod general; mod idx_to_indicator; mod neg_prop; mod pos_prop; -mod range_gate; +mod range; mod utils; diff --git a/halo2-base/src/gates/tests/neg_prop.rs b/halo2-base/src/gates/tests/neg_prop.rs index db838b40..27994ac0 100644 --- a/halo2-base/src/gates/tests/neg_prop.rs +++ b/halo2-base/src/gates/tests/neg_prop.rs @@ -1,31 +1,19 @@ -use std::env::set_var; - -use ff::Field; -use itertools::Itertools; -use num_bigint::BigUint; -use proptest::{collection::vec, prelude::*}; -use rand::rngs::OsRng; - -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::{bn256::Fr, FieldExt}, - plonk::Assigned, -}; use crate::{ + ff::Field, gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, - range::{RangeChip, RangeInstructions}, - tests::{ - pos_prop::{rand_bin_witness, rand_fr, rand_witness}, - utils, - }, - GateChip, GateInstructions, + range::RangeInstructions, + tests::{pos_prop::rand_fr, utils}, + GateInstructions, }, - utils::{biguint_to_fe, bit_length, fe_to_biguint, ScalarField}, - QuantumCell, + halo2_proofs::halo2curves::bn256::Fr, + utils::{biguint_to_fe, bit_length, fe_to_biguint, testing::base_test, ScalarField}, QuantumCell::Witness, }; +use num_bigint::BigUint; +use proptest::{collection::vec, prelude::*}; +use rand::rngs::OsRng; + // Strategies for generating random witnesses prop_compose! { // length == 1 is just selecting [0] which should be covered in unit test @@ -40,8 +28,8 @@ prop_compose! { prop_compose! { fn select_strat(k_bounds: (usize, usize)) - (k in k_bounds.0..=k_bounds.1, a in rand_witness(), b in rand_witness(), sel in rand_bin_witness(), rand_output in rand_fr()) - -> (usize, QuantumCell, QuantumCell, QuantumCell, Fr) { + (k in k_bounds.0..=k_bounds.1, a in rand_fr(), b in rand_fr(), sel in any::(), rand_output in rand_fr()) + -> (usize, Fr, Fr, bool, Fr) { (k, a, b, sel, rand_output) } } @@ -49,8 +37,8 @@ prop_compose! { prop_compose! { fn select_by_indicator_strat(k_bounds: (usize, usize), max_size: usize) (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), a in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) - -> (usize, Vec>, usize, Fr) { + (k in Just(k), a in vec(rand_fr(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec, usize, Fr) { (k, a, idx, rand_output) } } @@ -58,8 +46,8 @@ prop_compose! { prop_compose! { fn select_from_idx_strat(k_bounds: (usize, usize), max_size: usize) (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), cells in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) - -> (usize, Vec>, usize, Fr) { + (k in Just(k), cells in vec(rand_fr(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec, usize, Fr) { (k, cells, idx, rand_output) } } @@ -67,8 +55,8 @@ prop_compose! { prop_compose! { fn inner_product_strat(k_bounds: (usize, usize), max_size: usize) (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in rand_fr()) - -> (usize, Vec>, Vec>, Fr) { + (k in Just(k), a in vec(rand_fr(), len), b in vec(rand_fr(), len), rand_output in rand_fr()) + -> (usize, Vec, Vec, Fr) { (k, a, b, rand_output) } } @@ -76,8 +64,8 @@ prop_compose! { prop_compose! { fn inner_product_left_last_strat(k_bounds: (usize, usize), max_size: usize) (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in (rand_fr(), rand_fr())) - -> (usize, Vec>, Vec>, (Fr, Fr)) { + (k in Just(k), a in vec(rand_fr(), len), b in vec(rand_fr(), len), rand_output in (rand_fr(), rand_fr())) + -> (usize, Vec, Vec, (Fr, Fr)) { (k, a, b, rand_output) } } @@ -121,7 +109,7 @@ fn check_idx_to_indicator(idx: Fr, len: usize, ind_witnesses: &[Fr]) -> bool { return false; } - let idx_val = idx.get_lower_128() as usize; + let idx_val = idx.get_lower_64() as usize; // Check that all indexes are zero except for the one at idx for (i, v) in ind_witnesses.iter().enumerate() { @@ -133,265 +121,146 @@ fn check_idx_to_indicator(idx: Fr, len: usize, ind_witnesses: &[Fr]) -> bool { } // verify rand_output == a if sel == 1, rand_output == b if sel == 0 -fn check_select(a: Fr, b: Fr, sel: Fr, rand_output: Fr) -> bool { - if (sel == Fr::zero() && rand_output != b) || (sel == Fr::one() && rand_output != a) { +fn check_select(a: Fr, b: Fr, sel: bool, rand_output: Fr) -> bool { + if (!sel && rand_output != b) || (sel && rand_output != a) { return false; } true } -fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[Fr]) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - // assign value to advice column before by assigning `idx` via ctx.load() -> use same method as ind_offsets to get offset - let dummy_idx = Witness(Fr::from(idx as u64)); - let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); - // get the offsets of the indicator cells for later 'pranking' - builder.config(k, Some(9)); - let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); - // prank the indicator cells - // TODO: prank the entire advice column with random values - for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { - builder.main(0).advice[*offset] = Assigned::Trivial(*witness); - } - // Get idx and indicator from advice column - // Apply check instance function to `idx` and `ind_witnesses` - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values +fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[Fr]) { + // Check soundness of witness values let is_valid_witness = check_idx_to_indicator(Fr::from(idx as u64), len, ind_witnesses); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + // assign value to advice column before by assigning `idx` via ctx.load() -> use same method as ind_offsets to get offset + let dummy_idx = Witness(Fr::from(idx as u64)); + let mut indicator = gate.idx_to_indicator(ctx, dummy_idx, len); + for (advice, prank_val) in indicator.iter_mut().zip(ind_witnesses) { + advice.debug_prank(ctx, *prank_val); + } + }); } -fn neg_test_select( - k: usize, - a: QuantumCell, - b: QuantumCell, - sel: QuantumCell, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - // add select gate - let select = gate.select(builder.main(0), a, b, sel); - - // Get the offset of `select`s output for later 'pranking' - builder.config(k, Some(9)); - let select_offset = select.cell.unwrap().offset; - // Prank the output - builder.main(0).advice[select_offset] = Assigned::Trivial(rand_output); - - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of output - let is_valid_instance = check_select(*a.value(), *b.value(), *sel.value(), rand_output); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_instance, - // if the proof is invalid, ignore - Err(_) => !is_valid_instance, - } +fn neg_test_select(k: usize, a: Fr, b: Fr, sel: bool, prank_output: Fr) { + // Check soundness of output + let is_valid_instance = check_select(a, b, sel, prank_output); + base_test().k(k as u32).expect_satisfied(is_valid_instance).run_gate(|ctx, gate| { + let [a, b, sel] = [a, b, Fr::from(sel)].map(|x| ctx.load_witness(x)); + let select = gate.select(ctx, a, b, sel); + select.debug_prank(ctx, prank_output); + }) } -fn neg_test_select_by_indicator( - k: usize, - a: Vec>, - idx: usize, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let indicator = gate.idx_to_indicator(builder.main(0), Witness(Fr::from(idx as u64)), a.len()); - let a_idx = gate.select_by_indicator(builder.main(0), a.clone(), indicator); - builder.config(k, Some(9)); - - let a_idx_offset = a_idx.cell.unwrap().offset; - builder.main(0).advice[a_idx_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // retrieve the value of a[idx] and check that it is equal to rand_output - let is_valid_witness = rand_output == *a[idx].value(); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } +fn neg_test_select_by_indicator(k: usize, a: Vec, idx: usize, prank_output: Fr) { + // retrieve the value of a[idx] and check that it is equal to rand_output + let is_valid_witness = prank_output == a[idx]; + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let indicator = gate.idx_to_indicator(ctx, Witness(Fr::from(idx as u64)), a.len()); + let a = ctx.assign_witnesses(a); + let a_idx = gate.select_by_indicator(ctx, a, indicator); + a_idx.debug_prank(ctx, prank_output); + }); } -fn neg_test_select_from_idx( - k: usize, - cells: Vec>, - idx: usize, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let idx_val = - gate.select_from_idx(builder.main(0), cells.clone(), Witness(Fr::from(idx as u64))); - builder.config(k, Some(9)); - - let idx_offset = idx_val.cell.unwrap().offset; - builder.main(0).advice[idx_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - let is_valid_witness = rand_output == *cells[idx].value(); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } +fn neg_test_select_from_idx(k: usize, cells: Vec, idx: usize, prank_output: Fr) { + // Check soundness of witness values + let is_valid_witness = prank_output == cells[idx]; + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let cells = ctx.assign_witnesses(cells); + let idx_val = gate.select_from_idx(ctx, cells, Witness(Fr::from(idx as u64))); + idx_val.debug_prank(ctx, prank_output); + }); } -fn neg_test_inner_product( - k: usize, - a: Vec>, - b: Vec>, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let inner_product = gate.inner_product(builder.main(0), a.clone(), b.clone()); - builder.config(k, Some(9)); - - let inner_product_offset = inner_product.cell.unwrap().offset; - builder.main(0).advice[inner_product_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - let is_valid_witness = rand_output == utils::inner_product_ground_truth(&(a, b)); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } +fn neg_test_inner_product(k: usize, a: Vec, b: Vec, prank_output: Fr) { + let is_valid_witness = prank_output == utils::inner_product_ground_truth(&a, &b); + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let a = ctx.assign_witnesses(a); + let inner_product = gate.inner_product(ctx, a, b.into_iter().map(Witness)); + inner_product.debug_prank(ctx, prank_output); + }); } fn neg_test_inner_product_left_last( k: usize, - a: Vec>, - b: Vec>, - rand_output: (Fr, Fr), -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let inner_product = gate.inner_product_left_last(builder.main(0), a.clone(), b.clone()); - builder.config(k, Some(9)); - - let inner_product_offset = - (inner_product.0.cell.unwrap().offset, inner_product.1.cell.unwrap().offset); - // prank the output cells - builder.main(0).advice[inner_product_offset.0] = Assigned::Trivial(rand_output.0); - builder.main(0).advice[inner_product_offset.1] = Assigned::Trivial(rand_output.1); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // (inner_product_ground_truth, a[a.len()-1]) - let inner_product_ground_truth = utils::inner_product_ground_truth(&(a.clone(), b)); - let is_valid_witness = - rand_output.0 == inner_product_ground_truth && rand_output.1 == *a[a.len() - 1].value(); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } + a: Vec, + b: Vec, + (prank_output, prank_a_last): (Fr, Fr), +) { + let is_valid_witness = prank_output == utils::inner_product_ground_truth(&a, &b) + && prank_a_last == *a.last().unwrap(); + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let a = ctx.assign_witnesses(a); + let (inner_product, a_last) = + gate.inner_product_left_last(ctx, a, b.into_iter().map(Witness)); + inner_product.debug_prank(ctx, prank_output); + a_last.debug_prank(ctx, prank_a_last); + }); } // Range Check -fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: Fr) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = RangeChip::default(lookup_bits); - - let a_witness = builder.main(0).load_witness(rand_a); - gate.range_check(builder.main(0), a_witness, range_bits); - - builder.config(k, Some(9)); - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values +fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: Fr) { let correct = fe_to_biguint(&rand_a).bits() <= range_bits as u64; - - MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct + base_test().k(k as u32).lookup_bits(lookup_bits).expect_satisfied(correct).run(|ctx, range| { + let a_witness = ctx.load_witness(rand_a); + range.range_check(ctx, a_witness, range_bits); + }) } // TODO: expand to prank output of is_less_than_safe() -fn neg_test_is_less_than_safe( - k: usize, - b: u64, - lookup_bits: usize, - rand_a: Fr, - prank_out: bool, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = RangeChip::default(lookup_bits); - let ctx = builder.main(0); - - let a_witness = ctx.load_witness(rand_a); // cannot prank this later because this witness will be copy-constrained - let out = gate.is_less_than_safe(ctx, a_witness, b); - - let out_idx = out.cell.unwrap().offset; - ctx.advice[out_idx] = Assigned::Trivial(Fr::from(prank_out)); - - builder.config(k, Some(9)); - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // println!("rand_a: {rand_a:?}, b: {b:?}"); +fn neg_test_is_less_than_safe(k: usize, b: u64, lookup_bits: usize, rand_a: Fr, prank_out: bool) { let a_big = fe_to_biguint(&rand_a); let is_lt = a_big < BigUint::from(b); let correct = (is_lt == prank_out) && (a_big.bits() as usize <= (bit_length(b) + lookup_bits - 1) / lookup_bits * lookup_bits); // circuit should always fail if `a` doesn't pass range check - MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct + + base_test().k(k as u32).lookup_bits(lookup_bits).expect_satisfied(correct).run(|ctx, range| { + let a_witness = ctx.load_witness(rand_a); + let out = range.is_less_than_safe(ctx, a_witness, b); + out.debug_prank(ctx, Fr::from(prank_out)); + }); } proptest! { // Note setting the minimum value of k to 8 is intentional as it is the smallest value that will not cause an `out of columns` error. Should be noted that filtering by len * (number cells per iteration) < 2^k leads to the filtering of to many cases and the failure of the tests w/o any runs. #[test] fn prop_test_neg_idx_to_indicator((k, len, idx, witness_vals) in idx_to_indicator_strat((10,20),100)) { - prop_assert!(neg_test_idx_to_indicator(k, len, idx, witness_vals.as_slice())); + neg_test_idx_to_indicator(k, len, idx, witness_vals.as_slice()); } #[test] fn prop_test_neg_select((k, a, b, sel, rand_output) in select_strat((10,20))) { - prop_assert!(neg_test_select(k, a, b, sel, rand_output)); + neg_test_select(k, a, b, sel, rand_output); } #[test] fn prop_test_neg_select_by_indicator((k, a, idx, rand_output) in select_by_indicator_strat((12,20),100)) { - prop_assert!(neg_test_select_by_indicator(k, a, idx, rand_output)); + neg_test_select_by_indicator(k, a, idx, rand_output); } #[test] fn prop_test_neg_select_from_idx((k, cells, idx, rand_output) in select_from_idx_strat((10,20),100)) { - prop_assert!(neg_test_select_from_idx(k, cells, idx, rand_output)); + neg_test_select_from_idx(k, cells, idx, rand_output); } #[test] fn prop_test_neg_inner_product((k, a, b, rand_output) in inner_product_strat((10,20),100)) { - prop_assert!(neg_test_inner_product(k, a, b, rand_output)); + neg_test_inner_product(k, a, b, rand_output); } #[test] fn prop_test_neg_inner_product_left_last((k, a, b, rand_output) in inner_product_left_last_strat((10,20),100)) { - prop_assert!(neg_test_inner_product_left_last(k, a, b, rand_output)); + neg_test_inner_product_left_last(k, a, b, rand_output); } #[test] fn prop_test_neg_range_check((k, range_bits, lookup_bits, rand_a) in range_check_strat((10,23),90)) { - prop_assert!(neg_test_range_check(k, range_bits, lookup_bits, rand_a)); + neg_test_range_check(k, range_bits, lookup_bits, rand_a); } #[test] fn prop_test_neg_is_less_than_safe((k, b, lookup_bits, rand_a, out) in is_less_than_safe_strat((10,20))) { - prop_assert!(neg_test_is_less_than_safe(k, b, lookup_bits, rand_a, out)); + neg_test_is_less_than_safe(k, b, lookup_bits, rand_a, out); } } diff --git a/halo2-base/src/gates/tests/pos_prop.rs b/halo2-base/src/gates/tests/pos_prop.rs index fd79e33a..927801fe 100644 --- a/halo2-base/src/gates/tests/pos_prop.rs +++ b/halo2-base/src/gates/tests/pos_prop.rs @@ -1,19 +1,26 @@ -use crate::gates::tests::{flex_gate, range_gate, utils::*, Fr}; -use crate::utils::{bit_length, fe_to_biguint}; +use std::cmp::max; + +use crate::ff::{Field, PrimeField}; +use crate::gates::tests::{flex_gate, range, utils::*, Fr}; +use crate::utils::{biguint_to_fe, bit_length, fe_to_biguint}; use crate::{QuantumCell, QuantumCell::Witness}; + +use num_bigint::{BigUint, RandBigInt, RandomBits}; use proptest::{collection::vec, prelude::*}; -//TODO: implement Copy for rand witness and rand fr to allow for array creation -// create vec and convert to array??? -//TODO: implement arbitrary for fr using looks like you'd probably need to implement your own TestFr struct to implement Arbitrary: https://docs.rs/quickcheck/latest/quickcheck/trait.Arbitrary.html , can probably just hack it from Fr = [u64; 4] +use rand::rngs::StdRng; +use rand::SeedableRng; + prop_compose! { - pub fn rand_fr()(val in any::()) -> Fr { - Fr::from(val) + pub fn rand_fr()(seed in any::()) -> Fr { + let rng = StdRng::seed_from_u64(seed); + Fr::random(rng) } } prop_compose! { - pub fn rand_witness()(val in any::()) -> QuantumCell { - Witness(Fr::from(val)) + pub fn rand_witness()(seed in any::()) -> QuantumCell { + let rng = StdRng::seed_from_u64(seed); + Witness(Fr::random(rng)) } } @@ -30,25 +37,33 @@ prop_compose! { } prop_compose! { - pub fn rand_fr_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> Fr { - Fr::from(val) + pub fn rand_fr_range(bits: u64)(seed in any::()) -> Fr { + let mut rng = StdRng::seed_from_u64(seed); + let n = rng.sample(RandomBits::new(bits)); + biguint_to_fe(&n) } } prop_compose! { - pub fn rand_witness_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> QuantumCell { - Witness(Fr::from(val)) + pub fn rand_witness_range(bits: u64)(x in rand_fr_range(bits)) -> QuantumCell { + Witness(x) } } -// LEsson here 0..2^range_bits fails with 'Uniform::new called with `low >= high` -// therfore to still have a range of 0..2^range_bits we need on a mod it by 2^range_bits -// note k > lookup_bits prop_compose! { - fn range_check_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_range_bits: u32) - (range_bits in 2..=max_range_bits, k in k_lo..=k_hi) - (k in Just(k), lookup_bits in min_lookup_bits..(k-3), a in rand_fr_range(0, range_bits), - range_bits in Just(range_bits)) + fn lookup_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) + (k in k_lo..=k_hi) + (k in Just(k), lookup_bits in min_lookup_bits..k) + -> (usize, usize) { + (k, lookup_bits) + } +} +// k is in [k_lo, k_hi] +// lookup_bits is in [min_lookup_bits, k-1] +prop_compose! { + fn range_check_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_range_bits: u64) + ((k, lookup_bits) in lookup_strat((k_lo,k_hi), min_lookup_bits), range_bits in 2..=max_range_bits) + (k in Just(k), lookup_bits in Just(lookup_bits), a in rand_fr_range(range_bits), range_bits in Just(range_bits)) -> (usize, usize, Fr, usize) { (k, lookup_bits, a, range_bits as usize) } @@ -56,25 +71,30 @@ prop_compose! { prop_compose! { fn check_less_than_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_num_bits: usize) - (num_bits in 2..max_num_bits, k in k_lo..=k_hi) - (k in Just(k), a in rand_witness_range(0, num_bits as u32), b in rand_witness_range(0, num_bits as u32), - num_bits in Just(num_bits), lookup_bits in min_lookup_bits..k) - -> (usize, usize, QuantumCell, QuantumCell, usize) { + (num_bits in 2..max_num_bits, k in k_lo..=k_hi) + (k in Just(k), num_bits in Just(num_bits), lookup_bits in min_lookup_bits..k, seed in any::()) + -> (usize, usize, Fr, Fr, usize) { + let mut rng = StdRng::seed_from_u64(seed); + let mut b = rng.sample(RandomBits::new(num_bits as u64)); + if b == BigUint::from(0u32) { + b = BigUint::from(1u32) + } + let a = rng.gen_biguint_below(&b); + let [a,b] = [a,b].map(|x| biguint_to_fe(&x)); (k, lookup_bits, a, b, num_bits) } } prop_compose! { fn check_less_than_safe_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) - (k in k_lo..=k_hi) - (k in Just(k), b in any::(), a in rand_fr(), lookup_bits in min_lookup_bits..k) - -> (usize, usize, Fr, u64) { + (k in k_lo..=k_hi, b in any::()) + (lookup_bits in min_lookup_bits..k, k in Just(k), a in 0..b, b in Just(b)) + -> (usize, usize, u64, u64) { (k, lookup_bits, a, b) } } proptest! { - // Flex Gate Positive Tests #[test] fn prop_test_add(input in vec(rand_witness(), 2)) { @@ -128,8 +148,7 @@ proptest! { #[test] fn prop_test_assert_bit(input in rand_fr()) { let ground_truth = input == Fr::one() || input == Fr::zero(); - let result = flex_gate::test_assert_bit(input).is_ok(); - prop_assert_eq!(result, ground_truth); + flex_gate::test_assert_bit(input, ground_truth); } // Note: due to unwrap after inversion this test will fail if the denominator is zero so we want to test for that. Therefore we do not filter for zero values. @@ -147,14 +166,18 @@ proptest! { #[test] fn prop_test_inner_product(inputs in (vec(rand_witness(), 0..=100), vec(rand_witness(), 0..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { - let ground_truth = inner_product_ground_truth(&inputs); + let a = inputs.0.iter().map(|x| *x.value()).collect::>(); + let b = inputs.1.iter().map(|x| *x.value()).collect::>(); + let ground_truth = inner_product_ground_truth(&a, &b); let result = flex_gate::test_inner_product(inputs); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_inner_product_left_last(inputs in (vec(rand_witness(), 1..=100), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { - let ground_truth = inner_product_left_last_ground_truth(&inputs); + let a = inputs.0.iter().map(|x| *x.value()).collect::>(); + let b = inputs.1.iter().map(|x| *x.value()).collect::>(); + let ground_truth = inner_product_left_last_ground_truth(&a, &b); let result = flex_gate::test_inner_product_left_last(inputs); prop_assert_eq!(result, ground_truth); } @@ -205,21 +228,21 @@ proptest! { #[test] fn prop_test_idx_to_indicator(input in (rand_witness(), 1..=16_usize)) { let ground_truth = idx_to_indicator_ground_truth(input); - let result = flex_gate::test_idx_to_indicator((input.0, input.1)); + let result = flex_gate::test_idx_to_indicator(input.0, input.1); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_select_by_indicator(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { let ground_truth = select_by_indicator_ground_truth(&inputs); - let result = flex_gate::test_select_by_indicator(inputs); + let result = flex_gate::test_select_by_indicator(inputs.0, inputs.1); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_select_from_idx(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { let ground_truth = select_from_idx_ground_truth(&inputs); - let result = flex_gate::test_select_from_idx(inputs); + let result = flex_gate::test_select_from_idx(inputs.0, inputs.1); prop_assert_eq!(result, ground_truth); } @@ -248,86 +271,110 @@ proptest! { bits.push(tmp & 1); tmp /= 2; } - let result = flex_gate::test_num_to_bits((Fr::from(num), bits.len())); + let result = flex_gate::test_num_to_bits(num as usize, bits.len()); prop_assert_eq!(bits.into_iter().map(Fr::from).collect::>(), result); } - /* #[test] - fn prop_test_lagrange_eval(inputs in vec(rand_fr(), 3)) { + fn prop_test_pow_var(a in rand_fr(), num in any::()) { + let native_res = a.pow_vartime([num]); + let result = flex_gate::test_pow_var(a, BigUint::from(num), Fr::CAPACITY as usize); + prop_assert_eq!(result, native_res); } - */ + /* #[test] - fn prop_test_get_field_element(n in any::()) { - let ground_truth = get_field_element_ground_truth(n); - let result = flex_gate::test_get_field_element::(n); - prop_assert_eq!(result, ground_truth); + fn prop_test_lagrange_eval(inputs in vec(rand_fr(), 3)) { } + */ // Range Check Property Tests #[test] - fn prop_test_is_less_than(a in rand_witness(), b in any::().prop_filter("not zero", |&x| x != 0), - lookup_bits in 4..=16_usize) { - let bits = std::cmp::max(fe_to_biguint(a.value()).bits() as usize, bit_length(b)); - let ground_truth = is_less_than_ground_truth((*a.value(), Fr::from(b))); - let result = range_gate::test_is_less_than(([a, Witness(Fr::from(b))], bits, lookup_bits)); + fn prop_test_is_less_than( + (k, lookup_bits)in lookup_strat((10,18),4), + bits in 1..Fr::CAPACITY as usize, + seed in any::() + ) { + // current is_less_than requires bits to not be too large + prop_assume!(((bits + lookup_bits - 1) / lookup_bits + 1) * lookup_bits <= Fr::CAPACITY as usize); + let mut rng = StdRng::seed_from_u64(seed); + let a = biguint_to_fe(&rng.sample(RandomBits::new(bits as u64))); + let b = biguint_to_fe(&rng.sample(RandomBits::new(bits as u64))); + let ground_truth = is_less_than_ground_truth((a, b)); + let result = range::test_is_less_than(k, lookup_bits, [Witness(a), Witness(b)], bits); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_is_less_than_safe(a in rand_fr().prop_filter("not zero", |&x| x != Fr::zero()), - b in any::().prop_filter("not zero", |&x| x != 0), - lookup_bits in 4..=16_usize) { - prop_assume!(fe_to_biguint(&a).bits() as usize <= bit_length(b)); + fn prop_test_is_less_than_safe( + (k, lookup_bits) in lookup_strat((10,18),4), + a in any::(), + b in any::(), + ) { + prop_assume!(bit_length(a) <= bit_length(b)); + let a = Fr::from(a); let ground_truth = is_less_than_ground_truth((a, Fr::from(b))); - let result = range_gate::test_is_less_than_safe((a, b, lookup_bits)); + let result = range::test_is_less_than_safe(k, lookup_bits, a, b); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_div_mod(inputs in (rand_witness().prop_filter("Non-zero num", |x| *x.value() != Fr::zero()), any::().prop_filter("Non-zero divisor", |x| *x != 0u64), 1..=16_usize)) { - let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); - let result = range_gate::test_div_mod((inputs.0, inputs.1, inputs.2)); + fn prop_test_div_mod( + a in rand_witness(), + b in any::().prop_filter("Non-zero divisor", |x| *x != 0u64) + ) { + let ground_truth = div_mod_ground_truth((*a.value(), b)); + let num_bits = max(fe_to_biguint(a.value()).bits() as usize, bit_length(b)); + prop_assume!(num_bits <= Fr::CAPACITY as usize); + let result = range::test_div_mod(a, b, num_bits); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_get_last_bit(input in rand_fr(), pad_bits in 0..10usize) { - let ground_truth = get_last_bit_ground_truth(input); - let bits = fe_to_biguint(&input).bits() as usize + pad_bits; - let result = range_gate::test_get_last_bit((input, bits)); + fn prop_test_get_last_bit(bits in 1..Fr::CAPACITY as usize, pad_bits in 0..10usize, seed in any::()) { + prop_assume!(bits + pad_bits <= Fr::CAPACITY as usize); + let mut rng = StdRng::seed_from_u64(seed); + let a = rng.sample(RandomBits::new(bits as u64)); + let a = biguint_to_fe(&a); + let ground_truth = get_last_bit_ground_truth(a); + let bits = bits + pad_bits; + let result = range::test_get_last_bit(a, bits); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_div_mod_var(inputs in (rand_witness(), any::(), 1..=16_usize, 1..=16_usize)) { - let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); - let result = range_gate::test_div_mod_var((inputs.0, Witness(Fr::from(inputs.1)), inputs.2, inputs.3)); + fn prop_test_div_mod_var(a in rand_fr(), b in any::()) { + let ground_truth = div_mod_ground_truth((a, b)); + let a_num_bits = fe_to_biguint(&a).bits() as usize; + let lookup_bits = 9; + prop_assume!((a_num_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + let b_num_bits= bit_length(b); + let result = range::test_div_mod_var(Witness(a), Witness(Fr::from(b)), a_num_bits, b_num_bits); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_range_check((k, lookup_bits, a, range_bits) in range_check_strat((14,24), 3, 63)) { - prop_assert_eq!(range_gate::test_range_check(k, lookup_bits, a, range_bits), ()); + fn prop_test_range_check((k, lookup_bits, a, range_bits) in range_check_strat((14,22),3,253)) { + // current range check only works when range_bits isn't too big: + prop_assume!((range_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + range::test_range_check(k, lookup_bits, a, range_bits); } #[test] - fn prop_test_check_less_than((k, lookup_bits, a, b, num_bits) in check_less_than_strat((14,24), 3, 10)) { - prop_assume!(a.value() < b.value()); - prop_assert_eq!(range_gate::test_check_less_than(k, lookup_bits, a, b, num_bits), ()); + fn prop_test_check_less_than((k, lookup_bits, a, b, num_bits) in check_less_than_strat((10,18),8,253)) { + prop_assume!((num_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + range::test_check_less_than(k, lookup_bits, Witness(a), Witness(b), num_bits); } #[test] - fn prop_test_check_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { - prop_assume!(a < Fr::from(b)); - prop_assert_eq!(range_gate::test_check_less_than_safe(k, lookup_bits, a, b), ()); + fn prop_test_check_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((10,18),3)) { + range::test_check_less_than_safe(k, lookup_bits, Fr::from(a), b); } #[test] - fn prop_test_check_big_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { - prop_assume!(a < Fr::from(b)); - prop_assert_eq!(range_gate::test_check_big_less_than_safe(k, lookup_bits, a, b), ()); + fn prop_test_check_big_less_than_safe((k, lookup_bits, a, b, num_bits) in check_less_than_strat((18,22),8,253)) { + prop_assume!((num_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + range::test_check_big_less_than_safe(k, lookup_bits, a, fe_to_biguint(&b)); } } diff --git a/halo2-base/src/gates/tests/range.rs b/halo2-base/src/gates/tests/range.rs new file mode 100644 index 00000000..d477d3f2 --- /dev/null +++ b/halo2-base/src/gates/tests/range.rs @@ -0,0 +1,108 @@ +use super::*; +use crate::utils::biguint_to_fe; +use crate::utils::testing::base_test; +use crate::QuantumCell::Witness; +use crate::{gates::range::RangeInstructions, QuantumCell}; +use num_bigint::BigUint; +use test_case::test_case; + +#[test_case(16, 10, Fr::zero(), 0; "range_check() 0 bits")] +#[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] +pub fn test_range_check(k: usize, lookup_bits: usize, a_val: Fr, range_bits: usize) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a_val); + chip.range_check(ctx, a, range_bits); + }) +} + +#[test_case(12, 10, Witness(Fr::zero()), Witness(Fr::one()), 64; "check_less_than() pos")] +pub fn test_check_less_than( + k: usize, + lookup_bits: usize, + a: QuantumCell, + b: QuantumCell, + num_bits: usize, +) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + chip.check_less_than(ctx, a, b, num_bits); + }) +} + +#[test_case(10, 8, Fr::zero(), 1; "check_less_than_safe() pos")] +pub fn test_check_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: u64) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + chip.check_less_than_safe(ctx, a, b); + }) +} + +#[test_case(10, 8, biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize; "check_big_less_than_safe() pos")] +pub fn test_check_big_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: BigUint) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + chip.check_big_less_than_safe(ctx, a, b) + }) +} + +#[test_case(10, 8, [6, 7].map(Fr::from).map(Witness), 3 => Fr::from(1); "is_less_than() pos")] +pub fn test_is_less_than( + k: usize, + lookup_bits: usize, + inputs: [QuantumCell; 2], + bits: usize, +) -> Fr { + base_test() + .k(k as u32) + .lookup_bits(lookup_bits) + .run(|ctx, chip| *chip.is_less_than(ctx, inputs[0], inputs[1], bits).value()) +} + +#[test_case(10, 8, Fr::from(2), 3 => Fr::from(1); "is_less_than_safe() pos")] +pub fn test_is_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: u64) -> Fr { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + let lt = chip.is_less_than_safe(ctx, a, b); + *lt.value() + }) +} + +#[test_case(10, 8, biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize => Fr::from(1); "is_big_less_than_safe() pos")] +pub fn test_is_big_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: BigUint) -> Fr { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + *chip.is_big_less_than_safe(ctx, a, b).value() + }) +} + +#[test_case(Witness(Fr::from(3)), 2, 2 => (Fr::from(1), Fr::from(1)) ; "div_mod(3, 2)")] +pub fn test_div_mod(a: QuantumCell, b: u64, num_bits: usize) -> (Fr, Fr) { + base_test().run(|ctx, chip| { + let a = chip.div_mod(ctx, a, b, num_bits); + (*a.0.value(), *a.1.value()) + }) +} + +#[test_case(Fr::from(3), 8 => Fr::one() ; "get_last_bit(): 3, 8 bits")] +#[test_case(Fr::from(3), 2 => Fr::one() ; "get_last_bit(): 3, 2 bits")] +#[test_case(Fr::from(0), 2 => Fr::zero() ; "get_last_bit(): 0")] +#[test_case(Fr::from(1), 2 => Fr::one() ; "get_last_bit(): 1")] +#[test_case(Fr::from(2), 2 => Fr::zero() ; "get_last_bit(): 2")] +pub fn test_get_last_bit(a: Fr, bits: usize) -> Fr { + base_test().run(|ctx, chip| { + let a = ctx.load_witness(a); + *chip.get_last_bit(ctx, a, bits).value() + }) +} + +#[test_case(Witness(Fr::from(3)), Witness(Fr::from(2)), 3, 3 => (Fr::one(), Fr::one()); "div_mod_var(3 ,2)")] +pub fn test_div_mod_var( + a: QuantumCell, + b: QuantumCell, + a_num_bits: usize, + b_num_bits: usize, +) -> (Fr, Fr) { + base_test().run(|ctx, chip| { + let a = chip.div_mod_var(ctx, a, b, a_num_bits, b_num_bits); + (*a.0.value(), *a.1.value()) + }) +} diff --git a/halo2-base/src/gates/tests/range_gate.rs b/halo2-base/src/gates/tests/range_gate.rs deleted file mode 100644 index cd8acf52..00000000 --- a/halo2-base/src/gates/tests/range_gate.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::env::set_var; - -use super::*; -use crate::halo2_proofs::dev::MockProver; -use crate::utils::{biguint_to_fe, ScalarField}; -use crate::QuantumCell::Witness; -use crate::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - range::{RangeChip, RangeInstructions}, - }, - utils::BigPrimeField, - QuantumCell, -}; -use num_bigint::BigUint; -use test_case::test_case; - -#[test_case(16, 10, Fr::zero(), 0; "range_check() 0 bits")] -#[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] -pub fn test_range_check(k: usize, lookup_bits: usize, a_val: F, range_bits: usize) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.assign_witnesses([a_val])[0]; - chip.range_check(ctx, a, range_bits); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(12, 10, Witness(Fr::zero()), Witness(Fr::one()), 64; "check_less_than() pos")] -pub fn test_check_less_than( - k: usize, - lookup_bits: usize, - a: QuantumCell, - b: QuantumCell, - num_bits: usize, -) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - chip.check_less_than(ctx, a, b, num_bits); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(10, 8, Fr::zero(), 1; "check_less_than_safe() pos")] -pub fn test_check_less_than_safe(k: usize, lookup_bits: usize, a_val: F, b: u64) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.assign_witnesses([a_val])[0]; - chip.check_less_than_safe(ctx, a, b); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(10, 8, Fr::zero(), 1; "check_big_less_than_safe() pos")] -pub fn test_check_big_less_than_safe( - k: usize, - lookup_bits: usize, - a_val: F, - b: u64, -) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.assign_witnesses([a_val])[0]; - chip.check_big_less_than_safe(ctx, a, BigUint::from(b)); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(([0, 1].map(Fr::from).map(Witness), 3, 12) => Fr::from(1) ; "is_less_than() pos")] -pub fn test_is_less_than( - (inputs, bits, lookup_bits): ([QuantumCell; 2], usize, usize), -) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = chip.is_less_than(ctx, inputs[0], inputs[1], bits); - *a.value() -} - -#[test_case((Fr::zero(), 3, 3) => Fr::from(1) ; "is_less_than_safe() pos")] -pub fn test_is_less_than_safe((a, b, lookup_bits): (F, u64, usize)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.load_witness(a); - let lt = chip.is_less_than_safe(ctx, a, b); - *lt.value() -} - -#[test_case((biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize, 8) => Fr::from(1) ; "is_big_less_than_safe() pos")] -pub fn test_is_big_less_than_safe( - (a, b, lookup_bits): (F, BigUint, usize), -) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.load_witness(a); - let b = chip.is_big_less_than_safe(ctx, a, b); - *b.value() -} - -#[test_case((Witness(Fr::one()), 1, 2) => (Fr::one(), Fr::zero()) ; "div_mod() pos")] -pub fn test_div_mod( - inputs: (QuantumCell, u64, usize), -) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(3); - let a = chip.div_mod(ctx, inputs.0, BigUint::from(inputs.1), inputs.2); - (*a.0.value(), *a.1.value()) -} - -#[test_case((Fr::from(3), 8) => Fr::one() ; "get_last_bit(): 3, 8 bits")] -#[test_case((Fr::from(3), 2) => Fr::one() ; "get_last_bit(): 3, 2 bits")] -#[test_case((Fr::from(0), 2) => Fr::zero() ; "get_last_bit(): 0")] -#[test_case((Fr::from(1), 2) => Fr::one() ; "get_last_bit(): 1")] -#[test_case((Fr::from(2), 2) => Fr::zero() ; "get_last_bit(): 2")] -pub fn test_get_last_bit((a, bits): (F, usize)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(3); - let a = ctx.load_witness(a); - let b = chip.get_last_bit(ctx, a, bits); - *b.value() -} - -#[test_case((Witness(Fr::from(3)), Witness(Fr::from(2)), 3, 3) => (Fr::one(), Fr::one()) ; "div_mod_var() pos")] -pub fn test_div_mod_var( - inputs: (QuantumCell, QuantumCell, usize, usize), -) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(3); - let a = chip.div_mod_var(ctx, inputs.0, inputs.1, inputs.2, inputs.3); - (*a.0.value(), *a.1.value()) -} diff --git a/halo2-base/src/gates/tests/utils.rs b/halo2-base/src/gates/tests/utils.rs index 34e2a435..2b8eb10a 100644 --- a/halo2-base/src/gates/tests/utils.rs +++ b/halo2-base/src/gates/tests/utils.rs @@ -36,28 +36,20 @@ pub fn mul_add_ground_truth(inputs: &[QuantumCell]) -> F { } pub fn mul_not_ground_truth(inputs: &[QuantumCell]) -> F { - (F::one() - *inputs[0].value()) * *inputs[1].value() + (F::ONE - *inputs[0].value()) * *inputs[1].value() } pub fn div_unsafe_ground_truth(inputs: &[QuantumCell]) -> F { inputs[1].value().invert().unwrap() * *inputs[0].value() } -pub fn inner_product_ground_truth( - inputs: &(Vec>, Vec>), -) -> F { - inputs - .0 - .iter() - .zip(inputs.1.iter()) - .fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b.value())) -} - -pub fn inner_product_left_last_ground_truth( - inputs: &(Vec>, Vec>), -) -> (F, F) { - let product = inner_product_ground_truth(inputs); - let last = *inputs.0.last().unwrap().value(); +pub fn inner_product_ground_truth(a: &[F], b: &[F]) -> F { + a.iter().zip(b.iter()).fold(F::ZERO, |acc, (&a, &b)| acc + a * b) +} + +pub fn inner_product_left_last_ground_truth(a: &[F], b: &[F]) -> (F, F) { + let product = inner_product_ground_truth(a, b); + let last = *a.last().unwrap(); (product, last) } @@ -66,7 +58,7 @@ pub fn inner_product_with_sums_ground_truth( ) -> Vec { let (a, b) = &input; let mut result = Vec::new(); - let mut sum = F::zero(); + let mut sum = F::ZERO; // TODO: convert to fold for (ai, bi) in a.iter().zip(b) { let product = *ai.value() * *bi.value(); @@ -79,9 +71,10 @@ pub fn inner_product_with_sums_ground_truth( pub fn sum_products_with_coeff_and_var_ground_truth( input: &(Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), ) -> F { - let expected = input.0.iter().fold(F::zero(), |acc, (coeff, cell1, cell2)| { - acc + *coeff * *cell1.value() * *cell2.value() - }) + *input.1.value(); + let expected = + input.0.iter().fold(F::ZERO, |acc, (coeff, cell1, cell2)| { + acc + *coeff * *cell1.value() * *cell2.value() + }) + *input.1.value(); expected } @@ -90,7 +83,7 @@ pub fn and_ground_truth(inputs: &[QuantumCell]) -> F { } pub fn not_ground_truth(a: &QuantumCell) -> F { - F::one() - *a.value() + F::ONE - *a.value() } pub fn select_ground_truth(inputs: &[QuantumCell]) -> F { @@ -104,7 +97,7 @@ pub fn or_and_ground_truth(inputs: &[QuantumCell]) -> F { pub fn idx_to_indicator_ground_truth(inputs: (QuantumCell, usize)) -> Vec { let (idx, size) = inputs; - let mut indicator = vec![F::zero(); size]; + let mut indicator = vec![F::ZERO; size]; let mut idx_value = size + 1; for i in 0..size as u64 { if F::from(i) == *idx.value() { @@ -113,7 +106,7 @@ pub fn idx_to_indicator_ground_truth(inputs: (QuantumCell, us } } if idx_value < size { - indicator[idx_value] = F::one(); + indicator[idx_value] = F::ONE; } indicator } @@ -122,7 +115,7 @@ pub fn select_by_indicator_ground_truth( inputs: &(Vec>, QuantumCell), ) -> F { let mut idx_value = inputs.0.len() + 1; - let mut indicator = vec![F::zero(); inputs.0.len()]; + let mut indicator = vec![F::ZERO; inputs.0.len()]; for i in 0..inputs.0.len() as u64 { if F::from(i) == *inputs.1.value() { idx_value = i as usize; @@ -130,10 +123,10 @@ pub fn select_by_indicator_ground_truth( } } if idx_value < inputs.0.len() { - indicator[idx_value] = F::one(); + indicator[idx_value] = F::ONE; } // take cross product of indicator and inputs.0 - inputs.0.iter().zip(indicator.iter()).fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b)) + inputs.0.iter().zip(indicator.iter()).fold(F::ZERO, |acc, (a, b)| acc + (*a.value() * *b)) } pub fn select_from_idx_ground_truth( @@ -146,22 +139,22 @@ pub fn select_from_idx_ground_truth( return *inputs.0[i as usize].value(); } } - F::zero() + F::ZERO } pub fn is_zero_ground_truth(x: F) -> F { if x.is_zero().into() { - F::one() + F::ONE } else { - F::zero() + F::ZERO } } pub fn is_equal_ground_truth(inputs: &[QuantumCell]) -> F { if inputs[0].value() == inputs[1].value() { - F::one() + F::ONE } else { - F::zero() + F::ZERO } } @@ -170,17 +163,13 @@ pub fn lagrange_eval_ground_truth(inputs: &[F]) -> (F, F) { } */ -pub fn get_field_element_ground_truth(n: u64) -> F { - F::from(n) -} - // Range Chip Ground Truths pub fn is_less_than_ground_truth(inputs: (F, F)) -> F { if inputs.0 < inputs.1 { - F::one() + F::ONE } else { - F::zero() + F::ZERO } } diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index e5c31636..1b922913 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -1,14 +1,18 @@ //! Base library to build Halo2 circuits. -#![allow(incomplete_features)] #![feature(generic_const_exprs)] -#![feature(const_cmp)] #![feature(stmt_expr_attributes)] #![feature(trait_alias)] +#![feature(associated_type_defaults)] +#![allow(incomplete_features)] #![deny(clippy::perf)] #![allow(clippy::too_many_arguments)] #![warn(clippy::default_numeric_fallback)] #![warn(missing_docs)] +use std::any::TypeId; + +use getset::CopyGetters; +use itertools::Itertools; // Different memory allocator options: #[cfg(feature = "jemallocator")] use jemallocator::Jemalloc; @@ -36,15 +40,21 @@ pub use halo2_proofs; #[cfg(feature = "halo2-axiom")] pub use halo2_proofs_axiom as halo2_proofs; +use halo2_proofs::halo2curves::ff; use halo2_proofs::plonk::Assigned; use utils::ScalarField; +use virtual_region::copy_constraints::SharedCopyConstraintManager; /// Module that contains the main API for creating and working with circuits. +/// `gates` is misleading because we currently only use one custom gate throughout. pub mod gates; +/// Module for the Poseidon hash function. +pub mod poseidon; /// Module for SafeType which enforce value range and realted functions. pub mod safe_types; /// Utility functions for converting between different types of field elements. pub mod utils; +pub mod virtual_region; /// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-axiom")] @@ -94,20 +104,32 @@ impl QuantumCell { } } +/// Unique tag for a context across all virtual regions +pub type ContextTag = (TypeId, usize); + /// Pointer to the position of a cell at `offset` in an advice column within a [Context] of `context_id`. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ContextCell { + /// The [TypeId] of the virtual region that this cell belongs to. + pub type_id: TypeId, /// Identifier of the [Context] that this cell belongs to. pub context_id: usize, /// Relative offset of the cell within this [Context] advice column. pub offset: usize, } +impl ContextCell { + /// Creates a new [ContextCell] with the given `type_id`, `context_id`, and `offset`. + pub fn new(type_id: TypeId, context_id: usize, offset: usize) -> Self { + Self { type_id, context_id, offset } + } +} + /// Pointer containing cell value and location within [Context]. /// /// Note: Performs a copy of the value, should only be used when you are about to assign the value again elsewhere. #[derive(Clone, Copy, Debug)] -pub struct AssignedValue { +pub struct AssignedValue { /// Value of the cell. pub value: Assigned, // we don't use reference to avoid issues with lifetimes (you can't safely borrow from vector and push to it at the same time). // only needed during vkey, pkey gen to fetch the actual cell from the relevant context @@ -125,34 +147,44 @@ impl AssignedValue { _ => unreachable!(), // if trying to fetch an un-evaluated fraction, you will have to do something manual } } + + /// Debug helper function for writing negative tests. This will change the **witness** value in `ctx` corresponding to `self.offset`. + /// This assumes that `ctx` is the context that `self` lies in. + pub fn debug_prank(&self, ctx: &mut Context, prank_value: F) { + ctx.advice[self.cell.unwrap().offset] = Assigned::Trivial(prank_value); + } +} + +impl AsRef> for AssignedValue { + fn as_ref(&self) -> &AssignedValue { + self + } } /// Represents a single thread of an execution trace. /// * We keep the naming [Context] for historical reasons. -#[derive(Clone, Debug)] +/// +/// [Context] is CPU thread-local. +#[derive(Clone, Debug, CopyGetters)] pub struct Context { /// Flag to determine whether only witness generation or proving and verification key generation is being performed. /// * If witness gen is performed many operations can be skipped for optimization. + #[getset(get_copy = "pub")] witness_gen_only: bool, - + /// The challenge phase that this [Context] will map to. + #[getset(get_copy = "pub")] + phase: usize, + /// Identifier for what virtual region this context is in + #[getset(get_copy = "pub")] + type_id: TypeId, /// Identifier to reference cells from this [Context]. - pub context_id: usize, + context_id: usize, /// Single column of advice cells. pub advice: Vec>, - /// [Vec] tracking all cells that lookup is enabled for. - /// * When there is more than 1 advice column all `advice` cells will be copied to a single lookup enabled column to perform lookups. - pub cells_to_lookup: Vec>, - - /// Cell that represents the zero value as AssignedValue - pub zero_cell: Option>, - - // To save time from re-allocating new temporary vectors that get quickly dropped (e.g., for some range checks), we keep a vector with high capacity around that we `clear` before use each time - // This is NOT THREAD SAFE - // Need to use RefCell to avoid borrow rules - // Need to use Rc to borrow this and mutably borrow self at same time - // preallocated_vec_to_assign: Rc>>>, + /// Slight optimization: since zero is so commonly used, keep a reference to the zero cell. + zero_cell: Option>, // ======================================== // General principle: we don't need to optimize anything specific to `witness_gen_only == false` because it is only done during keygen @@ -161,38 +193,45 @@ pub struct Context { /// * Assumed to have the same length as `advice` pub selector: Vec, - // TODO: gates that use fixed columns as selectors? - /// A [Vec] tracking equality constraints between pairs of [Context] `advice` cells. - /// - /// Assumes both `advice` cells are in the same [Context]. - pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, - - /// A [Vec] tracking pairs equality constraints between Fixed values and [Context] `advice` cells. - /// - /// Assumes the constant and `advice` cell are in the same [Context]. - pub constant_equality_constraints: Vec<(F, ContextCell)>, + /// Global shared thread-safe manager for all copy (equality) constraints between virtual advice, constants, and raw external Halo2 cells. + pub copy_manager: SharedCopyConstraintManager, } impl Context { /// Creates a new [Context] with the given `context_id` and witness generation enabled/disabled by the `witness_gen_only` flag. /// * `witness_gen_only`: flag to determine whether public key generation or only witness generation is being performed. /// * `context_id`: identifier to reference advice cells from this [Context] later. - pub fn new(witness_gen_only: bool, context_id: usize) -> Self { + pub fn new( + witness_gen_only: bool, + phase: usize, + type_id: TypeId, + context_id: usize, + copy_manager: SharedCopyConstraintManager, + ) -> Self { Self { witness_gen_only, + phase, + type_id, context_id, advice: Vec::new(), - cells_to_lookup: Vec::new(), - zero_cell: None, selector: Vec::new(), - advice_equality_constraints: Vec::new(), - constant_equality_constraints: Vec::new(), + zero_cell: None, + copy_manager, } } - /// Returns the `witness_gen_only` flag of the [Context] - pub fn witness_gen_only(&self) -> bool { - self.witness_gen_only + /// The context id, this can be used as a tag when CPU multi-threading + pub fn id(&self) -> usize { + self.context_id + } + + /// A unique tag that should identify this context across all virtual regions and phases. + pub fn tag(&self) -> ContextTag { + (self.type_id, self.context_id) + } + + fn latest_cell(&self) -> ContextCell { + ContextCell::new(self.type_id, self.context_id, self.advice.len() - 1) } /// Pushes a [QuantumCell] to the end of the `advice` column ([Vec] of advice cells) in this [Context]. @@ -204,9 +243,12 @@ impl Context { self.advice.push(acell.value); // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); + let new_cell = self.latest_cell(); + self.copy_manager + .lock() + .unwrap() + .advice_equalities + .push((new_cell, acell.cell.unwrap())); } } QuantumCell::Witness(val) => { @@ -219,9 +261,8 @@ impl Context { self.advice.push(Assigned::Trivial(c)); // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.constant_equality_constraints.push((c, new_cell)); + let new_cell = self.latest_cell(); + self.copy_manager.lock().unwrap().constant_equalities.push((c, new_cell)); } } } @@ -230,10 +271,7 @@ impl Context { /// Returns the [AssignedValue] of the last cell in the `advice` column of [Context] or [None] if `advice` is empty pub fn last(&self) -> Option> { self.advice.last().map(|v| { - let cell = (!self.witness_gen_only).then_some(ContextCell { - context_id: self.context_id, - offset: self.advice.len() - 1, - }); + let cell = (!self.witness_gen_only).then_some(self.latest_cell()); AssignedValue { value: *v, cell } }) } @@ -250,8 +288,11 @@ impl Context { offset as usize }; assert!(offset < self.advice.len()); - let cell = - (!self.witness_gen_only).then_some(ContextCell { context_id: self.context_id, offset }); + let cell = (!self.witness_gen_only).then_some(ContextCell::new( + self.type_id, + self.context_id, + offset, + )); AssignedValue { value: self.advice[offset], cell } } @@ -261,7 +302,11 @@ impl Context { /// * Assumes both cells are `advice` cells pub fn constrain_equal(&mut self, a: &AssignedValue, b: &AssignedValue) { if !self.witness_gen_only { - self.advice_equality_constraints.push((a.cell.unwrap(), b.cell.unwrap())); + self.copy_manager + .lock() + .unwrap() + .advice_equalities + .push((a.cell.unwrap(), b.cell.unwrap())); } } @@ -341,25 +386,28 @@ impl Context { if !self.witness_gen_only { // Add equality constraints between cells in the advice column. for (offset1, offset2) in equality_offsets { - self.advice_equality_constraints.push(( - ContextCell { - context_id: self.context_id, - offset: row_offset.wrapping_add_signed(offset1), - }, - ContextCell { - context_id: self.context_id, - offset: row_offset.wrapping_add_signed(offset2), - }, + self.copy_manager.lock().unwrap().advice_equalities.push(( + ContextCell::new( + self.type_id, + self.context_id, + row_offset.wrapping_add_signed(offset1), + ), + ContextCell::new( + self.type_id, + self.context_id, + row_offset.wrapping_add_signed(offset2), + ), )); } // Add equality constraints between cells in the advice column and external cells (Fixed column). for (cell, offset) in external_equality { - self.advice_equality_constraints.push(( + self.copy_manager.lock().unwrap().advice_equalities.push(( cell.unwrap(), - ContextCell { - context_id: self.context_id, - offset: row_offset.wrapping_add_signed(offset), - }, + ContextCell::new( + self.type_id, + self.context_id, + row_offset.wrapping_add_signed(offset), + ), )); } } @@ -377,8 +425,11 @@ impl Context { .iter() .enumerate() .map(|(i, v)| { - let cell = (!self.witness_gen_only) - .then_some(ContextCell { context_id: self.context_id, offset: row_offset + i }); + let cell = (!self.witness_gen_only).then_some(ContextCell::new( + self.type_id, + self.context_id, + row_offset + i, + )); AssignedValue { value: *v, cell } }) .collect() @@ -404,13 +455,29 @@ impl Context { self.last().unwrap() } + /// Assigns a list of constant values and returns the corresponding assigned cells. + /// * `c`: the list of constant values to be assigned + pub fn load_constants(&mut self, c: &[F]) -> Vec> { + c.iter().map(|v| self.load_constant(*v)).collect_vec() + } + /// Assigns the 0 value to a new cell or returns a previously assigned zero cell from `zero_cell`. pub fn load_zero(&mut self) -> AssignedValue { if let Some(zcell) = &self.zero_cell { return *zcell; } - let zero_cell = self.load_constant(F::zero()); + let zero_cell = self.load_constant(F::ZERO); self.zero_cell = Some(zero_cell); zero_cell } + + /// Helper function for debugging using `MockProver`. This adds a constraint that always fails. + /// The `MockProver` will print out the row, column where it fails, so it serves as a debugging "break point" + /// so you can add to your code to search for where the actual constraint failure occurs. + pub fn debug_assert_false(&mut self) { + use rand_chacha::rand_core::OsRng; + let rand1 = self.load_witness(F::random(OsRng)); + let rand2 = self.load_witness(F::random(OsRng)); + self.constrain_equal(&rand1, &rand2); + } } diff --git a/halo2-base/src/poseidon/hasher/mds.rs b/halo2-base/src/poseidon/hasher/mds.rs new file mode 100644 index 00000000..91b7d262 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/mds.rs @@ -0,0 +1,172 @@ +#![allow(clippy::needless_range_loop)] +use getset::Getters; + +use crate::ff::PrimeField; + +/// The type used to hold the MDS matrix +pub(crate) type Mds = [[F; T]; T]; + +/// `MDSMatrices` holds the MDS matrix as well as transition matrix which is +/// also called `pre_sparse_mds` and sparse matrices that enables us to reduce +/// number of multiplications in apply MDS step +#[derive(Debug, Clone, Getters)] +pub struct MDSMatrices { + /// MDS matrix + #[getset(get = "pub")] + pub(crate) mds: MDSMatrix, + /// Transition matrix + #[getset(get = "pub")] + pub(crate) pre_sparse_mds: MDSMatrix, + /// Sparse matrices + #[getset(get = "pub")] + pub(crate) sparse_matrices: Vec>, +} + +/// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear +/// layer of partial rounds instead of the original MDS +#[derive(Debug, Clone, Getters)] +pub struct SparseMDSMatrix { + /// row + #[getset(get = "pub")] + pub(crate) row: [F; T], + /// column transpose + #[getset(get = "pub")] + pub(crate) col_hat: [F; RATE], +} + +/// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon +#[derive(Clone, Debug)] +pub struct MDSMatrix(pub(crate) Mds); + +impl AsRef> for MDSMatrix { + fn as_ref(&self) -> &Mds { + &self.0 + } +} + +impl MDSMatrix { + pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] { + let mut res = [F::ZERO; T]; + for i in 0..T { + for j in 0..T { + res[i] += self.0[i][j] * v[j]; + } + } + res + } + + pub(crate) fn identity() -> Mds { + let mut mds = [[F::ZERO; T]; T]; + for i in 0..T { + mds[i][i] = F::ONE; + } + mds + } + + /// Multiplies two MDS matrices. Used in sparse matrix calculations + pub(crate) fn mul(&self, other: &Self) -> Self { + let mut res = [[F::ZERO; T]; T]; + for i in 0..T { + for j in 0..T { + for k in 0..T { + res[i][j] += self.0[i][k] * other.0[k][j]; + } + } + } + Self(res) + } + + pub(crate) fn transpose(&self) -> Self { + let mut res = [[F::ZERO; T]; T]; + for i in 0..T { + for j in 0..T { + res[i][j] = self.0[j][i]; + } + } + Self(res) + } + + pub(crate) fn determinant(m: [[F; N]; N]) -> F { + let mut res = F::ONE; + let mut m = m; + for i in 0..N { + let mut pivot = i; + while m[pivot][i] == F::ZERO { + pivot += 1; + assert!(pivot < N, "matrix is not invertible"); + } + if pivot != i { + res = -res; + m.swap(pivot, i); + } + res *= m[i][i]; + let inv = m[i][i].invert().unwrap(); + for j in i + 1..N { + let factor = m[j][i] * inv; + for k in i + 1..N { + m[j][k] -= m[i][k] * factor; + } + } + } + res + } + + /// See Section B in Supplementary Material https://eprint.iacr.org/2019/458.pdf + /// Factorises an MDS matrix `M` into `M'` and `M''` where `M = M' * M''`. + /// Resulted `M''` matrices are the sparse ones while `M'` will contribute + /// to the accumulator of the process + pub(crate) fn factorise(&self) -> (Self, SparseMDSMatrix) { + assert_eq!(RATE + 1, T); + // Given `(t-1 * t-1)` MDS matrix called `hat` constructs the `t * t` matrix in + // form `[[1 | 0], [0 | m]]`, ie `hat` is the right bottom sub-matrix + let prime = |hat: Mds| -> Self { + let mut prime = Self::identity(); + for (prime_row, hat_row) in prime.iter_mut().skip(1).zip(hat.iter()) { + for (el_prime, el_hat) in prime_row.iter_mut().skip(1).zip(hat_row.iter()) { + *el_prime = *el_hat; + } + } + Self(prime) + }; + + // Given `(t-1)` sized `w_hat` vector constructs the matrix in form + // `[[m_0_0 | m_0_i], [w_hat | identity]]` + let prime_prime = |w_hat: [F; RATE]| -> Mds { + let mut prime_prime = Self::identity(); + prime_prime[0] = self.0[0]; + for (row, w) in prime_prime.iter_mut().skip(1).zip(w_hat.iter()) { + row[0] = *w + } + prime_prime + }; + + let w = self.0.iter().skip(1).map(|row| row[0]).collect::>(); + // m_hat is the `(t-1 * t-1)` right bottom sub-matrix of m := self.0 + let mut m_hat = [[F::ZERO; RATE]; RATE]; + for i in 0..RATE { + for j in 0..RATE { + m_hat[i][j] = self.0[i + 1][j + 1]; + } + } + // w_hat = m_hat^{-1} * w, where m_hat^{-1} is matrix inverse and * is matrix mult + // we avoid computing m_hat^{-1} explicitly by using Cramer's rule: https://en.wikipedia.org/wiki/Cramer%27s_rule + let mut w_hat = [F::ZERO; RATE]; + let det = Self::determinant(m_hat); + let det_inv = Option::::from(det.invert()).expect("matrix is not invertible"); + for j in 0..RATE { + let mut m_hat_j = m_hat; + for i in 0..RATE { + m_hat_j[i][j] = w[i]; + } + w_hat[j] = Self::determinant(m_hat_j) * det_inv; + } + let m_prime = prime(m_hat); + let m_prime_prime = prime_prime(w_hat); + // row = first row of m_prime_prime.transpose() = first column of m_prime_prime + let row: [F; T] = + m_prime_prime.iter().map(|row| row[0]).collect::>().try_into().unwrap(); + // col_hat = first column of m_prime_prime.transpose() without first element = first row of m_prime_prime without first element + let col_hat: [F; RATE] = m_prime_prime[0][1..].try_into().unwrap(); + (m_prime, SparseMDSMatrix { row, col_hat }) + } +} diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs new file mode 100644 index 00000000..10a03034 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -0,0 +1,352 @@ +use crate::{ + gates::{GateInstructions, RangeInstructions}, + poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState}, + safe_types::{SafeBool, SafeTypeChip}, + utils::BigPrimeField, + AssignedValue, Context, + QuantumCell::Constant, + ScalarField, +}; + +use getset::Getters; +use num_bigint::BigUint; +use std::{cell::OnceCell, mem}; + +#[cfg(test)] +mod tests; + +/// Module for maximum distance separable matrix operations. +pub mod mds; +/// Module for poseidon specification. +pub mod spec; +/// Module for poseidon states. +pub mod state; + +/// Stateless Poseidon hasher. +#[derive(Clone, Debug)] +pub struct PoseidonHasher { + spec: OptimizedPoseidonSpec, + consts: OnceCell>, +} +#[derive(Clone, Debug, Getters)] +struct PoseidonHasherConsts { + #[getset(get = "pub")] + init_state: PoseidonState, + // hash of an empty input(""). + #[getset(get = "pub")] + empty_hash: AssignedValue, +} + +impl PoseidonHasherConsts { + pub fn new( + ctx: &mut Context, + gate: &impl GateInstructions, + spec: &OptimizedPoseidonSpec, + ) -> Self { + let init_state = PoseidonState::default(ctx); + let mut state = init_state.clone(); + let empty_hash = fix_len_array_squeeze(ctx, gate, &[], &mut state, spec); + Self { init_state, empty_hash } + } +} + +/// 1 logical row of compact input for Poseidon hasher. +#[derive(Copy, Clone, Debug)] +pub struct PoseidonCompactInput { + // Right padded inputs. No constrains on paddings. + inputs: [AssignedValue; RATE], + // is_final = 1 triggers squeeze. + is_final: SafeBool, + // Length of `inputs`. + len: AssignedValue, +} + +impl PoseidonCompactInput { + /// Create a new PoseidonCompactInput. + pub fn new( + inputs: [AssignedValue; RATE], + is_final: SafeBool, + len: AssignedValue, + ) -> Self { + Self { inputs, is_final, len } + } + + /// Add data validation constraints. + pub fn add_validation_constraints( + &self, + ctx: &mut Context, + range: &impl RangeInstructions, + ) { + range.is_less_than_safe(ctx, self.len, (RATE + 1) as u64); + // Invalid case: (!is_final && len != RATE) ==> !(is_final || len == RATE) + let is_full: AssignedValue = + range.gate().is_equal(ctx, self.len, Constant(F::from(RATE as u64))); + let invalid_cond = range.gate().or(ctx, *self.is_final.as_ref(), is_full); + range.gate().assert_is_const(ctx, &invalid_cond, &F::ZERO); + } +} + +/// A compact chunk input for Poseidon hasher. The end of a logical input could only be at the boundary of a chunk. +#[derive(Clone, Debug)] +pub struct PoseidonCompactChunkInput { + // Inputs of a chunk. All witnesses will be absorbed. + inputs: Vec<[AssignedValue; RATE]>, + // is_final = 1 triggers squeeze. + is_final: SafeBool, +} + +impl PoseidonCompactChunkInput { + /// Create a new PoseidonCompactInput. + pub fn new(inputs: Vec<[AssignedValue; RATE]>, is_final: SafeBool) -> Self { + Self { inputs, is_final } + } +} + +/// 1 logical row of compact output for Poseidon hasher. +#[derive(Copy, Clone, Debug, Getters)] +pub struct PoseidonCompactOutput { + /// hash of 1 logical input. + #[getset(get = "pub")] + hash: AssignedValue, + /// is_final = 1 ==> this is the end of a logical input. + #[getset(get = "pub")] + is_final: SafeBool, +} + +impl PoseidonHasher { + /// Create a poseidon hasher from an existing spec. + pub fn new(spec: OptimizedPoseidonSpec) -> Self { + Self { spec, consts: OnceCell::new() } + } + /// Initialize necessary consts of hasher. Must be called before any computation. + pub fn initialize_consts(&mut self, ctx: &mut Context, gate: &impl GateInstructions) { + self.consts.get_or_init(|| PoseidonHasherConsts::::new(ctx, gate, &self.spec)); + } + + /// Clear all consts. + pub fn clear(&mut self) { + self.consts.take(); + } + + fn empty_hash(&self) -> &AssignedValue { + self.consts.get().unwrap().empty_hash() + } + fn init_state(&self) -> &PoseidonState { + self.consts.get().unwrap().init_state() + } + + /// Constrains and returns hash of a witness array with a variable length. + /// + /// Assumes `len` is within [usize] and `len <= inputs.len()`. + /// * inputs: An right-padded array of [AssignedValue]. Constraints on paddings are not required. + /// * len: Length of `inputs`. + /// Return hash of `inputs`. + pub fn hash_var_len_array( + &self, + ctx: &mut Context, + range: &impl RangeInstructions, + inputs: &[AssignedValue], + len: AssignedValue, + ) -> AssignedValue + where + F: BigPrimeField, + { + // TODO: rewrite this using hash_compact_input. + let max_len = inputs.len(); + if max_len == 0 { + return *self.empty_hash(); + }; + + // len <= max_len --> num_of_bits(len) <= num_of_bits(max_len) + let num_bits = (usize::BITS - max_len.leading_zeros()) as usize; + // num_perm = len // RATE + 1, len_last_chunk = len % RATE + let (mut num_perm, len_last_chunk) = range.div_mod(ctx, len, BigUint::from(RATE), num_bits); + num_perm = range.gate().inc(ctx, num_perm); + + let mut state = self.init_state().clone(); + let mut result_state = state.clone(); + for (i, chunk) in inputs.chunks(RATE).enumerate() { + let is_last_perm = + range.gate().is_equal(ctx, num_perm, Constant(F::from((i + 1) as u64))); + let len_chunk = range.gate().select( + ctx, + len_last_chunk, + Constant(F::from(RATE as u64)), + is_last_perm, + ); + + state.permutation(ctx, range.gate(), chunk, Some(len_chunk), &self.spec); + result_state.select( + ctx, + range.gate(), + SafeTypeChip::::unsafe_to_bool(is_last_perm), + &state, + ); + } + if max_len % RATE == 0 { + let is_last_perm = range.gate().is_equal( + ctx, + num_perm, + Constant(F::from((max_len / RATE + 1) as u64)), + ); + let len_chunk = ctx.load_zero(); + state.permutation(ctx, range.gate(), &[], Some(len_chunk), &self.spec); + result_state.select( + ctx, + range.gate(), + SafeTypeChip::::unsafe_to_bool(is_last_perm), + &state, + ); + } + result_state.s[1] + } + + /// Constrains and returns hash of a witness array. + /// + /// * inputs: An array of [AssignedValue]. + /// Return hash of `inputs`. + pub fn hash_fix_len_array( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: &[AssignedValue], + ) -> AssignedValue + where + F: BigPrimeField, + { + let mut state = self.init_state().clone(); + fix_len_array_squeeze(ctx, gate, inputs, &mut state, &self.spec) + } + + /// Constrains and returns hashes of inputs in a compact format. Length of `compact_inputs` should be determined at compile time. + pub fn hash_compact_input( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + compact_inputs: &[PoseidonCompactInput], + ) -> Vec> + where + F: BigPrimeField, + { + let mut outputs = Vec::with_capacity(compact_inputs.len()); + let mut state = self.init_state().clone(); + for input in compact_inputs { + // Assume this is the last row of a logical input: + // Depending on if len == RATE. + let is_full = gate.is_equal(ctx, input.len, Constant(F::from(RATE as u64))); + // Case 1: if len != RATE. + state.permutation(ctx, gate, &input.inputs, Some(input.len), &self.spec); + // Case 2: if len == RATE, an extra permuation is needed for squeeze. + let mut state_2 = state.clone(); + state_2.permutation(ctx, gate, &[], None, &self.spec); + // Select the result of case 1/2 depending on if len == RATE. + let hash = gate.select(ctx, state_2.s[1], state.s[1], is_full); + outputs.push(PoseidonCompactOutput { hash, is_final: input.is_final }); + // Reset state to init_state if this is the end of a logical input. + // TODO: skip this if this is the last row. + state.select(ctx, gate, input.is_final, self.init_state()); + } + outputs + } + + /// Constrains and returns hashes of chunk inputs in a compact format. Length of `chunk_inputs` should be determined at compile time. + pub fn hash_compact_chunk_inputs( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + chunk_inputs: &[PoseidonCompactChunkInput], + ) -> Vec> + where + F: BigPrimeField, + { + let zero_witness = ctx.load_zero(); + let mut outputs = Vec::with_capacity(chunk_inputs.len()); + let mut state = self.init_state().clone(); + for chunk_input in chunk_inputs { + let is_final = chunk_input.is_final; + for absorb in &chunk_input.inputs { + state.permutation(ctx, gate, absorb, None, &self.spec); + } + // Because the length of each absorb is always RATE. An extra permutation is needed for squeeze. + let mut output_state = state.clone(); + output_state.permutation(ctx, gate, &[], None, &self.spec); + let hash = gate.select(ctx, output_state.s[1], zero_witness, *is_final.as_ref()); + outputs.push(PoseidonCompactOutput { hash, is_final }); + // Reset state to init_state if this is the end of a logical input. + state.select(ctx, gate, is_final, self.init_state()); + } + outputs + } +} + +/// Poseidon sponge. This is stateful. +pub struct PoseidonSponge { + init_state: PoseidonState, + state: PoseidonState, + spec: OptimizedPoseidonSpec, + absorbing: Vec>, +} + +impl PoseidonSponge { + /// Create new Poseidon hasher. + pub fn new( + ctx: &mut Context, + ) -> Self { + let init_state = PoseidonState::default(ctx); + let state = init_state.clone(); + Self { + init_state, + state, + spec: OptimizedPoseidonSpec::new::(), + absorbing: Vec::new(), + } + } + + /// Initialize a poseidon hasher from an existing spec. + pub fn from_spec(ctx: &mut Context, spec: OptimizedPoseidonSpec) -> Self { + let init_state = PoseidonState::default(ctx); + Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() } + } + + /// Reset state to default and clear the buffer. + pub fn clear(&mut self) { + self.state = self.init_state.clone(); + self.absorbing.clear(); + } + + /// Store given `elements` into buffer. + pub fn update(&mut self, elements: &[AssignedValue]) { + self.absorbing.extend_from_slice(elements); + } + + /// Consume buffer and perform permutation, then output second element of + /// state. + pub fn squeeze( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> AssignedValue { + let input_elements = mem::take(&mut self.absorbing); + fix_len_array_squeeze(ctx, gate, &input_elements, &mut self.state, &self.spec) + } +} + +/// ATTETION: input_elements.len() needs to be fixed at compile time. +fn fix_len_array_squeeze( + ctx: &mut Context, + gate: &impl GateInstructions, + input_elements: &[AssignedValue], + state: &mut PoseidonState, + spec: &OptimizedPoseidonSpec, +) -> AssignedValue { + let exact = input_elements.len() % RATE == 0; + + for chunk in input_elements.chunks(RATE) { + state.permutation(ctx, gate, chunk, None, spec); + } + if exact { + state.permutation(ctx, gate, &[], None, spec); + } + + state.s[1] +} diff --git a/halo2-base/src/poseidon/hasher/spec.rs b/halo2-base/src/poseidon/hasher/spec.rs new file mode 100644 index 00000000..e0a0d2c9 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/spec.rs @@ -0,0 +1,176 @@ +use crate::{ + ff::{FromUniformBytes, PrimeField}, + poseidon::hasher::mds::*, +}; + +use getset::{CopyGetters, Getters}; +use poseidon_rs::poseidon::primitives::Spec as PoseidonSpec; // trait +use std::marker::PhantomData; + +// struct so we can use PoseidonSpec trait to generate round constants and MDS matrix +#[derive(Debug)] +pub(crate) struct Poseidon128Pow5Gen< + F: PrimeField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + const SECURE_MDS: usize, +> { + _marker: PhantomData, +} + +impl< + F: PrimeField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + const SECURE_MDS: usize, + > PoseidonSpec for Poseidon128Pow5Gen +{ + fn full_rounds() -> usize { + R_F + } + + fn partial_rounds() -> usize { + R_P + } + + fn sbox(val: F) -> F { + val.pow_vartime([5]) + } + + // see "Avoiding insecure matrices" in Section 2.3 of https://eprint.iacr.org/2019/458.pdf + // most Specs used in practice have SECURE_MDS = 0 + fn secure_mds() -> usize { + SECURE_MDS + } +} + +// We use the optimized Poseidon implementation described in Supplementary Material Section B of https://eprint.iacr.org/2019/458.pdf +// This involves some further computation of optimized constants and sparse MDS matrices beyond what the Scroll PoseidonSpec generates +// The implementation below is adapted from https://github.com/privacy-scaling-explorations/poseidon + +/// `OptimizedPoseidonSpec` holds construction parameters as well as constants that are used in +/// permutation step. +#[derive(Debug, Clone, Getters, CopyGetters)] +pub struct OptimizedPoseidonSpec { + /// Number of full rounds + #[getset(get_copy = "pub")] + pub(crate) r_f: usize, + /// MDS matrices + #[getset(get = "pub")] + pub(crate) mds_matrices: MDSMatrices, + /// Round constants + #[getset(get = "pub")] + pub(crate) constants: OptimizedConstants, +} + +/// `OptimizedConstants` has round constants that are added each round. While +/// full rounds has T sized constants there is a single constant for each +/// partial round +#[derive(Debug, Clone, Getters)] +pub struct OptimizedConstants { + /// start + #[getset(get = "pub")] + pub(crate) start: Vec<[F; T]>, + /// partial + #[getset(get = "pub")] + pub(crate) partial: Vec, + /// end + #[getset(get = "pub")] + pub(crate) end: Vec<[F; T]>, +} + +impl OptimizedPoseidonSpec { + /// Generate new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated + pub fn new() -> Self + where + F: FromUniformBytes<64> + Ord, + { + let (round_constants, mds, mds_inv) = + Poseidon128Pow5Gen::::constants(); + let mds = MDSMatrix(mds); + let inverse_mds = MDSMatrix(mds_inv); + + let constants = + Self::calculate_optimized_constants(R_F, R_P, round_constants, &inverse_mds); + let (sparse_matrices, pre_sparse_mds) = Self::calculate_sparse_matrices(R_P, &mds); + + Self { + r_f: R_F, + constants, + mds_matrices: MDSMatrices { mds, sparse_matrices, pre_sparse_mds }, + } + } + + fn calculate_optimized_constants( + r_f: usize, + r_p: usize, + constants: Vec<[F; T]>, + inverse_mds: &MDSMatrix, + ) -> OptimizedConstants { + let (number_of_rounds, r_f_half) = (r_f + r_p, r_f / 2); + assert_eq!(constants.len(), number_of_rounds); + + // Calculate optimized constants for first half of the full rounds + let mut constants_start: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half]; + constants_start[0] = constants[0]; + for (optimized, constants) in + constants_start.iter_mut().skip(1).zip(constants.iter().skip(1)) + { + *optimized = inverse_mds.mul_vector(constants); + } + + // Calculate constants for partial rounds + let mut acc = constants[r_f_half + r_p]; + let mut constants_partial = vec![F::ZERO; r_p]; + for (optimized, constants) in constants_partial + .iter_mut() + .rev() + .zip(constants.iter().skip(r_f_half).rev().skip(r_f_half)) + { + let mut tmp = inverse_mds.mul_vector(&acc); + *optimized = tmp[0]; + + tmp[0] = F::ZERO; + for ((acc, tmp), constant) in acc.iter_mut().zip(tmp).zip(constants.iter()) { + *acc = tmp + constant + } + } + constants_start.push(inverse_mds.mul_vector(&acc)); + + // Calculate optimized constants for ending half of the full rounds + let mut constants_end: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half - 1]; + for (optimized, constants) in + constants_end.iter_mut().zip(constants.iter().skip(r_f_half + r_p + 1)) + { + *optimized = inverse_mds.mul_vector(constants); + } + + OptimizedConstants { + start: constants_start, + partial: constants_partial, + end: constants_end, + } + } + + fn calculate_sparse_matrices( + r_p: usize, + mds: &MDSMatrix, + ) -> (Vec>, MDSMatrix) { + let mds = mds.transpose(); + let mut acc = mds.clone(); + let mut sparse_matrices = (0..r_p) + .map(|_| { + let (m_prime, m_prime_prime) = acc.factorise(); + acc = mds.mul(&m_prime); + m_prime_prime + }) + .collect::>>(); + + sparse_matrices.reverse(); + (sparse_matrices, acc.transpose()) + } +} diff --git a/halo2-base/src/poseidon/hasher/state.rs b/halo2-base/src/poseidon/hasher/state.rs new file mode 100644 index 00000000..5b8fd308 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/state.rs @@ -0,0 +1,251 @@ +use std::iter; + +use itertools::Itertools; + +use crate::{ + gates::GateInstructions, + poseidon::hasher::{mds::SparseMDSMatrix, spec::OptimizedPoseidonSpec}, + safe_types::SafeBool, + utils::ScalarField, + AssignedValue, Context, + QuantumCell::{Constant, Existing}, +}; + +#[derive(Clone, Debug)] +pub(crate) struct PoseidonState { + pub(crate) s: [AssignedValue; T], +} + +impl PoseidonState { + pub fn default(ctx: &mut Context) -> Self { + let mut default_state = [F::ZERO; T]; + // from Section 4.2 of https://eprint.iacr.org/2019/458.pdf + // • Variable-Input-Length Hashing. The capacity value is 2^64 + (o−1) where o the output length. + // for our transcript use cases, o = 1 + default_state[0] = F::from_u128(1u128 << 64); + Self { s: default_state.map(|f| ctx.load_constant(f)) } + } + + /// Perform permutation on this state. + /// + /// ATTETION: inputs.len() needs to be fixed at compile time. + /// Assume len <= inputs.len(). + /// `inputs` is right padded. + /// If `len` is `None`, treat `inputs` as a fixed length array. + pub fn permutation( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: &[AssignedValue], + len: Option>, + spec: &OptimizedPoseidonSpec, + ) { + let r_f = spec.r_f / 2; + let mds = &spec.mds_matrices.mds.0; + let pre_sparse_mds = &spec.mds_matrices.pre_sparse_mds.0; + let sparse_matrices = &spec.mds_matrices.sparse_matrices; + + // First half of the full round + let constants = &spec.constants.start; + if let Some(len) = len { + // Note: this doesn't mean `padded_inputs` is 0 padded because there is no constraints on `inputs[len..]` + let padded_inputs: [AssignedValue; RATE] = + core::array::from_fn( + |i| if i < inputs.len() { inputs[i] } else { ctx.load_zero() }, + ); + self.absorb_var_len_with_pre_constants(ctx, gate, padded_inputs, len, &constants[0]); + } else { + self.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); + } + for constants in constants.iter().skip(1).take(r_f - 1) { + self.sbox_full(ctx, gate, constants); + self.apply_mds(ctx, gate, mds); + } + self.sbox_full(ctx, gate, constants.last().unwrap()); + self.apply_mds(ctx, gate, pre_sparse_mds); + + // Partial rounds + let constants = &spec.constants.partial; + for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { + self.sbox_part(ctx, gate, constant); + self.apply_sparse_mds(ctx, gate, sparse_mds); + } + + // Second half of the full rounds + let constants = &spec.constants.end; + for constants in constants.iter() { + self.sbox_full(ctx, gate, constants); + self.apply_mds(ctx, gate, mds); + } + self.sbox_full(ctx, gate, &[F::ZERO; T]); + self.apply_mds(ctx, gate, mds); + } + + /// Constrains and set self to a specific state if `selector` is true. + pub fn select( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + selector: SafeBool, + set_to: &Self, + ) { + for i in 0..T { + self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref()); + } + } + + fn x_power5_with_constant( + ctx: &mut Context, + gate: &impl GateInstructions, + x: AssignedValue, + constant: &F, + ) -> AssignedValue { + let x2 = gate.mul(ctx, x, x); + let x4 = gate.mul(ctx, x2, x2); + gate.mul_add(ctx, x, x4, Constant(*constant)) + } + + fn sbox_full( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + constants: &[F; T], + ) { + for (x, constant) in self.s.iter_mut().zip(constants.iter()) { + *x = Self::x_power5_with_constant(ctx, gate, *x, constant); + } + } + + fn sbox_part(&mut self, ctx: &mut Context, gate: &impl GateInstructions, constant: &F) { + let x = &mut self.s[0]; + *x = Self::x_power5_with_constant(ctx, gate, *x, constant); + } + + fn absorb_with_pre_constants( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: &[AssignedValue], + pre_constants: &[F; T], + ) { + assert!(inputs.len() < T); + + // Explanation of what's going on: before each round of the poseidon permutation, + // two things have to be added to the state: inputs (the absorbed elements) and + // preconstants. Imagine the state as a list of T elements, the first of which is + // the capacity: |--cap--|--el1--|--el2--|--elR--| + // - A preconstant is added to each of all T elements (which is different for each) + // - The inputs are added to all elements starting from el1 (so, not to the capacity), + // to as many elements as inputs are available. + // - To the first element for which no input is left (if any), an extra 1 is added. + + // adding preconstant to the distinguished capacity element (only one) + self.s[0] = gate.add(ctx, self.s[0], Constant(pre_constants[0])); + + // adding pre-constants and inputs to the elements for which both are available + for ((x, constant), input) in + self.s.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs.iter()) + { + *x = gate.sum(ctx, [Existing(*x), Existing(*input), Constant(*constant)]); + } + + let offset = inputs.len() + 1; + // adding only pre-constants when no input is left + for (i, (x, constant)) in + self.s.iter_mut().zip(pre_constants.iter()).skip(offset).enumerate() + { + *x = gate.add(ctx, *x, Constant(if i == 0 { F::ONE + constant } else { *constant })); + // the if idx == 0 { F::one() } else { F::zero() } is to pad the input with a single 1 and then 0s + // this is the padding suggested in pg 31 of https://eprint.iacr.org/2019/458.pdf and in Section 4.2 (Variable-Input-Length Hashing. The padding consists of one field element being 1, and the remaining elements being 0.) + } + } + + /// Absorb inputs with a variable length. + /// + /// `inputs` is right padded. + fn absorb_var_len_with_pre_constants( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: [AssignedValue; RATE], + len: AssignedValue, + pre_constants: &[F; T], + ) { + // Explanation of what's going on: before each round of the poseidon permutation, + // two things have to be added to the state: inputs (the absorbed elements) and + // preconstants. Imagine the state as a list of T elements, the first of which is + // the capacity: |--cap--|--el1--|--el2--|--elR--| + // - A preconstant is added to each of all T elements (which is different for each) + // - The inputs are added to all elements starting from el1 (so, not to the capacity), + // to as many elements as inputs are available. + // - To the first element for which no input is left (if any), an extra 1 is added. + + // Adding preconstants to the current state. + for (i, pre_const) in pre_constants.iter().enumerate() { + self.s[i] = gate.add(ctx, self.s[i], Constant(*pre_const)); + } + + // Generate a mask array where a[i] = i < len for i = 0..RATE. + let idx = gate.dec(ctx, len); + let len_indicator = gate.idx_to_indicator(ctx, idx, RATE); + // inputs_mask[i] = sum(len_indicator[i..]) + let mut inputs_mask = + gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec(); + inputs_mask.reverse(); + + let padded_inputs = inputs + .iter() + .zip(inputs_mask.iter()) + .map(|(input, mask)| gate.mul(ctx, *input, *mask)) + .collect_vec(); + for i in 0..RATE { + // Add all inputs. + self.s[i + 1] = gate.add(ctx, self.s[i + 1], padded_inputs[i]); + // Add the extra 1 after inputs. + if i + 2 < T { + self.s[i + 2] = gate.add(ctx, self.s[i + 2], len_indicator[i]); + } + } + // If len == 0, inputs_mask is all 0. Then the extra 1 should be added into s[1]. + let empty_extra_one = gate.not(ctx, inputs_mask[0]); + self.s[1] = gate.add(ctx, self.s[1], empty_extra_one); + } + + fn apply_mds( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + mds: &[[F; T]; T], + ) { + let res = mds + .iter() + .map(|row| { + gate.inner_product(ctx, self.s.iter().copied(), row.iter().map(|c| Constant(*c))) + }) + .collect::>(); + + self.s = res.try_into().unwrap(); + } + + fn apply_sparse_mds( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + mds: &SparseMDSMatrix, + ) { + self.s = iter::once(gate.inner_product( + ctx, + self.s.iter().copied(), + mds.row.iter().map(|c| Constant(*c)), + )) + .chain( + mds.col_hat + .iter() + .zip(self.s.iter().skip(1)) + .map(|(coeff, state)| gate.mul_add(ctx, self.s[0], Constant(*coeff), *state)), + ) + .collect::>() + .try_into() + .unwrap(); + } +} diff --git a/halo2-base/src/poseidon/hasher/tests/compatibility.rs b/halo2-base/src/poseidon/hasher/tests/compatibility.rs new file mode 100644 index 00000000..74e40531 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/compatibility.rs @@ -0,0 +1,117 @@ +use std::{cmp::max, iter::zip}; + +use crate::{ + gates::{flex_gate::threads::SinglePhaseCoreManager, GateChip}, + halo2_proofs::halo2curves::bn256::Fr, + poseidon::hasher::PoseidonSponge, + utils::ScalarField, +}; +use pse_poseidon::Poseidon; +use rand::Rng; + +// make interleaved calls to absorb and squeeze elements and +// check that the result is the same in-circuit and natively +fn sponge_compatiblity_verification< + F: ScalarField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + // elements of F to absorb; one sublist = one absorption + mut absorptions: Vec>, + // list of amounts of elements of F that should be squeezed every time + mut squeezings: Vec, +) { + let mut pool = SinglePhaseCoreManager::new(true, Default::default()); + let gate = GateChip::default(); + + let ctx = pool.main(); + + // constructing native and in-circuit Poseidon sponges + let mut native_sponge = Poseidon::::new(R_F, R_P); + // assuming SECURE_MDS = 0 + let mut circuit_sponge = PoseidonSponge::::new::(ctx); + + // preparing to interleave absorptions and squeezings + let n_iterations = max(absorptions.len(), squeezings.len()); + absorptions.resize(n_iterations, Vec::new()); + squeezings.resize(n_iterations, 0); + + for (absorption, squeezing) in zip(absorptions, squeezings) { + // absorb (if any elements were provided) + native_sponge.update(&absorption); + circuit_sponge.update(&ctx.assign_witnesses(absorption)); + + // squeeze (if any elements were requested) + for _ in 0..squeezing { + let native_squeezed = native_sponge.squeeze(); + let circuit_squeezed = circuit_sponge.squeeze(ctx, &gate); + + assert_eq!(native_squeezed, *circuit_squeezed.value()); + } + } + + // even if no squeezings were requested, we squeeze to verify the + // states are the same after all absorptions + let native_squeezed = native_sponge.squeeze(); + let circuit_squeezed = circuit_sponge.squeeze(ctx, &gate); + + assert_eq!(native_squeezed, *circuit_squeezed.value()); +} + +fn random_nested_list_f(len: usize, max_sub_len: usize) -> Vec> { + let mut rng = rand::thread_rng(); + let mut list = Vec::new(); + for _ in 0..len { + let len = rng.gen_range(0..=max_sub_len); + let mut sublist = Vec::new(); + + for _ in 0..len { + sublist.push(F::random(&mut rng)); + } + list.push(sublist); + } + list +} + +fn random_list_usize(len: usize, max: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut list = Vec::new(); + for _ in 0..len { + list.push(rng.gen_range(0..=max)); + } + list +} + +#[test] +fn test_sponge_compatibility_squeezing_only() { + let absorptions = Vec::new(); + let squeezings = random_list_usize(10, 7); + + sponge_compatiblity_verification::(absorptions, squeezings); +} + +#[test] +fn test_sponge_compatibility_absorbing_only() { + let absorptions = random_nested_list_f(8, 5); + let squeezings = Vec::new(); + + sponge_compatiblity_verification::(absorptions, squeezings); +} + +#[test] +fn test_sponge_compatibility_interleaved() { + let absorptions = random_nested_list_f(10, 5); + let squeezings = random_list_usize(7, 10); + + sponge_compatiblity_verification::(absorptions, squeezings); +} + +#[test] +fn test_sponge_compatibility_other_params() { + let absorptions = random_nested_list_f(10, 10); + let squeezings = random_list_usize(10, 10); + + sponge_compatiblity_verification::(absorptions, squeezings); +} diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs new file mode 100644 index 00000000..043bf221 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -0,0 +1,358 @@ +use crate::{ + gates::{range::RangeInstructions, RangeChip}, + halo2_proofs::halo2curves::bn256::Fr, + poseidon::hasher::{ + spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactInput, + PoseidonHasher, + }, + safe_types::SafeTypeChip, + utils::{testing::base_test, ScalarField}, + Context, +}; +use halo2_proofs_axiom::arithmetic::Field; +use itertools::Itertools; +use pse_poseidon::Poseidon; +use rand::Rng; + +#[derive(Clone)] +struct Payload { + // Represent value of a right-padded witness array with a variable length + pub values: Vec, + // Length of `values`. + pub len: usize, +} + +// check if the results from hasher and native sponge are same for hash_var_len_array. +fn hasher_compatiblity_verification< + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec>, +) { + base_test().k(12).run(|ctx, range| { + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + for payload in payloads { + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + native_sponge.update(&payload.values[..payload.len]); + let native_result = native_sponge.squeeze(); + let inputs = ctx.assign_witnesses(payload.values); + let len = ctx.load_witness(Fr::from(payload.len as u64)); + let hasher_result = hasher.hash_var_len_array(ctx, range, &inputs, len); + assert_eq!(native_result, *hasher_result.value()); + } + }); +} + +// check if the results from hasher and native sponge are same for hash_compact_input. +fn hasher_compact_inputs_compatiblity_verification< + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec>, + ctx: &mut Context, + range: &RangeChip, +) { + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + let mut native_results = Vec::with_capacity(payloads.len()); + let mut compact_inputs = Vec::>::new(); + let rate_witness = ctx.load_constant(Fr::from(RATE as u64)); + let true_witness = ctx.load_constant(Fr::ONE); + let false_witness = ctx.load_zero(); + for payload in payloads { + assert!(payload.values.len() % RATE == 0); + assert!(payload.values.len() >= payload.len); + assert!(payload.values.len() == RATE || payload.values.len() - payload.len < RATE); + let num_chunk = payload.values.len() / RATE; + let last_chunk_len = RATE - (payload.values.len() - payload.len); + let inputs = ctx.assign_witnesses(payload.values.clone()); + for (chunk_idx, input_chunk) in inputs.chunks(RATE).enumerate() { + let len_witness = if chunk_idx + 1 == num_chunk { + ctx.load_witness(Fr::from(last_chunk_len as u64)) + } else { + rate_witness + }; + let is_final_witness = SafeTypeChip::unsafe_to_bool(if chunk_idx + 1 == num_chunk { + true_witness + } else { + false_witness + }); + compact_inputs.push(PoseidonCompactInput { + inputs: input_chunk.try_into().unwrap(), + len: len_witness, + is_final: is_final_witness, + }); + } + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + native_sponge.update(&payload.values[..payload.len]); + let native_result = native_sponge.squeeze(); + native_results.push(native_result); + } + let compact_outputs = hasher.hash_compact_input(ctx, range.gate(), &compact_inputs); + let mut output_offset = 0; + for (compact_output, compact_input) in compact_outputs.iter().zip(compact_inputs) { + // into() doesn't work if ! is in the beginning in the bool expression... + let is_not_final_input: bool = compact_input.is_final.as_ref().value().is_zero().into(); + let is_not_final_output: bool = compact_output.is_final().as_ref().value().is_zero().into(); + assert_eq!(is_not_final_input, is_not_final_output); + if !is_not_final_output { + assert_eq!(native_results[output_offset], *compact_output.hash().value()); + output_offset += 1; + } + } +} + +// check if the results from hasher and native sponge are same for hash_compact_input. +fn hasher_compact_chunk_inputs_compatiblity_verification< + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec<(Payload, bool)>, + ctx: &mut Context, + range: &RangeChip, +) { + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + let mut native_results = Vec::with_capacity(payloads.len()); + let mut chunk_inputs = Vec::>::new(); + let true_witness = SafeTypeChip::unsafe_to_bool(ctx.load_constant(Fr::ONE)); + let false_witness = SafeTypeChip::unsafe_to_bool(ctx.load_zero()); + + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + for (payload, is_final) in payloads { + assert!(payload.values.len() == payload.len); + assert!(payload.values.len() % RATE == 0); + let inputs = ctx.assign_witnesses(payload.values.clone()); + + let is_final_witness = if is_final { true_witness } else { false_witness }; + chunk_inputs.push(PoseidonCompactChunkInput { + inputs: inputs.chunks(RATE).map(|c| c.try_into().unwrap()).collect_vec(), + is_final: is_final_witness, + }); + native_sponge.update(&payload.values); + if is_final { + let native_result = native_sponge.squeeze(); + native_results.push(native_result); + native_sponge = Poseidon::::new(R_F, R_P); + } + } + let compact_outputs = hasher.hash_compact_chunk_inputs(ctx, range.gate(), &chunk_inputs); + assert_eq!(chunk_inputs.len(), compact_outputs.len()); + let mut output_offset = 0; + for (compact_output, chunk_input) in compact_outputs.iter().zip(chunk_inputs) { + // into() doesn't work if ! is in the beginning in the bool expression... + let is_final_input = chunk_input.is_final.as_ref().value(); + let is_final_output = compact_output.is_final().as_ref().value(); + assert_eq!(is_final_input, is_final_output); + if is_final_output == &Fr::ONE { + assert_eq!(native_results[output_offset], *compact_output.hash().value()); + output_offset += 1; + } + } +} + +fn random_payload(max_len: usize, len: usize, max_value: usize) -> Payload { + assert!(len <= max_len); + let mut rng = rand::thread_rng(); + let mut values = Vec::new(); + for _ in 0..max_len { + values.push(F::from(rng.gen_range(0..=max_value) as u64)); + } + Payload { values, len } +} + +fn random_payload_without_len(max_len: usize, max_value: usize) -> Payload { + let mut rng = rand::thread_rng(); + let mut values = Vec::new(); + for _ in 0..max_len { + values.push(F::from(rng.gen_range(0..=max_value) as u64)); + } + Payload { values, len: rng.gen_range(0..=max_len) } +} + +#[test] +fn test_poseidon_hasher_compatiblity() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + // max_len = 0 + random_payload(0, 0, usize::MAX), + // max_len % RATE == 0 && len = 0 + random_payload(RATE * 2, 0, usize::MAX), + // max_len % RATE == 0 && 0 < len < max_len && len % RATE == 0 + random_payload(RATE * 2, RATE, usize::MAX), + // max_len % RATE == 0 && 0 < len < max_len && len % RATE != 0 + random_payload(RATE * 5, RATE * 2 + 1, usize::MAX), + // max_len % RATE == 0 && len == max_len + random_payload(RATE * 2, RATE * 2, usize::MAX), + random_payload(RATE * 5, RATE * 5, usize::MAX), + // len % RATE != 0 && len = 0 + random_payload(RATE * 2 + 1, 0, usize::MAX), + random_payload(RATE * 5 + 1, 0, usize::MAX), + // len % RATE != 0 && 0 < len < max_len && len % RATE == 0 + random_payload(RATE * 2 + 1, RATE, usize::MAX), + // len % RATE != 0 && 0 < len < max_len && len % RATE != 0 + random_payload(RATE * 5 + 1, RATE * 2 + 1, usize::MAX), + // len % RATE != 0 && len = max_len + random_payload(RATE * 2 + 1, RATE * 2 + 1, usize::MAX), + random_payload(RATE * 5 + 1, RATE * 5 + 1, usize::MAX), + ]; + hasher_compatiblity_verification::(payloads); + } +} + +#[test] +fn test_poseidon_hasher_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + const R_F: usize = 8; + const R_P: usize = 57; + + let max_lens = vec![0, RATE * 2, RATE * 5, RATE * 2 + 1, RATE * 5 + 1]; + for max_len in max_lens { + let init_input = random_payload_without_len(max_len, usize::MAX); + let logic_input = random_payload_without_len(max_len, usize::MAX); + base_test().k(12).bench_builder(init_input, logic_input, |pool, range, payload| { + let ctx = pool.main(); + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + let inputs = ctx.assign_witnesses(payload.values); + let len = ctx.load_witness(Fr::from(payload.len as u64)); + hasher.hash_var_len_array(ctx, range, &inputs, len); + }); + } + } +} + +#[test] +fn test_poseidon_hasher_compact_inputs() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + // len == 0 + random_payload(RATE, 0, usize::MAX), + // 0 < len < max_len + random_payload(RATE * 2, RATE + 1, usize::MAX), + random_payload(RATE * 5, RATE * 4 + 1, usize::MAX), + // len == max_len + random_payload(RATE * 2, RATE * 2, usize::MAX), + random_payload(RATE * 5, RATE * 5, usize::MAX), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_inputs_compatiblity_verification::(payloads, ctx, range); + }); + } +} + +#[test] +fn test_poseidon_hasher_compact_inputs_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + let params = [ + (RATE, 0), + (RATE * 2, RATE + 1), + (RATE * 5, RATE * 4 + 1), + (RATE * 2, RATE * 2), + (RATE * 5, RATE * 5), + ]; + let init_payloads = params + .iter() + .map(|(max_len, len)| random_payload(*max_len, *len, usize::MAX)) + .collect::>(); + let logic_payloads = params + .iter() + .map(|(max_len, len)| random_payload(*max_len, *len, usize::MAX)) + .collect::>(); + base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| { + let ctx = pool.main(); + hasher_compact_inputs_compatiblity_verification::(input, ctx, range); + }); + } +} + +#[test] +fn test_poseidon_hasher_compact_chunk_inputs() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + (random_payload(RATE * 5, RATE * 5, usize::MAX), true), + (random_payload(RATE, RATE, usize::MAX), false), + (random_payload(RATE * 2, RATE * 2, usize::MAX), true), + (random_payload(RATE * 3, RATE * 3, usize::MAX), true), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_chunk_inputs_compatiblity_verification::( + payloads, ctx, range, + ); + }); + } + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + (random_payload(0, 0, usize::MAX), true), + (random_payload(0, 0, usize::MAX), false), + (random_payload(0, 0, usize::MAX), false), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_chunk_inputs_compatiblity_verification::( + payloads, ctx, range, + ); + }); + } +} + +#[test] +fn test_poseidon_hasher_compact_chunk_inputs_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + let params = [ + (RATE, false), + (RATE * 2, false), + (RATE * 5, false), + (RATE * 2, true), + (RATE * 5, true), + ]; + let init_payloads = params + .iter() + .map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final)) + .collect::>(); + let logic_payloads = params + .iter() + .map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final)) + .collect::>(); + base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| { + let ctx = pool.main(); + hasher_compact_chunk_inputs_compatiblity_verification::( + input, ctx, range, + ); + }); + } +} diff --git a/halo2-base/src/poseidon/hasher/tests/mod.rs b/halo2-base/src/poseidon/hasher/tests/mod.rs new file mode 100644 index 00000000..a734f7d0 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/mod.rs @@ -0,0 +1,39 @@ +use super::*; +use crate::halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; + +use itertools::Itertools; + +mod compatibility; +mod hasher; +mod state; + +#[test] +fn test_mds() { + let spec = OptimizedPoseidonSpec::::new::<8, 57, 0>(); + + let mds = vec![ + vec![ + "7511745149465107256748700652201246547602992235352608707588321460060273774987", + "10370080108974718697676803824769673834027675643658433702224577712625900127200", + "19705173408229649878903981084052839426532978878058043055305024233888854471533", + ], + vec![ + "18732019378264290557468133440468564866454307626475683536618613112504878618481", + "20870176810702568768751421378473869562658540583882454726129544628203806653987", + "7266061498423634438633389053804536045105766754026813321943009179476902321146", + ], + vec![ + "9131299761947733513298312097611845208338517739621853568979632113419485819303", + "10595341252162738537912664445405114076324478519622938027420701542910180337937", + "11597556804922396090267472882856054602429588299176362916247939723151043581408", + ], + ]; + for (row1, row2) in mds.iter().zip_eq(spec.mds_matrices.mds.0.iter()) { + for (e1, e2) in row1.iter().zip_eq(row2.iter()) { + assert_eq!(Fr::from_str_vartime(e1).unwrap(), *e2); + } + } +} + +// TODO: test clear()/squeeze(). +// TODO: test constraints actually work. diff --git a/halo2-base/src/poseidon/hasher/tests/state.rs b/halo2-base/src/poseidon/hasher/tests/state.rs new file mode 100644 index 00000000..f09fb76e --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/state.rs @@ -0,0 +1,129 @@ +use super::*; +use crate::{ + gates::{flex_gate::threads::SinglePhaseCoreManager, GateChip}, + halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}, +}; + +#[test] +fn test_fix_permutation_against_test_vectors() { + let mut pool = SinglePhaseCoreManager::new(true, Default::default()); + let gate = GateChip::::default(); + let ctx = pool.main(); + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + state.permutation(ctx, &gate, &inputs, None, &spec); // avoid padding + let state_0 = state.s; + let expected = [ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2, 3, 4].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + state.permutation(ctx, &gate, &inputs, None, &spec); + let state_0 = state.s; + let expected: [&str; 5] = [ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } +} + +#[test] +fn test_var_permutation_against_test_vectors() { + let mut pool = SinglePhaseCoreManager::new(true, Default::default()); + let gate = GateChip::::default(); + let ctx = pool.main(); + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + let len = ctx.load_constant(Fr::from(RATE as u64)); + state.permutation(ctx, &gate, &inputs, Some(len), &spec); // avoid padding + let state_0 = state.s; + let expected = [ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2, 3, 4].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + let len = ctx.load_constant(Fr::from(RATE as u64)); + state.permutation(ctx, &gate, &inputs, Some(len), &spec); + let state_0 = state.s; + let expected: [&str; 5] = [ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } +} diff --git a/halo2-base/src/poseidon/mod.rs b/halo2-base/src/poseidon/mod.rs new file mode 100644 index 00000000..896b863c --- /dev/null +++ b/halo2-base/src/poseidon/mod.rs @@ -0,0 +1,114 @@ +use crate::{ + gates::{RangeChip, RangeInstructions}, + poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}, + safe_types::{FixLenBytes, VarLenBytes, VarLenBytesVec}, + utils::{BigPrimeField, ScalarField}, + AssignedValue, Context, +}; + +use itertools::Itertools; + +/// Module for Poseidon hasher +pub mod hasher; + +/// Chip for Poseidon hash. +pub struct PoseidonChip<'a, F: ScalarField, const T: usize, const RATE: usize> { + range_chip: &'a RangeChip, + hasher: PoseidonHasher, +} + +impl<'a, F: ScalarField, const T: usize, const RATE: usize> PoseidonChip<'a, F, T, RATE> { + /// Create a new PoseidonChip. + pub fn new( + ctx: &mut Context, + spec: OptimizedPoseidonSpec, + range_chip: &'a RangeChip, + ) -> Self { + let mut hasher = PoseidonHasher::new(spec); + hasher.initialize_consts(ctx, range_chip.gate()); + Self { range_chip, hasher } + } +} + +/// Trait for Poseidon instructions +pub trait PoseidonInstructions { + /// Return hash of a [VarLenBytes] + fn hash_var_len_bytes( + &self, + ctx: &mut Context, + inputs: &VarLenBytes, + ) -> AssignedValue + where + F: BigPrimeField; + + /// Return hash of a [VarLenBytesVec] + fn hash_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: &VarLenBytesVec, + ) -> AssignedValue + where + F: BigPrimeField; + + /// Return hash of a [FixLenBytes] + fn hash_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: &FixLenBytes, + ) -> AssignedValue + where + F: BigPrimeField; +} + +impl<'a, F: ScalarField, const T: usize, const RATE: usize> PoseidonInstructions + for PoseidonChip<'a, F, T, RATE> +{ + fn hash_var_len_bytes( + &self, + ctx: &mut Context, + inputs: &VarLenBytes, + ) -> AssignedValue + where + F: BigPrimeField, + { + let inputs_len = inputs.len(); + self.hasher.hash_var_len_array( + ctx, + self.range_chip, + inputs.bytes().map(|sb| *sb.as_ref()).as_ref(), + *inputs_len, + ) + } + + fn hash_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: &VarLenBytesVec, + ) -> AssignedValue + where + F: BigPrimeField, + { + let inputs_len = inputs.len(); + self.hasher.hash_var_len_array( + ctx, + self.range_chip, + &inputs.bytes().iter().map(|sb| *sb.as_ref()).collect_vec(), + *inputs_len, + ) + } + + fn hash_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: &FixLenBytes, + ) -> AssignedValue + where + F: BigPrimeField, + { + self.hasher.hash_fix_len_array( + ctx, + self.range_chip.gate(), + inputs.bytes().map(|sb| *sb.as_ref()).as_ref(), + ) + } +} diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs new file mode 100644 index 00000000..e1a5e03d --- /dev/null +++ b/halo2-base/src/safe_types/bytes.rs @@ -0,0 +1,238 @@ +#![allow(clippy::len_without_is_empty)] +use crate::{ + gates::GateInstructions, + utils::bit_length, + AssignedValue, Context, + QuantumCell::{Constant, Existing}, +}; + +use super::{SafeByte, SafeType, ScalarField}; + +use getset::Getters; +use itertools::Itertools; + +/// Represents a variable length byte array in circuit. +/// +/// Each element is guaranteed to be a byte, given by type [`SafeByte`]. +/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is some additional context the user must provide. +/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s). +#[derive(Debug, Clone, Getters)] +pub struct VarLenBytes { + /// The byte array, right padded + #[getset(get = "pub")] + bytes: [SafeByte; MAX_LEN], + /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN` + #[getset(get = "pub")] + len: AssignedValue, +} + +impl VarLenBytes { + // VarLenBytes can be only created by SafeChip. + pub(super) fn new(bytes: [SafeByte; MAX_LEN], len: AssignedValue) -> Self { + assert!( + len.value().le(&F::from(MAX_LEN as u64)), + "Invalid length which exceeds MAX_LEN {MAX_LEN}", + ); + Self { bytes, len } + } + + /// Returns the maximum length of the byte array. + pub fn max_len(&self) -> usize { + MAX_LEN + } + + /// Left pads the variable length byte array with 0s to the MAX_LEN + pub fn left_pad_to_fixed( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> FixLenBytes { + let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, MAX_LEN); + FixLenBytes::new( + padded.into_iter().map(|b| SafeByte(b)).collect::>().try_into().unwrap(), + ) + } + + /// Return a copy of the byte array with 0 padding ensured. + pub fn ensure_0_padding(&self, ctx: &mut Context, gate: &impl GateInstructions) -> Self { + let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len); + Self::new(bytes.try_into().unwrap(), self.len) + } +} + +/// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. +/// +/// Each element is guaranteed to be a byte, given by type [`SafeByte`]. +/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is provided when constructing and `bytes.len()` == `MAX_LEN` is enforced. +/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s). +#[derive(Debug, Clone, Getters)] +pub struct VarLenBytesVec { + /// The byte array, right padded + #[getset(get = "pub")] + bytes: Vec>, + /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN` + #[getset(get = "pub")] + len: AssignedValue, +} + +impl VarLenBytesVec { + // VarLenBytesVec can be only created by SafeChip. + pub(super) fn new(bytes: Vec>, len: AssignedValue, max_len: usize) -> Self { + assert!( + len.value().le(&F::from(max_len as u64)), + "Invalid length which exceeds MAX_LEN {}", + max_len + ); + assert_eq!(bytes.len(), max_len, "bytes is not padded correctly"); + Self { bytes, len } + } + + /// Returns the maximum length of the byte array. + pub fn max_len(&self) -> usize { + self.bytes.len() + } + + /// Left pads the variable length byte array with 0s to the MAX_LEN + pub fn left_pad_to_fixed( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> FixLenBytesVec { + let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len()); + FixLenBytesVec::new(padded.into_iter().map(|b| SafeByte(b)).collect_vec(), self.max_len()) + } + + /// Return a copy of the byte array with 0 padding ensured. + pub fn ensure_0_padding(&self, ctx: &mut Context, gate: &impl GateInstructions) -> Self { + let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len); + Self::new(bytes, self.len, self.max_len()) + } +} + +/// Represents a fixed length byte array in circuit. +#[derive(Debug, Clone, Getters)] +pub struct FixLenBytes { + /// The byte array + #[getset(get = "pub")] + bytes: [SafeByte; LEN], +} + +impl FixLenBytes { + // FixLenBytes can be only created by SafeChip. + pub(super) fn new(bytes: [SafeByte; LEN]) -> Self { + Self { bytes } + } + + /// Returns the length of the byte array. + pub fn len(&self) -> usize { + LEN + } + + /// Returns inner array of [SafeByte]s. + pub fn into_bytes(self) -> [SafeByte; LEN] { + self.bytes + } +} + +/// Represents a fixed length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. +#[derive(Debug, Clone, Getters)] +pub struct FixLenBytesVec { + /// The byte array + #[getset(get = "pub")] + bytes: Vec>, +} + +impl FixLenBytesVec { + // FixLenBytes can be only created by SafeChip. + pub(super) fn new(bytes: Vec>, len: usize) -> Self { + assert_eq!(bytes.len(), len, "bytes length doesn't match"); + Self { bytes } + } + + /// Returns the length of the byte array. + pub fn len(&self) -> usize { + self.bytes.len() + } + + /// Returns inner array of [SafeByte]s. + pub fn into_bytes(self) -> Vec> { + self.bytes + } +} + +impl From> + for FixLenBytes::VALUE_LENGTH }> +{ + fn from(bytes: SafeType) -> Self { + let bytes = bytes.value.into_iter().map(|b| SafeByte(b)).collect::>(); + Self::new(bytes.try_into().unwrap()) + } +} + +impl + From::VALUE_LENGTH }>> + for SafeType +{ + fn from(bytes: FixLenBytes::VALUE_LENGTH }>) -> Self { + let bytes = bytes.bytes.into_iter().map(|b| b.0).collect::>(); + Self::new(bytes) + } +} + +/// Represents a fixed length byte array in circuit as a vector, where length must be fixed. +/// Not encouraged to use because `LEN` cannot be verified at compile time. +// pub type FixLenBytesVec = Vec>; + +/// Takes a fixed length array `arr` and returns a length `out_len` array equal to +/// `[[0; out_len - len], arr[..len]].concat()`, i.e., we take `arr[..len]` and +/// zero pad it on the left. +/// +/// Assumes `0 < len <= max_len <= out_len`. +pub fn left_pad_var_array_to_fixed( + ctx: &mut Context, + gate: &impl GateInstructions, + arr: &[impl AsRef>], + len: AssignedValue, + out_len: usize, +) -> Vec> { + debug_assert!(arr.len() <= out_len); + debug_assert!(bit_length(out_len as u64) < F::CAPACITY as usize); + + let mut padded = arr.iter().map(|b| *b.as_ref()).collect_vec(); + padded.resize(out_len, padded[0]); + // We use a barrel shifter to shift `arr` to the right by `out_len - len` bits. + let shift = gate.sub(ctx, Constant(F::from(out_len as u64)), len); + let shift_bits = gate.num_to_bits(ctx, shift, bit_length(out_len as u64)); + for (i, shift_bit) in shift_bits.into_iter().enumerate() { + let shifted = (0..out_len) + .map(|j| if j >= (1 << i) { Existing(padded[j - (1 << i)]) } else { Constant(F::ZERO) }) + .collect_vec(); + padded = padded + .into_iter() + .zip(shifted) + .map(|(noshift, shift)| gate.select(ctx, shift, noshift, shift_bit)) + .collect_vec(); + } + padded +} + +fn ensure_0_padding( + ctx: &mut Context, + gate: &impl GateInstructions, + bytes: &[SafeByte], + len: AssignedValue, +) -> Vec> { + let max_len = bytes.len(); + // Generate a mask array where a[i] = i < len for i = 0..max_len. + let idx = gate.dec(ctx, len); + let len_indicator = gate.idx_to_indicator(ctx, idx, max_len); + // inputs_mask[i] = sum(len_indicator[i..]) + let mut mask = gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec(); + mask.reverse(); + + bytes + .iter() + .zip(mask.iter()) + .map(|(byte, mask)| SafeByte(gate.mul(ctx, byte.0, *mask))) + .collect_vec() +} diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index 63a8d526..5c016d86 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -1,13 +1,25 @@ -pub use crate::{ +use std::{ + borrow::{Borrow, BorrowMut}, + cmp::{max, min}, +}; + +use crate::{ gates::{ flex_gate::GateInstructions, range::{RangeChip, RangeInstructions}, }, utils::ScalarField, AssignedValue, Context, - QuantumCell::{self, Constant, Existing, Witness}, + QuantumCell::Witness, }; -use std::cmp::{max, min}; + +use itertools::Itertools; + +mod bytes; +mod primitives; + +pub use bytes::*; +pub use primitives::*; #[cfg(test)] pub mod tests; @@ -39,32 +51,54 @@ impl pub const BYTES_PER_ELE: usize = BYTES_PER_ELE; /// Total bits of this type. pub const TOTAL_BITS: usize = TOTAL_BITS; - /// Number of bits of each element. - pub const BITS_PER_ELE: usize = min(TOTAL_BITS, BYTES_PER_ELE * BITS_PER_BYTE); /// Number of elements of this type. pub const VALUE_LENGTH: usize = (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE); + /// Number of bits of each element. + pub fn bits_per_ele() -> usize { + min(TOTAL_BITS, BYTES_PER_ELE * BITS_PER_BYTE) + } + // new is private so Safetype can only be constructed by this crate. fn new(raw_values: RawAssignedValues) -> Self { assert!(raw_values.len() == Self::VALUE_LENGTH, "Invalid raw values length"); Self { value: raw_values } } - /// Return values in littile-endian. - pub fn value(&self) -> &RawAssignedValues { + /// Return values in little-endian. + pub fn value(&self) -> &[AssignedValue] { &self.value } } +impl AsRef<[AssignedValue]> + for SafeType +{ + fn as_ref(&self) -> &[AssignedValue] { + self.value() + } +} + +impl TryFrom>> + for SafeType +{ + type Error = String; + + fn try_from(value: Vec>) -> Result { + if value.len() * 8 != TOTAL_BITS { + return Err("Invalid length".to_owned()); + } + Ok(Self::new(value.into_iter().map(|b| b.0).collect::>())) + } +} + /// Represent TOTAL_BITS with the least number of AssignedValue. /// (2^(F::NUM_BITS) - 1) might not be a valid value for F. e.g. max value of F is a prime in [2^(F::NUM_BITS-1), 2^(F::NUM_BITS) - 1] #[allow(type_alias_bounds)] type CompactSafeType = - SafeType; + SafeType; -/// SafeType for bool. -pub type SafeBool = CompactSafeType; /// SafeType for uint8. pub type SafeUint8 = CompactSafeType; /// SafeType for uint16. @@ -75,8 +109,12 @@ pub type SafeUint32 = CompactSafeType; pub type SafeUint64 = CompactSafeType; /// SafeType for uint128. pub type SafeUint128 = CompactSafeType; +/// SafeType for uint160. +pub type SafeUint160 = CompactSafeType; /// SafeType for uint256. pub type SafeUint256 = CompactSafeType; +/// SafeType for Address. +pub type SafeAddress = SafeType; /// SafeType for bytes32. pub type SafeBytes32 = SafeType; @@ -91,7 +129,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { Self { range_chip } } - /// Convert a vector of AssignedValue(treated as little-endian) to a SafeType. + /// Convert a vector of AssignedValue (treated as little-endian) to a SafeType. /// The number of bytes of inputs must equal to the number of bytes of outputs. /// This function also add contraints that a AssignedValue in inputs must be in the range of a byte. pub fn raw_bytes_to( @@ -99,7 +137,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { ctx: &mut Context, inputs: RawAssignedValues, ) -> SafeType { - let element_bits = SafeType::::BITS_PER_ELE; + let element_bits = SafeType::::bits_per_ele(); let bits = TOTAL_BITS; assert!( inputs.len() * BITS_PER_BYTE == max(bits, BITS_PER_BYTE), @@ -127,6 +165,161 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { SafeType::::new(value) } + /// Unsafe method that directly converts `input` to [`SafeType`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeType`]. + pub fn unsafe_to_safe_type( + inputs: RawAssignedValues, + ) -> SafeType { + assert_eq!(inputs.len(), SafeType::::VALUE_LENGTH); + SafeType::::new(inputs) + } + + /// Constrains that the `input` is a boolean value (either 0 or 1) and wraps it in [`SafeBool`]. + pub fn assert_bool(&self, ctx: &mut Context, input: AssignedValue) -> SafeBool { + self.range_chip.gate().assert_bit(ctx, input); + SafeBool(input) + } + + /// Load a boolean value as witness and constrain it is either 0 or 1. + pub fn load_bool(&self, ctx: &mut Context, input: bool) -> SafeBool { + let input = ctx.load_witness(F::from(input)); + self.assert_bool(ctx, input) + } + + /// Unsafe method that directly converts `input` to [`SafeBool`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeBool`]. + pub fn unsafe_to_bool(input: AssignedValue) -> SafeBool { + SafeBool(input) + } + + /// Constrains that the `input` is a byte value and wraps it in [`SafeByte`]. + pub fn assert_byte(&self, ctx: &mut Context, input: AssignedValue) -> SafeByte { + self.range_chip.range_check(ctx, input, BITS_PER_BYTE); + SafeByte(input) + } + + /// Load a boolean value as witness and constrain it is either 0 or 1. + pub fn load_byte(&self, ctx: &mut Context, input: u8) -> SafeByte { + let input = ctx.load_witness(F::from(input as u64)); + self.assert_byte(ctx, input) + } + + /// Unsafe method that directly converts `input` to [`SafeByte`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_byte(input: AssignedValue) -> SafeByte { + SafeByte(input) + } + + /// Unsafe method that directly converts `inputs` to [`VarLenBytes`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_var_len_bytes( + inputs: [AssignedValue; MAX_LEN], + len: AssignedValue, + ) -> VarLenBytes { + VarLenBytes::::new(inputs.map(|input| Self::unsafe_to_byte(input)), len) + } + + /// Unsafe method that directly converts `inputs` to [`VarLenBytesVec`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_var_len_bytes_vec( + inputs: RawAssignedValues, + len: AssignedValue, + max_len: usize, + ) -> VarLenBytesVec { + VarLenBytesVec::::new( + inputs.iter().map(|input| Self::unsafe_to_byte(*input)).collect_vec(), + len, + max_len, + ) + } + + /// Unsafe method that directly converts `inputs` to [`FixLenBytes`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_fix_len_bytes( + inputs: [AssignedValue; MAX_LEN], + ) -> FixLenBytes { + FixLenBytes::::new(inputs.map(|input| Self::unsafe_to_byte(input))) + } + + /// Unsafe method that directly converts `inputs` to [`FixLenBytesVec`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_fix_len_bytes_vec( + inputs: RawAssignedValues, + len: usize, + ) -> FixLenBytesVec { + FixLenBytesVec::::new( + inputs.into_iter().map(|input| Self::unsafe_to_byte(input)).collect_vec(), + len, + ) + } + + /// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Slice representing the byte array. + /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= MAX_LEN`. + /// * MAX_LEN: [usize] representing the maximum length of the byte array and the number of elements it must contain. + pub fn raw_to_var_len_bytes( + &self, + ctx: &mut Context, + inputs: [AssignedValue; MAX_LEN], + len: AssignedValue, + ) -> VarLenBytes { + self.range_chip.check_less_than_safe(ctx, len, MAX_LEN as u64 + 1); + VarLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input)), len) + } + + /// Converts a vector of AssignedValue to [VarLenBytesVec]. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding. + /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= max_len`. + /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. + pub fn raw_to_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: RawAssignedValues, + len: AssignedValue, + max_len: usize, + ) -> VarLenBytesVec { + self.range_chip.check_less_than_safe(ctx, len, max_len as u64 + 1); + VarLenBytesVec::::new( + inputs.iter().map(|input| self.assert_byte(ctx, *input)).collect_vec(), + len, + max_len, + ) + } + + /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytes. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Slice representing the byte array. + /// * LEN: length of the byte array. + pub fn raw_to_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: [AssignedValue; LEN], + ) -> FixLenBytes { + FixLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input))) + } + + /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytesVec. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Slice representing the byte array. + /// * len: length of the byte array. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. + pub fn raw_to_fix_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: RawAssignedValues, + len: usize, + ) -> FixLenBytesVec { + FixLenBytesVec::::new( + inputs.into_iter().map(|input| self.assert_byte(ctx, input)).collect_vec(), + len, + ) + } + fn add_bytes_constraints( &self, ctx: &mut Context, @@ -141,6 +334,6 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { } } - // TODO: Add comprasion. e.g. is_less_than(SafeUint8, SafeUint8) -> SafeBool + // TODO: Add comparison. e.g. is_less_than(SafeUint8, SafeUint8) -> SafeBool // TODO: Add type castings. e.g. uint256 -> bytes32/uint32 -> uint64 } diff --git a/halo2-base/src/safe_types/primitives.rs b/halo2-base/src/safe_types/primitives.rs new file mode 100644 index 00000000..86726595 --- /dev/null +++ b/halo2-base/src/safe_types/primitives.rs @@ -0,0 +1,53 @@ +use super::*; +/// SafeType for bool (1 bit). +/// +/// This is a separate struct from [`CompactSafeType`] with the same behavior. Because +/// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid +/// using [`CompactSafeType`] to avoid the additional heap allocation from a length 1 vector. +#[derive(Clone, Copy, Debug)] +pub struct SafeBool(pub(super) AssignedValue); + +/// SafeType for byte (8 bits). +/// +/// This is a separate struct from [`CompactSafeType`] with the same behavior. Because +/// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid +/// using [`CompactSafeType`] to avoid the additional heap allocation from a length 1 vector. +#[derive(Clone, Copy, Debug)] +pub struct SafeByte(pub(super) AssignedValue); + +macro_rules! safe_primitive_impls { + ($SafePrimitive:ty) => { + impl AsRef> for $SafePrimitive { + fn as_ref(&self) -> &AssignedValue { + &self.0 + } + } + + impl AsMut> for $SafePrimitive { + fn as_mut(&mut self) -> &mut AssignedValue { + &mut self.0 + } + } + + impl Borrow> for $SafePrimitive { + fn borrow(&self) -> &AssignedValue { + &self.0 + } + } + + impl BorrowMut> for $SafePrimitive { + fn borrow_mut(&mut self) -> &mut AssignedValue { + &mut self.0 + } + } + + impl From<$SafePrimitive> for AssignedValue { + fn from(safe_primitive: $SafePrimitive) -> Self { + safe_primitive.0 + } + } + }; +} + +safe_primitive_impls!(SafeBool); +safe_primitive_impls!(SafeByte); diff --git a/halo2-base/src/safe_types/tests/bytes.rs b/halo2-base/src/safe_types/tests/bytes.rs new file mode 100644 index 00000000..9c24444f --- /dev/null +++ b/halo2-base/src/safe_types/tests/bytes.rs @@ -0,0 +1,235 @@ +use crate::{ + gates::{circuit::builder::RangeCircuitBuilder, RangeInstructions}, + halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr}, + plonk::{keygen_pk, keygen_vk}, + poly::kzg::commitment::ParamsKZG, + }, + safe_types::SafeTypeChip, + utils::{ + testing::{base_test, check_proof, gen_proof}, + ScalarField, + }, + Context, +}; +use rand::rngs::OsRng; +use std::vec; +use test_case::test_case; + +// =========== Utilies =============== +fn mock_circuit_test, SafeTypeChip<'_, Fr>)>(mut f: FM) { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + f(ctx, safe); + }); +} + +// =========== Mock Prover =========== + +// Circuit Satisfied for valid inputs +#[test] +fn pos_var_len_bytes() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + let len = ctx.load_witness(Fr::from(3u64)); + safe.raw_to_var_len_bytes::<4>(ctx, bytes.clone().try_into().unwrap(), len); + + // check edge case len == MAX_LEN + let len = ctx.load_witness(Fr::from(4u64)); + safe.raw_to_var_len_bytes::<4>(ctx, bytes.try_into().unwrap(), len); + }); +} + +#[test_case(vec![1,2,3], 4 => vec![0,1,2,3]; "pos left pad 3 to 4")] +#[test_case(vec![1,2,3], 5 => vec![0,0,1,2,3]; "pos left pad 3 to 5")] +#[test_case(vec![1,2,3], 6 => vec![0,0,0,1,2,3]; "pos left pad 3 to 6")] +fn left_pad_var_len_bytes(mut bytes: Vec, max_len: usize) -> Vec { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let len = bytes.len(); + bytes.resize(max_len, 0); + let bytes = ctx.assign_witnesses(bytes.into_iter().map(|b| Fr::from(b as u64))); + let len = ctx.load_witness(Fr::from(len as u64)); + let bytes = safe.raw_to_var_len_bytes_vec(ctx, bytes, len, max_len); + let padded = bytes.left_pad_to_fixed(ctx, range.gate()); + padded.bytes().iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect() + }) +} + +// Checks circuit is unsatisfied for AssignedValue's are not in range 0..256 +#[test] +#[should_panic(expected = "circuit was not satisfied")] +fn neg_var_len_bytes_witness_values_not_bytes() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(3u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + }); +} + +// Checks assertion len <= max_len +#[test] +#[should_panic] +fn neg_var_len_bytes_len_less_than_max_len() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(5u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_var_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + let len = ctx.load_witness(Fr::from(3u64)); + safe.raw_to_var_len_bytes_vec(ctx, bytes.clone(), len, 4); + + // check edge case len == MAX_LEN + let len = ctx.load_witness(Fr::from(4u64)); + safe.raw_to_var_len_bytes_vec(ctx, bytes, len, 4); + }); +} + +// Checks circuit is unsatisfied for AssignedValue's are not in range 0..256 +#[test] +#[should_panic(expected = "circuit was not satisfied")] +fn neg_var_len_bytes_vec_witness_values_not_bytes() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(3u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + let max_len = fake_bytes.len(); + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, max_len); + }); +} + +// Checks assertion len <= max_len +#[test] +#[should_panic] +fn neg_var_len_bytes_vec_len_less_than_max_len() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(5u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + let max_len = 4; + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, max_len); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_fix_len_bytes() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap()); + }); +} + +// Assert inputs.len() == len +#[test] +#[should_panic] +fn neg_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 5); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 4); + }); +} + +// =========== Prover =========== +#[test] +fn pos_prover_satisfied() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 4; + let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); + let proof_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); + prover_satisfied::(keygen_inputs, proof_inputs); +} + +#[test] +fn pos_diff_len_same_max_len() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 4; + let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); + let proof_inputs = (vec![1u64, 2u64, 3u64, 4u64], 2); + prover_satisfied::(keygen_inputs, proof_inputs); +} + +#[test] +#[should_panic] +fn neg_different_proof_max_len() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 3; + let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 4); + let proof_inputs = (vec![1u64, 2u64, 3u64], 3); + prover_satisfied::(keygen_inputs, proof_inputs); +} + +// test circuit +fn var_byte_array_circuit( + k: usize, + witness_gen_only: bool, + (bytes, len): (Vec, usize), +) -> RangeCircuitBuilder { + let lookup_bits = 3; + let mut builder = + RangeCircuitBuilder::new(witness_gen_only).use_k(k).use_lookup_bits(lookup_bits); + let range = builder.range_chip(); + let safe = SafeTypeChip::new(&range); + let ctx = builder.main(0); + let len = ctx.load_witness(Fr::from(len as u64)); + let fake_bytes = ctx.assign_witnesses(bytes.into_iter().map(Fr::from).collect::>()); + safe.raw_to_var_len_bytes::(ctx, fake_bytes.try_into().unwrap(), len); + builder.calculate_params(Some(9)); + builder +} + +// Prover test +fn prover_satisfied( + keygen_inputs: (Vec, usize), + proof_inputs: (Vec, usize), +) { + let k = 11; + let rng = OsRng; + let params = ParamsKZG::::setup(k as u32, rng); + let keygen_circuit = var_byte_array_circuit::(k, false, keygen_inputs); + let vk = keygen_vk(¶ms, &keygen_circuit).unwrap(); + let pk = keygen_pk(¶ms, vk.clone(), &keygen_circuit).unwrap(); + let break_points = keygen_circuit.break_points(); + + let mut proof_circuit = var_byte_array_circuit::(k, true, proof_inputs); + proof_circuit.set_break_points(break_points); + let proof = gen_proof(¶ms, &pk, proof_circuit); + check_proof(¶ms, &vk, &proof[..], true); +} diff --git a/halo2-base/src/safe_types/tests/mod.rs b/halo2-base/src/safe_types/tests/mod.rs new file mode 100644 index 00000000..ee37540f --- /dev/null +++ b/halo2-base/src/safe_types/tests/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod bytes; +pub(crate) mod safe_type; diff --git a/halo2-base/src/safe_types/tests.rs b/halo2-base/src/safe_types/tests/safe_type.rs similarity index 89% rename from halo2-base/src/safe_types/tests.rs rename to halo2-base/src/safe_types/tests/safe_type.rs index 14480fdd..96a43800 100644 --- a/halo2-base/src/safe_types/tests.rs +++ b/halo2-base/src/safe_types/tests/safe_type.rs @@ -1,22 +1,12 @@ use crate::{ + gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}, + halo2_proofs::plonk::{keygen_pk, keygen_vk, Assigned}, halo2_proofs::{halo2curves::bn256::Fr, poly::kzg::commitment::ParamsKZG}, + safe_types::*, utils::testing::{check_proof, gen_proof}, }; - -use super::*; -use crate::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - RangeChip, - }, - halo2_proofs::{ - plonk::keygen_pk, - plonk::{keygen_vk, Assigned}, - }, -}; use itertools::Itertools; use rand::rngs::OsRng; -use std::env; // soundness checks for `raw_bytes_to` function fn test_raw_bytes_to_gen( @@ -26,10 +16,11 @@ fn test_raw_bytes_to_gen( expect_satisfied: bool, ) { // first create proving and verifying key - let mut builder = GateThreadBuilder::::keygen(); let lookup_bits = 3; - env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range_chip = RangeChip::::default(lookup_bits); + let mut builder = RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen) + .use_k(k as usize) + .use_lookup_bits(lookup_bits); + let range_chip = builder.range_chip(); let safe_type_chip = SafeTypeChip::new(&range_chip); let dummy_raw_bytes = builder @@ -41,20 +32,20 @@ fn test_raw_bytes_to_gen( // get the offsets of the safe value cells for later 'pranking' let safe_value_offsets = safe_value.value().iter().map(|v| v.cell.unwrap().offset).collect::>(); - // set env vars - builder.config(k as usize, Some(9)); - let circuit = RangeCircuitBuilder::keygen(builder); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::setup(k, OsRng); // generate proving key - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = keygen_vk(¶ms, &builder).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); let vk = pk.get_vk(); // pk consumed vk + let break_points = builder.break_points(); + drop(builder); // now create different proofs to test the soundness of the circuit let gen_pf = |inputs: &[Fr], outputs: &[Fr]| { - let mut builder = GateThreadBuilder::::prover(); - let range_chip = RangeChip::::default(lookup_bits); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); + let range_chip = builder.range_chip(); let safe_type_chip = SafeTypeChip::new(&range_chip); let assigned_raw_bytes = builder.main(0).assign_witnesses(inputs.to_vec()); @@ -64,8 +55,7 @@ fn test_raw_bytes_to_gen( for (offset, witness) in safe_value_offsets.iter().zip_eq(outputs) { builder.main(0).advice[*offset] = Assigned::::Trivial(*witness); } - let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points - gen_proof(¶ms, &pk, circuit) + gen_proof(¶ms, &pk, builder) }; let pf = gen_pf(raw_bytes, outputs); check_proof(¶ms, vk, &pf, expect_satisfied); diff --git a/halo2-base/src/utils/halo2.rs b/halo2-base/src/utils/halo2.rs new file mode 100644 index 00000000..510f7d25 --- /dev/null +++ b/halo2-base/src/utils/halo2.rs @@ -0,0 +1,73 @@ +use crate::ff::Field; +use crate::halo2_proofs::{ + circuit::{AssignedCell, Cell, Region, Value}, + plonk::{Advice, Assigned, Column, Fixed}, +}; + +/// Raw (physical) assigned cell in Plonkish arithmetization. +#[cfg(feature = "halo2-axiom")] +pub type Halo2AssignedCell<'v, F> = AssignedCell<&'v Assigned, F>; +/// Raw (physical) assigned cell in Plonkish arithmetization. +#[cfg(not(feature = "halo2-axiom"))] +pub type Halo2AssignedCell<'v, F> = AssignedCell, F>; + +/// Assign advice to physical region. +#[inline(always)] +pub fn raw_assign_advice<'v, F: Field>( + region: &mut Region, + column: Column, + offset: usize, + value: Value>>, +) -> Halo2AssignedCell<'v, F> { + #[cfg(feature = "halo2-axiom")] + { + region.assign_advice(column, offset, value) + } + #[cfg(feature = "halo2-pse")] + { + let value = value.map(|a| Into::>::into(a)); + region + .assign_advice( + || format!("assign advice {column:?} offset {offset}"), + column, + offset, + || value, + ) + .unwrap() + } +} + +/// Assign fixed to physical region. +#[inline(always)] +pub fn raw_assign_fixed( + region: &mut Region, + column: Column, + offset: usize, + value: F, +) -> Cell { + #[cfg(feature = "halo2-axiom")] + { + region.assign_fixed(column, offset, value) + } + #[cfg(feature = "halo2-pse")] + { + region + .assign_fixed( + || format!("assign fixed {column:?} offset {offset}"), + column, + offset, + || Value::known(value), + ) + .unwrap() + .cell() + } +} + +/// Constrain two physical cells to be equal. +#[inline(always)] +pub fn raw_constrain_equal(region: &mut Region, left: Cell, right: Cell) { + #[cfg(feature = "halo2-axiom")] + region.constrain_equal(left, right); + #[cfg(not(feature = "halo2-axiom"))] + region.constrain_equal(left, right).unwrap(); +} diff --git a/halo2-base/src/utils.rs b/halo2-base/src/utils/mod.rs similarity index 85% rename from halo2-base/src/utils.rs rename to halo2-base/src/utils/mod.rs index 69c1a1f9..2aaa5166 100644 --- a/halo2-base/src/utils.rs +++ b/halo2-base/src/utils/mod.rs @@ -1,17 +1,52 @@ -#[cfg(feature = "halo2-pse")] -use crate::halo2_proofs::arithmetic::CurveAffine; -use crate::halo2_proofs::{arithmetic::FieldExt, circuit::Value}; use core::hash::Hash; + +use crate::ff::{FromUniformBytes, PrimeField}; +#[cfg(not(feature = "halo2-axiom"))] +use crate::halo2_proofs::arithmetic::CurveAffine; +use crate::halo2_proofs::circuit::Value; +#[cfg(feature = "halo2-axiom")] +pub use crate::halo2_proofs::halo2curves::CurveAffineExt; + use num_bigint::BigInt; use num_bigint::BigUint; use num_bigint::Sign; use num_traits::Signed; use num_traits::{One, Zero}; +/// Helper functions for raw halo2 operations to unify slight differences in API for halo2-axiom and halo2-pse +pub mod halo2; +#[cfg(any(test, feature = "test-utils"))] +pub mod testing; + +/// Helper trait to convert to and from a [BigPrimeField] by converting a list of [u64] digits +#[cfg(feature = "halo2-axiom")] +pub trait BigPrimeField: ScalarField { + /// Converts a slice of [u64] to [BigPrimeField] + /// * `val`: the slice of u64 + /// + /// # Assumptions + /// * `val` has the correct length for the implementation + /// * The integer value of `val` is already less than the modulus of `Self` + fn from_u64_digits(val: &[u64]) -> Self; +} +#[cfg(feature = "halo2-axiom")] +impl BigPrimeField for F +where + F: ScalarField + From<[u64; 4]>, // Assume [u64; 4] is little-endian. We only implement ScalarField when this is true. +{ + #[inline(always)] + fn from_u64_digits(val: &[u64]) -> Self { + debug_assert!(val.len() <= 4); + let mut raw = [0u64; 4]; + raw[..val.len()].copy_from_slice(val); + Self::from(raw) + } +} + /// Helper trait to represent a field element that can be converted into [u64] limbs. /// /// Note: Since the number of bits necessary to represent a field element is larger than the number of bits in a u64, we decompose the integer representation of the field element into multiple [u64] values e.g. `limbs`. -pub trait ScalarField: FieldExt + Hash { +pub trait ScalarField: PrimeField + FromUniformBytes<64> + From + Hash + Ord { /// Returns the base `2bit_len` little endian representation of the [ScalarField] element up to `num_limbs` number of limbs (truncates any extra limbs). /// /// Assumes `bit_len < 64`. @@ -20,7 +55,9 @@ pub trait ScalarField: FieldExt + Hash { fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec; /// Returns the little endian byte representation of the element. - fn to_bytes_le(&self) -> Vec; + fn to_bytes_le(&self) -> Vec { + self.to_repr().as_ref().to_vec() + } /// Creates a field element from a little endian byte representation. /// @@ -31,11 +68,34 @@ pub trait ScalarField: FieldExt + Hash { repr.as_mut()[..bytes.len()].copy_from_slice(bytes); Self::from_repr(repr).unwrap() } + + /// Gets the least significant 32 bits of the field element. + fn get_lower_32(&self) -> u32 { + let bytes = self.to_bytes_le(); + let mut lower_32 = 0u32; + for (i, byte) in bytes.into_iter().enumerate().take(4) { + lower_32 |= (byte as u32) << (i * 8); + } + lower_32 + } + + /// Gets the least significant 64 bits of the field element. + fn get_lower_64(&self) -> u64 { + let bytes = self.to_bytes_le(); + let mut lower_64 = 0u64; + for (i, byte) in bytes.into_iter().enumerate().take(8) { + lower_64 |= (byte as u64) << (i * 8); + } + lower_64 + } } // See below for implementations // Later: will need to separate BigPrimeField from ScalarField when Goldilocks is introduced -pub trait BigPrimeField = ScalarField; + +/// [ScalarField] that is ~256 bits long +#[cfg(feature = "halo2-pse")] +pub trait BigPrimeField = PrimeField + ScalarField; /// Converts an [Iterator] of u64 digits into `number_of_limbs` limbs of `bit_len` bits returned as a [Vec]. /// @@ -91,7 +151,7 @@ pub(crate) fn decompose_u64_digits_to_limbs( } /// Returns the number of bits needed to represent the value of `x`. -pub fn bit_length(x: u64) -> usize { +pub const fn bit_length(x: u64) -> usize { (u64::BITS - x.leading_zeros()) as usize } @@ -104,7 +164,7 @@ pub fn log2_ceil(x: u64) -> usize { /// Returns the modulus of [BigPrimeField]. pub fn modulus() -> BigUint { - fe_to_biguint(&-F::one()) + 1u64 + fe_to_biguint(&-F::ONE) + 1u64 } /// Returns the [BigPrimeField] element of 2n. @@ -290,13 +350,10 @@ pub fn compose(input: Vec, bit_len: usize) -> BigUint { input.iter().rev().fold(BigUint::zero(), |acc, val| (acc << bit_len) + val) } -#[cfg(feature = "halo2-axiom")] -pub use halo2_proofs_axiom::halo2curves::CurveAffineExt; - /// Helper trait #[cfg(feature = "halo2-pse")] pub trait CurveAffineExt: CurveAffine { - /// Unlike the `Coordinates` trait, this just returns the raw affine (X, Y) coordinantes without checking `is_on_curve` + /// Returns the raw affine (X, Y) coordinantes fn into_coordinates(self) -> (Self::Base, Self::Base) { let coordinates = self.coordinates().unwrap(); (*coordinates.x(), *coordinates.y()) @@ -306,18 +363,20 @@ pub trait CurveAffineExt: CurveAffine { impl CurveAffineExt for C {} mod scalar_field_impls { + use std::hash::Hash; + use num_bigint::BigUint; use super::{decompose_u64_digits_to_limbs, ScalarField}; - use crate::halo2_proofs::halo2curves::FieldExt; - use std::hash::Hash; + + use crate::ff::{FromUniformBytes, PrimeField}; /// We do a blanket implementation in 'community-edition' to make it easier to integrate with other crates. /// /// ASSUMING F::Repr is little-endian impl ScalarField for F where - F: FieldExt + Hash, + F: PrimeField + FromUniformBytes<64> + From + Hash + Ord, { #[inline(always)] fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { @@ -326,11 +385,6 @@ mod scalar_field_impls { let digits = uint.iter_u64_digits(); decompose_u64_digits_to_limbs(digits, num_limbs, bit_len) } - - #[inline(always)] - fn to_bytes_le(&self) -> Vec { - self.to_repr().as_ref().to_vec() - } } } @@ -401,73 +455,14 @@ pub mod fs { } } -/// Utilities for testing -#[cfg(any(test, feature = "test-utils"))] -pub mod testing { - use crate::halo2_proofs::{ - halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, - multiopen::VerifierSHPLONK, strategy::SingleStrategy, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, - }; - use rand::rngs::OsRng; - - /// helper function to generate a proof with real prover - pub fn gen_proof( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: impl Circuit, - ) -> Vec { - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255<_>, - _, - Blake2bWrite, G1Affine, _>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); - transcript.finalize() - } - - /// helper function to verify a proof - pub fn check_proof( - params: &ParamsKZG, - vk: &VerifyingKey, - proof: &[u8], - expect_satisfied: bool, - ) { - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(params); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); - let res = verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, vk, strategy, &[&[]], &mut transcript); - - if expect_satisfied { - assert!(res.is_ok()); - } else { - assert!(res.is_err()); - } - } -} - #[cfg(test)] mod tests { use crate::halo2_proofs::halo2curves::bn256::Fr; use num_bigint::RandomBits; - use rand::{rngs::OsRng, Rng}; + use rand::{ + rngs::{OsRng, StdRng}, + Rng, SeedableRng, + }; use std::ops::Shl; use super::*; @@ -539,4 +534,23 @@ mod tests { fn test_log2_ceil_zero() { assert_eq!(log2_ceil(0), 0); } + + #[test] + fn test_get_lower_32() { + let mut rng = StdRng::seed_from_u64(0); + for _ in 0..10_000usize { + let e: u32 = rng.gen_range(0..u32::MAX); + assert_eq!(Fr::from(e as u64).get_lower_32(), e); + } + assert_eq!(Fr::from((1u64 << 32_i32) + 1).get_lower_32(), 1); + } + + #[test] + fn test_get_lower_64() { + let mut rng = StdRng::seed_from_u64(0); + for _ in 0..10_000usize { + let e: u64 = rng.gen_range(0..u64::MAX); + assert_eq!(Fr::from(e).get_lower_64(), e); + } + } } diff --git a/halo2-base/src/utils/testing.rs b/halo2-base/src/utils/testing.rs new file mode 100644 index 00000000..a4608df1 --- /dev/null +++ b/halo2-base/src/utils/testing.rs @@ -0,0 +1,264 @@ +//! Utilities for testing +use crate::{ + gates::{ + circuit::{builder::RangeCircuitBuilder, BaseCircuitParams, CircuitBuilderStage}, + flex_gate::threads::SinglePhaseCoreManager, + GateChip, RangeChip, + }, + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::{ + create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey, + }, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, + multiopen::VerifierSHPLONK, strategy::SingleStrategy, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, + }, + Context, +}; +use ark_std::{end_timer, perf_trace::TimerInfo, start_timer}; +use rand::{rngs::StdRng, SeedableRng}; + +use super::fs::gen_srs; + +/// Helper function to generate a proof with real prover using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn gen_proof_with_instances( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, + instances: &[&[Fr]], +) -> Vec { + let rng = StdRng::seed_from_u64(0); + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255<_>, + _, + Blake2bWrite, G1Affine, _>, + _, + >(params, pk, &[circuit], &[instances], rng, &mut transcript) + .expect("prover should not fail"); + transcript.finalize() +} + +/// For testing use only: Helper function to generate a proof **without public instances** with real prover using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn gen_proof( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, +) -> Vec { + gen_proof_with_instances(params, pk, circuit, &[]) +} + +/// Helper function to verify a proof (generated using [`gen_proof_with_instances`]) using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn check_proof_with_instances( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + instances: &[&[Fr]], + expect_satisfied: bool, +) { + let verifier_params = params.verifier_params(); + let strategy = SingleStrategy::new(params); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); + let res = verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >(verifier_params, vk, strategy, &[instances], &mut transcript); + // Just FYI, because strategy is `SingleStrategy`, the output `res` is `Result<(), Error>`, so there is no need to call `res.finalize()`. + + if expect_satisfied { + res.unwrap(); + } else { + assert!(res.is_err()); + } +} + +/// For testing only: Helper function to verify a proof (generated using [`gen_proof`]) without public instances using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn check_proof( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + expect_satisfied: bool, +) { + check_proof_with_instances(params, vk, proof, &[], expect_satisfied); +} + +/// Helper to facilitate easier writing of tests using `RangeChip` and `RangeCircuitBuilder`. +/// By default, the [`MockProver`] is used. +/// +/// Currently this tester uses all private inputs. +pub struct BaseTester { + k: u32, + lookup_bits: Option, + expect_satisfied: bool, + unusable_rows: usize, +} + +impl Default for BaseTester { + fn default() -> Self { + Self { k: 10, lookup_bits: Some(9), expect_satisfied: true, unusable_rows: 9 } + } +} + +/// Creates a [`BaseTester`] +pub fn base_test() -> BaseTester { + BaseTester::default() +} + +impl BaseTester { + /// Changes the number of rows in the circuit to 2k. + /// By default it will also set lookup bits as large as possible, to `k - 1`. + pub fn k(mut self, k: u32) -> Self { + self.k = k; + self.lookup_bits = Some(k as usize - 1); + self + } + + /// Sets the size of the lookup table used for range checks to [0, 2lookup_bits) + pub fn lookup_bits(mut self, lookup_bits: usize) -> Self { + assert!(lookup_bits < self.k as usize, "lookup_bits must be less than k"); + self.lookup_bits = Some(lookup_bits); + self + } + + /// Specify whether you expect this test to pass or fail. Default: pass + pub fn expect_satisfied(mut self, expect_satisfied: bool) -> Self { + self.expect_satisfied = expect_satisfied; + self + } + + /// Set the number of blinding (poisoned) rows + pub fn unusable_rows(mut self, unusable_rows: usize) -> Self { + self.unusable_rows = unusable_rows; + self + } + + /// Run a mock test by providing a closure that uses a `ctx` and `RangeChip`. + /// - `expect_satisfied`: flag for whether you expect the test to pass or fail. Failure means a constraint system failure -- the tester does not catch system panics. + pub fn run(&self, f: impl FnOnce(&mut Context, &RangeChip) -> R) -> R { + self.run_builder(|builder, range| f(builder.main(), range)) + } + + /// Run a mock test by providing a closure that uses a `ctx` and `GateChip`. + /// - `expect_satisfied`: flag for whether you expect the test to pass or fail. Failure means a constraint system failure -- the tester does not catch system panics. + pub fn run_gate(&self, f: impl FnOnce(&mut Context, &GateChip) -> R) -> R { + self.run(|ctx, range| f(ctx, &range.gate)) + } + + /// Run a mock test by providing a closure that uses a `builder` and `RangeChip`. + pub fn run_builder( + &self, + f: impl FnOnce(&mut SinglePhaseCoreManager, &RangeChip) -> R, + ) -> R { + let mut builder = RangeCircuitBuilder::default().use_k(self.k as usize); + if let Some(lb) = self.lookup_bits { + builder.set_lookup_bits(lb) + } + let range = RangeChip::new(self.lookup_bits.unwrap_or(0), builder.lookup_manager().clone()); + // run the function, mutating `builder` + let res = f(builder.pool(0), &range); + + // helper check: if your function didn't use lookups, turn lookup table "off" + let t_cells_lookup = + builder.lookup_manager().iter().map(|lm| lm.total_rows()).sum::(); + let lookup_bits = if t_cells_lookup == 0 { None } else { self.lookup_bits }; + builder.config_params.lookup_bits = lookup_bits; + + // configure the circuit shape, 9 blinding rows seems enough + builder.calculate_params(Some(self.unusable_rows)); + if self.expect_satisfied { + MockProver::run(self.k, &builder, vec![]).unwrap().assert_satisfied(); + } else { + assert!(MockProver::run(self.k, &builder, vec![]).unwrap().verify().is_err()); + } + res + } + + /// Runs keygen, real prover, and verifier by providing a closure that uses a `builder` and `RangeChip`. + /// + /// Must provide `init_input` for use during key generation, which is preferably not equal to `logic_input`. + /// These are the inputs to the closure, not necessary public inputs to the circuit. + /// + /// Currently for testing, no public instances. + pub fn bench_builder( + &self, + init_input: I, + logic_input: I, + f: impl Fn(&mut SinglePhaseCoreManager, &RangeChip, I), + ) -> BenchStats { + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(self.k as usize); + if let Some(lb) = self.lookup_bits { + builder.set_lookup_bits(lb) + } + let range = RangeChip::new(self.lookup_bits.unwrap_or(0), builder.lookup_manager().clone()); + // run the function, mutating `builder` + f(builder.pool(0), &range, init_input); + + // helper check: if your function didn't use lookups, turn lookup table "off" + let t_cells_lookup = + builder.lookup_manager().iter().map(|lm| lm.total_rows()).sum::(); + let lookup_bits = if t_cells_lookup == 0 { None } else { self.lookup_bits }; + builder.config_params.lookup_bits = lookup_bits; + + // configure the circuit shape, 9 blinding rows seems enough + let config_params = builder.calculate_params(Some(self.unusable_rows)); + + let params = gen_srs(self.k); + let vk_time = start_timer!(|| "Generating vkey"); + let vk = keygen_vk(¶ms, &builder).unwrap(); + end_timer!(vk_time); + let pk_time = start_timer!(|| "Generating pkey"); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); + end_timer!(pk_time); + + let break_points = builder.break_points(); + drop(builder); + // create real proof + let proof_time = start_timer!(|| "Proving time"); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points); + let range = RangeChip::new(self.lookup_bits.unwrap_or(0), builder.lookup_manager().clone()); + f(builder.pool(0), &range, logic_input); + let proof = gen_proof(¶ms, &pk, builder); + end_timer!(proof_time); + + let proof_size = proof.len(); + + let verify_time = start_timer!(|| "Verify time"); + check_proof(¶ms, pk.get_vk(), &proof, self.expect_satisfied); + end_timer!(verify_time); + + BenchStats { config_params, vk_time, pk_time, proof_time, proof_size, verify_time } + } +} + +/// Bench stats +pub struct BenchStats { + /// Config params + pub config_params: BaseCircuitParams, + /// Vkey gen time + pub vk_time: TimerInfo, + /// Pkey gen time + pub pk_time: TimerInfo, + /// Proving time + pub proof_time: TimerInfo, + /// Proof size in bytes + pub proof_size: usize, + /// Verify time + pub verify_time: TimerInfo, +} diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs new file mode 100644 index 00000000..d9fe6742 --- /dev/null +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -0,0 +1,173 @@ +use std::any::TypeId; +use std::collections::{BTreeMap, HashMap}; +use std::ops::DerefMut; +use std::sync::{Arc, Mutex, OnceLock}; + +use itertools::Itertools; +use rayon::slice::ParallelSliceMut; + +use crate::halo2_proofs::{ + circuit::{Cell, Region}, + plonk::{Assigned, Column, Fixed}, +}; +use crate::utils::halo2::{raw_assign_fixed, raw_constrain_equal, Halo2AssignedCell}; +use crate::AssignedValue; +use crate::{ff::Field, ContextCell}; + +use super::manager::VirtualRegionManager; + +/// Thread-safe shared global manager for all copy constraints. +pub type SharedCopyConstraintManager = Arc>>; + +/// Global manager for all copy constraints. Thread-safe. +/// +/// This will only be accessed during key generation, not proof generation, so it does not need to be optimized. +/// +/// Implements [VirtualRegionManager], which should be assigned only after all cells have been assigned +/// by other managers. +#[derive(Clone, Default, Debug)] +pub struct CopyConstraintManager { + /// A [Vec] tracking equality constraints between pairs of virtual advice cells, tagged by [ContextCell]. + /// These can be across different virtual regions. + pub advice_equalities: Vec<(ContextCell, ContextCell)>, + + /// A [Vec] tracking equality constraints between virtual advice cell and fixed values. + /// Fixed values will only be added once globally. + pub constant_equalities: Vec<(F, ContextCell)>, + + external_cell_count: usize, + + // In circuit assignments + /// Advice assignments, mapping from virtual [ContextCell] to assigned physical [Cell] + pub assigned_advices: HashMap, + /// Constant assignments, (key = constant, value = [Cell]) + pub assigned_constants: BTreeMap, + /// Flag for whether `assign_raw` has been called, for safety only. + assigned: OnceLock<()>, +} + +impl CopyConstraintManager { + /// Returns the number of distinct constants used. + pub fn num_distinct_constants(&self) -> usize { + self.constant_equalities.iter().map(|(x, _)| x).sorted().dedup().count() + } + + /// Adds external raw [Halo2AssignedCell] to `self.assigned_advices` and returns a new virtual [AssignedValue] + /// that can be used in any virtual region. No copy constraint is imposed, as the virtual cell "points" to the + /// raw assigned cell. The returned [ContextCell] will have `type_id` the `TypeId::of::()`. + pub fn load_external_assigned( + &mut self, + assigned_cell: Halo2AssignedCell, + ) -> AssignedValue { + let context_cell = self.load_external_cell(assigned_cell.cell()); + let mut value = Assigned::Trivial(F::ZERO); + assigned_cell.value().map(|v| { + #[cfg(feature = "halo2-axiom")] + { + value = **v; + } + #[cfg(not(feature = "halo2-axiom"))] + { + value = *v; + } + }); + AssignedValue { value, cell: Some(context_cell) } + } + + /// Adds external raw Halo2 cell to `self.assigned_advices` and returns a new virtual cell that can be + /// used as a tag (but will not be re-assigned). The returned [ContextCell] will have `type_id` the `TypeId::of::()`. + pub fn load_external_cell(&mut self, cell: Cell) -> ContextCell { + self.load_external_cell_impl(Some(cell)) + } + + /// Mock to load an external cell for base circuit simulation. If any mock external cell is loaded, calling [assign_raw] will panic. + pub fn mock_external_assigned(&mut self, v: F) -> AssignedValue { + let context_cell = self.load_external_cell_impl(None); + AssignedValue { value: Assigned::Trivial(v), cell: Some(context_cell) } + } + + fn load_external_cell_impl(&mut self, cell: Option) -> ContextCell { + let context_cell = ContextCell::new(TypeId::of::(), 0, self.external_cell_count); + self.external_cell_count += 1; + if let Some(cell) = cell { + self.assigned_advices.insert(context_cell, cell); + } + context_cell + } + + /// Clears state + pub fn clear(&mut self) { + self.advice_equalities.clear(); + self.constant_equalities.clear(); + self.assigned_advices.clear(); + self.assigned_constants.clear(); + self.external_cell_count = 0; + self.assigned.take(); + } +} + +impl Drop for CopyConstraintManager { + fn drop(&mut self) { + if self.assigned.get().is_some() { + return; + } + if !self.advice_equalities.is_empty() { + dbg!("WARNING: advice_equalities not empty"); + } + if !self.constant_equalities.is_empty() { + dbg!("WARNING: constant_equalities not empty"); + } + } +} + +impl VirtualRegionManager for SharedCopyConstraintManager { + // The fixed columns + type Config = Vec>; + + /// This should be the last manager to be assigned, after all other managers have assigned cells. + fn assign_raw(&self, config: &Self::Config, region: &mut Region) -> Self::Assignment { + let mut guard = self.lock().unwrap(); + let manager = guard.deref_mut(); + // sort by constant so constant assignment order is deterministic + // this is necessary because constants can be assigned by multiple CPU threads + // We further sort by ContextCell because the backend implementation of `raw_constrain_equal` (permutation argument) seems to depend on the order you specify copy constraints... + manager + .constant_equalities + .par_sort_unstable_by(|(c1, cell1), (c2, cell2)| c1.cmp(c2).then(cell1.cmp(cell2))); + // Assign fixed cells, we go left to right, then top to bottom, to avoid needing to know number of rows here + let mut fixed_col = 0; + let mut fixed_offset = 0; + for (c, _) in manager.constant_equalities.iter() { + if manager.assigned_constants.get(c).is_none() { + // this will panic if you run out of rows + let cell = raw_assign_fixed(region, config[fixed_col], fixed_offset, *c); + manager.assigned_constants.insert(*c, cell); + fixed_col += 1; + if fixed_col >= config.len() { + fixed_col = 0; + fixed_offset += 1; + } + } + } + + // Just in case: we sort by ContextCell because the backend implementation of `raw_constrain_equal` (permutation argument) seems to depend on the order you specify copy constraints... + manager.advice_equalities.par_sort_unstable(); + // Impose equality constraints between assigned advice cells + // At this point we assume all cells have been assigned by other VirtualRegionManagers + for (left, right) in &manager.advice_equalities { + let left = manager.assigned_advices.get(left).expect("virtual cell not assigned"); + let right = manager.assigned_advices.get(right).expect("virtual cell not assigned"); + raw_constrain_equal(region, *left, *right); + } + for (left, right) in &manager.constant_equalities { + let left = manager.assigned_constants[left]; + let right = manager.assigned_advices.get(right).expect("virtual cell not assigned"); + raw_constrain_equal(region, left, *right); + } + // We can't clear advice_equalities and constant_equalities because keygen_vk and keygen_pk will call this function twice + let _ = manager.assigned.set(()); + // When keygen_vk and keygen_pk are both run, you need to clear assigned constants + // so the second run still assigns constants in the pk + manager.assigned_constants.clear(); + } +} diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs new file mode 100644 index 00000000..bf82f211 --- /dev/null +++ b/halo2-base/src/virtual_region/lookups.rs @@ -0,0 +1,152 @@ +use std::collections::BTreeMap; +use std::sync::{Arc, Mutex, OnceLock}; + +use getset::{CopyGetters, Getters, Setters}; + +use crate::ff::Field; +use crate::halo2_proofs::{ + circuit::{Region, Value}, + plonk::{Advice, Column}, +}; +use crate::utils::halo2::raw_assign_advice; +use crate::{AssignedValue, ContextTag}; + +use super::copy_constraints::SharedCopyConstraintManager; +use super::manager::VirtualRegionManager; + +/// A manager that can be used for any lookup argument. This manager automates +/// the process of copying cells to designed advice columns with lookup enabled. +/// It also manages how many such advice columns are necessary. +/// +/// ## Detailed explanation +/// If we have a lookup argument that uses `ADVICE_COLS` advice columns and `TABLE_COLS` table columns, where +/// the table is either fixed or dynamic (advice), then we want to dynamically allocate chunks of `ADVICE_COLS` columns +/// that have the lookup into the table **always on** so that: +/// - every time we want to lookup [_; ADVICE_COLS] values, we copy them over to a row in the special +/// lookup-enabled advice columns. +/// - note that just for assignment, we don't need to know anything about the table itself. +/// Note: the manager does not need to know the value of `TABLE_COLS`. +/// +/// We want this manager to be CPU thread safe, while ensuring that the resulting circuit is +/// deterministic -- the order in which the cells to lookup are added matters. +/// The current solution is to tag the cells to lookup with the context id from the [Context] in which +/// it was called, and add virtual cells sequentially to buckets labelled by id. +/// The virtual cells will be assigned to physical cells sequentially by id. +/// We use a `BTreeMap` for the buckets instead of sorting to cells, to ensure that the order of the cells +/// within a bucket is deterministic. +/// The assumption is that the [Context] is thread-local. +/// +/// Cheap to clone across threads because everything is in [Arc]. +#[derive(Clone, Debug, Getters, CopyGetters, Setters)] +pub struct LookupAnyManager { + /// Shared cells to lookup, tagged by (type id, context id). + #[allow(clippy::type_complexity)] + pub cells_to_lookup: Arc; ADVICE_COLS]>>>>, + /// Global shared copy manager + #[getset(get = "pub", set = "pub")] + copy_manager: SharedCopyConstraintManager, + /// Specify whether constraints should be imposed for additional safety. + #[getset(get_copy = "pub")] + witness_gen_only: bool, + /// Flag for whether `assign_raw` has been called, for safety only. + pub(crate) assigned: Arc>, +} + +impl LookupAnyManager { + /// Creates a new [LookupAnyManager] with a given copy manager. + pub fn new(witness_gen_only: bool, copy_manager: SharedCopyConstraintManager) -> Self { + Self { + witness_gen_only, + cells_to_lookup: Default::default(), + copy_manager, + assigned: Default::default(), + } + } + + /// Add a lookup argument to the manager. + pub fn add_lookup(&self, tag: ContextTag, cells: [AssignedValue; ADVICE_COLS]) { + self.cells_to_lookup + .lock() + .unwrap() + .entry(tag) + .and_modify(|thread| thread.push(cells)) + .or_insert(vec![cells]); + } + + /// The total number of virtual rows needed to special lookups + pub fn total_rows(&self) -> usize { + self.cells_to_lookup.lock().unwrap().iter().flat_map(|(_, advices)| advices).count() + } + + /// The optimal number of `ADVICE_COLS` chunks of advice columns with lookup enabled for this + /// particular lookup argument that we should allocate. + pub fn num_advice_chunks(&self, usable_rows: usize) -> usize { + let total = self.total_rows(); + (total + usable_rows - 1) / usable_rows + } + + /// Clears state + pub fn clear(&mut self) { + self.cells_to_lookup.lock().unwrap().clear(); + self.copy_manager.lock().unwrap().clear(); + self.assigned = Arc::new(OnceLock::new()); + } + + /// Deep clone with the specified copy manager. Unsets `assigned`. + pub fn deep_clone(&self, copy_manager: SharedCopyConstraintManager) -> Self { + Self { + witness_gen_only: self.witness_gen_only, + cells_to_lookup: Arc::new(Mutex::new(self.cells_to_lookup.lock().unwrap().clone())), + copy_manager, + assigned: Default::default(), + } + } +} + +impl Drop for LookupAnyManager { + /// Sanity checks whether the manager has assigned cells to lookup, + /// to prevent user error. + fn drop(&mut self) { + if Arc::strong_count(&self.cells_to_lookup) > 1 { + return; + } + if self.total_rows() > 0 && self.assigned.get().is_none() { + dbg!("WARNING: LookupAnyManager was not assigned!"); + } + } +} + +impl VirtualRegionManager + for LookupAnyManager +{ + type Config = Vec<[Column; ADVICE_COLS]>; + + fn assign_raw(&self, config: &Self::Config, region: &mut Region) { + let cells_to_lookup = self.cells_to_lookup.lock().unwrap(); + // Copy the cells to the config columns, going left to right, then top to bottom. + // Will panic if out of rows + let mut lookup_offset = 0; + let mut lookup_col = 0; + for advices in cells_to_lookup.iter().flat_map(|(_, advices)| advices) { + if lookup_col >= config.len() { + lookup_col = 0; + lookup_offset += 1; + } + for (advice, &column) in advices.iter().zip(config[lookup_col].iter()) { + let bcell = + raw_assign_advice(region, column, lookup_offset, Value::known(advice.value)); + if !self.witness_gen_only { + let ctx_cell = advice.cell.unwrap(); + let copy_manager = self.copy_manager.lock().unwrap(); + let acell = + copy_manager.assigned_advices.get(&ctx_cell).expect("cell not assigned"); + region.constrain_equal(*acell, bcell.cell()); + } + } + + lookup_col += 1; + } + // We cannot clear `cells_to_lookup` because keygen_vk and keygen_pk both call this function + let _ = self.assigned.set(()); + } +} diff --git a/halo2-base/src/virtual_region/manager.rs b/halo2-base/src/virtual_region/manager.rs new file mode 100644 index 00000000..4abc5875 --- /dev/null +++ b/halo2-base/src/virtual_region/manager.rs @@ -0,0 +1,16 @@ +use crate::ff::Field; +use crate::halo2_proofs::circuit::Region; + +/// A virtual region manager is responsible for managing a virtual region and assigning the +/// virtual region to a physical Halo2 region. +/// +pub trait VirtualRegionManager { + /// The Halo2 config with associated columns and gates describing the physical Halo2 region + /// that this virtual region manager is responsible for. + type Config: Clone; + /// Return type of the `assign_raw` method. Default is `()`. + type Assignment = (); + + /// Assign virtual region this is in charge of to the raw region described by `config`. + fn assign_raw(&self, config: &Self::Config, region: &mut Region) -> Self::Assignment; +} diff --git a/halo2-base/src/virtual_region/mod.rs b/halo2-base/src/virtual_region/mod.rs new file mode 100644 index 00000000..47d4bbf4 --- /dev/null +++ b/halo2-base/src/virtual_region/mod.rs @@ -0,0 +1,15 @@ +//! Trait describing the shared properties for a struct that is in charge of managing a virtual region of a circuit +//! _and_ assigning that virtual region to a "raw" Halo2 region in the "physical" circuit. +//! +//! Currently a raw region refers to a subset of columns of the circuit, and spans all rows (so it is a vertical region), +//! but this is not a requirement of the trait. + +/// Shared copy constraints across different virtual regions +pub mod copy_constraints; +/// Virtual region manager for lookup tables +pub mod lookups; +/// Virtual region manager +pub mod manager; + +#[cfg(test)] +mod tests; diff --git a/halo2-base/src/virtual_region/tests/lookups/memory.rs b/halo2-base/src/virtual_region/tests/lookups/memory.rs new file mode 100644 index 00000000..66df4085 --- /dev/null +++ b/halo2-base/src/virtual_region/tests/lookups/memory.rs @@ -0,0 +1,212 @@ +use crate::halo2_proofs::{ + arithmetic::Field, + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + halo2curves::bn256::Fr, + plonk::{keygen_pk, keygen_vk, Advice, Circuit, Column, ConstraintSystem, Error}, + poly::Rotation, +}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use test_log::test; + +use crate::{ + gates::{ + flex_gate::{threads::SinglePhaseCoreManager, FlexGateConfig, FlexGateConfigParams}, + GateChip, GateInstructions, + }, + utils::{ + fs::gen_srs, + halo2::raw_assign_advice, + testing::{check_proof, gen_proof}, + ScalarField, + }, + virtual_region::{lookups::LookupAnyManager, manager::VirtualRegionManager}, +}; + +#[derive(Clone, Debug)] +struct RAMConfig { + cpu: FlexGateConfig, + copy: Vec<[Column; 2]>, + // dynamic lookup table + memory: [Column; 2], +} + +#[derive(Clone, Default)] +struct RAMConfigParams { + cpu: FlexGateConfigParams, + copy_columns: usize, +} + +struct RAMCircuit { + // private memory input + memory: Vec, + // memory accesses + ptrs: [usize; CYCLES], + + cpu: SinglePhaseCoreManager, + ram: LookupAnyManager, + + params: RAMConfigParams, +} + +impl RAMCircuit { + fn new( + memory: Vec, + ptrs: [usize; CYCLES], + params: RAMConfigParams, + witness_gen_only: bool, + ) -> Self { + let cpu = SinglePhaseCoreManager::new(witness_gen_only, Default::default()); + let ram = LookupAnyManager::new(witness_gen_only, cpu.copy_manager.clone()); + Self { memory, ptrs, cpu, ram, params } + } + + fn compute(&mut self) { + let gate = GateChip::default(); + let ctx = self.cpu.main(); + let mut sum = ctx.load_constant(F::ZERO); + for &ptr in &self.ptrs { + let value = self.memory[ptr]; + let ptr = ctx.load_witness(F::from(ptr as u64 + 1)); + let value = ctx.load_witness(value); + self.ram.add_lookup((ctx.type_id(), ctx.id()), [ptr, value]); + sum = gate.add(ctx, sum, value); + } + } +} + +impl Circuit for RAMCircuit { + type Config = RAMConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = RAMConfigParams; + + fn params(&self) -> Self::Params { + self.params.clone() + } + + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let k = params.cpu.k; + let mut cpu = FlexGateConfig::configure(meta, params.cpu); + let copy: Vec<_> = (0..params.copy_columns) + .map(|_| { + [(); 2].map(|_| { + let advice = meta.advice_column(); + meta.enable_equality(advice); + advice + }) + }) + .collect(); + let mem = [meta.advice_column(), meta.advice_column()]; + + for copy in © { + meta.lookup_any("dynamic memory lookup table", |meta| { + let mem = mem.map(|c| meta.query_advice(c, Rotation::cur())); + let copy = copy.map(|c| meta.query_advice(c, Rotation::cur())); + vec![(copy[0].clone(), mem[0].clone()), (copy[1].clone(), mem[1].clone())] + }); + } + log::info!("Poisoned rows: {}", meta.minimum_rows()); + cpu.max_rows = (1 << k) - meta.minimum_rows(); + + RAMConfig { cpu, copy, memory: mem } + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!() + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "RAM Circuit", + |mut region| { + // Raw assign the private memory inputs + for (i, &value) in self.memory.iter().enumerate() { + // I think there will always be (0, 0) in the table so we index starting from 1 + let idx = Value::known(F::from(i as u64 + 1)); + raw_assign_advice(&mut region, config.memory[0], i, idx); + raw_assign_advice(&mut region, config.memory[1], i, Value::known(value)); + } + self.cpu.assign_raw( + &(config.cpu.basic_gates[0].clone(), config.cpu.max_rows), + &mut region, + ); + self.ram.assign_raw(&config.copy, &mut region); + self.cpu.copy_manager.assign_raw(&config.cpu.constants, &mut region); + Ok(()) + }, + ) + } +} + +#[test] +fn test_ram_mock() { + let k = 5u32; + const CYCLES: usize = 50; + let mut rng = StdRng::seed_from_u64(0); + let mem_len = 16usize; + let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); + let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); + let usable_rows = 2usize.pow(k) - 11; // guess + let copy_columns = CYCLES / usable_rows + 1; + let params = RAMConfigParams::default(); + let mut circuit = RAMCircuit::new(memory, ptrs, params, false); + circuit.compute(); + // auto-configuration stuff + let num_advice = circuit.cpu.total_advice() / usable_rows + 1; + circuit.params.cpu = FlexGateConfigParams { + k: k as usize, + num_advice_per_phase: vec![num_advice], + num_fixed: 1, + }; + circuit.params.copy_columns = copy_columns; + MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_ram_prover() { + let k = 10u32; + const CYCLES: usize = 2000; + + let mut rng = StdRng::seed_from_u64(0); + let mem_len = 500; + + let memory = vec![Fr::ZERO; mem_len]; + let ptrs = [0; CYCLES]; + + let usable_rows = 2usize.pow(k) - 11; // guess + let copy_columns = CYCLES / usable_rows + 1; + let params = RAMConfigParams::default(); + let mut circuit = RAMCircuit::new(memory, ptrs, params, false); + circuit.compute(); + let num_advice = circuit.cpu.total_advice() / usable_rows + 1; + circuit.params.cpu = FlexGateConfigParams { + k: k as usize, + num_advice_per_phase: vec![num_advice], + num_fixed: 1, + }; + circuit.params.copy_columns = copy_columns; + + let params = gen_srs(k); + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let circuit_params = circuit.params(); + let break_points = circuit.cpu.break_points.borrow().clone().unwrap(); + drop(circuit); + + let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); + let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); + let mut circuit = RAMCircuit::new(memory, ptrs, circuit_params, true); + *circuit.cpu.break_points.borrow_mut() = Some(break_points); + circuit.compute(); + + let proof = gen_proof(¶ms, &pk, circuit); + check_proof(¶ms, pk.get_vk(), &proof, true); +} diff --git a/halo2-base/src/virtual_region/tests/lookups/mod.rs b/halo2-base/src/virtual_region/tests/lookups/mod.rs new file mode 100644 index 00000000..23635403 --- /dev/null +++ b/halo2-base/src/virtual_region/tests/lookups/mod.rs @@ -0,0 +1 @@ +mod memory; diff --git a/halo2-base/src/virtual_region/tests/mod.rs b/halo2-base/src/virtual_region/tests/mod.rs new file mode 100644 index 00000000..5b0a9bcb --- /dev/null +++ b/halo2-base/src/virtual_region/tests/mod.rs @@ -0,0 +1 @@ +mod lookups; diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index 0c33e387..d3da45fc 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -1,51 +1,51 @@ [package] -name = "halo2-ecc" -version = "0.3.0" -edition = "2021" +name="halo2-ecc" +version="0.4.0" +edition="2021" [dependencies] -itertools = "=0.10" -num-bigint = { version = "=0.4", features = ["rand"] } -num-integer = "=0.1" -num-traits = "=0.2" -rand_core = { version = "=0.6", default-features = false, features = ["getrandom"] } -rand = "=0.8" -rand_chacha = "=0.3.1" -serde = { version = "=1.0", features = ["derive"] } -serde_json = "=1.0" -rayon = "=1.7" -test-case = "=3.1.0" - -# arithmetic -ff = "=0.12" -group = "=0.12" - -halo2-base = { path = "../halo2-base", default-features = false } +itertools="0.10" +num-bigint={ version="0.4", features=["rand"] } +num-integer="0.1" +num-traits="0.2" +rand_core={ version="0.6", default-features=false, features=["getrandom"] } +rand="0.8" +rand_chacha="0.3.1" +serde={ version="1.0", features=["derive"] } +serde_json="1.0" +rayon="1.6.1" +test-case="3.1.0" + +halo2-base={ path="../halo2-base", default-features=false } [dev-dependencies] -ark-std = { version = "=0.3.0", features = ["print-trace"] } -pprof = { version = "=0.11", features = ["criterion", "flamegraph"] } -criterion = "=0.4" -criterion-macro = "=0.4" -halo2-base = { path = "../halo2-base", default-features = false, features = ["test-utils"] } +ark-std={ version="0.3.0", features=["print-trace"] } +pprof={ version="0.11", features=["criterion", "flamegraph"] } +criterion="0.4" +criterion-macro="0.4" +halo2-base={ path="../halo2-base", default-features=false, features=["test-utils"] } +test-log="0.2.12" +env_logger="0.10.0" +pairing="0.23.0" [features] -default = ["jemallocator", "halo2-axiom", "display"] -dev-graph = ["halo2-base/dev-graph"] -display = ["halo2-base/display"] -halo2-pse = ["halo2-base/halo2-pse"] -halo2-axiom = ["halo2-base/halo2-axiom"] -jemallocator = ["halo2-base/jemallocator"] -mimalloc = ["halo2-base/mimalloc"] +default=["jemallocator", "halo2-axiom", "display"] +dev-graph=["halo2-base/dev-graph"] +display=["halo2-base/display"] +asm=["halo2-base/asm"] +halo2-pse=["halo2-base/halo2-pse"] +halo2-axiom=["halo2-base/halo2-axiom"] +jemallocator=["halo2-base/jemallocator"] +mimalloc=["halo2-base/mimalloc"] [[bench]] -name = "fp_mul" -harness = false +name="fp_mul" +harness=false [[bench]] -name = "msm" -harness = false +name="msm" +harness=false [[bench]] -name = "fixed_base_msm" -harness = false +name="fixed_base_msm" +harness=false diff --git a/halo2-ecc/benches/fixed_base_msm.rs b/halo2-ecc/benches/fixed_base_msm.rs index b4f3df25..1db118bb 100644 --- a/halo2-ecc/benches/fixed_base_msm.rs +++ b/halo2-ecc/benches/fixed_base_msm.rs @@ -1,21 +1,16 @@ -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, - }, - RangeChip, -}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; +use halo2_base::gates::flex_gate::threads::SinglePhaseCoreManager; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; +use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; -use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use halo2_base::{gates::RangeChip, utils::testing::gen_proof}; +use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -40,22 +35,19 @@ const BEST_100_CONFIG: MSMCircuitParams = const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; fn fixed_base_msm_bench( - builder: &mut GateThreadBuilder, + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: MSMCircuitParams, bases: Vec, scalars: Vec, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let scalars_assigned = scalars - .iter() - .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) - .collect::>(); + let scalars_assigned = + scalars.iter().map(|scalar| vec![pool.main().load_witness(*scalar)]).collect::>(); - ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); + ecc_chip.fixed_base_msm(pool, &bases, scalars_assigned, Fr::NUM_BITS as usize); } fn fixed_base_msm_circuit( @@ -63,31 +55,22 @@ fn fixed_base_msm_circuit( stage: CircuitBuilderStage, bases: Vec, scalars: Vec, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = params.degree as usize; let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fixed_base_msm_bench(&mut builder, params, bases, scalars); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + _ => RangeCircuitBuilder::from_stage(stage).use_k(k).use_lookup_bits(params.lookup_bits), }; - end_timer!(start0); - circuit + let range = builder.range_chip(); + fixed_base_msm_bench(builder.pool(0), &range, params, bases, scalars); + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder } fn bench(c: &mut Criterion) { @@ -101,12 +84,14 @@ fn bench(c: &mut Criterion) { vec![G1Affine::generator(); config.batch_size], vec![Fr::zero(); config.batch_size], None, + None, ); + let config_params = circuit.params(); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = circuit.break_points(); drop(circuit); let (bases, scalars): (Vec<_>, Vec<_>) = @@ -123,19 +108,11 @@ fn bench(c: &mut Criterion) { CircuitBuilderStage::Prover, bases.clone(), scalars.clone(), + Some(config_params.clone()), Some(break_points.clone()), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-ecc/benches/fp_mul.rs b/halo2-ecc/benches/fp_mul.rs index 48351c45..0848ac5f 100644 --- a/halo2-ecc/benches/fp_mul.rs +++ b/halo2-ecc/benches/fp_mul.rs @@ -1,26 +1,22 @@ use ark_std::{end_timer, start_timer}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; +use halo2_base::gates::{ + circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}, + RangeChip, +}; use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, halo2_proofs::{ arithmetic::Field, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fq, Fr}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }, + utils::{testing::gen_proof, BigPrimeField}, Context, }; use halo2_ecc::fields::fp::FpChip; -use halo2_ecc::fields::{FieldChip, PrimeField}; +use halo2_ecc::fields::FieldChip; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -32,17 +28,15 @@ use pprof::criterion::{Output, PProfProfiler}; const K: u32 = 19; -fn fp_mul_bench( +fn fp_mul_bench( ctx: &mut Context, - lookup_bits: usize, + range: &RangeChip, limb_bits: usize, num_limbs: usize, _a: Fq, _b: Fq, ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range = RangeChip::::default(lookup_bits); - let chip = FpChip::::new(&range, limb_bits, num_limbs); + let chip = FpChip::::new(range, limb_bits, num_limbs); let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); for _ in 0..2857 { @@ -54,40 +48,36 @@ fn fp_mul_circuit( stage: CircuitBuilderStage, a: Fq, b: Fq, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = K as usize; + let lookup_bits = k - 1; let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) + } + _ => RangeCircuitBuilder::from_stage(stage).use_k(k).use_lookup_bits(lookup_bits), }; let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fp_mul_bench(builder.main(0), k - 1, 88, 3, a, b); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; + let range = builder.range_chip(); + fp_mul_bench(builder.main(0), &range, 88, 3, a, b); end_timer!(start0); - circuit + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder } fn bench(c: &mut Criterion) { - let circuit = fp_mul_circuit(CircuitBuilderStage::Keygen, Fq::zero(), Fq::zero(), None); + let circuit = fp_mul_circuit(CircuitBuilderStage::Keygen, Fq::zero(), Fq::zero(), None, None); + let config_params = circuit.params(); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = circuit.break_points(); let a = Fq::random(OsRng); let b = Fq::random(OsRng); @@ -98,19 +88,15 @@ fn bench(c: &mut Criterion) { &(¶ms, &pk, a, b), |bencher, &(params, pk, a, b)| { bencher.iter(|| { - let circuit = - fp_mul_circuit(CircuitBuilderStage::Prover, a, b, Some(break_points.clone())); + let circuit = fp_mul_circuit( + CircuitBuilderStage::Prover, + a, + b, + Some(config_params.clone()), + Some(break_points.clone()), + ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-ecc/benches/msm.rs b/halo2-ecc/benches/msm.rs index 3a98ee38..e4668d13 100644 --- a/halo2-ecc/benches/msm.rs +++ b/halo2-ecc/benches/msm.rs @@ -1,21 +1,20 @@ use ark_std::{end_timer, start_timer}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::flex_gate::threads::SinglePhaseCoreManager; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; use halo2_base::gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, - }, + circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}, RangeChip, }; +use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; -use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use halo2_base::utils::testing::gen_proof; +use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -46,17 +45,16 @@ const BEST_100_CONFIG: MSMCircuitParams = MSMCircuitParams { const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; fn msm_bench( - builder: &mut GateThreadBuilder, + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: MSMCircuitParams, bases: Vec, scalars: Vec, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let ctx = builder.main(0); + let ctx = pool.main(); let scalars_assigned = scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); let bases_assigned = bases @@ -64,13 +62,12 @@ fn msm_bench( .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) .collect::>(); - ecc_chip.variable_base_msm_in::( - builder, + ecc_chip.variable_base_msm_custom::( + pool, &bases_assigned, scalars_assigned, Fr::NUM_BITS as usize, params.clump_factor, - 0, ); } @@ -79,31 +76,24 @@ fn msm_circuit( stage: CircuitBuilderStage, bases: Vec, scalars: Vec, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); let k = params.degree as usize; let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - msm_bench(&mut builder, params, bases, scalars); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + _ => RangeCircuitBuilder::from_stage(stage).use_k(k).use_lookup_bits(params.lookup_bits), }; + let range = builder.range_chip(); + msm_bench(builder.pool(0), &range, params, bases, scalars); end_timer!(start0); - circuit + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder } fn bench(c: &mut Criterion) { @@ -117,12 +107,14 @@ fn bench(c: &mut Criterion) { vec![G1Affine::generator(); config.batch_size], vec![Fr::one(); config.batch_size], None, + None, ); + let config_params = circuit.params(); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = circuit.break_points(); drop(circuit); let (bases, scalars): (Vec<_>, Vec<_>) = @@ -139,19 +131,11 @@ fn bench(c: &mut Criterion) { CircuitBuilderStage::Prover, bases.clone(), scalars.clone(), + Some(config_params.clone()), Some(break_points.clone()), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-ecc/configs/bn254/bench_fixed_msm.t.config b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config index 61db5d6d..fb4be34a 100644 --- a/halo2-ecc/configs/bn254/bench_fixed_msm.t.config +++ b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config @@ -1,5 +1,2 @@ {"strategy":"Simple","degree":17,"num_advice":83,"num_lookup_advice":9,"num_fixed":7,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":18,"num_advice":42,"num_lookup_advice":5,"num_fixed":4,"lookup_bits":17,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":2,"num_fixed":2,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"radix":0,"clump_factor":4} \ No newline at end of file +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_msm.t.config b/halo2-ecc/configs/bn254/bench_msm.t.config index bd4c4318..f516d6cf 100644 --- a/halo2-ecc/configs/bn254/bench_msm.t.config +++ b/halo2-ecc/configs/bn254/bench_msm.t.config @@ -1,5 +1,2 @@ {"strategy":"Simple","degree":16,"num_advice":170,"num_lookup_advice":23,"num_fixed":1,"lookup_bits":15,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} -{"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} -{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"window_bits":4} -{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"window_bits":4} -{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"window_bits":4} \ No newline at end of file +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_pairing.t.config b/halo2-ecc/configs/bn254/bench_pairing.t.config index d76ebad1..ddaf65fa 100644 --- a/halo2-ecc/configs/bn254/bench_pairing.t.config +++ b/halo2-ecc/configs/bn254/bench_pairing.t.config @@ -1,5 +1 @@ -{"strategy":"Simple","degree":15,"num_advice":105,"num_lookup_advice":14,"num_fixed":1,"lookup_bits":14,"limb_bits":90,"num_limbs":3} -{"strategy":"Simple","degree":17,"num_advice":25,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3} -{"strategy":"Simple","degree":18,"num_advice":13,"num_lookup_advice":2,"num_fixed":1,"lookup_bits":17,"limb_bits":88,"num_limbs":3} -{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3} -{"strategy":"Simple","degree":20,"num_advice":3,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3} \ No newline at end of file +{"strategy":"Simple","degree":17,"num_advice":25,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/halo2-ecc/configs/secp256k1/bench_ecdsa.t.config b/halo2-ecc/configs/secp256k1/bench_ecdsa.t.config new file mode 100644 index 00000000..33fb34d8 --- /dev/null +++ b/halo2-ecc/configs/secp256k1/bench_ecdsa.t.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":15,"num_advice":17,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":14,"limb_bits":90,"num_limbs":3} \ No newline at end of file diff --git a/halo2-ecc/src/bigint/big_is_even.rs b/halo2-ecc/src/bigint/big_is_even.rs index 13f9e873..97fd10e9 100644 --- a/halo2-ecc/src/bigint/big_is_even.rs +++ b/halo2-ecc/src/bigint/big_is_even.rs @@ -1,7 +1,9 @@ use super::OverflowInteger; -use halo2_base::gates::GateInstructions; -use halo2_base::gates::RangeChip; -use halo2_base::{safe_types::RangeInstructions, utils::ScalarField, AssignedValue, Context}; +use halo2_base::{ + gates::{GateInstructions, RangeChip, RangeInstructions}, + utils::ScalarField, + AssignedValue, Context, +}; /// # Assumptions /// * `a` has nonzero number of limbs diff --git a/halo2-ecc/src/bigint/carry_mod.rs b/halo2-ecc/src/bigint/carry_mod.rs index a78fd32b..a9667d79 100644 --- a/halo2-ecc/src/bigint/carry_mod.rs +++ b/halo2-ecc/src/bigint/carry_mod.rs @@ -1,7 +1,7 @@ use std::{cmp::max, iter}; use halo2_base::{ - gates::{range::RangeStrategy, GateInstructions, RangeInstructions}, + gates::{GateInstructions, RangeInstructions}, utils::{decompose_bigint, BigPrimeField}, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, @@ -108,32 +108,27 @@ pub fn crt( ); // let gate_index = prod.column(); - let out_cell; - let check_cell; // perform step 2: compute prod - a + out let temp1 = *prod.value() - a_limb.value(); let check_val = temp1 + out_v; - match range.strategy() { - RangeStrategy::Vertical => { - // transpose of: - // | prod | -1 | a | prod - a | 1 | out | prod - a + out - // where prod is at relative row `offset` - ctx.assign_region( - [ - Constant(-F::one()), - Existing(a_limb), - Witness(temp1), - Constant(F::one()), - Witness(out_v), - Witness(check_val), - ], - [-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call - ); - check_cell = ctx.last().unwrap(); - out_cell = ctx.get(-2); - } - } + // transpose of: + // | prod | -1 | a | prod - a | 1 | out | prod - a + out + // where prod is at relative row `offset` + ctx.assign_region( + [ + Constant(-F::ONE), + Existing(a_limb), + Witness(temp1), + Constant(F::ONE), + Witness(out_v), + Witness(check_val), + ], + [-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call + ); + let check_cell = ctx.last().unwrap(); + let out_cell = ctx.get(-2); + quot_assigned.push(new_quot_cell); out_assigned.push(out_cell); check_assigned.push(check_cell); diff --git a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs index 6232cbdf..13523ba5 100644 --- a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs @@ -79,8 +79,8 @@ pub fn crt( // transpose of: // | prod | -1 | a | prod - a | let check_val = *prod.value() - a_limb.value(); - let check_cell = ctx - .assign_region_last([Constant(-F::one()), Existing(a_limb), Witness(check_val)], [-1]); + let check_cell = + ctx.assign_region_last([Constant(-F::ONE), Existing(a_limb), Witness(check_val)], [-1]); quot_assigned.push(new_quot_cell); check_assigned.push(check_cell); @@ -119,7 +119,7 @@ pub fn crt( // Check `0 + modulus * quotient - a = 0` in native field // | 0 | modulus | quotient | a | ctx.assign_region( - [Constant(F::zero()), Constant(mod_native), Existing(quot_native), Existing(a.native)], + [Constant(F::ZERO), Constant(mod_native), Existing(quot_native), Existing(a.native)], [0], ); } diff --git a/halo2-ecc/src/bigint/check_carry_to_zero.rs b/halo2-ecc/src/bigint/check_carry_to_zero.rs index fa2f5648..d445f7e5 100644 --- a/halo2-ecc/src/bigint/check_carry_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_to_zero.rs @@ -62,14 +62,14 @@ pub fn truncate( // let num_windows = (k - 1) / window + 1; // = ((k - 1) - (window - 1) + window - 1) / window + 1; let mut previous = None; - for (a_limb, carry) in a.limbs.into_iter().zip(carries.into_iter()) { + for (a_limb, carry) in a.limbs.into_iter().zip(carries) { let neg_carry_val = bigint_to_fe(&-carry); ctx.assign_region( [ Existing(a_limb), Witness(neg_carry_val), Constant(limb_base), - previous.map(Existing).unwrap_or_else(|| Constant(F::zero())), + previous.map(Existing).unwrap_or_else(|| Constant(F::ZERO)), ], [0], ); diff --git a/halo2-ecc/src/bigint/sub.rs b/halo2-ecc/src/bigint/sub.rs index 8b2263f9..c8a18433 100644 --- a/halo2-ecc/src/bigint/sub.rs +++ b/halo2-ecc/src/bigint/sub.rs @@ -46,7 +46,7 @@ pub fn assign( Existing(lt), Constant(limb_base), Witness(a_with_borrow_val), - Constant(-F::one()), + Constant(-F::ONE), Existing(bottom), Witness(out_val), ], diff --git a/halo2-ecc/src/bn254/bls_signature.rs b/halo2-ecc/src/bn254/bls_signature.rs index 0e10a090..b3cf876b 100644 --- a/halo2-ecc/src/bn254/bls_signature.rs +++ b/halo2-ecc/src/bn254/bls_signature.rs @@ -4,18 +4,18 @@ use super::pairing::PairingChip; use super::{Fp12Chip, Fp2Chip, FpChip}; use crate::ecc::EccChip; use crate::fields::FieldChip; -use crate::fields::PrimeField; use crate::halo2_proofs::halo2curves::bn256::Fq12; use crate::halo2_proofs::halo2curves::bn256::{G1Affine, G2Affine}; +use halo2_base::utils::BigPrimeField; use halo2_base::{AssignedValue, Context}; // To avoid issues with mutably borrowing twice (not allowed in Rust), we only store fp_chip and construct g2_chip and fp12_chip in scope when needed for temporary mutable borrows -pub struct BlsSignatureChip<'chip, F: PrimeField> { +pub struct BlsSignatureChip<'chip, F: BigPrimeField> { pub fp_chip: &'chip FpChip<'chip, F>, pub pairing_chip: &'chip PairingChip<'chip, F>, } -impl<'chip, F: PrimeField> BlsSignatureChip<'chip, F> { +impl<'chip, F: BigPrimeField> BlsSignatureChip<'chip, F> { pub fn new(fp_chip: &'chip FpChip, pairing_chip: &'chip PairingChip) -> Self { Self { fp_chip, pairing_chip } } diff --git a/halo2-ecc/src/bn254/final_exp.rs b/halo2-ecc/src/bn254/final_exp.rs index 7959142e..ae2ecac9 100644 --- a/halo2-ecc/src/bn254/final_exp.rs +++ b/halo2-ecc/src/bn254/final_exp.rs @@ -5,14 +5,19 @@ use crate::halo2_proofs::{ }; use crate::{ ecc::get_naf, - fields::{fp12::mul_no_carry_w6, vector::FieldVector, FieldChip, PrimeField}, + fields::{fp12::mul_no_carry_w6, vector::FieldVector, FieldChip}, +}; +use halo2_base::{ + gates::GateInstructions, + utils::{modulus, BigPrimeField}, + Context, + QuantumCell::Constant, }; -use halo2_base::{gates::GateInstructions, utils::modulus, Context, QuantumCell::Constant}; use num_bigint::BigUint; const XI_0: i64 = 9; -impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { +impl<'chip, F: BigPrimeField> Fp12Chip<'chip, F> { // computes a ** (p ** power) // only works for p = 3 (mod 4) and p = 1 (mod 6) pub fn frobenius_map( @@ -172,8 +177,8 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { // compute `g0 + 1` g0[0].truncation.limbs[0] = - fp2_chip.gate().add(ctx, g0[0].truncation.limbs[0], Constant(F::one())); - g0[0].native = fp2_chip.gate().add(ctx, g0[0].native, Constant(F::one())); + fp2_chip.gate().add(ctx, g0[0].truncation.limbs[0], Constant(F::ONE)); + g0[0].native = fp2_chip.gate().add(ctx, g0[0].native, Constant(F::ONE)); g0[0].truncation.max_limb_bits += 1; g0[0].value += 1usize; diff --git a/halo2-ecc/src/bn254/pairing.rs b/halo2-ecc/src/bn254/pairing.rs index 886985d4..dbd7382f 100644 --- a/halo2-ecc/src/bn254/pairing.rs +++ b/halo2-ecc/src/bn254/pairing.rs @@ -7,8 +7,9 @@ use crate::halo2_proofs::halo2curves::bn256::{ use crate::{ ecc::{EcPoint, EccChip}, fields::fp12::mul_no_carry_w6, - fields::{FieldChip, PrimeField}, + fields::FieldChip, }; +use halo2_base::utils::BigPrimeField; use halo2_base::Context; const XI_0: i64 = 9; @@ -21,7 +22,7 @@ const XI_0: i64 = 9; // line_{Psi(Q0), Psi(Q1)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals w^3 (y_1 - y_2) X + w^2 (x_2 - x_1) Y + w^5 (x_1 y_2 - x_2 y_1) =: out3 * w^3 + out2 * w^2 + out5 * w^5 where out2, out3, out5 are Fp2 points // Output is [None, None, out2, out3, None, out5] as vector of `Option`s -pub fn sparse_line_function_unequal( +pub fn sparse_line_function_unequal( fp2_chip: &Fp2Chip, ctx: &mut Context, Q: (&EcPoint>, &EcPoint>), @@ -60,7 +61,7 @@ pub fn sparse_line_function_unequal( // line_{Psi(Q), Psi(Q)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals (3x^3 - 2y^2)(XI_0 + u) + w^4 (-3 x^2 * Q.x) + w^3 (2 y * Q.y) =: out0 + out4 * w^4 + out3 * w^3 where out0, out3, out4 are Fp2 points // Output is [out0, None, None, out3, out4, None] as vector of `Option`s -pub fn sparse_line_function_equal( +pub fn sparse_line_function_equal( fp2_chip: &Fp2Chip, ctx: &mut Context, Q: &EcPoint>, @@ -95,7 +96,7 @@ pub fn sparse_line_function_equal( // multiply Fp12 point `a` with Fp12 point `b` where `b` is len 6 vector of Fp2 points, where some are `None` to represent zero. // Assumes `b` is not vector of all `None`s -pub fn sparse_fp12_multiply( +pub fn sparse_fp12_multiply( fp2_chip: &Fp2Chip, ctx: &mut Context, a: &FqPoint, @@ -162,7 +163,7 @@ pub fn sparse_fp12_multiply( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q0), Psi(Q1)}(P) as Fp12 point -pub fn fp12_multiply_with_line_unequal( +pub fn fp12_multiply_with_line_unequal( fp2_chip: &Fp2Chip, ctx: &mut Context, g: &FqPoint, @@ -179,7 +180,7 @@ pub fn fp12_multiply_with_line_unequal( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q), Psi(Q)}(P) as Fp12 point -pub fn fp12_multiply_with_line_equal( +pub fn fp12_multiply_with_line_equal( fp2_chip: &Fp2Chip, ctx: &mut Context, g: &FqPoint, @@ -208,7 +209,7 @@ pub fn fp12_multiply_with_line_equal( // - `0 <= loop_count < r` and `loop_count < p` (to avoid [loop_count]Q' = Frob_p(Q')) // - x^3 + b = 0 has no solution in Fp2, i.e., the y-coordinate of Q cannot be 0. -pub fn miller_loop_BN( +pub fn miller_loop_BN( ecc_chip: &EccChip>, ctx: &mut Context, Q: &EcPoint>, @@ -294,7 +295,7 @@ pub fn miller_loop_BN( // let pairs = [(a_i, b_i)], a_i in G_1, b_i in G_2 // output is Prod_i e'(a_i, b_i), where e'(a_i, b_i) is the output of `miller_loop_BN(b_i, a_i)` -pub fn multi_miller_loop_BN( +pub fn multi_miller_loop_BN( ecc_chip: &EccChip>, ctx: &mut Context, pairs: Vec<(&EcPoint>, &EcPoint>)>, @@ -397,7 +398,7 @@ pub fn multi_miller_loop_BN( // - coeff[1][2], coeff[1][3] as assigned cells: this is an optimization to avoid loading new constants // Output: // - (coeff[1][2] * x^p, coeff[1][3] * y^p) point in E(Fp2) -pub fn twisted_frobenius( +pub fn twisted_frobenius( ecc_chip: &EccChip>, ctx: &mut Context, Q: impl Into>>, @@ -423,7 +424,7 @@ pub fn twisted_frobenius( // - Q = (x, y) point in E(Fp2) // Output: // - (coeff[1][2] * x^p, coeff[1][3] * -y^p) point in E(Fp2) -pub fn neg_twisted_frobenius( +pub fn neg_twisted_frobenius( ecc_chip: &EccChip>, ctx: &mut Context, Q: impl Into>>, @@ -444,11 +445,11 @@ pub fn neg_twisted_frobenius( } // To avoid issues with mutably borrowing twice (not allowed in Rust), we only store fp_chip and construct g2_chip and fp12_chip in scope when needed for temporary mutable borrows -pub struct PairingChip<'chip, F: PrimeField> { +pub struct PairingChip<'chip, F: BigPrimeField> { pub fp_chip: &'chip FpChip<'chip, F>, } -impl<'chip, F: PrimeField> PairingChip<'chip, F> { +impl<'chip, F: BigPrimeField> PairingChip<'chip, F> { pub fn new(fp_chip: &'chip FpChip) -> Self { Self { fp_chip } } diff --git a/halo2-ecc/src/bn254/tests/bls_signature.rs b/halo2-ecc/src/bn254/tests/bls_signature.rs index 115ab8ed..8475f677 100644 --- a/halo2-ecc/src/bn254/tests/bls_signature.rs +++ b/halo2-ecc/src/bn254/tests/bls_signature.rs @@ -4,25 +4,18 @@ use std::{ }; use super::*; -use crate::{fields::FpStrategy, halo2_proofs::halo2curves::bn256::G2Affine}; +use crate::{ + bn254::bls_signature::BlsSignatureChip, fields::FpStrategy, + halo2_proofs::halo2curves::bn256::G2Affine, +}; use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - halo2_proofs::{ - halo2curves::{ - bn256::{multi_miller_loop, G2Prepared, Gt}, - pairing::MillerLoopResult, - }, - poly::kzg::multiopen::{ProverGWC, VerifierGWC}, - }, - utils::fs::gen_srs, + gates::RangeChip, + halo2_proofs::halo2curves::bn256::{multi_miller_loop, G2Prepared, Gt}, + utils::BigPrimeField, Context, }; +extern crate pairing; +use pairing::{group::ff::Field, MillerLoopResult}; use rand_core::OsRng; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -39,8 +32,9 @@ struct BlsSignatureCircuitParams { } /// Verify e(g1, signature_agg) = e(pubkey_agg, H(m)) -fn bls_signature_test( +fn bls_signature_test( ctx: &mut Context, + range: &RangeChip, params: BlsSignatureCircuitParams, g1: G1Affine, signatures: &[G2Affine], @@ -48,23 +42,21 @@ fn bls_signature_test( msghash: G2Affine, ) { // Calculate halo2 pairing by multipairing - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let pairing_chip = PairingChip::new(&fp_chip); let bls_signature_chip = BlsSignatureChip::new(&fp_chip, &pairing_chip); let result = bls_signature_chip.bls_signature_verify(ctx, g1, signatures, pubkeys, msghash); // Calculate non-halo2 pairing by multipairing let mut signatures_g2: G2Affine = signatures[0]; - for i in 1..signatures.len() { - signatures_g2 = (signatures_g2 + signatures[i]).into(); + for sig in signatures.iter().skip(1) { + signatures_g2 = (signatures_g2 + sig).into(); } let signature_g2_prepared = G2Prepared::from(signatures_g2); let mut pubkeys_g1: G1Affine = pubkeys[0]; - for i in 1..signatures.len() { - pubkeys_g1 = (pubkeys_g1 + pubkeys[i]).into(); + for pubkey in pubkeys.iter().skip(1) { + pubkeys_g1 = (pubkeys_g1 + pubkey).into(); } let pubkey_aggregated = pubkeys_g1; @@ -77,72 +69,37 @@ fn bls_signature_test( assert_eq!(*result.value(), F::from(actual_result == Gt::identity())) } -fn random_bls_signature_circuit( - params: BlsSignatureCircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - assert!(params.num_aggregation > 0, "Cannot aggregate 0 signatures!"); - - // TODO: Implement hash_to_curve(msg) for arbitrary message +#[test] +fn test_bls_signature() { + let run_path = "configs/bn254/bls_signature_circuit.config"; + let path = run_path; + let params: BlsSignatureCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + println!("num_advice: {num_advice}", num_advice = params.num_advice); + let msg_hash = G2Affine::random(OsRng); let g1 = G1Affine::generator(); - - let mut sks: Vec = Vec::new(); let mut signatures: Vec = Vec::new(); let mut pubkeys: Vec = Vec::new(); - for _ in 0..params.num_aggregation { let sk = Fr::random(OsRng); let signature = G2Affine::from(msg_hash * sk); let pubkey = G1Affine::from(G1Affine::generator() * sk); - sks.push(sk); signatures.push(signature); pubkeys.push(pubkey); } - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - bls_signature_test::(builder.main(0), params, g1, &signatures, &pubkeys, msg_hash); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit -} - -#[test] -fn test_bls_signature() { - let run_path = "configs/bn254/bls_signature_circuit.config"; - let path = run_path; - let params: BlsSignatureCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - println!("num_advice: {num_advice}", num_advice = params.num_advice); - let circuit = random_bls_signature_circuit(params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run(|ctx, range| { + // signatures: &[G2Affine], pubkeys: &[G1Affine], msghash: G2Affine) + bls_signature_test(ctx, range, params, g1, &signatures, &pubkeys, msg_hash); + }) } #[test] fn bench_bls_signature() -> Result<(), Box> { - let rng = OsRng; let config_path = "configs/bn254/bench_bls_signature.config"; let bench_params_file = File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); @@ -160,70 +117,34 @@ fn bench_bls_signature() -> Result<(), Box> { let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let params = gen_srs(k); - let circuit = random_bls_signature_circuit(bench_params, CircuitBuilderStage::Keygen, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = random_bls_signature_circuit( - bench_params, - CircuitBuilderStage::Prover, - Some(break_points), + let msg_hash = G2Affine::random(OsRng); + let g1 = G1Affine::generator(); + let mut signatures: Vec = Vec::new(); + let mut pubkeys: Vec = Vec::new(); + for _ in 0..bench_params.num_aggregation { + let sk = Fr::random(OsRng); + let signature = G2Affine::from(msg_hash * sk); + let pubkey = G1Affine::from(G1Affine::generator() * sk); + + signatures.push(signature); + pubkeys.push(pubkey); + } + + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + (g1, signatures.clone(), pubkeys.clone(), msg_hash), + (g1, signatures, pubkeys, msg_hash), + |pool, range, (g1, signatures, pubkeys, msg_hash)| { + bls_signature_test( + pool.main(), + range, + bench_params, + g1, + &signatures, + &pubkeys, + msg_hash, + ); + }, ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverGWC<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/bls_signature_bn254_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.num_aggregation - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierGWC<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); writeln!( fs_results, @@ -236,9 +157,9 @@ fn bench_bls_signature() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.num_aggregation, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/ec_add.rs b/halo2-ecc/src/bn254/tests/ec_add.rs index a902ce3c..1df235f1 100644 --- a/halo2-ecc/src/bn254/tests/ec_add.rs +++ b/halo2-ecc/src/bn254/tests/ec_add.rs @@ -4,11 +4,11 @@ use std::io::{BufRead, BufReader}; use super::*; use crate::fields::{FieldChip, FpStrategy}; +use crate::group::cofactor::CofactorCurveAffine; use crate::halo2_proofs::halo2curves::bn256::G2Affine; -use group::cofactor::CofactorCurveAffine; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; -use halo2_base::utils::fs::gen_srs; +use halo2_base::utils::testing::base_test; +use halo2_base::utils::BigPrimeField; use halo2_base::Context; use itertools::Itertools; use rand_core::OsRng; @@ -26,10 +26,13 @@ struct CircuitParams { batch_size: usize, } -fn g2_add_test(ctx: &mut Context, params: CircuitParams, _points: Vec) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); +fn g2_add_test( + ctx: &mut Context, + range: &RangeChip, + params: CircuitParams, + _points: Vec, +) { + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let fp2_chip = Fp2Chip::::new(&fp_chip); let g2_chip = EccChip::new(&fp2_chip); @@ -56,12 +59,10 @@ fn test_ec_add() { let k = params.degree; let points = (0..params.batch_size).map(|_| G2Affine::random(OsRng)).collect_vec(); - let mut builder = GateThreadBuilder::::mock(); - g2_add_test(builder.main(0), params, points); - - builder.config(k as usize, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); + base_test() + .k(k) + .lookup_bits(params.lookup_bits) + .run(|ctx, range| g2_add_test(ctx, range, params, points)); } #[test] @@ -83,84 +84,13 @@ fn bench_ec_add() -> Result<(), Box> { println!("---------------------- degree = {k} ------------------------------",); let mut rng = OsRng; - let params_time = start_timer!(|| "Params construction"); - let params = gen_srs(k); - end_timer!(params_time); - - let start0 = start_timer!(|| "Witness generation for empty circuit"); - let circuit = { - let points = vec![G2Affine::generator(); bench_params.batch_size]; - let mut builder = GateThreadBuilder::::keygen(); - g2_add_test(builder.main(0), bench_params, points); - builder.config(k as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) - }; - end_timer!(start0); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - - // create a proof - let points = (0..bench_params.batch_size).map(|_| G2Affine::random(&mut rng)).collect_vec(); - let proof_time = start_timer!(|| "Proving time"); - let proof_circuit = { - let mut builder = GateThreadBuilder::::prover(); - g2_add_test(builder.main(0), bench_params, points); - builder.config(k as usize, Some(20)); - RangeCircuitBuilder::prover(builder, break_points) - }; - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[proof_circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/ec_add_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); - + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + vec![G2Affine::generator(); bench_params.batch_size], + (0..bench_params.batch_size).map(|_| G2Affine::random(&mut rng)).collect_vec(), + |pool, range, points| { + g2_add_test(pool.main(), range, bench_params, points); + }, + ); writeln!( fs_results, "{},{},{},{},{},{},{},{},{:?},{},{:?}", @@ -172,9 +102,9 @@ fn bench_ec_add() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index 0283f672..28466a80 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -3,57 +3,25 @@ use std::{ io::{BufRead, BufReader}, }; -use crate::fields::{FpStrategy, PrimeField}; +use crate::ff::{Field, PrimeField}; use super::*; -#[allow(unused_imports)] -use ff::PrimeField as _; -use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - halo2_proofs::halo2curves::bn256::G1, - utils::fs::gen_srs, -}; use itertools::Itertools; -use rand_core::OsRng; - -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct FixedMSMCircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - batch_size: usize, - radix: usize, - clump_factor: usize, -} -fn fixed_base_msm_test( - builder: &mut GateThreadBuilder, +pub fn fixed_base_msm_test( + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: FixedMSMCircuitParams, bases: Vec, scalars: Vec, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let scalars_assigned = scalars - .iter() - .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) - .collect::>(); + let scalars_assigned = + scalars.iter().map(|scalar| vec![pool.main().load_witness(*scalar)]).collect::>(); - let msm = ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); + let msm = ecc_chip.fixed_base_msm(pool, &bases, scalars_assigned, Fr::NUM_BITS as usize); let mut elts: Vec = Vec::new(); for (base, scalar) in bases.iter().zip(scalars.iter()) { @@ -67,38 +35,6 @@ fn fixed_base_msm_test( assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } -fn random_fixed_base_msm_circuit( - params: FixedMSMCircuitParams, - bases: Vec, // bases are fixed in vkey so don't randomly generate - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let scalars = (0..params.batch_size).map(|_| Fr::random(OsRng)).collect_vec(); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fixed_base_msm_test(&mut builder, params, bases, scalars); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit -} - #[test] fn test_fixed_base_msm() { let path = "configs/bn254/fixed_msm_circuit.config"; @@ -107,9 +43,12 @@ fn test_fixed_base_msm() { ) .unwrap(); - let bases = (0..params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); - let circuit = random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let mut rng = StdRng::seed_from_u64(0); + let bases = (0..params.batch_size).map(|_| G1Affine::random(&mut rng)).collect_vec(); + let scalars = (0..params.batch_size).map(|_| Fr::random(&mut rng)).collect_vec(); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + fixed_base_msm_test(pool, range, params, bases, scalars); + }); } #[test] @@ -119,14 +58,11 @@ fn test_fixed_msm_minus_1() { File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - let base = G1Affine::random(OsRng); - let k = params.degree as usize; - let mut builder = GateThreadBuilder::mock(); - fixed_base_msm_test(&mut builder, params, vec![base], vec![-Fr::one()]); - - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let rng = StdRng::seed_from_u64(0); + let base = G1Affine::random(rng); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + fixed_base_msm_test(pool, range, params, vec![base], vec![-Fr::one()]); + }); } #[test] @@ -141,89 +77,24 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,proof_time,proof_size,verify_time")?; + let mut rng = StdRng::seed_from_u64(0); let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: FixedMSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; + let batch_size = bench_params.batch_size; println!("---------------------- degree = {k} ------------------------------",); - let rng = OsRng; - let params = gen_srs(k); - println!("{bench_params:?}"); - - let bases = (0..bench_params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); - let circuit = random_fixed_base_msm_circuit( - bench_params, - bases.clone(), - CircuitBuilderStage::Keygen, - None, - ); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = random_fixed_base_msm_circuit( - bench_params, - bases, - CircuitBuilderStage::Prover, - Some(break_points), + let bases = (0..batch_size).map(|_| G1Affine::random(&mut rng)).collect_vec(); + let scalars = (0..batch_size).map(|_| Fr::random(&mut rng)).collect_vec(); + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + (bases.clone(), scalars.clone()), + (bases, scalars), + |pool, range, (bases, scalars)| { + fixed_base_msm_test(pool, range, bench_params, bases, scalars); + }, ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/ - msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); - writeln!( fs_results, "{},{},{},{},{},{},{},{},{:?},{},{:?}", @@ -235,9 +106,9 @@ fn bench_fixed_base_msm() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index 89aea571..ac3d5c0b 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -1,27 +1,20 @@ #![allow(non_snake_case)] -use super::bls_signature::BlsSignatureChip; use super::pairing::PairingChip; use super::*; -use crate::{ecc::EccChip, fields::PrimeField}; +use crate::ecc::EccChip; +use crate::group::Curve; use crate::{ fields::FpStrategy, - halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, - plonk::*, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, - }, + halo2_proofs::halo2curves::bn256::{pairing, Fr, G1Affine}, }; -use ark_std::{end_timer, start_timer}; -use group::Curve; use halo2_base::utils::fe_to_biguint; +use halo2_base::{ + gates::{flex_gate::threads::SinglePhaseCoreManager, RangeChip}, + halo2_proofs::halo2curves::bn256::G1, + utils::testing::base_test, +}; +use rand::rngs::StdRng; +use rand_core::SeedableRng; use serde::{Deserialize, Serialize}; use std::io::Write; @@ -34,7 +27,7 @@ pub mod msm_sum_infinity_fixed_base; pub mod pairing; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { +pub struct MSMCircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -46,3 +39,18 @@ struct MSMCircuitParams { batch_size: usize, window_bits: usize, } + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct FixedMSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + radix: usize, + clump_factor: usize, +} diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index cfc7d40f..22ea8ee8 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -1,16 +1,4 @@ -use crate::fields::FpStrategy; -use ff::{Field, PrimeField}; -use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - utils::fs::gen_srs, -}; -use rand_core::OsRng; +use crate::ff::{Field, PrimeField}; use std::{ fs::{self, File}, io::{BufRead, BufReader}, @@ -18,33 +6,17 @@ use std::{ use super::*; -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - batch_size: usize, - window_bits: usize, -} - -fn msm_test( - builder: &mut GateThreadBuilder, +pub fn msm_test( + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: MSMCircuitParams, bases: Vec, scalars: Vec, - window_bits: usize, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let ctx = builder.main(0); + let ctx = pool.main(); let scalars_assigned = scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); let bases_assigned = bases @@ -52,13 +24,12 @@ fn msm_test( .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) .collect::>(); - let msm = ecc_chip.variable_base_msm_in::( - builder, + let msm = ecc_chip.variable_base_msm_custom::( + pool, &bases_assigned, scalars_assigned, Fr::NUM_BITS as usize, - window_bits, - 0, + params.window_bits, ); let msm_answer = bases @@ -75,36 +46,8 @@ fn msm_test( assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } -fn random_msm_circuit( - params: MSMCircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let (bases, scalars): (Vec<_>, Vec<_>) = - (0..params.batch_size).map(|_| (G1Affine::random(OsRng), Fr::random(OsRng))).unzip(); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - msm_test(&mut builder, params, bases, scalars, params.window_bits); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit +fn random_pairs(batch_size: usize, rng: &StdRng) -> (Vec, Vec) { + (0..batch_size).map(|_| (G1Affine::random(rng.clone()), Fr::random(rng.clone()))).unzip() } #[test] @@ -114,9 +57,10 @@ fn test_msm() { File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - - let circuit = random_msm_circuit(params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let (bases, scalars) = random_pairs(params.batch_size, &StdRng::seed_from_u64(0)); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + msm_test(pool, range, params, bases, scalars); + }); } #[test] @@ -136,72 +80,16 @@ fn bench_msm() -> Result<(), Box> { let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let rng = OsRng; - let params = gen_srs(k); - println!("{bench_params:?}"); - - let circuit = random_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = - random_msm_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - bench_params.window_bits + let (bases, scalars) = random_pairs(bench_params.batch_size, &StdRng::seed_from_u64(0)); + let stats = + base_test().k(bench_params.degree).lookup_bits(bench_params.lookup_bits).bench_builder( + (bases.clone(), scalars.clone()), + (bases, scalars), + |pool, range, (bases, scalars)| { + msm_test(pool, range, bench_params, bases, scalars); + }, ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); writeln!( fs_results, @@ -215,9 +103,9 @@ fn bench_msm() -> Result<(), Box> { bench_params.num_limbs, bench_params.batch_size, bench_params.window_bits, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed(), )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs index 600a4931..d053d196 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -1,134 +1,40 @@ -use ff::PrimeField; -use halo2_base::gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, - }, - RangeChip, -}; -use rand_core::OsRng; use std::fs::File; -use super::*; +use super::{msm::msm_test, *}; -fn msm_test( - builder: &mut GateThreadBuilder, - params: MSMCircuitParams, - bases: Vec, - scalars: Vec, - window_bits: usize, -) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let ecc_chip = EccChip::new(&fp_chip); - - let ctx = builder.main(0); - let scalars_assigned = - scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); - let bases_assigned = bases - .iter() - .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) - .collect::>(); - - let msm = ecc_chip.variable_base_msm_in::( - builder, - &bases_assigned, - scalars_assigned, - Fr::NUM_BITS as usize, - window_bits, - 0, - ); - - let msm_answer = bases - .iter() - .zip(scalars.iter()) - .map(|(base, scalar)| base * scalar) - .reduce(|a, b| a + b) - .unwrap() - .to_affine(); - - let msm_x = msm.x.value(); - let msm_y = msm.y.value(); - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); -} - -fn custom_msm_circuit( - params: MSMCircuitParams, - stage: CircuitBuilderStage, - break_points: Option, - bases: Vec, - scalars: Vec, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - msm_test(&mut builder, params, bases, scalars, params.window_bits); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit -} - -#[test] -fn test_msm1() { +fn run_test(scalars: Vec, bases: Vec) { let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( + let params: MSMCircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - params.batch_size = 3; + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + msm_test(pool, range, params, bases, scalars); + }); +} - let random_point = G1Affine::random(OsRng); +#[test] +fn test_msm1() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, random_point]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_msm2() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 3; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_msm3() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![ random_point, random_point, @@ -136,20 +42,11 @@ fn test_msm3() { (random_point + random_point + random_point).to_affine(), ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_msm4() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - let generator_point = G1Affine::generator(); let bases = vec![ generator_point, @@ -158,26 +55,15 @@ fn test_msm4() { (generator_point + generator_point + generator_point).to_affine(), ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_msm5() { - // Very similar example that does not add to infinity. It works fine. - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs index 6cf96c7f..d10d8a7c 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -1,134 +1,40 @@ -use ff::PrimeField; -use halo2_base::gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, - }, - RangeChip, -}; -use rand_core::OsRng; use std::fs::File; -use super::*; +use super::{fixed_base_msm::fixed_base_msm_test, *}; -fn msm_test( - builder: &mut GateThreadBuilder, - params: MSMCircuitParams, - bases: Vec, - scalars: Vec, - window_bits: usize, -) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let ecc_chip = EccChip::new(&fp_chip); - - let ctx = builder.main(0); - let scalars_assigned = - scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); - let bases_assigned = bases; - //.iter() - //.map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) - //.collect::>(); - - let msm = ecc_chip.fixed_base_msm_in::( - builder, - &bases_assigned, - scalars_assigned, - Fr::NUM_BITS as usize, - window_bits, - 0, - ); - - let msm_answer = bases_assigned - .iter() - .zip(scalars.iter()) - .map(|(base, scalar)| base * scalar) - .reduce(|a, b| a + b) - .unwrap() - .to_affine(); - - let msm_x = msm.x.value(); - let msm_y = msm.y.value(); - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); -} - -fn custom_msm_circuit( - params: MSMCircuitParams, - stage: CircuitBuilderStage, - break_points: Option, - bases: Vec, - scalars: Vec, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - msm_test(&mut builder, params, bases, scalars, params.window_bits); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit -} - -#[test] -fn test_fb_msm1() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( +fn run_test(scalars: Vec, bases: Vec) { + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - params.batch_size = 3; + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + fixed_base_msm_test(pool, range, params, bases, scalars); + }); +} - let random_point = G1Affine::random(OsRng); +#[test] +fn test_fb_msm1() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, random_point]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_fb_msm2() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 3; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_fb_msm3() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![ random_point, random_point, @@ -136,20 +42,11 @@ fn test_fb_msm3() { (random_point + random_point + random_point).to_affine(), ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_fb_msm4() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - let generator_point = G1Affine::generator(); let bases = vec![ generator_point, @@ -158,26 +55,15 @@ fn test_fb_msm4() { (generator_point + generator_point + generator_point).to_affine(), ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_fb_msm5() { - // Very similar example that does not add to infinity. It works fine. - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } diff --git a/halo2-ecc/src/bn254/tests/pairing.rs b/halo2-ecc/src/bn254/tests/pairing.rs index d00330ee..6b192ada 100644 --- a/halo2-ecc/src/bn254/tests/pairing.rs +++ b/halo2-ecc/src/bn254/tests/pairing.rs @@ -7,18 +7,8 @@ use super::*; use crate::fields::FieldChip; use crate::{fields::FpStrategy, halo2_proofs::halo2curves::bn256::G2Affine}; use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - halo2_proofs::poly::kzg::multiopen::{ProverGWC, VerifierGWC}, - utils::fs::gen_srs, - Context, + gates::RangeChip, halo2_proofs::arithmetic::Field, utils::BigPrimeField, Context, }; -use rand_core::OsRng; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct PairingCircuitParams { @@ -32,33 +22,14 @@ struct PairingCircuitParams { num_limbs: usize, } -fn pairing_check_test( - ctx: &mut Context, - params: PairingCircuitParams, - P: G1Affine, - Q: G2Affine, - S: G1Affine, -) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let chip = PairingChip::new(&fp_chip); - let P_assigned = chip.load_private_g1(ctx, P); - let Q_assigned = chip.load_private_g2(ctx, Q); - let S_assigned = chip.load_private_g1(ctx, S); - let T_assigned = chip.load_private_g2(ctx, G2Affine::generator()); - chip.pairing_check(ctx, &Q_assigned, &P_assigned, &T_assigned, &S_assigned); -} - -fn pairing_test( +fn pairing_test( ctx: &mut Context, + range: &RangeChip, params: PairingCircuitParams, P: G1Affine, Q: G2Affine, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let chip = PairingChip::new(&fp_chip); let P_assigned = chip.load_private_g1(ctx, P); let Q_assigned = chip.load_private_g2(ctx, Q); @@ -73,52 +44,36 @@ fn pairing_test( ); } -fn build_setup( - params: PairingCircuitParams, - stage: CircuitBuilderStage, -) -> (usize, GateThreadBuilder) { - ( - params.degree as usize, - match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }, +#[test] +fn test_pairing() { + let path = "configs/bn254/pairing_circuit.config"; + let params: PairingCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) + .unwrap(); + let mut rng = StdRng::seed_from_u64(0); + let P = G1Affine::random(&mut rng); + let Q = G2Affine::random(&mut rng); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run(|ctx, range| { + pairing_test(ctx, range, params, P, Q); + }); } -fn build_circuit( - k: usize, - builder: GateThreadBuilder, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - } -} - -fn random_pairing_circuit( +fn pairing_check_test( + ctx: &mut Context, + range: &RangeChip, params: PairingCircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let (k, mut builder) = build_setup(params, stage); - let P = G1Affine::random(OsRng); - let Q = G2Affine::random(OsRng); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - pairing_test::(builder.main(0), params, P, Q); - let circuit = build_circuit(k, builder, stage, break_points); - end_timer!(start0); - circuit + P: G1Affine, + Q: G2Affine, + S: G1Affine, +) { + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); + let chip = PairingChip::new(&fp_chip); + let P_assigned = chip.load_private_g1(ctx, P); + let Q_assigned = chip.load_private_g2(ctx, Q); + let S_assigned = chip.load_private_g1(ctx, S); + let T_assigned = chip.load_private_g2(ctx, G2Affine::generator()); + chip.pairing_check(ctx, &Q_assigned, &P_assigned, &T_assigned, &S_assigned); } /* @@ -126,22 +81,22 @@ fn random_pairing_circuit( * e(H_1^α, H_2^β) = e(H_1^(α*β), H_2), where H_1 is the generator for G1 and * H_2 for G2. */ -fn random_pairing_check_circuit( - params: PairingCircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let (k, mut builder) = build_setup(params, stage); - let alpha = Fr::random(OsRng); - let beta = Fr::random(OsRng); +#[test] +fn test_pairing_check() { + let path = "configs/bn254/pairing_circuit.config"; + let params: PairingCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + let mut rng = StdRng::seed_from_u64(0); + let alpha = Fr::random(&mut rng); + let beta = Fr::random(&mut rng); let P = G1Affine::from(G1Affine::generator() * alpha); let Q = G2Affine::from(G2Affine::generator() * beta); let S = G1Affine::from(G1Affine::generator() * alpha * beta); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - pairing_check_test::(builder.main(0), params, P, Q, S); - let circuit = build_circuit(k, builder, stage, break_points); - end_timer!(start0); - circuit + base_test().k(params.degree).lookup_bits(params.lookup_bits).run(|ctx, range| { + pairing_check_test(ctx, range, params, P, Q, S); + }) } /* @@ -149,60 +104,27 @@ fn random_pairing_check_circuit( * e(H_1^α, H_2^β) = e(H_1^α, H_2), where H_1 is the generator for G1 and * H_2 for G2. */ -fn random_pairing_check_fail_circuit( - params: PairingCircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let (k, mut builder) = build_setup(params, stage); - let alpha = Fr::random(OsRng); - let beta = Fr::random(OsRng); - let P = G1Affine::from(G1Affine::generator() * alpha); - let Q = G2Affine::from(G2Affine::generator() * beta); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - pairing_check_test::(builder.main(0), params, P, Q, P); - let circuit = build_circuit(k, builder, stage, break_points); - end_timer!(start0); - circuit -} - -#[test] -fn test_pairing() { - let path = "configs/bn254/pairing_circuit.config"; - let params: PairingCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - let circuit = random_pairing_circuit(params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[test] -fn test_pairing_check() { - let path = "configs/bn254/pairing_circuit.config"; - let params: PairingCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - let circuit = random_pairing_check_circuit(params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); -} - #[test] -#[should_panic] fn test_pairing_check_fail() { let path = "configs/bn254/pairing_circuit.config"; let params: PairingCircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - let circuit = random_pairing_check_fail_circuit(params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let mut rng = StdRng::seed_from_u64(0); + let alpha = Fr::random(&mut rng); + let beta = Fr::random(&mut rng); + let P = G1Affine::from(G1Affine::generator() * alpha); + let Q = G2Affine::from(G2Affine::generator() * beta); + base_test().k(params.degree).lookup_bits(params.lookup_bits).expect_satisfied(false).run( + |ctx, range| { + pairing_check_test(ctx, range, params, P, Q, P); + }, + ) } #[test] fn bench_pairing() -> Result<(), Box> { - let rng = OsRng; let config_path = "configs/bn254/bench_pairing.config"; let bench_params_file = File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); @@ -213,6 +135,7 @@ fn bench_pairing() -> Result<(), Box> { let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; + let mut rng = StdRng::seed_from_u64(0); let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: PairingCircuitParams = @@ -220,66 +143,15 @@ fn bench_pairing() -> Result<(), Box> { let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let params = gen_srs(k); - let circuit = random_pairing_circuit(bench_params, CircuitBuilderStage::Keygen, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = - random_pairing_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverGWC<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/pairing_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierGWC<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); + let P = G1Affine::random(&mut rng); + let Q = G2Affine::random(&mut rng); + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + (P, Q), + (P, Q), + |pool, range, (P, Q)| { + pairing_test(pool.main(), range, bench_params, P, Q); + }, + ); writeln!( fs_results, @@ -291,9 +163,9 @@ fn bench_pairing() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index ca0b111b..c72b3974 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -1,7 +1,8 @@ +use halo2_base::utils::BigPrimeField; use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use crate::bigint::{big_is_equal, big_less_than, FixedOverflowInteger, ProperCrtUint}; -use crate::fields::{fp::FpChip, FieldChip, PrimeField}; +use crate::fields::{fp::FpChip, FieldChip}; use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // CF is the coordinate field of GA @@ -12,7 +13,7 @@ use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // Assumes `r, s` are proper CRT integers /// **WARNING**: Only use this function if `1 / (p - n)` is very small (e.g., < 2-100) /// `pubkey` should not be the identity point -pub fn ecdsa_verify_no_pubkey_check( +pub fn ecdsa_verify_no_pubkey_check( chip: &EccChip>, ctx: &mut Context, pubkey: EcPoint as FieldChip>::FieldPoint>, diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index 5dfba754..304cd6b8 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,9 +1,11 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; use crate::ecc::{ec_sub_strict, load_random_point}; -use crate::fields::{FieldChip, PrimeField, Selectable}; -use group::Curve; -use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; +use crate::ff::Field; +use crate::fields::{FieldChip, Selectable}; +use crate::group::Curve; +use halo2_base::gates::flex_gate::threads::{parallelize_core, SinglePhaseCoreManager}; +use halo2_base::utils::BigPrimeField; use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use itertools::Itertools; use rayon::prelude::*; @@ -27,12 +29,12 @@ pub fn scalar_multiply( window_bits: usize, ) -> EcPoint where - F: PrimeField, + F: BigPrimeField, C: CurveAffineExt, FC: FieldChip + Selectable, { if point.is_identity().into() { - let zero = chip.load_constant(ctx, C::Base::zero()); + let zero = chip.load_constant(ctx, C::Base::ZERO); return EcPoint::new(zero.clone(), zero); } assert!(!scalar.is_empty()); @@ -111,20 +113,19 @@ where /// * Output may be point at infinity, in which case (0, 0) is returned pub fn msm_par( chip: &EccChip, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[C], scalars: Vec>>, max_scalar_bits_per_cell: usize, window_bits: usize, - phase: usize, ) -> EcPoint where - F: PrimeField, + F: BigPrimeField, C: CurveAffineExt, FC: FieldChip + Selectable, { if points.is_empty() { - return chip.assign_constant_point(builder.main(phase), C::identity()); + return chip.assign_constant_point(builder.main(), C::identity()); } assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS); assert_eq!(points.len(), scalars.len()); @@ -166,11 +167,10 @@ where C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); - let ctx = builder.main(phase); + let ctx = builder.main(); let any_point = chip.load_random_point::(ctx); - let scalar_mults = parallelize_in( - phase, + let scalar_mults = parallelize_core( builder, cached_points_affine .chunks(cached_points_affine.len() / points.len()) @@ -207,7 +207,7 @@ where curr_point }, ); - let ctx = builder.main(phase); + let ctx = builder.main(); // sum `scalar_mults` but take into account possiblity of identity points let any_point2 = chip.load_random_point::(ctx); let mut acc = any_point2.clone(); diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index 0b3103d2..b410b1e0 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -1,9 +1,10 @@ #![allow(non_snake_case)] -use crate::fields::{fp::FpChip, FieldChip, PrimeField, Selectable}; +use crate::ff::Field; +use crate::fields::{fp::FpChip, FieldChip, Selectable}; +use crate::group::{Curve, Group}; use crate::halo2_proofs::arithmetic::CurveAffine; -use group::{Curve, Group}; -use halo2_base::gates::builder::GateThreadBuilder; -use halo2_base::utils::modulus; +use halo2_base::gates::flex_gate::threads::SinglePhaseCoreManager; +use halo2_base::utils::{modulus, BigPrimeField}; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, utils::CurveAffineExt, @@ -22,20 +23,20 @@ pub mod pippenger; // EcPoint and EccChip take in a generic `FieldChip` to implement generic elliptic curve operations on arbitrary field extensions (provided chip exists) for short Weierstrass curves (currently further assuming a4 = 0 for optimization purposes) #[derive(Debug)] -pub struct EcPoint { +pub struct EcPoint { pub x: FieldPoint, pub y: FieldPoint, _marker: PhantomData, } -impl Clone for EcPoint { +impl Clone for EcPoint { fn clone(&self) -> Self { Self { x: self.x.clone(), y: self.y.clone(), _marker: PhantomData } } } // Improve readability by allowing `&EcPoint` to be converted to `EcPoint` via cloning -impl<'a, F: PrimeField, FieldPoint: Clone> From<&'a EcPoint> +impl<'a, F: BigPrimeField, FieldPoint: Clone> From<&'a EcPoint> for EcPoint { fn from(value: &'a EcPoint) -> Self { @@ -43,7 +44,7 @@ impl<'a, F: PrimeField, FieldPoint: Clone> From<&'a EcPoint> } } -impl EcPoint { +impl EcPoint { pub fn new(x: FieldPoint, y: FieldPoint) -> Self { Self { x, y, _marker: PhantomData } } @@ -59,25 +60,25 @@ impl EcPoint { /// An elliptic curve point where it is easy to compare the x-coordinate of two points #[derive(Clone, Debug)] -pub struct StrictEcPoint> { +pub struct StrictEcPoint> { pub x: FC::ReducedFieldPoint, pub y: FC::FieldPoint, _marker: PhantomData, } -impl> StrictEcPoint { +impl> StrictEcPoint { pub fn new(x: FC::ReducedFieldPoint, y: FC::FieldPoint) -> Self { Self { x, y, _marker: PhantomData } } } -impl> From> for EcPoint { +impl> From> for EcPoint { fn from(value: StrictEcPoint) -> Self { Self::new(value.x.into(), value.y) } } -impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> +impl<'a, F: BigPrimeField, FC: FieldChip> From<&'a StrictEcPoint> for EcPoint { fn from(value: &'a StrictEcPoint) -> Self { @@ -88,18 +89,18 @@ impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> /// An elliptic curve point where the x-coordinate has already been constrained to be reduced or not. /// In the reduced case one can more optimally compare equality of x-coordinates. #[derive(Clone, Debug)] -pub enum ComparableEcPoint> { +pub enum ComparableEcPoint> { Strict(StrictEcPoint), NonStrict(EcPoint), } -impl> From> for ComparableEcPoint { +impl> From> for ComparableEcPoint { fn from(pt: StrictEcPoint) -> Self { Self::Strict(pt) } } -impl> From> +impl> From> for ComparableEcPoint { fn from(pt: EcPoint) -> Self { @@ -107,7 +108,7 @@ impl> From> } } -impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> +impl<'a, F: BigPrimeField, FC: FieldChip> From<&'a StrictEcPoint> for ComparableEcPoint { fn from(pt: &'a StrictEcPoint) -> Self { @@ -115,7 +116,7 @@ impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> } } -impl<'a, F: PrimeField, FC: FieldChip> From<&'a EcPoint> +impl<'a, F: BigPrimeField, FC: FieldChip> From<&'a EcPoint> for ComparableEcPoint { fn from(pt: &'a EcPoint) -> Self { @@ -123,7 +124,7 @@ impl<'a, F: PrimeField, FC: FieldChip> From<&'a EcPoint> } } -impl> From> +impl> From> for EcPoint { fn from(pt: ComparableEcPoint) -> Self { @@ -150,7 +151,7 @@ impl> From> /// /// # Assumptions /// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) -pub fn ec_add_unequal>( +pub fn ec_add_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -180,7 +181,7 @@ pub fn ec_add_unequal>( /// If `do_check = true`, then this function constrains that `P.x != Q.x`. /// Otherwise does nothing. -fn check_points_are_unequal>( +fn check_points_are_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -196,7 +197,7 @@ fn check_points_are_unequal>( ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), }); let x_is_equal = chip.is_equal_unenforced(ctx, x1, x2); - chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::ZERO); } (EcPoint::from(P), EcPoint::from(Q)) } @@ -216,7 +217,7 @@ fn check_points_are_unequal>( /// /// # Assumptions /// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) -pub fn ec_sub_unequal>( +pub fn ec_sub_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -250,7 +251,7 @@ pub fn ec_sub_unequal>( /// /// Assumptions /// # Neither P or Q is the point at infinity -pub fn ec_sub_strict>( +pub fn ec_sub_strict>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -280,7 +281,7 @@ where P = ec_select(chip, ctx, rand_pt, P, is_identity); let out = ec_sub_unequal(chip, ctx, P, Q, false); - let zero = chip.load_constant(ctx, FC::FieldType::zero()); + let zero = chip.load_constant(ctx, FC::FieldType::ZERO); ec_select(chip, ctx, EcPoint::new(zero.clone(), zero), out, is_identity) } @@ -299,7 +300,7 @@ where /// # Assumptions /// * `P.y != 0` /// * `P` is not the point at infinity (undefined behavior otherwise) -pub fn ec_double>( +pub fn ec_double>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -338,7 +339,7 @@ pub fn ec_double>( /// /// # Assumptions /// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) -pub fn ec_double_and_add_unequal>( +pub fn ec_double_and_add_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -355,7 +356,7 @@ pub fn ec_double_and_add_unequal>( ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), }); let x_is_equal = chip.is_equal_unenforced(ctx, x0.clone(), x1); - chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::ZERO); x_0 = Some(x0); } let P = EcPoint::from(P); @@ -376,7 +377,7 @@ pub fn ec_double_and_add_unequal>( // TODO: when can we remove this check? // constrains that x_2 != x_0 let x_is_equal = chip.is_equal_unenforced(ctx, x_0.unwrap(), x_2); - chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::ZERO); } // lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0) let two_y_0 = chip.scalar_mul_no_carry(ctx, &P.y, 2); @@ -399,7 +400,7 @@ pub fn ec_double_and_add_unequal>( EcPoint::new(x_res, y_res) } -pub fn ec_select( +pub fn ec_select( chip: &FC, ctx: &mut Context, P: EcPoint, @@ -416,7 +417,7 @@ where // takes the dot product of points with sel, where each is intepreted as // a _vector_ -pub fn ec_select_by_indicator( +pub fn ec_select_by_indicator( chip: &FC, ctx: &mut Context, points: &[Pt], @@ -439,7 +440,7 @@ where } // `sel` is little-endian binary -pub fn ec_select_from_bits( +pub fn ec_select_from_bits( chip: &FC, ctx: &mut Context, points: &[Pt], @@ -456,7 +457,7 @@ where } // `sel` is little-endian binary -pub fn strict_ec_select_from_bits( +pub fn strict_ec_select_from_bits( chip: &FC, ctx: &mut Context, points: &[StrictEcPoint], @@ -485,7 +486,7 @@ where /// - The curve has no points of order 2. /// - `scalar_i < 2^{max_bits} for all i` /// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` -pub fn scalar_multiply( +pub fn scalar_multiply( chip: &FC, ctx: &mut Context, P: EcPoint, @@ -588,7 +589,7 @@ where /// Checks that `P` is indeed a point on the elliptic curve `C`. pub fn check_is_on_curve(chip: &FC, ctx: &mut Context, P: &EcPoint) where - F: PrimeField, + F: BigPrimeField, FC: FieldChip, C: CurveAffine, { @@ -603,7 +604,7 @@ where pub fn load_random_point(chip: &FC, ctx: &mut Context) -> EcPoint where - F: PrimeField, + F: BigPrimeField, FC: FieldChip, C: CurveAffineExt, { @@ -625,7 +626,7 @@ pub fn into_strict_point( pt: EcPoint, ) -> StrictEcPoint where - F: PrimeField, + F: BigPrimeField, FC: FieldChip, { let x = chip.enforce_less_than(ctx, pt.x); @@ -648,7 +649,7 @@ where /// * `points` are all on the curve or the point at infinity /// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) /// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point -pub fn multi_scalar_multiply( +pub fn multi_scalar_multiply( chip: &FC, ctx: &mut Context, P: &[EcPoint], @@ -812,12 +813,12 @@ pub type BaseFieldEccChip<'chip, C> = EccChip< >; #[derive(Clone, Debug)] -pub struct EccChip<'chip, F: PrimeField, FC: FieldChip> { +pub struct EccChip<'chip, F: BigPrimeField, FC: FieldChip> { pub field_chip: &'chip FC, _marker: PhantomData, } -impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { +impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn new(field_chip: &'chip FC) -> Self { Self { field_chip, _marker: PhantomData } } @@ -857,11 +858,11 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn assign_point(&self, ctx: &mut Context, g: C) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, + C::Base: crate::ff::PrimeField, { let pt = self.assign_point_unchecked(ctx, g); let is_on_curve = self.is_on_curve_or_infinity::(ctx, &pt); - self.field_chip.gate().assert_is_const(ctx, &is_on_curve, &F::one()); + self.field_chip.gate().assert_is_const(ctx, &is_on_curve, &F::ONE); pt } @@ -1010,7 +1011,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { } } -impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> +impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> where FC: Selectable, { @@ -1043,7 +1044,7 @@ where /// See [`pippenger::multi_exp_par`] for more details. pub fn variable_base_msm( &self, - thread_pool: &mut GateThreadBuilder, + thread_pool: &mut SinglePhaseCoreManager, P: &[EcPoint], scalars: Vec>>, max_bits: usize, @@ -1053,18 +1054,17 @@ where FC: Selectable, { // window_bits = 4 is optimal from empirical observations - self.variable_base_msm_in::(thread_pool, P, scalars, max_bits, 4, 0) + self.variable_base_msm_custom::(thread_pool, P, scalars, max_bits, 4) } // TODO: add asserts to validate input assumptions described in docs - pub fn variable_base_msm_in( + pub fn variable_base_msm_custom( &self, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, P: &[EcPoint], scalars: Vec>>, max_bits: usize, window_bits: usize, - phase: usize, ) -> EcPoint where C: CurveAffineExt, @@ -1076,7 +1076,7 @@ where if P.len() <= 25 { multi_scalar_multiply::( self.field_chip, - builder.main(phase), + builder.main(), P, scalars, max_bits, @@ -1098,13 +1098,12 @@ where scalars, max_bits, window_bits, // clump_factor := window_bits - phase, ) } } } -impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { +impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { /// See [`fixed_base::scalar_multiply`] for more details. // TODO: put a check in place that scalar is < modulus of C::Scalar pub fn fixed_base_scalar_mult( @@ -1132,7 +1131,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { // default for most purposes pub fn fixed_base_msm( &self, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[C], scalars: Vec>>, max_scalar_bits_per_cell: usize, @@ -1141,7 +1140,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { C: CurveAffineExt, FC: FieldChip + Selectable, { - self.fixed_base_msm_in::(builder, points, scalars, max_scalar_bits_per_cell, 4, 0) + self.fixed_base_msm_custom::(builder, points, scalars, max_scalar_bits_per_cell, 4) } // `radix = 0` means auto-calculate @@ -1149,14 +1148,13 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { /// `clump_factor = 0` means auto-calculate /// /// The user should filter out base points that are identity beforehand; we do not separately do this here - pub fn fixed_base_msm_in( + pub fn fixed_base_msm_custom( &self, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[C], scalars: Vec>>, max_scalar_bits_per_cell: usize, clump_factor: usize, - phase: usize, ) -> EcPoint where C: CurveAffineExt, @@ -1166,15 +1164,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { #[cfg(feature = "display")] println!("computing length {} fixed base msm", points.len()); - fixed_base::msm_par( - self, - builder, - points, - scalars, - max_scalar_bits_per_cell, - clump_factor, - phase, - ) + fixed_base::msm_par(self, builder, points, scalars, max_scalar_bits_per_cell, clump_factor) // Empirically does not seem like pippenger is any better for fixed base msm right now, because of the cost of `select_by_indicator` // Cell usage becomes around comparable when `points.len() > 100`, and `clump_factor` should always be 4 diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index 934a7432..736a9f34 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -4,14 +4,14 @@ use super::{ }; use crate::{ ecc::ec_sub_strict, - fields::{FieldChip, PrimeField, Selectable}, + fields::{FieldChip, Selectable}, }; use halo2_base::{ gates::{ - builder::{parallelize_in, GateThreadBuilder}, + flex_gate::threads::{parallelize_core, SinglePhaseCoreManager}, GateInstructions, }, - utils::CurveAffineExt, + utils::{BigPrimeField, CurveAffineExt}, AssignedValue, }; @@ -216,16 +216,15 @@ where /// * `points` are all on the curve or the point at infinity /// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) /// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point -pub fn multi_exp_par( +pub fn multi_exp_par( chip: &FC, // these are the "threads" within a single Phase - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[EcPoint], scalars: Vec>>, max_scalar_bits_per_cell: usize, // radix: usize, // specialize to radix = 1 clump_factor: usize, - phase: usize, ) -> EcPoint where FC: FieldChip + Selectable + Selectable, @@ -239,7 +238,7 @@ where let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits]; // get a main thread - let ctx = builder.main(phase); + let ctx = builder.main(); // single-threaded computation: for scalar in scalars { for (scalar_chunk, bool_chunk) in @@ -267,10 +266,9 @@ where // now begins multi-threading // multi_prods is 2d vector of size `num_rounds` by `scalar_bits` - let multi_prods = parallelize_in( - phase, + let multi_prods = parallelize_core( builder, - points.chunks(c).into_iter().zip(any_points.iter()).enumerate().collect(), + points.chunks(c).zip(any_points.iter()).enumerate().collect(), |ctx, (round, (points_clump, any_point))| { // compute all possible multi-products of elements in points[round * c .. round * (c+1)] // stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... } @@ -306,7 +304,7 @@ where ); // agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits - let mut agg = parallelize_in(phase, builder, (0..scalar_bits).collect(), |ctx, i| { + let mut agg = parallelize_core(builder, (0..scalar_bits).collect(), |ctx, i| { let mut acc = multi_prods[0][i].clone(); for multi_prod in multi_prods.iter().skip(1) { let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true); @@ -316,7 +314,7 @@ where }); // gets the LAST thread for single threaded work - let ctx = builder.main(phase); + let ctx = builder.main(); // we have agg[j] = G'[j] + (2^num_rounds - 1) * any_base // let any_point = (2^num_rounds - 1) * any_base // TODO: can we remove all these random point operations somehow? diff --git a/halo2-ecc/src/ecc/schnorr_signature.rs b/halo2-ecc/src/ecc/schnorr_signature.rs index aebdfeca..a124560f 100644 --- a/halo2-ecc/src/ecc/schnorr_signature.rs +++ b/halo2-ecc/src/ecc/schnorr_signature.rs @@ -1,6 +1,10 @@ use crate::bigint::{big_is_equal, ProperCrtUint}; -use crate::fields::{fp::FpChip, FieldChip, PrimeField}; -use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; +use crate::fields::{fp::FpChip, FieldChip}; +use halo2_base::{ + gates::GateInstructions, + utils::{BigPrimeField, CurveAffineExt}, + AssignedValue, Context, +}; use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; @@ -12,7 +16,7 @@ use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // this circuit applies over constraints that s > 0, msgHash > 0 cause scalar_multiply can't handle zero scalar /// `pubkey` should not be the identity point /// follow spec in https://github.com/bitcoin/bips/blob/master/bip-0340.mediawiki -pub fn schnorr_verify_no_pubkey_check( +pub fn schnorr_verify_no_pubkey_check( chip: &EccChip>, ctx: &mut Context, pubkey: EcPoint as FieldChip>::FieldPoint>, diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index 5bbc612e..02f549e3 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -1,34 +1,33 @@ #![allow(unused_assignments, unused_imports, unused_variables)] use super::*; use crate::fields::fp2::Fp2Chip; +use crate::group::Group; use crate::halo2_proofs::{ circuit::*, dev::MockProver, halo2curves::bn256::{Fq, Fr, G1Affine, G2Affine, G1, G2}, plonk::*, }; -use group::Group; -use halo2_base::gates::builder::RangeCircuitBuilder; use halo2_base::gates::RangeChip; use halo2_base::utils::bigint_to_fe; +use halo2_base::utils::testing::base_test; +use halo2_base::utils::value_to_option; use halo2_base::SKIP_FIRST_PASS; -use halo2_base::{gates::range::RangeStrategy, utils::value_to_option}; use num_bigint::{BigInt, RandBigInt}; use rand_core::OsRng; use std::marker::PhantomData; use std::ops::Neg; -fn basic_g1_tests( +fn basic_g1_tests( ctx: &mut Context, + range: &RangeChip, lookup_bits: usize, limb_bits: usize, num_limbs: usize, P: G1Affine, Q: G1Affine, ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range = RangeChip::::default(lookup_bits); - let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); + let fp_chip = FpChip::::new(range, limb_bits, num_limbs); let chip = EccChip::new(&fp_chip); let P_assigned = chip.load_private_unchecked(ctx, (P.x, P.y)); @@ -61,37 +60,9 @@ fn basic_g1_tests( #[test] fn test_ecc() { - let k = 23; - let P = G1Affine::random(OsRng); - let Q = G1Affine::random(OsRng); - - let mut builder = GateThreadBuilder::::mock(); - basic_g1_tests(builder.main(0), k - 1, 88, 3, P, Q); - - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_ecc() { - let k = 10; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (512, 16384)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Ecc Layout", ("sans-serif", 60)).unwrap(); - - let P = G1Affine::random(OsRng); - let Q = G1Affine::random(OsRng); - - let mut builder = GateThreadBuilder::::keygen(); - basic_g1_tests(builder.main(0), 22, 88, 3, P, Q); - - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::mock(builder); - - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); + base_test().k(23).lookup_bits(22).run(|ctx, range| { + let P = G1Affine::random(OsRng); + let Q = G1Affine::random(OsRng); + basic_g1_tests(ctx, range, 22, 88, 3, P, Q); + }); } diff --git a/halo2-ecc/src/fields/fp.rs b/halo2-ecc/src/fields/fp.rs index 6ec428ef..579ac39a 100644 --- a/halo2-ecc/src/fields/fp.rs +++ b/halo2-ecc/src/fields/fp.rs @@ -1,4 +1,4 @@ -use super::{FieldChip, PrimeField, PrimeFieldChip, Selectable}; +use super::{FieldChip, PrimeFieldChip, Selectable}; use crate::bigint::{ add_no_carry, big_is_equal, big_is_even, big_is_zero, carry_mod, check_carry_mod_to_zero, mul_no_carry, scalar_mul_and_add_no_carry, scalar_mul_no_carry, select, select_by_indicator, @@ -6,7 +6,7 @@ use crate::bigint::{ }; use crate::halo2_proofs::halo2curves::CurveAffine; use halo2_base::gates::RangeChip; -use halo2_base::utils::ScalarField; +use halo2_base::utils::{BigPrimeField, ScalarField}; use halo2_base::{ gates::{range::RangeConfig, GateInstructions, RangeInstructions}, utils::{bigint_to_fe, biguint_to_fe, bit_length, decompose_biguint, fe_to_biguint, modulus}, @@ -48,7 +48,7 @@ impl From, Fp>> for ProperCrtUint { +pub struct FpChip<'range, F: BigPrimeField, Fp: BigPrimeField> { pub range: &'range RangeChip, pub limb_bits: usize, @@ -68,7 +68,7 @@ pub struct FpChip<'range, F: PrimeField, Fp: PrimeField> { _marker: PhantomData, } -impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> FpChip<'range, F, Fp> { pub fn new(range: &'range RangeChip, limb_bits: usize, num_limbs: usize) -> Self { assert!(limb_bits > 0); assert!(num_limbs > 0); @@ -81,7 +81,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { let limb_base = biguint_to_fe::(&(BigUint::one() << limb_bits)); let mut limb_bases = Vec::with_capacity(num_limbs); - limb_bases.push(F::one()); + limb_bases.push(F::ONE); while limb_bases.len() != num_limbs { limb_bases.push(limb_base * limb_bases.last().unwrap()); } @@ -121,7 +121,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { }; borrow = Some(lt); } - self.gate().assert_is_const(ctx, &borrow.unwrap(), &F::one()); + self.gate().assert_is_const(ctx, &borrow.unwrap(), &F::ONE); } /// Given proper CRT integer `a`, returns 1 iff `a < modulus::()` @@ -166,7 +166,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { } } -impl<'range, F: PrimeField, Fp: PrimeField> PrimeFieldChip for FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> PrimeFieldChip for FpChip<'range, F, Fp> { fn num_limbs(&self) -> usize { self.num_limbs } @@ -178,7 +178,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> PrimeFieldChip for FpChip<'range, } } -impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> FieldChip for FpChip<'range, F, Fp> { const PRIME_FIELD_NUM_BITS: u32 = Fp::NUM_BITS; type UnsafeFieldPoint = CRTInteger; type FieldPoint = ProperCrtUint; @@ -267,7 +267,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F let (out_or_p, underflow) = sub::crt(self.range(), ctx, p, a.clone(), self.limb_bits, self.limb_bases[1]); // constrain underflow to equal 0 - self.gate().assert_is_const(ctx, &underflow, &F::zero()); + self.gate().assert_is_const(ctx, &underflow, &F::ZERO); let a_is_zero = big_is_zero::positive(self.gate(), ctx, a.0.truncation.clone()); ProperCrtUint(select::crt(self.gate(), ctx, a.0, out_or_p, a_is_zero)) @@ -435,7 +435,9 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F } } -impl<'range, F: PrimeField, Fp: PrimeField> Selectable> for FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> Selectable> + for FpChip<'range, F, Fp> +{ fn select( &self, ctx: &mut Context, @@ -456,7 +458,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> Selectable> for FpC } } -impl<'range, F: PrimeField, Fp: PrimeField> Selectable> +impl<'range, F: BigPrimeField, Fp: BigPrimeField> Selectable> for FpChip<'range, F, Fp> { fn select( @@ -480,7 +482,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> Selectable> } } -impl Selectable> for FC +impl Selectable> for FC where FC: Selectable, { diff --git a/halo2-ecc/src/fields/fp12.rs b/halo2-ecc/src/fields/fp12.rs index 156ca452..bdb9f790 100644 --- a/halo2-ecc/src/fields/fp12.rs +++ b/halo2-ecc/src/fields/fp12.rs @@ -1,15 +1,19 @@ use std::marker::PhantomData; -use halo2_base::{utils::modulus, AssignedValue, Context}; -use num_bigint::BigUint; - +use crate::ff::PrimeField as _; use crate::impl_field_ext_chip_common; use super::{ vector::{FieldVector, FieldVectorChip}, - FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, + FieldChip, FieldExtConstructor, PrimeFieldChip, }; +use halo2_base::{ + utils::{modulus, BigPrimeField}, + AssignedValue, Context, +}; +use num_bigint::BigUint; + /// Represent Fp12 point as FqPoint with degree = 12 /// `Fp12 = Fp2[w] / (w^6 - u - xi)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to @@ -17,17 +21,17 @@ use super::{ /// This means we store an Fp12 point as `\sum_{i = 0}^6 (a_{i0} + a_{i1} * u) * w^i` /// This is encoded in an FqPoint of degree 12 as `(a_{00}, ..., a_{50}, a_{01}, ..., a_{51})` #[derive(Clone, Copy, Debug)] -pub struct Fp12Chip<'a, F: PrimeField, FpChip: FieldChip, Fp12, const XI_0: i64>( +pub struct Fp12Chip<'a, F: BigPrimeField, FpChip: FieldChip, Fp12, const XI_0: i64>( pub FieldVectorChip<'a, F, FpChip>, PhantomData, ); impl<'a, F, FpChip, Fp12, const XI_0: i64> Fp12Chip<'a, F, FpChip, Fp12, XI_0> where - F: PrimeField, + F: BigPrimeField, FpChip: PrimeFieldChip, - FpChip::FieldType: PrimeField, - Fp12: ff::Field, + FpChip::FieldType: BigPrimeField, + Fp12: crate::ff::Field, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. pub fn new(fp_chip: &'a FpChip) -> Self { @@ -93,7 +97,7 @@ where /// /// # Assumptions /// * `a` is `Fp2` point represented as `FieldVector` with degree = 2 -pub fn mul_no_carry_w6, const XI_0: i64>( +pub fn mul_no_carry_w6, const XI_0: i64>( fp_chip: &FC, ctx: &mut Context, a: FieldVector, @@ -112,10 +116,10 @@ pub fn mul_no_carry_w6, const XI_0: i64>( impl<'a, F, FpChip, Fp12, const XI_0: i64> FieldChip for Fp12Chip<'a, F, FpChip, Fp12, XI_0> where - F: PrimeField, + F: BigPrimeField, FpChip: PrimeFieldChip, - FpChip::FieldType: PrimeField, - Fp12: ff::Field + FieldExtConstructor, + FpChip::FieldType: BigPrimeField, + Fp12: crate::ff::Field + FieldExtConstructor, FieldVector: From>, FieldVector: From>, { diff --git a/halo2-ecc/src/fields/fp2.rs b/halo2-ecc/src/fields/fp2.rs index 55e3243a..71c5d446 100644 --- a/halo2-ecc/src/fields/fp2.rs +++ b/halo2-ecc/src/fields/fp2.rs @@ -1,29 +1,30 @@ use std::fmt::Debug; use std::marker::PhantomData; -use halo2_base::{utils::modulus, AssignedValue, Context}; -use num_bigint::BigUint; - +use crate::ff::PrimeField as _; use crate::impl_field_ext_chip_common; use super::{ vector::{FieldVector, FieldVectorChip}, - FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, + BigPrimeField, FieldChip, FieldExtConstructor, PrimeFieldChip, }; +use halo2_base::{utils::modulus, AssignedValue, Context}; +use num_bigint::BigUint; /// Represent Fp2 point as `FieldVector` with degree = 2 /// `Fp2 = Fp[u] / (u^2 + 1)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to be irreducible over Fp; i.e., in order for -1 to not be a square (quadratic residue) in Fp /// This means we store an Fp2 point as `a_0 + a_1 * u` where `a_0, a_1 in Fp` #[derive(Clone, Copy, Debug)] -pub struct Fp2Chip<'a, F: PrimeField, FpChip: FieldChip, Fp2>( +pub struct Fp2Chip<'a, F: BigPrimeField, FpChip: FieldChip, Fp2>( pub FieldVectorChip<'a, F, FpChip>, PhantomData, ); -impl<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp2: ff::Field> Fp2Chip<'a, F, FpChip, Fp2> +impl<'a, F: BigPrimeField, FpChip: PrimeFieldChip, Fp2: crate::ff::Field> + Fp2Chip<'a, F, FpChip, Fp2> where - FpChip::FieldType: PrimeField, + FpChip::FieldType: BigPrimeField, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. pub fn new(fp_chip: &'a FpChip) -> Self { @@ -66,10 +67,10 @@ where impl<'a, F, FpChip, Fp2> FieldChip for Fp2Chip<'a, F, FpChip, Fp2> where - F: PrimeField, - FpChip::FieldType: PrimeField, + F: BigPrimeField, + FpChip::FieldType: BigPrimeField, FpChip: PrimeFieldChip, - Fp2: ff::Field + FieldExtConstructor, + Fp2: crate::ff::Field + FieldExtConstructor, FieldVector: From>, FieldVector: From>, { diff --git a/halo2-ecc/src/fields/mod.rs b/halo2-ecc/src/fields/mod.rs index 0c55affa..5b3bde39 100644 --- a/halo2-ecc/src/fields/mod.rs +++ b/halo2-ecc/src/fields/mod.rs @@ -16,13 +16,11 @@ pub mod vector; #[cfg(test)] mod tests; -pub trait PrimeField = BigPrimeField; - /// Trait for common functionality for finite field chips. /// Primarily intended to emulate a "non-native" finite field using "native" values in a prime field `F`. /// Most functions are designed for the case when the non-native field is larger than the native field, but /// the trait can still be implemented and used in other cases. -pub trait FieldChip: Clone + Send + Sync { +pub trait FieldChip: Clone + Send + Sync { const PRIME_FIELD_NUM_BITS: u32; /// A representation of a field element that is used for intermediate computations. @@ -211,7 +209,7 @@ pub trait FieldChip: Clone + Send + Sync { ) -> Self::FieldPoint { let b = b.into(); let b_is_zero = self.is_zero(ctx, b.clone()); - self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::ZERO); self.divide_unsafe(ctx, a.into(), b) } @@ -253,7 +251,7 @@ pub trait FieldChip: Clone + Send + Sync { ) -> Self::FieldPoint { let b = b.into(); let b_is_zero = self.is_zero(ctx, b.clone()); - self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::ZERO); self.neg_divide_unsafe(ctx, a.into(), b) } @@ -296,9 +294,9 @@ pub trait Selectable { } // Common functionality for prime field chips -pub trait PrimeFieldChip: FieldChip +pub trait PrimeFieldChip: FieldChip where - Self::FieldType: PrimeField, + Self::FieldType: BigPrimeField, { fn num_limbs(&self) -> usize; fn limb_mask(&self) -> &BigUint; @@ -307,7 +305,7 @@ where // helper trait so we can actually construct and read the Fp2 struct // needs to be implemented for Fp2 struct for use cases below -pub trait FieldExtConstructor { +pub trait FieldExtConstructor { fn new(c: [Fp; DEGREE]) -> Self; fn coeffs(&self) -> Vec; diff --git a/halo2-ecc/src/fields/tests/fp/assert_eq.rs b/halo2-ecc/src/fields/tests/fp/assert_eq.rs index a8184594..c39140d0 100644 --- a/halo2-ecc/src/fields/tests/fp/assert_eq.rs +++ b/halo2-ecc/src/fields/tests/fp/assert_eq.rs @@ -1,57 +1,52 @@ -use std::env::set_var; +use crate::ff::Field; +use crate::{bn254::FpChip, fields::FieldChip}; -use ff::Field; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use halo2_base::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - RangeChip, - }, halo2_proofs::{ halo2curves::bn256::Fq, plonk::keygen_pk, plonk::keygen_vk, poly::kzg::commitment::ParamsKZG, }, utils::testing::{check_proof, gen_proof}, }; - -use crate::{bn254::FpChip, fields::FieldChip}; use rand::thread_rng; // soundness checks for `` function fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { let mut rng = thread_rng(); - set_var("LOOKUP_BITS", lookup_bits.to_string()); // first create proving and verifying key - let mut builder = GateThreadBuilder::keygen(); - let range = RangeChip::default(lookup_bits); + let mut builder = RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen) + .use_k(k as usize) + .use_lookup_bits(lookup_bits); + let range = builder.range_chip(); let chip = FpChip::new(&range, 88, 3); let ctx = builder.main(0); let a = chip.load_private(ctx, Fq::zero()); let b = chip.load_private(ctx, Fq::zero()); chip.assert_equal(ctx, &a, &b); - // set env vars - builder.config(k as usize, Some(9)); - let circuit = RangeCircuitBuilder::keygen(builder); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::setup(k, &mut rng); // generate proving key - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = keygen_vk(¶ms, &builder).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); let vk = pk.get_vk(); // pk consumed vk + let break_points = builder.break_points(); + drop(builder); // now create different proofs to test the soundness of the circuit let gen_pf = |a: Fq, b: Fq| { - let mut builder = GateThreadBuilder::prover(); - let range = RangeChip::default(lookup_bits); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); + let range = builder.range_chip(); let chip = FpChip::new(&range, 88, 3); let ctx = builder.main(0); let [a, b] = [a, b].map(|x| chip.load_private(ctx, x)); chip.assert_equal(ctx, &a, &b); - let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points - gen_proof(¶ms, &pk, circuit) + gen_proof(¶ms, &pk, builder) }; // expected answer diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs index 675aab5a..d88d6a1a 100644 --- a/halo2-ecc/src/fields/tests/fp/mod.rs +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -1,15 +1,10 @@ -use std::env::set_var; - +use crate::ff::{Field as _, PrimeField as _}; use crate::fields::fp::FpChip; -use crate::fields::{FieldChip, PrimeField}; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{Fq, Fr}, -}; +use crate::fields::FieldChip; +use crate::halo2_proofs::halo2curves::bn256::{Fq, Fr}; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; -use halo2_base::gates::RangeChip; use halo2_base::utils::biguint_to_fe; +use halo2_base::utils::testing::base_test; use halo2_base::utils::{fe_to_biguint, modulus}; use halo2_base::Context; use rand::rngs::OsRng; @@ -25,16 +20,10 @@ fn fp_chip_test( num_limbs: usize, f: impl Fn(&mut Context, &FpChip), ) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range = RangeChip::::default(lookup_bits); - let chip = FpChip::::new(&range, limb_bits, num_limbs); - - let mut builder = GateThreadBuilder::mock(); - f(builder.main(0), &chip); - - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, range| { + let chip = FpChip::::new(range, limb_bits, num_limbs); + f(ctx, &chip); + }); } #[test] @@ -86,7 +75,7 @@ fn plot_fp() { let mut builder = GateThreadBuilder::keygen(); fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::keygen(builder); + let config_params = builder.config(k, Some(10), Some(k - 1)); + let circuit = RangeCircuitBuilder::keygen(builder, config_params); halo2_proofs::dev::CircuitLayout::default().render(k as u32, &circuit, &root).unwrap(); } diff --git a/halo2-ecc/src/fields/tests/fp12/mod.rs b/halo2-ecc/src/fields/tests/fp12/mod.rs index 6fb631b9..dbd618c9 100644 --- a/halo2-ecc/src/fields/tests/fp12/mod.rs +++ b/halo2-ecc/src/fields/tests/fp12/mod.rs @@ -1,37 +1,33 @@ +use crate::ff::Field as _; use crate::fields::fp::FpChip; use crate::fields::fp12::Fp12Chip; -use crate::fields::{FieldChip, PrimeField}; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{Fq, Fq12, Fr}, -}; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; -use halo2_base::gates::RangeChip; -use halo2_base::Context; +use crate::fields::FieldChip; +use crate::halo2_proofs::halo2curves::bn256::{Fq, Fq12}; +use halo2_base::utils::testing::base_test; use rand_core::OsRng; const XI_0: i64 = 9; -fn fp12_mul_test( - ctx: &mut Context, +fn fp12_mul_test( + k: u32, lookup_bits: usize, limb_bits: usize, num_limbs: usize, _a: Fq12, _b: Fq12, ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range = RangeChip::::default(lookup_bits); - let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); - let chip = Fp12Chip::::new(&fp_chip); - - let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); - let c = chip.mul(ctx, a, b).into(); - - assert_eq!(chip.get_assigned_value(&c), _a * _b); - for c in c.into_iter() { - assert_eq!(c.truncation.to_bigint(limb_bits), c.value); - } + base_test().k(k).lookup_bits(lookup_bits).run(|ctx, range| { + let fp_chip = FpChip::<_, Fq>::new(range, limb_bits, num_limbs); + let chip = Fp12Chip::<_, _, Fq12, XI_0>::new(&fp_chip); + + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b).into(); + + assert_eq!(chip.get_assigned_value(&c), _a * _b); + for c in c.into_iter() { + assert_eq!(c.truncation.to_bigint(limb_bits), c.value); + } + }); } #[test] @@ -40,34 +36,5 @@ fn test_fp12() { let a = Fq12::random(OsRng); let b = Fq12::random(OsRng); - let mut builder = GateThreadBuilder::::mock(); - fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); - - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_fp12() { - use ff::Field; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); - - let k = 23; - let a = Fq12::zero(); - let b = Fq12::zero(); - - let mut builder = GateThreadBuilder::::mock(); - fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); - - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); - - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); + fp12_mul_test(k, k as usize - 1, 88, 3, a, b); } diff --git a/halo2-ecc/src/fields/vector.rs b/halo2-ecc/src/fields/vector.rs index 6aea9d97..d27dc25f 100644 --- a/halo2-ecc/src/fields/vector.rs +++ b/halo2-ecc/src/fields/vector.rs @@ -1,4 +1,8 @@ -use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use halo2_base::{ + gates::GateInstructions, + utils::{BigPrimeField, ScalarField}, + AssignedValue, Context, +}; use itertools::Itertools; use std::{ marker::PhantomData, @@ -7,7 +11,7 @@ use std::{ use crate::bigint::{CRTInteger, ProperCrtUint}; -use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, Selectable}; +use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeFieldChip, Selectable}; /// A fixed length vector of `FieldPoint`s #[repr(transparent)] @@ -63,16 +67,16 @@ impl IntoIterator for FieldVector { /// Contains common functionality for vector operations that can be derived from those of the underlying `FpChip` #[derive(Clone, Copy, Debug)] -pub struct FieldVectorChip<'fp, F: PrimeField, FpChip: FieldChip> { +pub struct FieldVectorChip<'fp, F: BigPrimeField, FpChip: FieldChip> { pub fp_chip: &'fp FpChip, _f: PhantomData, } impl<'fp, F, FpChip> FieldVectorChip<'fp, F, FpChip> where - F: PrimeField, + F: BigPrimeField, FpChip: PrimeFieldChip, - FpChip::FieldType: PrimeField, + FpChip::FieldType: BigPrimeField, { pub fn new(fp_chip: &'fp FpChip) -> Self { Self { fp_chip, _f: PhantomData } diff --git a/halo2-ecc/src/lib.rs b/halo2-ecc/src/lib.rs index 10da56bc..c4a47c15 100644 --- a/halo2-ecc/src/lib.rs +++ b/halo2-ecc/src/lib.rs @@ -1,7 +1,6 @@ #![allow(clippy::too_many_arguments)] #![allow(clippy::op_ref)] #![allow(clippy::type_complexity)] -#![feature(int_log)] #![feature(trait_alias)] pub mod bigint; @@ -13,3 +12,6 @@ pub mod secp256k1; pub use halo2_base; pub(crate) use halo2_base::halo2_proofs; +use halo2_proofs::halo2curves; +use halo2curves::ff; +use halo2curves::group; diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index af7050f9..a6dfd993 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -1,44 +1,29 @@ #![allow(non_snake_case)] +use std::fs::File; +use std::io::BufReader; +use std::io::Write; +use std::{fs, io::BufRead}; + +use super::*; use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, - dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::bn256::Fr, halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, - plonk::*, - poly::commitment::ParamsProver, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, -}; -use crate::halo2_proofs::{ - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; use crate::secp256k1::{FpChip, FqChip}; use crate::{ ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, - fields::{FieldChip, PrimeField}, -}; -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + fields::FieldChip, }; use halo2_base::gates::RangeChip; -use halo2_base::utils::fs::gen_srs; -use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; +use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus, BigPrimeField}; use halo2_base::Context; -use rand_core::OsRng; use serde::{Deserialize, Serialize}; -use std::fs::File; -use std::io::BufReader; -use std::io::Write; -use std::{fs, io::BufRead}; +use test_log::test; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct CircuitParams { +pub struct CircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -49,86 +34,74 @@ struct CircuitParams { num_limbs: usize, } -fn ecdsa_test( +#[derive(Clone, Copy, Debug)] +pub struct ECDSAInput { + pub r: Fq, + pub s: Fq, + pub msghash: Fq, + pub pk: Secp256k1Affine, +} + +pub fn ecdsa_test( ctx: &mut Context, + range: &RangeChip, params: CircuitParams, - r: Fq, - s: Fq, - msghash: Fq, - pk: Secp256k1Affine, -) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); - - let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); + input: ECDSAInput, +) -> F { + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(range, params.limb_bits, params.num_limbs); + + let [m, r, s] = [input.msghash, input.r, input.s].map(|x| fq_chip.load_private(ctx, x)); let ecc_chip = EccChip::>::new(&fp_chip); - let pk = ecc_chip.load_private_unchecked(ctx, (pk.x, pk.y)); + let pk = ecc_chip.load_private_unchecked(ctx, (input.pk.x, input.pk.y)); // test ECDSA let res = ecdsa_verify_no_pubkey_check::( &ecc_chip, ctx, pk, r, s, m, 4, 4, ); - assert_eq!(res.value(), &F::one()); + *res.value() } -fn random_ecdsa_circuit( - params: CircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); - let msg_hash = ::ScalarExt::random(OsRng); - - let k = ::ScalarExt::random(OsRng); +pub fn random_ecdsa_input(rng: &mut StdRng) -> ECDSAInput { + let sk = ::ScalarExt::random(rng.clone()); + let pk = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msghash = ::ScalarExt::random(rng.clone()); + + let k = ::ScalarExt::random(rng); let k_inv = k.invert().unwrap(); let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); let x = r_point.x(); let x_bigint = fe_to_biguint(x); let r = biguint_to_fe::(&(x_bigint % modulus::())); - let s = k_inv * (msg_hash + (r * sk)); - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit + let s = k_inv * (msghash + (r * sk)); + + ECDSAInput { r, s, msghash, pk } } -#[test] -fn test_secp256k1_ecdsa() { +pub fn run_test(input: ECDSAInput) { let path = "configs/secp256k1/ecdsa_circuit.config"; let params: CircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - let circuit = random_ecdsa_circuit(params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let res = base_test() + .k(params.degree) + .lookup_bits(params.lookup_bits) + .run(|ctx, range| ecdsa_test(ctx, range, params, input)); + assert_eq!(res, Fr::ONE); +} + +#[test] +fn test_secp256k1_ecdsa() { + let mut rng = StdRng::seed_from_u64(0); + let input = random_ecdsa_input(&mut rng); + run_test(input); } #[test] fn bench_secp256k1_ecdsa() -> Result<(), Box> { - let mut rng = OsRng; let config_path = "configs/secp256k1/bench_ecdsa.config"; let bench_params_file = File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); @@ -138,74 +111,21 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; + let mut rng = StdRng::seed_from_u64(0); let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: CircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let params = gen_srs(k); - println!("{bench_params:?}"); - - let circuit = random_ecdsa_circuit(bench_params, CircuitBuilderStage::Keygen, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = - random_ecdsa_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], &mut rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/ecdsa_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs + let stats = + base_test().k(k).lookup_bits(bench_params.lookup_bits).unusable_rows(20).bench_builder( + random_ecdsa_input(&mut rng), + random_ecdsa_input(&mut rng), + |pool, range, input| { + ecdsa_test(pool.main(), range, bench_params, input); + }, ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); writeln!( fs_results, @@ -217,9 +137,9 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index 45e251f3..46bb6481 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -1,73 +1,17 @@ #![allow(non_snake_case)] +use crate::ff::Field as _; use crate::halo2_proofs::{ arithmetic::CurveAffine, - dev::MockProver, - halo2curves::bn256::Fr, - halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, -}; -use crate::secp256k1::{FpChip, FqChip}; -use crate::{ - ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, - fields::{FieldChip, PrimeField}, -}; -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + halo2curves::secp256k1::{Fq, Secp256k1Affine}, }; -use halo2_base::gates::RangeChip; use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; -use halo2_base::Context; use rand::random; -use rand_core::OsRng; -use std::fs::File; use test_case::test_case; -use super::CircuitParams; - -fn ecdsa_test( - ctx: &mut Context, - params: CircuitParams, - r: Fq, - s: Fq, - msghash: Fq, - pk: Secp256k1Affine, -) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); - - let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); - - let ecc_chip = EccChip::>::new(&fp_chip); - let pk = ecc_chip.assign_point(ctx, pk); - // test ECDSA - let res = ecdsa_verify_no_pubkey_check::( - &ecc_chip, ctx, pk, r, s, m, 4, 4, - ); - assert_eq!(res.value(), &F::one()); -} - -fn random_parameters_ecdsa() -> (Fq, Fq, Fq, Secp256k1Affine) { - let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); - let msg_hash = ::ScalarExt::random(OsRng); - - let k = ::ScalarExt::random(OsRng); - let k_inv = k.invert().unwrap(); - - let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); - let x = r_point.x(); - let x_bigint = fe_to_biguint(x); +use super::ecdsa::{run_test, ECDSAInput}; - let r = biguint_to_fe::(&(x_bigint % modulus::())); - let s = k_inv * (msg_hash + (r * sk)); - - (r, s, msg_hash, pubkey) -} - -fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> (Fq, Fq, Fq, Secp256k1Affine) { +fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> ECDSAInput { let sk = ::ScalarExt::from(sk); let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let msg_hash = ::ScalarExt::from(msg_hash); @@ -82,110 +26,32 @@ fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> (Fq, Fq, Fq, Secp2 let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - (r, s, msg_hash, pubkey) -} - -fn ecdsa_circuit( - r: Fq, - s: Fq, - msg_hash: Fq, - pubkey: Secp256k1Affine, - params: CircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit + ECDSAInput { r, s, msghash: msg_hash, pk: pubkey } } #[test] #[should_panic(expected = "assertion failed: `(left == right)`")] fn test_ecdsa_msg_hash_zero() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(random::(), 0, random::()); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_ecdsa(random::(), 0, random::()); + run_test(input); } #[test] #[should_panic(expected = "assertion failed: `(left == right)`")] fn test_ecdsa_private_key_zero() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(0, random::(), random::()); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[test] -fn test_ecdsa_random_valid_inputs() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = random_parameters_ecdsa(); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_ecdsa(0, random::(), random::()); + run_test(input); } #[test_case(1, 1, 1; "")] fn test_ecdsa_custom_valid_inputs(sk: u64, msg_hash: u64, k: u64) { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_ecdsa(sk, msg_hash, k); + run_test(input); } #[test_case(1, 1, 1; "")] fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64, msg_hash: u64, k: u64) { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); - let s = -s; - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let mut input = custom_parameters_ecdsa(sk, msg_hash, k); + input.s = -input.s; + run_test(input); } diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index 5938fe3b..b83720b1 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1,28 +1,17 @@ #![allow(non_snake_case)] use std::fs::File; -use ff::Field; -use group::Curve; +use crate::ff::Field; +use crate::group::Curve; use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - halo2_proofs::{ - dev::MockProver, - halo2curves::{ - bn256::Fr, - secp256k1::{Fq, Secp256k1Affine}, - }, - }, - utils::{biguint_to_fe, fe_to_biguint, BigPrimeField}, + gates::RangeChip, + halo2_proofs::halo2curves::secp256k1::{Fq, Secp256k1Affine}, + utils::{biguint_to_fe, fe_to_biguint, testing::base_test, BigPrimeField}, Context, }; use num_bigint::BigUint; -use rand_core::OsRng; +use rand::rngs::StdRng; +use rand_core::SeedableRng; use serde::{Deserialize, Serialize}; use crate::{ @@ -33,11 +22,11 @@ use crate::{ pub mod ecdsa; pub mod ecdsa_tests; - +pub mod schnorr_signature; pub mod schnorr_signature_tests; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct CircuitParams { +pub struct CircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -50,15 +39,14 @@ struct CircuitParams { fn sm_test( ctx: &mut Context, + range: &RangeChip, params: CircuitParams, base: Secp256k1Affine, scalar: Fq, window_bits: usize, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::>::new(&fp_chip); let s = fq_chip.load_private(ctx, scalar); @@ -80,63 +68,32 @@ fn sm_test( assert_eq!(sm_y, fe_to_biguint(&sm_answer.y)); } -fn sm_circuit( - params: CircuitParams, - stage: CircuitBuilderStage, - break_points: Option, - base: Secp256k1Affine, - scalar: Fq, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = GateThreadBuilder::new(stage == CircuitBuilderStage::Prover); - - sm_test(builder.main(0), params, base, scalar, 4); - - match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - } -} - -#[test] -fn test_secp_sm_random() { +fn run_test(base: Secp256k1Affine, scalar: Fq) { let path = "configs/secp256k1/ecdsa_circuit.config"; let params: CircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - let circuit = sm_circuit( - params, - CircuitBuilderStage::Mock, - None, - Secp256k1Affine::random(OsRng), - Fq::random(OsRng), - ); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run(|ctx, range| { + sm_test(ctx, range, params, base, scalar, 4); + }); } #[test] -fn test_secp_sm_minus_1() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); +fn test_secp_sm_random() { + let mut rng = StdRng::seed_from_u64(0); + run_test(Secp256k1Affine::random(&mut rng), Fq::random(&mut rng)); +} - let base = Secp256k1Affine::random(OsRng); +#[test] +fn test_secp_sm_minus_1() { + let rng = StdRng::seed_from_u64(0); + let base = Secp256k1Affine::random(rng); let mut s = -Fq::one(); let mut n = fe_to_biguint(&s); loop { - let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(base, s); if &n % BigUint::from(2usize) == BigUint::from(0usize) { break; } @@ -147,19 +104,8 @@ fn test_secp_sm_minus_1() { #[test] fn test_secp_sm_0_1() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let base = Secp256k1Affine::random(OsRng); - let s = Fq::zero(); - let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); - - let s = Fq::one(); - let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let rng = StdRng::seed_from_u64(0); + let base = Secp256k1Affine::random(rng); + run_test(base, Fq::ZERO); + run_test(base, Fq::ONE); } -pub mod sm_unsafe_scalars; diff --git a/halo2-ecc/src/secp256k1/tests/schnorr_signature.rs b/halo2-ecc/src/secp256k1/tests/schnorr_signature.rs new file mode 100644 index 00000000..842a4693 --- /dev/null +++ b/halo2-ecc/src/secp256k1/tests/schnorr_signature.rs @@ -0,0 +1,151 @@ +#![allow(non_snake_case)] +use crate::halo2_proofs::{ + arithmetic::CurveAffine, + halo2curves::bn256::Fr, + halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, +}; +use crate::secp256k1::{FpChip, FqChip}; +use crate::{ + ecc::{schnorr_signature::schnorr_verify_no_pubkey_check, EccChip}, + fields::FieldChip, +}; +use halo2_base::gates::RangeChip; +use halo2_base::utils::fe_to_biguint; +use halo2_base::utils::BigPrimeField; +use halo2_base::Context; +use halo2_base::{halo2_proofs::arithmetic::Field, utils::testing::base_test}; +use num_bigint::BigUint; +use num_integer::Integer; +use rand::rngs::StdRng; +use rand_core::SeedableRng; +use std::fs::File; +use std::io::BufReader; +use std::io::Write; +use std::{fs, io::BufRead}; + +use super::CircuitParams; + +#[derive(Clone, Copy, Debug)] +pub struct SchnorrInput { + pub r: Fp, + pub s: Fq, + pub msg_hash: Fq, + pub pk: Secp256k1Affine, +} + +pub fn schnorr_signature_test( + ctx: &mut Context, + range: &RangeChip, + params: CircuitParams, + input: SchnorrInput, +) -> F { + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(range, params.limb_bits, params.num_limbs); + + let [m, s] = [input.msg_hash, input.s].map(|x| fq_chip.load_private(ctx, x)); + let r = fp_chip.load_private(ctx, input.r); + + let ecc_chip = EccChip::>::new(&fp_chip); + let pk = ecc_chip.assign_point(ctx, input.pk); + // test schnorr signature + let res = schnorr_verify_no_pubkey_check::( + &ecc_chip, ctx, pk, r, s, m, 4, 4, + ); + *res.value() +} + +// This function mut rng internal state +pub fn random_schnorr_signature_input(rng: &mut StdRng) -> SchnorrInput { + let mut tmp = rng.clone(); + let sk = ::ScalarExt::random(&mut tmp); + let pk = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msg_hash = ::ScalarExt::random(&mut tmp); + + let mut k = ::ScalarExt::random(&mut tmp); + + let mut r_point = + Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); + let mut x: &Fp = r_point.x(); + let mut y: &Fp = r_point.y(); + // make sure R.y is even + while fe_to_biguint(y).mod_floor(&BigUint::from(2u64)) != BigUint::from(0u64) { + k = ::ScalarExt::random(&mut tmp); + r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); + x = r_point.x(); + y = r_point.y(); + } + + let r = *x; + let s = k + sk * msg_hash; + + // change rng internal state + *rng = tmp; + + SchnorrInput { r, s, msg_hash, pk } +} + +pub fn run_test(input: SchnorrInput) { + let path = "configs/secp256k1/schnorr_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let res = base_test() + .k(params.degree) + .lookup_bits(params.lookup_bits) + .run(|ctx, range| schnorr_signature_test(ctx, range, params, input)); + assert_eq!(res, Fr::ONE); +} + +#[test] +fn test_secp256k1_schnorr() { + let mut rng = StdRng::seed_from_u64(0); + let input = random_schnorr_signature_input(&mut rng); + run_test(input); +} + +#[test] +fn bench_secp256k1_schnorr() -> Result<(), Box> { + let mut rng = StdRng::from_seed([0u8; 32]); + let config_path = "configs/secp256k1/bench_schnorr.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/secp256k1").unwrap(); + fs::create_dir_all("data").unwrap(); + let results_path = "results/secp256k1/schnorr_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); + writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; + + let bench_params_reader = BufReader::new(bench_params_file); + for line in bench_params_reader.lines() { + let bench_params: CircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); + + let stats = + base_test().k(k).lookup_bits(bench_params.lookup_bits).unusable_rows(20).bench_builder( + random_schnorr_signature_input(&mut rng), + random_schnorr_signature_input(&mut rng), + |pool, range, input| { + schnorr_signature_test(pool.main(), range, bench_params, input); + }, + ); + + writeln!( + fs_results, + "{},{},{},{},{},{},{},{:?},{},{:?}", + bench_params.degree, + bench_params.num_advice, + bench_params.num_lookup_advice, + bench_params.num_fixed, + bench_params.lookup_bits, + bench_params.limb_bits, + bench_params.num_limbs, + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() + )?; + } + Ok(()) +} diff --git a/halo2-ecc/src/secp256k1/tests/schnorr_signature_tests.rs b/halo2-ecc/src/secp256k1/tests/schnorr_signature_tests.rs index efcbeeea..a2d36501 100644 --- a/halo2-ecc/src/secp256k1/tests/schnorr_signature_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/schnorr_signature_tests.rs @@ -1,108 +1,13 @@ -#![allow(non_snake_case)] -use crate::halo2_proofs::{ - arithmetic::CurveAffine, - dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, - halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, - plonk::*, - poly::commitment::ParamsProver, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, -}; -use crate::halo2_proofs::{ - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, -}; -use crate::secp256k1::{FpChip, FqChip}; -use crate::{ - ecc::{schnorr_signature::schnorr_verify_no_pubkey_check, EccChip}, - fields::{FieldChip, PrimeField}, -}; -use ark_std::{end_timer, start_timer}; -use halo2_base::utils::fs::gen_srs; -use halo2_base::{ - gates::builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, - }, - utils::fe_to_biguint, -}; -use num_bigint::BigUint; - -use halo2_base::gates::RangeChip; -use halo2_base::Context; -use num_integer::Integer; -use rand::random; -use rand::rngs::StdRng; +use halo2_base::halo2_proofs::halo2curves::{secp256k1::Secp256k1Affine, CurveAffine}; +use rand::{random, rngs::StdRng}; use rand_core::SeedableRng; -use std::fs::File; -use std::io::BufReader; -use std::io::Write; -use std::{fs, io::BufRead}; use test_case::test_case; -use super::CircuitParams; - -fn schnorr_signature_test( - ctx: &mut Context, - params: CircuitParams, - r: Fp, - s: Fq, - msghash: Fq, - pk: Secp256k1Affine, -) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); - - let [m, s] = [msghash, s].map(|x| fq_chip.load_private(ctx, x)); - let r = fp_chip.load_private(ctx, r); - - let ecc_chip = EccChip::>::new(&fp_chip); - let pk = ecc_chip.assign_point(ctx, pk); - // test schnorr signature - let res = schnorr_verify_no_pubkey_check::( - &ecc_chip, ctx, pk, r, s, m, 4, 4, - ); - assert_eq!(res.value(), &F::one()); -} - -fn random_parameters_schnorr_signature() -> (Fp, Fq, Fq, Secp256k1Affine) { - let sk = ::ScalarExt::random(StdRng::from_seed([0u8; 32])); - let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); - let msg_hash = - ::ScalarExt::random(StdRng::from_seed([0u8; 32])); - - let mut k = ::ScalarExt::random(StdRng::from_seed([0u8; 32])); +use super::schnorr_signature::{random_schnorr_signature_input, run_test, SchnorrInput}; - let mut r_point = - Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); - let mut x: &Fp = r_point.x(); - let mut y: &Fp = r_point.y(); - // make sure R.y is even - while fe_to_biguint(y).mod_floor(&BigUint::from(2u64)) != BigUint::from(0u64) { - k = ::ScalarExt::random(StdRng::from_seed([0u8; 32])); - r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); - x = r_point.x(); - y = r_point.y(); - } - - let r = *x; - let s = k + sk * msg_hash; - - (r, s, msg_hash, pubkey) -} - -fn custom_parameters_schnorr_signature( - sk: u64, - msg_hash: u64, - k: u64, -) -> (Fp, Fq, Fq, Secp256k1Affine) { +fn custom_parameters_schnorr_signature(sk: u64, msg_hash: u64, k: u64) -> SchnorrInput { let sk = ::ScalarExt::from(sk); - let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let pk = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let msg_hash = ::ScalarExt::from(msg_hash); let k = ::ScalarExt::from(k); @@ -113,239 +18,41 @@ fn custom_parameters_schnorr_signature( let r = *x; let s = k + sk * msg_hash; - (r, s, msg_hash, pubkey) -} - -fn schnorr_signature_circuit( - r: Fp, - s: Fq, - msg_hash: Fq, - pubkey: Secp256k1Affine, - params: CircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - schnorr_signature_test(builder.main(0), params, r, s, msg_hash, pubkey); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit + SchnorrInput { r, s, msg_hash, pk } } #[test] #[should_panic(expected = "assertion failed: `(left == right)`")] fn test_schnorr_signature_msg_hash_zero() { - let path = "configs/secp256k1/schnorr_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = - custom_parameters_schnorr_signature(random::(), 0, random::()); - - let circuit = - schnorr_signature_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_schnorr_signature(random::(), 0, random::()); + run_test(input); } #[test] #[should_panic(expected = "assertion failed: `(left == right)`")] fn test_schnorr_signature_private_key_zero() { - let path = "configs/secp256k1/schnorr_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = - custom_parameters_schnorr_signature(0, random::(), random::()); - - let circuit = - schnorr_signature_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_schnorr_signature(0, random::(), random::()); + run_test(input); } #[test_case(1, 1, 0; "")] #[should_panic(expected = "assertion failed: `(left == right)`")] fn test_schnorr_signature_k_zero(sk: u64, msg_hash: u64, k: u64) { - let path = "configs/secp256k1/schnorr_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_schnorr_signature(sk, msg_hash, k); - - let circuit = - schnorr_signature_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_schnorr_signature(sk, msg_hash, k); + run_test(input); } #[test] fn test_schnorr_signature_random_valid_inputs() { - let path = "configs/secp256k1/schnorr_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); + let mut rng = StdRng::seed_from_u64(0); for _ in 0..10 { - let (r, s, msg_hash, pubkey) = random_parameters_schnorr_signature(); - - let circuit = schnorr_signature_circuit( - r, - s, - msg_hash, - pubkey, - params, - CircuitBuilderStage::Mock, - None, - ); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = random_schnorr_signature_input(&mut rng); + run_test(input); } } #[test_case(1, 1, 1; "")] fn test_schnorr_signature_custom_valid_inputs(sk: u64, msg_hash: u64, k: u64) { - let path = "configs/secp256k1/schnorr_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_schnorr_signature(sk, msg_hash, k); - - let circuit = - schnorr_signature_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[test] -fn bench_secp256k1_schnorr() -> Result<(), Box> { - let mut rng = StdRng::from_seed([0u8; 32]); - let config_path = "configs/secp256k1/bench_schnorr.config"; - let bench_params_file = - File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); - fs::create_dir_all("results/secp256k1").unwrap(); - fs::create_dir_all("data").unwrap(); - let results_path = "results/secp256k1/schnorr_bench.csv"; - let mut fs_results = File::create(results_path).unwrap(); - writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; - - let bench_params_reader = BufReader::new(bench_params_file); - for line in bench_params_reader.lines() { - let bench_params: CircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - let k = bench_params.degree; - let (r, s, msg_hash, pubkey) = random_parameters_schnorr_signature(); - println!("---------------------- degree = {k} ------------------------------",); - - let params = gen_srs(k); - println!("{bench_params:?}"); - let circuit = schnorr_signature_circuit( - r, - s, - msg_hash, - pubkey, - bench_params, - CircuitBuilderStage::Keygen, - None, - ); - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = schnorr_signature_circuit( - r, - s, - msg_hash, - pubkey, - bench_params, - CircuitBuilderStage::Prover, - Some(break_points), - ); - - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], &mut rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/schnorr_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); - - writeln!( - fs_results, - "{},{},{},{},{},{},{},{:?},{},{:?}", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() - )?; - } - Ok(()) + let input = custom_parameters_schnorr_signature(sk, msg_hash, k); + run_test(input); } diff --git a/hashes/poseidon/Cargo.toml b/hashes/poseidon/Cargo.toml deleted file mode 100644 index efc3b1e8..00000000 --- a/hashes/poseidon/Cargo.toml +++ /dev/null @@ -1,28 +0,0 @@ -[package] -name = "poseidon" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -[dependencies] -array-init = "2.0.0" -rand = "0.8" -itertools = "0.10.3" -lazy_static = "1.4" -log = "0.4" -num-bigint = { version = "0.4" } -halo2-base = { path = "../../halo2-base", default-features = false, features = ["halo2-axiom"] } -rayon = "1.6.1" -poseidon = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/dev", package = "poseidon" } - -[dev-dependencies] -hex = "0.4.3" -itertools = "0.10.1" -pretty_assertions = "1.0.0" -rand_core = "0.6.4" -rand_xorshift = "0.3" -env_logger = "0.10" - -[features] -default = ["display"] -display = ["halo2-base/display"] \ No newline at end of file diff --git a/hashes/poseidon/src/lib.rs b/hashes/poseidon/src/lib.rs deleted file mode 100644 index 952b8288..00000000 --- a/hashes/poseidon/src/lib.rs +++ /dev/null @@ -1,211 +0,0 @@ -// impl taken from https://github.com/scroll-tech/halo2-snark-aggregator/tree/main/halo2-snark-aggregator-api/src/hash - -use ::poseidon::{SparseMDSMatrix, Spec, State}; -use halo2_base::halo2_proofs::plonk::Error; -use halo2_base::{ - gates::GateInstructions, - utils::ScalarField, - AssignedValue, Context, - QuantumCell::{Constant, Existing}, -}; - -pub mod tests; - -struct PoseidonState { - s: [AssignedValue; T], -} - -impl PoseidonState { - fn x_power5_with_constant( - ctx: &mut Context, - gate: &impl GateInstructions, - x: AssignedValue, - constant: &F, - ) -> AssignedValue { - let x2 = gate.mul(ctx, x, x); - let x4 = gate.mul(ctx, x2, x2); - gate.mul_add(ctx, x, x4, Constant(*constant)) - } - - fn sbox_full( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - constants: &[F; T], - ) { - for (x, constant) in self.s.iter_mut().zip(constants.iter()) { - *x = Self::x_power5_with_constant(ctx, gate, *x, constant); - } - } - - fn sbox_part(&mut self, ctx: &mut Context, gate: &impl GateInstructions, constant: &F) { - let x = &mut self.s[0]; - *x = Self::x_power5_with_constant(ctx, gate, *x, constant); - } - - fn absorb_with_pre_constants( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - inputs: Vec>, - pre_constants: &[F; T], - ) { - assert!(inputs.len() < T); - let offset = inputs.len() + 1; - - // Explanation of what's going on: before each round of the poseidon permutation, - // two things have to be added to the state: inputs (the absorbed elements) and - // preconstants. Imagine the state as a list of T elements, the first of which is - // the capacity: |--cap--|--el1--|--el2--|--elR--| - // - A preconstant is added to each of all T elements (which is different for each) - // - The inputs are added to all elements starting from el1 (so, not to the capacity), - // to as many elements as inputs are available. - // - To the first element for which no input is left (if any), an extra 1 is added. - - // adding preconstant to the distinguished capacity element (only one) - self.s[0] = gate.add(ctx, self.s[0], Constant(pre_constants[0])); - - // adding pre-constants and inputs to the elements for which both are available - for ((x, constant), input) in - self.s.iter_mut().skip(1).zip(pre_constants.iter().skip(1)).zip(inputs.iter()) - { - *x = gate.sum(ctx, [Existing(*x), Existing(*input), Constant(*constant)]); - } - - // adding only pre-constants when no input is left - for (i, (x, constant)) in - self.s.iter_mut().skip(offset).zip(pre_constants.iter().skip(offset)).enumerate() - { - *x = gate.add( - ctx, - Existing(*x), - Constant(if i == 0 { F::one() + constant } else { *constant }), - ); - } - } - - fn apply_mds( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - mds: &[[F; T]; T], - ) { - let res = mds - .iter() - .map(|row| { - gate.inner_product(ctx, self.s.iter().copied(), row.iter().map(|c| Constant(*c))) - }) - .collect::>(); - - self.s = res.try_into().unwrap(); - } - - fn apply_sparse_mds( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - mds: &SparseMDSMatrix, - ) { - let sum = - gate.inner_product(ctx, self.s.iter().copied(), mds.row().iter().map(|c| Constant(*c))); - let mut res = vec![sum]; - - for (e, x) in mds.col_hat().iter().zip(self.s.iter().skip(1)) { - res.push(gate.mul_add(ctx, self.s[0], Constant(*e), *x)); - } - - for (x, new_x) in self.s.iter_mut().zip(res.into_iter()) { - *x = new_x - } - } -} - -pub struct PoseidonChip { - init_state: [AssignedValue; T], - state: PoseidonState, - spec: Spec, - absorbing: Vec>, -} - -impl PoseidonChip { - pub fn new(ctx: &mut Context, r_f: usize, r_p: usize) -> Result { - let init_state = State::::default() - .words() - .into_iter() - .map(|x| ctx.load_constant(x)) - .collect::>>(); - Ok(Self { - spec: Spec::new(r_f, r_p), - init_state: init_state.clone().try_into().unwrap(), - state: PoseidonState { s: init_state.try_into().unwrap() }, - absorbing: Vec::new(), - }) - } - - pub fn clear(&mut self) { - self.state = PoseidonState { s: self.init_state }; - self.absorbing.clear(); - } - - pub fn update(&mut self, elements: &[AssignedValue]) { - self.absorbing.extend_from_slice(elements); - } - - pub fn squeeze( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - ) -> Result, Error> { - let mut input_elements = vec![]; - input_elements.append(&mut self.absorbing); - - let mut padding_offset = 0; - - for chunk in input_elements.chunks(RATE) { - padding_offset = RATE - chunk.len(); - self.permutation(ctx, gate, chunk.to_vec()); - } - - if padding_offset == 0 { - self.permutation(ctx, gate, vec![]); - } - - Ok(self.state.s[1]) - } - - fn permutation( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - inputs: Vec>, - ) { - let r_f = self.spec.r_f() / 2; - let mds = &self.spec.mds_matrices().mds().rows(); - - let constants = self.spec.constants().start(); - self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); - for constants in constants.iter().skip(1).take(r_f - 1) { - self.state.sbox_full(ctx, gate, constants); - self.state.apply_mds(ctx, gate, mds); - } - - let pre_sparse_mds = &self.spec.mds_matrices().pre_sparse_mds().rows(); - self.state.sbox_full(ctx, gate, constants.last().unwrap()); - self.state.apply_mds(ctx, gate, pre_sparse_mds); - - let sparse_matrices = &self.spec.mds_matrices().sparse_matrices(); - let constants = &self.spec.constants().partial(); - for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { - self.state.sbox_part(ctx, gate, constant); - self.state.apply_sparse_mds(ctx, gate, sparse_mds); - } - - let constants = &self.spec.constants().end(); - for constants in constants.iter() { - self.state.sbox_full(ctx, gate, constants); - self.state.apply_mds(ctx, gate, mds); - } - self.state.sbox_full(ctx, gate, &[F::zero(); T]); - self.state.apply_mds(ctx, gate, mds); - } -} diff --git a/hashes/poseidon/src/tests.rs b/hashes/poseidon/src/tests.rs deleted file mode 100644 index fd9625ad..00000000 --- a/hashes/poseidon/src/tests.rs +++ /dev/null @@ -1,119 +0,0 @@ -#[cfg(test)] -mod tests { - use std::{cmp::max, iter::zip}; - - use halo2_base::{ - gates::{builder::GateThreadBuilder, GateChip}, - halo2_proofs::halo2curves::bn256::Fr, - utils::ScalarField, - }; - use poseidon::Poseidon; - use rand::Rng; - - use crate::PoseidonChip; - - // make interleaved calls to absorb and squeeze elements and - // check that the result is the same in-circuit and natively - fn poseidon_compatiblity_verification( - // elements of F to absorb; one sublist = one absorption - mut absorptions: Vec>, - // list of amounts of elements of F that should be squeezed every time - mut squeezings: Vec, - rounds_full: usize, - rounts_partial: usize, - ) { - let mut builder = GateThreadBuilder::prover(); - let gate = GateChip::default(); - - let mut ctx = builder.main(0); - - // constructing native and in-circuit Poseidon sponges - let mut native_sponge = Poseidon::::new(rounds_full, rounts_partial); - let mut circuit_sponge = - PoseidonChip::::new(&mut ctx, rounds_full, rounts_partial) - .expect("Failed to construct Poseidon circuit"); - - // preparing to interleave absorptions and squeezings - let n_iterations = max(absorptions.len(), squeezings.len()); - absorptions.resize(n_iterations, Vec::new()); - squeezings.resize(n_iterations, 0); - - for (absorption, squeezing) in zip(absorptions, squeezings) { - // absorb (if any elements were provided) - native_sponge.update(&absorption); - circuit_sponge.update(&ctx.assign_witnesses(absorption)); - - // squeeze (if any elements were requested) - for _ in 0..squeezing { - let native_squeezed = native_sponge.squeeze(); - let circuit_squeezed = - circuit_sponge.squeeze(&mut ctx, &gate).expect("Failed to squeeze"); - - assert_eq!(native_squeezed, *circuit_squeezed.value()); - } - } - - // even if no squeezings were requested, we squeeze to verify the - // states are the same after all absorptions - let native_squeezed = native_sponge.squeeze(); - let circuit_squeezed = circuit_sponge.squeeze(&mut ctx, &gate).expect("Failed to squeeze"); - - assert_eq!(native_squeezed, *circuit_squeezed.value()); - } - - fn random_nested_list_f(len: usize, max_sub_len: usize) -> Vec> { - let mut rng = rand::thread_rng(); - let mut list = Vec::new(); - for _ in 0..len { - let len = rng.gen_range(0..=max_sub_len); - let mut sublist = Vec::new(); - - for _ in 0..len { - sublist.push(F::random(&mut rng)); - } - list.push(sublist); - } - list - } - - fn random_list_usize(len: usize, max: usize) -> Vec { - let mut rng = rand::thread_rng(); - let mut list = Vec::new(); - for _ in 0..len { - list.push(rng.gen_range(0..=max)); - } - list - } - - #[test] - fn test_poseidon_compatibility_squeezing_only() { - let absorptions = Vec::new(); - let squeezings = random_list_usize(10, 7); - - poseidon_compatiblity_verification::(absorptions, squeezings, 8, 57); - } - - #[test] - fn test_poseidon_compatibility_absorbing_only() { - let absorptions = random_nested_list_f(8, 5); - let squeezings = Vec::new(); - - poseidon_compatiblity_verification::(absorptions, squeezings, 8, 57); - } - - #[test] - fn test_poseidon_compatibility_interleaved() { - let absorptions = random_nested_list_f(10, 5); - let squeezings = random_list_usize(7, 10); - - poseidon_compatiblity_verification::(absorptions, squeezings, 8, 57); - } - - #[test] - fn test_poseidon_compatibility_other_params() { - let absorptions = random_nested_list_f(10, 10); - let squeezings = random_list_usize(10, 10); - - poseidon_compatiblity_verification::(absorptions, squeezings, 8, 120); - } -} diff --git a/hashes/zkevm-keccak/Cargo.toml b/hashes/zkevm-keccak/Cargo.toml deleted file mode 100644 index 3b35b7a3..00000000 --- a/hashes/zkevm-keccak/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -[package] -name = "zkevm-keccak" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -[dependencies] -array-init = "2.0.0" -ethers-core = "0.17.0" -rand = "0.8" -itertools = "0.10.3" -lazy_static = "1.4" -log = "0.4" -num-bigint = { version = "0.4" } -halo2-base = { path = "../../halo2-base", default-features = false } -rayon = "1.6.1" - -[dev-dependencies] -criterion = "0.3" -ctor = "0.1.22" -ethers-signers = "0.17.0" -hex = "0.4.3" -itertools = "0.10.1" -pretty_assertions = "1.0.0" -rand_core = "0.6.4" -rand_xorshift = "0.3" -env_logger = "0.10" - -[features] -default = ["halo2-axiom", "display"] -display = ["halo2-base/display"] -halo2-pse = ["halo2-base/halo2-pse"] -halo2-axiom = ["halo2-base/halo2-axiom"] diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi.rs b/hashes/zkevm-keccak/src/keccak_packed_multi.rs deleted file mode 100644 index 55be8306..00000000 --- a/hashes/zkevm-keccak/src/keccak_packed_multi.rs +++ /dev/null @@ -1,2040 +0,0 @@ -use super::util::{ - constraint_builder::BaseConstraintBuilder, - eth_types::Field, - expression::{and, not, select, Expr}, - field_xor, get_absorb_positions, get_num_bits_per_lookup, into_bits, load_lookup_table, - load_normalize_table, load_pack_table, pack, pack_u64, pack_with_base, rotate, scatter, - target_part_sizes, to_bytes, unpack, CHI_BASE_LOOKUP_TABLE, NUM_BYTES_PER_WORD, NUM_ROUNDS, - NUM_WORDS_TO_ABSORB, NUM_WORDS_TO_SQUEEZE, RATE, RATE_IN_BITS, RHO_MATRIX, ROUND_CST, -}; -use crate::halo2_proofs::{ - arithmetic::FieldExt, - circuit::{Layouter, Region, Value}, - plonk::{ - Advice, Challenge, Column, ConstraintSystem, Error, Expression, Fixed, SecondPhase, - TableColumn, VirtualCells, - }, - poly::Rotation, -}; -use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; -use itertools::Itertools; -use log::{debug, info}; -use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; -use std::env::var; -use std::marker::PhantomData; - -#[cfg(test)] -mod tests; - -const MAX_DEGREE: usize = 3; -const ABSORB_LOOKUP_RANGE: usize = 3; -const THETA_C_LOOKUP_RANGE: usize = 6; -const RHO_PI_LOOKUP_RANGE: usize = 4; -const CHI_BASE_LOOKUP_RANGE: usize = 5; - -pub fn get_num_rows_per_round() -> usize { - var("KECCAK_ROWS") - .unwrap_or_else(|_| "25".to_string()) - .parse() - .expect("Cannot parse KECCAK_ROWS env var as usize") -} - -fn get_num_bits_per_absorb_lookup() -> usize { - get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE) -} - -fn get_num_bits_per_theta_c_lookup() -> usize { - get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE) -} - -fn get_num_bits_per_rho_pi_lookup() -> usize { - get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE)) -} - -fn get_num_bits_per_base_chi_lookup() -> usize { - get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE)) -} - -/// The number of keccak_f's that can be done in this circuit -/// -/// `num_rows` should be number of usable rows without blinding factors -pub fn get_keccak_capacity(num_rows: usize) -> usize { - // - 1 because we have a dummy round at the very beginning of multi_keccak - // - NUM_WORDS_TO_ABSORB because `absorb_data_next` and `absorb_result_next` query `NUM_WORDS_TO_ABSORB * get_num_rows_per_round()` beyond any row where `q_absorb == 1` - (num_rows / get_num_rows_per_round() - 1 - NUM_WORDS_TO_ABSORB) / (NUM_ROUNDS + 1) -} - -pub fn get_num_keccak_f(byte_length: usize) -> usize { - // ceil( (byte_length + 1) / RATE ) - byte_length / RATE + 1 -} - -/// AbsorbData -#[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct AbsorbData { - from: F, - absorb: F, - result: F, -} - -/// SqueezeData -#[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct SqueezeData { - packed: F, -} - -/// KeccakRow -#[derive(Clone, Debug)] -pub struct KeccakRow { - q_enable: bool, - // q_enable_row: bool, - q_round: bool, - q_absorb: bool, - q_round_last: bool, - q_padding: bool, - q_padding_last: bool, - round_cst: F, - is_final: bool, - cell_values: Vec, - // We have no need for length as RLC equality checks length implicitly - // length: usize, - // SecondPhase values will be assigned separately - // data_rlc: Value, - // hash_rlc: Value, -} - -impl KeccakRow { - pub fn dummy_rows(num_rows: usize) -> Vec { - (0..num_rows) - .map(|idx| KeccakRow { - q_enable: idx == 0, - // q_enable_row: true, - q_round: false, - q_absorb: idx == 0, - q_round_last: false, - q_padding: false, - q_padding_last: false, - round_cst: F::zero(), - is_final: false, - cell_values: Vec::new(), - }) - .collect() - } -} - -/// Part -#[derive(Clone, Debug)] -pub(crate) struct Part { - cell: Cell, - expr: Expression, - num_bits: usize, -} - -/// Part Value -#[derive(Clone, Copy, Debug)] -pub(crate) struct PartValue { - value: F, - rot: i32, - num_bits: usize, -} - -#[derive(Clone, Debug)] -pub(crate) struct KeccakRegion { - pub(crate) rows: Vec>, -} - -impl KeccakRegion { - pub(crate) fn new() -> Self { - Self { rows: Vec::new() } - } - - pub(crate) fn assign(&mut self, column: usize, offset: usize, value: F) { - while offset >= self.rows.len() { - self.rows.push(Vec::new()); - } - let row = &mut self.rows[offset]; - while column >= row.len() { - row.push(F::zero()); - } - row[column] = value; - } -} - -#[derive(Clone, Debug)] -pub(crate) struct Cell { - expression: Expression, - column_expression: Expression, - column: Option>, - column_idx: usize, - rotation: i32, -} - -impl Cell { - pub(crate) fn new( - meta: &mut VirtualCells, - column: Column, - column_idx: usize, - rotation: i32, - ) -> Self { - Self { - expression: meta.query_advice(column, Rotation(rotation)), - column_expression: meta.query_advice(column, Rotation::cur()), - column: Some(column), - column_idx, - rotation, - } - } - - pub(crate) fn new_value(column_idx: usize, rotation: i32) -> Self { - Self { - expression: 0.expr(), - column_expression: 0.expr(), - column: None, - column_idx, - rotation, - } - } - - pub(crate) fn at_offset(&self, meta: &mut ConstraintSystem, offset: i32) -> Self { - let mut expression = 0.expr(); - meta.create_gate("Query cell", |meta| { - expression = meta.query_advice(self.column.unwrap(), Rotation(self.rotation + offset)); - vec![0.expr()] - }); - - Self { - expression, - column_expression: self.column_expression.clone(), - column: self.column, - column_idx: self.column_idx, - rotation: self.rotation + offset, - } - } - - pub(crate) fn assign(&self, region: &mut KeccakRegion, offset: i32, value: F) { - region.assign(self.column_idx, (offset + self.rotation) as usize, value); - } -} - -impl Expr for Cell { - fn expr(&self) -> Expression { - self.expression.clone() - } -} - -impl Expr for &Cell { - fn expr(&self) -> Expression { - self.expression.clone() - } -} - -/// CellColumn -#[derive(Clone, Debug)] -pub(crate) struct CellColumn { - advice: Column, - expr: Expression, -} - -/// CellManager -#[derive(Clone, Debug)] -pub(crate) struct CellManager { - height: usize, - width: usize, - current_row: usize, - columns: Vec>, - // rows[i] gives the number of columns already used in row `i` - rows: Vec, - num_unused_cells: usize, -} - -impl CellManager { - pub(crate) fn new(height: usize) -> Self { - Self { - height, - width: 0, - current_row: 0, - columns: Vec::new(), - rows: vec![0; height], - num_unused_cells: 0, - } - } - - pub(crate) fn query_cell(&mut self, meta: &mut ConstraintSystem) -> Cell { - let (row_idx, column_idx) = self.get_position(); - self.query_cell_at_pos(meta, row_idx as i32, column_idx) - } - - pub(crate) fn query_cell_at_row( - &mut self, - meta: &mut ConstraintSystem, - row_idx: i32, - ) -> Cell { - let column_idx = self.rows[row_idx as usize]; - self.rows[row_idx as usize] += 1; - self.width = self.width.max(column_idx + 1); - self.current_row = (row_idx as usize + 1) % self.height; - self.query_cell_at_pos(meta, row_idx, column_idx) - } - - pub(crate) fn query_cell_at_pos( - &mut self, - meta: &mut ConstraintSystem, - row_idx: i32, - column_idx: usize, - ) -> Cell { - let column = if column_idx < self.columns.len() { - self.columns[column_idx].advice - } else { - assert!(column_idx == self.columns.len()); - let advice = meta.advice_column(); - let mut expr = 0.expr(); - meta.create_gate("Query column", |meta| { - expr = meta.query_advice(advice, Rotation::cur()); - vec![0.expr()] - }); - self.columns.push(CellColumn { advice, expr }); - advice - }; - - let mut cells = Vec::new(); - meta.create_gate("Query cell", |meta| { - cells.push(Cell::new(meta, column, column_idx, row_idx)); - vec![0.expr()] - }); - cells[0].clone() - } - - pub(crate) fn query_cell_value(&mut self) -> Cell { - let (row_idx, column_idx) = self.get_position(); - self.query_cell_value_at_pos(row_idx as i32, column_idx) - } - - pub(crate) fn query_cell_value_at_row(&mut self, row_idx: i32) -> Cell { - let column_idx = self.rows[row_idx as usize]; - self.rows[row_idx as usize] += 1; - self.width = self.width.max(column_idx + 1); - self.current_row = (row_idx as usize + 1) % self.height; - self.query_cell_value_at_pos(row_idx, column_idx) - } - - pub(crate) fn query_cell_value_at_pos(&mut self, row_idx: i32, column_idx: usize) -> Cell { - Cell::new_value(column_idx, row_idx) - } - - fn get_position(&mut self) -> (usize, usize) { - let best_row_idx = self.current_row; - let best_row_pos = self.rows[best_row_idx]; - self.rows[best_row_idx] += 1; - self.width = self.width.max(best_row_pos + 1); - self.current_row = (best_row_idx + 1) % self.height; - (best_row_idx, best_row_pos) - } - - pub(crate) fn get_width(&self) -> usize { - self.width - } - - pub(crate) fn start_region(&mut self) -> usize { - // Make sure all rows start at the same column - let width = self.get_width(); - #[cfg(debug_assertions)] - for row in self.rows.iter() { - self.num_unused_cells += width - *row; - } - self.rows = vec![width; self.height]; - width - } - - pub(crate) fn columns(&self) -> &[CellColumn] { - &self.columns - } - - pub(crate) fn get_num_unused_cells(&self) -> usize { - self.num_unused_cells - } -} - -/// Keccak Table, used to verify keccak hashing from RLC'ed input. -#[derive(Clone, Debug)] -pub struct KeccakTable { - /// True when the row is enabled - pub is_enabled: Column, - /// Byte array input as `RLC(reversed(input))` - pub input_rlc: Column, // RLC of input bytes - // Byte array input length - // pub input_len: Column, - /// RLC of the hash result - pub output_rlc: Column, // RLC of hash of input bytes -} - -impl KeccakTable { - /// Construct a new KeccakTable - pub fn construct(meta: &mut ConstraintSystem) -> Self { - let input_rlc = meta.advice_column_in(SecondPhase); - let output_rlc = meta.advice_column_in(SecondPhase); - meta.enable_equality(input_rlc); - meta.enable_equality(output_rlc); - Self { - is_enabled: meta.advice_column(), - input_rlc, - // input_len: meta.advice_column(), - output_rlc, - } - } -} - -#[cfg(feature = "halo2-axiom")] -type KeccakAssignedValue<'v, F> = AssignedCell<&'v Assigned, F>; -#[cfg(not(feature = "halo2-axiom"))] -type KeccakAssignedValue<'v, F> = AssignedCell; - -pub fn assign_advice_custom<'v, F: Field>( - region: &mut Region, - column: Column, - offset: usize, - value: Value, -) -> KeccakAssignedValue<'v, F> { - #[cfg(feature = "halo2-axiom")] - { - region.assign_advice(column, offset, value) - } - #[cfg(feature = "halo2-pse")] - { - region - .assign_advice(|| format!("assign advice {}", offset), column, offset, || value) - .unwrap() - } -} - -pub fn assign_fixed_custom( - region: &mut Region, - column: Column, - offset: usize, - value: F, -) { - #[cfg(feature = "halo2-axiom")] - { - region.assign_fixed(column, offset, value); - } - #[cfg(feature = "halo2-pse")] - { - region - .assign_fixed( - || format!("assign fixed {}", offset), - column, - offset, - || Value::known(value), - ) - .unwrap(); - } -} - -/// Recombines parts back together -mod decode { - use super::{Expr, FieldExt, Part, PartValue}; - use crate::halo2_proofs::plonk::Expression; - use crate::util::BIT_COUNT; - - pub(crate) fn expr(parts: Vec>) -> Expression { - parts.iter().rev().fold(0.expr(), |acc, part| { - acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.expr.clone() - }) - } - - pub(crate) fn value(parts: Vec>) -> F { - parts.iter().rev().fold(F::zero(), |acc, part| { - acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.value - }) - } -} - -/// Splits a word into parts -mod split { - use super::{ - decode, BaseConstraintBuilder, CellManager, Expr, Field, FieldExt, KeccakRegion, Part, - PartValue, - }; - use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; - use crate::util::{pack, pack_part, unpack, WordParts}; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - meta: &mut ConstraintSystem, - cell_manager: &mut CellManager, - cb: &mut BaseConstraintBuilder, - input: Expression, - rot: usize, - target_part_size: usize, - normalize: bool, - row: Option, - ) -> Vec> { - let word = WordParts::new(target_part_size, rot, normalize); - let mut parts = Vec::with_capacity(word.parts.len()); - for word_part in word.parts { - let cell = if let Some(row) = row { - cell_manager.query_cell_at_row(meta, row as i32) - } else { - cell_manager.query_cell(meta) - }; - parts.push(Part { - num_bits: word_part.bits.len(), - cell: cell.clone(), - expr: cell.expr(), - }); - } - // Input parts need to equal original input expression - cb.require_equal("split", decode::expr(parts.clone()), input); - parts - } - - pub(crate) fn value( - cell_manager: &mut CellManager, - region: &mut KeccakRegion, - input: F, - rot: usize, - target_part_size: usize, - normalize: bool, - row: Option, - ) -> Vec> { - let input_bits = unpack(input); - debug_assert_eq!(pack::(&input_bits), input); - let word = WordParts::new(target_part_size, rot, normalize); - let mut parts = Vec::with_capacity(word.parts.len()); - for word_part in word.parts { - let value = pack_part(&input_bits, &word_part); - let cell = if let Some(row) = row { - cell_manager.query_cell_value_at_row(row as i32) - } else { - cell_manager.query_cell_value() - }; - cell.assign(region, 0, F::from(value)); - parts.push(PartValue { - num_bits: word_part.bits.len(), - rot: cell.rotation, - value: F::from(value), - }); - } - debug_assert_eq!(decode::value(parts.clone()), input); - parts - } -} - -// Split into parts, but storing the parts in a specific way to have the same -// table layout in `output_cells` regardless of rotation. -mod split_uniform { - use super::{ - decode, target_part_sizes, BaseConstraintBuilder, Cell, CellManager, Expr, FieldExt, - KeccakRegion, Part, PartValue, - }; - use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; - use crate::util::{ - eth_types::Field, pack, pack_part, rotate, rotate_rev, unpack, WordParts, BIT_SIZE, - }; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - meta: &mut ConstraintSystem, - output_cells: &[Cell], - cell_manager: &mut CellManager, - cb: &mut BaseConstraintBuilder, - input: Expression, - rot: usize, - target_part_size: usize, - normalize: bool, - ) -> Vec> { - let mut input_parts = Vec::new(); - let mut output_parts = Vec::new(); - let word = WordParts::new(target_part_size, rot, normalize); - - let word = rotate(word.parts, rot, target_part_size); - - let target_sizes = target_part_sizes(target_part_size); - let mut word_iter = word.iter(); - let mut counter = 0; - while let Some(word_part) = word_iter.next() { - if word_part.bits.len() == target_sizes[counter] { - // Input and output part are the same - let part = Part { - num_bits: target_sizes[counter], - cell: output_cells[counter].clone(), - expr: output_cells[counter].expr(), - }; - input_parts.push(part.clone()); - output_parts.push(part); - counter += 1; - } else if let Some(extra_part) = word_iter.next() { - // The two parts combined need to have the expected combined length - debug_assert_eq!( - word_part.bits.len() + extra_part.bits.len(), - target_sizes[counter] - ); - - // Needs two cells here to store the parts - // These still need to be range checked elsewhere! - let part_a = cell_manager.query_cell(meta); - let part_b = cell_manager.query_cell(meta); - - // Make sure the parts combined equal the value in the uniform output - let expr = part_a.expr() - + part_b.expr() - * F::from((BIT_SIZE as u32).pow(word_part.bits.len() as u32) as u64); - cb.require_equal("rot part", expr, output_cells[counter].expr()); - - // Input needs the two parts because it needs to be able to undo the rotation - input_parts.push(Part { - num_bits: word_part.bits.len(), - cell: part_a.clone(), - expr: part_a.expr(), - }); - input_parts.push(Part { - num_bits: extra_part.bits.len(), - cell: part_b.clone(), - expr: part_b.expr(), - }); - // Output only has the combined cell - output_parts.push(Part { - num_bits: target_sizes[counter], - cell: output_cells[counter].clone(), - expr: output_cells[counter].expr(), - }); - counter += 1; - } else { - unreachable!(); - } - } - let input_parts = rotate_rev(input_parts, rot, target_part_size); - // Input parts need to equal original input expression - cb.require_equal("split", decode::expr(input_parts), input); - // Uniform output - output_parts - } - - pub(crate) fn value( - output_cells: &[Cell], - cell_manager: &mut CellManager, - region: &mut KeccakRegion, - input: F, - rot: usize, - target_part_size: usize, - normalize: bool, - ) -> Vec> { - let input_bits = unpack(input); - debug_assert_eq!(pack::(&input_bits), input); - - let mut input_parts = Vec::new(); - let mut output_parts = Vec::new(); - let word = WordParts::new(target_part_size, rot, normalize); - - let word = rotate(word.parts, rot, target_part_size); - - let target_sizes = target_part_sizes(target_part_size); - let mut word_iter = word.iter(); - let mut counter = 0; - while let Some(word_part) = word_iter.next() { - if word_part.bits.len() == target_sizes[counter] { - let value = pack_part(&input_bits, word_part); - output_cells[counter].assign(region, 0, F::from(value)); - input_parts.push(PartValue { - num_bits: word_part.bits.len(), - rot: output_cells[counter].rotation, - value: F::from(value), - }); - output_parts.push(PartValue { - num_bits: word_part.bits.len(), - rot: output_cells[counter].rotation, - value: F::from(value), - }); - counter += 1; - } else if let Some(extra_part) = word_iter.next() { - debug_assert_eq!( - word_part.bits.len() + extra_part.bits.len(), - target_sizes[counter] - ); - - let part_a = cell_manager.query_cell_value(); - let part_b = cell_manager.query_cell_value(); - - let value_a = pack_part(&input_bits, word_part); - let value_b = pack_part(&input_bits, extra_part); - - part_a.assign(region, 0, F::from(value_a)); - part_b.assign(region, 0, F::from(value_b)); - - let value = value_a + value_b * (BIT_SIZE as u64).pow(word_part.bits.len() as u32); - - output_cells[counter].assign(region, 0, F::from(value)); - - input_parts.push(PartValue { - num_bits: word_part.bits.len(), - value: F::from(value_a), - rot: part_a.rotation, - }); - input_parts.push(PartValue { - num_bits: extra_part.bits.len(), - value: F::from(value_b), - rot: part_b.rotation, - }); - output_parts.push(PartValue { - num_bits: target_sizes[counter], - value: F::from(value), - rot: output_cells[counter].rotation, - }); - counter += 1; - } else { - unreachable!(); - } - } - let input_parts = rotate_rev(input_parts, rot, target_part_size); - debug_assert_eq!(decode::value(input_parts), input); - output_parts - } -} - -// Transform values using a lookup table -mod transform { - use super::{transform_to, CellManager, Field, FieldExt, KeccakRegion, Part, PartValue}; - use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; - use itertools::Itertools; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - name: &'static str, - meta: &mut ConstraintSystem, - cell_manager: &mut CellManager, - lookup_counter: &mut usize, - input: Vec>, - transform_table: [TableColumn; 2], - uniform_lookup: bool, - ) -> Vec> { - let cells = input - .iter() - .map(|input_part| { - if uniform_lookup { - cell_manager.query_cell_at_row(meta, input_part.cell.rotation) - } else { - cell_manager.query_cell(meta) - } - }) - .collect_vec(); - transform_to::expr( - name, - meta, - &cells, - lookup_counter, - input, - transform_table, - uniform_lookup, - ) - } - - pub(crate) fn value( - cell_manager: &mut CellManager, - region: &mut KeccakRegion, - input: Vec>, - do_packing: bool, - f: fn(&u8) -> u8, - uniform_lookup: bool, - ) -> Vec> { - let cells = input - .iter() - .map(|input_part| { - if uniform_lookup { - cell_manager.query_cell_value_at_row(input_part.rot) - } else { - cell_manager.query_cell_value() - } - }) - .collect_vec(); - transform_to::value(&cells, region, input, do_packing, f) - } -} - -// Transfroms values to cells -mod transform_to { - use super::{Cell, Expr, Field, FieldExt, KeccakRegion, Part, PartValue}; - use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; - use crate::util::{pack, to_bytes, unpack}; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - name: &'static str, - meta: &mut ConstraintSystem, - cells: &[Cell], - lookup_counter: &mut usize, - input: Vec>, - transform_table: [TableColumn; 2], - uniform_lookup: bool, - ) -> Vec> { - let mut output = Vec::with_capacity(input.len()); - for (idx, input_part) in input.iter().enumerate() { - let output_part = cells[idx].clone(); - if !uniform_lookup || input_part.cell.rotation == 0 { - meta.lookup(name, |_| { - vec![ - (input_part.expr.clone(), transform_table[0]), - (output_part.expr(), transform_table[1]), - ] - }); - *lookup_counter += 1; - } - output.push(Part { - num_bits: input_part.num_bits, - cell: output_part.clone(), - expr: output_part.expr(), - }); - } - output - } - - pub(crate) fn value( - cells: &[Cell], - region: &mut KeccakRegion, - input: Vec>, - do_packing: bool, - f: fn(&u8) -> u8, - ) -> Vec> { - let mut output = Vec::new(); - for (idx, input_part) in input.iter().enumerate() { - let input_bits = &unpack(input_part.value)[0..input_part.num_bits]; - let output_bits = input_bits.iter().map(f).collect::>(); - let value = if do_packing { - pack(&output_bits) - } else { - F::from(to_bytes::value(&output_bits)[0] as u64) - }; - let output_part = cells[idx].clone(); - output_part.assign(region, 0, value); - output.push(PartValue { - num_bits: input_part.num_bits, - rot: output_part.rotation, - value, - }); - } - output - } -} - -/// KeccakConfig -#[derive(Clone, Debug)] -pub struct KeccakCircuitConfig { - challenge: Challenge, - q_enable: Column, - // q_enable_row: Column, - q_first: Column, - q_round: Column, - q_absorb: Column, - q_round_last: Column, - q_padding: Column, - q_padding_last: Column, - - pub keccak_table: KeccakTable, - - cell_manager: CellManager, - round_cst: Column, - normalize_3: [TableColumn; 2], - normalize_4: [TableColumn; 2], - normalize_6: [TableColumn; 2], - chi_base_table: [TableColumn; 2], - pack_table: [TableColumn; 2], - _marker: PhantomData, -} - -impl KeccakCircuitConfig { - pub fn challenge(&self) -> Challenge { - self.challenge - } - /// Return a new KeccakCircuitConfig - pub fn new(meta: &mut ConstraintSystem, challenge: Challenge) -> Self { - let q_enable = meta.fixed_column(); - // let q_enable_row = meta.fixed_column(); - let q_first = meta.fixed_column(); - let q_round = meta.fixed_column(); - let q_absorb = meta.fixed_column(); - let q_round_last = meta.fixed_column(); - let q_padding = meta.fixed_column(); - let q_padding_last = meta.fixed_column(); - let round_cst = meta.fixed_column(); - let keccak_table = KeccakTable::construct(meta); - - let is_final = keccak_table.is_enabled; - // let length = keccak_table.input_len; - let data_rlc = keccak_table.input_rlc; - let hash_rlc = keccak_table.output_rlc; - - let normalize_3 = array_init::array_init(|_| meta.lookup_table_column()); - let normalize_4 = array_init::array_init(|_| meta.lookup_table_column()); - let normalize_6 = array_init::array_init(|_| meta.lookup_table_column()); - let chi_base_table = array_init::array_init(|_| meta.lookup_table_column()); - let pack_table = array_init::array_init(|_| meta.lookup_table_column()); - - let num_rows_per_round = get_num_rows_per_round(); - let mut cell_manager = CellManager::new(get_num_rows_per_round()); - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let mut total_lookup_counter = 0; - - let start_new_hash = |meta: &mut VirtualCells, rot| { - // A new hash is started when the previous hash is done or on the first row - meta.query_fixed(q_first, rot) + meta.query_advice(is_final, rot) - }; - - // Round constant - let mut round_cst_expr = 0.expr(); - meta.create_gate("Query round cst", |meta| { - round_cst_expr = meta.query_fixed(round_cst, Rotation::cur()); - vec![0u64.expr()] - }); - // State data - let mut s = vec![vec![0u64.expr(); 5]; 5]; - let mut s_next = vec![vec![0u64.expr(); 5]; 5]; - for i in 0..5 { - for j in 0..5 { - let cell = cell_manager.query_cell(meta); - s[i][j] = cell.expr(); - s_next[i][j] = cell.at_offset(meta, num_rows_per_round as i32).expr(); - } - } - // Absorb data - let absorb_from = cell_manager.query_cell(meta); - let absorb_data = cell_manager.query_cell(meta); - let absorb_result = cell_manager.query_cell(meta); - let mut absorb_from_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; - let mut absorb_data_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; - let mut absorb_result_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; - for i in 0..NUM_WORDS_TO_ABSORB { - let rot = ((i + 1) * num_rows_per_round) as i32; - absorb_from_next[i] = absorb_from.at_offset(meta, rot).expr(); - absorb_data_next[i] = absorb_data.at_offset(meta, rot).expr(); - absorb_result_next[i] = absorb_result.at_offset(meta, rot).expr(); - } - - // Store the pre-state - let pre_s = s.clone(); - - // Absorb - // The absorption happening at the start of the 24 rounds is done spread out - // over those 24 rounds. In a single round (in 17 of the 24 rounds) a - // single word is absorbed so the work is spread out. The absorption is - // done simply by doing state + data and then normalizing the result to [0,1]. - // We also need to convert the input data into bytes to calculate the input data - // rlc. - cell_manager.start_region(); - let mut lookup_counter = 0; - let part_size = get_num_bits_per_absorb_lookup(); - let input = absorb_from.expr() + absorb_data.expr(); - let absorb_fat = - split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); - cell_manager.start_region(); - let absorb_res = transform::expr( - "absorb", - meta, - &mut cell_manager, - &mut lookup_counter, - absorb_fat, - normalize_3, - true, - ); - cb.require_equal("absorb result", decode::expr(absorb_res), absorb_result.expr()); - info!("- Post absorb:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Squeeze - // The squeezing happening at the end of the 24 rounds is done spread out - // over those 24 rounds. In a single round (in 4 of the 24 rounds) a - // single word is converted to bytes. - cell_manager.start_region(); - let mut lookup_counter = 0; - // Potential optimization: could do multiple bytes per lookup - let packed_parts = - split::expr(meta, &mut cell_manager, &mut cb, absorb_data.expr(), 0, 8, false, None); - cell_manager.start_region(); - // input_bytes.len() = packed_parts.len() = 64 / 8 = 8 = NUM_BYTES_PER_WORD - let input_bytes = transform::expr( - "squeeze unpack", - meta, - &mut cell_manager, - &mut lookup_counter, - packed_parts, - pack_table.into_iter().rev().collect::>().try_into().unwrap(), - true, - ); - debug_assert_eq!(input_bytes.len(), NUM_BYTES_PER_WORD); - - // Padding data - cell_manager.start_region(); - let is_paddings = input_bytes.iter().map(|_| cell_manager.query_cell(meta)).collect_vec(); - info!("- Post padding:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Theta - // Calculate - // - `c[i] = s[i][0] + s[i][1] + s[i][2] + s[i][3] + s[i][4]` - // - `bc[i] = normalize(c)`. - // - `t[i] = bc[(i + 4) % 5] + rot(bc[(i + 1)% 5], 1)` - // This is done by splitting the bc values in parts in a way - // that allows us to also calculate the rotated value "for free". - cell_manager.start_region(); - let mut lookup_counter = 0; - let part_size_c = get_num_bits_per_theta_c_lookup(); - let mut c_parts = Vec::new(); - for s in s.iter() { - // Calculate c and split into parts - let c = s[0].clone() + s[1].clone() + s[2].clone() + s[3].clone() + s[4].clone(); - c_parts.push(split::expr( - meta, - &mut cell_manager, - &mut cb, - c, - 1, - part_size_c, - false, - None, - )); - } - // Now calculate `bc` by normalizing `c` - cell_manager.start_region(); - let mut bc = Vec::new(); - for c in c_parts { - // Normalize c - bc.push(transform::expr( - "theta c", - meta, - &mut cell_manager, - &mut lookup_counter, - c, - normalize_6, - true, - )); - } - // Now do `bc[(i + 4) % 5] + rot(bc[(i + 1) % 5], 1)` using just expressions. - // We don't normalize the result here. We do it as part of the rho/pi step, even - // though we would only have to normalize 5 values instead of 25, because of the - // way the rho/pi and chi steps can be combined it's more efficient to - // do it there (the max value for chi is 4 already so that's the - // limiting factor). - let mut os = vec![vec![0u64.expr(); 5]; 5]; - for i in 0..5 { - let t = decode::expr(bc[(i + 4) % 5].clone()) - + decode::expr(rotate(bc[(i + 1) % 5].clone(), 1, part_size_c)); - for j in 0..5 { - os[i][j] = s[i][j].clone() + t.clone(); - } - } - s = os.clone(); - info!("- Post theta:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Rho/Pi - // For the rotation of rho/pi we split up the words like expected, but in a way - // that allows reusing the same parts in an optimal way for the chi step. - // We can save quite a few columns by not recombining the parts after rho/pi and - // re-splitting the words again before chi. Instead we do chi directly - // on the output parts of rho/pi. For rho/pi specically we do - // `s[j][2 * i + 3 * j) % 5] = normalize(rot(s[i][j], RHOM[i][j]))`. - cell_manager.start_region(); - let mut lookup_counter = 0; - let part_size = get_num_bits_per_base_chi_lookup(); - // To combine the rho/pi/chi steps we have to ensure a specific layout so - // query those cells here first. - // For chi we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & s[(i+2)%5][j])`. `j` - // remains static but `i` is accessed in a wrap around manner. To do this using - // multiple rows with lookups in a way that doesn't require any - // extra additional cells or selectors we have to put all `s[i]`'s on the same - // row. This isn't that strong of a requirement actually because we the - // words are split into multipe parts, and so only the parts at the same - // position of those words need to be on the same row. - let target_word_sizes = target_part_sizes(part_size); - let num_word_parts = target_word_sizes.len(); - let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = array_init::array_init(|_| { - array_init::array_init(|_| array_init::array_init(|_| Vec::new())) - }); - let mut num_columns = 0; - let mut column_starts = [0usize; 3]; - for p in 0..3 { - column_starts[p] = cell_manager.start_region(); - let mut row_idx = 0; - num_columns = 0; - for j in 0..5 { - for _ in 0..num_word_parts { - for i in 0..5 { - rho_pi_chi_cells[p][i][j] - .push(cell_manager.query_cell_at_row(meta, row_idx)); - } - if row_idx == 0 { - num_columns += 1; - } - row_idx = (((row_idx as usize) + 1) % num_rows_per_round) as i32; - } - } - } - // Do the transformation, resulting in the word parts also being normalized. - let pi_region_start = cell_manager.start_region(); - let mut os_parts = vec![vec![Vec::new(); 5]; 5]; - for (j, os_part) in os_parts.iter_mut().enumerate() { - for i in 0..5 { - // Split s into parts - let s_parts = split_uniform::expr( - meta, - &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], - &mut cell_manager, - &mut cb, - s[i][j].clone(), - RHO_MATRIX[i][j], - part_size, - true, - ); - // Normalize the data to the target cells - let s_parts = transform_to::expr( - "rho/pi", - meta, - &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], - &mut lookup_counter, - s_parts.clone(), - normalize_4, - true, - ); - os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); - } - } - let pi_region_end = cell_manager.start_region(); - // Pi parts range checks - // To make the uniform stuff work we had to combine some parts together - // in new cells (see split_uniform). Here we make sure those parts are range - // checked. Potential improvement: Could combine multiple smaller parts - // in a single lookup but doesn't save that much. - for c in pi_region_start..pi_region_end { - meta.lookup("pi part range check", |_| { - vec![(cell_manager.columns()[c].expr.clone(), normalize_4[0])] - }); - lookup_counter += 1; - } - info!("- Post rho/pi:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Chi - // In groups of 5 columns, we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & - // s[(i+2)%5][j])` five times, on each row (no selector needed). - // This is calculated by making use of `CHI_BASE_LOOKUP_TABLE`. - let mut lookup_counter = 0; - let part_size_base = get_num_bits_per_base_chi_lookup(); - for idx in 0..num_columns { - // First fetch the cells we wan to use - let mut input: [Expression; 5] = array_init::array_init(|_| 0.expr()); - let mut output: [Expression; 5] = array_init::array_init(|_| 0.expr()); - for c in 0..5 { - input[c] = cell_manager.columns()[column_starts[1] + idx * 5 + c].expr.clone(); - output[c] = cell_manager.columns()[column_starts[2] + idx * 5 + c].expr.clone(); - } - // Now calculate `a ^ ((~b) & c)` by doing `lookup[3 - 2*a + b - c]` - for i in 0..5 { - let input = scatter::expr(3, part_size_base) - 2.expr() * input[i].clone() - + input[(i + 1) % 5].clone() - - input[(i + 2) % 5].clone(); - let output = output[i].clone(); - meta.lookup("chi base", |_| { - vec![(input.clone(), chi_base_table[0]), (output.clone(), chi_base_table[1])] - }); - lookup_counter += 1; - } - } - // Now just decode the parts after the chi transformation done with the lookups - // above. - let mut os = vec![vec![0u64.expr(); 5]; 5]; - for (i, os) in os.iter_mut().enumerate() { - for (j, os) in os.iter_mut().enumerate() { - let mut parts = Vec::new(); - for idx in 0..num_word_parts { - parts.push(Part { - num_bits: part_size_base, - cell: rho_pi_chi_cells[2][i][j][idx].clone(), - expr: rho_pi_chi_cells[2][i][j][idx].expr(), - }); - } - *os = decode::expr(parts); - } - } - s = os.clone(); - - // iota - // Simply do the single xor on state [0][0]. - cell_manager.start_region(); - let part_size = get_num_bits_per_absorb_lookup(); - let input = s[0][0].clone() + round_cst_expr.clone(); - let iota_parts = - split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); - cell_manager.start_region(); - // Could share columns with absorb which may end up using 1 lookup/column - // fewer... - s[0][0] = decode::expr(transform::expr( - "iota", - meta, - &mut cell_manager, - &mut lookup_counter, - iota_parts, - normalize_3, - true, - )); - // Final results stored in the next row - for i in 0..5 { - for j in 0..5 { - cb.require_equal("next row check", s[i][j].clone(), s_next[i][j].clone()); - } - } - info!("- Post chi:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - let mut lookup_counter = 0; - cell_manager.start_region(); - - // Squeeze data - let squeeze_from = cell_manager.query_cell(meta); - let mut squeeze_from_prev = vec![0u64.expr(); NUM_WORDS_TO_SQUEEZE]; - for (idx, squeeze_from_prev) in squeeze_from_prev.iter_mut().enumerate() { - let rot = (-(idx as i32) - 1) * num_rows_per_round as i32; - *squeeze_from_prev = squeeze_from.at_offset(meta, rot).expr(); - } - // Squeeze - // The squeeze happening at the end of the 24 rounds is done spread out - // over those 24 rounds. In a single round (in 4 of the 24 rounds) a - // single word is converted to bytes. - // Potential optimization: could do multiple bytes per lookup - cell_manager.start_region(); - // Unpack a single word into bytes (for the squeeze) - // Potential optimization: could do multiple bytes per lookup - let squeeze_from_parts = - split::expr(meta, &mut cell_manager, &mut cb, squeeze_from.expr(), 0, 8, false, None); - cell_manager.start_region(); - let squeeze_bytes = transform::expr( - "squeeze unpack", - meta, - &mut cell_manager, - &mut lookup_counter, - squeeze_from_parts, - pack_table.into_iter().rev().collect::>().try_into().unwrap(), - true, - ); - info!("- Post squeeze:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // The round constraints that we've been building up till now - meta.create_gate("round", |meta| cb.gate(meta.query_fixed(q_round, Rotation::cur()))); - - // Absorb - meta.create_gate("absorb", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let continue_hash = not::expr(start_new_hash(meta, Rotation::cur())); - let absorb_positions = get_absorb_positions(); - let mut a_slice = 0; - for j in 0..5 { - for i in 0..5 { - if absorb_positions.contains(&(i, j)) { - cb.condition(continue_hash.clone(), |cb| { - cb.require_equal( - "absorb verify input", - absorb_from_next[a_slice].clone(), - pre_s[i][j].clone(), - ); - }); - cb.require_equal( - "absorb result copy", - select::expr( - continue_hash.clone(), - absorb_result_next[a_slice].clone(), - absorb_data_next[a_slice].clone(), - ), - s_next[i][j].clone(), - ); - a_slice += 1; - } else { - cb.require_equal( - "absorb state copy", - pre_s[i][j].clone() * continue_hash.clone(), - s_next[i][j].clone(), - ); - } - } - } - cb.gate(meta.query_fixed(q_absorb, Rotation::cur())) - }); - - // Collect the bytes that are spread out over previous rows - let mut hash_bytes = Vec::new(); - for i in 0..NUM_WORDS_TO_SQUEEZE { - for byte in squeeze_bytes.iter() { - let rot = (-(i as i32) - 1) * num_rows_per_round as i32; - hash_bytes.push(byte.cell.at_offset(meta, rot).expr()); - } - } - - // Squeeze - meta.create_gate("squeeze", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let start_new_hash = start_new_hash(meta, Rotation::cur()); - // The words to squeeze - let hash_words: Vec<_> = - pre_s.into_iter().take(4).map(|a| a[0].clone()).take(4).collect(); - // Verify if we converted the correct words to bytes on previous rows - for (idx, word) in hash_words.iter().enumerate() { - cb.condition(start_new_hash.clone(), |cb| { - cb.require_equal( - "squeeze verify packed", - word.clone(), - squeeze_from_prev[idx].clone(), - ); - }); - } - - let challenge_expr = meta.query_challenge(challenge); - let rlc = - hash_bytes.into_iter().reduce(|rlc, x| rlc * challenge_expr.clone() + x).unwrap(); - cb.require_equal("hash rlc check", rlc, meta.query_advice(hash_rlc, Rotation::cur())); - cb.gate(meta.query_fixed(q_round_last, Rotation::cur())) - }); - - // Some general input checks - meta.create_gate("input checks", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - cb.require_boolean("boolean is_final", meta.query_advice(is_final, Rotation::cur())); - cb.gate(meta.query_fixed(q_enable, Rotation::cur())) - }); - - // Enforce fixed values on the first row - meta.create_gate("first row", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - cb.require_zero( - "is_final needs to be disabled on the first row", - meta.query_advice(is_final, Rotation::cur()), - ); - cb.gate(meta.query_fixed(q_first, Rotation::cur())) - }); - - // Enforce logic for when this block is the last block for a hash - let last_is_padding_in_block = is_paddings.last().unwrap().at_offset( - meta, - -(((NUM_ROUNDS + 1 - NUM_WORDS_TO_ABSORB) * num_rows_per_round) as i32), - ); - meta.create_gate("is final", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - // All absorb rows except the first row - cb.condition( - meta.query_fixed(q_absorb, Rotation::cur()) - - meta.query_fixed(q_first, Rotation::cur()), - |cb| { - cb.require_equal( - "is_final needs to be the same as the last is_padding in the block", - meta.query_advice(is_final, Rotation::cur()), - last_is_padding_in_block.expr(), - ); - }, - ); - // For all the rows of a round, only the first row can have `is_final == 1`. - cb.condition( - (1..num_rows_per_round as i32) - .map(|i| meta.query_fixed(q_enable, Rotation(-i))) - .fold(0.expr(), |acc, elem| acc + elem), - |cb| { - cb.require_zero( - "is_final only when q_enable", - meta.query_advice(is_final, Rotation::cur()), - ); - }, - ); - cb.gate(1.expr()) - }); - - // Padding - // May be cleaner to do this padding logic in the byte conversion lookup but - // currently easier to do it like this. - let prev_is_padding = - is_paddings.last().unwrap().at_offset(meta, -(num_rows_per_round as i32)); - meta.create_gate("padding", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let q_padding = meta.query_fixed(q_padding, Rotation::cur()); - let q_padding_last = meta.query_fixed(q_padding_last, Rotation::cur()); - - // All padding selectors need to be boolean - for is_padding in is_paddings.iter() { - cb.condition(meta.query_fixed(q_enable, Rotation::cur()), |cb| { - cb.require_boolean("is_padding boolean", is_padding.expr()); - }); - } - // This last padding selector will be used on the first round row so needs to be - // zero - cb.condition(meta.query_fixed(q_absorb, Rotation::cur()), |cb| { - cb.require_zero( - "last is_padding should be zero on absorb rows", - is_paddings.last().unwrap().expr(), - ); - }); - // Now for each padding selector - for idx in 0..is_paddings.len() { - // Previous padding selector can be on the previous row - let is_padding_prev = - if idx == 0 { prev_is_padding.expr() } else { is_paddings[idx - 1].expr() }; - let is_first_padding = is_paddings[idx].expr() - is_padding_prev.clone(); - - // Check padding transition 0 -> 1 done only once - cb.condition(q_padding.expr(), |cb| { - cb.require_boolean("padding step boolean", is_first_padding.clone()); - }); - - // Padding start/intermediate/end byte checks - if idx == is_paddings.len() - 1 { - // These can be combined in the future, but currently this would increase the - // degree by one Padding start/intermediate byte, all - // padding rows except the last one - cb.condition( - and::expr([ - q_padding.expr() - q_padding_last.expr(), - is_paddings[idx].expr(), - ]), - |cb| { - // Input bytes need to be zero, or one if this is the first padding byte - cb.require_equal( - "padding start/intermediate byte last byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr(), - ); - }, - ); - // Padding start/end byte, only on the last padding row - cb.condition( - and::expr([q_padding_last.expr(), is_paddings[idx].expr()]), - |cb| { - // The input byte needs to be 128, unless it's also the first padding - // byte then it's 129 - cb.require_equal( - "padding start/end byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr() + 128.expr(), - ); - }, - ); - } else { - // Padding start/intermediate byte - cb.condition(and::expr([q_padding.expr(), is_paddings[idx].expr()]), |cb| { - // Input bytes need to be zero, or one if this is the first padding byte - cb.require_equal( - "padding start/intermediate byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr(), - ); - }); - } - } - cb.gate(1.expr()) - }); - - assert!(num_rows_per_round > NUM_BYTES_PER_WORD, "We require enough rows per round to hold the running RLC of the bytes from the one keccak word absorbed per round"); - // TODO: there is probably a way to only require NUM_BYTES_PER_WORD instead of - // NUM_BYTES_PER_WORD + 1 rows per round, but for simplicity and to keep the - // gate degree at 3, we just do the obvious thing for now Input data rlc - meta.create_gate("data rlc", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - - let q_padding = meta.query_fixed(q_padding, Rotation::cur()); - let start_new_hash_prev = start_new_hash(meta, Rotation(-(num_rows_per_round as i32))); - let data_rlc_prev = meta.query_advice(data_rlc, Rotation(-(num_rows_per_round as i32))); - - // Update the length/data_rlc on rows where we absorb data - cb.condition(q_padding.expr(), |cb| { - let challenge_expr = meta.query_challenge(challenge); - // Use intermediate cells to keep the degree low - let mut new_data_rlc = - data_rlc_prev.clone() * not::expr(start_new_hash_prev.expr()); - let mut data_rlcs = (0..NUM_BYTES_PER_WORD) - .map(|i| meta.query_advice(data_rlc, Rotation(i as i32 + 1))); - let intermed_rlc = data_rlcs.next().unwrap(); - cb.require_equal("initial data rlc", intermed_rlc.clone(), new_data_rlc); - new_data_rlc = intermed_rlc; - for (byte, is_padding) in input_bytes.iter().zip(is_paddings.iter()) { - new_data_rlc = select::expr( - is_padding.expr(), - new_data_rlc.clone(), - new_data_rlc * challenge_expr.clone() + byte.expr.clone(), - ); - if let Some(intermed_rlc) = data_rlcs.next() { - cb.require_equal( - "intermediate data rlc", - intermed_rlc.clone(), - new_data_rlc, - ); - new_data_rlc = intermed_rlc; - } - } - cb.require_equal( - "update data rlc", - meta.query_advice(data_rlc, Rotation::cur()), - new_data_rlc, - ); - }); - // Keep length/data_rlc the same on rows where we don't absorb data - cb.condition( - and::expr([ - meta.query_fixed(q_enable, Rotation::cur()) - - meta.query_fixed(q_first, Rotation::cur()), - not::expr(q_padding), - ]), - |cb| { - cb.require_equal( - "data_rlc equality check", - meta.query_advice(data_rlc, Rotation::cur()), - data_rlc_prev.clone(), - ); - }, - ); - cb.gate(1.expr()) - }); - - info!("Degree: {}", meta.degree()); - info!("Minimum rows: {}", meta.minimum_rows()); - info!("Total Lookups: {}", total_lookup_counter); - #[cfg(feature = "display")] - { - println!("Total Keccak Columns: {}", cell_manager.get_width()); - std::env::set_var("KECCAK_ADVICE_COLUMNS", cell_manager.get_width().to_string()); - } - #[cfg(not(feature = "display"))] - info!("Total Keccak Columns: {}", cell_manager.get_width()); - info!("num unused cells: {}", cell_manager.get_num_unused_cells()); - info!("part_size absorb: {}", get_num_bits_per_absorb_lookup()); - info!("part_size theta: {}", get_num_bits_per_theta_c_lookup()); - info!("part_size theta c: {}", get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE)); - info!("part_size theta t: {}", get_num_bits_per_lookup(4)); - info!("part_size rho/pi: {}", get_num_bits_per_rho_pi_lookup()); - info!("part_size chi base: {}", get_num_bits_per_base_chi_lookup()); - info!("uniform part sizes: {:?}", target_part_sizes(get_num_bits_per_theta_c_lookup())); - - KeccakCircuitConfig { - challenge, - q_enable, - // q_enable_row, - q_first, - q_round, - q_absorb, - q_round_last, - q_padding, - q_padding_last, - keccak_table, - cell_manager, - round_cst, - normalize_3, - normalize_4, - normalize_6, - chi_base_table, - pack_table, - _marker: PhantomData, - } - } -} - -impl KeccakCircuitConfig { - pub fn assign(&self, region: &mut Region<'_, F>, witness: &[KeccakRow]) { - for (offset, keccak_row) in witness.iter().enumerate() { - self.set_row(region, offset, keccak_row); - } - } - - pub fn set_row(&self, region: &mut Region<'_, F>, offset: usize, row: &KeccakRow) { - // Fixed selectors - for (_, column, value) in &[ - ("q_enable", self.q_enable, F::from(row.q_enable)), - ("q_first", self.q_first, F::from(offset == 0)), - ("q_round", self.q_round, F::from(row.q_round)), - ("q_round_last", self.q_round_last, F::from(row.q_round_last)), - ("q_absorb", self.q_absorb, F::from(row.q_absorb)), - ("q_padding", self.q_padding, F::from(row.q_padding)), - ("q_padding_last", self.q_padding_last, F::from(row.q_padding_last)), - ] { - assign_fixed_custom(region, *column, offset, *value); - } - - assign_advice_custom( - region, - self.keccak_table.is_enabled, - offset, - Value::known(F::from(row.is_final)), - ); - - // Cell values - row.cell_values.iter().zip(self.cell_manager.columns()).for_each(|(bit, column)| { - assign_advice_custom(region, column.advice, offset, Value::known(*bit)); - }); - - // Round constant - assign_fixed_custom(region, self.round_cst, offset, row.round_cst); - } - - pub fn load_aux_tables(&self, layouter: &mut impl Layouter) -> Result<(), Error> { - load_normalize_table(layouter, "normalize_6", &self.normalize_6, 6u64)?; - load_normalize_table(layouter, "normalize_4", &self.normalize_4, 4u64)?; - load_normalize_table(layouter, "normalize_3", &self.normalize_3, 3u64)?; - load_lookup_table( - layouter, - "chi base", - &self.chi_base_table, - get_num_bits_per_base_chi_lookup(), - &CHI_BASE_LOOKUP_TABLE, - )?; - load_pack_table(layouter, &self.pack_table) - } -} - -/// Computes and assigns the input RLC values (but not the output RLC values: -/// see `multi_keccak_phase1`). -pub fn keccak_phase1<'v, F: Field>( - region: &mut Region, - keccak_table: &KeccakTable, - bytes: &[u8], - challenge: Value, - input_rlcs: &mut Vec>, - offset: &mut usize, -) { - let num_chunks = get_num_keccak_f(bytes.len()); - let num_rows_per_round = get_num_rows_per_round(); - - let mut byte_idx = 0; - let mut data_rlc = Value::known(F::zero()); - - for _ in 0..num_chunks { - for round in 0..NUM_ROUNDS + 1 { - if round < NUM_WORDS_TO_ABSORB { - for idx in 0..NUM_BYTES_PER_WORD { - assign_advice_custom( - region, - keccak_table.input_rlc, - *offset + idx + 1, - data_rlc, - ); - if byte_idx < bytes.len() { - data_rlc = - data_rlc * challenge + Value::known(F::from(bytes[byte_idx] as u64)); - } - byte_idx += 1; - } - } - let input_rlc = assign_advice_custom(region, keccak_table.input_rlc, *offset, data_rlc); - if round == NUM_ROUNDS { - input_rlcs.push(input_rlc); - } - - *offset += num_rows_per_round; - } - } -} - -/// Witness generation in `FirstPhase` for a keccak hash digest without -/// computing RLCs, which are deferred to `SecondPhase`. -pub fn keccak_phase0( - rows: &mut Vec>, - squeeze_digests: &mut Vec<[F; NUM_WORDS_TO_SQUEEZE]>, - bytes: &[u8], -) { - let mut bits = into_bits(bytes); - let mut s = [[F::zero(); 5]; 5]; - let absorb_positions = get_absorb_positions(); - let num_bytes_in_last_block = bytes.len() % RATE; - let num_rows_per_round = get_num_rows_per_round(); - let two = F::from(2u64); - - // Padding - bits.push(1); - while (bits.len() + 1) % RATE_IN_BITS != 0 { - bits.push(0); - } - bits.push(1); - - let chunks = bits.chunks(RATE_IN_BITS); - let num_chunks = chunks.len(); - - let mut cell_managers = Vec::with_capacity(NUM_ROUNDS + 1); - let mut regions = Vec::with_capacity(NUM_ROUNDS + 1); - let mut hash_words = [F::zero(); NUM_WORDS_TO_SQUEEZE]; - - for (idx, chunk) in chunks.enumerate() { - let is_final_block = idx == num_chunks - 1; - - let mut absorb_rows = Vec::new(); - // Absorb - for (idx, &(i, j)) in absorb_positions.iter().enumerate() { - let absorb = pack(&chunk[idx * 64..(idx + 1) * 64]); - let from = s[i][j]; - s[i][j] = field_xor(s[i][j], absorb); - absorb_rows.push(AbsorbData { from, absorb, result: s[i][j] }); - } - - // better memory management to clear already allocated Vecs - cell_managers.clear(); - regions.clear(); - - for round in 0..NUM_ROUNDS + 1 { - let mut cell_manager = CellManager::new(num_rows_per_round); - let mut region = KeccakRegion::new(); - - let mut absorb_row = AbsorbData::default(); - if round < NUM_WORDS_TO_ABSORB { - absorb_row = absorb_rows[round].clone(); - } - - // State data - for s in &s { - for s in s { - let cell = cell_manager.query_cell_value(); - cell.assign(&mut region, 0, *s); - } - } - - // Absorb data - let absorb_from = cell_manager.query_cell_value(); - let absorb_data = cell_manager.query_cell_value(); - let absorb_result = cell_manager.query_cell_value(); - absorb_from.assign(&mut region, 0, absorb_row.from); - absorb_data.assign(&mut region, 0, absorb_row.absorb); - absorb_result.assign(&mut region, 0, absorb_row.result); - - // Absorb - cell_manager.start_region(); - let part_size = get_num_bits_per_absorb_lookup(); - let input = absorb_row.from + absorb_row.absorb; - let absorb_fat = - split::value(&mut cell_manager, &mut region, input, 0, part_size, false, None); - cell_manager.start_region(); - let _absorb_result = transform::value( - &mut cell_manager, - &mut region, - absorb_fat.clone(), - true, - |v| v & 1, - true, - ); - - // Padding - cell_manager.start_region(); - // Unpack a single word into bytes (for the absorption) - // Potential optimization: could do multiple bytes per lookup - let packed = - split::value(&mut cell_manager, &mut region, absorb_row.absorb, 0, 8, false, None); - cell_manager.start_region(); - let input_bytes = - transform::value(&mut cell_manager, &mut region, packed, false, |v| *v, true); - cell_manager.start_region(); - let is_paddings = - input_bytes.iter().map(|_| cell_manager.query_cell_value()).collect::>(); - debug_assert_eq!(is_paddings.len(), NUM_BYTES_PER_WORD); - if round < NUM_WORDS_TO_ABSORB { - for (padding_idx, is_padding) in is_paddings.iter().enumerate() { - let byte_idx = round * NUM_BYTES_PER_WORD + padding_idx; - let padding = is_final_block && byte_idx >= num_bytes_in_last_block; - is_padding.assign(&mut region, 0, F::from(padding)); - } - } - cell_manager.start_region(); - - if round != NUM_ROUNDS { - // Theta - let part_size = get_num_bits_per_theta_c_lookup(); - let mut bcf = Vec::new(); - for s in &s { - let c = s[0] + s[1] + s[2] + s[3] + s[4]; - let bc_fat = - split::value(&mut cell_manager, &mut region, c, 1, part_size, false, None); - bcf.push(bc_fat); - } - cell_manager.start_region(); - let mut bc = Vec::new(); - for bc_fat in bcf { - let bc_norm = transform::value( - &mut cell_manager, - &mut region, - bc_fat.clone(), - true, - |v| v & 1, - true, - ); - bc.push(bc_norm); - } - cell_manager.start_region(); - let mut os = [[F::zero(); 5]; 5]; - for i in 0..5 { - let t = decode::value(bc[(i + 4) % 5].clone()) - + decode::value(rotate(bc[(i + 1) % 5].clone(), 1, part_size)); - for j in 0..5 { - os[i][j] = s[i][j] + t; - } - } - s = os; - cell_manager.start_region(); - - // Rho/Pi - let part_size = get_num_bits_per_base_chi_lookup(); - let target_word_sizes = target_part_sizes(part_size); - let num_word_parts = target_word_sizes.len(); - let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = - array_init::array_init(|_| { - array_init::array_init(|_| array_init::array_init(|_| Vec::new())) - }); - let mut column_starts = [0usize; 3]; - for p in 0..3 { - column_starts[p] = cell_manager.start_region(); - let mut row_idx = 0; - for j in 0..5 { - for _ in 0..num_word_parts { - for i in 0..5 { - rho_pi_chi_cells[p][i][j] - .push(cell_manager.query_cell_value_at_row(row_idx as i32)); - } - row_idx = (row_idx + 1) % num_rows_per_round; - } - } - } - cell_manager.start_region(); - let mut os_parts: [[Vec>; 5]; 5] = - array_init::array_init(|_| array_init::array_init(|_| Vec::new())); - for (j, os_part) in os_parts.iter_mut().enumerate() { - for i in 0..5 { - let s_parts = split_uniform::value( - &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], - &mut cell_manager, - &mut region, - s[i][j], - RHO_MATRIX[i][j], - part_size, - true, - ); - - let s_parts = transform_to::value( - &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], - &mut region, - s_parts.clone(), - true, - |v| v & 1, - ); - os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); - } - } - cell_manager.start_region(); - - // Chi - let part_size_base = get_num_bits_per_base_chi_lookup(); - let three_packed = pack::(&vec![3u8; part_size_base]); - let mut os = [[F::zero(); 5]; 5]; - for j in 0..5 { - for i in 0..5 { - let mut s_parts = Vec::new(); - for ((part_a, part_b), part_c) in os_parts[i][j] - .iter() - .zip(os_parts[(i + 1) % 5][j].iter()) - .zip(os_parts[(i + 2) % 5][j].iter()) - { - let value = - three_packed - two * part_a.value + part_b.value - part_c.value; - s_parts.push(PartValue { - num_bits: part_size_base, - rot: j as i32, - value, - }); - } - os[i][j] = decode::value(transform_to::value( - &rho_pi_chi_cells[2][i][j], - &mut region, - s_parts.clone(), - true, - |v| CHI_BASE_LOOKUP_TABLE[*v as usize], - )); - } - } - s = os; - cell_manager.start_region(); - - // iota - let part_size = get_num_bits_per_absorb_lookup(); - let input = s[0][0] + pack_u64::(ROUND_CST[round]); - let iota_parts = split::value::( - &mut cell_manager, - &mut region, - input, - 0, - part_size, - false, - None, - ); - cell_manager.start_region(); - s[0][0] = decode::value(transform::value( - &mut cell_manager, - &mut region, - iota_parts.clone(), - true, - |v| v & 1, - true, - )); - } - - // The words to squeeze out: this is the hash digest as words with - // NUM_BYTES_PER_WORD (=8) bytes each - for (hash_word, a) in hash_words.iter_mut().zip(s.iter()) { - *hash_word = a[0]; - } - - cell_managers.push(cell_manager); - regions.push(region); - } - - // Now that we know the state at the end of the rounds, set the squeeze data - let num_rounds = cell_managers.len(); - for (idx, word) in hash_words.iter().enumerate() { - let cell_manager = &mut cell_managers[num_rounds - 2 - idx]; - let region = &mut regions[num_rounds - 2 - idx]; - - cell_manager.start_region(); - let squeeze_packed = cell_manager.query_cell_value(); - squeeze_packed.assign(region, 0, *word); - - cell_manager.start_region(); - let packed = split::value(cell_manager, region, *word, 0, 8, false, None); - cell_manager.start_region(); - transform::value(cell_manager, region, packed, false, |v| *v, true); - } - squeeze_digests.push(hash_words); - - for round in 0..NUM_ROUNDS + 1 { - let round_cst = pack_u64(ROUND_CST[round]); - - for row_idx in 0..num_rows_per_round { - rows.push(KeccakRow { - q_enable: row_idx == 0, - // q_enable_row: true, - q_round: row_idx == 0 && round < NUM_ROUNDS, - q_absorb: row_idx == 0 && round == NUM_ROUNDS, - q_round_last: row_idx == 0 && round == NUM_ROUNDS, - q_padding: row_idx == 0 && round < NUM_WORDS_TO_ABSORB, - q_padding_last: row_idx == 0 && round == NUM_WORDS_TO_ABSORB - 1, - round_cst, - is_final: is_final_block && round == NUM_ROUNDS && row_idx == 0, - cell_values: regions[round].rows.get(row_idx).unwrap_or(&vec![]).clone(), - }); - #[cfg(debug_assertions)] - { - let mut r = rows.last().unwrap().clone(); - r.cell_values.clear(); - log::trace!("offset {:?} row idx {} row {:?}", rows.len() - 1, row_idx, r); - } - } - log::trace!(" = = = = = = round {} end", round); - } - log::trace!(" ====================== chunk {} end", idx); - } - - #[cfg(debug_assertions)] - { - let hash_bytes = s - .into_iter() - .take(4) - .map(|a| { - pack_with_base::(&unpack(a[0]), 2) - .to_bytes_le() - .into_iter() - .take(8) - .collect::>() - .to_vec() - }) - .collect::>(); - debug!("hash: {:x?}", &(hash_bytes[0..4].concat())); - // debug!("data rlc: {:x?}", data_rlc); - } -} - -/// Computes and assigns the input and output RLC values. -pub fn multi_keccak_phase1<'a, 'v, F: Field>( - region: &mut Region, - keccak_table: &KeccakTable, - bytes: impl IntoIterator, - challenge: Value, - squeeze_digests: Vec<[F; NUM_WORDS_TO_SQUEEZE]>, -) -> (Vec>, Vec>) { - let mut input_rlcs = Vec::with_capacity(squeeze_digests.len()); - let mut output_rlcs = Vec::with_capacity(squeeze_digests.len()); - - let num_rows_per_round = get_num_rows_per_round(); - for idx in 0..num_rows_per_round { - [keccak_table.input_rlc, keccak_table.output_rlc] - .map(|column| assign_advice_custom(region, column, idx, Value::known(F::zero()))); - } - - let mut offset = num_rows_per_round; - for bytes in bytes { - keccak_phase1(region, keccak_table, bytes, challenge, &mut input_rlcs, &mut offset); - } - debug_assert!(input_rlcs.len() <= squeeze_digests.len()); - while input_rlcs.len() < squeeze_digests.len() { - keccak_phase1(region, keccak_table, &[], challenge, &mut input_rlcs, &mut offset); - } - - offset = num_rows_per_round; - for hash_words in squeeze_digests { - offset += num_rows_per_round * NUM_ROUNDS; - let hash_rlc = hash_words - .into_iter() - .flat_map(|a| to_bytes::value(&unpack(a))) - .map(|x| Value::known(F::from(x as u64))) - .reduce(|rlc, x| rlc * challenge + x) - .unwrap(); - let output_rlc = assign_advice_custom(region, keccak_table.output_rlc, offset, hash_rlc); - output_rlcs.push(output_rlc); - offset += num_rows_per_round; - } - - (input_rlcs, output_rlcs) -} - -/// Returns vector of KeccakRow and vector of hash digest outputs. -pub fn multi_keccak_phase0( - bytes: &[Vec], - capacity: Option, -) -> (Vec>, Vec<[F; NUM_WORDS_TO_SQUEEZE]>) { - let num_rows_per_round = get_num_rows_per_round(); - let mut rows = - Vec::with_capacity((1 + capacity.unwrap_or(0) * (NUM_ROUNDS + 1)) * num_rows_per_round); - // Dummy first row so that the initial data is absorbed - // The initial data doesn't really matter, `is_final` just needs to be disabled. - rows.append(&mut KeccakRow::dummy_rows(num_rows_per_round)); - // Actual keccaks - let artifacts = bytes - .par_iter() - .map(|bytes| { - let num_keccak_f = get_num_keccak_f(bytes.len()); - let mut squeeze_digests = Vec::with_capacity(num_keccak_f); - let mut rows = Vec::with_capacity(num_keccak_f * (NUM_ROUNDS + 1) * num_rows_per_round); - keccak_phase0(&mut rows, &mut squeeze_digests, bytes); - (rows, squeeze_digests) - }) - .collect::>(); - - let mut squeeze_digests = Vec::with_capacity(capacity.unwrap_or(0)); - for (rows_part, squeezes) in artifacts { - rows.extend(rows_part); - squeeze_digests.extend(squeezes); - } - - if let Some(capacity) = capacity { - // Pad with no data hashes to the expected capacity - while rows.len() < (1 + capacity * (NUM_ROUNDS + 1)) * get_num_rows_per_round() { - keccak_phase0(&mut rows, &mut squeeze_digests, &[]); - } - // Check that we are not over capacity - if rows.len() > (1 + capacity * (NUM_ROUNDS + 1)) * get_num_rows_per_round() { - panic!("{:?}", Error::BoundsFailure); - } - } - (rows, squeeze_digests) -} diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs deleted file mode 100644 index d009d044..00000000 --- a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs +++ /dev/null @@ -1,169 +0,0 @@ -use super::*; -use crate::halo2_proofs::{ - circuit::SimpleFloorPlanner, - dev::MockProver, - halo2curves::bn256::Fr, - halo2curves::bn256::{Bn256, G1Affine}, - plonk::{create_proof, keygen_pk, keygen_vk, verify_proof}, - plonk::{Circuit, FirstPhase}, - poly::{ - commitment::ParamsProver, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG, ParamsVerifierKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, -}; -use halo2_base::SKIP_FIRST_PASS; -use rand_core::OsRng; - -/// KeccakCircuit -#[derive(Default, Clone, Debug)] -pub struct KeccakCircuit { - inputs: Vec>, - num_rows: Option, - _marker: PhantomData, -} - -#[cfg(any(feature = "test", test))] -impl Circuit for KeccakCircuit { - type Config = KeccakCircuitConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - // MockProver complains if you only have columns in SecondPhase, so let's just make an empty column in FirstPhase - meta.advice_column(); - - let challenge = meta.challenge_usable_after(FirstPhase); - KeccakCircuitConfig::new(meta, challenge) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.load_aux_tables(&mut layouter)?; - let mut challenge = layouter.get_challenge(config.challenge); - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "keccak circuit", - |mut region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let (witness, squeeze_digests) = multi_keccak_phase0(&self.inputs, self.capacity()); - config.assign(&mut region, &witness); - - #[cfg(feature = "halo2-axiom")] - { - region.next_phase(); - challenge = region.get_challenge(config.challenge); - } - multi_keccak_phase1( - &mut region, - &config.keccak_table, - self.inputs.iter().map(|v| v.as_slice()), - challenge, - squeeze_digests, - ); - println!("finished keccak circuit"); - Ok(()) - }, - )?; - - Ok(()) - } -} - -impl KeccakCircuit { - /// Creates a new circuit instance - pub fn new(num_rows: Option, inputs: Vec>) -> Self { - KeccakCircuit { inputs, num_rows, _marker: PhantomData } - } - - /// The number of keccak_f's that can be done in this circuit - pub fn capacity(&self) -> Option { - // Subtract two for unusable rows - self.num_rows.map(|num_rows| num_rows / ((NUM_ROUNDS + 1) * get_num_rows_per_round()) - 2) - } -} - -fn verify(k: u32, inputs: Vec>, _success: bool) { - let circuit = KeccakCircuit::new(Some(2usize.pow(k)), inputs); - - let prover = MockProver::::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); -} - -/// Cmdline: KECCAK_ROWS=28 KECCAK_DEGREE=14 RUST_LOG=info cargo test -- --nocapture packed_multi_keccak_simple -#[test] -fn packed_multi_keccak_simple() { - let _ = env_logger::builder().is_test(true).try_init(); - - let k = 14; - let inputs = vec![ - vec![], - (0u8..1).collect::>(), - (0u8..135).collect::>(), - (0u8..136).collect::>(), - (0u8..200).collect::>(), - ]; - verify::(k, inputs, true); -} - -/// Cmdline: KECCAK_DEGREE=14 RUST_LOG=info cargo test -- --nocapture packed_multi_keccak_prover -#[test] -fn packed_multi_keccak_prover() { - let _ = env_logger::builder().is_test(true).try_init(); - - let k: u32 = var("KECCAK_DEGREE").unwrap_or_else(|_| "14".to_string()).parse().unwrap(); - let params = ParamsKZG::::setup(k, OsRng); - - let inputs = vec![ - vec![], - (0u8..1).collect::>(), - (0u8..135).collect::>(), - (0u8..136).collect::>(), - (0u8..200).collect::>(), - ]; - let circuit = KeccakCircuit::new(Some(2usize.pow(k)), inputs); - - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); - - let verifier_params: ParamsVerifierKZG = params.verifier_params().clone(); - let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); - - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("proof generation should not fail"); - let proof = transcript.finalize(); - - let mut verifier_transcript = Blake2bRead::<_, G1Affine, Challenge255<_>>::init(&proof[..]); - let strategy = SingleStrategy::new(¶ms); - - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(&verifier_params, pk.get_vk(), strategy, &[&[]], &mut verifier_transcript) - .expect("failed to verify bench circuit"); -} diff --git a/hashes/zkevm-keccak/src/util.rs b/hashes/zkevm-keccak/src/util.rs deleted file mode 100644 index b3e2e2b5..00000000 --- a/hashes/zkevm-keccak/src/util.rs +++ /dev/null @@ -1,412 +0,0 @@ -//! Utility traits, functions used in the crate. - -use crate::halo2_proofs::{ - circuit::{Layouter, Value}, - plonk::{Error, TableColumn}, -}; -use itertools::Itertools; -use std::env::var; - -pub mod constraint_builder; -pub mod eth_types; -pub mod expression; - -use eth_types::{Field, ToScalar, Word}; - -pub const NUM_BITS_PER_BYTE: usize = 8; -pub const NUM_BYTES_PER_WORD: usize = 8; -pub const NUM_BITS_PER_WORD: usize = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; -pub const KECCAK_WIDTH: usize = 5 * 5; -pub const KECCAK_WIDTH_IN_BITS: usize = KECCAK_WIDTH * NUM_BITS_PER_WORD; -pub const NUM_ROUNDS: usize = 24; -pub const NUM_WORDS_TO_ABSORB: usize = 17; -pub const NUM_BYTES_TO_ABSORB: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; -pub const NUM_WORDS_TO_SQUEEZE: usize = 4; -pub const NUM_BYTES_TO_SQUEEZE: usize = NUM_WORDS_TO_SQUEEZE * NUM_BYTES_PER_WORD; -pub const ABSORB_WIDTH_PER_ROW: usize = NUM_BITS_PER_WORD; -pub const ABSORB_WIDTH_PER_ROW_BYTES: usize = ABSORB_WIDTH_PER_ROW / NUM_BITS_PER_BYTE; -pub const RATE: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; -pub const RATE_IN_BITS: usize = RATE * NUM_BITS_PER_BYTE; -// pub(crate) const THETA_C_WIDTH: usize = 5 * NUM_BITS_PER_WORD; -pub(crate) const RHO_MATRIX: [[usize; 5]; 5] = [ - [0, 36, 3, 41, 18], - [1, 44, 10, 45, 2], - [62, 6, 43, 15, 61], - [28, 55, 25, 21, 56], - [27, 20, 39, 8, 14], -]; -pub(crate) const ROUND_CST: [u64; NUM_ROUNDS + 1] = [ - 0x0000000000000001, - 0x0000000000008082, - 0x800000000000808a, - 0x8000000080008000, - 0x000000000000808b, - 0x0000000080000001, - 0x8000000080008081, - 0x8000000000008009, - 0x000000000000008a, - 0x0000000000000088, - 0x0000000080008009, - 0x000000008000000a, - 0x000000008000808b, - 0x800000000000008b, - 0x8000000000008089, - 0x8000000000008003, - 0x8000000000008002, - 0x8000000000000080, - 0x000000000000800a, - 0x800000008000000a, - 0x8000000080008081, - 0x8000000000008080, - 0x0000000080000001, - 0x8000000080008008, - 0x0000000000000000, // absorb round -]; -// Bit positions that have a non-zero value in `IOTA_ROUND_CST`. -// pub(crate) const ROUND_CST_BIT_POS: [usize; 7] = [0, 1, 3, 7, 15, 31, 63]; - -// The number of bits used in the sparse word representation per bit -pub const BIT_COUNT: usize = 3; -// The base of the bit in the sparse word representation -pub const BIT_SIZE: usize = 2usize.pow(BIT_COUNT as u32); - -// `a ^ ((~b) & c)` is calculated by doing `lookup[3 - 2*a + b - c]` -pub(crate) const CHI_BASE_LOOKUP_TABLE: [u8; 5] = [0, 1, 1, 0, 0]; -// `a ^ ((~b) & c) ^ d` is calculated by doing `lookup[5 - 2*a - b + c - 2*d]` -// pub(crate) const CHI_EXT_LOOKUP_TABLE: [u8; 7] = [0, 0, 1, 1, 0, 0, 1]; - -/// Description of which bits (positions) a part contains -#[derive(Clone, Debug)] -pub struct PartInfo { - /// The bit positions of the part - pub bits: Vec, -} - -/// Description of how a word is split into parts -#[derive(Clone, Debug)] -pub struct WordParts { - /// The parts of the word - pub parts: Vec, -} - -/// Packs bits into bytes -pub mod to_bytes { - pub(crate) fn value(bits: &[u8]) -> Vec { - debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); - let mut bytes = Vec::new(); - for byte_bits in bits.chunks(8) { - let mut value = 0u8; - for (idx, bit) in byte_bits.iter().enumerate() { - value += *bit << idx; - } - bytes.push(value); - } - bytes - } -} - -/// Rotates a word that was split into parts to the right -pub fn rotate(parts: Vec, count: usize, part_size: usize) -> Vec { - let mut rotated_parts = parts; - rotated_parts.rotate_right(get_rotate_count(count, part_size)); - rotated_parts -} - -/// Rotates a word that was split into parts to the left -pub fn rotate_rev(parts: Vec, count: usize, part_size: usize) -> Vec { - let mut rotated_parts = parts; - rotated_parts.rotate_left(get_rotate_count(count, part_size)); - rotated_parts -} - -/// Rotates bits left -pub fn rotate_left(bits: &[u8], count: usize) -> [u8; NUM_BITS_PER_WORD] { - let mut rotated = bits.to_vec(); - rotated.rotate_left(count); - rotated.try_into().unwrap() -} - -/// Scatters a value into a packed word constant -pub mod scatter { - use super::{eth_types::Field, pack}; - use crate::halo2_proofs::plonk::Expression; - - pub(crate) fn expr(value: u8, count: usize) -> Expression { - Expression::Constant(pack(&vec![value; count])) - } -} - -/// The words that absorb data -pub fn get_absorb_positions() -> Vec<(usize, usize)> { - let mut absorb_positions = Vec::new(); - for j in 0..5 { - for i in 0..5 { - if i + j * 5 < 17 { - absorb_positions.push((i, j)); - } - } - } - absorb_positions -} - -/// Converts bytes into bits -pub fn into_bits(bytes: &[u8]) -> Vec { - let mut bits: Vec = vec![0; bytes.len() * 8]; - for (byte_idx, byte) in bytes.iter().enumerate() { - for idx in 0u64..8 { - bits[byte_idx * 8 + (idx as usize)] = (*byte >> idx) & 1; - } - } - bits -} - -/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word -pub fn pack(bits: &[u8]) -> F { - pack_with_base(bits, BIT_SIZE) -} - -/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word with the -/// specified bit base -pub fn pack_with_base(bits: &[u8], base: usize) -> F { - let base = F::from(base as u64); - bits.iter().rev().fold(F::zero(), |acc, &bit| acc * base + F::from(bit as u64)) -} - -/// Decodes the bits using the position data found in the part info -pub fn pack_part(bits: &[u8], info: &PartInfo) -> u64 { - info.bits - .iter() - .rev() - .fold(0u64, |acc, &bit_pos| acc * (BIT_SIZE as u64) + (bits[bit_pos] as u64)) -} - -/// Unpack a sparse keccak word into bits in the range [0,BIT_SIZE[ -pub fn unpack(packed: F) -> [u8; NUM_BITS_PER_WORD] { - let mut bits = [0; NUM_BITS_PER_WORD]; - let packed = Word::from_little_endian(packed.to_bytes_le().as_ref()); - let mask = Word::from(BIT_SIZE - 1); - for (idx, bit) in bits.iter_mut().enumerate() { - *bit = ((packed >> (idx * BIT_COUNT)) & mask).as_u32() as u8; - } - debug_assert_eq!(pack::(&bits), packed.to_scalar().unwrap()); - bits -} - -/// Pack bits stored in a u64 value into a sparse keccak word -pub fn pack_u64(value: u64) -> F { - pack(&((0..NUM_BITS_PER_WORD).map(|i| ((value >> i) & 1) as u8).collect::>())) -} - -/// Calculates a ^ b with a and b field elements -pub fn field_xor(a: F, b: F) -> F { - let mut bytes = [0u8; 32]; - for (idx, (a, b)) in a.to_bytes_le().into_iter().zip(b.to_bytes_le()).enumerate() { - bytes[idx] = a ^ b; - } - F::from_bytes_le(&bytes) -} - -/// Returns the size (in bits) of each part size when splitting up a keccak word -/// in parts of `part_size` -pub fn target_part_sizes(part_size: usize) -> Vec { - let num_full_chunks = NUM_BITS_PER_WORD / part_size; - let partial_chunk_size = NUM_BITS_PER_WORD % part_size; - let mut part_sizes = vec![part_size; num_full_chunks]; - if partial_chunk_size > 0 { - part_sizes.push(partial_chunk_size); - } - part_sizes -} - -/// Gets the rotation count in parts -pub fn get_rotate_count(count: usize, part_size: usize) -> usize { - (count + part_size - 1) / part_size -} - -impl WordParts { - /// Returns a description of how a word will be split into parts - pub fn new(part_size: usize, rot: usize, normalize: bool) -> Self { - let mut bits = (0usize..64).collect::>(); - bits.rotate_right(rot); - - let mut parts = Vec::new(); - let mut rot_idx = 0; - - let mut idx = 0; - let target_sizes = if normalize { - // After the rotation we want the parts of all the words to be at the same - // positions - target_part_sizes(part_size) - } else { - // Here we only care about minimizing the number of parts - let num_parts_a = rot / part_size; - let partial_part_a = rot % part_size; - - let num_parts_b = (64 - rot) / part_size; - let partial_part_b = (64 - rot) % part_size; - - let mut part_sizes = vec![part_size; num_parts_a]; - if partial_part_a > 0 { - part_sizes.push(partial_part_a); - } - - part_sizes.extend(vec![part_size; num_parts_b]); - if partial_part_b > 0 { - part_sizes.push(partial_part_b); - } - - part_sizes - }; - // Split into parts bit by bit - for part_size in target_sizes { - let mut num_consumed = 0; - while num_consumed < part_size { - let mut part_bits: Vec = Vec::new(); - while num_consumed < part_size { - if !part_bits.is_empty() && bits[idx] == 0 { - break; - } - if bits[idx] == 0 { - rot_idx = parts.len(); - } - part_bits.push(bits[idx]); - idx += 1; - num_consumed += 1; - } - parts.push(PartInfo { bits: part_bits }); - } - } - - debug_assert_eq!(get_rotate_count(rot, part_size), rot_idx); - - parts.rotate_left(rot_idx); - debug_assert_eq!(parts[0].bits[0], 0); - - Self { parts } - } -} - -/// Get the degree of the circuit from the KECCAK_DEGREE env variable -pub fn get_degree() -> usize { - var("KECCAK_DEGREE") - .expect("Need to set KECCAK_DEGREE to log_2(rows) of circuit") - .parse() - .expect("Cannot parse KECCAK_DEGREE env var as usize") -} - -/// Returns how many bits we can process in a single lookup given the range of -/// values the bit can have and the height of the circuit. -pub fn get_num_bits_per_lookup(range: usize) -> usize { - let num_unusable_rows = 31; - let degree = get_degree() as u32; - let mut num_bits = 1; - while range.pow(num_bits + 1) + num_unusable_rows <= 2usize.pow(degree) { - num_bits += 1; - } - num_bits as usize -} - -/// Loads a normalization table with the given parameters -pub(crate) fn load_normalize_table( - layouter: &mut impl Layouter, - name: &str, - tables: &[TableColumn; 2], - range: u64, -) -> Result<(), Error> { - let part_size = get_num_bits_per_lookup(range as usize); - layouter.assign_table( - || format!("{name} table"), - |mut table| { - for (offset, perm) in - (0..part_size).map(|_| 0u64..range).multi_cartesian_product().enumerate() - { - let mut input = 0u64; - let mut output = 0u64; - let mut factor = 1u64; - for input_part in perm.iter() { - input += input_part * factor; - output += (input_part & 1) * factor; - factor *= BIT_SIZE as u64; - } - table.assign_cell( - || format!("{name} input"), - tables[0], - offset, - || Value::known(F::from(input)), - )?; - table.assign_cell( - || format!("{name} output"), - tables[1], - offset, - || Value::known(F::from(output)), - )?; - } - Ok(()) - }, - ) -} - -/// Loads the byte packing table -pub(crate) fn load_pack_table( - layouter: &mut impl Layouter, - tables: &[TableColumn; 2], -) -> Result<(), Error> { - layouter.assign_table( - || "pack table", - |mut table| { - for (offset, idx) in (0u64..256).enumerate() { - table.assign_cell( - || "unpacked", - tables[0], - offset, - || Value::known(F::from(idx)), - )?; - let packed: F = pack(&into_bits(&[idx as u8])); - table.assign_cell(|| "packed", tables[1], offset, || Value::known(packed))?; - } - Ok(()) - }, - ) -} - -/// Loads a lookup table -pub(crate) fn load_lookup_table( - layouter: &mut impl Layouter, - name: &str, - tables: &[TableColumn; 2], - part_size: usize, - lookup_table: &[u8], -) -> Result<(), Error> { - layouter.assign_table( - || format!("{name} table"), - |mut table| { - for (offset, perm) in (0..part_size) - .map(|_| 0..lookup_table.len() as u64) - .multi_cartesian_product() - .enumerate() - { - let mut input = 0u64; - let mut output = 0u64; - let mut factor = 1u64; - for input_part in perm.iter() { - input += input_part * factor; - output += (lookup_table[*input_part as usize] as u64) * factor; - factor *= BIT_SIZE as u64; - } - table.assign_cell( - || format!("{name} input"), - tables[0], - offset, - || Value::known(F::from(input)), - )?; - table.assign_cell( - || format!("{name} output"), - tables[1], - offset, - || Value::known(F::from(output)), - )?; - } - Ok(()) - }, - ) -} diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml new file mode 100644 index 00000000..28703f24 --- /dev/null +++ b/hashes/zkevm/Cargo.toml @@ -0,0 +1,43 @@ +[package] +name = "zkevm-hashes" +version = "0.1.4" +edition = "2021" +license = "MIT OR Apache-2.0" + +[dependencies] +array-init = "2.0.0" +ethers-core = "2.0.8" +rand = "0.8" +itertools = "0.11" +lazy_static = "1.4" +log = "0.4" +num-bigint = { version = "0.4" } +halo2-base = { path = "../../halo2-base", default-features = false, features = [ + "test-utils", +] } +rayon = "1.7" +sha3 = "0.10.8" +# always included but without features to use Native poseidon +snark-verifier = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "develop", default-features = false } +getset = "0.1.2" + +[dev-dependencies] +criterion = "0.3" +ctor = "0.1.22" +ethers-signers = "2.0.8" +hex = "0.4.3" +itertools = "0.11" +pretty_assertions = "1.0.0" +rand_core = "0.6.4" +rand_xorshift = "0.3" +env_logger = "0.10" +test-case = "3.1.0" + +[features] +default = ["halo2-axiom", "display"] +display = ["halo2-base/display", "snark-verifier/display"] +halo2-pse = ["halo2-base/halo2-pse", "snark-verifier/halo2-pse"] +halo2-axiom = ["halo2-base/halo2-axiom", "snark-verifier/halo2-axiom"] +jemallocator = ["halo2-base/jemallocator"] +mimalloc = ["halo2-base/mimalloc"] +asm = ["halo2-base/asm"] diff --git a/hashes/zkevm/src/keccak/README.md b/hashes/zkevm/src/keccak/README.md new file mode 100644 index 00000000..527d671f --- /dev/null +++ b/hashes/zkevm/src/keccak/README.md @@ -0,0 +1,144 @@ +# ZKEVM Keccak + +## Vanilla + +Keccak circuit in vanilla halo2. This implementation starts from [PSE version](https://github.com/privacy-scaling-explorations/zkevm-circuits/tree/main/zkevm-circuits/src/keccak_circuit), then adopts some changes from [this PR](https://github.com/scroll-tech/zkevm-circuits/pull/216) and later updates in PSE version. + +The major differences is that this version directly represent raw inputs and Keccak results as witnesses, while the original version only has RLCs(random linear combination) of raw inputs and Keccak results. Because this version doesn't need RLCs, it doesn't have the 2nd phase or use challenge APIs. + +### Logical Input/Output + +Logically the circuit takes an array of bytes as inputs and Keccak results of these bytes as outputs. + +`keccak::vanilla::witness::multi_keccak` generates the witnesses of the ciruit for a given input. + +### Background Knowledge + +All these items remain consistent across all versions. + +- Keccak process a logical input `keccak_f` by `keccak_f`. +- Each `keccak_f` has `NUM_ROUNDS`(24) rounds. +- The number of rows of a round(`rows_per_round`) is configurable. Usually less rows means less wasted cells. +- Each `keccak_f` takes `(NUM_ROUNDS + 1) * rows_per_round` rows. The last `rows_per_round` rows could be considered as a virtual round for "squeeze". +- Every input is padded to be a multiple of RATE (136 bytes). If the length of the logical input already matches a multiple of RATE, an additional RATE bytes are added as padding. +- Each `keccak_f` absorbs `RATE` bytes, which are splitted into `NUM_WORDS_TO_ABSORB`(17) words. Each word has `NUM_BYTES_PER_WORD`(8) bytes. +- Each of the first `NUM_WORDS_TO_ABSORB`(17) rounds of each `keccak_f` absorbs a word. +- `is_final`(anothe name is `is_enabled`) is meaningful only at the first row of the "squeeze" round. It must be true if this is the last `keccak_f` of an logical input. +- The first round of the circuit is a dummy round, which doesn't crespond to any input. + +### Raw inputs + +- In this version, we added column `word_value`/`bytes_left` to represent raw inputs. +- `word_value` is meaningful only at the first row of the first `NUM_WORDS_TO_ABSORB`(17) rounds. +- `bytes_left` is meaningful only at the first row of each round. +- `word_value` equals to the bytes from the raw input in this round's word in little-endian. +- `bytes_left` equals to the number of bytes, which haven't been absorbed from the raw input before this round. +- More details could be found in comments. + +### Keccak Results + +- In this version, we added column `hash_lo`/`hash_hi` to represent Keccak results. +- `hash_lo`/`hash_hi` of a logical input could be found at the first row of the virtual round of the last `keccak_f`. +- `hash_lo` is the low 128 bits of Keccak results. `hash_hi` is the high 128 bits of Keccak results. + +### Example + +In this version, we care more about the first row of each round(`offset = x * rows_per_round`). So we only show the first row of each round in the following example. +Let's say `rows_per_round = 10` and `inputs = [[], [0x89, 0x88, .., 0x01]]`. The corresponding table is: + +| row | input idx | round | word_value | bytes_left | is_final | hash_lo | hash_hi | +| ------------- | --------- | ----- | -------------------- | ---------- | -------- | ------- | ------- | +| 0 (dummy) | - | - | - | - | false | - | - | +| 10 | 0 | 1 | `0` | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 170 | 0 | 17 | `0` | 0 | - | - | - | +| 180 | 0 | 18 | - | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 250 (squeeze) | 0 | 25 | - | 0 | true | RESULT | RESULT | +| 260 | 1 | 1 | `0x8283848586878889` | 137 | - | - | - | +| 270 | 1 | 2 | `0x7A7B7C7D7E7F8081` | 129 | - | - | - | +| ... | 1 | ... | ... | ... | - | - | - | +| 420 | 1 | 17 | `0x0203040506070809` | 9 | - | - | - | +| 430 | 1 | 18 | - | 1 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 500 (squeeze) | 1 | 25 | - | 0 | false | - | - | +| 510 | 1 | 1 | `0x01` | 1 | - | - | - | +| 520 | 1 | 2 | - | 0 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 750 (squeeze) | 1 | 25 | - | 0 | true | RESULT | RESULT | + +### Change Details + +- Removed column `input_rlc`/`input_len` and related gates. +- Removed column `output_rlc` and related gates. +- Removed challenges. +- Refactored the folder structure to follow [Scroll's repo](https://github.com/scroll-tech/zkevm-circuits/tree/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/keccak_circuit). `mod.rs` and `witness.rs` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/develop/zkevm-circuits/src/keccak_circuit.rs). `KeccakTable` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/table.rs#L1308). +- Imported utilites from [PSE zkevm-circuits repo](https://github.com/privacy-scaling-explorations/zkevm-circuits/blob/588b8b8c55bf639fc5cbf7eae575da922ea7f1fd/zkevm-circuits/src/util/word.rs). + +## Component + +Keccak component circuits and utilities based on halo2-lib. + +### Motivation + +Move expensive Keccak computation into standalone circuits(**Component Circuits**) and circuits with actual business logic(**App Circuits**) can read Keccak results from component circuits. Then we achieve better scalability - the maximum size of a single circuit could be managed and component/app circuits could be proved in paralle. + +### Output + +Logically a component circuit outputs 3 columns `lookup_key`, `hash_lo`, `hash_hi` with `capacity` rows, where `capacity` is a configurable parameter and it means the maximum number of keccak_f this circuit can perform. + +- `lookup_key` can be cheaply derived from a bytes input. Specs can be found at `keccak::component::encode::encode_native_input`. Also `keccak::component::encode` provides some utilities to encode bytes inputs in halo2-lib. +- `hash_lo`/`hash_hi` are low/high 128 bits of the corresponding Keccak result. + +There 2 ways to publish circuit outputs: + +- Publish all these 3 columns as 3 public instance columns. +- Publish the commitment of all these 3 columns as a single public instance. + +Developers can choose either way according to their needs. Specs of these 2 ways can be found at `keccak::component::circuit::shard::KeccakComponentShardCircuit::publish_outputs`. + +`keccak::component::output` provides utilities to compute component circuit outputs for given inputs. App circuits could use these utilities to load Keccak results before witness generation of component circuits. + +### Lookup Key Encode + +For easier understanding specs at `keccak::component::encode::encode_native_input`, here we provide an example of encoding `[0x89, 0x88, .., 0x01]`(137 bytes): +| keccak_f| round | word | witness | Note | +|---------|-------|------|---------| ---- | +| 0 | 1 | `0x8283848586878889` | - | | +| 0 | 2 | `0x7A7B7C7D7E7F8081` | `0x7A7B7C7D7E7F808182838485868788890000000000000089` | [length, word[0], word[1]] | +| 0 | 3 | `0x7273747576777879` | - | | +| 0 | 4 | `0x6A6B6C6D6E6F7071` | - | | +| 0 | 5 | `0x6263646566676869` | `0x62636465666768696A6B6C6D6E6F70717273747576777879` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 0 | 15 | `0x1213141516171819` | - | | +| 0 | 16 | `0x0A0B0C0D0E0F1011` | - | | +| 0 | 17 | `0x0203040506070809` | `0x02030405060708090A0B0C0D0E0F10111213141516171819` | [word[15], word[16], word[17]] | +| 1 | 1 | `0x0000000000000001` | - | | +| 1 | 2 | `0x0000000000000000` | `0x000000000000000000000000000000010000000000000000` | [0, word[0], word[1]] | +| 1 | 3 | `0x0000000000000000` | - | | +| 1 | 4 | `0x0000000000000000` | - | | +| 1 | 5 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 1 | 15 | `0x0000000000000000` | - | | +| 1 | 16 | `0x0000000000000000` | - | | +| 1 | 17 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[15], word[16], word[17]] | + +The raw input is transformed into `payload = [0x7A7B7C7D7E7F808182838485868788890000000000000089, 0x62636465666768696A6B6C6D6E6F70717273747576777879, ... , 0x02030405060708090A0B0C0D0E0F10111213141516171819, 0x000000000000000000000000000000010000000000000000, 0x000000000000000000000000000000000000000000000000, ... , 0x000000000000000000000000000000000000000000000000]`. 2 keccak_fs, 6 witnesses each keecak_f, 12 witnesses in total. + +Finally the lookup key will be `Poseidon(payload)`. + +### Shard Circuit + +Implementation: `keccak::component::circuit::shard::KeccakComponentShardCircuit` + +- Shard circuits are the circuits that actually perform Keccak computation. +- Logically shard circuits take an array of bytes as inputs. +- Shard circuits follow the component output format above. +- Shard circuits have a configurable parameter `capacity`, which is the maximum number of keccak_f this circuit can perform. +- Shard circuits' outputs have Keccak results of all logical inputs. Outputs are padded into `capacity` rows with Keccak results of "". Paddings might be inserted between Keccak results of logical inputs. + +### Aggregation Circuit + +Aggregation circuits aggregate Keccak results of shard circuits and smaller aggregation circuits. Aggregation circuits can bring better scalability. + +Implementation is TODO. diff --git a/hashes/zkevm/src/keccak/component/circuit/mod.rs b/hashes/zkevm/src/keccak/component/circuit/mod.rs new file mode 100644 index 00000000..27f33642 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/mod.rs @@ -0,0 +1,3 @@ +pub mod shard; +#[cfg(test)] +mod tests; diff --git a/hashes/zkevm/src/keccak/component/circuit/shard.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs new file mode 100644 index 00000000..f818f4d6 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -0,0 +1,479 @@ +use std::cell::RefCell; + +use crate::{ + keccak::{ + component::{ + encode::{ + get_words_to_witness_multipliers, num_poseidon_absorb_per_keccak_f, + num_word_per_witness, + }, + output::{dummy_circuit_output, KeccakCircuitOutput}, + param::*, + }, + vanilla::{ + keccak_packed_multi::get_num_keccak_f, param::*, witness::multi_keccak, + KeccakAssignedRow, KeccakCircuitConfig, KeccakConfigParams, + }, + }, + util::eth_types::Field, +}; +use getset::{CopyGetters, Getters}; +use halo2_base::{ + gates::{ + circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig}, + flex_gate::MultiPhaseThreadBreakPoints, + GateChip, GateInstructions, + }, + halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + plonk::{Circuit, ConstraintSystem, Error}, + }, + poseidon::hasher::{ + spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactOutput, + PoseidonHasher, + }, + safe_types::{SafeBool, SafeTypeChip}, + AssignedValue, Context, + QuantumCell::Constant, +}; +use itertools::Itertools; + +/// Keccak Component Shard Circuit +#[derive(Getters)] +pub struct KeccakComponentShardCircuit { + inputs: Vec>, + + /// Parameters of this circuit. The same parameters always construct the same circuit. + #[getset(get = "pub")] + params: KeccakComponentShardCircuitParams, + + base_circuit_builder: RefCell>, + hasher: RefCell>, + gate_chip: GateChip, +} + +/// Parameters of KeccakComponentCircuit. +#[derive(Default, Clone, CopyGetters)] +pub struct KeccakComponentShardCircuitParams { + /// This circuit has 2^k rows. + #[getset(get_copy = "pub")] + k: usize, + // Number of unusable rows withhold by Halo2. + #[getset(get_copy = "pub")] + num_unusable_row: usize, + /// Max keccak_f this circuits can aceept. The circuit can at most process of inputs + /// with < NUM_BYTES_TO_ABSORB bytes or an input with * NUM_BYTES_TO_ABSORB - 1 bytes. + #[getset(get_copy = "pub")] + capacity: usize, + // If true, publish raw outputs. Otherwise, publish Poseidon commitment of raw outputs. + #[getset(get_copy = "pub")] + publish_raw_outputs: bool, + + // Derived parameters of sub-circuits. + pub keccak_circuit_params: KeccakConfigParams, + pub base_circuit_params: BaseCircuitParams, +} + +impl KeccakComponentShardCircuitParams { + /// Create a new KeccakComponentShardCircuitParams. + pub fn new( + k: usize, + num_unusable_row: usize, + capacity: usize, + publish_raw_outputs: bool, + ) -> Self { + assert!(1 << k > num_unusable_row, "Number of unusable rows must be less than 2^k"); + let max_rows = (1 << k) - num_unusable_row; + // Derived from [crate::keccak::native_circuit::keccak_packed_multi::get_keccak_capacity]. + let rows_per_round = max_rows / (capacity * (NUM_ROUNDS + 1) + 1 + NUM_WORDS_TO_ABSORB); + assert!(rows_per_round > 0, "No enough rows for the speficied capacity"); + let keccak_circuit_params = KeccakConfigParams { k: k as u32, rows_per_round }; + let base_circuit_params = BaseCircuitParams { + k, + lookup_bits: None, + num_instance_columns: if publish_raw_outputs { + OUTPUT_NUM_COL_RAW + } else { + OUTPUT_NUM_COL_COMMIT + }, + ..Default::default() + }; + Self { + k, + num_unusable_row, + capacity, + publish_raw_outputs, + keccak_circuit_params, + base_circuit_params, + } + } +} + +/// Circuit::Config for Keccak Component Shard Circuit. +#[derive(Clone)] +pub struct KeccakComponentShardConfig { + pub base_circuit_config: BaseConfig, + pub keccak_circuit_config: KeccakCircuitConfig, +} + +impl Circuit for KeccakComponentShardCircuit { + type Config = KeccakComponentShardConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = KeccakComponentShardCircuitParams; + + fn params(&self) -> Self::Params { + self.params.clone() + } + + /// Creates a new instance of the [KeccakCoprocessorLeafCircuit] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using [`BaseConfigParams`] + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let keccak_circuit_config = KeccakCircuitConfig::new(meta, params.keccak_circuit_params); + let base_circuit_params = params.base_circuit_params; + // BaseCircuitBuilder::configure_with_params must be called in the end in order to get the correct + // unusable_rows. + let base_circuit_config = + BaseCircuitBuilder::configure_with_params(meta, base_circuit_params.clone()); + Self::Config { base_circuit_config, keccak_circuit_config } + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!("You must use configure_with_params"); + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let k = self.params.k; + config.keccak_circuit_config.load_aux_tables(&mut layouter, k as u32)?; + let mut keccak_assigned_rows: Vec> = Vec::default(); + layouter.assign_region( + || "keccak circuit", + |mut region| { + let (keccak_rows, _) = multi_keccak::( + &self.inputs, + Some(self.params.capacity), + self.params.keccak_circuit_params, + ); + keccak_assigned_rows = + config.keccak_circuit_config.assign(&mut region, &keccak_rows); + Ok(()) + }, + )?; + + // Base circuit witness generation. + let loaded_keccak_fs = self.load_keccak_assigned_rows(keccak_assigned_rows); + self.generate_base_circuit_witnesses(&loaded_keccak_fs); + + self.base_circuit_builder.borrow().synthesize(config.base_circuit_config, layouter)?; + + // Reset the circuit to the initial state so synthesize could be called multiple times. + self.base_circuit_builder.borrow_mut().clear(); + self.hasher.borrow_mut().clear(); + Ok(()) + } +} + +/// Witnesses of a keccak_f which are necessary to be loaded into halo2-lib. +#[derive(Clone, Copy, Debug, CopyGetters, Getters)] +pub struct LoadedKeccakF { + /// bytes_left of the first row of the first round of this keccak_f. This could be used to determine the length of the input. + #[getset(get_copy = "pub")] + pub(crate) bytes_left: AssignedValue, + /// Input words (u64) of this keccak_f. + #[getset(get = "pub")] + pub(crate) word_values: [AssignedValue; NUM_WORDS_TO_ABSORB], + /// The output of this keccak_f. is_final/hash_lo/hash_hi come from the first row of the last round(NUM_ROUNDS). + #[getset(get_copy = "pub")] + pub(crate) is_final: SafeBool, + /// The lower 16 bits (in big-endian, 16..) of the output of this keccak_f. + #[getset(get_copy = "pub")] + pub(crate) hash_lo: AssignedValue, + /// The high 16 bits (in big-endian, ..16) of the output of this keccak_f. + #[getset(get_copy = "pub")] + pub(crate) hash_hi: AssignedValue, +} + +impl LoadedKeccakF { + pub fn new( + bytes_left: AssignedValue, + word_values: [AssignedValue; NUM_WORDS_TO_ABSORB], + is_final: SafeBool, + hash_lo: AssignedValue, + hash_hi: AssignedValue, + ) -> Self { + Self { bytes_left, word_values, is_final, hash_lo, hash_hi } + } +} + +impl KeccakComponentShardCircuit { + /// Create a new KeccakComponentShardCircuit. + pub fn new( + inputs: Vec>, + params: KeccakComponentShardCircuitParams, + witness_gen_only: bool, + ) -> Self { + let input_size = inputs.iter().map(|input| get_num_keccak_f(input.len())).sum::(); + assert!(input_size < params.capacity, "Input size exceeds capacity"); + let mut base_circuit_builder = BaseCircuitBuilder::new(witness_gen_only); + base_circuit_builder.set_params(params.base_circuit_params.clone()); + Self { + inputs, + params, + base_circuit_builder: RefCell::new(base_circuit_builder), + hasher: RefCell::new(create_hasher()), + gate_chip: GateChip::new(), + } + } + + /// Get break points of BaseCircuitBuilder. + pub fn base_circuit_break_points(&self) -> MultiPhaseThreadBreakPoints { + self.base_circuit_builder.borrow().break_points() + } + + /// Set break points of BaseCircuitBuilder. + pub fn set_base_circuit_break_points(&self, break_points: MultiPhaseThreadBreakPoints) { + self.base_circuit_builder.borrow_mut().set_break_points(break_points); + } + + pub fn update_base_circuit_params(&mut self, params: &BaseCircuitParams) { + self.params.base_circuit_params = params.clone(); + self.base_circuit_builder.borrow_mut().set_params(params.clone()); + } + + /// Simulate witness generation of the base circuit to determine BaseCircuitParams because the number of columns + /// of the base circuit can only be known after witness generation. + pub fn calculate_base_circuit_params( + params: &KeccakComponentShardCircuitParams, + ) -> BaseCircuitParams { + // Create a simulation circuit to calculate base circuit parameters. + let simulation_circuit = Self::new(vec![], params.clone(), false); + let loaded_keccak_fs = simulation_circuit.mock_load_keccak_assigned_rows(); + simulation_circuit.generate_base_circuit_witnesses(&loaded_keccak_fs); + + let base_circuit_params = simulation_circuit + .base_circuit_builder + .borrow_mut() + .calculate_params(Some(params.num_unusable_row)); + // prevent drop warnings + simulation_circuit.base_circuit_builder.borrow_mut().clear(); + + base_circuit_params + } + + /// Mock loading Keccak assigned rows from Keccak circuit. This function doesn't create any witnesses/constraints. + fn mock_load_keccak_assigned_rows(&self) -> Vec> { + let base_circuit_builder = self.base_circuit_builder.borrow(); + let mut copy_manager = base_circuit_builder.core().copy_manager.lock().unwrap(); + (0..self.params.capacity) + .map(|_| LoadedKeccakF { + bytes_left: copy_manager.mock_external_assigned(F::ZERO), + word_values: core::array::from_fn(|_| copy_manager.mock_external_assigned(F::ZERO)), + is_final: SafeTypeChip::unsafe_to_bool( + copy_manager.mock_external_assigned(F::ZERO), + ), + hash_lo: copy_manager.mock_external_assigned(F::ZERO), + hash_hi: copy_manager.mock_external_assigned(F::ZERO), + }) + .collect_vec() + } + + /// Load needed witnesses into halo2-lib from keccak assigned rows. This function doesn't create any witnesses/constraints. + fn load_keccak_assigned_rows( + &self, + assigned_rows: Vec>, + ) -> Vec> { + let rows_per_round = self.params.keccak_circuit_params.rows_per_round; + let base_circuit_builder = self.base_circuit_builder.borrow(); + let mut copy_manager = base_circuit_builder.core().copy_manager.lock().unwrap(); + assigned_rows + .into_iter() + .step_by(rows_per_round) + // Skip the first round which is dummy. + .skip(1) + .chunks(NUM_ROUNDS + 1) + .into_iter() + .map(|rounds| { + let mut rounds = rounds.collect_vec(); + assert_eq!(rounds.len(), NUM_ROUNDS + 1); + let bytes_left = copy_manager.load_external_assigned(rounds[0].bytes_left.clone()); + let output_row = rounds.pop().unwrap(); + let word_values = core::array::from_fn(|i| { + let assigned_row = &rounds[i]; + copy_manager.load_external_assigned(assigned_row.word_value.clone()) + }); + let is_final = SafeTypeChip::unsafe_to_bool( + copy_manager.load_external_assigned(output_row.is_final), + ); + let hash_lo = copy_manager.load_external_assigned(output_row.hash_lo); + let hash_hi = copy_manager.load_external_assigned(output_row.hash_hi); + LoadedKeccakF { bytes_left, word_values, is_final, hash_lo, hash_hi } + }) + .collect() + } + + /// Generate witnesses of the base circuit. + fn generate_base_circuit_witnesses(&self, loaded_keccak_fs: &[LoadedKeccakF]) { + let gate = &self.gate_chip; + let circuit_final_outputs = { + let mut base_circuit_builder_mut = self.base_circuit_builder.borrow_mut(); + let ctx = base_circuit_builder_mut.main(0); + let mut hasher = self.hasher.borrow_mut(); + hasher.initialize_consts(ctx, gate); + + let lookup_key_per_keccak_f = + encode_inputs_from_keccak_fs(ctx, gate, &hasher, loaded_keccak_fs); + Self::generate_circuit_final_outputs( + ctx, + gate, + &lookup_key_per_keccak_f, + loaded_keccak_fs, + ) + }; + self.publish_outputs(&circuit_final_outputs); + } + + /// Combine lookup keys and Keccak results to generate final outputs of the circuit. + pub fn generate_circuit_final_outputs( + ctx: &mut Context, + gate: &impl GateInstructions, + lookup_key_per_keccak_f: &[PoseidonCompactOutput], + loaded_keccak_fs: &[LoadedKeccakF], + ) -> Vec>> { + let KeccakCircuitOutput { + key: dummy_key_val, + hash_lo: dummy_keccak_val_lo, + hash_hi: dummy_keccak_val_hi, + } = dummy_circuit_output::(); + + // Dummy row for keccak_fs with is_final = false. The corresponding logical input is empty. + let dummy_key_witness = ctx.load_constant(dummy_key_val); + let dummy_keccak_lo_witness = ctx.load_constant(dummy_keccak_val_lo); + let dummy_keccak_hi_witness = ctx.load_constant(dummy_keccak_val_hi); + + let mut circuit_final_outputs = Vec::with_capacity(loaded_keccak_fs.len()); + for (compact_output, loaded_keccak_f) in + lookup_key_per_keccak_f.iter().zip_eq(loaded_keccak_fs) + { + let is_final = AssignedValue::from(loaded_keccak_f.is_final); + let key = gate.select(ctx, *compact_output.hash(), dummy_key_witness, is_final); + let hash_lo = + gate.select(ctx, loaded_keccak_f.hash_lo, dummy_keccak_lo_witness, is_final); + let hash_hi = + gate.select(ctx, loaded_keccak_f.hash_hi, dummy_keccak_hi_witness, is_final); + circuit_final_outputs.push(KeccakCircuitOutput { key, hash_lo, hash_hi }); + } + circuit_final_outputs + } + + /// Publish outputs of the circuit as public instances. + fn publish_outputs(&self, outputs: &[KeccakCircuitOutput>]) { + // The length of outputs should always equal to params.capacity. + assert_eq!(outputs.len(), self.params.capacity); + if !self.params.publish_raw_outputs { + let gate = &self.gate_chip; + let mut base_circuit_builder_mut = self.base_circuit_builder.borrow_mut(); + let ctx = base_circuit_builder_mut.main(0); + + // TODO: wrap this into a function which should be shared wiht App circuits. + let output_commitment = self.hasher.borrow().hash_fix_len_array( + ctx, + gate, + &outputs + .iter() + .flat_map(|output| [output.key, output.hash_lo, output.hash_hi]) + .collect_vec(), + ); + + let assigned_instances = &mut base_circuit_builder_mut.assigned_instances; + // The commitment should be in the first row. + assert!(assigned_instances[OUTPUT_COL_IDX_COMMIT].is_empty()); + assigned_instances[OUTPUT_COL_IDX_COMMIT].push(output_commitment); + } else { + let assigned_instances = &mut self.base_circuit_builder.borrow_mut().assigned_instances; + + // Outputs should be in the top of instance columns. + assert!(assigned_instances[OUTPUT_COL_IDX_KEY].is_empty()); + assert!(assigned_instances[OUTPUT_COL_IDX_HASH_LO].is_empty()); + assert!(assigned_instances[OUTPUT_COL_IDX_HASH_HI].is_empty()); + for output in outputs { + assigned_instances[OUTPUT_COL_IDX_KEY].push(output.key); + assigned_instances[OUTPUT_COL_IDX_HASH_LO].push(output.hash_lo); + assigned_instances[OUTPUT_COL_IDX_HASH_HI].push(output.hash_hi); + } + } + } +} + +pub(crate) fn create_hasher() -> PoseidonHasher { + // Construct in-circuit Poseidon hasher. + let spec = OptimizedPoseidonSpec::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(); + PoseidonHasher::::new(spec) +} + +/// Encode raw inputs from Keccak circuit witnesses into lookup keys. +/// +/// Each element in the return value corrresponds to a Keccak chunk. If is_final = true, this element is the lookup key of the corresponding logical input. +pub fn encode_inputs_from_keccak_fs( + ctx: &mut Context, + gate: &impl GateInstructions, + initialized_hasher: &PoseidonHasher, + loaded_keccak_fs: &[LoadedKeccakF], +) -> Vec> { + // Circuit parameters + let num_poseidon_absorb_per_keccak_f = num_poseidon_absorb_per_keccak_f::(); + let num_word_per_witness = num_word_per_witness::(); + let num_witness_per_keccak_f = POSEIDON_RATE * num_poseidon_absorb_per_keccak_f; + + // Constant witnesses + let one_const = ctx.load_constant(F::ONE); + let zero_const = ctx.load_zero(); + let multipliers_val = get_words_to_witness_multipliers::() + .into_iter() + .map(|multiplier| Constant(multiplier)) + .collect_vec(); + + let mut compact_chunk_inputs = Vec::with_capacity(loaded_keccak_fs.len()); + let mut last_is_final = one_const; + for loaded_keccak_f in loaded_keccak_fs { + // If this keccak_f is the last of a logical input. + let is_final = loaded_keccak_f.is_final; + let mut poseidon_absorb_data = Vec::with_capacity(num_witness_per_keccak_f); + + // First witness of a keccak_f: [, word_values[0], word_values[1], ...] + // is the length of the input if this is the first keccak_f of a logical input. Otherwise 0. + let mut words = Vec::with_capacity(num_word_per_witness); + let input_bytes_len = gate.mul(ctx, loaded_keccak_f.bytes_left, last_is_final); + words.push(input_bytes_len); + words.extend_from_slice(&loaded_keccak_f.word_values); + + // Turn every num_word_per_witness words later into a witness. + for words in words.chunks(num_word_per_witness) { + let mut words = words.to_vec(); + words.resize(num_word_per_witness, zero_const); + let witness = gate.inner_product(ctx, words, multipliers_val.clone()); + poseidon_absorb_data.push(witness); + } + // Pad 0s to make sure poseidon_absorb_data.len() % RATE == 0. + poseidon_absorb_data.resize(num_witness_per_keccak_f, zero_const); + let compact_inputs: Vec<_> = poseidon_absorb_data + .chunks_exact(POSEIDON_RATE) + .map(|chunk| chunk.to_vec().try_into().unwrap()) + .collect_vec(); + debug_assert_eq!(compact_inputs.len(), num_poseidon_absorb_per_keccak_f); + compact_chunk_inputs.push(PoseidonCompactChunkInput::new(compact_inputs, is_final)); + last_is_final = is_final.into(); + } + + initialized_hasher.hash_compact_chunk_inputs(ctx, gate, &compact_chunk_inputs) +} diff --git a/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs b/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs new file mode 100644 index 00000000..c77c1a0c --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs @@ -0,0 +1 @@ +pub mod shard; diff --git a/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs b/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs new file mode 100644 index 00000000..17726327 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs @@ -0,0 +1,193 @@ +use crate::{ + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::Bn256, + halo2curves::bn256::Fr, + plonk::{keygen_pk, keygen_vk}, + }, + keccak::component::{ + circuit::shard::{KeccakComponentShardCircuit, KeccakComponentShardCircuitParams}, + output::{calculate_circuit_outputs_commit, multi_inputs_to_circuit_outputs}, + }, +}; + +use halo2_base::{ + halo2_proofs::poly::kzg::commitment::ParamsKZG, + utils::testing::{check_proof_with_instances, gen_proof_with_instances}, +}; +use itertools::Itertools; +use rand_core::OsRng; + +#[test] +fn test_mock_shard_circuit_raw_outputs() { + let k: usize = 18; + let num_unusable_row: usize = 109; + let capacity: usize = 10; + let publish_raw_outputs: bool = true; + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + + let mut params = + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + let base_circuit_params = + KeccakComponentShardCircuit::::calculate_base_circuit_params(¶ms); + params.base_circuit_params = base_circuit_params; + let circuit = KeccakComponentShardCircuit::::new(inputs.clone(), params.clone(), false); + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); + + let instances = vec![ + circuit_outputs.iter().map(|o| o.key).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_lo).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_hi).collect_vec(), + ]; + + let prover = MockProver::::run(k as u32, &circuit, instances).unwrap(); + prover.assert_satisfied(); +} + +#[test] +fn test_prove_shard_circuit_raw_outputs() { + let _ = env_logger::builder().is_test(true).try_init(); + + let k: usize = 18; + let num_unusable_row: usize = 109; + let capacity: usize = 10; + let publish_raw_outputs: bool = true; + + let inputs = vec![]; + let mut circuit_params = + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + let base_circuit_params = + KeccakComponentShardCircuit::::calculate_base_circuit_params(&circuit_params); + circuit_params.base_circuit_params = base_circuit_params; + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params.clone(), false); + + let params = ParamsKZG::::setup(k as u32, OsRng); + + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, circuit_params.capacity()); + let instances: Vec> = vec![ + circuit_outputs.iter().map(|o| o.key).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_lo).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_hi).collect_vec(), + ]; + + let break_points = circuit.base_circuit_break_points(); + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params, true); + circuit.set_base_circuit_break_points(break_points); + + let proof = gen_proof_with_instances( + ¶ms, + &pk, + circuit, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + ); + check_proof_with_instances( + ¶ms, + pk.get_vk(), + &proof, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + true, + ); +} + +#[test] +fn test_mock_shard_circuit_commit() { + let k: usize = 18; + let num_unusable_row: usize = 109; + let capacity: usize = 10; + let publish_raw_outputs: bool = false; + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + + let mut params = + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + let base_circuit_params = + KeccakComponentShardCircuit::::calculate_base_circuit_params(¶ms); + params.base_circuit_params = base_circuit_params; + let circuit = KeccakComponentShardCircuit::::new(inputs.clone(), params.clone(), false); + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); + + let instances = vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]]; + + let prover = MockProver::::run(k as u32, &circuit, instances).unwrap(); + prover.assert_satisfied(); +} + +#[test] +fn test_prove_shard_circuit_commit() { + let _ = env_logger::builder().is_test(true).try_init(); + + let k: usize = 18; + let num_unusable_row: usize = 109; + let capacity: usize = 10; + let publish_raw_outputs: bool = false; + + let inputs = vec![]; + let mut circuit_params = + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + let base_circuit_params = + KeccakComponentShardCircuit::::calculate_base_circuit_params(&circuit_params); + circuit_params.base_circuit_params = base_circuit_params; + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params.clone(), false); + + let params = ParamsKZG::::setup(k as u32, OsRng); + + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + + let break_points = circuit.base_circuit_break_points(); + let circuit = + KeccakComponentShardCircuit::::new(inputs.clone(), circuit_params.clone(), true); + circuit.set_base_circuit_break_points(break_points); + + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, circuit_params.capacity()); + let instances = vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]]; + + let proof = gen_proof_with_instances( + ¶ms, + &pk, + circuit, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + ); + check_proof_with_instances( + ¶ms, + pk.get_vk(), + &proof, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + true, + ); +} diff --git a/hashes/zkevm/src/keccak/component/encode.rs b/hashes/zkevm/src/keccak/component/encode.rs new file mode 100644 index 00000000..33230bee --- /dev/null +++ b/hashes/zkevm/src/keccak/component/encode.rs @@ -0,0 +1,261 @@ +use halo2_base::{ + gates::{GateInstructions, RangeInstructions}, + poseidon::hasher::{PoseidonCompactChunkInput, PoseidonHasher}, + safe_types::{FixLenBytesVec, SafeByte, SafeTypeChip, VarLenBytesVec}, + utils::bit_length, + AssignedValue, Context, + QuantumCell::Constant, +}; +use itertools::Itertools; +use num_bigint::BigUint; +use snark_verifier::loader::native::NativeLoader; + +use crate::{ + keccak::vanilla::{keccak_packed_multi::get_num_keccak_f, param::*}, + util::eth_types::Field, +}; + +use super::param::*; + +// TODO: Abstract this module into a trait for all component circuits. + +/// Module to encode raw inputs into lookup keys for looking up keccak results. The encoding is +/// designed to be efficient in component circuits. + +/// Encode a native input bytes into its corresponding lookup key. This function can be considered as the spec of the encoding. +pub fn encode_native_input(bytes: &[u8]) -> F { + assert!(NUM_BITS_PER_WORD <= u128::BITS as usize); + let multipliers: Vec = get_words_to_witness_multipliers::(); + let num_word_per_witness = num_word_per_witness::(); + let len = bytes.len(); + + // Divide the bytes input into Keccak words(each word has NUM_BYTES_PER_WORD bytes). + let mut words = bytes + .chunks(NUM_BYTES_PER_WORD) + .map(|chunk| { + let mut padded_chunk = [0; u128::BITS as usize / NUM_BITS_PER_BYTE]; + padded_chunk[..chunk.len()].copy_from_slice(chunk); + u128::from_le_bytes(padded_chunk) + }) + .collect_vec(); + // An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + if len % NUM_BYTES_TO_ABSORB == 0 { + words.extend([0; NUM_WORDS_TO_ABSORB]); + } + // 1. Split Keccak words into keccak_fs(each keccak_f has NUM_WORDS_TO_ABSORB). + // 2. Append an extra word into the beginning of each keccak_f. In the first keccak_f, this word is the byte length of the input. Otherwise 0. + let words_per_keccak_f = words + .chunks(NUM_WORDS_TO_ABSORB) + .enumerate() + .map(|(i, chunk)| { + let mut padded_chunk = [0; NUM_WORDS_TO_ABSORB + 1]; + padded_chunk[0] = if i == 0 { len as u128 } else { 0 }; + padded_chunk[1..(chunk.len() + 1)].copy_from_slice(chunk); + padded_chunk + }) + .collect_vec(); + // Compress every num_word_per_witness words into a witness. + let witnesses_per_keccak_f = words_per_keccak_f + .iter() + .map(|chunk| { + chunk + .chunks(num_word_per_witness) + .map(|c| { + c.iter().zip(multipliers.iter()).fold(F::ZERO, |acc, (word, multipiler)| { + acc + F::from_u128(*word) * multipiler + }) + }) + .collect_vec() + }) + .collect_vec(); + // Absorb witnesses keccak_f by keccak_f. + let mut native_poseidon_sponge = + snark_verifier::util::hash::Poseidon::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(&NativeLoader); + for witnesses in witnesses_per_keccak_f { + for absorbing in witnesses.chunks(POSEIDON_RATE) { + // To avoid absorbing witnesses crossing keccak_fs together, pad 0s to make sure absorb.len() == RATE. + let mut padded_absorb = [F::ZERO; POSEIDON_RATE]; + padded_absorb[..absorbing.len()].copy_from_slice(absorbing); + native_poseidon_sponge.update(&padded_absorb); + } + } + native_poseidon_sponge.squeeze() +} + +/// Encode a VarLenBytesVec into its corresponding lookup key. +pub fn encode_var_len_bytes_vec( + ctx: &mut Context, + range_chip: &impl RangeInstructions, + initialized_hasher: &PoseidonHasher, + bytes: &VarLenBytesVec, +) -> AssignedValue { + let max_len = bytes.max_len(); + let max_num_keccak_f = get_num_keccak_f(max_len); + // num_keccak_f = len / NUM_BYTES_TO_ABSORB + 1 + let num_bits = bit_length(max_len as u64); + let (num_keccak_f, _) = + range_chip.div_mod(ctx, *bytes.len(), BigUint::from(NUM_BYTES_TO_ABSORB), num_bits); + let f_indicator = range_chip.gate().idx_to_indicator(ctx, num_keccak_f, max_num_keccak_f); + + let bytes = bytes.ensure_0_padding(ctx, range_chip.gate()); + let chunk_input_per_f = format_input(ctx, range_chip.gate(), bytes.bytes(), *bytes.len()); + + let chunk_inputs = chunk_input_per_f + .into_iter() + .zip(&f_indicator) + .map(|(chunk_input, is_final)| { + let is_final = SafeTypeChip::unsafe_to_bool(*is_final); + PoseidonCompactChunkInput::new(chunk_input, is_final) + }) + .collect_vec(); + + let compact_outputs = + initialized_hasher.hash_compact_chunk_inputs(ctx, range_chip.gate(), &chunk_inputs); + range_chip.gate().select_by_indicator( + ctx, + compact_outputs.into_iter().map(|o| *o.hash()), + f_indicator, + ) +} + +/// Encode a FixLenBytesVec into its corresponding lookup key. +pub fn encode_fix_len_bytes_vec( + ctx: &mut Context, + gate_chip: &impl GateInstructions, + initialized_hasher: &PoseidonHasher, + bytes: &FixLenBytesVec, +) -> AssignedValue { + // Constant witnesses + let len_witness = ctx.load_constant(F::from(bytes.len() as u64)); + + let chunk_input_per_f = format_input(ctx, gate_chip, bytes.bytes(), len_witness); + let flatten_inputs = chunk_input_per_f + .into_iter() + .flat_map(|chunk_input| chunk_input.into_iter().flatten()) + .collect_vec(); + + initialized_hasher.hash_fix_len_array(ctx, gate_chip, &flatten_inputs) +} + +// For reference, when F is bn254::Fr: +// num_word_per_witness = 3 +// num_witness_per_keccak_f = 6 +// num_poseidon_absorb_per_keccak_f = 3 + +/// Number of Keccak words in each encoded input for Poseidon. +/// When `F` is `bn254::Fr`, this is 3. +pub const fn num_word_per_witness() -> usize { + (F::CAPACITY as usize) / NUM_BITS_PER_WORD +} + +/// Number of witnesses to represent inputs in a keccak_f. +/// +/// Assume the representation of is not longer than a Keccak word. +/// +/// When `F` is `bn254::Fr`, this is 6. +pub const fn num_witness_per_keccak_f() -> usize { + // With , a keccak_f could have NUM_WORDS_TO_ABSORB + 1 words. + // ceil((NUM_WORDS_TO_ABSORB + 1) / num_word_per_witness) + NUM_WORDS_TO_ABSORB / num_word_per_witness::() + 1 +} + +/// Number of Poseidon absorb rounds per keccak_f. +/// +/// When `F` is `bn254::Fr`, with our fixed `POSEIDON_RATE = 2`, this is 3. +pub const fn num_poseidon_absorb_per_keccak_f() -> usize { + // Each absorb round consumes RATE witnesses. + // ceil(num_witness_per_keccak_f / RATE) + (num_witness_per_keccak_f::() - 1) / POSEIDON_RATE + 1 +} + +pub(crate) fn get_words_to_witness_multipliers() -> Vec { + let num_word_per_witness = num_word_per_witness::(); + let mut multiplier_f = F::ONE; + let mut multipliers = Vec::with_capacity(num_word_per_witness); + multipliers.push(multiplier_f); + let base_f = F::from_u128(1u128 << NUM_BITS_PER_WORD); + for _ in 1..num_word_per_witness { + multiplier_f *= base_f; + multipliers.push(multiplier_f); + } + multipliers +} + +pub(crate) fn get_bytes_to_words_multipliers() -> Vec { + let mut multiplier_f = F::ONE; + let mut multipliers = Vec::with_capacity(NUM_BYTES_PER_WORD); + multipliers.push(multiplier_f); + let base_f = F::from_u128(1 << NUM_BITS_PER_BYTE); + for _ in 1..NUM_BYTES_PER_WORD { + multiplier_f *= base_f; + multipliers.push(multiplier_f); + } + multipliers +} + +fn format_input( + ctx: &mut Context, + gate: &impl GateInstructions, + bytes: &[SafeByte], + len: AssignedValue, +) -> Vec; POSEIDON_RATE]>> { + // Constant witnesses + let zero_const = ctx.load_zero(); + let bytes_to_words_multipliers_val = + get_bytes_to_words_multipliers::().into_iter().map(|m| Constant(m)).collect_vec(); + let words_to_witness_multipliers_val = + get_words_to_witness_multipliers::().into_iter().map(|m| Constant(m)).collect_vec(); + + let mut bytes_witnesses = bytes.to_vec(); + // Append a zero to the end because An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + bytes_witnesses.push(SafeTypeChip::unsafe_to_byte(zero_const)); + let words = bytes_witnesses + .chunks(NUM_BYTES_PER_WORD) + .map(|c| { + let len = c.len(); + let multipliers = bytes_to_words_multipliers_val[..len].to_vec(); + gate.inner_product(ctx, c.iter().map(|sb| *sb.as_ref()), multipliers) + }) + .collect_vec(); + + let words_per_f = words + .chunks(NUM_WORDS_TO_ABSORB) + .enumerate() + .map(|(i, words_per_f)| { + let mut buffer = [zero_const; NUM_WORDS_TO_ABSORB + 1]; + buffer[0] = if i == 0 { len } else { zero_const }; + buffer[1..words_per_f.len() + 1].copy_from_slice(words_per_f); + buffer + }) + .collect_vec(); + + let witnesses_per_f = words_per_f + .iter() + .map(|words| { + words + .chunks(num_word_per_witness::()) + .map(|c| { + gate.inner_product(ctx, c.to_vec(), words_to_witness_multipliers_val.clone()) + }) + .collect_vec() + }) + .collect_vec(); + + witnesses_per_f + .iter() + .map(|words| { + words + .chunks(POSEIDON_RATE) + .map(|c| { + let mut buffer = [zero_const; POSEIDON_RATE]; + buffer[..c.len()].copy_from_slice(c); + buffer + }) + .collect_vec() + }) + .collect_vec() +} diff --git a/hashes/zkevm/src/keccak/component/ingestion.rs b/hashes/zkevm/src/keccak/component/ingestion.rs new file mode 100644 index 00000000..cc0b2c3f --- /dev/null +++ b/hashes/zkevm/src/keccak/component/ingestion.rs @@ -0,0 +1,86 @@ +use ethers_core::{types::H256, utils::keccak256}; + +use crate::keccak::vanilla::param::NUM_BYTES_TO_ABSORB; + +/// Fixed length format for one keccak_f. +/// This closely matches [crate::keccak::component::circuit::shard::LoadedKeccakF]. +#[derive(Clone, Debug)] +pub struct KeccakIngestionFormat { + pub bytes_per_keccak_f: [u8; NUM_BYTES_TO_ABSORB], + /// In the first keccak_f of a full keccak, this will be the length in bytes of the input. Otherwise 0. + pub byte_len_placeholder: usize, + /// Is this the last keccak_f of a full keccak? Note that the last keccak_f includes input padding. + pub is_final: bool, + /// If `is_final = true`, the output of the full keccak, split into two 128-bit chunks. Otherwise `keccak256([])` in hi-lo form. + pub hash_lo: u128, + pub hash_hi: u128, +} + +impl Default for KeccakIngestionFormat { + fn default() -> Self { + Self::new([0; NUM_BYTES_TO_ABSORB], 0, true, H256(keccak256([]))) + } +} + +impl KeccakIngestionFormat { + fn new( + bytes_per_keccak_f: [u8; NUM_BYTES_TO_ABSORB], + byte_len_placeholder: usize, + is_final: bool, + hash: H256, + ) -> Self { + let hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap()); + let hash_hi = u128::from_be_bytes(hash[..16].try_into().unwrap()); + Self { bytes_per_keccak_f, byte_len_placeholder, is_final, hash_lo, hash_hi } + } +} + +/// We take all `requests` as a deduplicated ordered list. +/// We split each input into `KeccakIngestionFormat` chunks, one for each keccak_f needed to compute `keccak(input)`. +/// We then resize so there are exactly `capacity` total chunks. +/// +/// Very similar to [crate::keccak::component::encode::encode_native_input] except we do not do the +/// encoding part (that will be done in circuit, not natively). +/// +/// Returns `Err(true_capacity)` if `true_capacity > capacity`, where `true_capacity` is the number of keccak_f needed +/// to compute all requests. +pub fn format_requests_for_ingestion( + requests: impl IntoIterator)>, + capacity: usize, +) -> Result, usize> +where + B: AsRef<[u8]>, +{ + let mut ingestions = Vec::with_capacity(capacity); + for (input, hash) in requests { + let input = input.as_ref(); + let hash = hash.unwrap_or_else(|| H256(keccak256(input))); + let len = input.len(); + for (i, chunk) in input.chunks(NUM_BYTES_TO_ABSORB).enumerate() { + let byte_len = if i == 0 { len } else { 0 }; + let mut bytes_per_keccak_f = [0; NUM_BYTES_TO_ABSORB]; + bytes_per_keccak_f[..chunk.len()].copy_from_slice(chunk); + ingestions.push(KeccakIngestionFormat::new( + bytes_per_keccak_f, + byte_len, + false, + H256::zero(), + )); + } + // An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + if len % NUM_BYTES_TO_ABSORB == 0 { + ingestions.push(KeccakIngestionFormat::default()); + } + let last_mut = ingestions.last_mut().unwrap(); + last_mut.is_final = true; + last_mut.hash_hi = u128::from_be_bytes(hash[..16].try_into().unwrap()); + last_mut.hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap()); + } + log::info!("Actual number of keccak_f used = {}", ingestions.len()); + if ingestions.len() > capacity { + Err(ingestions.len()) + } else { + ingestions.resize_with(capacity, Default::default); + Ok(ingestions) + } +} diff --git a/hashes/zkevm/src/keccak/component/mod.rs b/hashes/zkevm/src/keccak/component/mod.rs new file mode 100644 index 00000000..13bbd303 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/mod.rs @@ -0,0 +1,12 @@ +/// Module of Keccak component circuit(s). +pub mod circuit; +/// Module of encoding raw inputs to component circuit lookup keys. +pub mod encode; +/// Module for Rust native processing of input bytes into resized fixed length format to match vanilla circuit LoadedKeccakF +pub mod ingestion; +/// Module of Keccak component circuit output. +pub mod output; +/// Module of Keccak component circuit constant parameters. +pub mod param; +#[cfg(test)] +mod tests; diff --git a/hashes/zkevm/src/keccak/component/output.rs b/hashes/zkevm/src/keccak/component/output.rs new file mode 100644 index 00000000..fa010bbe --- /dev/null +++ b/hashes/zkevm/src/keccak/component/output.rs @@ -0,0 +1,77 @@ +use super::{encode::encode_native_input, param::*}; +use crate::{keccak::vanilla::keccak_packed_multi::get_num_keccak_f, util::eth_types::Field}; +use itertools::Itertools; +use sha3::{Digest, Keccak256}; +use snark_verifier::loader::native::NativeLoader; + +/// Witnesses to be exposed as circuit outputs. +#[derive(Clone, Copy, PartialEq, Debug)] +pub struct KeccakCircuitOutput { + /// Key for App circuits to lookup keccak hash. + pub key: E, + /// Low 128 bits of Keccak hash. + pub hash_lo: E, + /// High 128 bits of Keccak hash. + pub hash_hi: E, +} + +/// Return circuit outputs of the specified Keccak corprocessor circuit for a specified input. +pub fn multi_inputs_to_circuit_outputs( + inputs: &[Vec], + capacity: usize, +) -> Vec> { + assert!(u128::BITS <= F::CAPACITY); + let mut outputs = + inputs.iter().flat_map(|input| input_to_circuit_outputs::(input)).collect_vec(); + assert!(outputs.len() <= capacity); + outputs.resize(capacity, dummy_circuit_output()); + outputs +} + +/// Return corresponding circuit outputs of a native input in bytes. An logical input could produce multiple +/// outputs. The last one is the lookup key and hash of the input. Other outputs are paddings which are the lookup +/// key and hash of an empty input. +pub fn input_to_circuit_outputs(bytes: &[u8]) -> Vec> { + assert!(u128::BITS <= F::CAPACITY); + let len = bytes.len(); + let num_keccak_f = get_num_keccak_f(len); + + let mut output = Vec::with_capacity(num_keccak_f); + output.resize(num_keccak_f - 1, dummy_circuit_output()); + + let key = encode_native_input(bytes); + let hash = Keccak256::digest(bytes); + let hash_lo = F::from_u128(u128::from_be_bytes(hash[16..].try_into().unwrap())); + let hash_hi = F::from_u128(u128::from_be_bytes(hash[..16].try_into().unwrap())); + output.push(KeccakCircuitOutput { key, hash_lo, hash_hi }); + + output +} + +/// Return the dummy circuit output for padding. +pub fn dummy_circuit_output() -> KeccakCircuitOutput { + assert!(u128::BITS <= F::CAPACITY); + let key = encode_native_input(&[]); + // Output of Keccak256::digest is big endian. + let hash = Keccak256::digest([]); + let hash_lo = F::from_u128(u128::from_be_bytes(hash[16..].try_into().unwrap())); + let hash_hi = F::from_u128(u128::from_be_bytes(hash[..16].try_into().unwrap())); + KeccakCircuitOutput { key, hash_lo, hash_hi } +} + +/// Calculate the commitment of circuit outputs. +pub fn calculate_circuit_outputs_commit(outputs: &[KeccakCircuitOutput]) -> F { + let mut native_poseidon_sponge = + snark_verifier::util::hash::Poseidon::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(&NativeLoader); + native_poseidon_sponge.update( + &outputs + .iter() + .flat_map(|output| [output.key, output.hash_lo, output.hash_hi]) + .collect_vec(), + ); + native_poseidon_sponge.squeeze() +} diff --git a/hashes/zkevm/src/keccak/component/param.rs b/hashes/zkevm/src/keccak/component/param.rs new file mode 100644 index 00000000..889d0bd9 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/param.rs @@ -0,0 +1,12 @@ +pub const OUTPUT_NUM_COL_COMMIT: usize = 1; +pub const OUTPUT_NUM_COL_RAW: usize = 3; +pub const OUTPUT_COL_IDX_COMMIT: usize = 0; +pub const OUTPUT_COL_IDX_KEY: usize = 0; +pub const OUTPUT_COL_IDX_HASH_LO: usize = 1; +pub const OUTPUT_COL_IDX_HASH_HI: usize = 2; + +pub const POSEIDON_T: usize = 3; +pub const POSEIDON_RATE: usize = 2; +pub const POSEIDON_R_F: usize = 8; +pub const POSEIDON_R_P: usize = 57; +pub const POSEIDON_SECURE_MDS: usize = 0; diff --git a/hashes/zkevm/src/keccak/component/tests/encode.rs b/hashes/zkevm/src/keccak/component/tests/encode.rs new file mode 100644 index 00000000..df576c66 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/tests/encode.rs @@ -0,0 +1,124 @@ +use ethers_core::k256::elliptic_curve::Field; +use halo2_base::{ + gates::{GateInstructions, RangeChip, RangeInstructions}, + halo2_proofs::halo2curves::bn256::Fr, + safe_types::SafeTypeChip, + utils::testing::base_test, + Context, +}; +use itertools::Itertools; + +use crate::keccak::component::{ + circuit::shard::create_hasher, + encode::{encode_fix_len_bytes_vec, encode_native_input, encode_var_len_bytes_vec}, +}; + +fn build_and_verify_encode_var_len_bytes_vec( + inputs: Vec<(Vec, usize)>, + ctx: &mut Context, + range_chip: &RangeChip, +) { + let mut hasher = create_hasher(); + hasher.initialize_consts(ctx, range_chip.gate()); + + for (input, max_len) in inputs { + let expected = encode_native_input::(&input); + let len = ctx.load_witness(Fr::from(input.len() as u64)); + let mut witnesses_val = vec![Fr::ZERO; max_len]; + witnesses_val[..input.len()] + .copy_from_slice(&input.iter().map(|b| Fr::from(*b as u64)).collect_vec()); + let input_witnesses = ctx.assign_witnesses(witnesses_val); + let var_len_bytes_vec = + SafeTypeChip::unsafe_to_var_len_bytes_vec(input_witnesses, len, max_len); + let encoded = encode_var_len_bytes_vec(ctx, range_chip, &hasher, &var_len_bytes_vec); + assert_eq!(encoded.value(), &expected); + } +} + +fn build_and_verify_encode_fix_len_bytes_vec( + inputs: Vec>, + ctx: &mut Context, + gate_chip: &impl GateInstructions, +) { + let mut hasher = create_hasher(); + hasher.initialize_consts(ctx, gate_chip); + + for input in inputs { + let expected = encode_native_input::(&input); + let len = input.len(); + let witnesses_val = input.into_iter().map(|b| Fr::from(b as u64)).collect_vec(); + let input_witnesses = ctx.assign_witnesses(witnesses_val); + let fix_len_bytes_vec = SafeTypeChip::unsafe_to_fix_len_bytes_vec(input_witnesses, len); + let encoded = encode_fix_len_bytes_vec(ctx, gate_chip, &hasher, &fix_len_bytes_vec); + assert_eq!(encoded.value(), &expected); + } +} + +#[test] +fn mock_encode_var_len_bytes_vec() { + let inputs = vec![ + (vec![], 1), + (vec![], 136), + ((1u8..135).collect_vec(), 136), + ((1u8..135).collect_vec(), 134), + ((1u8..135).collect_vec(), 137), + ((1u8..135).collect_vec(), 272), + ((1u8..135).collect_vec(), 136 * 3), + ]; + base_test().k(18).lookup_bits(4).run(|ctx: &mut Context, range_chip: &RangeChip| { + build_and_verify_encode_var_len_bytes_vec(inputs, ctx, range_chip); + }) +} + +#[test] +fn prove_encode_var_len_bytes_vec() { + let init_inputs = vec![ + (vec![], 1), + (vec![], 136), + (vec![], 136), + (vec![], 137), + (vec![], 272), + (vec![], 136 * 3), + ]; + let inputs = vec![ + (vec![], 1), + (vec![], 136), + ((1u8..135).collect_vec(), 136), + ((1u8..135).collect_vec(), 137), + ((1u8..135).collect_vec(), 272), + ((1u8..135).collect_vec(), 136 * 3), + ]; + base_test().k(18).lookup_bits(4).bench_builder( + init_inputs, + inputs, + |core, range_chip, inputs| { + let ctx = core.main(); + build_and_verify_encode_var_len_bytes_vec(inputs, ctx, range_chip); + }, + ); +} + +#[test] +fn mock_encode_fix_len_bytes_vec() { + let inputs = + vec![vec![], (1u8..135).collect_vec(), (0u8..136).collect_vec(), (0u8..211).collect_vec()]; + base_test().k(18).lookup_bits(4).run(|ctx: &mut Context, range_chip: &RangeChip| { + build_and_verify_encode_fix_len_bytes_vec(inputs, ctx, range_chip.gate()); + }); +} + +#[test] +fn prove_encode_fix_len_bytes_vec() { + let init_inputs = + vec![vec![], (2u8..136).collect_vec(), (1u8..137).collect_vec(), (2u8..213).collect_vec()]; + let inputs = + vec![vec![], (1u8..135).collect_vec(), (0u8..136).collect_vec(), (0u8..211).collect_vec()]; + base_test().k(18).lookup_bits(4).bench_builder( + init_inputs, + inputs, + |core, range_chip, inputs| { + let ctx = core.main(); + build_and_verify_encode_fix_len_bytes_vec(inputs, ctx, range_chip.gate()); + }, + ); +} diff --git a/hashes/zkevm/src/keccak/component/tests/mod.rs b/hashes/zkevm/src/keccak/component/tests/mod.rs new file mode 100644 index 00000000..520b3573 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/tests/mod.rs @@ -0,0 +1,4 @@ +#[cfg(test)] +mod encode; +#[cfg(test)] +mod output; diff --git a/hashes/zkevm/src/keccak/component/tests/output.rs b/hashes/zkevm/src/keccak/component/tests/output.rs new file mode 100644 index 00000000..c63aa352 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/tests/output.rs @@ -0,0 +1,131 @@ +use crate::keccak::component::output::{ + dummy_circuit_output, input_to_circuit_outputs, multi_inputs_to_circuit_outputs, + KeccakCircuitOutput, +}; +use halo2_base::halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; +use itertools::Itertools; +use lazy_static::lazy_static; + +lazy_static! { + static ref OUTPUT_EMPTY: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x54595a1525d3534a, + 0xf90e160f1b4648ef, + 0x34d557ddfb89da5d, + 0x04ffe3d4b8885928, + ]), + hash_lo: Fr::from_u128(0xe500b653ca82273b7bfad8045d85a470), + hash_hi: Fr::from_u128(0xc5d2460186f7233c927e7db2dcc703c0), + }; + static ref OUTPUT_0: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0xc009f26a12e2f494, + 0xb4a9d43c17609251, + 0x68068b5344cba120, + 0x1531327ea92d38ba, + ]), + hash_lo: Fr::from_u128(0x6612f7b477d66591ff96a9e064bcc98a), + hash_hi: Fr::from_u128(0xbc36789e7a1e281436464229828f817d), + }; + static ref OUTPUT_0_135: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x9a88287adab4da1c, + 0xe9ff61b507cfd8c2, + 0xdbf697a6a3ad66a1, + 0x1eb1d5cc8cdd1532, + ]), + hash_lo: Fr::from_u128(0x290b0e1706f6a82e5a595b9ce9faca62), + hash_hi: Fr::from_u128(0xcbdfd9dee5faad3818d6b06f95a219fd), + }; + static ref OUTPUT_0_136: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x39c1a578acb62676, + 0x0dc19a75e610c062, + 0x3f158e809150a14a, + 0x2367059ac8c80538, + ]), + hash_lo: Fr::from_u128(0xff11fe3e38e17df89cf5d29c7d7f807e), + hash_hi: Fr::from_u128(0x7ce759f1ab7f9ce437719970c26b0a66), + }; + static ref OUTPUT_0_200: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x379bfca638552583, + 0x1bf7bd603adec30e, + 0x05efe90ad5dbd814, + 0x053c729cb8908ccb, + ]), + hash_lo: Fr::from_u128(0xb4543f3d2703c0923c6901c2af57b890), + hash_hi: Fr::from_u128(0xbfb0aa97863e797943cf7c33bb7e880b), + }; +} + +#[test] +fn test_dummy_circuit_output() { + let KeccakCircuitOutput { key, hash_lo, hash_hi } = dummy_circuit_output::(); + assert_eq!(key, OUTPUT_EMPTY.key); + assert_eq!(hash_lo, OUTPUT_EMPTY.hash_lo); + assert_eq!(hash_hi, OUTPUT_EMPTY.hash_hi); +} + +#[test] +fn test_input_to_circuit_outputs_empty() { + let result = input_to_circuit_outputs::(&[]); + assert_eq!(result, vec![*OUTPUT_EMPTY]); +} + +#[test] +fn test_input_to_circuit_outputs_1_keccak_f() { + let result = input_to_circuit_outputs::(&[0]); + assert_eq!(result, vec![*OUTPUT_0]); +} + +#[test] +fn test_input_to_circuit_outputs_1_keccak_f_full() { + let result = input_to_circuit_outputs::(&(0..135).collect_vec()); + assert_eq!(result, vec![*OUTPUT_0_135]); +} + +#[test] +fn test_input_to_circuit_outputs_2_keccak_f_2nd_empty() { + let result = input_to_circuit_outputs::(&(0..136).collect_vec()); + assert_eq!(result, vec![*OUTPUT_EMPTY, *OUTPUT_0_136]); +} + +#[test] +fn test_input_to_circuit_outputs_2_keccak_f() { + let result = input_to_circuit_outputs::(&(0..200).collect_vec()); + assert_eq!(result, vec![*OUTPUT_EMPTY, *OUTPUT_0_200]); +} + +#[test] +fn test_multi_input_to_circuit_outputs() { + let results = multi_inputs_to_circuit_outputs::( + &[(0..135).collect_vec(), (0..200).collect_vec(), vec![], vec![0], (0..136).collect_vec()], + 10, + ); + assert_eq!( + results, + vec![ + *OUTPUT_0_135, + *OUTPUT_EMPTY, + *OUTPUT_0_200, + *OUTPUT_EMPTY, + *OUTPUT_0, + *OUTPUT_EMPTY, + *OUTPUT_0_136, + // Padding + *OUTPUT_EMPTY, + *OUTPUT_EMPTY, + *OUTPUT_EMPTY, + ] + ); +} + +#[test] +#[should_panic] +fn test_multi_input_to_circuit_outputs_exceed_capacity() { + let _ = multi_inputs_to_circuit_outputs::( + &[(0..135).collect_vec(), (0..200).collect_vec(), vec![], vec![0], (0..136).collect_vec()], + 2, + ); +} diff --git a/hashes/zkevm/src/keccak/mod.rs b/hashes/zkevm/src/keccak/mod.rs new file mode 100644 index 00000000..dd9a660b --- /dev/null +++ b/hashes/zkevm/src/keccak/mod.rs @@ -0,0 +1,4 @@ +/// Module for component circuits. +pub mod component; +/// Module for Keccak circuits in vanilla halo2. +pub mod vanilla; diff --git a/hashes/zkevm/src/keccak/vanilla/cell_manager.rs b/hashes/zkevm/src/keccak/vanilla/cell_manager.rs new file mode 100644 index 00000000..04c67a6b --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/cell_manager.rs @@ -0,0 +1,204 @@ +use crate::{ + halo2_proofs::{ + halo2curves::ff::PrimeField, + plonk::{Advice, Column, ConstraintSystem, Expression, VirtualCells}, + poly::Rotation, + }, + util::expression::Expr, +}; + +use super::KeccakRegion; + +#[derive(Clone, Debug)] +pub(crate) struct Cell { + pub(crate) expression: Expression, + pub(crate) column_expression: Expression, + pub(crate) column: Option>, + pub(crate) column_idx: usize, + pub(crate) rotation: i32, +} + +impl Cell { + pub(crate) fn new( + meta: &mut VirtualCells, + column: Column, + column_idx: usize, + rotation: i32, + ) -> Self { + Self { + expression: meta.query_advice(column, Rotation(rotation)), + column_expression: meta.query_advice(column, Rotation::cur()), + column: Some(column), + column_idx, + rotation, + } + } + + pub(crate) fn new_value(column_idx: usize, rotation: i32) -> Self { + Self { + expression: 0.expr(), + column_expression: 0.expr(), + column: None, + column_idx, + rotation, + } + } + + pub(crate) fn at_offset(&self, meta: &mut ConstraintSystem, offset: i32) -> Self { + let mut expression = 0.expr(); + meta.create_gate("Query cell", |meta| { + expression = meta.query_advice(self.column.unwrap(), Rotation(self.rotation + offset)); + vec![0.expr()] + }); + + Self { + expression, + column_expression: self.column_expression.clone(), + column: self.column, + column_idx: self.column_idx, + rotation: self.rotation + offset, + } + } + + pub(crate) fn assign(&self, region: &mut KeccakRegion, offset: i32, value: F) { + region.assign(self.column_idx, (offset + self.rotation) as usize, value); + } +} + +impl Expr for Cell { + fn expr(&self) -> Expression { + self.expression.clone() + } +} + +impl Expr for &Cell { + fn expr(&self) -> Expression { + self.expression.clone() + } +} + +/// CellColumn +#[derive(Clone, Debug)] +pub(crate) struct CellColumn { + pub(crate) advice: Column, + pub(crate) expr: Expression, +} + +/// CellManager +#[derive(Clone, Debug)] +pub(crate) struct CellManager { + height: usize, + width: usize, + current_row: usize, + columns: Vec>, + // rows[i] gives the number of columns already used in row `i` + rows: Vec, + num_unused_cells: usize, +} + +impl CellManager { + pub(crate) fn new(height: usize) -> Self { + Self { + height, + width: 0, + current_row: 0, + columns: Vec::new(), + rows: vec![0; height], + num_unused_cells: 0, + } + } + + pub(crate) fn query_cell(&mut self, meta: &mut ConstraintSystem) -> Cell { + let (row_idx, column_idx) = self.get_position(); + self.query_cell_at_pos(meta, row_idx as i32, column_idx) + } + + pub(crate) fn query_cell_at_row( + &mut self, + meta: &mut ConstraintSystem, + row_idx: i32, + ) -> Cell { + let column_idx = self.rows[row_idx as usize]; + self.rows[row_idx as usize] += 1; + self.width = self.width.max(column_idx + 1); + self.current_row = (row_idx as usize + 1) % self.height; + self.query_cell_at_pos(meta, row_idx, column_idx) + } + + pub(crate) fn query_cell_at_pos( + &mut self, + meta: &mut ConstraintSystem, + row_idx: i32, + column_idx: usize, + ) -> Cell { + let column = if column_idx < self.columns.len() { + self.columns[column_idx].advice + } else { + assert!(column_idx == self.columns.len()); + let advice = meta.advice_column(); + let mut expr = 0.expr(); + meta.create_gate("Query column", |meta| { + expr = meta.query_advice(advice, Rotation::cur()); + vec![0.expr()] + }); + self.columns.push(CellColumn { advice, expr }); + advice + }; + + let mut cells = Vec::new(); + meta.create_gate("Query cell", |meta| { + cells.push(Cell::new(meta, column, column_idx, row_idx)); + vec![0.expr()] + }); + cells[0].clone() + } + + pub(crate) fn query_cell_value(&mut self) -> Cell { + let (row_idx, column_idx) = self.get_position(); + self.query_cell_value_at_pos(row_idx as i32, column_idx) + } + + pub(crate) fn query_cell_value_at_row(&mut self, row_idx: i32) -> Cell { + let column_idx = self.rows[row_idx as usize]; + self.rows[row_idx as usize] += 1; + self.width = self.width.max(column_idx + 1); + self.current_row = (row_idx as usize + 1) % self.height; + self.query_cell_value_at_pos(row_idx, column_idx) + } + + pub(crate) fn query_cell_value_at_pos(&mut self, row_idx: i32, column_idx: usize) -> Cell { + Cell::new_value(column_idx, row_idx) + } + + fn get_position(&mut self) -> (usize, usize) { + let best_row_idx = self.current_row; + let best_row_pos = self.rows[best_row_idx]; + self.rows[best_row_idx] += 1; + self.width = self.width.max(best_row_pos + 1); + self.current_row = (best_row_idx + 1) % self.height; + (best_row_idx, best_row_pos) + } + + pub(crate) fn get_width(&self) -> usize { + self.width + } + + pub(crate) fn start_region(&mut self) -> usize { + // Make sure all rows start at the same column + let width = self.get_width(); + #[cfg(debug_assertions)] + for row in self.rows.iter() { + self.num_unused_cells += width - *row; + } + self.rows = vec![width; self.height]; + width + } + + pub(crate) fn columns(&self) -> &[CellColumn] { + &self.columns + } + + pub(crate) fn get_num_unused_cells(&self) -> usize { + self.num_unused_cells + } +} diff --git a/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs b/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs new file mode 100644 index 00000000..5a76d248 --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs @@ -0,0 +1,559 @@ +use super::{cell_manager::*, param::*, table::*}; +use crate::{ + halo2_proofs::{ + circuit::Value, + halo2curves::ff::PrimeField, + plonk::{Advice, Column, ConstraintSystem, Expression}, + }, + util::{ + constraint_builder::BaseConstraintBuilder, eth_types::Field, expression::Expr, word::Word, + }, +}; +use halo2_base::utils::halo2::Halo2AssignedCell; + +pub(crate) fn get_num_bits_per_absorb_lookup(k: u32) -> usize { + get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE, k) +} + +pub(crate) fn get_num_bits_per_theta_c_lookup(k: u32) -> usize { + get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k) +} + +pub(crate) fn get_num_bits_per_rho_pi_lookup(k: u32) -> usize { + get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) +} + +pub(crate) fn get_num_bits_per_base_chi_lookup(k: u32) -> usize { + get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) +} + +/// The number of keccak_f's that can be done in this circuit +/// +/// `num_rows` should be number of usable rows without blinding factors +pub fn get_keccak_capacity(num_rows: usize, rows_per_round: usize) -> usize { + // - 1 because we have a dummy round at the very beginning of multi_keccak + // - NUM_WORDS_TO_ABSORB because `absorb_data_next` and `absorb_result_next` query `NUM_WORDS_TO_ABSORB * num_rows_per_round` beyond any row where `q_absorb == 1` + (num_rows / rows_per_round - 1 - NUM_WORDS_TO_ABSORB) / (NUM_ROUNDS + 1) +} + +pub fn get_num_keccak_f(byte_length: usize) -> usize { + // ceil( (byte_length + 1) / RATE ) + byte_length / RATE + 1 +} + +/// AbsorbData +#[derive(Clone, Default, Debug, PartialEq)] +pub(crate) struct AbsorbData { + pub(crate) from: F, + pub(crate) absorb: F, + pub(crate) result: F, +} + +/// SqueezeData +#[derive(Clone, Default, Debug, PartialEq)] +pub(crate) struct SqueezeData { + packed: F, +} + +/// KeccakRow. Field definitions could be found in [KeccakCircuitConfig]. +#[derive(Clone, Debug)] +pub struct KeccakRow { + pub(crate) q_enable: bool, + pub(crate) q_round: bool, + pub(crate) q_absorb: bool, + pub(crate) q_round_last: bool, + pub(crate) q_input: bool, + pub(crate) q_input_last: bool, + pub(crate) round_cst: F, + pub(crate) is_final: bool, + pub(crate) cell_values: Vec, + pub(crate) hash: Word>, + pub(crate) bytes_left: F, + // A keccak word(NUM_BYTES_PER_WORD bytes) + pub(crate) word_value: F, +} + +impl KeccakRow { + pub fn dummy_rows(num_rows: usize) -> Vec { + (0..num_rows) + .map(|idx| KeccakRow { + q_enable: idx == 0, + q_round: false, + q_absorb: idx == 0, + q_round_last: false, + q_input: false, + q_input_last: false, + round_cst: F::ZERO, + is_final: false, + cell_values: Vec::new(), + hash: Word::default().into_value(), + bytes_left: F::ZERO, + word_value: F::ZERO, + }) + .collect() + } +} + +/// Part +#[derive(Clone, Debug)] +pub(crate) struct Part { + pub(crate) cell: Cell, + pub(crate) expr: Expression, + pub(crate) num_bits: usize, +} + +/// Part Value +#[derive(Clone, Copy, Debug)] +pub(crate) struct PartValue { + pub(crate) value: F, + pub(crate) rot: i32, + pub(crate) num_bits: usize, +} + +#[derive(Clone, Debug)] +pub(crate) struct KeccakRegion { + pub(crate) rows: Vec>, +} + +impl KeccakRegion { + pub(crate) fn new() -> Self { + Self { rows: Vec::new() } + } + + pub(crate) fn assign(&mut self, column: usize, offset: usize, value: F) { + while offset >= self.rows.len() { + self.rows.push(Vec::new()); + } + let row = &mut self.rows[offset]; + while column >= row.len() { + row.push(F::ZERO); + } + row[column] = value; + } +} + +/// Keccak Table, used to verify keccak hashing from RLC'ed input. +#[derive(Clone, Debug)] +pub struct KeccakTable { + /// True when the row is enabled + pub is_enabled: Column, + /// Keccak hash of input + pub output: Word>, + /// Raw keccak words(NUM_BYTES_PER_WORD bytes) of inputs + pub word_value: Column, + /// Number of bytes left of a input + pub bytes_left: Column, +} + +impl KeccakTable { + /// Construct a new KeccakTable + pub fn construct(meta: &mut ConstraintSystem) -> Self { + let is_enabled = meta.advice_column(); + let word_value = meta.advice_column(); + let bytes_left = meta.advice_column(); + let hash_lo = meta.advice_column(); + let hash_hi = meta.advice_column(); + meta.enable_equality(is_enabled); + meta.enable_equality(word_value); + meta.enable_equality(bytes_left); + meta.enable_equality(hash_lo); + meta.enable_equality(hash_hi); + Self { is_enabled, output: Word::new([hash_lo, hash_hi]), word_value, bytes_left } + } +} + +pub(crate) type KeccakAssignedValue<'v, F> = Halo2AssignedCell<'v, F>; + +/// Recombines parts back together +pub(crate) mod decode { + use super::{Expr, Part, PartValue, PrimeField}; + use crate::{halo2_proofs::plonk::Expression, keccak::vanilla::param::*}; + + pub(crate) fn expr(parts: Vec>) -> Expression { + parts.iter().rev().fold(0.expr(), |acc, part| { + acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.expr.clone() + }) + } + + pub(crate) fn value(parts: Vec>) -> F { + parts.iter().rev().fold(F::ZERO, |acc, part| { + acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.value + }) + } +} + +/// Splits a word into parts +pub(crate) mod split { + use super::{ + decode, BaseConstraintBuilder, CellManager, Expr, Field, KeccakRegion, Part, PartValue, + PrimeField, + }; + use crate::{ + halo2_proofs::plonk::{ConstraintSystem, Expression}, + keccak::vanilla::util::{pack, pack_part, unpack, WordParts}, + }; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + meta: &mut ConstraintSystem, + cell_manager: &mut CellManager, + cb: &mut BaseConstraintBuilder, + input: Expression, + rot: usize, + target_part_size: usize, + normalize: bool, + row: Option, + ) -> Vec> { + let word = WordParts::new(target_part_size, rot, normalize); + let mut parts = Vec::with_capacity(word.parts.len()); + for word_part in word.parts { + let cell = if let Some(row) = row { + cell_manager.query_cell_at_row(meta, row as i32) + } else { + cell_manager.query_cell(meta) + }; + parts.push(Part { + num_bits: word_part.bits.len(), + cell: cell.clone(), + expr: cell.expr(), + }); + } + // Input parts need to equal original input expression + cb.require_equal("split", decode::expr(parts.clone()), input); + parts + } + + pub(crate) fn value( + cell_manager: &mut CellManager, + region: &mut KeccakRegion, + input: F, + rot: usize, + target_part_size: usize, + normalize: bool, + row: Option, + ) -> Vec> { + let input_bits = unpack(input); + debug_assert_eq!(pack::(&input_bits), input); + let word = WordParts::new(target_part_size, rot, normalize); + let mut parts = Vec::with_capacity(word.parts.len()); + for word_part in word.parts { + let value = pack_part(&input_bits, &word_part); + let cell = if let Some(row) = row { + cell_manager.query_cell_value_at_row(row as i32) + } else { + cell_manager.query_cell_value() + }; + cell.assign(region, 0, F::from(value)); + parts.push(PartValue { + num_bits: word_part.bits.len(), + rot: cell.rotation, + value: F::from(value), + }); + } + debug_assert_eq!(decode::value(parts.clone()), input); + parts + } +} + +// Split into parts, but storing the parts in a specific way to have the same +// table layout in `output_cells` regardless of rotation. +pub(crate) mod split_uniform { + use super::decode; + use crate::{ + halo2_proofs::plonk::{ConstraintSystem, Expression}, + keccak::vanilla::{ + param::*, + target_part_sizes, + util::{pack, pack_part, rotate, rotate_rev, unpack, WordParts}, + BaseConstraintBuilder, Cell, CellManager, Expr, KeccakRegion, Part, PartValue, + PrimeField, + }, + util::eth_types::Field, + }; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + meta: &mut ConstraintSystem, + output_cells: &[Cell], + cell_manager: &mut CellManager, + cb: &mut BaseConstraintBuilder, + input: Expression, + rot: usize, + target_part_size: usize, + normalize: bool, + ) -> Vec> { + let mut input_parts = Vec::new(); + let mut output_parts = Vec::new(); + let word = WordParts::new(target_part_size, rot, normalize); + + let word = rotate(word.parts, rot, target_part_size); + + let target_sizes = target_part_sizes(target_part_size); + let mut word_iter = word.iter(); + let mut counter = 0; + while let Some(word_part) = word_iter.next() { + if word_part.bits.len() == target_sizes[counter] { + // Input and output part are the same + let part = Part { + num_bits: target_sizes[counter], + cell: output_cells[counter].clone(), + expr: output_cells[counter].expr(), + }; + input_parts.push(part.clone()); + output_parts.push(part); + counter += 1; + } else if let Some(extra_part) = word_iter.next() { + // The two parts combined need to have the expected combined length + debug_assert_eq!( + word_part.bits.len() + extra_part.bits.len(), + target_sizes[counter] + ); + + // Needs two cells here to store the parts + // These still need to be range checked elsewhere! + let part_a = cell_manager.query_cell(meta); + let part_b = cell_manager.query_cell(meta); + + // Make sure the parts combined equal the value in the uniform output + let expr = part_a.expr() + + part_b.expr() + * F::from((BIT_SIZE as u32).pow(word_part.bits.len() as u32) as u64); + cb.require_equal("rot part", expr, output_cells[counter].expr()); + + // Input needs the two parts because it needs to be able to undo the rotation + input_parts.push(Part { + num_bits: word_part.bits.len(), + cell: part_a.clone(), + expr: part_a.expr(), + }); + input_parts.push(Part { + num_bits: extra_part.bits.len(), + cell: part_b.clone(), + expr: part_b.expr(), + }); + // Output only has the combined cell + output_parts.push(Part { + num_bits: target_sizes[counter], + cell: output_cells[counter].clone(), + expr: output_cells[counter].expr(), + }); + counter += 1; + } else { + unreachable!(); + } + } + let input_parts = rotate_rev(input_parts, rot, target_part_size); + // Input parts need to equal original input expression + cb.require_equal("split", decode::expr(input_parts), input); + // Uniform output + output_parts + } + + pub(crate) fn value( + output_cells: &[Cell], + cell_manager: &mut CellManager, + region: &mut KeccakRegion, + input: F, + rot: usize, + target_part_size: usize, + normalize: bool, + ) -> Vec> { + let input_bits = unpack(input); + debug_assert_eq!(pack::(&input_bits), input); + + let mut input_parts = Vec::new(); + let mut output_parts = Vec::new(); + let word = WordParts::new(target_part_size, rot, normalize); + + let word = rotate(word.parts, rot, target_part_size); + + let target_sizes = target_part_sizes(target_part_size); + let mut word_iter = word.iter(); + let mut counter = 0; + while let Some(word_part) = word_iter.next() { + if word_part.bits.len() == target_sizes[counter] { + let value = pack_part(&input_bits, word_part); + output_cells[counter].assign(region, 0, F::from(value)); + input_parts.push(PartValue { + num_bits: word_part.bits.len(), + rot: output_cells[counter].rotation, + value: F::from(value), + }); + output_parts.push(PartValue { + num_bits: word_part.bits.len(), + rot: output_cells[counter].rotation, + value: F::from(value), + }); + counter += 1; + } else if let Some(extra_part) = word_iter.next() { + debug_assert_eq!( + word_part.bits.len() + extra_part.bits.len(), + target_sizes[counter] + ); + + let part_a = cell_manager.query_cell_value(); + let part_b = cell_manager.query_cell_value(); + + let value_a = pack_part(&input_bits, word_part); + let value_b = pack_part(&input_bits, extra_part); + + part_a.assign(region, 0, F::from(value_a)); + part_b.assign(region, 0, F::from(value_b)); + + let value = value_a + value_b * (BIT_SIZE as u64).pow(word_part.bits.len() as u32); + + output_cells[counter].assign(region, 0, F::from(value)); + + input_parts.push(PartValue { + num_bits: word_part.bits.len(), + value: F::from(value_a), + rot: part_a.rotation, + }); + input_parts.push(PartValue { + num_bits: extra_part.bits.len(), + value: F::from(value_b), + rot: part_b.rotation, + }); + output_parts.push(PartValue { + num_bits: target_sizes[counter], + value: F::from(value), + rot: output_cells[counter].rotation, + }); + counter += 1; + } else { + unreachable!(); + } + } + let input_parts = rotate_rev(input_parts, rot, target_part_size); + debug_assert_eq!(decode::value(input_parts), input); + output_parts + } +} + +// Transform values using a lookup table +pub(crate) mod transform { + use super::{transform_to, CellManager, Field, KeccakRegion, Part, PartValue, PrimeField}; + use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; + use itertools::Itertools; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + name: &'static str, + meta: &mut ConstraintSystem, + cell_manager: &mut CellManager, + lookup_counter: &mut usize, + input: Vec>, + transform_table: [TableColumn; 2], + uniform_lookup: bool, + ) -> Vec> { + let cells = input + .iter() + .map(|input_part| { + if uniform_lookup { + cell_manager.query_cell_at_row(meta, input_part.cell.rotation) + } else { + cell_manager.query_cell(meta) + } + }) + .collect_vec(); + transform_to::expr( + name, + meta, + &cells, + lookup_counter, + input, + transform_table, + uniform_lookup, + ) + } + + pub(crate) fn value( + cell_manager: &mut CellManager, + region: &mut KeccakRegion, + input: Vec>, + do_packing: bool, + f: fn(&u8) -> u8, + uniform_lookup: bool, + ) -> Vec> { + let cells = input + .iter() + .map(|input_part| { + if uniform_lookup { + cell_manager.query_cell_value_at_row(input_part.rot) + } else { + cell_manager.query_cell_value() + } + }) + .collect_vec(); + transform_to::value(&cells, region, input, do_packing, f) + } +} + +// Transfroms values to cells +pub(crate) mod transform_to { + use crate::{ + halo2_proofs::plonk::{ConstraintSystem, TableColumn}, + keccak::vanilla::{ + util::{pack, to_bytes, unpack}, + Cell, Expr, Field, KeccakRegion, Part, PartValue, PrimeField, + }, + }; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + name: &'static str, + meta: &mut ConstraintSystem, + cells: &[Cell], + lookup_counter: &mut usize, + input: Vec>, + transform_table: [TableColumn; 2], + uniform_lookup: bool, + ) -> Vec> { + let mut output = Vec::with_capacity(input.len()); + for (idx, input_part) in input.iter().enumerate() { + let output_part = cells[idx].clone(); + if !uniform_lookup || input_part.cell.rotation == 0 { + meta.lookup(name, |_| { + vec![ + (input_part.expr.clone(), transform_table[0]), + (output_part.expr(), transform_table[1]), + ] + }); + *lookup_counter += 1; + } + output.push(Part { + num_bits: input_part.num_bits, + cell: output_part.clone(), + expr: output_part.expr(), + }); + } + output + } + + pub(crate) fn value( + cells: &[Cell], + region: &mut KeccakRegion, + input: Vec>, + do_packing: bool, + f: fn(&u8) -> u8, + ) -> Vec> { + let mut output = Vec::new(); + for (idx, input_part) in input.iter().enumerate() { + let input_bits = &unpack(input_part.value)[0..input_part.num_bits]; + let output_bits = input_bits.iter().map(f).collect::>(); + let value = if do_packing { + pack(&output_bits) + } else { + F::from(to_bytes::value(&output_bits)[0] as u64) + }; + let output_part = cells[idx].clone(); + output_part.assign(region, 0, value); + output.push(PartValue { + num_bits: input_part.num_bits, + rot: output_part.rotation, + value, + }); + } + output + } +} diff --git a/hashes/zkevm/src/keccak/vanilla/mod.rs b/hashes/zkevm/src/keccak/vanilla/mod.rs new file mode 100644 index 00000000..8018142f --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/mod.rs @@ -0,0 +1,891 @@ +use self::{cell_manager::*, keccak_packed_multi::*, param::*, table::*, util::*}; +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Region, Value}, + halo2curves::ff::PrimeField, + plonk::{Column, ConstraintSystem, Error, Expression, Fixed, TableColumn, VirtualCells}, + poly::Rotation, + }, + util::{ + constraint_builder::BaseConstraintBuilder, + eth_types::{self, Field}, + expression::{and, from_bytes, not, select, sum, Expr}, + word::{self, Word, WordExpr}, + }, +}; +use halo2_base::utils::halo2::{raw_assign_advice, raw_assign_fixed}; +use itertools::Itertools; +use log::{debug, info}; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use std::marker::PhantomData; + +pub mod cell_manager; +pub mod keccak_packed_multi; +pub mod param; +pub mod table; +#[cfg(test)] +mod tests; +pub mod util; +/// Module for witness generation. +pub mod witness; + +/// Configuration parameters to define [`KeccakCircuitConfig`] +#[derive(Copy, Clone, Debug, Default)] +pub struct KeccakConfigParams { + /// The circuit degree, i.e., circuit has 2k rows + pub k: u32, + /// The number of rows to use for each round in the keccak_f permutation + pub rows_per_round: usize, +} + +/// KeccakConfig +#[derive(Clone, Debug)] +pub struct KeccakCircuitConfig { + // Bool. True on 1st row of each round. + q_enable: Column, + // Bool. True on 1st row. + q_first: Column, + // Bool. True on 1st row of all rounds except last rounds. + q_round: Column, + // Bool. True on 1st row of last rounds. + q_absorb: Column, + // Bool. True on 1st row of last rounds. + q_round_last: Column, + // Bool. True on 1st row of rounds which might contain inputs. + // Note: first NUM_WORDS_TO_ABSORB rounds of each chunk might contain inputs. + // It "might" contain inputs because it's possible that a round only have paddings. + q_input: Column, + // Bool. True on 1st row of all last input round. + q_input_last: Column, + + pub keccak_table: KeccakTable, + + cell_manager: CellManager, + round_cst: Column, + normalize_3: [TableColumn; 2], + normalize_4: [TableColumn; 2], + normalize_6: [TableColumn; 2], + chi_base_table: [TableColumn; 2], + pack_table: [TableColumn; 2], + + // config parameters for convenience + pub parameters: KeccakConfigParams, + + _marker: PhantomData, +} + +impl KeccakCircuitConfig { + /// Return a new KeccakCircuitConfig + pub fn new(meta: &mut ConstraintSystem, parameters: KeccakConfigParams) -> Self { + let k = parameters.k; + let num_rows_per_round = parameters.rows_per_round; + + let q_enable = meta.fixed_column(); + let q_first = meta.fixed_column(); + let q_round = meta.fixed_column(); + let q_absorb = meta.fixed_column(); + let q_round_last = meta.fixed_column(); + let q_input = meta.fixed_column(); + let q_input_last = meta.fixed_column(); + let round_cst = meta.fixed_column(); + let keccak_table = KeccakTable::construct(meta); + + let is_final = keccak_table.is_enabled; + let hash_word = keccak_table.output; + + let normalize_3 = array_init::array_init(|_| meta.lookup_table_column()); + let normalize_4 = array_init::array_init(|_| meta.lookup_table_column()); + let normalize_6 = array_init::array_init(|_| meta.lookup_table_column()); + let chi_base_table = array_init::array_init(|_| meta.lookup_table_column()); + let pack_table = array_init::array_init(|_| meta.lookup_table_column()); + + let mut cell_manager = CellManager::new(num_rows_per_round); + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let mut total_lookup_counter = 0; + + let start_new_hash = |meta: &mut VirtualCells, rot| { + // A new hash is started when the previous hash is done or on the first row + meta.query_fixed(q_first, rot) + meta.query_advice(is_final, rot) + }; + + // Round constant + let mut round_cst_expr = 0.expr(); + meta.create_gate("Query round cst", |meta| { + round_cst_expr = meta.query_fixed(round_cst, Rotation::cur()); + vec![0u64.expr()] + }); + // State data + let mut s = vec![vec![0u64.expr(); 5]; 5]; + let mut s_next = vec![vec![0u64.expr(); 5]; 5]; + for i in 0..5 { + for j in 0..5 { + let cell = cell_manager.query_cell(meta); + s[i][j] = cell.expr(); + s_next[i][j] = cell.at_offset(meta, num_rows_per_round as i32).expr(); + } + } + // Absorb data + let absorb_from = cell_manager.query_cell(meta); + let absorb_data = cell_manager.query_cell(meta); + let absorb_result = cell_manager.query_cell(meta); + let mut absorb_from_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; + let mut absorb_data_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; + let mut absorb_result_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; + for i in 0..NUM_WORDS_TO_ABSORB { + let rot = ((i + 1) * num_rows_per_round) as i32; + absorb_from_next[i] = absorb_from.at_offset(meta, rot).expr(); + absorb_data_next[i] = absorb_data.at_offset(meta, rot).expr(); + absorb_result_next[i] = absorb_result.at_offset(meta, rot).expr(); + } + + // Store the pre-state + let pre_s = s.clone(); + + // Absorb + // The absorption happening at the start of the 24 rounds is done spread out + // over those 24 rounds. In a single round (in 17 of the 24 rounds) a + // single word is absorbed so the work is spread out. The absorption is + // done simply by doing state + data and then normalizing the result to [0,1]. + // We also need to convert the input data into bytes to calculate the input data + // rlc. + cell_manager.start_region(); + let mut lookup_counter = 0; + let part_size = get_num_bits_per_absorb_lookup(k); + let input = absorb_from.expr() + absorb_data.expr(); + let absorb_fat = + split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); + cell_manager.start_region(); + let absorb_res = transform::expr( + "absorb", + meta, + &mut cell_manager, + &mut lookup_counter, + absorb_fat, + normalize_3, + true, + ); + cb.require_equal("absorb result", decode::expr(absorb_res), absorb_result.expr()); + info!("- Post absorb:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Squeeze + // The squeezing happening at the end of the 24 rounds is done spread out + // over those 24 rounds. In a single round (in 4 of the 24 rounds) a + // single word is converted to bytes. + cell_manager.start_region(); + let mut lookup_counter = 0; + // Potential optimization: could do multiple bytes per lookup + let packed_parts = + split::expr(meta, &mut cell_manager, &mut cb, absorb_data.expr(), 0, 8, false, None); + cell_manager.start_region(); + // input_bytes.len() = packed_parts.len() = 64 / 8 = 8 = NUM_BYTES_PER_WORD + let input_bytes = transform::expr( + "squeeze unpack", + meta, + &mut cell_manager, + &mut lookup_counter, + packed_parts, + pack_table.into_iter().rev().collect::>().try_into().unwrap(), + true, + ); + debug_assert_eq!(input_bytes.len(), NUM_BYTES_PER_WORD); + + // Padding data + cell_manager.start_region(); + let is_paddings = input_bytes.iter().map(|_| cell_manager.query_cell(meta)).collect_vec(); + info!("- Post padding:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Theta + // Calculate + // - `c[i] = s[i][0] + s[i][1] + s[i][2] + s[i][3] + s[i][4]` + // - `bc[i] = normalize(c)`. + // - `t[i] = bc[(i + 4) % 5] + rot(bc[(i + 1)% 5], 1)` + // This is done by splitting the bc values in parts in a way + // that allows us to also calculate the rotated value "for free". + cell_manager.start_region(); + let mut lookup_counter = 0; + let part_size_c = get_num_bits_per_theta_c_lookup(k); + let mut c_parts = Vec::new(); + for s in s.iter() { + // Calculate c and split into parts + let c = s[0].clone() + s[1].clone() + s[2].clone() + s[3].clone() + s[4].clone(); + c_parts.push(split::expr( + meta, + &mut cell_manager, + &mut cb, + c, + 1, + part_size_c, + false, + None, + )); + } + // Now calculate `bc` by normalizing `c` + cell_manager.start_region(); + let mut bc = Vec::new(); + for c in c_parts { + // Normalize c + bc.push(transform::expr( + "theta c", + meta, + &mut cell_manager, + &mut lookup_counter, + c, + normalize_6, + true, + )); + } + // Now do `bc[(i + 4) % 5] + rot(bc[(i + 1) % 5], 1)` using just expressions. + // We don't normalize the result here. We do it as part of the rho/pi step, even + // though we would only have to normalize 5 values instead of 25, because of the + // way the rho/pi and chi steps can be combined it's more efficient to + // do it there (the max value for chi is 4 already so that's the + // limiting factor). + let mut os = vec![vec![0u64.expr(); 5]; 5]; + for i in 0..5 { + let t = decode::expr(bc[(i + 4) % 5].clone()) + + decode::expr(rotate(bc[(i + 1) % 5].clone(), 1, part_size_c)); + for j in 0..5 { + os[i][j] = s[i][j].clone() + t.clone(); + } + } + s = os.clone(); + info!("- Post theta:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Rho/Pi + // For the rotation of rho/pi we split up the words like expected, but in a way + // that allows reusing the same parts in an optimal way for the chi step. + // We can save quite a few columns by not recombining the parts after rho/pi and + // re-splitting the words again before chi. Instead we do chi directly + // on the output parts of rho/pi. For rho/pi specically we do + // `s[j][2 * i + 3 * j) % 5] = normalize(rot(s[i][j], RHOM[i][j]))`. + cell_manager.start_region(); + let mut lookup_counter = 0; + let part_size = get_num_bits_per_base_chi_lookup(k); + // To combine the rho/pi/chi steps we have to ensure a specific layout so + // query those cells here first. + // For chi we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & s[(i+2)%5][j])`. `j` + // remains static but `i` is accessed in a wrap around manner. To do this using + // multiple rows with lookups in a way that doesn't require any + // extra additional cells or selectors we have to put all `s[i]`'s on the same + // row. This isn't that strong of a requirement actually because we the + // words are split into multipe parts, and so only the parts at the same + // position of those words need to be on the same row. + let target_word_sizes = target_part_sizes(part_size); + let num_word_parts = target_word_sizes.len(); + let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = array_init::array_init(|_| { + array_init::array_init(|_| array_init::array_init(|_| Vec::new())) + }); + let mut num_columns = 0; + let mut column_starts = [0usize; 3]; + for p in 0..3 { + column_starts[p] = cell_manager.start_region(); + let mut row_idx = 0; + num_columns = 0; + for j in 0..5 { + for _ in 0..num_word_parts { + for i in 0..5 { + rho_pi_chi_cells[p][i][j] + .push(cell_manager.query_cell_at_row(meta, row_idx)); + } + if row_idx == 0 { + num_columns += 1; + } + row_idx = (((row_idx as usize) + 1) % num_rows_per_round) as i32; + } + } + } + // Do the transformation, resulting in the word parts also being normalized. + let pi_region_start = cell_manager.start_region(); + let mut os_parts = vec![vec![Vec::new(); 5]; 5]; + for (j, os_part) in os_parts.iter_mut().enumerate() { + for i in 0..5 { + // Split s into parts + let s_parts = split_uniform::expr( + meta, + &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], + &mut cell_manager, + &mut cb, + s[i][j].clone(), + RHO_MATRIX[i][j], + part_size, + true, + ); + // Normalize the data to the target cells + let s_parts = transform_to::expr( + "rho/pi", + meta, + &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], + &mut lookup_counter, + s_parts.clone(), + normalize_4, + true, + ); + os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); + } + } + let pi_region_end = cell_manager.start_region(); + // Pi parts range checks + // To make the uniform stuff work we had to combine some parts together + // in new cells (see split_uniform). Here we make sure those parts are range + // checked. Potential improvement: Could combine multiple smaller parts + // in a single lookup but doesn't save that much. + for c in pi_region_start..pi_region_end { + meta.lookup("pi part range check", |_| { + vec![(cell_manager.columns()[c].expr.clone(), normalize_4[0])] + }); + lookup_counter += 1; + } + info!("- Post rho/pi:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Chi + // In groups of 5 columns, we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & + // s[(i+2)%5][j])` five times, on each row (no selector needed). + // This is calculated by making use of `CHI_BASE_LOOKUP_TABLE`. + let mut lookup_counter = 0; + let part_size_base = get_num_bits_per_base_chi_lookup(k); + for idx in 0..num_columns { + // First fetch the cells we wan to use + let mut input: [Expression; 5] = array_init::array_init(|_| 0.expr()); + let mut output: [Expression; 5] = array_init::array_init(|_| 0.expr()); + for c in 0..5 { + input[c] = cell_manager.columns()[column_starts[1] + idx * 5 + c].expr.clone(); + output[c] = cell_manager.columns()[column_starts[2] + idx * 5 + c].expr.clone(); + } + // Now calculate `a ^ ((~b) & c)` by doing `lookup[3 - 2*a + b - c]` + for i in 0..5 { + let input = scatter::expr(3, part_size_base) - 2.expr() * input[i].clone() + + input[(i + 1) % 5].clone() + - input[(i + 2) % 5].clone(); + let output = output[i].clone(); + meta.lookup("chi base", |_| { + vec![(input.clone(), chi_base_table[0]), (output.clone(), chi_base_table[1])] + }); + lookup_counter += 1; + } + } + // Now just decode the parts after the chi transformation done with the lookups + // above. + let mut os = vec![vec![0u64.expr(); 5]; 5]; + for (i, os) in os.iter_mut().enumerate() { + for (j, os) in os.iter_mut().enumerate() { + let mut parts = Vec::new(); + for idx in 0..num_word_parts { + parts.push(Part { + num_bits: part_size_base, + cell: rho_pi_chi_cells[2][i][j][idx].clone(), + expr: rho_pi_chi_cells[2][i][j][idx].expr(), + }); + } + *os = decode::expr(parts); + } + } + s = os.clone(); + + // iota + // Simply do the single xor on state [0][0]. + cell_manager.start_region(); + let part_size = get_num_bits_per_absorb_lookup(k); + let input = s[0][0].clone() + round_cst_expr.clone(); + let iota_parts = + split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); + cell_manager.start_region(); + // Could share columns with absorb which may end up using 1 lookup/column + // fewer... + s[0][0] = decode::expr(transform::expr( + "iota", + meta, + &mut cell_manager, + &mut lookup_counter, + iota_parts, + normalize_3, + true, + )); + // Final results stored in the next row + for i in 0..5 { + for j in 0..5 { + cb.require_equal("next row check", s[i][j].clone(), s_next[i][j].clone()); + } + } + info!("- Post chi:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + let mut lookup_counter = 0; + cell_manager.start_region(); + + // Squeeze data + let squeeze_from = cell_manager.query_cell(meta); + let mut squeeze_from_prev = vec![0u64.expr(); NUM_WORDS_TO_SQUEEZE]; + for (idx, squeeze_from_prev) in squeeze_from_prev.iter_mut().enumerate() { + let rot = (-(idx as i32) - 1) * num_rows_per_round as i32; + *squeeze_from_prev = squeeze_from.at_offset(meta, rot).expr(); + } + // Squeeze + // The squeeze happening at the end of the 24 rounds is done spread out + // over those 24 rounds. In a single round (in 4 of the 24 rounds) a + // single word is converted to bytes. + // Potential optimization: could do multiple bytes per lookup + cell_manager.start_region(); + // Unpack a single word into bytes (for the squeeze) + // Potential optimization: could do multiple bytes per lookup + let squeeze_from_parts = + split::expr(meta, &mut cell_manager, &mut cb, squeeze_from.expr(), 0, 8, false, None); + cell_manager.start_region(); + let squeeze_bytes = transform::expr( + "squeeze unpack", + meta, + &mut cell_manager, + &mut lookup_counter, + squeeze_from_parts, + pack_table.into_iter().rev().collect::>().try_into().unwrap(), + true, + ); + info!("- Post squeeze:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // The round constraints that we've been building up till now + meta.create_gate("round", |meta| cb.gate(meta.query_fixed(q_round, Rotation::cur()))); + + // Absorb + meta.create_gate("absorb", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let continue_hash = not::expr(start_new_hash(meta, Rotation::cur())); + let absorb_positions = get_absorb_positions(); + let mut a_slice = 0; + for j in 0..5 { + for i in 0..5 { + if absorb_positions.contains(&(i, j)) { + cb.condition(continue_hash.clone(), |cb| { + cb.require_equal( + "absorb verify input", + absorb_from_next[a_slice].clone(), + pre_s[i][j].clone(), + ); + }); + cb.require_equal( + "absorb result copy", + select::expr( + continue_hash.clone(), + absorb_result_next[a_slice].clone(), + absorb_data_next[a_slice].clone(), + ), + s_next[i][j].clone(), + ); + a_slice += 1; + } else { + cb.require_equal( + "absorb state copy", + pre_s[i][j].clone() * continue_hash.clone(), + s_next[i][j].clone(), + ); + } + } + } + cb.gate(meta.query_fixed(q_absorb, Rotation::cur())) + }); + + // Collect the bytes that are spread out over previous rows + let mut hash_bytes = Vec::new(); + for i in 0..NUM_WORDS_TO_SQUEEZE { + for byte in squeeze_bytes.iter() { + let rot = (-(i as i32) - 1) * num_rows_per_round as i32; + hash_bytes.push(byte.cell.at_offset(meta, rot).expr()); + } + } + + // Squeeze + meta.create_gate("squeeze", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let start_new_hash = start_new_hash(meta, Rotation::cur()); + // The words to squeeze + let hash_words: Vec<_> = + pre_s.into_iter().take(4).map(|a| a[0].clone()).take(4).collect(); + // Verify if we converted the correct words to bytes on previous rows + for (idx, word) in hash_words.iter().enumerate() { + cb.condition(start_new_hash.clone(), |cb| { + cb.require_equal( + "squeeze verify packed", + word.clone(), + squeeze_from_prev[idx].clone(), + ); + }); + } + + let hash_bytes_le = hash_bytes.into_iter().rev().collect::>(); + cb.condition(start_new_hash, |cb| { + cb.require_equal_word( + "output check", + word::Word32::new(hash_bytes_le.try_into().expect("32 limbs")).to_word(), + hash_word.map(|col| meta.query_advice(col, Rotation::cur())), + ); + }); + cb.gate(meta.query_fixed(q_round_last, Rotation::cur())) + }); + + // Some general input checks + meta.create_gate("input checks", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + cb.require_boolean("boolean is_final", meta.query_advice(is_final, Rotation::cur())); + cb.gate(meta.query_fixed(q_enable, Rotation::cur())) + }); + + // Enforce fixed values on the first row + meta.create_gate("first row", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + cb.require_zero( + "is_final needs to be disabled on the first row", + meta.query_advice(is_final, Rotation::cur()), + ); + cb.gate(meta.query_fixed(q_first, Rotation::cur())) + }); + + // some utility query functions + let q = |col: Column, meta: &mut VirtualCells<'_, F>| { + meta.query_fixed(col, Rotation::cur()) + }; + /* + eg: + data: + get_num_rows_per_round: 18 + input: "12345678abc" + table: + Note[1]: be careful: is_paddings is not column here! It is [Cell; 8] and it will be constrained later. + Note[2]: only first row of each round has constraints on bytes_left. This example just shows how witnesses are filled. + offset word_value bytes_left is_paddings q_enable q_input_last + 18 0x87654321 11 0 1 0 // 1st round begin + 19 0 10 0 0 0 + 20 0 9 0 0 0 + 21 0 8 0 0 0 + 22 0 7 0 0 0 + 23 0 6 0 0 0 + 24 0 5 0 0 0 + 25 0 4 0 0 0 + 26 0 4 NA 0 0 + ... + 35 0 4 NA 0 0 // 1st round end + 36 0xcba 3 0 1 1 // 2nd round begin + 37 0 2 0 0 0 + 38 0 1 0 0 0 + 39 0 0 1 0 0 + 40 0 0 1 0 0 + 41 0 0 1 0 0 + 42 0 0 1 0 0 + 43 0 0 1 0 0 + */ + + meta.create_gate("word_value", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let masked_input_bytes = input_bytes + .iter() + .zip_eq(is_paddings.clone()) + .map(|(input_byte, is_padding)| { + input_byte.expr.clone() * not::expr(is_padding.expr().clone()) + }) + .collect_vec(); + let input_word = from_bytes::expr(&masked_input_bytes); + cb.require_equal( + "word value", + input_word, + meta.query_advice(keccak_table.word_value, Rotation::cur()), + ); + cb.gate(q(q_input, meta)) + }); + meta.create_gate("bytes_left", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let bytes_left_expr = meta.query_advice(keccak_table.bytes_left, Rotation::cur()); + + // bytes_left is 0 in the absolute first `rows_per_round` of the entire circuit, i.e., the first dummy round. + cb.condition(q(q_first, meta), |cb| { + cb.require_zero( + "bytes_left needs to be zero on the absolute first dummy round", + meta.query_advice(keccak_table.bytes_left, Rotation::cur()), + ); + }); + // is_final ==> bytes_left == 0. + // Note: is_final = true only in the last round, which doesn't have any data to absorb. + cb.condition(meta.query_advice(is_final, Rotation::cur()), |cb| { + cb.require_zero("bytes_left should be 0 when is_final", bytes_left_expr.clone()); + }); + // q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] + cb.condition(q(q_input, meta), |cb| { + // word_len = NUM_BYTES_PER_WORD - sum(is_paddings) + let word_len = NUM_BYTES_PER_WORD.expr() - sum::expr(is_paddings.clone()); + let bytes_left_next_expr = + meta.query_advice(keccak_table.bytes_left, Rotation(num_rows_per_round as i32)); + cb.require_equal( + "if there is a word in this round, bytes_left[curr + num_rows_per_round] + word_len == bytes_left[curr]", + bytes_left_expr.clone(), + bytes_left_next_expr + word_len, + ); + }); + // Logically here we want !q_input[cur] && !start_new_hash(cur) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] + // In practice, in order to save a degree we use !(q_input[cur] ^ start_new_hash(cur)) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] + // When q_input[cur] is true, the above constraint q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] has + // already been enabled. Even is_final in start_new_hash(cur) is true, it's just over-constrainted. + // Note: At the first row of any round except the last round, is_final could be either true or false. + cb.condition(not::expr(q(q_input, meta) + start_new_hash(meta, Rotation::cur())), |cb| { + let bytes_left_next_expr = + meta.query_advice(keccak_table.bytes_left, Rotation(num_rows_per_round as i32)); + cb.require_equal( + "if no input and not starting new hash, bytes_left should keep the same", + bytes_left_expr, + bytes_left_next_expr, + ); + }); + + cb.gate(q(q_enable, meta)) + }); + + // Enforce logic for when this block is the last block for a hash + let last_is_padding_in_block = is_paddings.last().unwrap().at_offset( + meta, + -(((NUM_ROUNDS + 1 - NUM_WORDS_TO_ABSORB) * num_rows_per_round) as i32), + ); + meta.create_gate("is final", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + // All absorb rows except the first row + cb.condition( + meta.query_fixed(q_absorb, Rotation::cur()) + - meta.query_fixed(q_first, Rotation::cur()), + |cb| { + cb.require_equal( + "is_final needs to be the same as the last is_padding in the block", + meta.query_advice(is_final, Rotation::cur()), + last_is_padding_in_block.expr(), + ); + }, + ); + // For all the rows of a round, only the first row can have `is_final == 1`. + cb.condition( + (1..num_rows_per_round as i32) + .map(|i| meta.query_fixed(q_enable, Rotation(-i))) + .fold(0.expr(), |acc, elem| acc + elem), + |cb| { + cb.require_zero( + "is_final only when q_enable", + meta.query_advice(is_final, Rotation::cur()), + ); + }, + ); + cb.gate(1.expr()) + }); + + // Padding + // May be cleaner to do this padding logic in the byte conversion lookup but + // currently easier to do it like this. + let prev_is_padding = + is_paddings.last().unwrap().at_offset(meta, -(num_rows_per_round as i32)); + meta.create_gate("padding", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let q_input = meta.query_fixed(q_input, Rotation::cur()); + let q_input_last = meta.query_fixed(q_input_last, Rotation::cur()); + + // All padding selectors need to be boolean + for is_padding in is_paddings.iter() { + cb.condition(meta.query_fixed(q_enable, Rotation::cur()), |cb| { + cb.require_boolean("is_padding boolean", is_padding.expr()); + }); + } + // This last padding selector will be used on the first round row so needs to be + // zero + cb.condition(meta.query_fixed(q_absorb, Rotation::cur()), |cb| { + cb.require_zero( + "last is_padding should be zero on absorb rows", + is_paddings.last().unwrap().expr(), + ); + }); + // Now for each padding selector + for idx in 0..is_paddings.len() { + // Previous padding selector can be on the previous row + let is_padding_prev = + if idx == 0 { prev_is_padding.expr() } else { is_paddings[idx - 1].expr() }; + let is_first_padding = is_paddings[idx].expr() - is_padding_prev.clone(); + + // Check padding transition 0 -> 1 done only once + cb.condition(q_input.expr(), |cb| { + cb.require_boolean("padding step boolean", is_first_padding.clone()); + }); + + // Padding start/intermediate/end byte checks + if idx == is_paddings.len() - 1 { + // These can be combined in the future, but currently this would increase the + // degree by one Padding start/intermediate byte, all + // padding rows except the last one + cb.condition( + and::expr([q_input.expr() - q_input_last.expr(), is_paddings[idx].expr()]), + |cb| { + // Input bytes need to be zero, or one if this is the first padding byte + cb.require_equal( + "padding start/intermediate byte last byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr(), + ); + }, + ); + // Padding start/end byte, only on the last padding row + cb.condition(and::expr([q_input_last.expr(), is_paddings[idx].expr()]), |cb| { + // The input byte needs to be 128, unless it's also the first padding + // byte then it's 129 + cb.require_equal( + "padding start/end byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr() + 128.expr(), + ); + }); + } else { + // Padding start/intermediate byte + cb.condition(and::expr([q_input.expr(), is_paddings[idx].expr()]), |cb| { + // Input bytes need to be zero, or one if this is the first padding byte + cb.require_equal( + "padding start/intermediate byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr(), + ); + }); + } + } + cb.gate(1.expr()) + }); + + info!("Degree: {}", meta.degree()); + info!("Minimum rows: {}", meta.minimum_rows()); + info!("Total Lookups: {}", total_lookup_counter); + #[cfg(feature = "display")] + { + println!("Total Keccak Columns: {}", cell_manager.get_width()); + std::env::set_var("KECCAK_ADVICE_COLUMNS", cell_manager.get_width().to_string()); + } + #[cfg(not(feature = "display"))] + info!("Total Keccak Columns: {}", cell_manager.get_width()); + info!("num unused cells: {}", cell_manager.get_num_unused_cells()); + info!("part_size absorb: {}", get_num_bits_per_absorb_lookup(k)); + info!("part_size theta: {}", get_num_bits_per_theta_c_lookup(k)); + info!("part_size theta c: {}", get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k)); + info!("part_size theta t: {}", get_num_bits_per_lookup(4, k)); + info!("part_size rho/pi: {}", get_num_bits_per_rho_pi_lookup(k)); + info!("part_size chi base: {}", get_num_bits_per_base_chi_lookup(k)); + info!("uniform part sizes: {:?}", target_part_sizes(get_num_bits_per_theta_c_lookup(k))); + + KeccakCircuitConfig { + q_enable, + q_first, + q_round, + q_absorb, + q_round_last, + q_input, + q_input_last, + keccak_table, + cell_manager, + round_cst, + normalize_3, + normalize_4, + normalize_6, + chi_base_table, + pack_table, + parameters, + _marker: PhantomData, + } + } +} + +#[derive(Clone)] +pub struct KeccakAssignedRow<'v, F: Field> { + pub is_final: KeccakAssignedValue<'v, F>, + pub hash_lo: KeccakAssignedValue<'v, F>, + pub hash_hi: KeccakAssignedValue<'v, F>, + pub bytes_left: KeccakAssignedValue<'v, F>, + pub word_value: KeccakAssignedValue<'v, F>, + pub _marker: PhantomData<&'v ()>, +} + +impl KeccakCircuitConfig { + /// Returns vector of `is_final`, `length`, `hash.lo`, `hash.hi` for assigned rows + pub fn assign<'v>( + &self, + region: &mut Region, + witness: &[KeccakRow], + ) -> Vec> { + witness + .iter() + .enumerate() + .map(|(offset, keccak_row)| self.set_row(region, offset, keccak_row)) + .collect() + } + + /// Output is `is_final`, `length`, `hash.lo`, `hash.hi` at that row + pub fn set_row<'v>( + &self, + region: &mut Region, + offset: usize, + row: &KeccakRow, + ) -> KeccakAssignedRow<'v, F> { + // Fixed selectors + for (_, column, value) in &[ + ("q_enable", self.q_enable, F::from(row.q_enable)), + ("q_first", self.q_first, F::from(offset == 0)), + ("q_round", self.q_round, F::from(row.q_round)), + ("q_round_last", self.q_round_last, F::from(row.q_round_last)), + ("q_absorb", self.q_absorb, F::from(row.q_absorb)), + ("q_input", self.q_input, F::from(row.q_input)), + ("q_input_last", self.q_input_last, F::from(row.q_input_last)), + ] { + raw_assign_fixed(region, *column, offset, *value); + } + + // Keccak data + let [is_final, hash_lo, hash_hi, bytes_left, word_value] = [ + ("is_final", self.keccak_table.is_enabled, Value::known(F::from(row.is_final))), + ("hash_lo", self.keccak_table.output.lo(), row.hash.lo()), + ("hash_hi", self.keccak_table.output.hi(), row.hash.hi()), + ("bytes_left", self.keccak_table.bytes_left, Value::known(row.bytes_left)), + ("word_value", self.keccak_table.word_value, Value::known(row.word_value)), + ] + .map(|(_name, column, value)| raw_assign_advice(region, column, offset, value)); + + // Cell values + row.cell_values.iter().zip(self.cell_manager.columns()).for_each(|(bit, column)| { + raw_assign_advice(region, column.advice, offset, Value::known(*bit)); + }); + + // Round constant + raw_assign_fixed(region, self.round_cst, offset, row.round_cst); + + KeccakAssignedRow { + is_final, + hash_lo, + hash_hi, + bytes_left, + word_value, + _marker: PhantomData, + } + } + + pub fn load_aux_tables(&self, layouter: &mut impl Layouter, k: u32) -> Result<(), Error> { + load_normalize_table(layouter, "normalize_6", &self.normalize_6, 6u64, k)?; + load_normalize_table(layouter, "normalize_4", &self.normalize_4, 4u64, k)?; + load_normalize_table(layouter, "normalize_3", &self.normalize_3, 3u64, k)?; + load_lookup_table( + layouter, + "chi base", + &self.chi_base_table, + get_num_bits_per_base_chi_lookup(k), + &CHI_BASE_LOOKUP_TABLE, + )?; + load_pack_table(layouter, &self.pack_table) + } +} diff --git a/hashes/zkevm/src/keccak/vanilla/param.rs b/hashes/zkevm/src/keccak/vanilla/param.rs new file mode 100644 index 00000000..abecd264 --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/param.rs @@ -0,0 +1,68 @@ +#![allow(dead_code)] +pub(crate) const MAX_DEGREE: usize = 3; +pub(crate) const ABSORB_LOOKUP_RANGE: usize = 3; +pub(crate) const THETA_C_LOOKUP_RANGE: usize = 6; +pub(crate) const RHO_PI_LOOKUP_RANGE: usize = 4; +pub(crate) const CHI_BASE_LOOKUP_RANGE: usize = 5; + +pub const NUM_BITS_PER_BYTE: usize = 8; +pub const NUM_BYTES_PER_WORD: usize = 8; +pub const NUM_BITS_PER_WORD: usize = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; +pub const KECCAK_WIDTH: usize = 5 * 5; +pub const KECCAK_WIDTH_IN_BITS: usize = KECCAK_WIDTH * NUM_BITS_PER_WORD; +pub const NUM_ROUNDS: usize = 24; +pub const NUM_WORDS_TO_ABSORB: usize = 17; +pub const NUM_BYTES_TO_ABSORB: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +pub const NUM_WORDS_TO_SQUEEZE: usize = 4; +pub const NUM_BYTES_TO_SQUEEZE: usize = NUM_WORDS_TO_SQUEEZE * NUM_BYTES_PER_WORD; +pub const ABSORB_WIDTH_PER_ROW: usize = NUM_BITS_PER_WORD; +pub const ABSORB_WIDTH_PER_ROW_BYTES: usize = ABSORB_WIDTH_PER_ROW / NUM_BITS_PER_BYTE; +pub const RATE: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +pub const RATE_IN_BITS: usize = RATE * NUM_BITS_PER_BYTE; +// pub(crate) const THETA_C_WIDTH: usize = 5 * NUM_BITS_PER_WORD; +pub(crate) const RHO_MATRIX: [[usize; 5]; 5] = [ + [0, 36, 3, 41, 18], + [1, 44, 10, 45, 2], + [62, 6, 43, 15, 61], + [28, 55, 25, 21, 56], + [27, 20, 39, 8, 14], +]; +pub(crate) const ROUND_CST: [u64; NUM_ROUNDS + 1] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808a, + 0x8000000080008000, + 0x000000000000808b, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008a, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000a, + 0x000000008000808b, + 0x800000000000008b, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800a, + 0x800000008000000a, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + 0x0000000000000000, // absorb round +]; +// Bit positions that have a non-zero value in `IOTA_ROUND_CST`. +// pub(crate) const ROUND_CST_BIT_POS: [usize; 7] = [0, 1, 3, 7, 15, 31, 63]; + +// The number of bits used in the sparse word representation per bit +pub(crate) const BIT_COUNT: usize = 3; +// The base of the bit in the sparse word representation +pub(crate) const BIT_SIZE: usize = 2usize.pow(BIT_COUNT as u32); + +// `a ^ ((~b) & c)` is calculated by doing `lookup[3 - 2*a + b - c]` +pub(crate) const CHI_BASE_LOOKUP_TABLE: [u8; 5] = [0, 1, 1, 0, 0]; +// `a ^ ((~b) & c) ^ d` is calculated by doing `lookup[5 - 2*a - b + c - 2*d]` +// pub(crate) const CHI_EXT_LOOKUP_TABLE: [u8; 7] = [0, 0, 1, 1, 0, 0, 1]; diff --git a/hashes/zkevm/src/keccak/vanilla/table.rs b/hashes/zkevm/src/keccak/vanilla/table.rs new file mode 100644 index 00000000..2249005d --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/table.rs @@ -0,0 +1,126 @@ +use super::{param::*, util::*}; +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Value}, + plonk::{Error, TableColumn}, + }, + util::eth_types::Field, +}; +use itertools::Itertools; + +/// Returns how many bits we can process in a single lookup given the range of +/// values the bit can have and the height of the circuit. +pub fn get_num_bits_per_lookup(range: usize, k: u32) -> usize { + let num_unusable_rows = 31; + let mut num_bits = 1; + while range.pow(num_bits + 1) + num_unusable_rows <= 2usize.pow(k) { + num_bits += 1; + } + num_bits as usize +} + +/// Loads a normalization table with the given parameters +pub(crate) fn load_normalize_table( + layouter: &mut impl Layouter, + name: &str, + tables: &[TableColumn; 2], + range: u64, + k: u32, +) -> Result<(), Error> { + let part_size = get_num_bits_per_lookup(range as usize, k); + layouter.assign_table( + || format!("{name} table"), + |mut table| { + for (offset, perm) in + (0..part_size).map(|_| 0u64..range).multi_cartesian_product().enumerate() + { + let mut input = 0u64; + let mut output = 0u64; + let mut factor = 1u64; + for input_part in perm.iter() { + input += input_part * factor; + output += (input_part & 1) * factor; + factor *= BIT_SIZE as u64; + } + table.assign_cell( + || format!("{name} input"), + tables[0], + offset, + || Value::known(F::from(input)), + )?; + table.assign_cell( + || format!("{name} output"), + tables[1], + offset, + || Value::known(F::from(output)), + )?; + } + Ok(()) + }, + ) +} + +/// Loads the byte packing table +pub(crate) fn load_pack_table( + layouter: &mut impl Layouter, + tables: &[TableColumn; 2], +) -> Result<(), Error> { + layouter.assign_table( + || "pack table", + |mut table| { + for (offset, idx) in (0u64..256).enumerate() { + table.assign_cell( + || "unpacked", + tables[0], + offset, + || Value::known(F::from(idx)), + )?; + let packed: F = pack(&into_bits(&[idx as u8])); + table.assign_cell(|| "packed", tables[1], offset, || Value::known(packed))?; + } + Ok(()) + }, + ) +} + +/// Loads a lookup table +pub(crate) fn load_lookup_table( + layouter: &mut impl Layouter, + name: &str, + tables: &[TableColumn; 2], + part_size: usize, + lookup_table: &[u8], +) -> Result<(), Error> { + layouter.assign_table( + || format!("{name} table"), + |mut table| { + for (offset, perm) in (0..part_size) + .map(|_| 0..lookup_table.len() as u64) + .multi_cartesian_product() + .enumerate() + { + let mut input = 0u64; + let mut output = 0u64; + let mut factor = 1u64; + for input_part in perm.iter() { + input += input_part * factor; + output += (lookup_table[*input_part as usize] as u64) * factor; + factor *= BIT_SIZE as u64; + } + table.assign_cell( + || format!("{name} input"), + tables[0], + offset, + || Value::known(F::from(input)), + )?; + table.assign_cell( + || format!("{name} output"), + tables[1], + offset, + || Value::known(F::from(output)), + )?; + } + Ok(()) + }, + ) +} diff --git a/hashes/zkevm/src/keccak/vanilla/tests.rs b/hashes/zkevm/src/keccak/vanilla/tests.rs new file mode 100644 index 00000000..f79aa4b7 --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/tests.rs @@ -0,0 +1,293 @@ +use super::{witness::*, *}; +use crate::halo2_proofs::{ + circuit::SimpleFloorPlanner, + dev::MockProver, + halo2curves::bn256::Fr, + halo2curves::bn256::{Bn256, G1Affine}, + plonk::Circuit, + plonk::{create_proof, keygen_pk, keygen_vk, verify_proof}, + poly::{ + commitment::ParamsProver, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG, ParamsVerifierKZG}, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, +}; +use halo2_base::{ + halo2_proofs::halo2curves::ff::FromUniformBytes, utils::value_to_option, SKIP_FIRST_PASS, +}; +use rand_core::OsRng; +use sha3::{Digest, Keccak256}; +use test_case::test_case; + +/// KeccakCircuit +#[derive(Default, Clone, Debug)] +pub struct KeccakCircuit { + config: KeccakConfigParams, + inputs: Vec>, + num_rows: Option, + verify_output: bool, + _marker: PhantomData, +} + +#[cfg(any(feature = "test", test))] +impl Circuit for KeccakCircuit { + type Config = KeccakCircuitConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = KeccakConfigParams; + + fn params(&self) -> Self::Params { + self.config + } + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + // MockProver complains if you only have columns in SecondPhase, so let's just make an empty column in FirstPhase + meta.advice_column(); + + KeccakCircuitConfig::new(meta, params) + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!() + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let params = config.parameters; + config.load_aux_tables(&mut layouter, params.k)?; + let mut first_pass = SKIP_FIRST_PASS; + layouter.assign_region( + || "keccak circuit", + |mut region| { + if first_pass { + first_pass = false; + return Ok(()); + } + let (witness, _) = multi_keccak( + &self.inputs, + self.num_rows.map(|nr| get_keccak_capacity(nr, params.rows_per_round)), + params, + ); + let assigned_rows = config.assign(&mut region, &witness); + if self.verify_output { + self.verify_output_witnesses(&assigned_rows); + self.verify_input_witnesses(&assigned_rows); + } + Ok(()) + }, + )?; + + Ok(()) + } +} + +impl KeccakCircuit { + /// Creates a new circuit instance + pub fn new( + config: KeccakConfigParams, + num_rows: Option, + inputs: Vec>, + verify_output: bool, + ) -> Self { + KeccakCircuit { config, inputs, num_rows, _marker: PhantomData, verify_output } + } + + fn verify_output_witnesses(&self, assigned_rows: &[KeccakAssignedRow]) { + let mut input_offset = 0; + // only look at last row in each round + // first round is dummy, so ignore + // only look at last round per absorb of RATE_IN_BITS + for assigned_row in + assigned_rows.iter().step_by(self.config.rows_per_round).step_by(NUM_ROUNDS + 1).skip(1) + { + let KeccakAssignedRow { is_final, hash_lo, hash_hi, .. } = assigned_row.clone(); + let is_final_val = extract_value(is_final).ne(&F::ZERO); + let hash_lo_val = extract_u128(hash_lo); + let hash_hi_val = extract_u128(hash_hi); + + if input_offset < self.inputs.len() && is_final_val { + // out is in big endian. + let out = Keccak256::digest(&self.inputs[input_offset]); + let lo = u128::from_be_bytes(out[16..].try_into().unwrap()); + let hi = u128::from_be_bytes(out[..16].try_into().unwrap()); + assert_eq!(lo, hash_lo_val); + assert_eq!(hi, hash_hi_val); + input_offset += 1; + } + } + } + + fn verify_input_witnesses(&self, assigned_rows: &[KeccakAssignedRow]) { + let rows_per_round = self.config.rows_per_round; + let mut input_offset = 0; + let mut input_byte_offset = 0; + // first round is dummy, so ignore + for absorb_chunk in &assigned_rows.chunks(rows_per_round).skip(1).chunks(NUM_ROUNDS + 1) { + let mut absorbed = false; + for (round_idx, assigned_rows) in absorb_chunk.enumerate() { + for (row_idx, assigned_row) in assigned_rows.iter().enumerate() { + let KeccakAssignedRow { is_final, word_value, bytes_left, .. } = + assigned_row.clone(); + let is_final_val = extract_value(is_final).ne(&F::ZERO); + let word_value_val = extract_u128(word_value); + let bytes_left_val = extract_u128(bytes_left); + // Padded inputs - all empty. + if input_offset >= self.inputs.len() { + assert_eq!(word_value_val, 0); + assert_eq!(bytes_left_val, 0); + continue; + } + let input_len = self.inputs[input_offset].len(); + if round_idx == NUM_ROUNDS && row_idx == 0 && is_final_val { + absorbed = true; + } + if row_idx == 0 { + assert_eq!(bytes_left_val, input_len as u128 - input_byte_offset as u128); + // Only these rows could contain inputs. + let end = if round_idx < NUM_WORDS_TO_ABSORB { + std::cmp::min(input_byte_offset + NUM_BYTES_PER_WORD, input_len) + } else { + input_byte_offset + }; + let mut expected_val_le_bytes = + self.inputs[input_offset][input_byte_offset..end].to_vec().clone(); + expected_val_le_bytes.resize(NUM_BYTES_PER_WORD, 0); + assert_eq!( + word_value_val, + u64::from_le_bytes(expected_val_le_bytes.try_into().unwrap()) as u128, + ); + input_byte_offset = end; + } + } + } + if absorbed { + input_offset += 1; + input_byte_offset = 0; + } + } + } +} + +fn verify>( + config: KeccakConfigParams, + inputs: Vec>, + _success: bool, +) { + let k = config.k; + let circuit = KeccakCircuit::new(config, Some(2usize.pow(k) - 109), inputs, true); + + let prover = MockProver::::run(k, &circuit, vec![]).unwrap(); + prover.assert_satisfied(); +} + +fn extract_value(assigned_value: KeccakAssignedValue) -> F { + #[cfg(feature = "halo2-axiom")] + let assigned = **value_to_option(assigned_value.value()).unwrap(); + #[cfg(not(feature = "halo2-axiom"))] + let assigned = *value_to_option(assigned_value.value()).unwrap(); + match assigned { + halo2_base::halo2_proofs::plonk::Assigned::Zero => F::ZERO, + halo2_base::halo2_proofs::plonk::Assigned::Trivial(f) => f, + _ => panic!("value should be trival"), + } +} + +fn extract_u128(assigned_value: KeccakAssignedValue) -> u128 { + let le_bytes = extract_value(assigned_value).to_bytes_le(); + let hi = u128::from_le_bytes(le_bytes[16..].try_into().unwrap()); + assert_eq!(hi, 0); + u128::from_le_bytes(le_bytes[..16].try_into().unwrap()) +} + +#[test_case(14, 28; "k: 14, rows_per_round: 28")] +#[test_case(12, 5; "k: 12, rows_per_round: 5")] +fn packed_multi_keccak_simple(k: u32, rows_per_round: usize) { + let _ = env_logger::builder().is_test(true).try_init(); + { + // First input is empty. + let inputs = vec![ + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + verify::(KeccakConfigParams { k, rows_per_round }, inputs, true); + } + { + // First input is not empty. + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + ]; + verify::(KeccakConfigParams { k, rows_per_round }, inputs, true); + } +} + +#[test_case(14, 25 ; "k: 14, rows_per_round: 25")] +#[test_case(18, 9 ; "k: 18, rows_per_round: 9")] +fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { + let _ = env_logger::builder().is_test(true).try_init(); + + let params = ParamsKZG::::setup(k, OsRng); + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + ]; + let circuit = KeccakCircuit::new( + KeccakConfigParams { k, rows_per_round }, + Some(2usize.pow(k)), + inputs, + false, + ); + + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + + let verifier_params: ParamsVerifierKZG = params.verifier_params().clone(); + let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); + + let start = std::time::Instant::now(); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255>, + _, + >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("proof generation should not fail"); + let proof = transcript.finalize(); + dbg!(start.elapsed()); + + let mut verifier_transcript = Blake2bRead::<_, G1Affine, Challenge255<_>>::init(&proof[..]); + let strategy = SingleStrategy::new(¶ms); + + verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >(&verifier_params, pk.get_vk(), strategy, &[&[]], &mut verifier_transcript) + .expect("failed to verify bench circuit"); +} diff --git a/hashes/zkevm/src/keccak/vanilla/util.rs b/hashes/zkevm/src/keccak/vanilla/util.rs new file mode 100644 index 00000000..f76d7099 --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/util.rs @@ -0,0 +1,251 @@ +//! Utility traits, functions used in the crate. +use super::param::*; +use crate::util::eth_types::{Field, ToScalar, Word}; + +/// Description of which bits (positions) a part contains +#[derive(Clone, Debug)] +pub struct PartInfo { + /// The bit positions of the part + pub bits: Vec, +} + +/// Description of how a word is split into parts +#[derive(Clone, Debug)] +pub struct WordParts { + /// The parts of the word + pub parts: Vec, +} + +impl WordParts { + /// Returns a description of how a word will be split into parts + pub fn new(part_size: usize, rot: usize, normalize: bool) -> Self { + let mut bits = (0usize..64).collect::>(); + bits.rotate_right(rot); + + let mut parts = Vec::new(); + let mut rot_idx = 0; + + let mut idx = 0; + let target_sizes = if normalize { + // After the rotation we want the parts of all the words to be at the same + // positions + target_part_sizes(part_size) + } else { + // Here we only care about minimizing the number of parts + let num_parts_a = rot / part_size; + let partial_part_a = rot % part_size; + + let num_parts_b = (64 - rot) / part_size; + let partial_part_b = (64 - rot) % part_size; + + let mut part_sizes = vec![part_size; num_parts_a]; + if partial_part_a > 0 { + part_sizes.push(partial_part_a); + } + + part_sizes.extend(vec![part_size; num_parts_b]); + if partial_part_b > 0 { + part_sizes.push(partial_part_b); + } + + part_sizes + }; + // Split into parts bit by bit + for part_size in target_sizes { + let mut num_consumed = 0; + while num_consumed < part_size { + let mut part_bits: Vec = Vec::new(); + while num_consumed < part_size { + if !part_bits.is_empty() && bits[idx] == 0 { + break; + } + if bits[idx] == 0 { + rot_idx = parts.len(); + } + part_bits.push(bits[idx]); + idx += 1; + num_consumed += 1; + } + parts.push(PartInfo { bits: part_bits }); + } + } + + debug_assert_eq!(get_rotate_count(rot, part_size), rot_idx); + + parts.rotate_left(rot_idx); + debug_assert_eq!(parts[0].bits[0], 0); + + Self { parts } + } +} + +/// Rotates a word that was split into parts to the right +pub fn rotate(parts: Vec, count: usize, part_size: usize) -> Vec { + let mut rotated_parts = parts; + rotated_parts.rotate_right(get_rotate_count(count, part_size)); + rotated_parts +} + +/// Rotates a word that was split into parts to the left +pub fn rotate_rev(parts: Vec, count: usize, part_size: usize) -> Vec { + let mut rotated_parts = parts; + rotated_parts.rotate_left(get_rotate_count(count, part_size)); + rotated_parts +} + +/// Rotates bits left +pub fn rotate_left(bits: &[u8], count: usize) -> [u8; NUM_BITS_PER_WORD] { + let mut rotated = bits.to_vec(); + rotated.rotate_left(count); + rotated.try_into().unwrap() +} + +/// The words that absorb data +pub fn get_absorb_positions() -> Vec<(usize, usize)> { + let mut absorb_positions = Vec::new(); + for j in 0..5 { + for i in 0..5 { + if i + j * 5 < 17 { + absorb_positions.push((i, j)); + } + } + } + absorb_positions +} + +/// Converts bytes into bits +pub fn into_bits(bytes: &[u8]) -> Vec { + let mut bits: Vec = vec![0; bytes.len() * 8]; + for (byte_idx, byte) in bytes.iter().enumerate() { + for idx in 0u64..8 { + bits[byte_idx * 8 + (idx as usize)] = (*byte >> idx) & 1; + } + } + bits +} + +/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word +pub fn pack(bits: &[u8]) -> F { + pack_with_base(bits, BIT_SIZE) +} + +/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word with the +/// specified bit base +pub fn pack_with_base(bits: &[u8], base: usize) -> F { + let base = F::from(base as u64); + bits.iter().rev().fold(F::ZERO, |acc, &bit| acc * base + F::from(bit as u64)) +} + +/// Decodes the bits using the position data found in the part info +pub fn pack_part(bits: &[u8], info: &PartInfo) -> u64 { + info.bits + .iter() + .rev() + .fold(0u64, |acc, &bit_pos| acc * (BIT_SIZE as u64) + (bits[bit_pos] as u64)) +} + +/// Unpack a sparse keccak word into bits in the range [0,BIT_SIZE[ +pub fn unpack(packed: F) -> [u8; NUM_BITS_PER_WORD] { + let mut bits = [0; NUM_BITS_PER_WORD]; + let packed = Word::from_little_endian(packed.to_bytes_le().as_ref()); + let mask = Word::from(BIT_SIZE - 1); + for (idx, bit) in bits.iter_mut().enumerate() { + *bit = ((packed >> (idx * BIT_COUNT)) & mask).as_u32() as u8; + } + debug_assert_eq!(pack::(&bits), packed.to_scalar().unwrap()); + bits +} + +/// Pack bits stored in a u64 value into a sparse keccak word +pub fn pack_u64(value: u64) -> F { + pack(&((0..NUM_BITS_PER_WORD).map(|i| ((value >> i) & 1) as u8).collect::>())) +} + +/// Calculates a ^ b with a and b field elements +pub fn field_xor(a: F, b: F) -> F { + let mut bytes = [0u8; 32]; + for (idx, (a, b)) in a.to_bytes_le().into_iter().zip(b.to_bytes_le()).enumerate() { + bytes[idx] = a ^ b; + } + F::from_bytes_le(&bytes) +} + +/// Returns the size (in bits) of each part size when splitting up a keccak word +/// in parts of `part_size` +pub fn target_part_sizes(part_size: usize) -> Vec { + let num_full_chunks = NUM_BITS_PER_WORD / part_size; + let partial_chunk_size = NUM_BITS_PER_WORD % part_size; + let mut part_sizes = vec![part_size; num_full_chunks]; + if partial_chunk_size > 0 { + part_sizes.push(partial_chunk_size); + } + part_sizes +} + +/// Gets the rotation count in parts +pub fn get_rotate_count(count: usize, part_size: usize) -> usize { + (count + part_size - 1) / part_size +} + +/// Encodes the data using rlc +pub mod compose_rlc { + use crate::halo2_proofs::plonk::Expression; + use crate::util::eth_types::Field; + + #[allow(dead_code)] + pub(crate) fn expr(expressions: &[Expression], r: F) -> Expression { + let mut rlc = expressions[0].clone(); + let mut multiplier = r; + for expression in expressions[1..].iter() { + rlc = rlc + expression.clone() * multiplier; + multiplier *= r; + } + rlc + } +} + +/// Packs bits into bytes +pub mod to_bytes { + use crate::util::eth_types::Field; + use crate::util::expression::Expr; + use halo2_base::halo2_proofs::plonk::Expression; + + pub fn expr(bits: &[Expression]) -> Vec> { + debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); + let mut bytes = Vec::new(); + for byte_bits in bits.chunks(8) { + let mut value = 0.expr(); + let mut multiplier = F::ONE; + for byte in byte_bits.iter() { + value = value + byte.expr() * multiplier; + multiplier *= F::from(2); + } + bytes.push(value); + } + bytes + } + + pub fn value(bits: &[u8]) -> Vec { + debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); + let mut bytes = Vec::new(); + for byte_bits in bits.chunks(8) { + let mut value = 0u8; + for (idx, bit) in byte_bits.iter().enumerate() { + value += *bit << idx; + } + bytes.push(value); + } + bytes + } +} + +/// Scatters a value into a packed word constant +pub mod scatter { + use super::pack; + use crate::halo2_proofs::plonk::Expression; + use crate::util::eth_types::Field; + + pub(crate) fn expr(value: u8, count: usize) -> Expression { + Expression::Constant(pack(&vec![value; count])) + } +} diff --git a/hashes/zkevm/src/keccak/vanilla/witness.rs b/hashes/zkevm/src/keccak/vanilla/witness.rs new file mode 100644 index 00000000..d97d487d --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/witness.rs @@ -0,0 +1,418 @@ +// This file is moved out from mod.rs. +use super::*; + +/// Witness generation for multiple keccak hashes of little-endian `bytes`. +pub fn multi_keccak( + bytes: &[Vec], + capacity: Option, + parameters: KeccakConfigParams, +) -> (Vec>, Vec<[F; NUM_WORDS_TO_SQUEEZE]>) { + let num_rows_per_round = parameters.rows_per_round; + let mut rows = + Vec::with_capacity((1 + capacity.unwrap_or(0) * (NUM_ROUNDS + 1)) * num_rows_per_round); + // Dummy first row so that the initial data is absorbed + // The initial data doesn't really matter, `is_final` just needs to be disabled. + rows.append(&mut KeccakRow::dummy_rows(num_rows_per_round)); + // Actual keccaks + let artifacts = bytes + .par_iter() + .map(|bytes| { + let num_keccak_f = get_num_keccak_f(bytes.len()); + let mut squeeze_digests = Vec::with_capacity(num_keccak_f); + let mut rows = Vec::with_capacity(num_keccak_f * (NUM_ROUNDS + 1) * num_rows_per_round); + keccak(&mut rows, &mut squeeze_digests, bytes, parameters); + (rows, squeeze_digests) + }) + .collect::>(); + + let mut squeeze_digests = Vec::with_capacity(capacity.unwrap_or(0)); + for (rows_part, squeezes) in artifacts { + rows.extend(rows_part); + squeeze_digests.extend(squeezes); + } + + if let Some(capacity) = capacity { + // Pad with no data hashes to the expected capacity + while rows.len() < (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { + keccak(&mut rows, &mut squeeze_digests, &[], parameters); + } + // Check that we are not over capacity + if rows.len() > (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { + panic!("{:?}", Error::BoundsFailure); + } + } + (rows, squeeze_digests) +} +/// Witness generation for keccak hash of little-endian `bytes`. +fn keccak( + rows: &mut Vec>, + squeeze_digests: &mut Vec<[F; NUM_WORDS_TO_SQUEEZE]>, + bytes: &[u8], + parameters: KeccakConfigParams, +) { + let k = parameters.k; + let num_rows_per_round = parameters.rows_per_round; + + let mut bits = into_bits(bytes); + let mut s = [[F::ZERO; 5]; 5]; + let absorb_positions = get_absorb_positions(); + let num_bytes_in_last_block = bytes.len() % RATE; + let two = F::from(2u64); + + // Padding + bits.push(1); + while (bits.len() + 1) % RATE_IN_BITS != 0 { + bits.push(0); + } + bits.push(1); + + // running length of absorbed input in bytes + let mut length = 0; + let chunks = bits.chunks(RATE_IN_BITS); + let num_chunks = chunks.len(); + + let mut cell_managers = Vec::with_capacity(NUM_ROUNDS + 1); + let mut regions = Vec::with_capacity(NUM_ROUNDS + 1); + // keeps track of running lengths over all rounds in an absorb step + let mut round_lengths = Vec::with_capacity(NUM_ROUNDS + 1); + let mut hash_words = [F::ZERO; NUM_WORDS_TO_SQUEEZE]; + let mut hash = Word::default(); + + for (idx, chunk) in chunks.enumerate() { + let is_final_block = idx == num_chunks - 1; + + let mut absorb_rows = Vec::new(); + // Absorb + for (idx, &(i, j)) in absorb_positions.iter().enumerate() { + let absorb = pack(&chunk[idx * 64..(idx + 1) * 64]); + let from = s[i][j]; + s[i][j] = field_xor(s[i][j], absorb); + absorb_rows.push(AbsorbData { from, absorb, result: s[i][j] }); + } + + // better memory management to clear already allocated Vecs + cell_managers.clear(); + regions.clear(); + round_lengths.clear(); + + for round in 0..NUM_ROUNDS + 1 { + let mut cell_manager = CellManager::new(num_rows_per_round); + let mut region = KeccakRegion::new(); + + let mut absorb_row = AbsorbData::default(); + if round < NUM_WORDS_TO_ABSORB { + absorb_row = absorb_rows[round].clone(); + } + + // State data + for s in &s { + for s in s { + let cell = cell_manager.query_cell_value(); + cell.assign(&mut region, 0, *s); + } + } + + // Absorb data + let absorb_from = cell_manager.query_cell_value(); + let absorb_data = cell_manager.query_cell_value(); + let absorb_result = cell_manager.query_cell_value(); + absorb_from.assign(&mut region, 0, absorb_row.from); + absorb_data.assign(&mut region, 0, absorb_row.absorb); + absorb_result.assign(&mut region, 0, absorb_row.result); + + // Absorb + cell_manager.start_region(); + let part_size = get_num_bits_per_absorb_lookup(k); + let input = absorb_row.from + absorb_row.absorb; + let absorb_fat = + split::value(&mut cell_manager, &mut region, input, 0, part_size, false, None); + cell_manager.start_region(); + let _absorb_result = transform::value( + &mut cell_manager, + &mut region, + absorb_fat.clone(), + true, + |v| v & 1, + true, + ); + + // Padding + cell_manager.start_region(); + // Unpack a single word into bytes (for the absorption) + // Potential optimization: could do multiple bytes per lookup + let packed = + split::value(&mut cell_manager, &mut region, absorb_row.absorb, 0, 8, false, None); + cell_manager.start_region(); + let input_bytes = + transform::value(&mut cell_manager, &mut region, packed, false, |v| *v, true); + cell_manager.start_region(); + let is_paddings = + input_bytes.iter().map(|_| cell_manager.query_cell_value()).collect::>(); + debug_assert_eq!(is_paddings.len(), NUM_BYTES_PER_WORD); + if round < NUM_WORDS_TO_ABSORB { + for (padding_idx, is_padding) in is_paddings.iter().enumerate() { + let byte_idx = round * NUM_BYTES_PER_WORD + padding_idx; + let padding = if is_final_block && byte_idx >= num_bytes_in_last_block { + true + } else { + length += 1; + false + }; + is_padding.assign(&mut region, 0, F::from(padding)); + } + } + cell_manager.start_region(); + + if round != NUM_ROUNDS { + // Theta + let part_size = get_num_bits_per_theta_c_lookup(k); + let mut bcf = Vec::new(); + for s in &s { + let c = s[0] + s[1] + s[2] + s[3] + s[4]; + let bc_fat = + split::value(&mut cell_manager, &mut region, c, 1, part_size, false, None); + bcf.push(bc_fat); + } + cell_manager.start_region(); + let mut bc = Vec::new(); + for bc_fat in bcf { + let bc_norm = transform::value( + &mut cell_manager, + &mut region, + bc_fat.clone(), + true, + |v| v & 1, + true, + ); + bc.push(bc_norm); + } + cell_manager.start_region(); + let mut os = [[F::ZERO; 5]; 5]; + for i in 0..5 { + let t = decode::value(bc[(i + 4) % 5].clone()) + + decode::value(rotate(bc[(i + 1) % 5].clone(), 1, part_size)); + for j in 0..5 { + os[i][j] = s[i][j] + t; + } + } + s = os; + cell_manager.start_region(); + + // Rho/Pi + let part_size = get_num_bits_per_base_chi_lookup(k); + let target_word_sizes = target_part_sizes(part_size); + let num_word_parts = target_word_sizes.len(); + let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = + array_init::array_init(|_| { + array_init::array_init(|_| array_init::array_init(|_| Vec::new())) + }); + let mut column_starts = [0usize; 3]; + for p in 0..3 { + column_starts[p] = cell_manager.start_region(); + let mut row_idx = 0; + for j in 0..5 { + for _ in 0..num_word_parts { + for i in 0..5 { + rho_pi_chi_cells[p][i][j] + .push(cell_manager.query_cell_value_at_row(row_idx as i32)); + } + row_idx = (row_idx + 1) % num_rows_per_round; + } + } + } + cell_manager.start_region(); + let mut os_parts: [[Vec>; 5]; 5] = + array_init::array_init(|_| array_init::array_init(|_| Vec::new())); + for (j, os_part) in os_parts.iter_mut().enumerate() { + for i in 0..5 { + let s_parts = split_uniform::value( + &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], + &mut cell_manager, + &mut region, + s[i][j], + RHO_MATRIX[i][j], + part_size, + true, + ); + + let s_parts = transform_to::value( + &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], + &mut region, + s_parts.clone(), + true, + |v| v & 1, + ); + os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); + } + } + cell_manager.start_region(); + + // Chi + let part_size_base = get_num_bits_per_base_chi_lookup(k); + let three_packed = pack::(&vec![3u8; part_size_base]); + let mut os = [[F::ZERO; 5]; 5]; + for j in 0..5 { + for i in 0..5 { + let mut s_parts = Vec::new(); + for ((part_a, part_b), part_c) in os_parts[i][j] + .iter() + .zip(os_parts[(i + 1) % 5][j].iter()) + .zip(os_parts[(i + 2) % 5][j].iter()) + { + let value = + three_packed - two * part_a.value + part_b.value - part_c.value; + s_parts.push(PartValue { + num_bits: part_size_base, + rot: j as i32, + value, + }); + } + os[i][j] = decode::value(transform_to::value( + &rho_pi_chi_cells[2][i][j], + &mut region, + s_parts.clone(), + true, + |v| CHI_BASE_LOOKUP_TABLE[*v as usize], + )); + } + } + s = os; + cell_manager.start_region(); + + // iota + let part_size = get_num_bits_per_absorb_lookup(k); + let input = s[0][0] + pack_u64::(ROUND_CST[round]); + let iota_parts = split::value::( + &mut cell_manager, + &mut region, + input, + 0, + part_size, + false, + None, + ); + cell_manager.start_region(); + s[0][0] = decode::value(transform::value( + &mut cell_manager, + &mut region, + iota_parts.clone(), + true, + |v| v & 1, + true, + )); + } + + // Assign the hash result + let is_final = is_final_block && round == NUM_ROUNDS; + hash = if is_final { + let hash_bytes_le = s + .into_iter() + .take(4) + .flat_map(|a| to_bytes::value(&unpack(a[0]))) + .rev() + .collect::>(); + + let word: Word> = + Word::from(eth_types::Word::from_little_endian(hash_bytes_le.as_slice())) + .map(Value::known); + word + } else { + Word::default().into_value() + }; + + // The words to squeeze out: this is the hash digest as words with + // NUM_BYTES_PER_WORD (=8) bytes each + for (hash_word, a) in hash_words.iter_mut().zip(s.iter()) { + *hash_word = a[0]; + } + + round_lengths.push(length); + + cell_managers.push(cell_manager); + regions.push(region); + } + + // Now that we know the state at the end of the rounds, set the squeeze data + let num_rounds = cell_managers.len(); + for (idx, word) in hash_words.iter().enumerate() { + let cell_manager = &mut cell_managers[num_rounds - 2 - idx]; + let region = &mut regions[num_rounds - 2 - idx]; + + cell_manager.start_region(); + let squeeze_packed = cell_manager.query_cell_value(); + squeeze_packed.assign(region, 0, *word); + + cell_manager.start_region(); + let packed = split::value(cell_manager, region, *word, 0, 8, false, None); + cell_manager.start_region(); + transform::value(cell_manager, region, packed, false, |v| *v, true); + } + squeeze_digests.push(hash_words); + + for round in 0..NUM_ROUNDS + 1 { + let round_cst = pack_u64(ROUND_CST[round]); + + for row_idx in 0..num_rows_per_round { + let word_value = if round < NUM_WORDS_TO_ABSORB && row_idx == 0 { + let byte_idx = (idx * NUM_WORDS_TO_ABSORB + round) * NUM_BYTES_PER_WORD; + if byte_idx >= bytes.len() { + 0 + } else { + let end = std::cmp::min(byte_idx + NUM_BYTES_PER_WORD, bytes.len()); + let mut word_bytes = bytes[byte_idx..end].to_vec().clone(); + word_bytes.resize(NUM_BYTES_PER_WORD, 0); + u64::from_le_bytes(word_bytes.try_into().unwrap()) + } + } else { + 0 + }; + let byte_idx = if round < NUM_WORDS_TO_ABSORB { + round * NUM_BYTES_PER_WORD + std::cmp::min(row_idx, NUM_BYTES_PER_WORD - 1) + } else { + NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD + } + idx * NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; + let bytes_left = if byte_idx >= bytes.len() { 0 } else { bytes.len() - byte_idx }; + rows.push(KeccakRow { + q_enable: row_idx == 0, + q_round: row_idx == 0 && round < NUM_ROUNDS, + q_absorb: row_idx == 0 && round == NUM_ROUNDS, + q_round_last: row_idx == 0 && round == NUM_ROUNDS, + q_input: row_idx == 0 && round < NUM_WORDS_TO_ABSORB, + q_input_last: row_idx == 0 && round == NUM_WORDS_TO_ABSORB - 1, + round_cst, + is_final: is_final_block && round == NUM_ROUNDS && row_idx == 0, + cell_values: regions[round].rows.get(row_idx).unwrap_or(&vec![]).clone(), + hash, + bytes_left: F::from_u128(bytes_left as u128), + word_value: F::from_u128(word_value as u128), + }); + #[cfg(debug_assertions)] + { + let mut r = rows.last().unwrap().clone(); + r.cell_values.clear(); + log::trace!("offset {:?} row idx {} row {:?}", rows.len() - 1, row_idx, r); + } + } + log::trace!(" = = = = = = round {} end", round); + } + log::trace!(" ====================== chunk {} end", idx); + } + + #[cfg(debug_assertions)] + { + let hash_bytes = s + .into_iter() + .take(4) + .map(|a| { + pack_with_base::(&unpack(a[0]), 2) + .to_bytes_le() + .into_iter() + .take(8) + .collect::>() + .to_vec() + }) + .collect::>(); + debug!("hash: {:x?}", &(hash_bytes[0..4].concat())); + assert_eq!(length, bytes.len()); + } +} diff --git a/hashes/zkevm-keccak/src/lib.rs b/hashes/zkevm/src/lib.rs similarity index 74% rename from hashes/zkevm-keccak/src/lib.rs rename to hashes/zkevm/src/lib.rs index e51bd006..272e4bf8 100644 --- a/hashes/zkevm-keccak/src/lib.rs +++ b/hashes/zkevm/src/lib.rs @@ -4,8 +4,6 @@ use halo2_base::halo2_proofs; /// Keccak packed multi -pub mod keccak_packed_multi; +pub mod keccak; /// Util pub mod util; - -pub use keccak_packed_multi::KeccakCircuitConfig as KeccakConfig; diff --git a/hashes/zkevm-keccak/src/util/constraint_builder.rs b/hashes/zkevm/src/util/constraint_builder.rs similarity index 81% rename from hashes/zkevm-keccak/src/util/constraint_builder.rs rename to hashes/zkevm/src/util/constraint_builder.rs index bae9f4a4..a93a1802 100644 --- a/hashes/zkevm-keccak/src/util/constraint_builder.rs +++ b/hashes/zkevm/src/util/constraint_builder.rs @@ -1,5 +1,5 @@ -use super::expression::Expr; -use crate::halo2_proofs::{arithmetic::FieldExt, plonk::Expression}; +use super::{expression::Expr, word::Word}; +use crate::halo2_proofs::{halo2curves::ff::PrimeField, plonk::Expression}; #[derive(Default)] pub struct BaseConstraintBuilder { @@ -8,7 +8,7 @@ pub struct BaseConstraintBuilder { pub condition: Option>, } -impl BaseConstraintBuilder { +impl BaseConstraintBuilder { pub(crate) fn new(max_degree: usize) -> Self { BaseConstraintBuilder { constraints: Vec::new(), max_degree, condition: None } } @@ -17,6 +17,18 @@ impl BaseConstraintBuilder { self.add_constraint(name, constraint); } + pub(crate) fn require_equal_word( + &mut self, + name: &'static str, + lhs: Word>, + rhs: Word>, + ) { + let (lhs_lo, lhs_hi) = lhs.to_lo_hi(); + let (rhs_lo, rhs_hi) = rhs.to_lo_hi(); + self.add_constraint(name, lhs_lo - rhs_lo); + self.add_constraint(name, lhs_hi - rhs_hi); + } + pub(crate) fn require_equal( &mut self, name: &'static str, diff --git a/hashes/zkevm-keccak/src/util/eth_types.rs b/hashes/zkevm/src/util/eth_types.rs similarity index 100% rename from hashes/zkevm-keccak/src/util/eth_types.rs rename to hashes/zkevm/src/util/eth_types.rs diff --git a/hashes/zkevm-keccak/src/util/expression.rs b/hashes/zkevm/src/util/expression.rs similarity index 55% rename from hashes/zkevm-keccak/src/util/expression.rs rename to hashes/zkevm/src/util/expression.rs index fa0ee216..57e2511b 100644 --- a/hashes/zkevm-keccak/src/util/expression.rs +++ b/hashes/zkevm/src/util/expression.rs @@ -1,34 +1,34 @@ -use crate::halo2_proofs::{arithmetic::FieldExt, plonk::Expression}; +use crate::halo2_proofs::{halo2curves::ff::PrimeField, plonk::Expression}; /// Returns the sum of the passed in cells pub mod sum { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression for the sum of the list of expressions. - pub fn expr, I: IntoIterator>(inputs: I) -> Expression { + pub fn expr, I: IntoIterator>(inputs: I) -> Expression { inputs.into_iter().fold(0.expr(), |acc, input| acc + input.expr()) } /// Returns the sum of the given list of values within the field. - pub fn value(values: &[u8]) -> F { - values.iter().fold(F::zero(), |acc, value| acc + F::from(*value as u64)) + pub fn value(values: &[u8]) -> F { + values.iter().fold(F::ZERO, |acc, value| acc + F::from(*value as u64)) } } /// Returns `1` when `expr[0] && expr[1] && ... == 1`, and returns `0` /// otherwise. Inputs need to be boolean pub mod and { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that evaluates to 1 only if all the expressions in /// the given list are 1, else returns 0. - pub fn expr, I: IntoIterator>(inputs: I) -> Expression { + pub fn expr, I: IntoIterator>(inputs: I) -> Expression { inputs.into_iter().fold(1.expr(), |acc, input| acc * input.expr()) } /// Returns the product of all given values. - pub fn value(inputs: Vec) -> F { - inputs.iter().fold(F::one(), |acc, input| acc * input) + pub fn value(inputs: Vec) -> F { + inputs.iter().fold(F::ONE, |acc, input| acc * input) } } @@ -36,16 +36,16 @@ pub mod and { /// otherwise. Inputs need to be boolean pub mod or { use super::{and, not}; - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that evaluates to 1 if any expression in the given /// list is 1. Returns 0 if all the expressions were 0. - pub fn expr, I: IntoIterator>(inputs: I) -> Expression { + pub fn expr, I: IntoIterator>(inputs: I) -> Expression { not::expr(and::expr(inputs.into_iter().map(not::expr))) } /// Returns the value after passing all given values through the OR gate. - pub fn value(inputs: Vec) -> F { + pub fn value(inputs: Vec) -> F { not::value(and::value(inputs.into_iter().map(not::value).collect())) } } @@ -53,31 +53,31 @@ pub mod or { /// Returns `1` when `b == 0`, and returns `0` otherwise. /// `b` needs to be boolean pub mod not { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that represents the NOT of the given expression. - pub fn expr>(b: E) -> Expression { + pub fn expr>(b: E) -> Expression { 1.expr() - b.expr() } /// Returns a value that represents the NOT of the given value. - pub fn value(b: F) -> F { - F::one() - b + pub fn value(b: F) -> F { + F::ONE - b } } /// Returns `a ^ b`. /// `a` and `b` needs to be boolean pub mod xor { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that represents the XOR of the given expression. - pub fn expr>(a: E, b: E) -> Expression { + pub fn expr>(a: E, b: E) -> Expression { a.expr() + b.expr() - 2.expr() * a.expr() * b.expr() } /// Returns a value that represents the XOR of the given value. - pub fn value(a: F, b: F) -> F { + pub fn value(a: F, b: F) -> F { a + b - F::from(2u64) * a * b } } @@ -85,11 +85,11 @@ pub mod xor { /// Returns `when_true` when `selector == 1`, and returns `when_false` when /// `selector == 0`. `selector` needs to be boolean. pub mod select { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns the `when_true` expression when the selector is true, else /// returns the `when_false` expression. - pub fn expr( + pub fn expr( selector: Expression, when_true: Expression, when_false: Expression, @@ -99,18 +99,18 @@ pub mod select { /// Returns the `when_true` value when the selector is true, else returns /// the `when_false` value. - pub fn value(selector: F, when_true: F, when_false: F) -> F { - selector * when_true + (F::one() - selector) * when_false + pub fn value(selector: F, when_true: F, when_false: F) -> F { + selector * when_true + (F::ONE - selector) * when_false } /// Returns the `when_true` word when selector is true, else returns the /// `when_false` word. - pub fn value_word( + pub fn value_word( selector: F, when_true: [u8; 32], when_false: [u8; 32], ) -> [u8; 32] { - if selector == F::one() { + if selector == F::ONE { when_true } else { when_false @@ -118,9 +118,38 @@ pub mod select { } } +/// Decodes a field element from its byte representation in little endian order +pub mod from_bytes { + use super::{Expr, Expression, PrimeField}; + + pub fn expr>(bytes: &[E]) -> Expression { + let mut value = 0.expr(); + let mut multiplier = F::ONE; + for byte in bytes.iter() { + value = value + byte.expr() * multiplier; + multiplier *= F::from(256); + } + value + } + + pub fn value(bytes: &[u8]) -> F { + let mut value = F::ZERO; + let mut multiplier = F::ONE; + let two_pow_64 = F::from_u128(1u128 << 64); + let two_pow_128 = two_pow_64 * two_pow_64; + for u128_chunk in bytes.chunks(u128::BITS as usize / u8::BITS as usize) { + let mut buffer = [0; 16]; + buffer[..u128_chunk.len()].copy_from_slice(u128_chunk); + value += F::from_u128(u128::from_le_bytes(buffer)) * multiplier; + multiplier *= two_pow_128; + } + value + } +} + /// Trait that implements functionality to get a constant expression from /// commonly used types. -pub trait Expr { +pub trait Expr { /// Returns an expression for the type. fn expr(&self) -> Expression; } @@ -129,7 +158,7 @@ pub trait Expr { #[macro_export] macro_rules! impl_expr { ($type:ty) => { - impl Expr for $type { + impl Expr for $type { #[inline] fn expr(&self) -> Expression { Expression::Constant(F::from(*self as u64)) @@ -137,7 +166,7 @@ macro_rules! impl_expr { } }; ($type:ty, $method:path) => { - impl Expr for $type { + impl Expr for $type { #[inline] fn expr(&self) -> Expression { Expression::Constant(F::from($method(self) as u64)) @@ -151,43 +180,30 @@ impl_expr!(u8); impl_expr!(u64); impl_expr!(usize); -impl Expr for Expression { +impl Expr for Expression { #[inline] fn expr(&self) -> Expression { self.clone() } } -impl Expr for &Expression { +impl Expr for &Expression { #[inline] fn expr(&self) -> Expression { (*self).clone() } } -impl Expr for i32 { +impl Expr for i32 { #[inline] fn expr(&self) -> Expression { Expression::Constant( - F::from(self.unsigned_abs() as u64) - * if self.is_negative() { -F::one() } else { F::one() }, + F::from(self.unsigned_abs() as u64) * if self.is_negative() { -F::ONE } else { F::ONE }, ) } } -/// Given a bytes-representation of an expression, it computes and returns the -/// single expression. -pub fn expr_from_bytes>(bytes: &[E]) -> Expression { - let mut value = 0.expr(); - let mut multiplier = F::one(); - for byte in bytes.iter() { - value = value + byte.expr() * multiplier; - multiplier *= F::from(256); - } - value -} - -/// Returns 2**by as FieldExt -pub fn pow_of_two(by: usize) -> F { - F::from(2).pow(&[by as u64, 0, 0, 0]) +/// Returns 2**by as PrimeField +pub fn pow_of_two(by: usize) -> F { + F::from(2).pow([by as u64]) } diff --git a/hashes/zkevm/src/util/mod.rs b/hashes/zkevm/src/util/mod.rs new file mode 100644 index 00000000..e5f9463e --- /dev/null +++ b/hashes/zkevm/src/util/mod.rs @@ -0,0 +1,4 @@ +pub mod constraint_builder; +pub mod eth_types; +pub mod expression; +pub mod word; diff --git a/hashes/zkevm/src/util/word.rs b/hashes/zkevm/src/util/word.rs new file mode 100644 index 00000000..1d417fbb --- /dev/null +++ b/hashes/zkevm/src/util/word.rs @@ -0,0 +1,328 @@ +//! Define generic Word type with utility functions +// Naming Convesion +// - Limbs: An EVM word is 256 bits **big-endian**. Limbs N means split 256 into N limb. For example, N = 4, each +// limb is 256/4 = 64 bits + +use super::{ + eth_types::{self, Field, ToLittleEndian, H160, H256}, + expression::{from_bytes, not, or, Expr}, +}; +use crate::halo2_proofs::{ + circuit::Value, + plonk::{Advice, Column, Expression, VirtualCells}, + poly::Rotation, +}; +use itertools::Itertools; + +/// evm word 32 bytes, half word 16 bytes +const N_BYTES_HALF_WORD: usize = 16; + +/// The EVM word for witness +#[derive(Clone, Debug, Copy)] +pub struct WordLimbs { + /// The limbs of this word. + pub limbs: [T; N], +} + +pub(crate) type Word2 = WordLimbs; + +#[allow(dead_code)] +pub(crate) type Word4 = WordLimbs; + +#[allow(dead_code)] +pub(crate) type Word32 = WordLimbs; + +impl WordLimbs { + /// Constructor + pub fn new(limbs: [T; N]) -> Self { + Self { limbs } + } + /// The number of limbs + pub fn n() -> usize { + N + } +} + +impl WordLimbs, N> { + /// Query advice of WordLibs of columns advice + pub fn query_advice( + &self, + meta: &mut VirtualCells, + at: Rotation, + ) -> WordLimbs, N> { + WordLimbs::new(self.limbs.map(|column| meta.query_advice(column, at))) + } +} + +impl WordLimbs { + /// Convert WordLimbs of u8 to WordLimbs of expressions + pub fn to_expr(&self) -> WordLimbs, N> { + WordLimbs::new(self.limbs.map(|v| Expression::Constant(F::from(v as u64)))) + } +} + +impl Default for WordLimbs { + fn default() -> Self { + Self { limbs: [(); N].map(|_| T::default()) } + } +} + +impl WordLimbs { + /// Check if zero + pub fn is_zero_vartime(&self) -> bool { + self.limbs.iter().all(|limb| limb.is_zero_vartime()) + } +} + +/// Get the word expression +pub trait WordExpr { + /// Get the word expression + fn to_word(&self) -> Word>; +} + +/// `Word`, special alias for Word2. +#[derive(Clone, Debug, Copy, Default)] +pub struct Word(Word2); + +impl Word { + /// Construct the word from 2 limbs + pub fn new(limbs: [T; 2]) -> Self { + Self(WordLimbs::::new(limbs)) + } + /// The high 128 bits limb + pub fn hi(&self) -> T { + self.0.limbs[1].clone() + } + /// the low 128 bits limb + pub fn lo(&self) -> T { + self.0.limbs[0].clone() + } + /// number of limbs + pub fn n() -> usize { + 2 + } + /// word to low and high 128 bits + pub fn to_lo_hi(&self) -> (T, T) { + (self.0.limbs[0].clone(), self.0.limbs[1].clone()) + } + + /// Extract (move) lo and hi values + pub fn into_lo_hi(self) -> (T, T) { + let [lo, hi] = self.0.limbs; + (lo, hi) + } + + /// Wrap `Word` into `Word` + pub fn into_value(self) -> Word> { + let [lo, hi] = self.0.limbs; + Word::new([Value::known(lo), Value::known(hi)]) + } + + /// Map the word to other types + pub fn map(&self, mut func: impl FnMut(T) -> T2) -> Word { + Word(WordLimbs::::new([func(self.lo()), func(self.hi())])) + } +} + +impl std::ops::Deref for Word { + type Target = WordLimbs; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl PartialEq for Word { + fn eq(&self, other: &Self) -> bool { + self.lo() == other.lo() && self.hi() == other.hi() + } +} + +impl From for Word { + /// Construct the word from u256 + fn from(value: eth_types::Word) -> Self { + let bytes = value.to_le_bytes(); + Word::new([ + from_bytes::value(&bytes[..N_BYTES_HALF_WORD]), + from_bytes::value(&bytes[N_BYTES_HALF_WORD..]), + ]) + } +} + +impl From for Word { + /// Construct the word from H256 + fn from(h: H256) -> Self { + let le_bytes = { + let mut b = h.to_fixed_bytes(); + b.reverse(); + b + }; + Word::new([ + from_bytes::value(&le_bytes[..N_BYTES_HALF_WORD]), + from_bytes::value(&le_bytes[N_BYTES_HALF_WORD..]), + ]) + } +} + +impl From for Word { + /// Construct the word from u64 + fn from(value: u64) -> Self { + let bytes = value.to_le_bytes(); + Word::new([from_bytes::value(&bytes), F::from(0)]) + } +} + +impl From for Word { + /// Construct the word from u8 + fn from(value: u8) -> Self { + Word::new([F::from(value as u64), F::from(0)]) + } +} + +impl From for Word { + fn from(value: bool) -> Self { + Word::new([F::from(value as u64), F::from(0)]) + } +} + +impl From for Word { + /// Construct the word from h160 + fn from(value: H160) -> Self { + let mut bytes = *value.as_fixed_bytes(); + bytes.reverse(); + Word::new([ + from_bytes::value(&bytes[..N_BYTES_HALF_WORD]), + from_bytes::value(&bytes[N_BYTES_HALF_WORD..]), + ]) + } +} + +// impl Word> { +// /// Assign advice +// pub fn assign_advice( +// &self, +// region: &mut Region<'_, F>, +// annotation: A, +// column: Word>, +// offset: usize, +// ) -> Result>, Error> +// where +// A: Fn() -> AR, +// AR: Into, +// { +// let annotation: String = annotation().into(); +// let lo = region.assign_advice(|| &annotation, column.lo(), offset, || self.lo())?; +// let hi = region.assign_advice(|| &annotation, column.hi(), offset, || self.hi())?; + +// Ok(Word::new([lo, hi])) +// } +// } + +impl Word> { + /// Query advice of Word of columns advice + pub fn query_advice( + &self, + meta: &mut VirtualCells, + at: Rotation, + ) -> Word> { + self.0.query_advice(meta, at).to_word() + } +} + +impl Word> { + /// create word from lo limb with hi limb as 0. caller need to guaranteed to be 128 bits. + pub fn from_lo_unchecked(lo: Expression) -> Self { + Self(WordLimbs::, 2>::new([lo, 0.expr()])) + } + /// zero word + pub fn zero() -> Self { + Self(WordLimbs::, 2>::new([0.expr(), 0.expr()])) + } + + /// one word + pub fn one() -> Self { + Self(WordLimbs::, 2>::new([1.expr(), 0.expr()])) + } + + /// select based on selector. Here assume selector is 1/0 therefore no overflow check + pub fn select + Clone>( + selector: T, + when_true: Word, + when_false: Word, + ) -> Word> { + let (true_lo, true_hi) = when_true.to_lo_hi(); + + let (false_lo, false_hi) = when_false.to_lo_hi(); + Word::new([ + selector.expr() * true_lo.expr() + (1.expr() - selector.expr()) * false_lo.expr(), + selector.expr() * true_hi.expr() + (1.expr() - selector.expr()) * false_hi.expr(), + ]) + } + + /// Assume selector is 1/0 therefore no overflow check + pub fn mul_selector(&self, selector: Expression) -> Self { + Word::new([self.lo() * selector.clone(), self.hi() * selector]) + } + + /// No overflow check on lo/hi limbs + pub fn add_unchecked(self, rhs: Self) -> Self { + Word::new([self.lo() + rhs.lo(), self.hi() + rhs.hi()]) + } + + /// No underflow check on lo/hi limbs + pub fn sub_unchecked(self, rhs: Self) -> Self { + Word::new([self.lo() - rhs.lo(), self.hi() - rhs.hi()]) + } + + /// No overflow check on lo/hi limbs + pub fn mul_unchecked(self, rhs: Self) -> Self { + Word::new([self.lo() * rhs.lo(), self.hi() * rhs.hi()]) + } +} + +impl WordExpr for Word> { + fn to_word(&self) -> Word> { + self.clone() + } +} + +impl WordLimbs, N1> { + /// to_wordlimbs will aggregate nested expressions, which implies during expression evaluation + /// it need more recursive call. if the converted limbs word will be used in many places, + /// consider create new low limbs word, have equality constrain, then finally use low limbs + /// elsewhere. + // TODO static assertion. wordaround https://github.com/nvzqz/static-assertions-rs/issues/40 + pub fn to_word_n(&self) -> WordLimbs, N2> { + assert_eq!(N1 % N2, 0); + let limbs = self + .limbs + .chunks(N1 / N2) + .map(|chunk| from_bytes::expr(chunk)) + .collect_vec() + .try_into() + .unwrap(); + WordLimbs::, N2>::new(limbs) + } + + /// Equality expression + // TODO static assertion. wordaround https://github.com/nvzqz/static-assertions-rs/issues/40 + pub fn eq(&self, others: &WordLimbs, N2>) -> Expression { + assert_eq!(N1 % N2, 0); + not::expr(or::expr( + self.limbs + .chunks(N1 / N2) + .map(|chunk| from_bytes::expr(chunk)) + .zip(others.limbs.clone()) + .map(|(expr1, expr2)| expr1 - expr2) + .collect_vec(), + )) + } +} + +impl WordExpr for WordLimbs, N1> { + fn to_word(&self) -> Word> { + Word(self.to_word_n()) + } +} + +// TODO unittest diff --git a/rust-toolchain b/rust-toolchain index 51ab4759..ee2d639b 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2022-10-28 \ No newline at end of file +nightly-2023-08-12 \ No newline at end of file