Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper ZK treatment in plonky2 #1625

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
35 changes: 32 additions & 3 deletions plonky2/src/batch_fri/oracle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(not(feature = "std"))]
use alloc::{format, vec::Vec};
use alloc::{format, vec, vec::Vec};

use itertools::Itertools;
use plonky2_field::extension::Extendable;
Expand All @@ -19,6 +19,7 @@ use crate::hash::batch_merkle_tree::BatchMerkleTree;
use crate::hash::hash_types::RichField;
use crate::iop::challenger::Challenger;
use crate::plonk::config::GenericConfig;
use crate::plonk::plonk_common::PlonkOracle;
use crate::timed;
use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree;
Expand Down Expand Up @@ -151,9 +152,15 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
// where the `k_i`s are chosen such that each power of `alpha` appears only once in the final sum.
// There are usually two batches for the openings at `zeta` and `g * zeta`.
// The oracles used in Plonky2 are given in `FRI_ORACLES` in `plonky2/src/plonk/plonk_common.rs`.
for FriBatchInfo { point, polynomials } in &instance.batches {
for (idx, FriBatchInfo { point, polynomials }) in instance.batches.iter().enumerate() {
let is_zk = fri_params.hiding;
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
.sum();
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
// Collect the coefficients of all the polynomials in `polynomials`.
let polys_coeff = polynomials.iter().map(|fri_poly| {
let polys_coeff = polynomials[..last_poly].iter().map(|fri_poly| {
&oracles[fri_poly.oracle_index].polynomials[fri_poly.polynomial_index]
});
let composition_poly = timed!(
Expand All @@ -165,6 +172,28 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
quotient.coeffs.push(F::Extension::ZERO); // pad back to power of two
alpha.shift_poly(&mut final_poly);
final_poly += quotient;

if is_zk && idx == 0 {
let degree = 1 << degree_bits[i];
let mut composition_poly = PolynomialCoeffs::empty();
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, fri_poly)| {
let mut cur_coeffs = oracles[fri_poly.oracle_index].polynomials
[fri_poly.polynomial_index]
.coeffs
.clone();
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; degree * i]);
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; 2 * degree - cur_coeffs.len()]);
composition_poly += PolynomialCoeffs { coeffs: cur_coeffs };
});

alpha.shift_poly(&mut final_poly);
final_poly += composition_poly.to_extension();
}
}

assert_eq!(final_poly.len(), 1 << degree_bits[i]);
Expand Down
10 changes: 8 additions & 2 deletions plonky2/src/batch_fri/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ pub fn batch_fri_proof<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>,
fri_params: &FriParams,
timing: &mut TimingTree,
) -> FriProof<F, C::Hasher, D> {
let n = lde_polynomial_coeffs.len();
assert_eq!(lde_polynomial_values[0].len(), n);
let mut n = lde_polynomial_coeffs.len();
assert_eq!(lde_polynomial_values[0].len(), lde_polynomial_coeffs.len());
// The polynomial vectors should be sorted by degree, from largest to smallest, with no duplicate degrees.
assert!(lde_polynomial_values
.windows(2)
Expand All @@ -49,6 +49,12 @@ pub fn batch_fri_proof<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>,
}
assert_eq!(cur_poly_index, lde_polynomial_values.len());

// In the zk case, the final polynomial polynomial to be reduced has degree double that
// of the original batch FRI polynomial.
if fri_params.hiding {
n /= 2;
}

// Commit phase
let (trees, final_coeffs) = timed!(
timing,
Expand Down
54 changes: 49 additions & 5 deletions plonky2/src/batch_fri/recursive_verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use alloc::{format, vec::Vec};

use itertools::Itertools;
use plonky2_field::types::Field;

use crate::field::extension::Extendable;
use crate::fri::proof::{
Expand All @@ -15,6 +16,7 @@ use crate::iop::ext_target::{flatten_target, ExtensionTarget};
use crate::iop::target::{BoolTarget, Target};
use crate::plonk::circuit_builder::CircuitBuilder;
use crate::plonk::config::{AlgebraicHasher, GenericConfig};
use crate::plonk::plonk_common::PlonkOracle;
use crate::util::reducing::ReducingFactorTarget;
use crate::with_context;

Expand Down Expand Up @@ -62,7 +64,8 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
PrecomputedReducedOpeningsTarget::from_os_and_alpha(
opn,
challenges.fri_alpha,
self
self,
params.hiding,
)
);
precomputed_reduced_evals.push(pre);
Expand Down Expand Up @@ -165,13 +168,24 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let mut alpha = ReducingFactorTarget::new(alpha);
let mut sum = self.zero_extension();

for (batch, reduced_openings) in instance[index]
for (idx, (batch, reduced_openings)) in instance[index]
.batches
.iter()
.zip(&precomputed_reduced_evals.reduced_openings_at_point)
.enumerate()
{
// If we are in the zk case, the `R` polynomial (the last polynomials in the first batch) is added to
// the batch polynomial independently, without being quotiented. So the final polynomial becomes:
// `final_poly = sum_i alpha^(k_i) (F_i(X) - F_i(z_i))/(X-z_i) + alpha^n R(X)`, where `n` is the degree
// of the batch polynomial.
let FriBatchInfoTarget { point, polynomials } = batch;
let evals = polynomials
let is_zk = params.hiding;
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
.sum();
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
let evals = polynomials[..last_poly]
.iter()
.map(|p| {
let poly_blinding = instance[index].oracles[p.oracle_index].blinding;
Expand All @@ -184,6 +198,31 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
let denominator = self.sub_extension(subgroup_x, *point);
sum = alpha.shift(sum, self);
sum = self.div_add_extension(numerator, denominator, sum);

// If we are in the zk case, we still have to add `R(X)` to the batch.
if is_zk && idx == 0 {
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, p)| {
let poly_blinding = instance[index].oracles[p.oracle_index].blinding;
let salted = params.hiding && poly_blinding;
let eval = proof.unsalted_eval(p.oracle_index, p.polynomial_index, salted);
sum = alpha.shift(sum, self);
let val = self
.constant_extension(F::Extension::from_canonical_u32((i == 0) as u32));
let power =
self.exp_power_of_2_extension(subgroup_x, i * params.degree_bits);
let pi =
self.constant_extension(F::Extension::from_canonical_u32(i as u32));
let power = self.mul_extension(power, pi);
let shift_val = self.add_extension(val, power);

let eval_extension = eval.to_ext_target(self.zero());
let tmp = self.mul_extension(eval_extension, shift_val);
sum = self.add_extension(sum, tmp);
});
}
}

sum
Expand All @@ -210,7 +249,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
Self::assert_noncanonical_indices_ok(&params.config);
let mut x_index_bits = self.low_bits(x_index, n, F::BITS);

let cap_index =
let initial_cap_index =
self.le_sum(x_index_bits[x_index_bits.len() - params.config.cap_height..].iter());
with_context!(
self,
Expand All @@ -221,7 +260,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
&x_index_bits,
&round_proof.initial_trees_proof,
initial_merkle_caps,
cap_index
initial_cap_index
)
);

Expand Down Expand Up @@ -252,6 +291,11 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilder<F, D> {
);
batch_index += 1;

// In case of zk, the finaly polynomial's degree bits is increased by 1.
let cap_index = self.le_sum(
x_index_bits[x_index_bits.len() + params.hiding as usize - params.config.cap_height..]
.iter(),
);
for (i, &arity_bits) in params.reduction_arity_bits.iter().enumerate() {
let evals = &round_proof.steps[i].evals;

Expand Down
36 changes: 33 additions & 3 deletions plonky2/src/batch_fri/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::hash::hash_types::RichField;
use crate::hash::merkle_proofs::{verify_batch_merkle_proof_to_cap, verify_merkle_proof_to_cap};
use crate::hash::merkle_tree::MerkleCap;
use crate::plonk::config::{GenericConfig, Hasher};
use crate::plonk::plonk_common::PlonkOracle;
use crate::util::reducing::ReducingFactor;
use crate::util::reverse_bits;

Expand Down Expand Up @@ -46,7 +47,8 @@ pub fn verify_batch_fri_proof<

let mut precomputed_reduced_evals = Vec::with_capacity(openings.len());
for opn in openings {
let pre = PrecomputedReducedOpenings::from_os_and_alpha(opn, challenges.fri_alpha);
let pre =
PrecomputedReducedOpenings::from_os_and_alpha(opn, challenges.fri_alpha, params.hiding);
precomputed_reduced_evals.push(pre);
}
let degree_bits = degree_bits
Expand Down Expand Up @@ -123,13 +125,24 @@ fn batch_fri_combine_initial<
let mut alpha = ReducingFactor::new(alpha);
let mut sum = F::Extension::ZERO;

for (batch, reduced_openings) in instances[index]
// If we are in the zk case, the `R` polynomial (the last polynomials in the first batch) is added to
// the batch polynomial independently, without being quotiented. So the final polynomial becomes:
// `final_poly = sum_i alpha^(k_i) (F_i(X) - F_i(z_i))/(X-z_i) + alpha^n R(X)`, where `n` is the degree
// of the batch polynomial.
for (idx, (batch, reduced_openings)) in instances[index]
.batches
.iter()
.zip(&precomputed_reduced_evals.reduced_openings_at_point)
.enumerate()
{
let FriBatchInfo { point, polynomials } = batch;
let evals = polynomials
let is_zk = params.hiding;
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
.sum();
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
let evals = polynomials[..last_poly]
.iter()
.map(|p| {
let poly_blinding = instances[index].oracles[p.oracle_index].blinding;
Expand All @@ -142,6 +155,23 @@ fn batch_fri_combine_initial<
let denominator = subgroup_x - *point;
sum = alpha.shift(sum);
sum += numerator / denominator;

// If we are in the zk case, we still have to add `R(X)` to the batch.
if is_zk && idx == 0 {
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, p)| {
let poly_blinding = instances[index].oracles[p.oracle_index].blinding;
let salted = params.hiding && poly_blinding;
let eval = proof.unsalted_eval(p.oracle_index, p.polynomial_index, salted);
sum = alpha.shift(sum);
let shift_val = F::Extension::from_canonical_usize((i == 0) as usize)
+ subgroup_x.exp_power_of_2(i * params.degree_bits)
* F::Extension::from_canonical_usize(i);
sum += F::Extension::from_basefield(eval) * shift_val;
});
}
}

sum
Expand Down
5 changes: 3 additions & 2 deletions plonky2/src/fri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ impl FriConfig {
self.rate_bits,
self.cap_height,
self.num_query_rounds,
hiding,
);
FriParams {
config: self.clone(),
Expand Down Expand Up @@ -87,7 +88,7 @@ pub struct FriParams {

impl FriParams {
pub fn total_arities(&self) -> usize {
self.reduction_arity_bits.iter().sum()
self.reduction_arity_bits.iter().sum::<usize>()
}

pub(crate) fn max_arity_bits(&self) -> Option<usize> {
Expand All @@ -103,7 +104,7 @@ impl FriParams {
}

pub fn final_poly_bits(&self) -> usize {
self.degree_bits - self.total_arities()
self.degree_bits + self.hiding as usize - self.total_arities()
}

pub fn final_poly_len(&self) -> usize {
Expand Down
46 changes: 42 additions & 4 deletions plonky2/src/fri/oracle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(not(feature = "std"))]
use alloc::{format, vec::Vec};
use alloc::{format, vec, vec::Vec};

use itertools::Itertools;
use plonky2_field::types::Field;
Expand All @@ -17,6 +17,7 @@ use crate::hash::hash_types::RichField;
use crate::hash::merkle_tree::MerkleTree;
use crate::iop::challenger::Challenger;
use crate::plonk::config::GenericConfig;
use crate::plonk::plonk_common::PlonkOracle;
use crate::timed;
use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree;
Expand Down Expand Up @@ -194,9 +195,23 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
// where the `k_i`s are chosen such that each power of `alpha` appears only once in the final sum.
// There are usually two batches for the openings at `zeta` and `g * zeta`.
// The oracles used in Plonky2 are given in `FRI_ORACLES` in `plonky2/src/plonk/plonk_common.rs`.
for FriBatchInfo { point, polynomials } in &instance.batches {
// Collect the coefficients of all the polynomials in `polynomials`.
let polys_coeff = polynomials.iter().map(|fri_poly| {
//
// If we are in the zk case, the `R` polynomial (the last polynomials in the first batch) is added to
// the batch polynomial independently, without being quotiented. So the final polynomial becomes:
// `final_poly = sum_i alpha^(k_i) (F_i(X) - F_i(z_i))/(X-z_i) + alpha^n R(X)`, where `n` is the degree
// of the batch polynomial.
// Then, since the degree of `R` is double that of the batch polynomial in our cimplementation, we need to
// compute one extra step in FRI to reach the correct degree.
let is_zk = fri_params.hiding;

for (idx, FriBatchInfo { point, polynomials }) in instance.batches.iter().enumerate() {
let nb_r_polys: usize = polynomials
.iter()
.map(|p| (p.oracle_index == PlonkOracle::R.index) as usize)
.sum();
let last_poly = polynomials.len() - nb_r_polys * (idx == 0) as usize;
// Collect the coefficients of all the polynomials in `polynomials` until `last_poly`.
let polys_coeff = polynomials[..last_poly].iter().map(|fri_poly| {
&oracles[fri_poly.oracle_index].polynomials[fri_poly.polynomial_index]
});
let composition_poly = timed!(
Expand All @@ -208,6 +223,29 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
quotient.coeffs.push(F::Extension::ZERO); // pad back to power of two
alpha.shift_poly(&mut final_poly);
final_poly += quotient;

// If we are in the zk case, we still have to add `R(X)` to the batch.
if is_zk && idx == 0 {
let degree = 1 << oracles[0].degree_log;
let mut composition_poly = PolynomialCoeffs::empty();
polynomials[last_poly..]
.iter()
.enumerate()
.for_each(|(i, fri_poly)| {
let mut cur_coeffs = oracles[fri_poly.oracle_index].polynomials
[fri_poly.polynomial_index]
.coeffs
.clone();
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; degree * i]);
cur_coeffs.reverse();
cur_coeffs.extend(vec![F::ZERO; 2 * degree - cur_coeffs.len()]);
composition_poly += PolynomialCoeffs { coeffs: cur_coeffs };
});

alpha.shift_poly(&mut final_poly);
final_poly += composition_poly.to_extension();
}
}

let lde_final_poly = final_poly.lde(fri_params.config.rate_bits);
Expand Down
Loading