Skip to content

Commit

Permalink
More zip with (microsoft#158)
Browse files Browse the repository at this point in the history
* Expand zip_with macros to accommodate more use.

* Clippy.

* Move Spartan macros to own module.

* Only use zip_with_fn to implement convenience macros.

* Use convenience zip_with macros more.

* Remove for_each variants of zip_with macros.

* Remove flat_map variants of zip_with.

---------

Co-authored-by: porcuquine <porcuquine@users.noreply.github.com>
  • Loading branch information
2 people authored and huitseeker committed Dec 8, 2023
1 parent 30d454b commit cdda2c2
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 327 deletions.
127 changes: 47 additions & 80 deletions src/spartan/batched.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::{
snark::{BatchedRelaxedR1CSSNARKTrait, DigestHelperTrait},
Engine, TranscriptEngineTrait,
},
CommitmentKey,
zip_with, zip_with_into_par_iter, zip_with_iter, zip_with_par_iter, CommitmentKey,
};

/// A succinct proof of knowledge of a witness to a batch of relaxed R1CS instances
Expand Down Expand Up @@ -155,8 +155,7 @@ where
.collect::<Result<Vec<_>, _>>()?;

// Pad (W,E) for each instance
let W =
zip_with!((W.iter(), S.iter()), |w, s| w.pad(s)).collect::<Vec<RelaxedR1CSWitness<E>>>();
let W = zip_with_iter!((W, S), |w, s| w.pad(s)).collect::<Vec<RelaxedR1CSWitness<E>>>();

let mut transcript = E::TE::new(b"BatchedRelaxedR1CSSNARK");

Expand All @@ -168,12 +167,8 @@ where
let (polys_W, polys_E): (Vec<_>, Vec<_>) = W.into_iter().map(|w| (w.W, w.E)).unzip();

// Append public inputs to W: Z = [W, u, X]
let polys_Z = zip_with!((polys_W.iter(), U.iter()), |w, u| [
w.clone(),
vec![u.u],
u.X.clone()
]
.concat())
let polys_Z = zip_with_iter!((polys_W, U), |w, u| [w.clone(), vec![u.u], u.X.clone()]
.concat())
.collect::<Vec<Vec<_>>>();

let (num_rounds_x, num_rounds_y): (Vec<_>, Vec<_>) = S
Expand All @@ -196,22 +191,17 @@ where

// Compute MLEs of Az, Bz, Cz, uCz + E
let (polys_Az, polys_Bz, polys_Cz): (Vec<_>, Vec<_>, Vec<_>) =
zip_with!((S.par_iter(), polys_Z.par_iter()), |s, poly_Z| {
zip_with_par_iter!((S, polys_Z), |s, poly_Z| {
let (poly_Az, poly_Bz, poly_Cz) = s.multiply_vec(poly_Z)?;
Ok((poly_Az, poly_Bz, poly_Cz))
})
.collect::<Result<Vec<_>, NovaError>>()?
.into_iter()
.multiunzip();

let polys_uCz_E = zip_with!(
(U.par_iter(), polys_E.par_iter(), polys_Cz.par_iter()),
|u, poly_E, poly_Cz| {
zip_with!((poly_Cz.par_iter(), poly_E.par_iter()), |cz, e| u.u * cz
+ e)
.collect::<Vec<E::Scalar>>()
}
)
let polys_uCz_E = zip_with_par_iter!((U, polys_E, polys_Cz), |u, poly_E, poly_Cz| {
zip_with_par_iter!((poly_Cz, poly_E), |cz, e| u.u * cz + e).collect::<Vec<E::Scalar>>()
})
.collect::<Vec<_>>();

let comb_func_outer =
Expand Down Expand Up @@ -253,14 +243,8 @@ where
.collect::<Vec<_>>();

// Extract evaluations of Az, Bz from Sumcheck and Cz, E at r_x
let (evals_Az_Bz_Cz, evals_E): (Vec<_>, Vec<_>) = zip_with!(
(
claims_outer[1].par_iter(),
claims_outer[2].par_iter(),
polys_Cz.par_iter(),
polys_E.par_iter(),
r_x.par_iter()
),
let (evals_Az_Bz_Cz, evals_E): (Vec<_>, Vec<_>) = zip_with_par_iter!(
(claims_outer[1], claims_outer[2], polys_Cz, polys_E, r_x),
|eval_Az, eval_Bz, poly_Cz, poly_E, r_x| {
let (eval_Cz, eval_E) = rayon::join(
|| MultilinearPolynomial::evaluate_with(poly_Cz, r_x),
Expand Down Expand Up @@ -295,18 +279,14 @@ where
M_evals_Bs: Vec<E::Scalar>,
M_evals_Cs: Vec<E::Scalar>|
-> Vec<E::Scalar> {
zip_with!(
(
M_evals_As.into_par_iter(),
M_evals_Bs.into_par_iter(),
M_evals_Cs.into_par_iter()
),
zip_with_into_par_iter!(
(M_evals_As, M_evals_Bs, M_evals_Cs),
|eval_A, eval_B, eval_C| eval_A + inner_r * eval_B + inner_r_square * eval_C
)
.collect::<Vec<_>>()
};

zip_with!((S.par_iter(), r_x.par_iter()), |s, r_x| {
zip_with_par_iter!((S, r_x), |s, r_x| {
let evals_rx = EqPolynomial::evals_from_points(r_x);
let (eval_A, eval_B, eval_C) = compute_eval_table_sparse(s, &evals_rx);
MultilinearPolynomial::new(inner(eval_A, eval_B, eval_C))
Expand Down Expand Up @@ -346,7 +326,7 @@ where
})
.collect::<Vec<_>>();

let evals_W = zip_with!((polys_W.par_iter(), r_y.par_iter()), |poly, r_y| {
let evals_W = zip_with_par_iter!((polys_W, r_y), |poly, r_y| {
MultilinearPolynomial::evaluate_with(poly, &r_y[1..])
})
.collect::<Vec<_>>();
Expand All @@ -356,14 +336,13 @@ where
let mut w_vec = Vec::with_capacity(2 * num_instances);
let mut u_vec = Vec::with_capacity(2 * num_instances);
w_vec.extend(polys_W.into_iter().map(|poly| PolyEvalWitness { p: poly }));
u_vec.extend(zip_with!(
(evals_W.iter(), U.iter(), r_y),
|eval, u, r_y| PolyEvalInstance {
u_vec.extend(zip_with_iter!((evals_W, U, r_y), |eval, u, r_y| {
PolyEvalInstance {
c: u.comm_W,
x: r_y[1..].to_vec(),
e: *eval,
}
));
}));

w_vec.extend(polys_E.into_iter().map(|poly| PolyEvalWitness { p: poly }));
u_vec.extend(zip_with!(
Expand Down Expand Up @@ -455,12 +434,12 @@ where

// Extract evaluations into a vector [(Azᵢ, Bzᵢ, Czᵢ, Eᵢ)]
// TODO: This is a multizip, simplify
let ABCE_evals = zip_with!(
let ABCE_evals = zip_with_iter!(
(
self.evals_E.iter(),
self.claims_outer.0.iter(),
self.claims_outer.1.iter(),
self.claims_outer.2.iter()
self.evals_E,
self.claims_outer.0,
self.claims_outer.1,
self.claims_outer.2
),
|eval_E, claim_Az, claim_Bz, claim_Cz| (*claim_Az, *claim_Bz, *claim_Cz, *eval_E)
)
Expand All @@ -477,8 +456,7 @@ where
});

// Evaluate τ(rₓ) for each instance
let evals_tau = zip_with!((polys_tau.iter(), r_x.iter()), |poly_tau, r_x| poly_tau
.evaluate(r_x));
let evals_tau = zip_with_iter!((polys_tau, r_x), |poly_tau, r_x| poly_tau.evaluate(r_x));

// Compute expected claim for all instances ∑ᵢ rⁱ⋅τ(rₓ)⋅(Azᵢ⋅Bzᵢ − uᵢ⋅Czᵢ − Eᵢ)
let claim_outer_final_expected = zip_with!(
Expand Down Expand Up @@ -527,25 +505,22 @@ where

// Compute evaluations of Zᵢ = [Wᵢ, uᵢ, Xᵢ] at r_y
// Zᵢ(r_y) = (1−r_y[0])⋅W(r_y[1..]) + r_y[0]⋅MLE([uᵢ, Xᵢ])(r_y[1..])
let evals_Z = zip_with!(
(self.evals_W.iter(), U.iter(), r_y.iter()),
|eval_W, U, r_y| {
let eval_X = {
// constant term
let mut poly_X = vec![(0, U.u)];
//remaining inputs
poly_X.extend(
U.X
.iter()
.enumerate()
.map(|(i, x_i)| (i + 1, *x_i))
.collect::<Vec<(usize, E::Scalar)>>(),
);
SparsePolynomial::new(r_y.len() - 1, poly_X).evaluate(&r_y[1..])
};
(E::Scalar::ONE - r_y[0]) * eval_W + r_y[0] * eval_X
}
)
let evals_Z = zip_with_iter!((self.evals_W, U, r_y), |eval_W, U, r_y| {
let eval_X = {
// constant term
let mut poly_X = vec![(0, U.u)];
//remaining inputs
poly_X.extend(
U.X
.iter()
.enumerate()
.map(|(i, x_i)| (i + 1, *x_i))
.collect::<Vec<(usize, E::Scalar)>>(),
);
SparsePolynomial::new(r_y.len() - 1, poly_X).evaluate(&r_y[1..])
};
(E::Scalar::ONE - r_y[0]) * eval_W + r_y[0] * eval_X
})
.collect::<Vec<_>>();

// compute evaluations of R1CS matrices M(r_x, r_y) = eq(r_y)ᵀ⋅M⋅eq(r_x)
Expand Down Expand Up @@ -579,14 +554,8 @@ where
};

// Compute inner claim ∑ᵢ r³ⁱ⋅(Aᵢ(r_x, r_y) + r⋅Bᵢ(r_x, r_y) + r²⋅Cᵢ(r_x, r_y))⋅Zᵢ(r_y)
let claim_inner_final_expected = zip_with!(
(
vk.S.iter(),
r_x.iter(),
r_y.iter(),
evals_Z.iter(),
inner_r_powers.iter()
),
let claim_inner_final_expected = zip_with_iter!(
(vk.S, r_x, r_y, evals_Z, inner_r_powers),
|S, r_x, r_y, eval_Z, r_i| {
let evals = multi_evaluate(&[&S.A, &S.B, &S.C], r_x, r_y);
let eval = evals[0] + inner_r * evals[1] + inner_r_square * evals[2];
Expand All @@ -602,23 +571,21 @@ where
// Create evaluation instances for W(r_y[1..]) and E(r_x)
let u_vec = {
let mut u_vec = Vec::with_capacity(2 * num_instances);
u_vec.extend(zip_with!(
(self.evals_W.iter(), U.iter(), r_y.iter()),
|eval, u, r_y| PolyEvalInstance {
u_vec.extend(zip_with_iter!((self.evals_W, U, r_y), |eval, u, r_y| {
PolyEvalInstance {
c: u.comm_W,
x: r_y[1..].to_vec(),
e: *eval,
}
));
}));

u_vec.extend(zip_with!(
(self.evals_E.iter(), U.iter(), r_x.iter()),
|eval, u, r_x| PolyEvalInstance {
u_vec.extend(zip_with_iter!((self.evals_E, U, r_x), |eval, u, r_x| {
PolyEvalInstance {
c: u.comm_E,
x: r_x.to_vec(),
e: *eval,
}
));
}));
u_vec
};

Expand Down
Loading

0 comments on commit cdda2c2

Please sign in to comment.